You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
55 lines
2.1 KiB
55 lines
2.1 KiB
import os |
|
import sys |
|
# Set /home/lasse/riksdagen as working directory |
|
os.chdir("/home/lasse/riksdagen") |
|
# Add the project root to Python path to locate local modules |
|
sys.path.append("/home/lasse/riksdagen") |
|
from typing import Any, List |
|
import chromadb |
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
from config import embedding_dimensions |
|
|
|
CHROMA_PATH = "/home/lasse/riksdagen/chromadb_data" |
|
DEFAULT_CHROMA_COLLECTION = "talks_embeddings" |
|
|
|
def _get_chroma_collection(collection_name=None) -> Any: |
|
"""Return the configured Chroma collection, raising if it has not been created yet.""" |
|
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH) |
|
if collection_name is None: |
|
collection_name = DEFAULT_CHROMA_COLLECTION |
|
return chroma_client.get_collection(name=collection_name) |
|
|
|
|
|
def query_similar(embeddings: list = None, texts: list[str] =None, n_results: int = 8) -> List[tuple[dict, str]]: |
|
""" Query the ChromaDB collection for texts similar to the input text. |
|
Args: |
|
text (str): The input text to query. |
|
Returns: |
|
List[tuple[dict, str]]: A list of tuples containing metadata and document texts (metadata, document). |
|
""" |
|
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH) |
|
collection = chroma_client.get_collection(name=DEFAULT_CHROMA_COLLECTION) |
|
results = collection.query( |
|
query_embeddings=embeddings, |
|
query_texts=texts, |
|
n_results=n_results, |
|
include=["metadatas", "documents", "distances"], |
|
) |
|
results_list = [] |
|
for metadata, document, distance, identifier in zip( |
|
results.get("metadatas", [[]])[0], |
|
results.get("documents", [[]])[0], |
|
results.get("distances", [[]])[0], |
|
results.get("ids", [[]])[0], |
|
): |
|
results_list.append({'metadata': metadata, 'document': document, 'distance': distance, 'id': identifier}) |
|
return results_list |
|
|
|
|
|
if __name__ == "__main__": |
|
# Example usage |
|
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH) |
|
collections = chroma_client.list_collections() |
|
print("Available collections:", [col.name for col in collections]) |