You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
487 lines
24 KiB
487 lines
24 KiB
from __future__ import annotations |
|
|
|
from typing import Any, Dict, List, Sequence, Optional, Tuple |
|
from pydantic import BaseModel, Field |
|
import backend.services.llm_tools |
|
from _llm import LLM, get_tools, ChatCompletionMessage |
|
from colorprinter import * |
|
import json |
|
|
|
ChatResponse = Dict[str, Any] |
|
ChatSource = Dict[str, Any] |
|
ChatMessage = Dict[str, Any] |
|
|
|
|
|
class FirstAnswer(BaseModel): |
|
research_plan: str = Field(..., description="Your concise research plan") |
|
direct_reply: Optional[str] = Field( |
|
None, |
|
description="If the message is clearly not a question but a greeting or a comment, just respond in a friendly manner. Only use this if you're not making a plan!", |
|
) |
|
request_clarification: Optional[str] = Field( |
|
None, |
|
description="If you really don't understand the user's question, ask the user to clarify.", |
|
) |
|
|
|
|
|
class FinalAnswer(BaseModel): |
|
final_answer: str = Field(..., description="Your final answer") |
|
explanation: str = Field( |
|
..., |
|
description="Your short and non-technical explanation of how you arrived at the answer", |
|
) |
|
|
|
|
|
class ChatService: |
|
""" |
|
Handles retrieval-augmented replies by letting the LLM pick tools dynamically. |
|
""" |
|
|
|
def __init__(self) -> None: |
|
""" |
|
Prepare the LLM instance and cache available tool specifications. |
|
""" |
|
system_message = """ |
|
You help users find information in speeches from the Swedish Riksdag. You have several tools available to search the speeches database; use these tools whenever you need data not present in earlier messages. |
|
|
|
*Important operational rules:* |
|
- Always read each tool's description and arguments carefully before calling it; follow examples. |
|
- When presenting results, cite sources by mentioning the talk titles and dates when available. |
|
- You may call multiple tools in one conversation; if one tool doesn't return what you need, call another. |
|
- Summarize and analyze findings continuously so you know what you have and what you still need. My including things like _id:s and other valuable information in you reasoning, this will be stored to your memory. |
|
|
|
**Decision / tool-selection map:** |
|
- Use `vector_search_talks(query, limit)` for semantic / concept matches (conceptual similarity, thematic clustering). |
|
- Use `arango_search(query, parties, people, from_year, to_year, limit)` for ranked full-text searches (language-aware, boolean/phrase search, highlighted snippets). |
|
- Use `search_documents(query)` for exact/structured queries, joins, and aggregations (you must write AQL; see the tool's docstring for templates). |
|
- Use `fetch_documents(_ids)` to retrieve a list of entire documents when you need the full texts. |
|
|
|
You can only request a tool use, not use it directly. After you request a tool, wait for the results to come back as a message from the tool. Then analyze the results and decide your next step. |
|
|
|
**Notes on different tools:** |
|
- `search_arango` has three special features/parameters: |
|
1) `return_snippets` – If you want to get an overview of the results, use this parameter to get highlighted snippets instead of full documents. |
|
2) `results_to_user` – If the user has asked for e.g. "a list of talks mentioning...", "give me all speeches about..." – or in other ways indicates they want to see the actual results – set `return_to_user`=True so the results are sent to the user as they are. This is a good way of showing the user the actual search results, and save resources. |
|
3) `focus_ids` – If you've done a previous search where you've used `results_to_user=True`, and the user has then asked a follow-up question that requires a more specific search within the previous results, set this parameter to True to do a search within the ID:s from the last search. |
|
- `search_arango` will return a lot of text if not set to `return_snippets=True` or the limit is set to above 20. The result will then be trucated. Tip: Make use of the fact that the results from `search_arango` are ordered by relevance and start with the most relevant ones. |
|
|
|
|
|
**When giving your final answer:** |
|
- Always start with a short summary of your findings, before any detailed analysis or tables. |
|
- Respond concisely, the user is not here for small talk. |
|
- **IMPORTANT: Always format your answer using Markdown.** The frontend will convert it to HTML automatically. |
|
- **IMPORTANT: Use inline citation numbers for ALL source citations.** Use the format `[1]`, `[2]`, etc. directly after the statement that references the source. |
|
- **CRITICAL: Citations must be plain square brackets with numbers inside: `[1]`, `[2]`, `[3]`. Do NOT use Markdown footnote syntax like `[^1]` or special Unicode brackets like `【1】`.** |
|
- **IMPORTANT: Always include a "Källor" (Sources) section at the end** with a numbered list matching your citations. Format each source as: `[1] Speaker name – Date – Brief context or quote` |
|
- Example of correct citation format: |
|
``` |
|
ROT-avdraget infördes 2009[1] och hade som syfte att minska svartarbete[2]. |
|
|
|
## Källor |
|
[1] Eva Andersson – 2009-01-15 – Debatt om ROT-avdrag |
|
[2] Per Svensson – 2009-02-20 – Diskussion om byggbranschen |
|
``` |
|
- Make sure citation numbers are sequential ([1], [2], [3]...) and that every citation has a matching entry in the Källor section. |
|
- Don't use internal _id or chunk_index fields in your answer; use human-readable information (speaker, date, topic). |
|
- Don't ever make up quotes or facts; if you don't have enough information, say that you don't know, or call another tool to find more information. |
|
- Answer in Swedish. |
|
""" |
|
self.llm = LLM(model="vllm", system_message=system_message, temperature=0.15) |
|
self.tools = get_tools(exclude_tools=["sql_query"]) |
|
self.max_tool_iterations = 20 |
|
|
|
def get_chat_response( |
|
self, |
|
messages: Sequence[ChatMessage], |
|
top_k: int = 5, |
|
focus_ids: Optional[Sequence[str]] = None, |
|
) -> ChatResponse: |
|
""" |
|
Generate a reply while allowing the assistant to call registered tools. |
|
|
|
Args: |
|
messages: Ordered chat history including the latest user prompt. |
|
top_k: Maximum number of unique sources to expose to the client. |
|
focus_ids: Optional list of document ids shared with the user in earlier turns. |
|
Returns: |
|
Dict containing the assistant answer and harvested sources. |
|
""" |
|
print_yellow(f"Messages in chat:") |
|
for msg in messages: |
|
print_yellow(msg) |
|
# Only prepend the system message ONCE |
|
full_messages = [{"role": "system", "content": self.llm.system_message}] + list( |
|
messages |
|
) |
|
|
|
question = self._latest_user_message(messages) |
|
ids_part = None |
|
if 'INTRESSENT_IDS' in question: |
|
question_parts = question.split('INTRESSENT_IDS') |
|
question = question_parts[0].strip() |
|
ids_part = question_parts[1].strip() |
|
question = f"""A user has asked: |
|
*{question}*\n |
|
Make sure to understand the question and plan your research accordingly. |
|
If it is in Swedish, make sure to understand it correctly. |
|
If you need to clarify the question, ask the user to clarify.""" |
|
if ids_part: |
|
question += f"""\nAs the user is interested in a certain person or persons, you can use the following list of intressent_id:s to find relevant speeches:\n{ids_part}.""" |
|
if not question: |
|
raise ValueError("Conversation must contain at least one user message.") |
|
print_yellow( |
|
f"[ChatService] Generating answer for {len(messages)} messages (top_k={top_k})." |
|
) |
|
collected_sources: List[ChatSource] = [] |
|
collected_tables: List[Dict[str, Any]] = [] |
|
response_message, tables, updated_focus_ids = self._run_tool_loop( |
|
full_messages, |
|
collected_sources, |
|
collected_tables, |
|
list(focus_ids or []), |
|
) |
|
answer_text = ( |
|
response_message.final_answer |
|
if isinstance(response_message, FinalAnswer) |
|
else str(response_message) |
|
).strip() |
|
deduped_sources = ( |
|
self._deduplicate_sources(collected_sources, limit=top_k) |
|
if collected_sources |
|
else [] |
|
) |
|
print_green( |
|
f"[ChatService] Completed answer with {len(deduped_sources)} collected sources." |
|
) |
|
return { |
|
"answer": answer_text, |
|
"sources": deduped_sources, |
|
"tables": tables, |
|
"focus_ids": updated_focus_ids, |
|
} |
|
|
|
def _run_tool_loop( |
|
self, |
|
messages: Sequence[ChatMessage], |
|
collected_sources: List[ChatSource], |
|
collected_tables: List[Dict[str, Any]], |
|
initial_focus_ids: List[str], |
|
) -> Tuple[FinalAnswer, List[Dict[str, Any]], List[str]]: |
|
""" |
|
Repeatedly call the LLM, executing tool calls as needed, until a final answer is produced. |
|
|
|
Args: |
|
messages: The current chat history (including system message). |
|
collected_sources: List to collect sources from tool results. |
|
|
|
Returns: |
|
Tuple containing the final assistant message, any tables to forward to the user, and the latest focus id list. |
|
""" |
|
print_purple("[ChatService] Starting tool interaction loop.") |
|
current_messages: List[ChatMessage] = list(messages) |
|
active_focus_ids: List[str] = list(dict.fromkeys(initial_focus_ids)) |
|
if active_focus_ids: |
|
current_messages.append( |
|
{ |
|
"role": "user", |
|
"content": ( |
|
"Du har tidigare delat sökresultat med användaren. " |
|
"Listan `focus_ids` innehåller deras dokument-id:n:\n" |
|
f"{active_focus_ids}\n" |
|
"Om du vill begränsa en ny arango_search till samma träffar anger du argumentet " |
|
"`focus_ids=focus_ids`." |
|
), |
|
} |
|
) |
|
|
|
for i in range(self.max_tool_iterations): |
|
|
|
if i == self.max_tool_iterations - 1: |
|
print_red(f"[ChatService] Reached max iterations ({self.max_tool_iterations}). Forcing final answer.") |
|
current_messages.append({'role': 'user', 'content': '**IMPORTANT** You have reached the maximum number of tool calls. Please provide your final answer based on the information you have gathered so far.'}) |
|
|
|
response: ChatCompletionMessage = self.llm.generate( |
|
messages=current_messages, |
|
tools=self.tools, |
|
model="vllm", |
|
) |
|
# Use duck typing: check for expected attribute instead of strict type |
|
# This avoids issues if there are multiple ChatCompletionMessage classes in the project |
|
|
|
# The following code should NOT be inside the if-block! |
|
# Use getattr so this doesn't raise AttributeError when the model |
|
# doesn't return a reasoning/thinking block (which is the normal case). |
|
thinking = getattr(response, "reasoning_content", None) |
|
if thinking: |
|
print_blue("Thinking:", thinking) |
|
try: |
|
print_purple("Content:", response.content) |
|
except Exception as e: |
|
print_red(f"[ChatService] Error printing content response: {e}") |
|
|
|
tool_calls = getattr(response, "tool_calls", None) |
|
if tool_calls: |
|
reasoning_content_attr = getattr(response, "reasoning_content", None) |
|
if reasoning_content_attr: |
|
if isinstance(reasoning_content_attr, dict) and "content" in reasoning_content_attr: |
|
reasoning_content = reasoning_content_attr["content"] |
|
else: |
|
reasoning_content = str(reasoning_content_attr) |
|
current_messages.append( |
|
{ |
|
"role": "assistant", |
|
"content": reasoning_content, |
|
} |
|
) |
|
print_blue( |
|
f"[ChatService] Model requested {len(tool_calls)} tool call(s)." |
|
) |
|
for tool_call in tool_calls: |
|
tool_name = tool_call.function.name |
|
tool_args = tool_call.function.arguments |
|
if isinstance(tool_args, dict) and "focus_ids" in tool_args: |
|
# Allow the model to pass focus_ids=True or "focus_ids" as a shorthand for "use the saved list". |
|
requested_focus = tool_args["focus_ids"] |
|
if requested_focus is True or ( |
|
isinstance(requested_focus, str) and requested_focus.strip().lower() == "focus_ids" |
|
): |
|
tool_args["focus_ids"] = list(active_focus_ids) |
|
elif requested_focus in (False, None) and not active_focus_ids: |
|
tool_args.pop("focus_ids") |
|
print_blue( |
|
f"[ChatService] Executing tool: {tool_name} with args: {tool_args}" |
|
) |
|
|
|
tool_func = self._get_tool_function(tool_name) |
|
if tool_func is None: |
|
print_blue( |
|
f"[ChatService] Tool function '{tool_name}' not found!" |
|
) |
|
tool_result = f"ERROR: Tool '{tool_name}' not found." |
|
else: |
|
try: |
|
tool_result = tool_func(**tool_args) |
|
except Exception as e: |
|
print_red( |
|
f"[ChatService] Exception in tool '{tool_name}': {e}" |
|
) |
|
import traceback |
|
|
|
traceback.print_exc() |
|
tool_result = f"ERROR: {e}" |
|
if isinstance(tool_result, dict) and tool_result.get("type") == "search_results": |
|
payload = tool_result.get("payload", {}) |
|
results = payload.get("results", []) |
|
focus_ids_from_tool = [ |
|
item["_id"] for item in results if isinstance(item, dict) and item.get("_id") |
|
] |
|
payload["focus_ids"] = focus_ids_from_tool |
|
collected_tables.append(payload) |
|
active_focus_ids = focus_ids_from_tool or active_focus_ids |
|
|
|
requested_direct_delivery = False |
|
if isinstance(tool_args, dict) and "results_to_user" in tool_args: |
|
requested_direct_delivery = self._is_truthy_flag( |
|
tool_args.get("results_to_user") |
|
) |
|
|
|
if requested_direct_delivery: |
|
hit_count = len(results) |
|
query_value = "" |
|
if isinstance(tool_args, dict): |
|
query_value = str(tool_args.get("query") or "").strip() |
|
if hit_count: |
|
summary_text = ( |
|
f"Jag hittade {hit_count} träffar" |
|
f"""{f' för sökningen "{query_value}"' if query_value else ''}. """ |
|
"Tabellen visar detaljerna. Ställ gärna följdfrågor om du vill veta mer." |
|
) |
|
else: |
|
summary_text = ( |
|
"Sökningen gav tyvärr inga träffar. " |
|
"Du kan justera sökvillkoren eller ställa en ny fråga." |
|
) |
|
explanation = ( |
|
"Resultatet skickades direkt till användaren eftersom verktygsanropet angav results_to_user=True." |
|
) |
|
final_message = FinalAnswer( |
|
final_answer=summary_text, |
|
explanation=explanation, |
|
) |
|
return final_message, collected_tables, active_focus_ids |
|
|
|
tool_result_string = json.dumps(payload, ensure_ascii=False) |
|
current_messages.append( |
|
{ |
|
"role": "system", |
|
"content": ( |
|
"Spara listan `focus_ids` för uppföljande frågor:\n" |
|
f"focus_ids = {focus_ids_from_tool}\n" |
|
"När du behöver arbeta vidare med dessa dokument använder du arango_search med argumentet " |
|
"`focus_ids=<denna_lista>`." |
|
), |
|
} |
|
) |
|
else: |
|
tool_result_string = str(tool_result) |
|
if len(tool_result_string) > 12000: |
|
print_red( |
|
f"[ChatService] Tool result too long ({len(str(tool_result))} chars), truncating." |
|
) |
|
tool_result_string = ( |
|
f"{tool_result_string[:12000]} (...) [truncated]" |
|
) |
|
|
|
reminder = '\n\n**Remember:**\n- You can only use information from tool results when giving your final answer.\n- Do not make up facts or quotes.\n- If you lack information, say so or call another tool.\n- **Always use inline citations in the format [1], [2], [3] etc. Do NOT use [^1] or 【1】.**\n- **Always include a "Källor" section at the end with matching numbered sources.**\n- Always format your final answer in Markdown.\n- Always answer in Swedish.' |
|
tool_message = { |
|
"role": "tool", |
|
"name": tool_name, |
|
"content": f"Result from calling {tool_name}:\n{tool_result_string}.{reminder}", |
|
} |
|
if "ERROR" in tool_result_string: |
|
print_red( |
|
f"[ChatService] Tool result for '{tool_name.upper()}': {tool_message['content'][:200]}..." |
|
) |
|
else: |
|
print_green( |
|
f"[ChatService] Tool result for '{tool_name.upper()}': {tool_message['content'][:200]}..." |
|
) |
|
current_messages.append(tool_message) |
|
# Continue the loop with the updated message history (do NOT add system message again) |
|
continue |
|
elif response.content: |
|
final_content = getattr(response, "content", "") |
|
final_message = FinalAnswer( |
|
final_answer=final_content, |
|
explanation="Model provided a direct answer without requiring additional tools." |
|
) |
|
return final_message, collected_tables, active_focus_ids |
|
|
|
|
|
def _get_tool_function(self, tool_name: str): |
|
""" |
|
Retrieve the Python function for a given tool name. |
|
|
|
Args: |
|
tool_name: The name of the tool as specified in the tool call. |
|
|
|
Returns:e |
|
The Python function, or None if not found. |
|
""" |
|
# This assumes your tools are registered in backend.services.llm_tools |
|
# and get_tools() returns a list of tool specs with .name and .function |
|
for tool in self.tools: |
|
if hasattr(tool, "name") and tool.name == tool_name: |
|
return getattr(tool, "function", None) |
|
# Fallback: try to import from backend.services.llm_tools |
|
try: |
|
import backend.services.llm_tools as llm_tools |
|
|
|
return getattr(llm_tools, tool_name, None) |
|
except Exception: |
|
print_red(f"[ChatService] Could not import tool '{tool_name}'.") |
|
return None |
|
|
|
def _latest_user_message(self, messages: Sequence[ChatMessage]) -> str: |
|
""" |
|
Fetch the most recent user utterance from the chat history. |
|
""" |
|
for message in reversed(messages): |
|
if message.get("role") == "user": |
|
return message.get("content", "").strip() |
|
return "" |
|
|
|
def _normalize_chunk_index(self, value: Any, default: int = -1) -> int: |
|
""" |
|
Convert raw chunk index values from tool outputs into integers. |
|
|
|
Args: |
|
value: The raw chunk index produced by a tool (can be str, float, etc.). |
|
default: Fallback index used when conversion is not possible. |
|
|
|
Returns: |
|
An integer chunk index compatible with the API schema. |
|
""" |
|
if isinstance(value, bool): |
|
return default |
|
if isinstance(value, int): |
|
return value |
|
if isinstance(value, float) and value.is_integer(): |
|
return int(value) |
|
if isinstance(value, str): |
|
stripped = value.strip() |
|
if stripped.startswith("+"): |
|
stripped = stripped[1:] |
|
if stripped.lstrip("-").isdigit(): |
|
return int(stripped) |
|
return default |
|
|
|
def _deduplicate_sources( |
|
self, sources: List[ChatSource], limit: int |
|
) -> List[ChatSource]: |
|
""" |
|
Collapse duplicate tool outputs and enforce the requested limit. |
|
""" |
|
unique: Dict[tuple[Any, Any], ChatSource] = {} |
|
for source in sources: |
|
source_id = source.get("_id") or source.get("_id") |
|
chunk_index = self._normalize_chunk_index(source.get("chunk_index")) |
|
key = (source_id, chunk_index) |
|
if key in unique: |
|
continue |
|
snippet_value = source.get("snippet", "") |
|
snippet_text = self._trim_snippet(str(snippet_value)) |
|
unique[key] = { |
|
"_id": source_id, |
|
"heading": source.get("heading"), |
|
"snippet": snippet_text, |
|
"chunk_index": chunk_index, |
|
"debateurl": source.get("debateurl") or source.get("debate_url"), |
|
} |
|
max_items = max(1, limit) |
|
return list(unique.values())[:max_items] |
|
|
|
def _is_truthy_flag(self, value: Any) -> bool: |
|
""" |
|
Interpret diverse truthy representations (bools, strings, numbers) used in tool arguments. |
|
""" |
|
if isinstance(value, bool): |
|
return value |
|
if isinstance(value, (int, float)): |
|
return value != 0 |
|
if isinstance(value, str): |
|
normalized = value.strip().lower() |
|
return normalized in {"true", "1", "yes", "y", "ja"} |
|
return False |
|
|
|
def _trim_snippet(self, text: str, length: int = 400) -> str: |
|
""" |
|
Truncate long snippets to keep the UI readable. |
|
""" |
|
cleaned = text.strip() |
|
if len(cleaned) <= length: |
|
return cleaned |
|
return f"{cleaned[:length].rstrip()}…" |
|
|
|
|
|
# ---- Test code ---- |
|
if __name__ == "__main__": |
|
service = ChatService() |
|
print("Registered tools:") |
|
for tool in service.tools: |
|
print( |
|
f" - {tool['function']['name']} - {tool['function']['description'][:100]}..." |
|
) |
|
test_messages = [ |
|
{"role": "user", "content": "Hur många gånger har kärnkraft nämnts?"}, |
|
] |
|
if __name__ == "__main__": |
|
service = ChatService() |
|
print("Registered tools:") |
|
for tool in service.tools: |
|
print( |
|
f" - {tool['function']['name']} - {tool['function']['description'][:100]}..." |
|
) |
|
test_messages = [ |
|
{"role": "user", "content": "Hur många gånger har kärnkraft nämnts?"}, |
|
]
|
|
|