You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
217 lines
6.1 KiB
217 lines
6.1 KiB
from __future__ import annotations |
|
|
|
import asyncio |
|
from datetime import datetime |
|
|
|
import httpx |
|
from fastapi import Depends, FastAPI, HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from arango_client import arango |
|
|
|
from info import debate_types, explainer, limit_warning, party_colors |
|
from .schemas import ( |
|
ChatRequest, |
|
ChatResponse, |
|
FeedbackRequest, |
|
FeedbackResponse, |
|
SearchRequest, |
|
SearchResponse, |
|
TalkHit, |
|
) |
|
from .services import ChatService, SearchService |
|
from backend.routes.chat import router as chat_router |
|
from .services.names_autocomplete import router as names_autocomplete_router |
|
|
|
app = FastAPI(title="Riksdagen API", version="0.1.0") |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], # tighten for production |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
search_service = SearchService() |
|
chat_service = ChatService() |
|
|
|
app.include_router(chat_router) |
|
app.include_router(names_autocomplete_router) |
|
|
|
|
|
@app.get("/api/meta") |
|
def meta(): |
|
return { |
|
"parties": party_colors, |
|
"debate_types": debate_types, |
|
"explainer": explainer, |
|
"limit_warning": limit_warning, |
|
} |
|
|
|
|
|
@app.post("/api/search", response_model=SearchResponse) |
|
def search(payload: SearchRequest): |
|
results, stats, limit_reached = search_service.search( |
|
payload, include_snippets=payload.include_snippets |
|
) |
|
|
|
# Try to convert results to TalkHit objects |
|
hits = [] |
|
for idx, hit in enumerate(results): |
|
try: |
|
talk_hit = TalkHit(**hit) |
|
# Serialize using alias so 'id' is sent to frontend, not '_id' |
|
hit_dict = talk_hit.dict(by_alias=True) |
|
hits.append(hit_dict) |
|
except Exception as e: |
|
print(f"Error converting result {idx} to TalkHit: {e}") |
|
print(f"Problematic result: {hit}") |
|
# Continue with other results instead of failing completely |
|
continue |
|
|
|
return { |
|
"results": hits, |
|
"stats": stats, |
|
"active_filters": { |
|
"parties": payload.parties, |
|
"people": payload.people, |
|
"debates": payload.debates, |
|
"from_year": payload.from_year, |
|
"to_year": payload.to_year, |
|
"speaker_ids": payload.speaker_ids, |
|
"speaker": payload.speaker, |
|
}, |
|
"limit_reached": limit_reached, |
|
} |
|
|
|
|
|
@app.post("/api/chat", response_model=ChatResponse) |
|
def chat(payload: ChatRequest) -> ChatResponse: |
|
""" |
|
Generate a chat answer plus citations via the retrieval-aware ChatService. |
|
|
|
Args: |
|
payload (ChatRequest): Chat history, retrieval strategy, and result limit. |
|
|
|
Returns: |
|
ChatResponse: Assistant reply and supporting sources. |
|
""" |
|
if not payload.messages: |
|
raise HTTPException(status_code=400, detail="messages cannot be empty") |
|
|
|
messages = [message.dict() for message in payload.messages] |
|
limit = getattr(payload, "top_k", None) |
|
if limit is None: |
|
limit = getattr(payload, "limit", None) |
|
top_k = limit or 5 |
|
|
|
chat_result = chat_service.get_chat_response( |
|
messages=messages, |
|
top_k=top_k, |
|
) |
|
return ChatResponse(answer=chat_result["answer"], sources=chat_result["sources"]) |
|
|
|
|
|
@app.get("/api/talk/{talk_id}") |
|
async def get_talk(talk_id: str) -> dict: |
|
""" |
|
Fetch a single talk document by its ID from the 'talks' collection. |
|
|
|
This endpoint accepts either: |
|
- A full _id like "talks/H40911" |
|
- Just the _key like "H40911" (will be prefixed with "talks/") |
|
|
|
The document is joined with the corresponding person from the 'people' collection |
|
using the intressent_id field. |
|
|
|
The response also includes lightweight navigation data (previous/next speeches) |
|
for the same debate when ordering information is available. |
|
|
|
Args: |
|
talk_id (str): The talk ID (either full _id or just _key) |
|
|
|
Returns: |
|
dict: The talk document with person information merged in |
|
|
|
Raises: |
|
HTTPException: 404 if talk not found |
|
""" |
|
# If the ID doesn't contain a slash, assume it's just the _key and prefix with collection |
|
if "/" not in talk_id: |
|
full_id = f"talks/{talk_id}" |
|
else: |
|
full_id = talk_id |
|
|
|
# AQL query to fetch the talk and join with person data |
|
query = """ |
|
LET doc_full = DOCUMENT(@talk_id) |
|
FILTER doc_full != null |
|
|
|
/* Only keep the relevant fields from the talk */ |
|
LET doc = KEEP( |
|
doc_full, |
|
[ |
|
"anforandetext", |
|
"talare", |
|
"parti", |
|
"datum", |
|
"kammaraktivitet", |
|
"avsnittsrubrik", |
|
"titel", |
|
"anforande_nummer", |
|
"replik", |
|
"url_session", |
|
"url_audio" |
|
] |
|
) |
|
|
|
/* Fetch person only if intressent_id exists */ |
|
LET person_full = doc_full.intressent_id |
|
? DOCUMENT(CONCAT("people/", doc_full.intressent_id)) |
|
: null |
|
|
|
LET person = person_full |
|
? KEEP(person_full, ["bild_url_192", "tilltalsnamn", "efternamn", "valkrets", "status"]) |
|
: null |
|
|
|
/* Interpret anforande_nummer as number */ |
|
LET num = IS_NUMBER(TO_NUMBER(doc.anforande_nummer)) ? TO_NUMBER(doc.anforande_nummer) : null |
|
|
|
LET previous = num != null |
|
? FIRST( |
|
FOR t IN talks |
|
FILTER t.datum == doc.datum |
|
AND t.kammaraktivitet == doc.kammaraktivitet |
|
AND IS_NUMBER(TO_NUMBER(t.anforande_nummer)) |
|
AND TO_NUMBER(t.anforande_nummer) == num - 1 |
|
RETURN t._id |
|
) |
|
: null |
|
|
|
LET next = num != null |
|
? FIRST( |
|
FOR t IN talks |
|
FILTER t.datum == doc.datum |
|
AND t.kammaraktivitet == doc.kammaraktivitet |
|
AND IS_NUMBER(TO_NUMBER(t.anforande_nummer)) |
|
AND TO_NUMBER(t.anforande_nummer) == num + 1 |
|
RETURN t._id |
|
) |
|
: null |
|
|
|
RETURN MERGE(doc, { |
|
person: person, |
|
navigation: { |
|
previous: previous, |
|
next: next |
|
} |
|
}) |
|
|
|
""" |
|
|
|
results = arango.execute_aql(query, bind_vars={"talk_id": full_id}) |
|
|
|
if not results or results[0] is None: |
|
raise HTTPException(status_code=404, detail=f"Talk not found: {talk_id}") |
|
|
|
return results[0]
|
|
|