from typing import Literal, List from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field from backend.services.chat import ChatService # Import the service class router = APIRouter(prefix="/api", tags=["chat"]) # Pydantic models for request/response validation class ChatMessage(BaseModel): role: Literal["system", "user", "assistant"] content: str = Field(..., min_length=1) class ChatRequest(BaseModel): messages: List[ChatMessage] top_k: int = Field(default=5, ge=1, le=10) focus_ids: List[str] | None = Field(default=None, description="Optional ids from previously shared results.") class ChatSource(BaseModel): _id: str chunk_index: int heading: str | None debateurl: str | None snippet: str class ChatResponse(BaseModel): answer: str sources: List[ChatSource] tables: List[dict] = Field(default_factory=list) focus_ids: List[str] = Field(default_factory=list) # Instantiate the chat service once (can be reused for all requests) chat_service = ChatService() @router.post("/chat", response_model=ChatResponse) def chat_endpoint(payload: ChatRequest) -> ChatResponse: """ Handles chat requests from the frontend. Uses ChatService to generate a response. Args: payload (ChatRequest): The chat history and parameters from the frontend. Returns: ChatResponse: The assistant's answer and a list of sources. """ # Convert Pydantic models to dicts for the service messages = [msg.model_dump() for msg in payload.messages] try: result = chat_service.get_chat_response( messages=messages, top_k=payload.top_k, focus_ids=payload.focus_ids or [], ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: # Log the error for debugging import traceback print("UNHANDLED ERROR in chat_endpoint:", e) traceback.print_exc() raise HTTPException(status_code=500, detail=f"Internal server error: {e}") raw_answer = result.get("answer", "") if not isinstance(raw_answer, str): raw_answer = str(raw_answer) raw_sources = result.get("sources", []) if not isinstance(raw_sources, list): raw_sources = [] sources = [ ChatSource( _id=src.get("_id", ""), chunk_index=src.get("chunk_index", 0), heading=src.get("heading"), debateurl=src.get("debateurl"), snippet=src.get("snippet", ""), ) for src in raw_sources ] return ChatResponse( answer=raw_answer, sources=sources, tables=result.get("tables", []), focus_ids=result.get("focus_ids", []), )