You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

360 lines
12 KiB

from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Sequence
from pydantic import BaseModel, Field
import env_manager
env_manager.set_env()
from arango.collection import Collection # noqa: E402
from arango_client import arango # noqa: E402
from backend.services.search import SearchService # noqa: E402
from _chromadb.chroma_client import chroma_db # noqa: E402
class HitDocument(BaseModel):
"""
HitDocument is a Pydantic model that provides a normalized representation of a search hit across various tools, enabling consistent downstream handling.
Attributes:
id (Optional[str]): Fully qualified ArangoDB document identifier.
key (Optional[str]): Document key without collection prefix.
speaker (Optional[str]): Name of the speaker associated with the hit.
party (Optional[str]): Party affiliation of the speaker.
date (Optional[str]): ISO formatted document date (YYYY-MM-DD).
snippet (Optional[str]): Contextual snippet or highlight from the document.
text (Optional[str]): Full text of the document when available.
score (Optional[float]): Relevance score supplied by the executing tool.
metadata (Dict[str, Any]): Additional metadata specific to the originating tool that should be preserved.
Methods:
to_string() -> str:
Renders the hit as a human-readable string with uppercase labels, including all present fields and metadata.
"""
"""Normalized representation of a search hit across tools to enable consistent downstream handling."""
id: Optional[str] = Field(
default=None, description="Fully qualified ArangoDB document identifier."
)
key: Optional[str] = Field(
default=None, description="Document key without collection prefix."
)
speaker: Optional[str] = Field(
default=None, description="Name of the speaker associated with the hit."
)
party: Optional[str] = Field(
default=None, description="Party affiliation of the speaker."
)
date: Optional[str] = Field(
default=None, description="ISO formatted document date (YYYY-MM-DD)."
)
snippet: Optional[str] = Field(
default=None, description="Contextual snippet or highlight from the document."
)
text: Optional[str] = Field(
default=None, description="Full text of the document when available."
)
score: Optional[float] = Field(
default=None, description="Relevance score supplied by the executing tool."
)
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Additional metadata specific to the originating tool that should be preserved.",
)
def to_string(self, include_metadata: bool = True) -> str:
"""
Render the object as a human-readable string with uppercase labels.
Args:
include_metadata (bool, optional): Whether to include metadata fields in the output. Defaults to True.
Returns:
str: A formatted string representation of the object, with each field and its value separated by double newlines, and field names in uppercase.
"""
data: Dict[str, Any] = self.model_dump(exclude_none=True)
metadata: Dict[str, Any] = data.pop("metadata", {})
segments: List[str] = []
for field_name, field_value in data.items():
segments.append(f"{field_name.upper()}\n{field_value}")
for meta_key, meta_value in metadata.items():
segments.append(f"{meta_key.upper()}\n{meta_value}")
return "\n\n".join(segments)
class HitsResponse(BaseModel):
"""
HitsResponse is a Pydantic model that serves as a container for multiple HitDocument instances, providing utility methods for formatting and rendering the collection.
Attributes:
hits (List[HitDocument]): A list of collected search hits.
Methods:
to_string() -> str:
Returns a string representation of all hits, separated by a visual divider. If there are no hits, returns an empty string.
"""
hits: List[HitDocument] = Field(
default_factory=list, description="Collected search hits."
)
def to_string(self, include_metadata=True) -> str:
"""
Render all hits as a single string, separated by a visual divider.
Args:
include_metadata (bool, optional): Whether to include metadata in each hit's string representation. Defaults to True.
Returns:
str: A single string containing all hits, separated by "\n\n---\n\n". Returns an empty string if there are no hits.
"""
"""Render all hits as a single string separated by a visual divider."""
if not self.hits:
return ""
return "\n\n---\n\n".join(
hit.to_string(include_metadata=include_metadata) for hit in self.hits
)
def ensure_read_only_aql(query: str) -> None:
"""
Reject AQL statements that attempt to mutate data or omit a RETURN clause.
Args:
query: Raw AQL statement from the client.
Raises:
ValueError: If the query looks unsafe.
"""
normalized = query.upper()
forbidden = (
"INSERT ",
"UPDATE ",
"UPSERT ",
"REMOVE ",
"REPLACE ",
"DELETE ",
"DROP ",
"TRUNCATE ",
"UPSERT ",
"MERGE ",
)
if any(keyword in normalized for keyword in forbidden):
raise ValueError("Only read-only AQL queries are allowed.")
if " RETURN " not in normalized and not normalized.strip().startswith("RETURN "):
raise ValueError("AQL queries must include a RETURN clause.")
def strip_private_fields(document: Dict[str, Any]) -> Dict[str, Any]:
"""
Remove large internal fields from a document dictionary.
Args:
document: Document returned by ArangoDB.
Returns:
Sanitized copy without chunk payloads.
"""
if "chunks" in document:
del document["chunks"]
return document
def search_documents(aql_query: str) -> Dict[str, Any]:
"""
Execute a read-only AQL query and return the result set together with the query string.
Args:
aql_query: Read-only AQL statement supplied by the client.
Returns:
Dictionary containing the executed AQL string, row count, and result rows.
"""
ensure_read_only_aql(aql_query)
rows = [strip_private_fields(doc) for doc in arango.execute_aql(aql_query)]
return {
"aql": aql_query,
"row_count": len(rows),
"rows": rows,
}
def run_aql_query(aql_query: str) -> List[Dict[str, Any]]:
"""
Execute a read-only AQL query and return the rows.
Args:
aql_query: Read-only AQL statement.
Returns:
List of result rows.
"""
ensure_read_only_aql(aql_query)
return [strip_private_fields(doc) for doc in arango.execute_aql(aql_query)]
def _get_existing_collection(name: str) -> Collection:
"""
Fetch an existing Chroma collection without creating new data.
Args:
name: Collection identifier.
Returns:
The requested collection.
Raises:
ValueError: If the collection is absent.
"""
available = {collection.name for collection in chroma_db._client.list_collections()}
if name not in available:
raise ValueError(f"Chroma collection '{name}' does not exist.")
return chroma_db._client.get_collection(name=name)
def vector_search(query: str, limit: int) -> List[Dict[str, Any]]:
"""
Perform semantic search against the pre-built Chroma collection.
Args:
query: Free-form search text.
limit: Maximum number of hits to return.
Returns:
List of hit dictionaries with metadata and scores.
"""
collection_name = chroma_db.path.split("/")[-1] # ...existing code...
chroma_collection = _get_existing_collection(collection_name)
results = chroma_collection.query(
query_texts=[query],
n_results=limit,
)
metadatas = results.get("metadatas") or []
documents = results.get("documents") or []
ids = results.get("ids") or []
distances = results.get("distances") or []
def as_int(value: Any, default: int = -1) -> int:
if isinstance(value, int):
return value
if isinstance(value, float) and value.is_integer():
return int(value)
if isinstance(value, str) and value.strip().lstrip("+-").isdigit():
return int(value)
return default
hits: List[Dict[str, Any]] = []
for index, metadata in enumerate(metadatas[0] if metadatas else []):
meta = metadata or {}
document = documents[0][index] if documents else ""
identifier = ids[0][index] if ids else ""
hit = {
"_id": meta.get("_id") or identifier,
"heading": meta.get("heading") or meta.get("title") or meta.get("talare"),
"snippet": meta.get("snippet") or meta.get("text") or document,
"debateurl": meta.get("debateurl") or meta.get("debate_url"),
"chunk_index": as_int(meta.get("chunk_index") or meta.get("index")),
"score": distances[0][index] if distances else None,
}
if hit["_id"]:
hits.append(hit)
return hits
def fetch_documents(document_ids: Sequence[str], fields: Optional[Iterable[str]] = None) -> List[Dict[str, Any]]:
"""
Pull full documents by _id while stripping heavy fields.
Args:
document_ids: Iterable with fully qualified Arango document ids.
fields: Optional subset of fields to return.
Returns:
List of sanitized documents.
"""
ids = [doc_id.replace("\\", "/") for doc_id in document_ids]
query = """
FOR id IN @document_ids
RETURN DOCUMENT(id)
"""
documents = arango.execute_aql(query, bind_vars={"document_ids": ids})
if fields:
return [{field: doc.get(field) for field in fields if field in doc} for doc in documents]
return [strip_private_fields(doc) for doc in documents]
@dataclass
class SearchPayload:
"""
Lightweight container passed to SearchService.search.
"""
q: str
parties: Optional[List[str]]
people: Optional[List[str]]
debates: Optional[List[str]]
from_year: Optional[int]
to_year: Optional[int]
limit: int
return_snippets: bool
focus_ids: Optional[List[str]]
speaker_ids: Optional[List[str]]
speaker: Optional[str] = None
def arango_search(
query: str,
limit: int,
parties: Optional[Sequence[str]] = None,
people: Optional[Sequence[str]] = None,
from_year: Optional[int] = None,
to_year: Optional[int] = None,
return_snippets: bool = False,
focus_ids: Optional[Sequence[str]] = None,
speaker_ids: Optional[Sequence[str]] = None,
) -> Dict[str, Any]:
"""
Run an ArangoSearch query using the existing SearchService utilities.
Args:
query: Search expression (supports AND/OR/NOT and phrases).
limit: Maximum number of hits to return.
parties: Party filters.
people: Speaker name filters.
from_year: Start year filter.
to_year: End year filter.
return_snippets: Whether only snippets should be returned.
focus_ids: Optional list restricting the search scope.
speaker_ids: Optional list of speaker identifiers.
Returns:
Dictionary containing results, stats, limit flag, and focus_ids for follow-up queries.
"""
payload = SearchPayload(
q=query,
parties=list(parties) if parties else None,
people=list(people) if people else None,
debates=None,
from_year=from_year,
to_year=to_year,
limit=limit,
return_snippets=return_snippets,
focus_ids=list(focus_ids) if focus_ids else None,
speaker_ids=list(speaker_ids) if speaker_ids else None,
)
service = SearchService()
results, stats, limit_reached = service.search(
payload=payload,
include_snippets=True,
return_snippets=return_snippets,
focus_ids=payload.focus_ids,
)
return {
"results": results,
"stats": stats,
"limit_reached": limit_reached,
"return_snippets": return_snippets,
"focus_ids": [hit["_id"] for hit in results if isinstance(hit, dict) and hit.get("_id")],
}