You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
105 lines
2.9 KiB
105 lines
2.9 KiB
import chromadb as db |
|
from chromadb.utils import embedding_functions |
|
from chromadb.config import Settings |
|
from chromadb.api.client import Client |
|
from chromadb.api.models.Collection import Collection |
|
|
|
|
|
|
|
class ChromaDB: |
|
""" |
|
A class representing a Chroma database. |
|
""" |
|
|
|
def __init__(self, host: str = "192.168.1.10", port: int = 8001) -> None: |
|
""" |
|
Initializes a ChromaDB object running on specified port. |
|
|
|
Args: |
|
host (str, optional): The host address of the Chroma database. Defaults to "192.168.1.10". |
|
port (int, optional): The port number of the Chroma database. Defaults to 8001. |
|
""" |
|
self.client: Client = db.HttpClient( |
|
settings=Settings(anonymized_telemetry=False), |
|
host=host, |
|
port=port, |
|
) |
|
huggingface_ef = embedding_functions.HuggingFaceEmbeddingFunction( |
|
api_key='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', |
|
model_name="KBLab/sentence-bert-swedish-cased" |
|
) |
|
self.embedding_function: embedding_functions = huggingface_ef |
|
|
|
def print_collections(self): |
|
""" |
|
Prints all collections in the database. |
|
""" |
|
collections = self.client.list_collections() |
|
for collection in collections: |
|
print(collection.name) |
|
|
|
def add_person_to_chroma(self, person): |
|
""" |
|
Adds a person to the Chroma database. |
|
|
|
Args: |
|
person (dict): A dictionary containing information about the person. |
|
|
|
Returns: |
|
None |
|
""" |
|
|
|
collection = self.client.get_collection("mala_persons") |
|
|
|
# Lists to store the documents, metadatas and ids |
|
documents = [] |
|
metadatas = [] |
|
ids = [] |
|
|
|
documents.append(person["name"]) |
|
metadata = {"name": person["name"], "_key": person["_key"], 'info': "\n".join(person["info"])} |
|
metadatas.append(metadata) |
|
ids.append(person["_key"]) |
|
|
|
collection.add(documents=documents, metadatas=metadatas, ids=ids) |
|
|
|
def add_all_persons_to_chroma(self): |
|
""" |
|
Adds all persons to the Chroma collection. |
|
|
|
This method deletes the existing 'mala_persons' collection, creates a new collection, |
|
and then adds all persons from the database whose 'verified' field is set to True. |
|
|
|
Args: |
|
None |
|
|
|
Returns: |
|
None |
|
""" |
|
from _arango import arango |
|
|
|
self.client.delete_collection('mala_persons') |
|
col = self.client.get_or_create_collection('mala_persons') |
|
|
|
db = arango.db |
|
q = "for doc in persons filter doc.confirmed == true return doc" |
|
persons = list(db.aql.execute(q)) |
|
|
|
for person in persons: |
|
self.add_person_to_chroma(person) |
|
|
|
print('Persons in chroma:', col.count()) |
|
|
|
|
|
|
|
|
|
# Initialize the ChromaDB object |
|
chroma = ChromaDB() |
|
|
|
if __name__ == '__main__': |
|
chroma = ChromaDB() |
|
chroma.add_all_persons_to_chroma() |
|
|
|
|
|
|
|
|
|
|