from __future__ import annotations from dataclasses import dataclass from typing import Any, Dict, Iterable, List, Optional, Sequence from pydantic import BaseModel, Field import env_manager env_manager.set_env() from arango.collection import Collection # noqa: E402 from arango_client import arango # noqa: E402 from backend.services.search import SearchService # noqa: E402 from _chromadb.chroma_client import chroma_db # noqa: E402 class HitDocument(BaseModel): """ HitDocument is a Pydantic model that provides a normalized representation of a search hit across various tools, enabling consistent downstream handling. Attributes: id (Optional[str]): Fully qualified ArangoDB document identifier. key (Optional[str]): Document key without collection prefix. speaker (Optional[str]): Name of the speaker associated with the hit. party (Optional[str]): Party affiliation of the speaker. date (Optional[str]): ISO formatted document date (YYYY-MM-DD). snippet (Optional[str]): Contextual snippet or highlight from the document. text (Optional[str]): Full text of the document when available. score (Optional[float]): Relevance score supplied by the executing tool. metadata (Dict[str, Any]): Additional metadata specific to the originating tool that should be preserved. Methods: to_string() -> str: Renders the hit as a human-readable string with uppercase labels, including all present fields and metadata. """ """Normalized representation of a search hit across tools to enable consistent downstream handling.""" id: Optional[str] = Field( default=None, description="Fully qualified ArangoDB document identifier." ) key: Optional[str] = Field( default=None, description="Document key without collection prefix." ) speaker: Optional[str] = Field( default=None, description="Name of the speaker associated with the hit." ) party: Optional[str] = Field( default=None, description="Party affiliation of the speaker." ) date: Optional[str] = Field( default=None, description="ISO formatted document date (YYYY-MM-DD)." ) snippet: Optional[str] = Field( default=None, description="Contextual snippet or highlight from the document." ) text: Optional[str] = Field( default=None, description="Full text of the document when available." ) score: Optional[float] = Field( default=None, description="Relevance score supplied by the executing tool." ) metadata: Dict[str, Any] = Field( default_factory=dict, description="Additional metadata specific to the originating tool that should be preserved.", ) def to_string(self, include_metadata: bool = True) -> str: """ Render the object as a human-readable string with uppercase labels. Args: include_metadata (bool, optional): Whether to include metadata fields in the output. Defaults to True. Returns: str: A formatted string representation of the object, with each field and its value separated by double newlines, and field names in uppercase. """ data: Dict[str, Any] = self.model_dump(exclude_none=True) metadata: Dict[str, Any] = data.pop("metadata", {}) segments: List[str] = [] for field_name, field_value in data.items(): segments.append(f"{field_name.upper()}\n{field_value}") for meta_key, meta_value in metadata.items(): segments.append(f"{meta_key.upper()}\n{meta_value}") return "\n\n".join(segments) class HitsResponse(BaseModel): """ HitsResponse is a Pydantic model that serves as a container for multiple HitDocument instances, providing utility methods for formatting and rendering the collection. Attributes: hits (List[HitDocument]): A list of collected search hits. Methods: to_string() -> str: Returns a string representation of all hits, separated by a visual divider. If there are no hits, returns an empty string. """ hits: List[HitDocument] = Field( default_factory=list, description="Collected search hits." ) def to_string(self, include_metadata=True) -> str: """ Render all hits as a single string, separated by a visual divider. Args: include_metadata (bool, optional): Whether to include metadata in each hit's string representation. Defaults to True. Returns: str: A single string containing all hits, separated by "\n\n---\n\n". Returns an empty string if there are no hits. """ """Render all hits as a single string separated by a visual divider.""" if not self.hits: return "" return "\n\n---\n\n".join( hit.to_string(include_metadata=include_metadata) for hit in self.hits ) def ensure_read_only_aql(query: str) -> None: """ Reject AQL statements that attempt to mutate data or omit a RETURN clause. Args: query: Raw AQL statement from the client. Raises: ValueError: If the query looks unsafe. """ normalized = query.upper() forbidden = ( "INSERT ", "UPDATE ", "UPSERT ", "REMOVE ", "REPLACE ", "DELETE ", "DROP ", "TRUNCATE ", "UPSERT ", "MERGE ", ) if any(keyword in normalized for keyword in forbidden): raise ValueError("Only read-only AQL queries are allowed.") if " RETURN " not in normalized and not normalized.strip().startswith("RETURN "): raise ValueError("AQL queries must include a RETURN clause.") def strip_private_fields(document: Dict[str, Any]) -> Dict[str, Any]: """ Remove large internal fields from a document dictionary. Args: document: Document returned by ArangoDB. Returns: Sanitized copy without chunk payloads. """ if "chunks" in document: del document["chunks"] return document def search_documents(aql_query: str) -> Dict[str, Any]: """ Execute a read-only AQL query and return the result set together with the query string. Args: aql_query: Read-only AQL statement supplied by the client. Returns: Dictionary containing the executed AQL string, row count, and result rows. """ ensure_read_only_aql(aql_query) rows = [strip_private_fields(doc) for doc in arango.execute_aql(aql_query)] return { "aql": aql_query, "row_count": len(rows), "rows": rows, } def run_aql_query(aql_query: str) -> List[Dict[str, Any]]: """ Execute a read-only AQL query and return the rows. Args: aql_query: Read-only AQL statement. Returns: List of result rows. """ ensure_read_only_aql(aql_query) return [strip_private_fields(doc) for doc in arango.execute_aql(aql_query)] def _get_existing_collection(name: str) -> Collection: """ Fetch an existing Chroma collection without creating new data. Args: name: Collection identifier. Returns: The requested collection. Raises: ValueError: If the collection is absent. """ available = {collection.name for collection in chroma_db._client.list_collections()} if name not in available: raise ValueError(f"Chroma collection '{name}' does not exist.") return chroma_db._client.get_collection(name=name) def vector_search(query: str, limit: int) -> List[Dict[str, Any]]: """ Perform semantic search against the pre-built Chroma collection. Args: query: Free-form search text. limit: Maximum number of hits to return. Returns: List of hit dictionaries with metadata and scores. """ collection_name = chroma_db.path.split("/")[-1] # ...existing code... chroma_collection = _get_existing_collection(collection_name) results = chroma_collection.query( query_texts=[query], n_results=limit, ) metadatas = results.get("metadatas") or [] documents = results.get("documents") or [] ids = results.get("ids") or [] distances = results.get("distances") or [] def as_int(value: Any, default: int = -1) -> int: if isinstance(value, int): return value if isinstance(value, float) and value.is_integer(): return int(value) if isinstance(value, str) and value.strip().lstrip("+-").isdigit(): return int(value) return default hits: List[Dict[str, Any]] = [] for index, metadata in enumerate(metadatas[0] if metadatas else []): meta = metadata or {} document = documents[0][index] if documents else "" identifier = ids[0][index] if ids else "" hit = { "_id": meta.get("_id") or identifier, "heading": meta.get("heading") or meta.get("title") or meta.get("talare"), "snippet": meta.get("snippet") or meta.get("text") or document, "debateurl": meta.get("debateurl") or meta.get("debate_url"), "chunk_index": as_int(meta.get("chunk_index") or meta.get("index")), "score": distances[0][index] if distances else None, } if hit["_id"]: hits.append(hit) return hits def fetch_documents(document_ids: Sequence[str], fields: Optional[Iterable[str]] = None) -> List[Dict[str, Any]]: """ Pull full documents by _id while stripping heavy fields. Args: document_ids: Iterable with fully qualified Arango document ids. fields: Optional subset of fields to return. Returns: List of sanitized documents. """ ids = [doc_id.replace("\\", "/") for doc_id in document_ids] query = """ FOR id IN @document_ids RETURN DOCUMENT(id) """ documents = arango.execute_aql(query, bind_vars={"document_ids": ids}) if fields: return [{field: doc.get(field) for field in fields if field in doc} for doc in documents] return [strip_private_fields(doc) for doc in documents] @dataclass class SearchPayload: """ Lightweight container passed to SearchService.search. """ q: str parties: Optional[List[str]] people: Optional[List[str]] debates: Optional[List[str]] from_year: Optional[int] to_year: Optional[int] limit: int return_snippets: bool focus_ids: Optional[List[str]] speaker_ids: Optional[List[str]] speaker: Optional[str] = None def arango_search( query: str, limit: int, parties: Optional[Sequence[str]] = None, people: Optional[Sequence[str]] = None, from_year: Optional[int] = None, to_year: Optional[int] = None, return_snippets: bool = False, focus_ids: Optional[Sequence[str]] = None, speaker_ids: Optional[Sequence[str]] = None, ) -> Dict[str, Any]: """ Run an ArangoSearch query using the existing SearchService utilities. Args: query: Search expression (supports AND/OR/NOT and phrases). limit: Maximum number of hits to return. parties: Party filters. people: Speaker name filters. from_year: Start year filter. to_year: End year filter. return_snippets: Whether only snippets should be returned. focus_ids: Optional list restricting the search scope. speaker_ids: Optional list of speaker identifiers. Returns: Dictionary containing results, stats, limit flag, and focus_ids for follow-up queries. """ payload = SearchPayload( q=query, parties=list(parties) if parties else None, people=list(people) if people else None, debates=None, from_year=from_year, to_year=to_year, limit=limit, return_snippets=return_snippets, focus_ids=list(focus_ids) if focus_ids else None, speaker_ids=list(speaker_ids) if speaker_ids else None, ) service = SearchService() results, stats, limit_reached = service.search( payload=payload, include_snippets=True, return_snippets=return_snippets, focus_ids=payload.focus_ids, ) return { "results": results, "stats": stats, "limit_reached": limit_reached, "return_snippets": return_snippets, "focus_ids": [hit["_id"] for hit in results if isinstance(hit, dict) and hit.get("_id")], }