From aafcf6b0dde712a4dda7d21796ee811247125fcd Mon Sep 17 00:00:00 2001 From: lasseedfast <> Date: Mon, 6 May 2024 07:45:50 +0200 Subject: [PATCH] Add methods to add persons to Chroma database --- _chroma.py | 60 +++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 57 insertions(+), 3 deletions(-) diff --git a/_chroma.py b/_chroma.py index 0f64c9d..7319a39 100644 --- a/_chroma.py +++ b/_chroma.py @@ -38,14 +38,68 @@ class ChromaDB: for collection in collections: print(collection.name) + def add_person_to_chroma(self, person): + """ + Adds a person to the Chroma database. + + Args: + person (dict): A dictionary containing information about the person. + + Returns: + None + """ + + collection = self.client.get_collection("mala_persons") + + # Lists to store the documents, metadatas and ids + documents = [] + metadatas = [] + ids = [] + + documents.append(person["name"]) + metadata = {"name": person["name"], "_key": person["_key"], 'info': "\n".join(person["info"])} + metadatas.append(metadata) + ids.append(person["_key"]) + + collection.add(documents=documents, metadatas=metadatas, ids=ids) + + def add_all_persons_to_chroma(self): + """ + Adds all persons to the Chroma collection. + + This method deletes the existing 'mala_persons' collection, creates a new collection, + and then adds all persons from the database whose 'verified' field is set to True. + + Args: + None + + Returns: + None + """ + from _arango import arango + + self.client.delete_collection('mala_persons') + col = self.client.get_or_create_collection('mala_persons') + + db = arango.db + q = "for doc in persons filter doc.confirmed == true return doc" + persons = list(db.aql.execute(q)) + + for person in persons: + self.add_person_to_chroma(person) + + print('Persons in chroma:', col.count()) + + + + # Initialize the ChromaDB object chroma = ChromaDB() if __name__ == '__main__': + chroma = ChromaDB() + chroma.add_all_persons_to_chroma() - chroma.client.delete_collection('mala_persons') - col = chroma.client.get_or_create_collection('mala_persons') - print(col.count())