You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
360 lines
12 KiB
360 lines
12 KiB
from __future__ import annotations |
|
|
|
from dataclasses import dataclass |
|
from typing import Any, Dict, Iterable, List, Optional, Sequence |
|
from pydantic import BaseModel, Field |
|
|
|
import env_manager |
|
|
|
env_manager.set_env() |
|
|
|
from arango.collection import Collection # noqa: E402 |
|
from arango_client import arango # noqa: E402 |
|
from backend.services.search import SearchService # noqa: E402 |
|
from _chromadb.chroma_client import chroma_db # noqa: E402 |
|
|
|
|
|
class HitDocument(BaseModel): |
|
""" |
|
HitDocument is a Pydantic model that provides a normalized representation of a search hit across various tools, enabling consistent downstream handling. |
|
|
|
Attributes: |
|
id (Optional[str]): Fully qualified ArangoDB document identifier. |
|
key (Optional[str]): Document key without collection prefix. |
|
speaker (Optional[str]): Name of the speaker associated with the hit. |
|
party (Optional[str]): Party affiliation of the speaker. |
|
date (Optional[str]): ISO formatted document date (YYYY-MM-DD). |
|
snippet (Optional[str]): Contextual snippet or highlight from the document. |
|
text (Optional[str]): Full text of the document when available. |
|
score (Optional[float]): Relevance score supplied by the executing tool. |
|
metadata (Dict[str, Any]): Additional metadata specific to the originating tool that should be preserved. |
|
|
|
Methods: |
|
to_string() -> str: |
|
Renders the hit as a human-readable string with uppercase labels, including all present fields and metadata. |
|
""" |
|
|
|
"""Normalized representation of a search hit across tools to enable consistent downstream handling.""" |
|
id: Optional[str] = Field( |
|
default=None, description="Fully qualified ArangoDB document identifier." |
|
) |
|
key: Optional[str] = Field( |
|
default=None, description="Document key without collection prefix." |
|
) |
|
speaker: Optional[str] = Field( |
|
default=None, description="Name of the speaker associated with the hit." |
|
) |
|
party: Optional[str] = Field( |
|
default=None, description="Party affiliation of the speaker." |
|
) |
|
date: Optional[str] = Field( |
|
default=None, description="ISO formatted document date (YYYY-MM-DD)." |
|
) |
|
snippet: Optional[str] = Field( |
|
default=None, description="Contextual snippet or highlight from the document." |
|
) |
|
text: Optional[str] = Field( |
|
default=None, description="Full text of the document when available." |
|
) |
|
score: Optional[float] = Field( |
|
default=None, description="Relevance score supplied by the executing tool." |
|
) |
|
metadata: Dict[str, Any] = Field( |
|
default_factory=dict, |
|
description="Additional metadata specific to the originating tool that should be preserved.", |
|
) |
|
|
|
def to_string(self, include_metadata: bool = True) -> str: |
|
""" |
|
Render the object as a human-readable string with uppercase labels. |
|
|
|
Args: |
|
include_metadata (bool, optional): Whether to include metadata fields in the output. Defaults to True. |
|
|
|
Returns: |
|
str: A formatted string representation of the object, with each field and its value separated by double newlines, and field names in uppercase. |
|
""" |
|
data: Dict[str, Any] = self.model_dump(exclude_none=True) |
|
metadata: Dict[str, Any] = data.pop("metadata", {}) |
|
segments: List[str] = [] |
|
for field_name, field_value in data.items(): |
|
segments.append(f"{field_name.upper()}\n{field_value}") |
|
for meta_key, meta_value in metadata.items(): |
|
segments.append(f"{meta_key.upper()}\n{meta_value}") |
|
return "\n\n".join(segments) |
|
|
|
|
|
class HitsResponse(BaseModel): |
|
""" |
|
HitsResponse is a Pydantic model that serves as a container for multiple HitDocument instances, providing utility methods for formatting and rendering the collection. |
|
|
|
Attributes: |
|
hits (List[HitDocument]): A list of collected search hits. |
|
|
|
Methods: |
|
to_string() -> str: |
|
Returns a string representation of all hits, separated by a visual divider. If there are no hits, returns an empty string. |
|
""" |
|
|
|
hits: List[HitDocument] = Field( |
|
default_factory=list, description="Collected search hits." |
|
) |
|
|
|
def to_string(self, include_metadata=True) -> str: |
|
""" |
|
Render all hits as a single string, separated by a visual divider. |
|
|
|
Args: |
|
include_metadata (bool, optional): Whether to include metadata in each hit's string representation. Defaults to True. |
|
|
|
Returns: |
|
str: A single string containing all hits, separated by "\n\n---\n\n". Returns an empty string if there are no hits. |
|
""" |
|
"""Render all hits as a single string separated by a visual divider.""" |
|
if not self.hits: |
|
return "" |
|
return "\n\n---\n\n".join( |
|
hit.to_string(include_metadata=include_metadata) for hit in self.hits |
|
) |
|
|
|
|
|
|
|
def ensure_read_only_aql(query: str) -> None: |
|
""" |
|
Reject AQL statements that attempt to mutate data or omit a RETURN clause. |
|
|
|
Args: |
|
query: Raw AQL statement from the client. |
|
|
|
Raises: |
|
ValueError: If the query looks unsafe. |
|
""" |
|
normalized = query.upper() |
|
forbidden = ( |
|
"INSERT ", |
|
"UPDATE ", |
|
"UPSERT ", |
|
"REMOVE ", |
|
"REPLACE ", |
|
"DELETE ", |
|
"DROP ", |
|
"TRUNCATE ", |
|
"UPSERT ", |
|
"MERGE ", |
|
) |
|
if any(keyword in normalized for keyword in forbidden): |
|
raise ValueError("Only read-only AQL queries are allowed.") |
|
if " RETURN " not in normalized and not normalized.strip().startswith("RETURN "): |
|
raise ValueError("AQL queries must include a RETURN clause.") |
|
|
|
|
|
def strip_private_fields(document: Dict[str, Any]) -> Dict[str, Any]: |
|
""" |
|
Remove large internal fields from a document dictionary. |
|
|
|
Args: |
|
document: Document returned by ArangoDB. |
|
|
|
Returns: |
|
Sanitized copy without chunk payloads. |
|
""" |
|
if "chunks" in document: |
|
del document["chunks"] |
|
return document |
|
|
|
|
|
def search_documents(aql_query: str) -> Dict[str, Any]: |
|
""" |
|
Execute a read-only AQL query and return the result set together with the query string. |
|
|
|
Args: |
|
aql_query: Read-only AQL statement supplied by the client. |
|
|
|
Returns: |
|
Dictionary containing the executed AQL string, row count, and result rows. |
|
""" |
|
ensure_read_only_aql(aql_query) |
|
rows = [strip_private_fields(doc) for doc in arango.execute_aql(aql_query)] |
|
return { |
|
"aql": aql_query, |
|
"row_count": len(rows), |
|
"rows": rows, |
|
} |
|
|
|
|
|
def run_aql_query(aql_query: str) -> List[Dict[str, Any]]: |
|
""" |
|
Execute a read-only AQL query and return the rows. |
|
|
|
Args: |
|
aql_query: Read-only AQL statement. |
|
|
|
Returns: |
|
List of result rows. |
|
""" |
|
ensure_read_only_aql(aql_query) |
|
return [strip_private_fields(doc) for doc in arango.execute_aql(aql_query)] |
|
|
|
|
|
def _get_existing_collection(name: str) -> Collection: |
|
""" |
|
Fetch an existing Chroma collection without creating new data. |
|
|
|
Args: |
|
name: Collection identifier. |
|
|
|
Returns: |
|
The requested collection. |
|
|
|
Raises: |
|
ValueError: If the collection is absent. |
|
""" |
|
available = {collection.name for collection in chroma_db._client.list_collections()} |
|
if name not in available: |
|
raise ValueError(f"Chroma collection '{name}' does not exist.") |
|
return chroma_db._client.get_collection(name=name) |
|
|
|
|
|
def vector_search(query: str, limit: int) -> List[Dict[str, Any]]: |
|
""" |
|
Perform semantic search against the pre-built Chroma collection. |
|
|
|
Args: |
|
query: Free-form search text. |
|
limit: Maximum number of hits to return. |
|
|
|
Returns: |
|
List of hit dictionaries with metadata and scores. |
|
""" |
|
collection_name = chroma_db.path.split("/")[-1] # ...existing code... |
|
chroma_collection = _get_existing_collection(collection_name) |
|
results = chroma_collection.query( |
|
query_texts=[query], |
|
n_results=limit, |
|
) |
|
metadatas = results.get("metadatas") or [] |
|
documents = results.get("documents") or [] |
|
ids = results.get("ids") or [] |
|
distances = results.get("distances") or [] |
|
|
|
def as_int(value: Any, default: int = -1) -> int: |
|
if isinstance(value, int): |
|
return value |
|
if isinstance(value, float) and value.is_integer(): |
|
return int(value) |
|
if isinstance(value, str) and value.strip().lstrip("+-").isdigit(): |
|
return int(value) |
|
return default |
|
|
|
hits: List[Dict[str, Any]] = [] |
|
for index, metadata in enumerate(metadatas[0] if metadatas else []): |
|
meta = metadata or {} |
|
document = documents[0][index] if documents else "" |
|
identifier = ids[0][index] if ids else "" |
|
hit = { |
|
"_id": meta.get("_id") or identifier, |
|
"heading": meta.get("heading") or meta.get("title") or meta.get("talare"), |
|
"snippet": meta.get("snippet") or meta.get("text") or document, |
|
"debateurl": meta.get("debateurl") or meta.get("debate_url"), |
|
"chunk_index": as_int(meta.get("chunk_index") or meta.get("index")), |
|
"score": distances[0][index] if distances else None, |
|
} |
|
if hit["_id"]: |
|
hits.append(hit) |
|
return hits |
|
|
|
|
|
def fetch_documents(document_ids: Sequence[str], fields: Optional[Iterable[str]] = None) -> List[Dict[str, Any]]: |
|
""" |
|
Pull full documents by _id while stripping heavy fields. |
|
|
|
Args: |
|
document_ids: Iterable with fully qualified Arango document ids. |
|
fields: Optional subset of fields to return. |
|
|
|
Returns: |
|
List of sanitized documents. |
|
""" |
|
ids = [doc_id.replace("\\", "/") for doc_id in document_ids] |
|
query = """ |
|
FOR id IN @document_ids |
|
RETURN DOCUMENT(id) |
|
""" |
|
documents = arango.execute_aql(query, bind_vars={"document_ids": ids}) |
|
if fields: |
|
return [{field: doc.get(field) for field in fields if field in doc} for doc in documents] |
|
return [strip_private_fields(doc) for doc in documents] |
|
|
|
|
|
@dataclass |
|
class SearchPayload: |
|
""" |
|
Lightweight container passed to SearchService.search. |
|
""" |
|
q: str |
|
parties: Optional[List[str]] |
|
people: Optional[List[str]] |
|
debates: Optional[List[str]] |
|
from_year: Optional[int] |
|
to_year: Optional[int] |
|
limit: int |
|
return_snippets: bool |
|
focus_ids: Optional[List[str]] |
|
speaker_ids: Optional[List[str]] |
|
speaker: Optional[str] = None |
|
|
|
|
|
def arango_search( |
|
query: str, |
|
limit: int, |
|
parties: Optional[Sequence[str]] = None, |
|
people: Optional[Sequence[str]] = None, |
|
from_year: Optional[int] = None, |
|
to_year: Optional[int] = None, |
|
return_snippets: bool = False, |
|
focus_ids: Optional[Sequence[str]] = None, |
|
speaker_ids: Optional[Sequence[str]] = None, |
|
) -> Dict[str, Any]: |
|
""" |
|
Run an ArangoSearch query using the existing SearchService utilities. |
|
|
|
Args: |
|
query: Search expression (supports AND/OR/NOT and phrases). |
|
limit: Maximum number of hits to return. |
|
parties: Party filters. |
|
people: Speaker name filters. |
|
from_year: Start year filter. |
|
to_year: End year filter. |
|
return_snippets: Whether only snippets should be returned. |
|
focus_ids: Optional list restricting the search scope. |
|
speaker_ids: Optional list of speaker identifiers. |
|
|
|
Returns: |
|
Dictionary containing results, stats, limit flag, and focus_ids for follow-up queries. |
|
""" |
|
payload = SearchPayload( |
|
q=query, |
|
parties=list(parties) if parties else None, |
|
people=list(people) if people else None, |
|
debates=None, |
|
from_year=from_year, |
|
to_year=to_year, |
|
limit=limit, |
|
return_snippets=return_snippets, |
|
focus_ids=list(focus_ids) if focus_ids else None, |
|
speaker_ids=list(speaker_ids) if speaker_ids else None, |
|
) |
|
service = SearchService() |
|
results, stats, limit_reached = service.search( |
|
payload=payload, |
|
include_snippets=True, |
|
return_snippets=return_snippets, |
|
focus_ids=payload.focus_ids, |
|
) |
|
return { |
|
"results": results, |
|
"stats": stats, |
|
"limit_reached": limit_reached, |
|
"return_snippets": return_snippets, |
|
"focus_ids": [hit["_id"] for hit in results if isinstance(hit, dict) and hit.get("_id")], |
|
} |