You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

61 lines
2.1 KiB

# diagnostics_embedding_check.py
import os, sys, numpy as np
os.chdir("/home/lasse/riksdagen")
sys.path.append("/home/lasse/riksdagen")
from _chromadb.chroma_client import chroma_db
col = chroma_db.get_collection("talks")
print("Client embedding function object:", chroma_db.embedding_function)
# Many Chroma Collection objects attach an embedding function under _embedding_function
col_emb_fn = getattr(col, "_embedding_function", None)
print("Collection _embedding_function:", col_emb_fn)
# show class/type names
print("client emb fn type:", type(chroma_db.embedding_function).__name__)
if col_emb_fn is not None:
print("collection emb fn type:", type(col_emb_fn).__name__)
else:
print("collection has no _embedding_function attribute")
# embed a test query with both (if available)
query = "bete utomhus kossor"
def embed_with(fn, text):
# try the common API names
if hasattr(fn, "embed_query"):
return fn.embed_query(text)
if hasattr(fn, "embed_documents"):
out = fn.embed_documents([text])
return out[0]
# last resort: call directly
return fn([text])[0]
v_client = embed_with(chroma_db.embedding_function, query)
print("client vector length:", len(v_client))
v_col = None
if col_emb_fn is not None:
try:
v_col = embed_with(col_emb_fn, query)
print("collection vector length:", len(v_col))
except Exception as e:
print("Failed to embed with collection's embedding function:", e)
# take one stored doc embedding
ids = col.get(include=["ids"]).get("ids", [])
if not ids:
raise SystemExit("No ids in collection")
sample_id = ids[0]
stored = col.get(ids=[sample_id], include=["embeddings"])
stored_vec = stored["embeddings"][0][0] # embeddings list contains one array per doc chunk
print("stored vector length:", len(stored_vec))
def cos(a,b):
a = np.asarray(a, dtype=float)
b = np.asarray(b, dtype=float)
return float(np.dot(a,b) / (np.linalg.norm(a)*np.linalg.norm(b)))
print("cos(client vs stored) =", cos(v_client, stored_vec))
if v_col is not None:
print("cos(client vs collection_fn) =", cos(v_client, v_col))
print("cos(collection_fn vs stored) =", cos(v_col, stored_vec))