From d1261d2d902dc948d408211d360e36e6e8df5368 Mon Sep 17 00:00:00 2001 From: lasseedfast Date: Thu, 17 Oct 2024 14:45:09 +0200 Subject: [PATCH] Refactor _chromadb.py to improve code structure and readability --- _chromadb.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/_chromadb.py b/_chromadb.py index 8c87cdb..986bd25 100644 --- a/_chromadb.py +++ b/_chromadb.py @@ -7,19 +7,22 @@ from chromadb.config import Settings from dotenv import load_dotenv from chromadb.utils import embedding_functions - load_dotenv('.chroma_env') + class ChromaDB: - def __init__(self): - self.db = chromadb.HttpClient( - host="https://lasseedfast.se/chroma_ev_cars", - settings=Settings( - chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider", - chroma_client_auth_credentials=os.getenv("CHROMA_CLIENT_AUTH_CREDENTIALS"), - chroma_auth_token_transport_header=os.getenv("CHROMA_AUTH_TOKEN_TRANSPORT_HEADER") + def __init__(self, local_deployment: bool = False, db='sci_articles'): + if local_deployment: + self.db = chromadb.PersistentClient(f'chroma_{db}') + else: + self.db = chromadb.HttpClient( + host=os.getenv('CHROMA_HOST'), + settings=Settings( + chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider", + chroma_client_auth_credentials=os.getenv("CHROMA_CLIENT_AUTH_CREDENTIALS"), + chroma_auth_token_transport_header=os.getenv("CHROMA_AUTH_TOKEN_TRANSPORT_HEADER") + ) ) - ) max_characters = 2200 self.ts = MarkdownSplitter(max_characters)