You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
86 lines
2.7 KiB
86 lines
2.7 KiB
from typing import Literal, List |
|
|
|
from fastapi import APIRouter, HTTPException |
|
from pydantic import BaseModel, Field |
|
|
|
from backend.services.chat import ChatService # Import the service class |
|
|
|
router = APIRouter(prefix="/api", tags=["chat"]) |
|
|
|
# Pydantic models for request/response validation |
|
class ChatMessage(BaseModel): |
|
role: Literal["system", "user", "assistant"] |
|
content: str = Field(..., min_length=1) |
|
|
|
class ChatRequest(BaseModel): |
|
messages: List[ChatMessage] |
|
top_k: int = Field(default=5, ge=1, le=10) |
|
focus_ids: List[str] | None = Field(default=None, description="Optional ids from previously shared results.") |
|
|
|
class ChatSource(BaseModel): |
|
_id: str |
|
chunk_index: int |
|
heading: str | None |
|
debateurl: str | None |
|
snippet: str |
|
|
|
class ChatResponse(BaseModel): |
|
answer: str |
|
sources: List[ChatSource] |
|
tables: List[dict] = Field(default_factory=list) |
|
focus_ids: List[str] = Field(default_factory=list) |
|
|
|
# Instantiate the chat service once (can be reused for all requests) |
|
chat_service = ChatService() |
|
|
|
@router.post("/chat", response_model=ChatResponse) |
|
def chat_endpoint(payload: ChatRequest) -> ChatResponse: |
|
""" |
|
Handles chat requests from the frontend. Uses ChatService to generate a response. |
|
|
|
Args: |
|
payload (ChatRequest): The chat history and parameters from the frontend. |
|
|
|
Returns: |
|
ChatResponse: The assistant's answer and a list of sources. |
|
""" |
|
# Convert Pydantic models to dicts for the service |
|
messages = [msg.model_dump() for msg in payload.messages] |
|
try: |
|
result = chat_service.get_chat_response( |
|
messages=messages, |
|
top_k=payload.top_k, |
|
focus_ids=payload.focus_ids or [], |
|
) |
|
except ValueError as e: |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
except Exception as e: |
|
# Log the error for debugging |
|
import traceback |
|
print("UNHANDLED ERROR in chat_endpoint:", e) |
|
traceback.print_exc() |
|
raise HTTPException(status_code=500, detail=f"Internal server error: {e}") |
|
raw_answer = result.get("answer", "") |
|
if not isinstance(raw_answer, str): |
|
raw_answer = str(raw_answer) |
|
|
|
raw_sources = result.get("sources", []) |
|
if not isinstance(raw_sources, list): |
|
raw_sources = [] |
|
|
|
sources = [ |
|
ChatSource( |
|
_id=src.get("_id", ""), |
|
chunk_index=src.get("chunk_index", 0), |
|
heading=src.get("heading"), |
|
debateurl=src.get("debateurl"), |
|
snippet=src.get("snippet", ""), |
|
) |
|
for src in raw_sources |
|
] |
|
return ChatResponse( |
|
answer=raw_answer, |
|
sources=sources, |
|
tables=result.get("tables", []), |
|
focus_ids=result.get("focus_ids", []), |
|
) |