You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

165 lines
5.7 KiB

import chromadb
import os
from chromadb.config import Settings
from dotenv import load_dotenv
from colorprinter.print_color import *
load_dotenv(".env")
class ChromaDB:
def __init__(self, local_deployment: bool = False, db="sci_articles", host=None):
if local_deployment:
self.db = chromadb.PersistentClient(f"chroma_{db}")
else:
if not host:
host = os.getenv("CHROMA_HOST")
credentials = os.getenv("CHROMA_CLIENT_AUTH_CREDENTIALS")
auth_token_transport_header = os.getenv(
"CHROMA_AUTH_TOKEN_TRANSPORT_HEADER"
)
self.db = chromadb.HttpClient(
host=host,
settings=Settings(
chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider",
chroma_client_auth_credentials=credentials,
chroma_auth_token_transport_header=auth_token_transport_header,
),
)
def query(
self,
query,
collection,
n_results=6,
n_sources=3,
max_retries=None,
where: dict = None,
**kwargs,
):
"""
Query the vector database for relevant documents.
Args:
query (str): The query text to search for.
collection (str): The name of the collection to search within.
n_results (int, optional): The number of results to return. Defaults to 6.
n_sources (int, optional): The number of unique sources to return. Defaults to 3.
max_retries (int, optional): The maximum number of retries for querying. Defaults to None.
where (dict, optional): Additional filtering criteria for the query. Defaults to None.
**kwargs: Additional keyword arguments to pass to the query.
Returns:
dict: A dictionary containing the query results with keys 'ids', 'metadatas', 'documents', and 'distances'.
"""
if not isinstance(n_sources, int):
n_sources = int(n_sources)
if not isinstance(n_results, int):
n_results = int(n_results)
if not max_retries:
max_retries = n_sources
if n_sources > n_results:
n_sources = n_results
col = self.db.get_collection(collection)
sources = []
n = 0
result = {"ids": [[]], "metadatas": [[]], "documents": [[]], "distances": [[]]}
while True:
n += 1
if n > max_retries:
break
if where == {}:
where = None
r = col.query(
query_texts=query,
n_results=n_results - len(sources),
where=where,
**kwargs,
)
if r["ids"][0] == []:
if result["ids"][0] == []:
print_red("No results found in vector database.")
else:
print_red("No more results found in vector database.")
break
# Manually extend each list within the lists of lists
for key in result:
if key in r:
result[key][0].extend(r[key][0])
# Order result by distance
combined = sorted(
zip(
result["distances"][0],
result["ids"][0],
result["metadatas"][0],
result["documents"][0],
),
key=lambda x: x[0],
)
(
result["distances"][0],
result["ids"][0],
result["metadatas"][0],
result["documents"][0],
) = map(list, zip(*combined))
sources += list(set([i["_id"] for i in result["metadatas"][0]]))
if len(sources) >= n_sources:
break
elif n != max_retries:
for k, v in result.items():
if k not in r["included"]:
continue
result[k][0] = v[0][: n_results - (n_sources - len(sources))]
if where and "_id" in where:
where["_id"]["$in"] = [
i for i in where["_id"]["$in"] if i not in sources
]
if where["_id"]["$in"] == []:
break
else:
break
return result
def add_chunks(self, collection: str, chunks: list, _key, metadata: dict = None):
"""
Adds chunks to a specified collection in the database.
Args:
collection (str): The name of the collection to add chunks to.
chunks (list): A list of chunks to be added to the collection.
_key: A key used to generate unique IDs for the chunks.
metadata (dict, optional): Metadata to be associated with each chunk. Defaults to None.
Returns:
None
"""
col = self.db.get_or_create_collection(collection)
ids = []
metadatas = []
for number in chunks:
if metadata:
metadata["number"] = number
metadatas.append(metadata)
else:
metadatas.append({})
ids.append(f"{_key}_{number}")
col.add(ids=ids, metadatas=metadatas, documents=chunks)
if __name__ == "__main__":
from colorprinter.print_color import *
chroma = ChromaDB()
print(chroma.db.list_collections())
exit()
result = chroma.query(
query="What is Open Science)",
collection="sci_articles",
n_results=2,
n_sources=3,
max_retries=4,
)
print_rainbow(result["metadatas"][0])