You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

55 lines
2.1 KiB

import os
import sys
# Set /home/lasse/riksdagen as working directory
os.chdir("/home/lasse/riksdagen")
# Add the project root to Python path to locate local modules
sys.path.append("/home/lasse/riksdagen")
from typing import Any, List
import chromadb
from dotenv import load_dotenv
load_dotenv()
from config import embedding_dimensions
CHROMA_PATH = "/home/lasse/riksdagen/chromadb_data"
DEFAULT_CHROMA_COLLECTION = "talks_embeddings"
def _get_chroma_collection(collection_name=None) -> Any:
"""Return the configured Chroma collection, raising if it has not been created yet."""
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
if collection_name is None:
collection_name = DEFAULT_CHROMA_COLLECTION
return chroma_client.get_collection(name=collection_name)
def query_similar(embeddings: list = None, texts: list[str] =None, n_results: int = 8) -> List[tuple[dict, str]]:
""" Query the ChromaDB collection for texts similar to the input text.
Args:
text (str): The input text to query.
Returns:
List[tuple[dict, str]]: A list of tuples containing metadata and document texts (metadata, document).
"""
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
collection = chroma_client.get_collection(name=DEFAULT_CHROMA_COLLECTION)
results = collection.query(
query_embeddings=embeddings,
query_texts=texts,
n_results=n_results,
include=["metadatas", "documents", "distances"],
)
results_list = []
for metadata, document, distance, identifier in zip(
results.get("metadatas", [[]])[0],
results.get("documents", [[]])[0],
results.get("distances", [[]])[0],
results.get("ids", [[]])[0],
):
results_list.append({'metadata': metadata, 'document': document, 'distance': distance, 'id': identifier})
return results_list
if __name__ == "__main__":
# Example usage
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
collections = chroma_client.list_collections()
print("Available collections:", [col.name for col in collections])