You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
331 lines
14 KiB
331 lines
14 KiB
import chromadb |
|
import os |
|
import sys |
|
# Set /home/lasse/riksdagen as working directory |
|
os.chdir("/home/lasse/riksdagen") |
|
sys.path.append("/home/lasse/riksdagen") |
|
|
|
from chromadb import Collection |
|
from chromadb.api import ClientAPI |
|
|
|
import bootstrap # Ensure sys.path and working directory are set |
|
from config import chromadb_path, embedding_model |
|
import re |
|
from typing import Dict, List, Any, Tuple, Optional |
|
from chromadb.utils.embedding_functions import OllamaEmbeddingFunction |
|
|
|
class ChromaClient: |
|
def __init__(self, path: str | None = chromadb_path): |
|
self.path: str = path |
|
self._client: ClientAPI = self._init_client() |
|
self.embedding_function = OllamaEmbeddingFunction(model_name=embedding_model, url='192.168.1.10:33405') |
|
|
|
def _init_client(self) -> chromadb.PersistentClient: |
|
return chromadb.PersistentClient(path=self.path) |
|
|
|
def get_collection(self, name: str) -> Collection: |
|
available_collections = {col.name for col in self._client.list_collections()} |
|
if name not in available_collections: |
|
self.create_collection(name=name) |
|
return self._client.get_collection(name=name) |
|
|
|
def create_collection(self, name: str) -> Collection: |
|
return self._client.get_or_create_collection(name=name, embedding_function=self.embedding_function) |
|
|
|
def parse_search_query(self, query: str) -> Tuple[Optional[Dict], Optional[Dict]]: |
|
""" |
|
Parse a Google-like search query into ChromaDB metadata and document filters. |
|
|
|
Supports syntax like: |
|
- Field searches: author:Smith, year:2020, category:politics |
|
- Comparisons: year:>2020, year:<=2024, year:>=2020 |
|
- Ranges: year:2020..2024 (equivalent to year:>=2020 AND year:<=2024) |
|
- Document content: document_contains:"climate change", document_regex:"\\d{4}" |
|
- Logical operators: AND, OR, NOT |
|
- Grouping: (author:Smith OR author:Johnson) AND year:>2020 |
|
- Quoted phrases: author:"John Smith" (handles spaces in values) |
|
|
|
Args: |
|
query (str): The search query string in Google-like syntax |
|
|
|
Returns: |
|
Tuple[Optional[Dict], Optional[Dict]]: A tuple containing: |
|
- metadata_filter: ChromaDB metadata filter dict (or None if no metadata filters) |
|
- document_filter: ChromaDB document filter dict (or None if no document filters) |
|
|
|
Examples: |
|
>>> client.parse_search_query("author:Smith AND year:>2020") |
|
({"$and": [{"author": "Smith"}, {"year": {"$gt": 2020}}]}, None) |
|
|
|
>>> client.parse_search_query("year:2020..2024 AND document_contains:'climate'") |
|
({"$and": [{"year": {"$gte": 2020}}, {"year": {"$lte": 2024}}]}, {"$contains": "climate"}) |
|
""" |
|
if not query or not query.strip(): |
|
return None, None |
|
|
|
# Normalize the query - convert to uppercase for operators, but preserve field values |
|
query = query.strip() |
|
|
|
# Split the query into tokens, preserving quoted strings and operators |
|
# This regex matches: quoted strings, field:value pairs, operators, parentheses |
|
tokens = re.findall(r'(?:"[^"]*"|\'[^\']*\'|\([^)]*\)|[^\s()]+)', query) |
|
|
|
metadata_conditions = [] |
|
document_conditions = [] |
|
|
|
# Process tokens and convert to conditions |
|
i = 0 |
|
while i < len(tokens): |
|
token = tokens[i].strip() |
|
|
|
# Skip logical operators and parentheses - they'll be handled in a more advanced parser |
|
if token.upper() in ['AND', 'OR', 'NOT', '(', ')']: |
|
i += 1 |
|
continue |
|
|
|
# Look for field:value patterns |
|
if ':' in token: |
|
field, value = token.split(':', 1) |
|
field = field.lower().strip() |
|
value = value.strip().strip('"\'') # Remove quotes if present |
|
|
|
# Handle document content searches |
|
if field in ['document_contains', 'doc_contains', 'contains']: |
|
document_conditions.append({"$contains": value}) |
|
elif field in ['document_regex', 'doc_regex', 'regex']: |
|
document_conditions.append({"$regex": value}) |
|
else: |
|
# Handle metadata field searches |
|
condition = self._parse_field_condition(field, value) |
|
if condition: |
|
metadata_conditions.append(condition) |
|
|
|
i += 1 |
|
|
|
# Build final filters |
|
metadata_filter = None |
|
if len(metadata_conditions) == 1: |
|
metadata_filter = metadata_conditions[0] |
|
elif len(metadata_conditions) > 1: |
|
# For now, combine all metadata conditions with AND |
|
# A more advanced parser could handle OR/NOT operators |
|
metadata_filter = {"$and": metadata_conditions} |
|
|
|
document_filter = None |
|
if len(document_conditions) == 1: |
|
document_filter = document_conditions[0] |
|
elif len(document_conditions) > 1: |
|
# Combine document conditions with AND |
|
document_filter = {"$and": document_conditions} |
|
|
|
return metadata_filter, document_filter |
|
|
|
def _parse_field_condition(self, field: str, value: str) -> Optional[Dict]: |
|
""" |
|
Parse a single field:value condition into a ChromaDB filter condition. |
|
|
|
Handles various syntaxes: |
|
- Simple equality: field:value -> {"field": "value"} |
|
- Comparisons: field:>10 -> {"field": {"$gt": 10}} |
|
- Ranges: field:2020..2024 -> expands to two conditions for >=2020 AND <=2024 |
|
|
|
Args: |
|
field (str): The field name (e.g., "year", "author", "category") |
|
value (str): The field value, possibly with operators or range syntax |
|
|
|
Returns: |
|
Optional[Dict]: ChromaDB filter condition dict, or None if parsing fails |
|
""" |
|
# Handle range syntax: field:start..end |
|
if '..' in value: |
|
try: |
|
start_str, end_str = value.split('..', 1) |
|
start_val = self._convert_value(start_str.strip()) |
|
end_val = self._convert_value(end_str.strip()) |
|
|
|
# Return as two separate conditions - caller should handle combining them |
|
# For now, we'll return the >= condition and let the caller handle the <= part |
|
# This is a limitation of this simple parser |
|
return {"field": field, "$gte": start_val, "$lte": end_val} |
|
except (ValueError, TypeError): |
|
# If range parsing fails, treat as literal string |
|
return {field: value} |
|
|
|
# Handle comparison operators: >, <, >=, <=, != |
|
comparison_match = re.match(r'^(>=|<=|>|<|!=)(.+)$', value) |
|
if comparison_match: |
|
operator, op_value = comparison_match.groups() |
|
op_value = op_value.strip() |
|
|
|
# Convert operator to ChromaDB syntax |
|
operator_map = { |
|
'>': '$gt', |
|
'<': '$lt', |
|
'>=': '$gte', |
|
'<=': '$lte', |
|
'!=': '$ne' |
|
} |
|
|
|
chroma_op = operator_map.get(operator) |
|
if chroma_op: |
|
converted_value = self._convert_value(op_value) |
|
return {field: {chroma_op: converted_value}} |
|
|
|
# Simple equality condition |
|
converted_value = self._convert_value(value) |
|
return {field: converted_value} |
|
|
|
def _convert_value(self, value: str) -> Any: |
|
""" |
|
Convert a string value to the appropriate Python type. |
|
|
|
Tries to convert to int, then float, then keeps as string. |
|
Handles quoted strings by removing quotes. |
|
|
|
Args: |
|
value (str): The string value to convert |
|
|
|
Returns: |
|
Any: The converted value (int, float, or str) |
|
""" |
|
if not value: |
|
return value |
|
|
|
# Remove surrounding quotes if present |
|
if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")): |
|
return value[1:-1] |
|
|
|
# Try to convert to number |
|
try: |
|
if '.' in value: |
|
return float(value) |
|
else: |
|
return int(value) |
|
except ValueError: |
|
return value |
|
|
|
def query_collection( |
|
self, |
|
collection: Collection, |
|
query_texts: list[str], |
|
n_results: int = 5, |
|
query_embeddings: list[list[float]] | None = None, |
|
where: dict | None = None, |
|
where_document: dict | None = None, |
|
) -> list[dict]: |
|
""" |
|
Query a ChromaDB collection with text queries and return formatted results. |
|
Use this to search for similar texts based on text or pre-computed embeddings. |
|
|
|
Args: |
|
collection (Collection): The ChromaDB collection to query against. |
|
query_texts (list[str]): List of text strings to search for in the collection. |
|
n_results (int, optional): Maximum number of results to return per query. Defaults to 5. |
|
query_embeddings (list[list[float]] | None, optional): Pre-computed embeddings |
|
for the queries. If None, embeddings will be computed from query_texts. |
|
Defaults to None. |
|
where (dict | None, optional): Metadata filter to apply to the search results. |
|
Examples: |
|
- Simple filter: {"author": "John Doe"} |
|
- Comparison: {"page": {"$gt": 10}} |
|
- Logical AND: {"$and": [{"author": "John"}, {"year": {"$gte": 2020}}]} |
|
- Logical OR: {"$or": [{"category": "news"}, {"category": "politics"}]} |
|
- Inclusion: {"status": {"$in": ["published", "draft"]}} |
|
If None, no filtering is applied. Defaults to None. |
|
where_document (dict | None, optional): Full-text search filter for document content. |
|
Examples: |
|
- Contains text: {"$contains": "climate change"} |
|
- Regex pattern: {"$regex": r"\\b\\d{4}\\b"} # matches 4-digit years |
|
If None, no document filtering is applied. Defaults to None. |
|
|
|
Returns: |
|
list[dict]: A flattened list of unique dictionaries containing query results |
|
from all queries, where each dictionary has keys 'metadata', 'document', |
|
'distance', and 'id'. Duplicates (same id) are removed, keeping the |
|
result with the best (lowest) distance score. |
|
""" |
|
|
|
assert query_texts or query_embeddings, "Either query_texts or query_embeddings must be provided." |
|
# Build query parameters - only include optional parameters if they're provided |
|
|
|
query_params = { |
|
'n_results': n_results, |
|
} |
|
|
|
if query_texts: |
|
query_params['query_texts'] = query_texts |
|
elif query_embeddings: |
|
query_params['query_embeddings'] = query_embeddings |
|
if where is not None: |
|
query_params['where'] = where |
|
|
|
if where_document is not None: |
|
query_params['where_document'] = where_document |
|
|
|
# Execute the query with the constructed parameters |
|
results = collection.query(**query_params) |
|
|
|
# Dictionary to store unique results by id, keeping the one with best distance |
|
unique_results = {} |
|
|
|
# Process results from all queries |
|
metadatas = results.get("metadatas", []) |
|
documents = results.get("documents", []) |
|
distances = results.get("distances", []) |
|
ids = results.get("ids", []) |
|
|
|
# Iterate through each query's results |
|
for query_idx in range(len(metadatas)): |
|
query_metadatas = metadatas[query_idx] if query_idx < len(metadatas) else [] |
|
query_documents = documents[query_idx] if query_idx < len(documents) else [] |
|
query_distances = distances[query_idx] if query_idx < len(distances) else [] |
|
query_ids = ids[query_idx] if query_idx < len(ids) else [] |
|
|
|
# Process each result in this query |
|
for metadata, document, distance, identifier in zip( |
|
query_metadatas, query_documents, query_distances, query_ids |
|
): |
|
# Keep the result with the best (lowest) distance if we have duplicates |
|
if identifier not in unique_results or distance < unique_results[identifier]['distance']: |
|
unique_results[identifier] = { |
|
'metadata': metadata, |
|
'document': document, |
|
'distance': distance, |
|
'id': identifier |
|
} |
|
|
|
# Convert to list and sort by distance (best results first) |
|
results_list = list(unique_results.values()) |
|
results_list.sort(key=lambda x: x['distance']) |
|
|
|
return results_list |
|
|
|
|
|
chroma_db = ChromaClient() |
|
# --- Tests --- |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
collection = chroma_db.get_collection(os.getenv("CHROMA_TALK_COLLECTION")) |
|
print(collection.count()) |
|
query = 'betyg grundskola' |
|
results = chroma_db.query_collection( |
|
collection=collection, |
|
query_texts=query, |
|
n_results=3, |
|
) |
|
for res in results: |
|
print(res['document']) |
|
print('---') |
|
col: Collection = chroma_db.get_collection(os.getenv("CHROMA_TALK_COLLECTION")) |
|
print(col.get(limit=10)) |
|
results = col.query(query_texts=query, n_results=3) |
|
for i in zip( |
|
results['metadatas'][0], |
|
results['documents'][0], |
|
results['distances'][0], |
|
results['ids'][0], |
|
): |
|
print(i) |
|
|
|
|