- Created a template for providers.yaml to define API providers and models. - Added a new providers.yaml file with initial provider configurations. - Implemented fix_things.py to update chunk documents in ArangoDB. - Developed make_arango_embeddings.py to generate embeddings for talks and store them in ArangoDB. - Introduced sync_talks.py to synchronize new speeches from riksdagen.se and process them. - Added notes.md for documentation on riksdagsgruppen login details. - Created test_make_arango_embeddings.py for integration testing of embedding generation. - Implemented test_gpu.py to test image input handling with vLLM.master
parent
3ba8c3340a
commit
88e0244429
37 changed files with 2282 additions and 373 deletions
Binary file not shown.
@ -0,0 +1,144 @@ |
||||
#!/usr/bin/env python3 |
||||
""" |
||||
System Resource Monitor - Logs system stats to help diagnose SSH connectivity issues. |
||||
|
||||
This script monitors: |
||||
- CPU usage |
||||
- Memory usage |
||||
- Disk usage |
||||
- Network connectivity |
||||
- SSH service status |
||||
- System load |
||||
- Active connections |
||||
|
||||
Run continuously to capture when the system becomes unreachable. |
||||
""" |
||||
|
||||
import psutil |
||||
import time |
||||
import logging |
||||
from datetime import datetime |
||||
from pathlib import Path |
||||
|
||||
# Setup logging to file with rotation |
||||
log_file = Path("/var/log/system_monitor.log") |
||||
logging.basicConfig( |
||||
level=logging.INFO, |
||||
format='%(asctime)s - %(levelname)s - %(message)s', |
||||
handlers=[ |
||||
logging.FileHandler(log_file), |
||||
logging.StreamHandler() # Also print to console |
||||
] |
||||
) |
||||
|
||||
def check_ssh_service() -> dict: |
||||
""" |
||||
Check if SSH service is running. |
||||
|
||||
Returns: |
||||
dict: Service status information |
||||
""" |
||||
try: |
||||
import subprocess |
||||
result = subprocess.run( |
||||
['systemctl', 'is-active', 'ssh'], |
||||
capture_output=True, |
||||
text=True, |
||||
timeout=5 |
||||
) |
||||
return { |
||||
'running': result.returncode == 0, |
||||
'status': result.stdout.strip() |
||||
} |
||||
except Exception as e: |
||||
return {'running': False, 'error': str(e)} |
||||
|
||||
def get_system_stats() -> dict: |
||||
""" |
||||
Collect current system statistics. |
||||
|
||||
Returns: |
||||
dict: System statistics including CPU, memory, disk, network |
||||
""" |
||||
# CPU usage |
||||
cpu_percent = psutil.cpu_percent(interval=1) |
||||
cpu_count = psutil.cpu_count() |
||||
|
||||
# Memory usage |
||||
memory = psutil.virtual_memory() |
||||
swap = psutil.swap_memory() |
||||
|
||||
# Disk usage |
||||
disk = psutil.disk_usage('/') |
||||
|
||||
# Network stats |
||||
net_io = psutil.net_io_counters() |
||||
|
||||
# System load (1, 5, 15 minute averages) |
||||
load_avg = psutil.getloadavg() |
||||
|
||||
# Number of connections |
||||
connections = len(psutil.net_connections()) |
||||
|
||||
return { |
||||
'cpu_percent': cpu_percent, |
||||
'cpu_count': cpu_count, |
||||
'memory_percent': memory.percent, |
||||
'memory_available_gb': memory.available / (1024**3), |
||||
'swap_percent': swap.percent, |
||||
'disk_percent': disk.percent, |
||||
'disk_free_gb': disk.free / (1024**3), |
||||
'network_bytes_sent': net_io.bytes_sent, |
||||
'network_bytes_recv': net_io.bytes_recv, |
||||
'load_1min': load_avg[0], |
||||
'load_5min': load_avg[1], |
||||
'load_15min': load_avg[2], |
||||
'connections': connections |
||||
} |
||||
|
||||
def monitor_loop(interval_seconds: int = 60): |
||||
""" |
||||
Main monitoring loop that logs system stats at regular intervals. |
||||
|
||||
Args: |
||||
interval_seconds: How often to log stats (default: 60 seconds) |
||||
""" |
||||
logging.info("Starting system monitoring...") |
||||
|
||||
while True: |
||||
try: |
||||
stats = get_system_stats() |
||||
ssh_status = check_ssh_service() |
||||
|
||||
# Log current stats |
||||
log_message = ( |
||||
f"CPU: {stats['cpu_percent']:.1f}% | " |
||||
f"MEM: {stats['memory_percent']:.1f}% ({stats['memory_available_gb']:.2f}GB free) | " |
||||
f"DISK: {stats['disk_percent']:.1f}% ({stats['disk_free_gb']:.2f}GB free) | " |
||||
f"LOAD: {stats['load_1min']:.2f} {stats['load_5min']:.2f} {stats['load_15min']:.2f} | " |
||||
f"CONN: {stats['connections']} | " |
||||
f"SSH: {ssh_status.get('status', 'unknown')}" |
||||
) |
||||
|
||||
# Warning thresholds |
||||
if stats['cpu_percent'] > 90: |
||||
logging.warning(f"HIGH CPU! {log_message}") |
||||
elif stats['memory_percent'] > 90: |
||||
logging.warning(f"HIGH MEMORY! {log_message}") |
||||
elif stats['disk_percent'] > 90: |
||||
logging.warning(f"HIGH DISK USAGE! {log_message}") |
||||
elif stats['load_1min'] > stats['cpu_count'] * 2: |
||||
logging.warning(f"HIGH LOAD! {log_message}") |
||||
elif not ssh_status.get('running'): |
||||
logging.error(f"SSH SERVICE DOWN! {log_message}") |
||||
else: |
||||
logging.info(log_message) |
||||
|
||||
time.sleep(interval_seconds) |
||||
|
||||
except Exception as e: |
||||
logging.error(f"Error in monitoring loop: {e}") |
||||
time.sleep(interval_seconds) |
||||
|
||||
if __name__ == "__main__": |
||||
monitor_loop(interval_seconds=60) # Log every 60 seconds |
||||
@ -0,0 +1,19 @@ |
||||
[Unit] |
||||
Description=Riksdagen daily talk sync |
||||
# Wait for network before starting |
||||
After=network-online.target |
||||
Wants=network-online.target |
||||
|
||||
[Service] |
||||
Type=oneshot |
||||
User=lasse |
||||
WorkingDirectory=/home/lasse/riksdagen |
||||
# Loads ARANGO_PWD and other env vars from the project .env file |
||||
EnvironmentFile=/home/lasse/riksdagen/.env |
||||
ExecStart=/home/lasse/riksdagen/.venv/bin/python /home/lasse/riksdagen/scripts/sync_talks.py |
||||
# Log stdout/stderr to the systemd journal (view with: journalctl -u riksdagen-sync) |
||||
StandardOutput=journal |
||||
StandardError=journal |
||||
|
||||
[Install] |
||||
WantedBy=multi-user.target |
||||
@ -0,0 +1,11 @@ |
||||
[Unit] |
||||
Description=Run riksdagen daily talk sync at 06:00 |
||||
|
||||
[Timer] |
||||
# Run every day at 06:00 |
||||
OnCalendar=*-*-* 06:00:00 |
||||
# If the server was off at 06:00, run the job as soon as it comes back up |
||||
Persistent=true |
||||
|
||||
[Install] |
||||
WantedBy=timers.target |
||||
@ -0,0 +1,10 @@ |
||||
[Interface] |
||||
PrivateKey = yDRb0EYZkUZCuYax44lSBAP3vmN+mPdDQEh2hAQ10lY= |
||||
Address = 10.156.168.2/24 |
||||
DNS = 1.1.1.1, 1.0.0.1 |
||||
|
||||
[Peer] |
||||
PublicKey = 6gwhWDypmpxrGaobEh8xZIXvRIKdp0pWH6YWZ9F8twY= |
||||
PresharedKey = XAD5qpUMr0Ouz2azeXfH7J5tE3iSi5XJOdzdrUTSbRg= |
||||
Endpoint = 98.128.172.165:51820 |
||||
AllowedIPs = 0.0.0.0/0, ::0/0 |
||||
@ -0,0 +1,6 @@ |
||||
""" |
||||
Public entry points for the Riksdagen MCP server package. |
||||
""" |
||||
from .server import run |
||||
|
||||
__all__ = ("run",) |
||||
@ -0,0 +1,24 @@ |
||||
from __future__ import annotations |
||||
|
||||
import os |
||||
import secrets |
||||
|
||||
|
||||
def validate_token(provided_token: str) -> None: |
||||
""" |
||||
Ensure the caller supplied the expected bearer token. |
||||
|
||||
Args: |
||||
provided_token: Token received from the MCP client. |
||||
|
||||
Raises: |
||||
RuntimeError: If the server token is not configured. |
||||
PermissionError: If the token is missing or incorrect. |
||||
""" |
||||
expected_token = os.getenv("MCP_SERVER_TOKEN") |
||||
if not expected_token: |
||||
raise RuntimeError("MCP_SERVER_TOKEN environment variable must be set for authentication.") |
||||
if not provided_token: |
||||
raise PermissionError("Missing MCP access token.") |
||||
if not secrets.compare_digest(provided_token, expected_token): |
||||
raise PermissionError("Invalid MCP access token.") |
||||
@ -0,0 +1,130 @@ |
||||
import ssl |
||||
import socket |
||||
from datetime import datetime |
||||
from typing import Any, Dict, List, Optional |
||||
import argparse |
||||
import pprint |
||||
import sys # added to detect whether --host was passed |
||||
|
||||
def fetch_certificate(host: str, port: int = 443, server_hostname: Optional[str] = None, timeout: float = 5.0) -> Dict[str, Any]: |
||||
""" |
||||
Fetch the TLS certificate from host:port. This function intentionally |
||||
uses a non-verifying SSL context to retrieve the certificate even if it |
||||
doesn't validate, so we can inspect its fields. |
||||
|
||||
Parameters: |
||||
- host: TCP connect target (can be an IP or hostname) |
||||
- port: TCP port (default 443) |
||||
- server_hostname: SNI value to send. If None, server_hostname = host. |
||||
- timeout: socket connect timeout in seconds |
||||
|
||||
Returns: |
||||
Dictionary with the peer certificate (as returned by SSLSocket.getpeercert()) and additional metadata. |
||||
""" |
||||
if server_hostname is None: |
||||
server_hostname = host |
||||
|
||||
# Create an SSL context that does NOT verify so we can always fetch the cert. |
||||
context = ssl.create_default_context() |
||||
context.check_hostname = False |
||||
context.verify_mode = ssl.CERT_NONE |
||||
|
||||
with socket.create_connection((host, port), timeout=timeout) as sock: |
||||
with context.wrap_socket(sock, server_hostname=server_hostname) as sslsock: |
||||
cert = sslsock.getpeercert() |
||||
peer_cipher = sslsock.cipher() |
||||
peertime = datetime.utcnow().isoformat() + "Z" |
||||
|
||||
info: Dict[str, Any] = {"peer_certificate": cert, "cipher": peer_cipher, "fetched_at": peertime, "server_hostname_used": server_hostname} |
||||
return info |
||||
|
||||
def parse_san(cert: Dict[str, Any]) -> List[str]: |
||||
"""Return list of DNS names from subjectAltName (if any).""" |
||||
san = [] |
||||
for typ, val in cert.get("subjectAltName", ()): |
||||
if typ.lower() == "dns": |
||||
san.append(val) |
||||
return san |
||||
|
||||
def format_subject(cert: Dict[str, Any]) -> str: |
||||
"""Return a short human-friendly subject string.""" |
||||
subject = cert.get("subject", ()) |
||||
parts = [] |
||||
for rdn in subject: |
||||
for k, v in rdn: |
||||
parts.append(f"{k}={v}") |
||||
return ", ".join(parts) |
||||
|
||||
def check_hostname_match(cert: Dict[str, Any], hostname: str) -> bool: |
||||
""" |
||||
Check whether the certificate matches hostname using ssl.match_hostname. |
||||
Returns True if match, False otherwise. |
||||
""" |
||||
try: |
||||
ssl.match_hostname(cert, hostname) |
||||
return True |
||||
except Exception: |
||||
return False |
||||
|
||||
def print_report(host: str, port: int, server_hostname: Optional[str]) -> None: |
||||
"""Fetch certificate and print a readable report.""" |
||||
info = fetch_certificate(host=host, port=port, server_hostname=server_hostname) |
||||
cert = info["peer_certificate"] |
||||
|
||||
print(f"Connected target: {host}:{port}") |
||||
print(f"SNI sent: {info['server_hostname_used']}") |
||||
print(f"Cipher: {info['cipher']}") |
||||
print(f"Fetched at (UTC): {info['fetched_at']}") |
||||
print() |
||||
|
||||
print("Subject:") |
||||
print(" ", format_subject(cert)) |
||||
print() |
||||
|
||||
print("Issuer:") |
||||
issuer = cert.get("issuer", ()) |
||||
issuer_parts = [] |
||||
for rdn in issuer: |
||||
for k, v in rdn: |
||||
issuer_parts.append(f"{k}={v}") |
||||
print(" ", ", ".join(issuer_parts)) |
||||
print() |
||||
|
||||
sans = parse_san(cert) |
||||
print("Subject Alternative Names (SANs):") |
||||
if sans: |
||||
for n in sans: |
||||
print(" -", n) |
||||
else: |
||||
print(" (none)") |
||||
|
||||
not_before = cert.get("notBefore") |
||||
not_after = cert.get("notAfter") |
||||
print() |
||||
print("Validity:") |
||||
print(" notBefore:", not_before) |
||||
print(" notAfter: ", not_after) |
||||
|
||||
match = check_hostname_match(cert, server_hostname or host) |
||||
print() |
||||
print(f"Hostname match for '{server_hostname or host}':", "YES" if match else "NO") |
||||
|
||||
# For debugging show the full cert dict if requested |
||||
# pprint.pprint(cert) |
||||
|
||||
def main() -> None: |
||||
parser = argparse.ArgumentParser(description="Fetch and inspect TLS certificate from a host (SNI-aware).") |
||||
# make --host optional and default to api.rixdagen.se so running without args works |
||||
parser.add_argument("--host", "-H", required=False, default="api.rixdagen.se", help="Host or IP to connect to (TCP target). Defaults to api.rixdagen.se") |
||||
parser.add_argument("--port", "-p", type=int, default=443, help="Port to connect to (default 443).") |
||||
parser.add_argument("--sni", help="SNI hostname to send. If omitted, the --host value is used as SNI.") |
||||
args = parser.parse_args() |
||||
|
||||
# Notify when using the default host for quick testing |
||||
if ("--host" not in sys.argv) and ("-H" not in sys.argv): |
||||
print("No --host provided: defaulting to api.rixdagen.se (you can override with --host or -H)") |
||||
|
||||
print_report(host=args.host, port=args.port, server_hostname=args.sni) |
||||
|
||||
if __name__ == "__main__": |
||||
main() |
||||
@ -0,0 +1,114 @@ |
||||
""" |
||||
RiksdagenTools MCP server (HTTP only, compatible with current FastMCP version) |
||||
""" |
||||
from __future__ import annotations |
||||
import asyncio |
||||
import logging |
||||
import os |
||||
import inspect |
||||
from typing import Any, Dict, List, Optional, Sequence |
||||
from fastmcp import FastMCP |
||||
from mcp_server.auth import validate_token |
||||
from mcp_server import tools |
||||
|
||||
HOST = os.getenv("MCP_HOST", "127.0.0.1") |
||||
PORT = int(os.getenv("MCP_PORT", "8010")) |
||||
PATH = os.getenv("MCP_PATH", "/mcp") |
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() |
||||
|
||||
logging.basicConfig( |
||||
level=LOG_LEVEL, |
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", |
||||
datefmt="%Y-%m-%d %H:%M:%S", |
||||
) |
||||
log = logging.getLogger("mcp_server") |
||||
|
||||
app = FastMCP("RiksdagenTools") |
||||
|
||||
# --- tools unchanged --- |
||||
@app.tool() |
||||
async def search_documents(token: str, aql_query: str) -> Dict[str, Any]: |
||||
validate_token(token) |
||||
return await asyncio.to_thread(tools.search_documents, aql_query) |
||||
|
||||
@app.tool() |
||||
async def aql_query(token: str, query: str) -> List[Dict[str, Any]]: |
||||
validate_token(token) |
||||
return await asyncio.to_thread(tools.run_aql_query, query) |
||||
|
||||
@app.tool() |
||||
async def vector_search_talks(token: str, query: str, limit: int = 8) -> List[Dict[str, Any]]: |
||||
validate_token(token) |
||||
return await asyncio.to_thread(tools.vector_search, query, limit) |
||||
|
||||
@app.tool() |
||||
async def fetch_documents( |
||||
token: str, document_ids: Sequence[str], fields: Optional[Sequence[str]] = None |
||||
) -> List[Dict[str, Any]]: |
||||
validate_token(token) |
||||
return await asyncio.to_thread( |
||||
tools.fetch_documents, list(document_ids), list(fields) if fields else None |
||||
) |
||||
|
||||
@app.tool() |
||||
async def arango_search( |
||||
token: str, |
||||
query: str, |
||||
limit: int = 20, |
||||
parties: Optional[Sequence[str]] = None, |
||||
people: Optional[Sequence[str]] = None, |
||||
from_year: Optional[int] = None, |
||||
to_year: Optional[int] = None, |
||||
return_snippets: bool = False, |
||||
focus_ids: Optional[Sequence[str]] = None, |
||||
speaker_ids: Optional[Sequence[str]] = None, |
||||
) -> Dict[str, Any]: |
||||
validate_token(token) |
||||
return await asyncio.to_thread( |
||||
tools.arango_search, |
||||
query, |
||||
limit, |
||||
parties, |
||||
people, |
||||
from_year, |
||||
to_year, |
||||
return_snippets, |
||||
focus_ids, |
||||
speaker_ids, |
||||
) |
||||
|
||||
@app.tool() |
||||
async def ping() -> str: |
||||
""" |
||||
Lightweight test tool for connectivity checks. |
||||
|
||||
Returns: |
||||
A short confirmation string ("ok"). This tool is intentionally |
||||
unauthenticated so it can be used to validate the transport/proxy |
||||
(e.g. nginx -> backend) without presenting credentials. |
||||
""" |
||||
return "ok" |
||||
|
||||
# --- Entrypoint --- |
||||
def run() -> None: |
||||
log.info( |
||||
"Starting RiksdagenTools MCP server (HTTP) on http://%s:%d%s", |
||||
HOST, |
||||
PORT, |
||||
PATH, |
||||
) |
||||
|
||||
try: |
||||
# Pass host, port, and path directly to run() method |
||||
app.run( |
||||
transport="streamable-http", |
||||
host=HOST, |
||||
port=PORT, |
||||
path=PATH, |
||||
) |
||||
except Exception: |
||||
log.exception("Unexpected error while running the MCP server.") |
||||
raise |
||||
|
||||
if __name__ == "__main__": |
||||
run() |
||||
@ -0,0 +1,168 @@ |
||||
""" |
||||
Test script for the RiksdagenTools MCP server (HTTP transport). |
||||
This script connects to the MCP server via Streamable HTTP and tests your main tools. |
||||
Ensure the MCP server is running and that the environment variable MCP_SERVER_TOKEN is set. |
||||
Also adjust MCP_SERVER_URL if needed. |
||||
""" |
||||
|
||||
import os |
||||
import asyncio |
||||
from typing import Any, Dict, List, Optional, Sequence |
||||
|
||||
from mcp.client.streamable_http import streamablehttp_client |
||||
from mcp.client.session import ClientSession # adjusted import per SDK version |
||||
|
||||
TOKEN: str = os.environ.get("MCP_SERVER_TOKEN", "2q89rwpfaiukdjshp298n3qw") |
||||
SERVER_URL: str = os.environ.get("MCP_SERVER_URL", "https://api.rixdagen.se/mcp") # use the public HTTPS endpoint by default so tests target the nginx proxy |
||||
|
||||
async def run_tests() -> None: |
||||
""" |
||||
Attempt to connect to SERVER_URL and run the tool tests. On SSL certificate |
||||
verification failures, optionally retry against the backend IP if |
||||
MCP_FALLBACK_TO_IP=1 is set (or a custom MCP_FALLBACK_URL is provided). |
||||
""" |
||||
async def run_with_url(url: str) -> None: |
||||
print(f"Connecting to server URL: {url}") |
||||
async with streamablehttp_client( |
||||
url=url, |
||||
headers={ "Authorization": f"Bearer {TOKEN}" } |
||||
) as (read_stream, write_stream, get_session_id): |
||||
async with ClientSession(read_stream, write_stream) as session: |
||||
# initialize the session (if needed) |
||||
init_result = await session.initialize() |
||||
print("Initialized session:", init_result) |
||||
|
||||
print("\nListing available tools...") |
||||
tools_info = await session.list_tools() |
||||
print("Tools:", [ tool.name for tool in tools_info.tools ]) |
||||
|
||||
# Test aql_query |
||||
print("\n== Testing aql_query ==") |
||||
result1 = await session.call_tool( |
||||
"aql_query", |
||||
arguments={ |
||||
"token": TOKEN, |
||||
"query": "FOR doc IN talks LIMIT 2 RETURN { _id: doc._id, talare: doc.talare }" |
||||
} |
||||
) |
||||
print("aql_query result:", result1) |
||||
|
||||
# Test search_documents |
||||
print("\n== Testing search_documents ==") |
||||
result2 = await session.call_tool( |
||||
"search_documents", |
||||
arguments={ |
||||
"token": TOKEN, |
||||
"aql_query": "FOR doc IN talks LIMIT 2 RETURN { _id: doc._id, talare: doc.talare }" |
||||
} |
||||
) |
||||
print("search_documents result:", result2) |
||||
|
||||
# Test vector_search_talks |
||||
print("\n== Testing vector_search_talks ==") |
||||
result3 = await session.call_tool( |
||||
"vector_search_talks", |
||||
arguments={ |
||||
"token": TOKEN, |
||||
"query": "klimatförändringar", |
||||
"limit": 2 |
||||
} |
||||
) |
||||
print("vector_search_talks result:", result3) |
||||
|
||||
# Test fetch_documents |
||||
print("\n== Testing fetch_documents ==") |
||||
# try to pull out IDs from result3 if available |
||||
doc_ids: List[str] |
||||
maybe = result3 |
||||
if hasattr(maybe, "output") and isinstance(maybe.output, list) and maybe.output: |
||||
doc_ids = [ maybe.output[0].get("_id", "") ] |
||||
else: |
||||
doc_ids = ["talks/1"] |
||||
result4 = await session.call_tool( |
||||
"fetch_documents", |
||||
arguments={ |
||||
"token": TOKEN, |
||||
"document_ids": doc_ids |
||||
} |
||||
) |
||||
print("fetch_documents result:", result4) |
||||
|
||||
# Test arango_search |
||||
print("\n== Testing arango_search ==") |
||||
result5 = await session.call_tool( |
||||
"arango_search", |
||||
arguments={ |
||||
"token": TOKEN, |
||||
"query": "klimat", |
||||
"limit": 2 |
||||
} |
||||
) |
||||
print("arango_search result:", result5) |
||||
|
||||
# try primary URL first |
||||
try: |
||||
await run_with_url(SERVER_URL) |
||||
except Exception as e: # capture failures from streamablehttp_client / httpx |
||||
err_str = str(e).lower() |
||||
ssl_fail = "certificate_verify_failed" in err_str or "hostname mismatch" in err_str or "certificate verify failed" in err_str |
||||
gateway_fail = "502" in err_str or "bad gateway" in err_str or "502 bad gateway" in err_str |
||||
|
||||
if gateway_fail: |
||||
print("Received 502 Bad Gateway from the proxy when connecting to the server URL.") |
||||
# If user explicitly set fallback env var, retry against backend IP or custom fallback |
||||
fallback_flag = os.environ.get("MCP_FALLBACK_TO_IP", "0").lower() in ("1", "true", "yes") |
||||
if fallback_flag: |
||||
fallback_url = os.environ.get("MCP_FALLBACK_URL", "http://127.0.0.1:8010/mcp") |
||||
print(f"Retrying with fallback URL (MCP_FALLBACK_URL or default backend): {fallback_url}") |
||||
await run_with_url(fallback_url) |
||||
return |
||||
print("") |
||||
print("Possible causes:") |
||||
print("- The proxy (nginx) couldn't reach the backend (backend down, wrong proxy_pass or path).") |
||||
print("- Proxy buffering or HTTP version issues interfering with streaming transport.") |
||||
print("") |
||||
print("Options:") |
||||
print("- Bypass the proxy and target the backend directly:") |
||||
print(" export MCP_SERVER_URL='http://127.0.0.1:8010/mcp'") |
||||
print("- Or enable automatic fallback to the backend (insecure) for testing:") |
||||
print(" export MCP_FALLBACK_TO_IP=1") |
||||
print(" # optionally override the fallback target") |
||||
print(" export MCP_FALLBACK_URL='http://127.0.0.1:8010/mcp'") |
||||
print("- Check the proxy's error log (e.g. /var/log/nginx/error.log) for upstream errors.") |
||||
print("") |
||||
# re-raise so caller still sees the error if they don't follow guidance |
||||
raise |
||||
|
||||
if ssl_fail: |
||||
print("SSL certificate verification failed while connecting to the server URL.") |
||||
# If user explicitly set fallback env var, retry against backend IP or custom fallback |
||||
fallback_flag = os.environ.get("MCP_FALLBACK_TO_IP", "0").lower() in ("1", "true", "yes") |
||||
if fallback_flag: |
||||
fallback_url = os.environ.get("MCP_FALLBACK_URL", "http://192.168.1.10:8010/mcp") |
||||
print(f"Retrying with fallback URL (MCP_FALLBACK_URL or default backend IP): {fallback_url}") |
||||
await run_with_url(fallback_url) |
||||
return |
||||
# Otherwise give actionable guidance |
||||
print("") |
||||
print("Possible causes:") |
||||
print("- The TLS certificate served for api.rixdagen.se does not match that hostname on this machine.") |
||||
print("") |
||||
print("Options:") |
||||
print("- Set MCP_SERVER_URL to the backend HTTP address to bypass TLS: e.g.") |
||||
print(" export MCP_SERVER_URL='http://192.168.1.10:8010/mcp'") |
||||
print("- Or enable automatic fallback to the backend IP for testing (insecure):") |
||||
print(" export MCP_FALLBACK_TO_IP=1") |
||||
print(" # optionally override the fallback target") |
||||
print(" export MCP_FALLBACK_URL='http://192.168.1.10:8010/mcp'") |
||||
print("") |
||||
# re-raise so caller still sees the error if they don't follow guidance |
||||
raise |
||||
# Not an SSL or gateway failure: re-raise |
||||
raise |
||||
|
||||
def main() -> None: |
||||
asyncio.run(run_tests()) |
||||
|
||||
if __name__ == "__main__": |
||||
main() |
||||
@ -0,0 +1,81 @@ |
||||
import os |
||||
import sys |
||||
import asyncio |
||||
from typing import Tuple, Any |
||||
|
||||
from mcp.client.streamable_http import streamablehttp_client |
||||
from mcp.client.session import ClientSession |
||||
|
||||
SERVER_URL: str = os.environ.get("MCP_SERVER_URL", "http://127.0.0.1:8010/mcp") |
||||
TOKEN: str = os.environ.get("MCP_SERVER_TOKEN", "") |
||||
|
||||
|
||||
async def _extract_ping_result(res: Any) -> Any: |
||||
""" |
||||
Extract a sensible value from various CallToolResult shapes returned by the SDK. |
||||
|
||||
Handles: |
||||
- objects with 'structuredContent' (dict) -> use 'result' or first value |
||||
- objects with 'output' attribute |
||||
- objects with 'content' list containing a TextContent with .text |
||||
- plain scalars |
||||
""" |
||||
# structuredContent is common in newer SDK responses |
||||
if hasattr(res, "structuredContent") and isinstance(res.structuredContent, dict): |
||||
# prefer a 'result' key |
||||
if "result" in res.structuredContent: |
||||
return res.structuredContent["result"] |
||||
# fallback to any first value |
||||
for v in res.structuredContent.values(): |
||||
return v |
||||
|
||||
# older / alternative shape: an 'output' attribute |
||||
if hasattr(res, "output"): |
||||
return res.output |
||||
|
||||
# textual content list (TextContent objects) |
||||
if hasattr(res, "content") and isinstance(res.content, list) and res.content: |
||||
first = res.content[0] |
||||
# Some SDK TextContent exposes 'text' |
||||
if hasattr(first, "text"): |
||||
return first.text |
||||
# fallback to stringifying the object |
||||
return str(first) |
||||
|
||||
# final fallback: try direct indexing or string conversion |
||||
try: |
||||
return res["result"] |
||||
except Exception: |
||||
return str(res) |
||||
|
||||
|
||||
async def check_ping(url: str, token: str) -> Tuple[bool, str]: |
||||
""" |
||||
Connect to the MCP server at `url` and call the 'ping' tool. |
||||
|
||||
Returns: |
||||
(ok, message) where ok is True if ping returned "ok", otherwise False. |
||||
""" |
||||
headers = {"Authorization": f"Bearer {token}"} if token else {} |
||||
try: |
||||
async with streamablehttp_client(url=url, headers=headers) as (read_stream, write_stream, _): |
||||
async with ClientSession(read_stream, write_stream) as session: |
||||
await session.initialize() |
||||
res = await session.call_tool("ping", arguments={}) |
||||
output = await _extract_ping_result(res) # robust extractor |
||||
if output == "ok": |
||||
return True, "ping -> ok" |
||||
return False, f"unexpected ping response: {output!r}" |
||||
except Exception as e: |
||||
return False, f"error connecting/calling ping: {e!r}" |
||||
|
||||
|
||||
def main() -> None: |
||||
ok, msg = asyncio.run(check_ping(SERVER_URL, TOKEN)) |
||||
print(msg) |
||||
if not ok: |
||||
sys.exit(1) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
main() |
||||
@ -0,0 +1,360 @@ |
||||
from __future__ import annotations |
||||
|
||||
from dataclasses import dataclass |
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence |
||||
from pydantic import BaseModel, Field |
||||
|
||||
import env_manager |
||||
|
||||
env_manager.set_env() |
||||
|
||||
from arango.collection import Collection # noqa: E402 |
||||
from arango_client import arango # noqa: E402 |
||||
from backend.services.search import SearchService # noqa: E402 |
||||
from _chromadb.chroma_client import chroma_db # noqa: E402 |
||||
|
||||
|
||||
class HitDocument(BaseModel): |
||||
""" |
||||
HitDocument is a Pydantic model that provides a normalized representation of a search hit across various tools, enabling consistent downstream handling. |
||||
|
||||
Attributes: |
||||
id (Optional[str]): Fully qualified ArangoDB document identifier. |
||||
key (Optional[str]): Document key without collection prefix. |
||||
speaker (Optional[str]): Name of the speaker associated with the hit. |
||||
party (Optional[str]): Party affiliation of the speaker. |
||||
date (Optional[str]): ISO formatted document date (YYYY-MM-DD). |
||||
snippet (Optional[str]): Contextual snippet or highlight from the document. |
||||
text (Optional[str]): Full text of the document when available. |
||||
score (Optional[float]): Relevance score supplied by the executing tool. |
||||
metadata (Dict[str, Any]): Additional metadata specific to the originating tool that should be preserved. |
||||
|
||||
Methods: |
||||
to_string() -> str: |
||||
Renders the hit as a human-readable string with uppercase labels, including all present fields and metadata. |
||||
""" |
||||
|
||||
"""Normalized representation of a search hit across tools to enable consistent downstream handling.""" |
||||
id: Optional[str] = Field( |
||||
default=None, description="Fully qualified ArangoDB document identifier." |
||||
) |
||||
key: Optional[str] = Field( |
||||
default=None, description="Document key without collection prefix." |
||||
) |
||||
speaker: Optional[str] = Field( |
||||
default=None, description="Name of the speaker associated with the hit." |
||||
) |
||||
party: Optional[str] = Field( |
||||
default=None, description="Party affiliation of the speaker." |
||||
) |
||||
date: Optional[str] = Field( |
||||
default=None, description="ISO formatted document date (YYYY-MM-DD)." |
||||
) |
||||
snippet: Optional[str] = Field( |
||||
default=None, description="Contextual snippet or highlight from the document." |
||||
) |
||||
text: Optional[str] = Field( |
||||
default=None, description="Full text of the document when available." |
||||
) |
||||
score: Optional[float] = Field( |
||||
default=None, description="Relevance score supplied by the executing tool." |
||||
) |
||||
metadata: Dict[str, Any] = Field( |
||||
default_factory=dict, |
||||
description="Additional metadata specific to the originating tool that should be preserved.", |
||||
) |
||||
|
||||
def to_string(self, include_metadata: bool = True) -> str: |
||||
""" |
||||
Render the object as a human-readable string with uppercase labels. |
||||
|
||||
Args: |
||||
include_metadata (bool, optional): Whether to include metadata fields in the output. Defaults to True. |
||||
|
||||
Returns: |
||||
str: A formatted string representation of the object, with each field and its value separated by double newlines, and field names in uppercase. |
||||
""" |
||||
data: Dict[str, Any] = self.model_dump(exclude_none=True) |
||||
metadata: Dict[str, Any] = data.pop("metadata", {}) |
||||
segments: List[str] = [] |
||||
for field_name, field_value in data.items(): |
||||
segments.append(f"{field_name.upper()}\n{field_value}") |
||||
for meta_key, meta_value in metadata.items(): |
||||
segments.append(f"{meta_key.upper()}\n{meta_value}") |
||||
return "\n\n".join(segments) |
||||
|
||||
|
||||
class HitsResponse(BaseModel): |
||||
""" |
||||
HitsResponse is a Pydantic model that serves as a container for multiple HitDocument instances, providing utility methods for formatting and rendering the collection. |
||||
|
||||
Attributes: |
||||
hits (List[HitDocument]): A list of collected search hits. |
||||
|
||||
Methods: |
||||
to_string() -> str: |
||||
Returns a string representation of all hits, separated by a visual divider. If there are no hits, returns an empty string. |
||||
""" |
||||
|
||||
hits: List[HitDocument] = Field( |
||||
default_factory=list, description="Collected search hits." |
||||
) |
||||
|
||||
def to_string(self, include_metadata=True) -> str: |
||||
""" |
||||
Render all hits as a single string, separated by a visual divider. |
||||
|
||||
Args: |
||||
include_metadata (bool, optional): Whether to include metadata in each hit's string representation. Defaults to True. |
||||
|
||||
Returns: |
||||
str: A single string containing all hits, separated by "\n\n---\n\n". Returns an empty string if there are no hits. |
||||
""" |
||||
"""Render all hits as a single string separated by a visual divider.""" |
||||
if not self.hits: |
||||
return "" |
||||
return "\n\n---\n\n".join( |
||||
hit.to_string(include_metadata=include_metadata) for hit in self.hits |
||||
) |
||||
|
||||
|
||||
|
||||
def ensure_read_only_aql(query: str) -> None: |
||||
""" |
||||
Reject AQL statements that attempt to mutate data or omit a RETURN clause. |
||||
|
||||
Args: |
||||
query: Raw AQL statement from the client. |
||||
|
||||
Raises: |
||||
ValueError: If the query looks unsafe. |
||||
""" |
||||
normalized = query.upper() |
||||
forbidden = ( |
||||
"INSERT ", |
||||
"UPDATE ", |
||||
"UPSERT ", |
||||
"REMOVE ", |
||||
"REPLACE ", |
||||
"DELETE ", |
||||
"DROP ", |
||||
"TRUNCATE ", |
||||
"UPSERT ", |
||||
"MERGE ", |
||||
) |
||||
if any(keyword in normalized for keyword in forbidden): |
||||
raise ValueError("Only read-only AQL queries are allowed.") |
||||
if " RETURN " not in normalized and not normalized.strip().startswith("RETURN "): |
||||
raise ValueError("AQL queries must include a RETURN clause.") |
||||
|
||||
|
||||
def strip_private_fields(document: Dict[str, Any]) -> Dict[str, Any]: |
||||
""" |
||||
Remove large internal fields from a document dictionary. |
||||
|
||||
Args: |
||||
document: Document returned by ArangoDB. |
||||
|
||||
Returns: |
||||
Sanitized copy without chunk payloads. |
||||
""" |
||||
if "chunks" in document: |
||||
del document["chunks"] |
||||
return document |
||||
|
||||
|
||||
def search_documents(aql_query: str) -> Dict[str, Any]: |
||||
""" |
||||
Execute a read-only AQL query and return the result set together with the query string. |
||||
|
||||
Args: |
||||
aql_query: Read-only AQL statement supplied by the client. |
||||
|
||||
Returns: |
||||
Dictionary containing the executed AQL string, row count, and result rows. |
||||
""" |
||||
ensure_read_only_aql(aql_query) |
||||
rows = [strip_private_fields(doc) for doc in arango.execute_aql(aql_query)] |
||||
return { |
||||
"aql": aql_query, |
||||
"row_count": len(rows), |
||||
"rows": rows, |
||||
} |
||||
|
||||
|
||||
def run_aql_query(aql_query: str) -> List[Dict[str, Any]]: |
||||
""" |
||||
Execute a read-only AQL query and return the rows. |
||||
|
||||
Args: |
||||
aql_query: Read-only AQL statement. |
||||
|
||||
Returns: |
||||
List of result rows. |
||||
""" |
||||
ensure_read_only_aql(aql_query) |
||||
return [strip_private_fields(doc) for doc in arango.execute_aql(aql_query)] |
||||
|
||||
|
||||
def _get_existing_collection(name: str) -> Collection: |
||||
""" |
||||
Fetch an existing Chroma collection without creating new data. |
||||
|
||||
Args: |
||||
name: Collection identifier. |
||||
|
||||
Returns: |
||||
The requested collection. |
||||
|
||||
Raises: |
||||
ValueError: If the collection is absent. |
||||
""" |
||||
available = {collection.name for collection in chroma_db._client.list_collections()} |
||||
if name not in available: |
||||
raise ValueError(f"Chroma collection '{name}' does not exist.") |
||||
return chroma_db._client.get_collection(name=name) |
||||
|
||||
|
||||
def vector_search(query: str, limit: int) -> List[Dict[str, Any]]: |
||||
""" |
||||
Perform semantic search against the pre-built Chroma collection. |
||||
|
||||
Args: |
||||
query: Free-form search text. |
||||
limit: Maximum number of hits to return. |
||||
|
||||
Returns: |
||||
List of hit dictionaries with metadata and scores. |
||||
""" |
||||
collection_name = chroma_db.path.split("/")[-1] # ...existing code... |
||||
chroma_collection = _get_existing_collection(collection_name) |
||||
results = chroma_collection.query( |
||||
query_texts=[query], |
||||
n_results=limit, |
||||
) |
||||
metadatas = results.get("metadatas") or [] |
||||
documents = results.get("documents") or [] |
||||
ids = results.get("ids") or [] |
||||
distances = results.get("distances") or [] |
||||
|
||||
def as_int(value: Any, default: int = -1) -> int: |
||||
if isinstance(value, int): |
||||
return value |
||||
if isinstance(value, float) and value.is_integer(): |
||||
return int(value) |
||||
if isinstance(value, str) and value.strip().lstrip("+-").isdigit(): |
||||
return int(value) |
||||
return default |
||||
|
||||
hits: List[Dict[str, Any]] = [] |
||||
for index, metadata in enumerate(metadatas[0] if metadatas else []): |
||||
meta = metadata or {} |
||||
document = documents[0][index] if documents else "" |
||||
identifier = ids[0][index] if ids else "" |
||||
hit = { |
||||
"_id": meta.get("_id") or identifier, |
||||
"heading": meta.get("heading") or meta.get("title") or meta.get("talare"), |
||||
"snippet": meta.get("snippet") or meta.get("text") or document, |
||||
"debateurl": meta.get("debateurl") or meta.get("debate_url"), |
||||
"chunk_index": as_int(meta.get("chunk_index") or meta.get("index")), |
||||
"score": distances[0][index] if distances else None, |
||||
} |
||||
if hit["_id"]: |
||||
hits.append(hit) |
||||
return hits |
||||
|
||||
|
||||
def fetch_documents(document_ids: Sequence[str], fields: Optional[Iterable[str]] = None) -> List[Dict[str, Any]]: |
||||
""" |
||||
Pull full documents by _id while stripping heavy fields. |
||||
|
||||
Args: |
||||
document_ids: Iterable with fully qualified Arango document ids. |
||||
fields: Optional subset of fields to return. |
||||
|
||||
Returns: |
||||
List of sanitized documents. |
||||
""" |
||||
ids = [doc_id.replace("\\", "/") for doc_id in document_ids] |
||||
query = """ |
||||
FOR id IN @document_ids |
||||
RETURN DOCUMENT(id) |
||||
""" |
||||
documents = arango.execute_aql(query, bind_vars={"document_ids": ids}) |
||||
if fields: |
||||
return [{field: doc.get(field) for field in fields if field in doc} for doc in documents] |
||||
return [strip_private_fields(doc) for doc in documents] |
||||
|
||||
|
||||
@dataclass |
||||
class SearchPayload: |
||||
""" |
||||
Lightweight container passed to SearchService.search. |
||||
""" |
||||
q: str |
||||
parties: Optional[List[str]] |
||||
people: Optional[List[str]] |
||||
debates: Optional[List[str]] |
||||
from_year: Optional[int] |
||||
to_year: Optional[int] |
||||
limit: int |
||||
return_snippets: bool |
||||
focus_ids: Optional[List[str]] |
||||
speaker_ids: Optional[List[str]] |
||||
speaker: Optional[str] = None |
||||
|
||||
|
||||
def arango_search( |
||||
query: str, |
||||
limit: int, |
||||
parties: Optional[Sequence[str]] = None, |
||||
people: Optional[Sequence[str]] = None, |
||||
from_year: Optional[int] = None, |
||||
to_year: Optional[int] = None, |
||||
return_snippets: bool = False, |
||||
focus_ids: Optional[Sequence[str]] = None, |
||||
speaker_ids: Optional[Sequence[str]] = None, |
||||
) -> Dict[str, Any]: |
||||
""" |
||||
Run an ArangoSearch query using the existing SearchService utilities. |
||||
|
||||
Args: |
||||
query: Search expression (supports AND/OR/NOT and phrases). |
||||
limit: Maximum number of hits to return. |
||||
parties: Party filters. |
||||
people: Speaker name filters. |
||||
from_year: Start year filter. |
||||
to_year: End year filter. |
||||
return_snippets: Whether only snippets should be returned. |
||||
focus_ids: Optional list restricting the search scope. |
||||
speaker_ids: Optional list of speaker identifiers. |
||||
|
||||
Returns: |
||||
Dictionary containing results, stats, limit flag, and focus_ids for follow-up queries. |
||||
""" |
||||
payload = SearchPayload( |
||||
q=query, |
||||
parties=list(parties) if parties else None, |
||||
people=list(people) if people else None, |
||||
debates=None, |
||||
from_year=from_year, |
||||
to_year=to_year, |
||||
limit=limit, |
||||
return_snippets=return_snippets, |
||||
focus_ids=list(focus_ids) if focus_ids else None, |
||||
speaker_ids=list(speaker_ids) if speaker_ids else None, |
||||
) |
||||
service = SearchService() |
||||
results, stats, limit_reached = service.search( |
||||
payload=payload, |
||||
include_snippets=True, |
||||
return_snippets=return_snippets, |
||||
focus_ids=payload.focus_ids, |
||||
) |
||||
return { |
||||
"results": results, |
||||
"stats": stats, |
||||
"limit_reached": limit_reached, |
||||
"return_snippets": return_snippets, |
||||
"focus_ids": [hit["_id"] for hit in results if isinstance(hit, dict) and hit.get("_id")], |
||||
} |
||||
|
After Width: | Height: | Size: 153 KiB |
@ -0,0 +1,113 @@ |
||||
# Template for providers.yaml |
||||
# |
||||
# You can add any OpenAI API compatible provider to the "providers" list. |
||||
# For each provider you must also specify a list of models, along with model abilities. |
||||
# |
||||
# All fields are required unless marked as optional. |
||||
# |
||||
# Refer to your provider's API documentation for specific |
||||
# details such as model identifiers, capabilities etc |
||||
# |
||||
# Note: Since the OpenAI API is not a standard we can't guarantee that all |
||||
# providers will work correctly with Raycast AI. |
||||
# |
||||
# To use this template rename as `providers.yaml` |
||||
# |
||||
providers: |
||||
- id: perplexity |
||||
name: Perplexity |
||||
base_url: https://api.perplexity.ai |
||||
# Specify at least one api key if authentication is required. |
||||
# Optional if authentication is not required or is provided elsewhere. |
||||
# If individual models require separate api keys, then specify a separate `key` for each model's `provider` |
||||
api_keys: |
||||
perplexity: PERPLEXITY_KEY |
||||
# Optional additional parameters sent to the `/chat/completions` endpoint |
||||
additional_parameters: |
||||
return_images: true |
||||
web_search_options: |
||||
search_context_size: medium |
||||
# Specify all models to use with the current provider |
||||
models: |
||||
- id: sonar # `id` must match the identifier used by the provider |
||||
name: Sonar # name visible in Raycast |
||||
provider: perplexity # Only required if mapping to a specific api key |
||||
description: Perplexity AI model for general-purpose queries # optional |
||||
context: 128000 # refer to provider's API documentation |
||||
# Optional abilities - all child properties are also optional. |
||||
# If you specify abilities incorrectly the model may fail to work as expected in Raycast AI. |
||||
# Refer to provider's API documentation for model abilities. |
||||
abilities: |
||||
temperature: |
||||
supported: true |
||||
vision: |
||||
supported: true |
||||
system_message: |
||||
supported: true |
||||
tools: |
||||
supported: false |
||||
reasoning_effort: |
||||
supported: false |
||||
- id: sonar-pro |
||||
name: Sonar Pro |
||||
description: Perplexity AI model for complex queries |
||||
context: 200000 |
||||
abilities: |
||||
temperature: |
||||
supported: true |
||||
vision: |
||||
supported: true |
||||
system_message: |
||||
supported: true |
||||
# provider with multiple api keys |
||||
- id: my_provider |
||||
name: My Provider |
||||
base_url: http://localhost:4000 |
||||
api_keys: |
||||
openai: OPENAI_KEY |
||||
anthropic: ANTHROPIC_KEY |
||||
models: |
||||
- id: gpt-4o |
||||
name: "GPT-4o" |
||||
context: 200000 |
||||
provider: openai # matches "openai" in api_keys |
||||
abilities: |
||||
temperature: |
||||
supported: true |
||||
vision: |
||||
supported: true |
||||
system_message: |
||||
supported: true |
||||
tools: |
||||
supported: true |
||||
- id: claude-sonnet-4 |
||||
name: "Claude Sonnet 4" |
||||
context: 200000 |
||||
provider: anthropic # matches "anthropic" in api_keys |
||||
abilities: |
||||
temperature: |
||||
supported: true |
||||
vision: |
||||
supported: true |
||||
system_message: |
||||
supported: true |
||||
tools: |
||||
supported: true |
||||
- id: litellm |
||||
name: LiteLLM |
||||
base_url: http://localhost:4000 |
||||
# No `api_keys` - authentication is provided by the LiteLLM config |
||||
models: |
||||
- id: anthropic/claude-sonnet-4-20250514 |
||||
name: "Claude Sonnet 4" |
||||
context: 200000 |
||||
abilities: |
||||
temperature: |
||||
supported: true |
||||
vision: |
||||
supported: true |
||||
system_message: |
||||
supported: true |
||||
tools: |
||||
supported: true |
||||
|
||||
@ -0,0 +1,20 @@ |
||||
providers: |
||||
- id: vllm |
||||
name: vLLM Instance |
||||
base_url: https://lasseedfast.se/vllm |
||||
api_keys: |
||||
vllm: "ap98sfoiuajcnwe89sozchnsw9oeacislh" |
||||
models: |
||||
- id: "gpt-oss:20b" |
||||
name: "GPT OSS 20B (Main)" |
||||
provider: vllm |
||||
context: 16000 |
||||
abilities: |
||||
temperature: |
||||
supported: true |
||||
vision: |
||||
supported: false |
||||
system_message: |
||||
supported: true |
||||
tools: |
||||
supported: true |
||||
@ -1,182 +0,0 @@ |
||||
#!/usr/bin/env python3 |
||||
""" |
||||
Convert stored embeddings to plain Python lists for an existing Chroma collection. |
||||
|
||||
Usage: |
||||
# Dry run (inspect first 50 ids) |
||||
python scripts/convert_embeddings_to_lists.py --collection talks --limit 50 --dry-run |
||||
|
||||
# Full run (no dry run) |
||||
python scripts/convert_embeddings_to_lists.py --collection talks |
||||
|
||||
Notes: |
||||
- Run from your project root (same env you use to access chroma_db). |
||||
- Back up chromadb_data before running. |
||||
""" |
||||
import argparse |
||||
import json |
||||
import os |
||||
import time |
||||
from pathlib import Path |
||||
from typing import List |
||||
import math |
||||
import sys |
||||
|
||||
# Use the same imports/bootstrapping as you already have in your project |
||||
# so the same chroma client and embedding function are loaded. |
||||
# Adjust the import path if necessary. |
||||
os.chdir("/home/lasse/riksdagen") |
||||
sys.path.append("/home/lasse/riksdagen") |
||||
|
||||
import numpy as np |
||||
from _chromadb.chroma_client import chroma_db |
||||
|
||||
CHECKPOINT_DIR = Path("var/chroma_repair") |
||||
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) |
||||
|
||||
def normalize_embedding(emb): |
||||
""" |
||||
Convert a single embedding to a plain Python list[float]. |
||||
Accepts numpy arrays, array-likes, lists. |
||||
""" |
||||
# numpy ndarray |
||||
if isinstance(emb, np.ndarray): |
||||
return emb.tolist() |
||||
# Some array-likes (pandas/other) may have tolist() |
||||
if hasattr(emb, "tolist") and not isinstance(emb, list): |
||||
try: |
||||
return emb.tolist() |
||||
except Exception: |
||||
pass |
||||
# If it's already a list of numbers, convert elements to float |
||||
if isinstance(emb, list): |
||||
return [float(x) for x in emb] |
||||
# last resort: try iterating |
||||
try: |
||||
return [float(x) for x in emb] |
||||
except Exception: |
||||
raise ValueError("Cannot normalize embedding of type: %s" % type(emb)) |
||||
|
||||
def chunked_iter(iterable, n): |
||||
it = iter(iterable) |
||||
while True: |
||||
chunk = [] |
||||
try: |
||||
for _ in range(n): |
||||
chunk.append(next(it)) |
||||
except StopIteration: |
||||
pass |
||||
if not chunk: |
||||
break |
||||
yield chunk |
||||
|
||||
def load_checkpoint(name): |
||||
path = CHECKPOINT_DIR / f"{name}.json" |
||||
if path.exists(): |
||||
return json.load(path) |
||||
return {"last_index": 0, "processed_ids": []} |
||||
|
||||
def save_checkpoint(name, data): |
||||
path = CHECKPOINT_DIR / f"{name}.json" |
||||
with open(path, "w") as f: |
||||
json.dump(data, f) |
||||
|
||||
def main(): |
||||
parser = argparse.ArgumentParser() |
||||
parser.add_argument("--collection", required=True, help="Chroma collection name (e.g. talks)") |
||||
parser.add_argument("--batch", type=int, default=1000, help="Batch size for update (default 1000)") |
||||
parser.add_argument("--dry-run", action="store_true", help="Dry run: don't write updates, just report") |
||||
parser.add_argument("--limit", type=int, default=None, help="Limit total number of ids to process (for testing)") |
||||
parser.add_argument("--checkpoint-name", default=None, help="Name for checkpoint file (defaults to collection name)") |
||||
args = parser.parse_args() |
||||
|
||||
coll_name = args.collection |
||||
checkpoint_name = args.checkpoint_name or coll_name |
||||
|
||||
print(f"Connecting to Chroma collection '{coll_name}'...") |
||||
col = chroma_db.get_collection(coll_name) |
||||
|
||||
# Get the full list of ids. For 600k this should be okay to hold in memory, |
||||
# but if you need a more streaming approach, tell me and I can adapt. |
||||
all_info = col.get(include=[]) # may return {'ids': [...]} as in your env |
||||
ids = list(all_info.get("ids", [])) |
||||
total_ids = len(ids) |
||||
if args.limit: |
||||
ids = ids[: args.limit] |
||||
total_process = len(ids) |
||||
else: |
||||
total_process = total_ids |
||||
|
||||
print(f"Found {total_ids} ids in collection; will process {total_process} ids (limit={args.limit})") |
||||
|
||||
# load checkpoint |
||||
ck = load_checkpoint(checkpoint_name) |
||||
start_index = ck.get("last_index", 0) |
||||
print(f"Resuming at index {start_index}") |
||||
|
||||
# iterate in batches starting from last_index |
||||
processed = 0 |
||||
for i in range(start_index, total_process, args.batch): |
||||
batch_ids = ids[i : i + args.batch] |
||||
print(f"\nProcessing batch {i}..{i+len(batch_ids)-1} (count={len(batch_ids)})") |
||||
|
||||
# fetch full info for this batch (documents, metadatas, embeddings) |
||||
# we only need embeddings for this repair, but include docs/meta for verification if you want |
||||
try: |
||||
items = col.get(ids=batch_ids, include=["embeddings", "documents", "metadatas"]) |
||||
except Exception as e: |
||||
print("Error fetching batch:", e) |
||||
# do a small retry after sleep |
||||
time.sleep(2) |
||||
items = col.get(ids=batch_ids, include=["embeddings", "documents", "metadatas"]) |
||||
|
||||
batch_embeddings = items.get("embeddings", []) |
||||
# items.get("ids") should match batch_ids order; if not, align by ids |
||||
ids_from_get = items.get("ids", batch_ids) |
||||
if len(ids_from_get) != len(batch_ids): |
||||
print("Warning: length mismatch between requested ids and returned ids") |
||||
|
||||
# Normalize embeddings |
||||
normalized_embeddings = [] |
||||
failed = False |
||||
for idx, emb in enumerate(batch_embeddings): |
||||
try: |
||||
norm = normalize_embedding(emb) |
||||
except Exception as e: |
||||
print(f"Failed to normalize embedding for id {ids_from_get[idx]}: {e}") |
||||
failed = True |
||||
break |
||||
normalized_embeddings.append(norm) |
||||
|
||||
if failed: |
||||
print("Skipping this batch due to failures. You can adjust batch size and retry.") |
||||
break |
||||
|
||||
# Dry-run: just print stats and continue |
||||
if args.dry_run: |
||||
# show a sample |
||||
sample_i = min(3, len(normalized_embeddings)) |
||||
print("Sample normalized embedding lengths:", [len(normalized_embeddings[k]) for k in range(sample_i)]) |
||||
# Optionally inspect first few floats |
||||
print("Sample values (first 6 floats):", [normalized_embeddings[k][:6] for k in range(sample_i)]) |
||||
else: |
||||
# Update the collection in place (update will upsert embeddings for given ids) |
||||
try: |
||||
col.update(ids=ids_from_get, embeddings=normalized_embeddings) |
||||
except Exception as e: |
||||
print("Update failed, retrying once after short sleep:", e) |
||||
time.sleep(2) |
||||
col.update(ids=ids_from_get, embeddings=normalized_embeddings) |
||||
|
||||
print(f"Updated {len(normalized_embeddings)} embeddings in collection '{coll_name}'") |
||||
|
||||
# checkpoint progress |
||||
ck["last_index"] = i + len(batch_ids) |
||||
save_checkpoint(checkpoint_name, ck) |
||||
processed += len(batch_ids) |
||||
|
||||
print(f"\nDone. Processed {processed} ids. Checkpoint saved to {CHECKPOINT_DIR / (checkpoint_name + '.json')}") |
||||
print("Reminder: run a few queries to validate search quality.") |
||||
|
||||
if __name__ == "__main__": |
||||
main() |
||||
@ -0,0 +1,25 @@ |
||||
from arango_client import arango |
||||
|
||||
chunks_collection = arango.db.collection("chunks") |
||||
|
||||
q = """ |
||||
FOR chunk IN chunks |
||||
FILTER chunk.parent_id == null |
||||
RETURN chunk |
||||
""" |
||||
|
||||
cursor = arango.db.aql.execute(q, batch_size=1000, count=True, ttl=360) |
||||
updated_docs = [] |
||||
n = 0 |
||||
for doc in cursor: |
||||
n += 1 |
||||
doc['collection'] = 'talks' |
||||
del doc['chroma_collecton'] |
||||
del doc['chroma_id'] |
||||
doc['parent_id'] = f"talks/{doc['_key'].split(':')[0]}" |
||||
updated_docs.append(doc) |
||||
if len(updated_docs) >= 100: |
||||
chunks_collection.update_many(updated_docs, merge=False, silent=True) |
||||
updated_docs = [] |
||||
print(f"Updated {n} documents", end="\r") |
||||
chunks_collection.update_many(updated_docs, merge=False, silent=True) |
||||
@ -0,0 +1,173 @@ |
||||
import os |
||||
import sys |
||||
import logging |
||||
|
||||
# Silence the per-request HTTP logs from the ollama/httpx library |
||||
logging.getLogger("httpx").setLevel(logging.WARNING) |
||||
|
||||
os.chdir("/home/lasse/riksdagen") |
||||
sys.path.append("/home/lasse/riksdagen") |
||||
|
||||
from arango_client import arango |
||||
from ollama import Client as Ollama |
||||
from arango.collection import Collection |
||||
from concurrent.futures import ThreadPoolExecutor, as_completed |
||||
from typing import List, Dict |
||||
from time import sleep |
||||
from utils import TextChunker |
||||
|
||||
|
||||
def make_embeddings(texts: List[str]) -> List[List[float]]: |
||||
""" |
||||
Generate embeddings for a list of texts using Ollama. |
||||
|
||||
Args: |
||||
texts (List[str]): List of text strings to embed. |
||||
|
||||
Returns: |
||||
List[List[float]]: List of embedding vectors. |
||||
""" |
||||
ollama_client = Ollama(host='192.168.1.12:33405') |
||||
embeddings = ollama_client.embed( |
||||
model="qwen3-embedding:latest", |
||||
input=texts, |
||||
dimensions=384, |
||||
) |
||||
return embeddings.embeddings |
||||
|
||||
|
||||
def process_chunk_batch(chunk_batch: List[Dict]) -> List[Dict]: |
||||
""" |
||||
Generate embeddings for a batch of chunks and attach them. |
||||
|
||||
Args: |
||||
chunk_batch (List[Dict]): List of chunk dicts, each with a 'text' field. |
||||
|
||||
Returns: |
||||
List[Dict]: Same list with an 'embedding' field added to each dict. |
||||
""" |
||||
sleep(1) |
||||
texts = [chunk['text'] for chunk in chunk_batch] |
||||
embeddings = make_embeddings(texts) |
||||
for i, chunk in enumerate(chunk_batch): |
||||
chunk['embedding'] = embeddings[i] |
||||
return chunk_batch |
||||
|
||||
|
||||
def make_arango_embeddings() -> int: |
||||
""" |
||||
Chunks and embeds all talks that are not yet represented in the 'chunks' collection. |
||||
|
||||
For each talk that has no chunks in the collection yet: |
||||
- If the talk document already has a 'chunks' field (legacy path), those are used. |
||||
- Otherwise the speech text is split into chunks using TextChunker. |
||||
Embedding vectors are generated via Ollama and stored in the 'chunks' collection. |
||||
|
||||
Each chunk document in ArangoDB has: |
||||
_key : "{talk_key}:{chunk_index}" (unique within the collection) |
||||
text : the chunk text |
||||
index : chunk index within the talk |
||||
parent_id : "talks/{talk_key}" (links back to the source talk) |
||||
collection: "talks" |
||||
embedding : the vector (list of floats) |
||||
|
||||
Returns: |
||||
int: Total number of chunk documents inserted/updated. |
||||
""" |
||||
if not arango.db.has_collection("chunks"): |
||||
chunks_collection: Collection = arango.db.create_collection("chunks") |
||||
else: |
||||
chunks_collection: Collection = arango.db.collection("chunks") |
||||
|
||||
# Find every talk that has no entry yet in the chunks collection. |
||||
# The inner FOR loop returns [] if no match exists (acts as NOT EXISTS). |
||||
cursor = arango.db.aql.execute( |
||||
""" |
||||
FOR p IN talks |
||||
FILTER p.anforandetext != null AND p.anforandetext != "" |
||||
FILTER ( |
||||
FOR c IN chunks |
||||
FILTER c.parent_id == p._id |
||||
LIMIT 1 |
||||
RETURN 1 |
||||
) == [] |
||||
RETURN { |
||||
_key: p._key, |
||||
_id: p._id, |
||||
anforandetext: p.anforandetext, |
||||
chunks: p.chunks |
||||
} |
||||
""", |
||||
batch_size=1000, |
||||
ttl=360, |
||||
) |
||||
|
||||
n = 0 |
||||
embed_batch_size = 20 # Number of chunks per Ollama call |
||||
chunk_batches: List[List[Dict]] = [] |
||||
|
||||
for talk in cursor: |
||||
talk_key = talk["_key"] |
||||
parent_id = f"talks/{talk_key}" |
||||
|
||||
if talk.get("chunks"): |
||||
# Legacy path: chunks were previously generated and stored on the talk document. |
||||
# Strip out the old ChromaDB-specific fields and assign a proper _key. |
||||
_chunks = [] |
||||
for chunk in talk["chunks"]: |
||||
idx = chunk.get("index", 0) |
||||
_chunks.append({ |
||||
"_key": f"{talk_key}:{idx}", |
||||
"text": chunk["text"], |
||||
"index": idx, |
||||
"parent_id": parent_id, |
||||
"collection": "talks", |
||||
}) |
||||
else: |
||||
# New path: chunk the speech text directly with TextChunker. |
||||
text = (talk.get("anforandetext") or "").strip() |
||||
text_chunks = TextChunker(chunk_limit=500).chunk(text) |
||||
_chunks = [ |
||||
{ |
||||
"_key": f"{talk_key}:{idx}", |
||||
"text": content, |
||||
"index": idx, |
||||
"parent_id": parent_id, |
||||
"collection": "talks", |
||||
} |
||||
for idx, content in enumerate(text_chunks) |
||||
if content and content.strip() |
||||
] |
||||
|
||||
# Split into batches for embedding |
||||
for i in range(0, len(_chunks), embed_batch_size): |
||||
batch = _chunks[i : i + embed_batch_size] |
||||
if batch: |
||||
chunk_batches.append(batch) |
||||
|
||||
# Embed all batches in parallel (Ollama calls are I/O-bound, threads are fine) |
||||
total_batches = len(chunk_batches) |
||||
completed_batches = 0 |
||||
with ThreadPoolExecutor(max_workers=3) as executor: |
||||
futures = [executor.submit(process_chunk_batch, batch) for batch in chunk_batches] |
||||
processed_chunks: List[Dict] = [] |
||||
for future in as_completed(futures): |
||||
result = future.result() |
||||
completed_batches += 1 |
||||
processed_chunks.extend(result) |
||||
print(f"Embedding batches: {completed_batches}/{total_batches} | chunks ready to insert: {len(processed_chunks)}", end="\r") |
||||
# Insert in batches of 100 to keep HTTP payloads small |
||||
if len(processed_chunks) >= 100: |
||||
n += len(processed_chunks) |
||||
chunks_collection.insert_many(processed_chunks, overwrite=True) |
||||
processed_chunks = [] |
||||
if processed_chunks: |
||||
n += len(processed_chunks) |
||||
chunks_collection.insert_many(processed_chunks, overwrite=True) |
||||
|
||||
print(f"\nDone. Inserted/updated {n} chunks in ArangoDB.") |
||||
return n |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
make_arango_embeddings() |
||||
@ -0,0 +1,5 @@ |
||||
### Inlogg riksdagsgruppen från Fojo-hackathon |
||||
https://arango.lasseedfast.se |
||||
riksdagsgruppen |
||||
popre4-cygcuz-viHjyc |
||||
|
||||
@ -0,0 +1,177 @@ |
||||
""" |
||||
Synkroniserar nya anföranden från riksdagen.se till databasen och processar dem. |
||||
|
||||
Pipeline (körs dagligen via systemd timer): |
||||
1. Ladda ned årets anföranden från riksdagen.se (ersätter tidigare nerladdning) |
||||
2. Infoga nya anföranden i ArangoDB (hoppar över redan existerande) |
||||
3. Tilldela debatt-ID:n till anföranden som saknar det |
||||
4. Bygg embeddings för datum som saknar chunks |
||||
5. Generera sammanfattningar för datum som saknar summary |
||||
|
||||
Kör manuellt: python scripts/sync_talks.py |
||||
""" |
||||
|
||||
import os |
||||
import sys |
||||
import logging |
||||
from datetime import datetime |
||||
from io import BytesIO |
||||
from urllib.request import urlopen |
||||
from zipfile import ZipFile |
||||
|
||||
# Säkerställ att vi kör från projektroten och att lokala moduler hittas |
||||
os.chdir("/home/lasse/riksdagen") |
||||
sys.path.append("/home/lasse/riksdagen") |
||||
|
||||
logging.basicConfig( |
||||
level=logging.INFO, |
||||
format="%(asctime)s [%(levelname)s] %(message)s", |
||||
) |
||||
logger = logging.getLogger(__name__) |
||||
|
||||
# Systemprompt som används av LLM:en vid sammanfattning av debatter |
||||
SYSTEM_MESSAGE = """Din uppgift är att sammanfatta debatter i Sveriges riksdag. |
||||
Du kommer först att få enskilda tal som du ska sammanfatta var för sig, efter det ska du sammanfatta hela debatten. |
||||
Sammanfattningarna ska vara på svenska och vara koncisa och informativa. |
||||
Det är viktigt att du förstår vad som är kärnan i varje tal och debatt, fokusera därför på de argument och sakförhållanden som framförs. |
||||
""" |
||||
|
||||
|
||||
def get_current_session_year() -> int: |
||||
""" |
||||
Returnerar startåret för aktuell riksdagssession. |
||||
|
||||
Riksdagssessionen löper september–augusti, så: |
||||
- Januari–Augusti 2026 → sessionen startade sep 2025 → returnerar 2025 |
||||
- September–December 2025 → sessionen startade sep 2025 → returnerar 2025 |
||||
|
||||
Returns: |
||||
int: Fyrsiffrigt startår för aktuell session (t.ex. 2025). |
||||
""" |
||||
now = datetime.now() |
||||
if now.month >= 9: |
||||
return now.year |
||||
else: |
||||
return now.year - 1 |
||||
|
||||
|
||||
def download_current_year(session_year: int) -> str: |
||||
""" |
||||
Laddar ned och extraherar ZIP-arkivet för angiven riksdagssession, |
||||
och ersätter eventuella tidigare nerladdade filer för det året. |
||||
|
||||
Riksdagen uppdaterar kontinuerligt samma ZIP-fil under pågående session, |
||||
så vi måste ladda ned den på nytt varje gång för att få med nya anföranden. |
||||
|
||||
Args: |
||||
session_year (int): Sessionens startår (t.ex. 2025 för session 2025/26). |
||||
|
||||
Returns: |
||||
str: Sökväg till katalogen dit filerna extraherades. |
||||
""" |
||||
second_part = str(session_year + 1)[2:] # t.ex. "26" för 2026 |
||||
url = f"https://data.riksdagen.se/dataset/anforande/anforande-{session_year}{second_part}.json.zip" |
||||
folder_name = f"anforande-{session_year}{second_part}" |
||||
dir_path = os.path.join("talks", folder_name) |
||||
|
||||
logger.info(f"Downloading {url} → {dir_path}") |
||||
os.makedirs(dir_path, exist_ok=True) |
||||
|
||||
# Rensa gamla filer så vi får en färsk kopia |
||||
for f in os.listdir(dir_path): |
||||
os.remove(os.path.join(dir_path, f)) |
||||
|
||||
with urlopen(url) as resp: |
||||
with ZipFile(BytesIO(resp.read())) as zf: |
||||
zf.extractall(dir_path) |
||||
|
||||
count = len(os.listdir(dir_path)) |
||||
logger.info(f"Extracted {count} files to {dir_path}") |
||||
return dir_path |
||||
|
||||
|
||||
def get_unsummarized_dates() -> list[str]: |
||||
""" |
||||
Hämtar datum från ArangoDB som har anföranden utan sammanfattning. |
||||
|
||||
Returns: |
||||
list[str]: Sorterad lista med datumsträngar, t.ex. ["2026-02-10", "2026-02-11"]. |
||||
""" |
||||
from arango_client import arango |
||||
|
||||
cursor = arango.db.aql.execute( |
||||
""" |
||||
FOR doc IN talks |
||||
FILTER doc.summary == null |
||||
RETURN DISTINCT doc.datum |
||||
""", |
||||
ttl=300, |
||||
) |
||||
dates = sorted(list(cursor)) |
||||
logger.info(f"Found {len(dates)} dates with unsummarized talks") |
||||
return dates |
||||
|
||||
|
||||
def sync() -> None: |
||||
""" |
||||
Kör hela sync-pipelinen: |
||||
1. Ladda ned årets anföranden |
||||
2. Infoga nya anföranden i ArangoDB |
||||
3. Tilldela debatt-ID:n |
||||
4. Bygg embeddings för nya datum |
||||
5. Generera sammanfattningar för nya datum |
||||
""" |
||||
logger.info("=== Starting daily riksdagen sync ===") |
||||
|
||||
# --- Steg 1: Ladda ned --- |
||||
session_year = get_current_session_year() |
||||
logger.info(f"Current session year: {session_year}/{session_year + 1}") |
||||
dir_path = download_current_year(session_year) |
||||
|
||||
# --- Steg 2: Infoga nya anföranden i ArangoDB --- |
||||
# update_folder() hämtar alla befintliga _key:s från databasen och hoppar |
||||
# över dem, så enbart nya anföranden infogas. |
||||
logger.info("Stage 2: Inserting new talks into ArangoDB...") |
||||
from scripts.documents_to_arango import update_folder |
||||
|
||||
new_talks = update_folder(os.path.abspath(dir_path)) |
||||
logger.info(f"Stage 2 complete: {new_talks} new talks inserted") |
||||
|
||||
# --- Steg 3: Tilldela debatt-ID:n --- |
||||
# Anföranden som saknar fältet 'debate' grupperas i debatter baserat på |
||||
# datum och om de är repliker eller ej. |
||||
logger.info("Stage 3: Assigning debate IDs to talks missing them...") |
||||
from scripts.debates import make_debate_ids |
||||
|
||||
make_debate_ids() |
||||
logger.info("Stage 3 complete") |
||||
|
||||
# --- Steg 4: Chunk + bygg embeddings i ArangoDB --- |
||||
# make_arango_embeddings() hittar alla anföranden som saknar chunks i |
||||
# 'chunks'-kollektionen, chunkar texten, genererar vektorer via Ollama |
||||
# och lagrar allt direkt i ArangoDB. ChromaDB används inte. |
||||
logger.info("Stage 4: Chunking and embedding new talks into ArangoDB...") |
||||
from scripts.make_arango_embeddings import make_arango_embeddings |
||||
|
||||
total_chunks = make_arango_embeddings() |
||||
logger.info(f"Stage 4 complete: {total_chunks} chunks created") |
||||
|
||||
# --- Steg 5: Generera sammanfattningar --- |
||||
# process_debate_date() hoppar automatiskt över anföranden som redan har |
||||
# en sammanfattning, så det är säkert att köra igen. |
||||
new_dates = get_unsummarized_dates() |
||||
if new_dates: |
||||
logger.info(f"Stage 5: Generating summaries for {len(new_dates)} dates...") |
||||
from scripts.debates import process_debate_date |
||||
|
||||
for date in new_dates: |
||||
process_debate_date(date, SYSTEM_MESSAGE) |
||||
logger.info(f"Stage 5 complete: summaries generated for {len(new_dates)} dates") |
||||
else: |
||||
logger.info("Stage 5: No unsummarized dates, skipping") |
||||
|
||||
logger.info("=== Sync complete ===") |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
sync() |
||||
@ -0,0 +1,57 @@ |
||||
from arango_client import arango |
||||
from scripts.make_arango_embeddings import process_chunk_batch |
||||
from arango.collection import Collection |
||||
from typing import List, Dict |
||||
|
||||
def test_full_make_arango_embeddings_for_one_talk() -> None: |
||||
""" |
||||
Integration test for the full make_arango_embeddings chain: |
||||
- Fetches a specific talk document from ArangoDB. |
||||
- Processes its chunks to generate embeddings. |
||||
- Inserts/updates those chunks in the 'chunks' collection. |
||||
- Verifies that the chunks were updated in ArangoDB. |
||||
|
||||
This test requires ArangoDB and Ollama to be running and accessible. |
||||
""" |
||||
# The _id of the talk we want to process |
||||
target_id: str = "talks/000004cc-b896-e611-9441-00262d0d7125" |
||||
_key = target_id.split("/")[-1] |
||||
|
||||
# Get the talks and chunks collections |
||||
talks_collection: Collection = arango.db.collection("talks") |
||||
chunks_collection: Collection = arango.db.collection("chunks") |
||||
|
||||
# Fetch the talk document |
||||
talk: Dict = talks_collection.get(target_id) |
||||
assert talk is not None, f"Talk with _id {target_id} not found" |
||||
assert "chunks" in talk and talk["chunks"], "Talk has no chunks" |
||||
|
||||
# Prepare chunks for embedding |
||||
processed_chunks: List[Dict] = [] |
||||
for chunk in talk["chunks"]: |
||||
key: str = chunk["chroma_id"].split("/")[-1] |
||||
chunk["_key"] = key.split(":")[-1] |
||||
chunk["parent_id"] = target_id |
||||
chunk["collection"] = "talks" |
||||
# Remove fields not needed for embedding |
||||
if "chroma_id" in chunk: |
||||
del chunk["chroma_id"] |
||||
if "chroma_collecton" in chunk: |
||||
del chunk["chroma_collecton"] |
||||
processed_chunks.append(chunk) |
||||
|
||||
# Generate embeddings for all chunks |
||||
processed_chunks = process_chunk_batch(processed_chunks) |
||||
|
||||
# Insert/update chunks in the 'chunks' collection |
||||
chunks_collection.insert_many(processed_chunks, overwrite=True) |
||||
|
||||
# Verify that the chunks were updated in ArangoDB |
||||
for chunk in processed_chunks: |
||||
db_chunk = chunks_collection.get(chunk["_key"]) |
||||
assert db_chunk is not None, f"Chunk {_key} not found in DB" |
||||
assert "embedding" in db_chunk, "Chunk missing embedding in DB" |
||||
assert isinstance(db_chunk["embedding"], list), "Embedding is not a list" |
||||
print(f"Chunk {chunk['_key']} updated with embedding of length {len(db_chunk['embedding'])}") |
||||
|
||||
test_full_make_arango_embeddings_for_one_talk() |
||||
Loading…
Reference in new issue