You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
470 lines
22 KiB
470 lines
22 KiB
from __future__ import annotations |
|
|
|
from typing import Any, Dict, List, Sequence, Optional, Tuple |
|
from pydantic import BaseModel, Field |
|
import backend.services.llm_tools |
|
from _llm import LLM, get_tools, ChatCompletionMessage |
|
from colorprinter import * |
|
import json |
|
|
|
ChatResponse = Dict[str, Any] |
|
ChatSource = Dict[str, Any] |
|
ChatMessage = Dict[str, Any] |
|
|
|
|
|
class FirstAnswer(BaseModel): |
|
research_plan: str = Field(..., description="Your concise research plan") |
|
direct_reply: Optional[str] = Field( |
|
None, |
|
description="If the message is clearly not a question but a greeting or a comment, just respond in a friendly manner. Only use this if you're not making a plan!", |
|
) |
|
request_clarification: Optional[str] = Field( |
|
None, |
|
description="If you really don't understand the user's question, ask the user to clarify.", |
|
) |
|
|
|
|
|
class FinalAnswer(BaseModel): |
|
final_answer: str = Field(..., description="Your final answer") |
|
explanation: str = Field( |
|
..., |
|
description="Your short and non-technical explanation of how you arrived at the answer", |
|
) |
|
|
|
|
|
class ChatService: |
|
""" |
|
Handles retrieval-augmented replies by letting the LLM pick tools dynamically. |
|
""" |
|
|
|
def __init__(self) -> None: |
|
""" |
|
Prepare the LLM instance and cache available tool specifications. |
|
""" |
|
system_message = """ |
|
You help users find information in speeches from the Swedish Riksdag. You have several tools available to search the speeches database; use these tools whenever you need data not present in earlier messages. |
|
|
|
*Important operational rules:* |
|
- Always read each tool's description and arguments carefully before calling it; follow examples. |
|
- When presenting results, cite sources by mentioning the talk titles and dates when available. |
|
- You may call multiple tools in one conversation; if one tool doesn't return what you need, call another. |
|
- Summarize and analyze findings continuously so you know what you have and what you still need. My including things like _id:s and other valuable information in you reasoning, this will be stored to your memory. |
|
|
|
**Decision / tool-selection map:** |
|
- Use `vector_search_talks(query, limit)` for semantic / concept matches (conceptual similarity, thematic clustering). |
|
- Use `arango_search(query, parties, people, from_year, to_year, limit)` for ranked full-text searches (language-aware, boolean/phrase search, highlighted snippets). |
|
- Use `search_documents(query)` for exact/structured queries, joins, and aggregations (you must write AQL; see the tool's docstring for templates). |
|
- Use `fetch_documents(_ids)` to retrieve a list of entire documents when you need the full texts. |
|
|
|
You can only request a tool use, not use it directly. After you request a tool, wait for the results to come back as a message from the tool. Then analyze the results and decide your next step. |
|
|
|
**Notes on different tools:** |
|
- `search_arango` has three special features/parameters: |
|
1) `return_snippets` – If you want to get an overview of the results, use this parameter to get highlighted snippets instead of full documents. |
|
2) `results_to_user` – If the user has asked for e.g. "a list of talks mentioning...", "give me all speeches about..." – or in other ways indicates they want to see the actual results – set `return_to_user`=True so the results are sent to the user as they are. This is a good way of showing the user the actual search results, and save resources. |
|
3) `focus_ids` – If you've done a previous search where you've used `results_to_user=True`, and the user has then asked a follow-up question that requires a more specific search within the previous results, set this parameter to True to do a search within the ID:s from the last search. |
|
- `search_arango` will return a lot of text if not set to `return_snippets=True` or the limit is set to above 20. The result will then be trucated. Tip: Make use of the fact that the results from `search_arango` are ordered by relevance and start with the most relevant ones. |
|
|
|
|
|
**When giving your final answer:** |
|
- Always start with a short summary of your findings, before any detailed analysis or tables. |
|
- Respond concisely, the user is not here for small talk. |
|
- Make sure to include sources for your answer, but don't use the internal _id or chunk_index fields; instead, use date, title, etc. |
|
- When refering to a source, use foot notes like [1], [2], etc. at the end of the sentence where you mention it. *Remember to include a short bibliography at the end of your answer, listing all sources you used.* |
|
- Always format your final answer using Markdown (it will be translated to HTML by the frontend). |
|
- Don't ever make up quotes or facts; if you don't have enough information, say that you don't know, or call another tool to find more information. |
|
- Answer in Swedish. |
|
""" |
|
self.llm = LLM(model="vllm", system_message=system_message, temperature=0.15) |
|
self.tools = get_tools(exclude_tools=["sql_query"]) |
|
self.max_tool_iterations = 20 |
|
|
|
def get_chat_response( |
|
self, |
|
messages: Sequence[ChatMessage], |
|
top_k: int = 5, |
|
focus_ids: Optional[Sequence[str]] = None, |
|
) -> ChatResponse: |
|
""" |
|
Generate a reply while allowing the assistant to call registered tools. |
|
|
|
Args: |
|
messages: Ordered chat history including the latest user prompt. |
|
top_k: Maximum number of unique sources to expose to the client. |
|
focus_ids: Optional list of document ids shared with the user in earlier turns. |
|
Returns: |
|
Dict containing the assistant answer and harvested sources. |
|
""" |
|
print_yellow(f"Messages in chat:") |
|
for msg in messages: |
|
print_yellow(msg) |
|
# Only prepend the system message ONCE |
|
full_messages = [{"role": "system", "content": self.llm.system_message}] + list( |
|
messages |
|
) |
|
|
|
question = self._latest_user_message(messages) |
|
ids_part = None |
|
if 'INTRESSENT_IDS' in question: |
|
question_parts = question.split('INTRESSENT_IDS') |
|
question = question_parts[0].strip() |
|
ids_part = question_parts[1].strip() |
|
question = f"""A user has asked: |
|
*{question}*\n |
|
Make sure to understand the question and plan your research accordingly. |
|
If it is in Swedish, make sure to understand it correctly. |
|
If you need to clarify the question, ask the user to clarify.""" |
|
if ids_part: |
|
question += f"""\nAs the user is interested in a certain person or persons, you can use the following list of intressent_id:s to find relevant speeches:\n{ids_part}.""" |
|
if not question: |
|
raise ValueError("Conversation must contain at least one user message.") |
|
print_yellow( |
|
f"[ChatService] Generating answer for {len(messages)} messages (top_k={top_k})." |
|
) |
|
collected_sources: List[ChatSource] = [] |
|
collected_tables: List[Dict[str, Any]] = [] |
|
response_message, tables, updated_focus_ids = self._run_tool_loop( |
|
full_messages, |
|
collected_sources, |
|
collected_tables, |
|
list(focus_ids or []), |
|
) |
|
answer_text = ( |
|
response_message.final_answer |
|
if isinstance(response_message, FinalAnswer) |
|
else str(response_message) |
|
).strip() |
|
deduped_sources = ( |
|
self._deduplicate_sources(collected_sources, limit=top_k) |
|
if collected_sources |
|
else [] |
|
) |
|
print_green( |
|
f"[ChatService] Completed answer with {len(deduped_sources)} collected sources." |
|
) |
|
return { |
|
"answer": answer_text, |
|
"sources": deduped_sources, |
|
"tables": tables, |
|
"focus_ids": updated_focus_ids, |
|
} |
|
|
|
def _run_tool_loop( |
|
self, |
|
messages: Sequence[ChatMessage], |
|
collected_sources: List[ChatSource], |
|
collected_tables: List[Dict[str, Any]], |
|
initial_focus_ids: List[str], |
|
) -> Tuple[FinalAnswer, List[Dict[str, Any]], List[str]]: |
|
""" |
|
Repeatedly call the LLM, executing tool calls as needed, until a final answer is produced. |
|
|
|
Args: |
|
messages: The current chat history (including system message). |
|
collected_sources: List to collect sources from tool results. |
|
|
|
Returns: |
|
Tuple containing the final assistant message, any tables to forward to the user, and the latest focus id list. |
|
""" |
|
print_purple("[ChatService] Starting tool interaction loop.") |
|
current_messages: List[ChatMessage] = list(messages) |
|
active_focus_ids: List[str] = list(dict.fromkeys(initial_focus_ids)) |
|
if active_focus_ids: |
|
current_messages.append( |
|
{ |
|
"role": "user", |
|
"content": ( |
|
"Du har tidigare delat sökresultat med användaren. " |
|
"Listan `focus_ids` innehåller deras dokument-id:n:\n" |
|
f"{active_focus_ids}\n" |
|
"Om du vill begränsa en ny arango_search till samma träffar anger du argumentet " |
|
"`focus_ids=focus_ids`." |
|
), |
|
} |
|
) |
|
|
|
for i in range(self.max_tool_iterations): |
|
|
|
if i == self.max_tool_iterations - 1: |
|
print_red(f"[ChatService] Reached max iterations ({self.max_tool_iterations}). Forcing final answer.") |
|
current_messages.append({'role': 'user', 'content': '**IMPORTANT** You have reached the maximum number of tool calls. Please provide your final answer based on the information you have gathered so far.'}) |
|
|
|
response: ChatCompletionMessage = self.llm.generate( |
|
messages=current_messages, |
|
tools=self.tools, |
|
model="vllm", |
|
) |
|
# Use duck typing: check for expected attribute instead of strict type |
|
# This avoids issues if there are multiple ChatCompletionMessage classes in the project |
|
|
|
# The following code should NOT be inside the if-block! |
|
try: |
|
print_blue("Thinking:", response.reasoning_content) |
|
except Exception as e: |
|
print_red(f"[ChatService] Error printing thinking response: {e}") |
|
try: |
|
print_purple("Content:", response.content) |
|
except Exception as e: |
|
print_red(f"[ChatService] Error printing content response: {e}") |
|
|
|
tool_calls = getattr(response, "tool_calls", None) |
|
if tool_calls: |
|
if response.reasoning_content: |
|
if isinstance(response.reasoning_content, dict) and "content" in response.reasoning_content: |
|
reasoning_content = response.reasoning_content["content"] |
|
else: |
|
reasoning_content = str(response.reasoning_content) |
|
current_messages.append( |
|
{ |
|
"role": "assistant", |
|
"content": reasoning_content, |
|
} |
|
) |
|
print_blue( |
|
f"[ChatService] Model requested {len(tool_calls)} tool call(s)." |
|
) |
|
for tool_call in tool_calls: |
|
tool_name = tool_call.function.name |
|
tool_args = tool_call.function.arguments |
|
if isinstance(tool_args, dict) and "focus_ids" in tool_args: |
|
# Allow the model to pass focus_ids=True or "focus_ids" as a shorthand for "use the saved list". |
|
requested_focus = tool_args["focus_ids"] |
|
if requested_focus is True or ( |
|
isinstance(requested_focus, str) and requested_focus.strip().lower() == "focus_ids" |
|
): |
|
tool_args["focus_ids"] = list(active_focus_ids) |
|
elif requested_focus in (False, None) and not active_focus_ids: |
|
tool_args.pop("focus_ids") |
|
print_blue( |
|
f"[ChatService] Executing tool: {tool_name} with args: {tool_args}" |
|
) |
|
|
|
tool_func = self._get_tool_function(tool_name) |
|
if tool_func is None: |
|
print_blue( |
|
f"[ChatService] Tool function '{tool_name}' not found!" |
|
) |
|
tool_result = f"ERROR: Tool '{tool_name}' not found." |
|
else: |
|
try: |
|
tool_result = tool_func(**tool_args) |
|
except Exception as e: |
|
print_red( |
|
f"[ChatService] Exception in tool '{tool_name}': {e}" |
|
) |
|
import traceback |
|
|
|
traceback.print_exc() |
|
tool_result = f"ERROR: {e}" |
|
if isinstance(tool_result, dict) and tool_result.get("type") == "search_results": |
|
payload = tool_result.get("payload", {}) |
|
results = payload.get("results", []) |
|
focus_ids_from_tool = [ |
|
item["_id"] for item in results if isinstance(item, dict) and item.get("_id") |
|
] |
|
payload["focus_ids"] = focus_ids_from_tool |
|
collected_tables.append(payload) |
|
active_focus_ids = focus_ids_from_tool or active_focus_ids |
|
|
|
requested_direct_delivery = False |
|
if isinstance(tool_args, dict) and "results_to_user" in tool_args: |
|
requested_direct_delivery = self._is_truthy_flag( |
|
tool_args.get("results_to_user") |
|
) |
|
|
|
if requested_direct_delivery: |
|
hit_count = len(results) |
|
query_value = "" |
|
if isinstance(tool_args, dict): |
|
query_value = str(tool_args.get("query") or "").strip() |
|
if hit_count: |
|
summary_text = ( |
|
f"Jag hittade {hit_count} träffar" |
|
f"""{f' för sökningen "{query_value}"' if query_value else ''}. """ |
|
"Tabellen visar detaljerna. Ställ gärna följdfrågor om du vill veta mer." |
|
) |
|
else: |
|
summary_text = ( |
|
"Sökningen gav tyvärr inga träffar. " |
|
"Du kan justera sökvillkoren eller ställa en ny fråga." |
|
) |
|
explanation = ( |
|
"Resultatet skickades direkt till användaren eftersom verktygsanropet angav results_to_user=True." |
|
) |
|
final_message = FinalAnswer( |
|
final_answer=summary_text, |
|
explanation=explanation, |
|
) |
|
return final_message, collected_tables, active_focus_ids |
|
|
|
tool_result_string = json.dumps(payload, ensure_ascii=False) |
|
current_messages.append( |
|
{ |
|
"role": "system", |
|
"content": ( |
|
"Spara listan `focus_ids` för uppföljande frågor:\n" |
|
f"focus_ids = {focus_ids_from_tool}\n" |
|
"När du behöver arbeta vidare med dessa dokument använder du arango_search med argumentet " |
|
"`focus_ids=<denna_lista>`." |
|
), |
|
} |
|
) |
|
else: |
|
tool_result_string = str(tool_result) |
|
if len(tool_result_string) > 12000: |
|
print_red( |
|
f"[ChatService] Tool result too long ({len(str(tool_result))} chars), truncating." |
|
) |
|
tool_result_string = ( |
|
f"{tool_result_string[:12000]} (...) [truncated]" |
|
) |
|
|
|
reminder = '\n\n**Remember that you can only use the information you get using the tools when giving your final answer. Do not make up any facts or quotes. If you do not have enough information, say that you do not know, or call another tool to find more information. Always give you final answer in Swedish.**' |
|
tool_message = { |
|
"role": "tool", |
|
"name": tool_name, |
|
"content": f"Result from calling {tool_name}:\n{tool_result_string}.{reminder}", |
|
} |
|
if "ERROR" in tool_result_string: |
|
print_red( |
|
f"[ChatService] Tool result for '{tool_name.upper()}': {tool_message['content'][:200]}..." |
|
) |
|
else: |
|
print_green( |
|
f"[ChatService] Tool result for '{tool_name.upper()}': {tool_message['content'][:200]}..." |
|
) |
|
current_messages.append(tool_message) |
|
# Continue the loop with the updated message history (do NOT add system message again) |
|
continue |
|
elif response.content: |
|
final_content = getattr(response, "content", "") |
|
return final_content, collected_tables, active_focus_ids |
|
|
|
|
|
def _get_tool_function(self, tool_name: str): |
|
""" |
|
Retrieve the Python function for a given tool name. |
|
|
|
Args: |
|
tool_name: The name of the tool as specified in the tool call. |
|
|
|
Returns:e |
|
The Python function, or None if not found. |
|
""" |
|
# This assumes your tools are registered in backend.services.llm_tools |
|
# and get_tools() returns a list of tool specs with .name and .function |
|
for tool in self.tools: |
|
if hasattr(tool, "name") and tool.name == tool_name: |
|
return getattr(tool, "function", None) |
|
# Fallback: try to import from backend.services.llm_tools |
|
try: |
|
import backend.services.llm_tools as llm_tools |
|
|
|
return getattr(llm_tools, tool_name, None) |
|
except Exception: |
|
print_red(f"[ChatService] Could not import tool '{tool_name}'.") |
|
return None |
|
|
|
def _latest_user_message(self, messages: Sequence[ChatMessage]) -> str: |
|
""" |
|
Fetch the most recent user utterance from the chat history. |
|
""" |
|
for message in reversed(messages): |
|
if message.get("role") == "user": |
|
return message.get("content", "").strip() |
|
return "" |
|
|
|
def _normalize_chunk_index(self, value: Any, default: int = -1) -> int: |
|
""" |
|
Convert raw chunk index values from tool outputs into integers. |
|
|
|
Args: |
|
value: The raw chunk index produced by a tool (can be str, float, etc.). |
|
default: Fallback index used when conversion is not possible. |
|
|
|
Returns: |
|
An integer chunk index compatible with the API schema. |
|
""" |
|
if isinstance(value, bool): |
|
return default |
|
if isinstance(value, int): |
|
return value |
|
if isinstance(value, float) and value.is_integer(): |
|
return int(value) |
|
if isinstance(value, str): |
|
stripped = value.strip() |
|
if stripped.startswith("+"): |
|
stripped = stripped[1:] |
|
if stripped.lstrip("-").isdigit(): |
|
return int(stripped) |
|
return default |
|
|
|
def _deduplicate_sources( |
|
self, sources: List[ChatSource], limit: int |
|
) -> List[ChatSource]: |
|
""" |
|
Collapse duplicate tool outputs and enforce the requested limit. |
|
""" |
|
unique: Dict[tuple[Any, Any], ChatSource] = {} |
|
for source in sources: |
|
source_id = source.get("_id") or source.get("_id") |
|
chunk_index = self._normalize_chunk_index(source.get("chunk_index")) |
|
key = (source_id, chunk_index) |
|
if key in unique: |
|
continue |
|
snippet_value = source.get("snippet", "") |
|
snippet_text = self._trim_snippet(str(snippet_value)) |
|
unique[key] = { |
|
"_id": source_id, |
|
"heading": source.get("heading"), |
|
"snippet": snippet_text, |
|
"chunk_index": chunk_index, |
|
"debateurl": source.get("debateurl") or source.get("debate_url"), |
|
} |
|
max_items = max(1, limit) |
|
return list(unique.values())[:max_items] |
|
|
|
def _is_truthy_flag(self, value: Any) -> bool: |
|
""" |
|
Interpret diverse truthy representations (bools, strings, numbers) used in tool arguments. |
|
""" |
|
if isinstance(value, bool): |
|
return value |
|
if isinstance(value, (int, float)): |
|
return value != 0 |
|
if isinstance(value, str): |
|
normalized = value.strip().lower() |
|
return normalized in {"true", "1", "yes", "y", "ja"} |
|
return False |
|
|
|
def _trim_snippet(self, text: str, length: int = 400) -> str: |
|
""" |
|
Truncate long snippets to keep the UI readable. |
|
""" |
|
cleaned = text.strip() |
|
if len(cleaned) <= length: |
|
return cleaned |
|
return f"{cleaned[:length].rstrip()}…" |
|
|
|
|
|
# ---- Test code ---- |
|
if __name__ == "__main__": |
|
service = ChatService() |
|
print("Registered tools:") |
|
for tool in service.tools: |
|
print( |
|
f" - {tool['function']['name']} - {tool['function']['description'][:100]}..." |
|
) |
|
test_messages = [ |
|
{"role": "user", "content": "Hur många gånger har kärnkraft nämnts?"}, |
|
] |
|
if __name__ == "__main__": |
|
service = ChatService() |
|
print("Registered tools:") |
|
for tool in service.tools: |
|
print( |
|
f" - {tool['function']['name']} - {tool['function']['description'][:100]}..." |
|
) |
|
test_messages = [ |
|
{"role": "user", "content": "Hur många gånger har kärnkraft nämnts?"}, |
|
]
|
|
|