import chromadb import os from typing import Union, List, Dict, Tuple, Any, Union import re from chromadb.config import Settings from dotenv import load_dotenv from colorprinter.print_color import * from models import ChunkSearchResults load_dotenv(".env") class ChromaDB: def __init__(self, local_deployment: bool = False, db="sci_articles", host=None): if local_deployment: self.db = chromadb.PersistentClient(f"chroma_{db}") else: if not host: host = os.getenv("CHROMA_HOST") credentials = os.getenv("CHROMA_CLIENT_AUTH_CREDENTIALS") auth_token_transport_header = os.getenv( "CHROMA_AUTH_TOKEN_TRANSPORT_HEADER" ) self.db = chromadb.HttpClient( host=host, #database=db, settings=Settings( chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider", chroma_client_auth_credentials=credentials, chroma_auth_token_transport_header=auth_token_transport_header, ), ) def query( self, query, collection, n_results=6, n_sources=3, max_retries=None, where: dict = None, **kwargs, ): """ Query the vector database for relevant documents. Args: query (str): The query text to search for. collection (str): The name of the collection to search within. n_results (int, optional): The number of results to return. Defaults to 6. n_sources (int, optional): The number of unique sources to return. Defaults to 3. max_retries (int, optional): The maximum number of retries for querying. Defaults to None. where (dict, optional): Additional filtering criteria for the query. Defaults to None. **kwargs: Additional keyword arguments to pass to the query. Returns: dict: A dictionary containing the query results with keys 'ids', 'metadatas', 'documents', and 'distances'. """ if not isinstance(n_sources, int): n_sources = int(n_sources) if not isinstance(n_results, int): n_results = int(n_results) if not max_retries: max_retries = n_sources if n_sources > n_results: n_sources = n_results col = self.db.get_collection(collection) sources = [] n = 0 print('Collection', collection) result = {"ids": [[]], "metadatas": [[]], "documents": [[]], "distances": [[]]} while True: n += 1 if n > max_retries: break if where == {}: where = None print_rainbow(kwargs) print('N_results:', n_results) print('Sources:', sources) print('Query:', query) r = col.query( query_texts=query, n_results=n_results - len(sources), where=where, **kwargs, ) if r["ids"][0] == []: if result["ids"][0] == []: print_rainbow(r) print_red("No results found in vector database.") else: print_red("No more results found in vector database.") break # Manually extend each list within the lists of lists for key in result: if key in r: result[key][0].extend(r[key][0]) # Order result by distance combined = sorted( zip( result["distances"][0], result["ids"][0], result["metadatas"][0], result["documents"][0], ), key=lambda x: x[0], ) ( result["distances"][0], result["ids"][0], result["metadatas"][0], result["documents"][0], ) = map(list, zip(*combined)) sources += list(set([i["_id"] for i in result["metadatas"][0]])) if len(sources) >= n_sources: break elif n != max_retries: for k, v in result.items(): if k not in r["included"]: continue result[k][0] = v[0][: n_results - (n_sources - len(sources))] if where and "_id" in where: where["_id"]["$in"] = [ i for i in where["_id"]["$in"] if i not in sources ] if where["_id"]["$in"] == []: break else: break return result def search( self, query: str, collection: str, n_results: int = 6, n_sources: int = 3, where: dict = None, format_results: bool = False, **kwargs, ) -> Union[dict, ChunkSearchResults]: """ An enhanced search method that provides a cleaner interface for querying and processing results. Args: query (str): The search query collection (str): Collection name to search in n_results (int): Maximum number of results to return n_sources (int): Maximum number of unique sources to include where (dict, optional): Additional filtering criteria format_results (bool): Whether to return formatted ChunkSearchResults **kwargs: Additional arguments to pass to the query Returns: List[dict]: List of dictionaries containing the search results """ # Get raw query results with existing query method result = self.query( query=query, collection=collection, n_results=n_results, n_sources=n_sources, where=where, **kwargs, ) # If no formatting requested, return raw results if not format_results: return result # Process results into dictionary format combined_chunks = [] for doc, meta, dist, _id in zip( result["documents"][0], result["metadatas"][0], result["distances"][0], result["ids"][0], ): combined_chunks.append( {"document": doc, "metadata": meta, "distance": dist, "id": _id} ) return combined_chunks def clean_result_text(self, documents: list) -> list: """ Clean text in document results by removing footnote references. Args: documents (list): List of document dictionaries Returns: list: Documents with cleaned text """ import re for doc in documents: if "document" in doc: doc["document"] = re.sub(r"\[\d+\]", "", doc["document"]) return documents def filter_by_unique_sources( self, results: list, n_sources: int, source_key: str = "_id" ) -> Tuple[List, List]: """ Filters search results to keep only a specified number of unique sources. Args: results (list): List of documents from search n_sources (int): Maximum number of unique sources to include source_key (str): The key in metadata that identifies the source Returns: tuple: (filtered_results, remaining_results) """ sources = set() filtered_results = [] remaining_results = [] for item in results: source_id = item["metadata"].get(source_key, "no_id") if source_id not in sources and len(sources) < n_sources: sources.add(source_id) filtered_results.append(item) else: remaining_results.append(item) return filtered_results, remaining_results def backfill_results( self, filtered_results: list, remaining_results: list, n_results: int ) -> list: """ Adds additional results from remaining_results to filtered_results until n_results is reached. Args: filtered_results (list): Initial filtered results remaining_results (list): Other results that can be added n_results (int): Target number of total results Returns: list: Combined results up to n_results """ if len(filtered_results) >= n_results: return filtered_results[:n_results] needed = n_results - len(filtered_results) return filtered_results + remaining_results[:needed] def search_chunks( self, query: str, collections: List[str], n_results: int = 7, n_sources: int = 4, where: dict = None, **kwargs, ) -> ChunkSearchResults: """ Complete pipeline for processing chunks: search, filter, clean, and format. Args: query (str): The search query collections (List[str]): List of collection names to search n_results (int): Maximum number of results to return n_sources (int): Maximum number of unique sources to include where (dict, optional): Additional filtering criteria **kwargs: Additional arguments to pass to search Returns: ChunkSearchResults: Processed chunks with Chroma IDs """ combined_chunks = [] if isinstance(collections, str): collections = [collections] # Search all collections for collection in collections: chunks = self.search( query=query, collection=collection, n_results=n_results, n_sources=n_sources, where=where, format_results=True, **kwargs, ) for chunk in chunks: combined_chunks.append({ "document": chunk["document"], "metadata": chunk["metadata"], "distance": chunk["distance"], "id": chunk["id"], }) # Sort and filter results combined_chunks.sort(key=lambda x: x["distance"]) # Filter by unique sources and backfill closest_chunks, remaining_chunks = self.filter_by_unique_sources( combined_chunks, n_sources ) closest_chunks = self.backfill_results( closest_chunks, remaining_chunks, n_results ) # Clean text closest_chunks = self.clean_result_text(closest_chunks) return closest_chunks def add_document(self, _id, collection: str, document: str, metadata: dict = None): """ Adds a single document to a specified collection in the database. Args: _id (str): Arango ID for the document, used as a unique identifier. collection (str): The name of the collection to add the document to. document (str): The document text to be added. metadata (dict, optional): Metadata to be associated with the document. Defaults to None. Returns: None """ col = self.db.get_or_create_collection(collection) if metadata is None: metadata = {} col.add(ids=[_id], documents=[document], metadatas=[metadata]) def add_chunks(self, collection: str, chunks: list, _key, metadata: dict = None): """ Adds chunks to a specified collection in the database. Args: collection (str): The name of the collection to add chunks to. chunks (list): A list of chunks to be added to the collection. _key: A key used to generate unique IDs for the chunks. metadata (dict, optional): Metadata to be associated with each chunk. Defaults to None. Returns: None """ col = self.db.get_or_create_collection(collection) ids = [] metadatas = [] for number in chunks: if metadata: metadata["number"] = number metadatas.append(metadata) else: metadatas.append({}) ids.append(f"{_key}_{number}") col.add(ids=ids, metadatas=metadatas, documents=chunks) def get_collection(self, collection: str) -> chromadb.Collection: """ Retrieves a collection from the database. Args: collection (str): The name of the collection to retrieve. Returns: chromadb.Collection: The requested collection. """ return self.db.get_or_create_collection(collection) def is_reference_chunk(text: str) -> bool: """ Determine if a text chunk primarily consists of academic references. Args: text (str): Text chunk to analyze Returns: bool: True if the chunk appears to be mainly references """ # Count significant reference indicators indicators = 0 # Check for DOI links (very strong indicator) doi_matches = len(re.findall(r'https?://doi\.org/10\.\d+/\S+', text)) if doi_matches >= 2: # Multiple DOIs almost certainly means references return True elif doi_matches == 1: indicators += 3 # Check for citation patterns with year, volume, pages (e.g., 2018;178:551–60) citation_patterns = len(re.findall(r'\d{4};\d+:\d+[-–]\d+', text)) indicators += citation_patterns * 2 # Check for year patterns in brackets [YYYY] year_brackets = len(re.findall(r'\[\d{4}\]', text)) indicators += year_brackets # Check for multiple lines starting with author name patterns lines = [line.strip() for line in text.split('\n') if line.strip()] author_started_lines = 0 for line in lines: # Common pattern in references: starts with Author Name(s) if re.match(r'^\s*[A-Z][a-z]+\s+[A-Z][a-z]+', line): author_started_lines += 1 # If multiple lines start with author names (common in reference lists) if author_started_lines >= 2: indicators += 2 # Check for academic reference terms if re.search(r'\bet al\b|\bet al\.\b', text, re.IGNORECASE): indicators += 1 # Return True if we have sufficient indicators return indicators >= 4 # Adjust threshold as needed if __name__ == "__main__": from colorprinter.print_color import * chroma = ChromaDB() print(chroma.db.list_collections()) print('DB', chroma.db.database) print('SETTINGS', chroma.db.get_version()) result = chroma.search_chunks( query="What is Open Science)", collections="lasse__other_documents", n_results=2, n_sources=3, max_retries=4, ) collection = chroma.db.get_or_create_collection("lasse__other_documents") result = collection.query( query_texts="What is Open Science?", n_results=2, ) from pprint import pprint pprint(result) #print_rainbow(result["metadatas"][0])