You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

182 lines
6.7 KiB

#!/usr/bin/env python3
"""
Convert stored embeddings to plain Python lists for an existing Chroma collection.
Usage:
# Dry run (inspect first 50 ids)
python scripts/convert_embeddings_to_lists.py --collection talks --limit 50 --dry-run
# Full run (no dry run)
python scripts/convert_embeddings_to_lists.py --collection talks
Notes:
- Run from your project root (same env you use to access chroma_db).
- Back up chromadb_data before running.
"""
import argparse
import json
import os
import time
from pathlib import Path
from typing import List
import math
import sys
# Use the same imports/bootstrapping as you already have in your project
# so the same chroma client and embedding function are loaded.
# Adjust the import path if necessary.
os.chdir("/home/lasse/riksdagen")
sys.path.append("/home/lasse/riksdagen")
import numpy as np
from _chromadb.chroma_client import chroma_db
CHECKPOINT_DIR = Path("var/chroma_repair")
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
def normalize_embedding(emb):
"""
Convert a single embedding to a plain Python list[float].
Accepts numpy arrays, array-likes, lists.
"""
# numpy ndarray
if isinstance(emb, np.ndarray):
return emb.tolist()
# Some array-likes (pandas/other) may have tolist()
if hasattr(emb, "tolist") and not isinstance(emb, list):
try:
return emb.tolist()
except Exception:
pass
# If it's already a list of numbers, convert elements to float
if isinstance(emb, list):
return [float(x) for x in emb]
# last resort: try iterating
try:
return [float(x) for x in emb]
except Exception:
raise ValueError("Cannot normalize embedding of type: %s" % type(emb))
def chunked_iter(iterable, n):
it = iter(iterable)
while True:
chunk = []
try:
for _ in range(n):
chunk.append(next(it))
except StopIteration:
pass
if not chunk:
break
yield chunk
def load_checkpoint(name):
path = CHECKPOINT_DIR / f"{name}.json"
if path.exists():
return json.load(path)
return {"last_index": 0, "processed_ids": []}
def save_checkpoint(name, data):
path = CHECKPOINT_DIR / f"{name}.json"
with open(path, "w") as f:
json.dump(data, f)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--collection", required=True, help="Chroma collection name (e.g. talks)")
parser.add_argument("--batch", type=int, default=1000, help="Batch size for update (default 1000)")
parser.add_argument("--dry-run", action="store_true", help="Dry run: don't write updates, just report")
parser.add_argument("--limit", type=int, default=None, help="Limit total number of ids to process (for testing)")
parser.add_argument("--checkpoint-name", default=None, help="Name for checkpoint file (defaults to collection name)")
args = parser.parse_args()
coll_name = args.collection
checkpoint_name = args.checkpoint_name or coll_name
print(f"Connecting to Chroma collection '{coll_name}'...")
col = chroma_db.get_collection(coll_name)
# Get the full list of ids. For 600k this should be okay to hold in memory,
# but if you need a more streaming approach, tell me and I can adapt.
all_info = col.get(include=[]) # may return {'ids': [...]} as in your env
ids = list(all_info.get("ids", []))
total_ids = len(ids)
if args.limit:
ids = ids[: args.limit]
total_process = len(ids)
else:
total_process = total_ids
print(f"Found {total_ids} ids in collection; will process {total_process} ids (limit={args.limit})")
# load checkpoint
ck = load_checkpoint(checkpoint_name)
start_index = ck.get("last_index", 0)
print(f"Resuming at index {start_index}")
# iterate in batches starting from last_index
processed = 0
for i in range(start_index, total_process, args.batch):
batch_ids = ids[i : i + args.batch]
print(f"\nProcessing batch {i}..{i+len(batch_ids)-1} (count={len(batch_ids)})")
# fetch full info for this batch (documents, metadatas, embeddings)
# we only need embeddings for this repair, but include docs/meta for verification if you want
try:
items = col.get(ids=batch_ids, include=["embeddings", "documents", "metadatas"])
except Exception as e:
print("Error fetching batch:", e)
# do a small retry after sleep
time.sleep(2)
items = col.get(ids=batch_ids, include=["embeddings", "documents", "metadatas"])
batch_embeddings = items.get("embeddings", [])
# items.get("ids") should match batch_ids order; if not, align by ids
ids_from_get = items.get("ids", batch_ids)
if len(ids_from_get) != len(batch_ids):
print("Warning: length mismatch between requested ids and returned ids")
# Normalize embeddings
normalized_embeddings = []
failed = False
for idx, emb in enumerate(batch_embeddings):
try:
norm = normalize_embedding(emb)
except Exception as e:
print(f"Failed to normalize embedding for id {ids_from_get[idx]}: {e}")
failed = True
break
normalized_embeddings.append(norm)
if failed:
print("Skipping this batch due to failures. You can adjust batch size and retry.")
break
# Dry-run: just print stats and continue
if args.dry_run:
# show a sample
sample_i = min(3, len(normalized_embeddings))
print("Sample normalized embedding lengths:", [len(normalized_embeddings[k]) for k in range(sample_i)])
# Optionally inspect first few floats
print("Sample values (first 6 floats):", [normalized_embeddings[k][:6] for k in range(sample_i)])
else:
# Update the collection in place (update will upsert embeddings for given ids)
try:
col.update(ids=ids_from_get, embeddings=normalized_embeddings)
except Exception as e:
print("Update failed, retrying once after short sleep:", e)
time.sleep(2)
col.update(ids=ids_from_get, embeddings=normalized_embeddings)
print(f"Updated {len(normalized_embeddings)} embeddings in collection '{coll_name}'")
# checkpoint progress
ck["last_index"] = i + len(batch_ids)
save_checkpoint(checkpoint_name, ck)
processed += len(batch_ids)
print(f"\nDone. Processed {processed} ids. Checkpoint saved to {CHECKPOINT_DIR / (checkpoint_name + '.json')}")
print("Reminder: run a few queries to validate search quality.")
if __name__ == "__main__":
main()