You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
163 lines
5.6 KiB
163 lines
5.6 KiB
import chromadb |
|
import os |
|
from chromadb.config import Settings |
|
from dotenv import load_dotenv |
|
from colorprinter.print_color import * |
|
|
|
load_dotenv(".env") |
|
|
|
|
|
class ChromaDB: |
|
def __init__(self, local_deployment: bool = False, db="sci_articles", host=None): |
|
if local_deployment: |
|
self.db = chromadb.PersistentClient(f"chroma_{db}") |
|
else: |
|
if not host: |
|
host = os.getenv("CHROMA_HOST") |
|
credentials = os.getenv("CHROMA_CLIENT_AUTH_CREDENTIALS") |
|
auth_token_transport_header = os.getenv( |
|
"CHROMA_AUTH_TOKEN_TRANSPORT_HEADER" |
|
) |
|
self.db = chromadb.HttpClient( |
|
host=host, |
|
settings=Settings( |
|
chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider", |
|
chroma_client_auth_credentials=credentials, |
|
chroma_auth_token_transport_header=auth_token_transport_header, |
|
), |
|
) |
|
|
|
def query( |
|
self, |
|
query, |
|
collection, |
|
n_results=6, |
|
n_sources=3, |
|
max_retries=None, |
|
where: dict = None, |
|
**kwargs, |
|
): |
|
""" |
|
Query the vector database for relevant documents. |
|
|
|
Args: |
|
query (str): The query text to search for. |
|
collection (str): The name of the collection to search within. |
|
n_results (int, optional): The number of results to return. Defaults to 6. |
|
n_sources (int, optional): The number of unique sources to return. Defaults to 3. |
|
max_retries (int, optional): The maximum number of retries for querying. Defaults to None. |
|
where (dict, optional): Additional filtering criteria for the query. Defaults to None. |
|
**kwargs: Additional keyword arguments to pass to the query. |
|
|
|
Returns: |
|
dict: A dictionary containing the query results with keys 'ids', 'metadatas', 'documents', and 'distances'. |
|
""" |
|
if not isinstance(n_sources, int): |
|
n_sources = int(n_sources) |
|
if not isinstance(n_results, int): |
|
n_results = int(n_results) |
|
if not max_retries: |
|
max_retries = n_sources |
|
if n_sources > n_results: |
|
n_sources = n_results |
|
col = self.db.get_collection(collection) |
|
sources = [] |
|
n = 0 |
|
|
|
result = {"ids": [[]], "metadatas": [[]], "documents": [[]], "distances": [[]]} |
|
while True: |
|
n += 1 |
|
if n > max_retries: |
|
break |
|
r = col.query( |
|
query_texts=query, |
|
n_results=n_results - len(sources), |
|
where=where, |
|
**kwargs, |
|
) |
|
if r["ids"][0] == []: |
|
if result["ids"][0] == []: |
|
print_red("No results found in vector database.") |
|
else: |
|
print_red("No more results found in vector database.") |
|
break |
|
|
|
# Manually extend each list within the lists of lists |
|
for key in result: |
|
if key in r: |
|
result[key][0].extend(r[key][0]) |
|
|
|
# Order result by distance |
|
combined = sorted( |
|
zip( |
|
result["distances"][0], |
|
result["ids"][0], |
|
result["metadatas"][0], |
|
result["documents"][0], |
|
), |
|
key=lambda x: x[0], |
|
) |
|
( |
|
result["distances"][0], |
|
result["ids"][0], |
|
result["metadatas"][0], |
|
result["documents"][0], |
|
) = map(list, zip(*combined)) |
|
sources += list(set([i["_id"] for i in result["metadatas"][0]])) |
|
if len(sources) >= n_sources: |
|
break |
|
elif n != max_retries: |
|
for k, v in result.items(): |
|
if k not in r["included"]: |
|
continue |
|
result[k][0] = v[0][: n_results - (n_sources - len(sources))] |
|
if where and "_id" in where: |
|
where["_id"]["$in"] = [ |
|
i for i in where["_id"]["$in"] if i not in sources |
|
] |
|
if where["_id"]["$in"] == []: |
|
break |
|
else: |
|
break |
|
return result |
|
|
|
def add_chunks(self, collection: str, chunks: list, _key, metadata: dict = None): |
|
""" |
|
Adds chunks to a specified collection in the database. |
|
|
|
Args: |
|
collection (str): The name of the collection to add chunks to. |
|
chunks (list): A list of chunks to be added to the collection. |
|
_key: A key used to generate unique IDs for the chunks. |
|
metadata (dict, optional): Metadata to be associated with each chunk. Defaults to None. |
|
|
|
Returns: |
|
None |
|
""" |
|
col = self.db.get_or_create_collection(collection) |
|
ids = [] |
|
metadatas = [] |
|
for number in chunks: |
|
if metadata: |
|
metadata["number"] = number |
|
metadatas.append(metadata) |
|
else: |
|
metadatas.append({}) |
|
ids.append(f"{_key}_{number}") |
|
col.add(ids=ids, metadatas=metadatas, documents=chunks) |
|
|
|
|
|
if __name__ == "__main__": |
|
from colorprinter.print_color import * |
|
|
|
chroma = ChromaDB() |
|
print(chroma.db.list_collections()) |
|
exit() |
|
result = chroma.query( |
|
query="What is Open Science)", |
|
collection="sci_articles", |
|
n_results=2, |
|
n_sources=3, |
|
max_retries=4, |
|
) |
|
print_rainbow(result["metadatas"][0])
|
|
|