You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

86 lines
2.7 KiB

from typing import Literal, List
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from backend.services.chat import ChatService # Import the service class
router = APIRouter(prefix="/api", tags=["chat"])
# Pydantic models for request/response validation
class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant"]
content: str = Field(..., min_length=1)
class ChatRequest(BaseModel):
messages: List[ChatMessage]
top_k: int = Field(default=5, ge=1, le=10)
focus_ids: List[str] | None = Field(default=None, description="Optional ids from previously shared results.")
class ChatSource(BaseModel):
_id: str
chunk_index: int
heading: str | None
debateurl: str | None
snippet: str
class ChatResponse(BaseModel):
answer: str
sources: List[ChatSource]
tables: List[dict] = Field(default_factory=list)
focus_ids: List[str] = Field(default_factory=list)
# Instantiate the chat service once (can be reused for all requests)
chat_service = ChatService()
@router.post("/chat", response_model=ChatResponse)
def chat_endpoint(payload: ChatRequest) -> ChatResponse:
"""
Handles chat requests from the frontend. Uses ChatService to generate a response.
Args:
payload (ChatRequest): The chat history and parameters from the frontend.
Returns:
ChatResponse: The assistant's answer and a list of sources.
"""
# Convert Pydantic models to dicts for the service
messages = [msg.model_dump() for msg in payload.messages]
try:
result = chat_service.get_chat_response(
messages=messages,
top_k=payload.top_k,
focus_ids=payload.focus_ids or [],
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
# Log the error for debugging
import traceback
print("UNHANDLED ERROR in chat_endpoint:", e)
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Internal server error: {e}")
raw_answer = result.get("answer", "")
if not isinstance(raw_answer, str):
raw_answer = str(raw_answer)
raw_sources = result.get("sources", [])
if not isinstance(raw_sources, list):
raw_sources = []
sources = [
ChatSource(
_id=src.get("_id", ""),
chunk_index=src.get("chunk_index", 0),
heading=src.get("heading"),
debateurl=src.get("debateurl"),
snippet=src.get("snippet", ""),
)
for src in raw_sources
]
return ChatResponse(
answer=raw_answer,
sources=sources,
tables=result.get("tables", []),
focus_ids=result.get("focus_ids", []),
)