import chromadb import os from chromadb.config import Settings from dotenv import load_dotenv from colorprinter.print_color import * load_dotenv(".env") class ChromaDB: def __init__(self, local_deployment: bool = False, db="sci_articles", host=None): if local_deployment: self.db = chromadb.PersistentClient(f"chroma_{db}") else: if not host: host = os.getenv("CHROMA_HOST") credentials = os.getenv("CHROMA_CLIENT_AUTH_CREDENTIALS") auth_token_transport_header = os.getenv( "CHROMA_AUTH_TOKEN_TRANSPORT_HEADER" ) self.db = chromadb.HttpClient( host=host, settings=Settings( chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider", chroma_client_auth_credentials=credentials, chroma_auth_token_transport_header=auth_token_transport_header, ), ) def query( self, query, collection, n_results=6, n_sources=3, max_retries=None, where: dict = None, **kwargs, ): """ Query the vector database for relevant documents. Args: query (str): The query text to search for. collection (str): The name of the collection to search within. n_results (int, optional): The number of results to return. Defaults to 6. n_sources (int, optional): The number of unique sources to return. Defaults to 3. max_retries (int, optional): The maximum number of retries for querying. Defaults to None. where (dict, optional): Additional filtering criteria for the query. Defaults to None. **kwargs: Additional keyword arguments to pass to the query. Returns: dict: A dictionary containing the query results with keys 'ids', 'metadatas', 'documents', and 'distances'. """ if not isinstance(n_sources, int): n_sources = int(n_sources) if not isinstance(n_results, int): n_results = int(n_results) if not max_retries: max_retries = n_sources if n_sources > n_results: n_sources = n_results col = self.db.get_collection(collection) sources = [] n = 0 result = {"ids": [[]], "metadatas": [[]], "documents": [[]], "distances": [[]]} while True: n += 1 if n > max_retries: break if where == {}: where = None r = col.query( query_texts=query, n_results=n_results - len(sources), where=where, **kwargs, ) if r["ids"][0] == []: if result["ids"][0] == []: print_red("No results found in vector database.") else: print_red("No more results found in vector database.") break # Manually extend each list within the lists of lists for key in result: if key in r: result[key][0].extend(r[key][0]) # Order result by distance combined = sorted( zip( result["distances"][0], result["ids"][0], result["metadatas"][0], result["documents"][0], ), key=lambda x: x[0], ) ( result["distances"][0], result["ids"][0], result["metadatas"][0], result["documents"][0], ) = map(list, zip(*combined)) sources += list(set([i["_id"] for i in result["metadatas"][0]])) if len(sources) >= n_sources: break elif n != max_retries: for k, v in result.items(): if k not in r["included"]: continue result[k][0] = v[0][: n_results - (n_sources - len(sources))] if where and "_id" in where: where["_id"]["$in"] = [ i for i in where["_id"]["$in"] if i not in sources ] if where["_id"]["$in"] == []: break else: break return result def add_chunks(self, collection: str, chunks: list, _key, metadata: dict = None): """ Adds chunks to a specified collection in the database. Args: collection (str): The name of the collection to add chunks to. chunks (list): A list of chunks to be added to the collection. _key: A key used to generate unique IDs for the chunks. metadata (dict, optional): Metadata to be associated with each chunk. Defaults to None. Returns: None """ col = self.db.get_or_create_collection(collection) ids = [] metadatas = [] for number in chunks: if metadata: metadata["number"] = number metadatas.append(metadata) else: metadatas.append({}) ids.append(f"{_key}_{number}") col.add(ids=ids, metadatas=metadatas, documents=chunks) if __name__ == "__main__": from colorprinter.print_color import * chroma = ChromaDB() print(chroma.db.list_collections()) exit() result = chroma.query( query="What is Open Science)", collection="sci_articles", n_results=2, n_sources=3, max_retries=4, ) print_rainbow(result["metadatas"][0])