You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
160 lines
5.0 KiB
160 lines
5.0 KiB
import chromadb as db |
|
from chromadb.utils import embedding_functions |
|
from chromadb.config import Settings |
|
from chromadb.api.client import Client |
|
from chromadb.api.models.Collection import Collection |
|
|
|
import os |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
class ChromaDB: |
|
""" |
|
A class representing a Chroma database. |
|
""" |
|
|
|
def __init__(self, host: str = "192.168.1.10", port: int = 8001) -> None: |
|
""" |
|
Initializes a ChromaDB object running on specified port. |
|
|
|
Args: |
|
host (str, optional): The host address of the Chroma database. Defaults to "192.168.1.10". |
|
port (int, optional): The port number of the Chroma database. Defaults to 8001. |
|
""" |
|
self.client: Client = db.HttpClient( |
|
settings=Settings(anonymized_telemetry=False), |
|
host=host, |
|
port=port, |
|
) |
|
# huggingface_ef = embedding_functions.HuggingFaceEmbeddingFunction( |
|
# api_key="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", |
|
# model_name="KBLab/sentence-bert-swedish-cased", |
|
# ) |
|
self.embedding_function: embedding_functions = ( |
|
embedding_functions.SentenceTransformerEmbeddingFunction( |
|
model_name="KBLab/sentence-bert-swedish-cased" |
|
) |
|
) |
|
|
|
def print_collections(self): |
|
""" |
|
Prints all collections in the database. |
|
""" |
|
collections: Collection = self.client.list_collections() |
|
for collection in collections: |
|
print(collection.name) |
|
|
|
def add_person_to_chroma(self, person): |
|
""" |
|
Adds a person to the Chroma database. |
|
|
|
Args: |
|
person (dict): A dictionary containing information about the person. |
|
|
|
Returns: |
|
None |
|
""" |
|
|
|
collection = self.client.get_or_create_collection( |
|
"mala_persons", embedding_function=self.embedding_function |
|
) |
|
|
|
# Lists to store the documents, metadatas and ids |
|
documents = [] |
|
metadatas = [] |
|
ids = [] |
|
|
|
documents.append(person["name"]) |
|
metadata = { |
|
"name": person["name"], |
|
"_key": person["_key"], |
|
"info": "\n".join(person["info"]), |
|
} |
|
metadatas.append(metadata) |
|
ids.append(person["_key"]) |
|
|
|
collection.add(documents=documents, metadatas=metadatas, ids=ids) |
|
|
|
def add_all_persons_to_chroma(self): |
|
""" |
|
Adds all persons to the Chroma collection. |
|
|
|
This method deletes the existing 'mala_persons' collection, creates a new collection, |
|
and then adds all persons from the database whose 'verified' field is set to True. |
|
|
|
Args: |
|
None |
|
|
|
Returns: |
|
None |
|
""" |
|
from _arango import arango |
|
|
|
self.client.delete_collection("mala_persons") |
|
col = self.client.get_or_create_collection( |
|
"mala_persons", embedding_function=self.embedding_function |
|
) |
|
|
|
db = arango.db |
|
q = "for doc in persons filter doc.confirmed == true return doc" |
|
persons = list(db.aql.execute(q)) |
|
|
|
for person in persons: |
|
self.add_person_to_chroma(person) |
|
|
|
print("Persons in chroma:", col.count()) |
|
|
|
def add_all_person_info(self): |
|
""" |
|
Adds all person information to the Chroma database. |
|
""" |
|
from _arango import arango |
|
|
|
try: |
|
self.client.delete_collection("mala_persons_info") |
|
except: |
|
pass |
|
col = self.client.get_or_create_collection( |
|
"mala_persons_info", embedding_function=self.embedding_function |
|
) |
|
|
|
persons = list(arango.db.collection("persons").all()) |
|
for person in persons: |
|
doc = person["name"] + "\n" + "\n".join(person["info"]) |
|
col.add( |
|
documents=[doc], |
|
metadatas=[{"name": person["name"], "_key": person["_key"]}], |
|
ids=[person["_key"]], |
|
) |
|
|
|
def query(self, collection, query_texts, n_results=5, where={}): |
|
if isinstance(query_texts, str): |
|
query_texts = [query_texts] |
|
col = self.client.get_collection( |
|
collection, embedding_function=self.embedding_function |
|
) |
|
return col.query(query_texts=query_texts, n_results=n_results, where=where, ) |
|
|
|
def add_interrogations(): |
|
from _arango import db |
|
from langchain_text_splitters import CharacterTextSplitter |
|
text_splitter = CharacterTextSplitter( |
|
separator="\n\n", |
|
chunk_size=1000, |
|
chunk_overlap=100, |
|
length_function=len, |
|
is_separator_regex=False, |
|
) |
|
interrogatons = list(db.collection('interrogations').all()) |
|
for interrogation in interrogatons: |
|
chunks = text_splitter.split_text(interrogation['text']) |
|
for chunk in chunks: |
|
|
|
|
|
|
|
# Initialize the ChromaDB object |
|
chroma = ChromaDB() |
|
|
|
if __name__ == "__main__": |
|
chroma.print_collections() |
|
#chroma.add_all_persons_to_chroma() |
|
#chroma.add_all_person_info()
|
|
|