diff --git a/_chromadb.py b/_chromadb.py index 8c87cdb..986bd25 100644 --- a/_chromadb.py +++ b/_chromadb.py @@ -7,19 +7,22 @@ from chromadb.config import Settings from dotenv import load_dotenv from chromadb.utils import embedding_functions - load_dotenv('.chroma_env') + class ChromaDB: - def __init__(self): - self.db = chromadb.HttpClient( - host="https://lasseedfast.se/chroma_ev_cars", - settings=Settings( - chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider", - chroma_client_auth_credentials=os.getenv("CHROMA_CLIENT_AUTH_CREDENTIALS"), - chroma_auth_token_transport_header=os.getenv("CHROMA_AUTH_TOKEN_TRANSPORT_HEADER") + def __init__(self, local_deployment: bool = False, db='sci_articles'): + if local_deployment: + self.db = chromadb.PersistentClient(f'chroma_{db}') + else: + self.db = chromadb.HttpClient( + host=os.getenv('CHROMA_HOST'), + settings=Settings( + chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider", + chroma_client_auth_credentials=os.getenv("CHROMA_CLIENT_AUTH_CREDENTIALS"), + chroma_auth_token_transport_header=os.getenv("CHROMA_AUTH_TOKEN_TRANSPORT_HEADER") + ) ) - ) max_characters = 2200 self.ts = MarkdownSplitter(max_characters)