#!/usr/bin/env python3 """ Convert stored embeddings to plain Python lists for an existing Chroma collection. Usage: # Dry run (inspect first 50 ids) python scripts/convert_embeddings_to_lists.py --collection talks --limit 50 --dry-run # Full run (no dry run) python scripts/convert_embeddings_to_lists.py --collection talks Notes: - Run from your project root (same env you use to access chroma_db). - Back up chromadb_data before running. """ import argparse import json import os import time from pathlib import Path from typing import List import math import sys # Use the same imports/bootstrapping as you already have in your project # so the same chroma client and embedding function are loaded. # Adjust the import path if necessary. os.chdir("/home/lasse/riksdagen") sys.path.append("/home/lasse/riksdagen") import numpy as np from _chromadb.chroma_client import chroma_db CHECKPOINT_DIR = Path("var/chroma_repair") CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) def normalize_embedding(emb): """ Convert a single embedding to a plain Python list[float]. Accepts numpy arrays, array-likes, lists. """ # numpy ndarray if isinstance(emb, np.ndarray): return emb.tolist() # Some array-likes (pandas/other) may have tolist() if hasattr(emb, "tolist") and not isinstance(emb, list): try: return emb.tolist() except Exception: pass # If it's already a list of numbers, convert elements to float if isinstance(emb, list): return [float(x) for x in emb] # last resort: try iterating try: return [float(x) for x in emb] except Exception: raise ValueError("Cannot normalize embedding of type: %s" % type(emb)) def chunked_iter(iterable, n): it = iter(iterable) while True: chunk = [] try: for _ in range(n): chunk.append(next(it)) except StopIteration: pass if not chunk: break yield chunk def load_checkpoint(name): path = CHECKPOINT_DIR / f"{name}.json" if path.exists(): return json.load(path) return {"last_index": 0, "processed_ids": []} def save_checkpoint(name, data): path = CHECKPOINT_DIR / f"{name}.json" with open(path, "w") as f: json.dump(data, f) def main(): parser = argparse.ArgumentParser() parser.add_argument("--collection", required=True, help="Chroma collection name (e.g. talks)") parser.add_argument("--batch", type=int, default=1000, help="Batch size for update (default 1000)") parser.add_argument("--dry-run", action="store_true", help="Dry run: don't write updates, just report") parser.add_argument("--limit", type=int, default=None, help="Limit total number of ids to process (for testing)") parser.add_argument("--checkpoint-name", default=None, help="Name for checkpoint file (defaults to collection name)") args = parser.parse_args() coll_name = args.collection checkpoint_name = args.checkpoint_name or coll_name print(f"Connecting to Chroma collection '{coll_name}'...") col = chroma_db.get_collection(coll_name) # Get the full list of ids. For 600k this should be okay to hold in memory, # but if you need a more streaming approach, tell me and I can adapt. all_info = col.get(include=[]) # may return {'ids': [...]} as in your env ids = list(all_info.get("ids", [])) total_ids = len(ids) if args.limit: ids = ids[: args.limit] total_process = len(ids) else: total_process = total_ids print(f"Found {total_ids} ids in collection; will process {total_process} ids (limit={args.limit})") # load checkpoint ck = load_checkpoint(checkpoint_name) start_index = ck.get("last_index", 0) print(f"Resuming at index {start_index}") # iterate in batches starting from last_index processed = 0 for i in range(start_index, total_process, args.batch): batch_ids = ids[i : i + args.batch] print(f"\nProcessing batch {i}..{i+len(batch_ids)-1} (count={len(batch_ids)})") # fetch full info for this batch (documents, metadatas, embeddings) # we only need embeddings for this repair, but include docs/meta for verification if you want try: items = col.get(ids=batch_ids, include=["embeddings", "documents", "metadatas"]) except Exception as e: print("Error fetching batch:", e) # do a small retry after sleep time.sleep(2) items = col.get(ids=batch_ids, include=["embeddings", "documents", "metadatas"]) batch_embeddings = items.get("embeddings", []) # items.get("ids") should match batch_ids order; if not, align by ids ids_from_get = items.get("ids", batch_ids) if len(ids_from_get) != len(batch_ids): print("Warning: length mismatch between requested ids and returned ids") # Normalize embeddings normalized_embeddings = [] failed = False for idx, emb in enumerate(batch_embeddings): try: norm = normalize_embedding(emb) except Exception as e: print(f"Failed to normalize embedding for id {ids_from_get[idx]}: {e}") failed = True break normalized_embeddings.append(norm) if failed: print("Skipping this batch due to failures. You can adjust batch size and retry.") break # Dry-run: just print stats and continue if args.dry_run: # show a sample sample_i = min(3, len(normalized_embeddings)) print("Sample normalized embedding lengths:", [len(normalized_embeddings[k]) for k in range(sample_i)]) # Optionally inspect first few floats print("Sample values (first 6 floats):", [normalized_embeddings[k][:6] for k in range(sample_i)]) else: # Update the collection in place (update will upsert embeddings for given ids) try: col.update(ids=ids_from_get, embeddings=normalized_embeddings) except Exception as e: print("Update failed, retrying once after short sleep:", e) time.sleep(2) col.update(ids=ids_from_get, embeddings=normalized_embeddings) print(f"Updated {len(normalized_embeddings)} embeddings in collection '{coll_name}'") # checkpoint progress ck["last_index"] = i + len(batch_ids) save_checkpoint(checkpoint_name, ck) processed += len(batch_ids) print(f"\nDone. Processed {processed} ids. Checkpoint saved to {CHECKPOINT_DIR / (checkpoint_name + '.json')}") print("Reminder: run a few queries to validate search quality.") if __name__ == "__main__": main()