You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
61 lines
2.1 KiB
61 lines
2.1 KiB
# diagnostics_embedding_check.py |
|
import os, sys, numpy as np |
|
os.chdir("/home/lasse/riksdagen") |
|
sys.path.append("/home/lasse/riksdagen") |
|
|
|
from _chromadb.chroma_client import chroma_db |
|
|
|
col = chroma_db.get_collection("talks") |
|
print("Client embedding function object:", chroma_db.embedding_function) |
|
# Many Chroma Collection objects attach an embedding function under _embedding_function |
|
col_emb_fn = getattr(col, "_embedding_function", None) |
|
print("Collection _embedding_function:", col_emb_fn) |
|
|
|
# show class/type names |
|
print("client emb fn type:", type(chroma_db.embedding_function).__name__) |
|
if col_emb_fn is not None: |
|
print("collection emb fn type:", type(col_emb_fn).__name__) |
|
else: |
|
print("collection has no _embedding_function attribute") |
|
|
|
# embed a test query with both (if available) |
|
query = "bete utomhus kossor" |
|
def embed_with(fn, text): |
|
# try the common API names |
|
if hasattr(fn, "embed_query"): |
|
return fn.embed_query(text) |
|
if hasattr(fn, "embed_documents"): |
|
out = fn.embed_documents([text]) |
|
return out[0] |
|
# last resort: call directly |
|
return fn([text])[0] |
|
|
|
v_client = embed_with(chroma_db.embedding_function, query) |
|
print("client vector length:", len(v_client)) |
|
|
|
v_col = None |
|
if col_emb_fn is not None: |
|
try: |
|
v_col = embed_with(col_emb_fn, query) |
|
print("collection vector length:", len(v_col)) |
|
except Exception as e: |
|
print("Failed to embed with collection's embedding function:", e) |
|
|
|
# take one stored doc embedding |
|
ids = col.get(include=["ids"]).get("ids", []) |
|
if not ids: |
|
raise SystemExit("No ids in collection") |
|
sample_id = ids[0] |
|
stored = col.get(ids=[sample_id], include=["embeddings"]) |
|
stored_vec = stored["embeddings"][0][0] # embeddings list contains one array per doc chunk |
|
print("stored vector length:", len(stored_vec)) |
|
|
|
def cos(a,b): |
|
a = np.asarray(a, dtype=float) |
|
b = np.asarray(b, dtype=float) |
|
return float(np.dot(a,b) / (np.linalg.norm(a)*np.linalg.norm(b))) |
|
|
|
print("cos(client vs stored) =", cos(v_client, stored_vec)) |
|
if v_col is not None: |
|
print("cos(client vs collection_fn) =", cos(v_client, v_col)) |
|
print("cos(collection_fn vs stored) =", cos(v_col, stored_vec))
|
|
|