- Created a template for providers.yaml to define API providers and models. - Added a new providers.yaml file with initial provider configurations. - Implemented fix_things.py to update chunk documents in ArangoDB. - Developed make_arango_embeddings.py to generate embeddings for talks and store them in ArangoDB. - Introduced sync_talks.py to synchronize new speeches from riksdagen.se and process them. - Added notes.md for documentation on riksdagsgruppen login details. - Created test_make_arango_embeddings.py for integration testing of embedding generation. - Implemented test_gpu.py to test image input handling with vLLM.master
parent
3ba8c3340a
commit
88e0244429
37 changed files with 2282 additions and 373 deletions
Binary file not shown.
@ -0,0 +1,144 @@ |
|||||||
|
#!/usr/bin/env python3 |
||||||
|
""" |
||||||
|
System Resource Monitor - Logs system stats to help diagnose SSH connectivity issues. |
||||||
|
|
||||||
|
This script monitors: |
||||||
|
- CPU usage |
||||||
|
- Memory usage |
||||||
|
- Disk usage |
||||||
|
- Network connectivity |
||||||
|
- SSH service status |
||||||
|
- System load |
||||||
|
- Active connections |
||||||
|
|
||||||
|
Run continuously to capture when the system becomes unreachable. |
||||||
|
""" |
||||||
|
|
||||||
|
import psutil |
||||||
|
import time |
||||||
|
import logging |
||||||
|
from datetime import datetime |
||||||
|
from pathlib import Path |
||||||
|
|
||||||
|
# Setup logging to file with rotation |
||||||
|
log_file = Path("/var/log/system_monitor.log") |
||||||
|
logging.basicConfig( |
||||||
|
level=logging.INFO, |
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s', |
||||||
|
handlers=[ |
||||||
|
logging.FileHandler(log_file), |
||||||
|
logging.StreamHandler() # Also print to console |
||||||
|
] |
||||||
|
) |
||||||
|
|
||||||
|
def check_ssh_service() -> dict: |
||||||
|
""" |
||||||
|
Check if SSH service is running. |
||||||
|
|
||||||
|
Returns: |
||||||
|
dict: Service status information |
||||||
|
""" |
||||||
|
try: |
||||||
|
import subprocess |
||||||
|
result = subprocess.run( |
||||||
|
['systemctl', 'is-active', 'ssh'], |
||||||
|
capture_output=True, |
||||||
|
text=True, |
||||||
|
timeout=5 |
||||||
|
) |
||||||
|
return { |
||||||
|
'running': result.returncode == 0, |
||||||
|
'status': result.stdout.strip() |
||||||
|
} |
||||||
|
except Exception as e: |
||||||
|
return {'running': False, 'error': str(e)} |
||||||
|
|
||||||
|
def get_system_stats() -> dict: |
||||||
|
""" |
||||||
|
Collect current system statistics. |
||||||
|
|
||||||
|
Returns: |
||||||
|
dict: System statistics including CPU, memory, disk, network |
||||||
|
""" |
||||||
|
# CPU usage |
||||||
|
cpu_percent = psutil.cpu_percent(interval=1) |
||||||
|
cpu_count = psutil.cpu_count() |
||||||
|
|
||||||
|
# Memory usage |
||||||
|
memory = psutil.virtual_memory() |
||||||
|
swap = psutil.swap_memory() |
||||||
|
|
||||||
|
# Disk usage |
||||||
|
disk = psutil.disk_usage('/') |
||||||
|
|
||||||
|
# Network stats |
||||||
|
net_io = psutil.net_io_counters() |
||||||
|
|
||||||
|
# System load (1, 5, 15 minute averages) |
||||||
|
load_avg = psutil.getloadavg() |
||||||
|
|
||||||
|
# Number of connections |
||||||
|
connections = len(psutil.net_connections()) |
||||||
|
|
||||||
|
return { |
||||||
|
'cpu_percent': cpu_percent, |
||||||
|
'cpu_count': cpu_count, |
||||||
|
'memory_percent': memory.percent, |
||||||
|
'memory_available_gb': memory.available / (1024**3), |
||||||
|
'swap_percent': swap.percent, |
||||||
|
'disk_percent': disk.percent, |
||||||
|
'disk_free_gb': disk.free / (1024**3), |
||||||
|
'network_bytes_sent': net_io.bytes_sent, |
||||||
|
'network_bytes_recv': net_io.bytes_recv, |
||||||
|
'load_1min': load_avg[0], |
||||||
|
'load_5min': load_avg[1], |
||||||
|
'load_15min': load_avg[2], |
||||||
|
'connections': connections |
||||||
|
} |
||||||
|
|
||||||
|
def monitor_loop(interval_seconds: int = 60): |
||||||
|
""" |
||||||
|
Main monitoring loop that logs system stats at regular intervals. |
||||||
|
|
||||||
|
Args: |
||||||
|
interval_seconds: How often to log stats (default: 60 seconds) |
||||||
|
""" |
||||||
|
logging.info("Starting system monitoring...") |
||||||
|
|
||||||
|
while True: |
||||||
|
try: |
||||||
|
stats = get_system_stats() |
||||||
|
ssh_status = check_ssh_service() |
||||||
|
|
||||||
|
# Log current stats |
||||||
|
log_message = ( |
||||||
|
f"CPU: {stats['cpu_percent']:.1f}% | " |
||||||
|
f"MEM: {stats['memory_percent']:.1f}% ({stats['memory_available_gb']:.2f}GB free) | " |
||||||
|
f"DISK: {stats['disk_percent']:.1f}% ({stats['disk_free_gb']:.2f}GB free) | " |
||||||
|
f"LOAD: {stats['load_1min']:.2f} {stats['load_5min']:.2f} {stats['load_15min']:.2f} | " |
||||||
|
f"CONN: {stats['connections']} | " |
||||||
|
f"SSH: {ssh_status.get('status', 'unknown')}" |
||||||
|
) |
||||||
|
|
||||||
|
# Warning thresholds |
||||||
|
if stats['cpu_percent'] > 90: |
||||||
|
logging.warning(f"HIGH CPU! {log_message}") |
||||||
|
elif stats['memory_percent'] > 90: |
||||||
|
logging.warning(f"HIGH MEMORY! {log_message}") |
||||||
|
elif stats['disk_percent'] > 90: |
||||||
|
logging.warning(f"HIGH DISK USAGE! {log_message}") |
||||||
|
elif stats['load_1min'] > stats['cpu_count'] * 2: |
||||||
|
logging.warning(f"HIGH LOAD! {log_message}") |
||||||
|
elif not ssh_status.get('running'): |
||||||
|
logging.error(f"SSH SERVICE DOWN! {log_message}") |
||||||
|
else: |
||||||
|
logging.info(log_message) |
||||||
|
|
||||||
|
time.sleep(interval_seconds) |
||||||
|
|
||||||
|
except Exception as e: |
||||||
|
logging.error(f"Error in monitoring loop: {e}") |
||||||
|
time.sleep(interval_seconds) |
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
monitor_loop(interval_seconds=60) # Log every 60 seconds |
||||||
@ -0,0 +1,19 @@ |
|||||||
|
[Unit] |
||||||
|
Description=Riksdagen daily talk sync |
||||||
|
# Wait for network before starting |
||||||
|
After=network-online.target |
||||||
|
Wants=network-online.target |
||||||
|
|
||||||
|
[Service] |
||||||
|
Type=oneshot |
||||||
|
User=lasse |
||||||
|
WorkingDirectory=/home/lasse/riksdagen |
||||||
|
# Loads ARANGO_PWD and other env vars from the project .env file |
||||||
|
EnvironmentFile=/home/lasse/riksdagen/.env |
||||||
|
ExecStart=/home/lasse/riksdagen/.venv/bin/python /home/lasse/riksdagen/scripts/sync_talks.py |
||||||
|
# Log stdout/stderr to the systemd journal (view with: journalctl -u riksdagen-sync) |
||||||
|
StandardOutput=journal |
||||||
|
StandardError=journal |
||||||
|
|
||||||
|
[Install] |
||||||
|
WantedBy=multi-user.target |
||||||
@ -0,0 +1,11 @@ |
|||||||
|
[Unit] |
||||||
|
Description=Run riksdagen daily talk sync at 06:00 |
||||||
|
|
||||||
|
[Timer] |
||||||
|
# Run every day at 06:00 |
||||||
|
OnCalendar=*-*-* 06:00:00 |
||||||
|
# If the server was off at 06:00, run the job as soon as it comes back up |
||||||
|
Persistent=true |
||||||
|
|
||||||
|
[Install] |
||||||
|
WantedBy=timers.target |
||||||
@ -0,0 +1,10 @@ |
|||||||
|
[Interface] |
||||||
|
PrivateKey = yDRb0EYZkUZCuYax44lSBAP3vmN+mPdDQEh2hAQ10lY= |
||||||
|
Address = 10.156.168.2/24 |
||||||
|
DNS = 1.1.1.1, 1.0.0.1 |
||||||
|
|
||||||
|
[Peer] |
||||||
|
PublicKey = 6gwhWDypmpxrGaobEh8xZIXvRIKdp0pWH6YWZ9F8twY= |
||||||
|
PresharedKey = XAD5qpUMr0Ouz2azeXfH7J5tE3iSi5XJOdzdrUTSbRg= |
||||||
|
Endpoint = 98.128.172.165:51820 |
||||||
|
AllowedIPs = 0.0.0.0/0, ::0/0 |
||||||
@ -0,0 +1,6 @@ |
|||||||
|
""" |
||||||
|
Public entry points for the Riksdagen MCP server package. |
||||||
|
""" |
||||||
|
from .server import run |
||||||
|
|
||||||
|
__all__ = ("run",) |
||||||
@ -0,0 +1,24 @@ |
|||||||
|
from __future__ import annotations |
||||||
|
|
||||||
|
import os |
||||||
|
import secrets |
||||||
|
|
||||||
|
|
||||||
|
def validate_token(provided_token: str) -> None: |
||||||
|
""" |
||||||
|
Ensure the caller supplied the expected bearer token. |
||||||
|
|
||||||
|
Args: |
||||||
|
provided_token: Token received from the MCP client. |
||||||
|
|
||||||
|
Raises: |
||||||
|
RuntimeError: If the server token is not configured. |
||||||
|
PermissionError: If the token is missing or incorrect. |
||||||
|
""" |
||||||
|
expected_token = os.getenv("MCP_SERVER_TOKEN") |
||||||
|
if not expected_token: |
||||||
|
raise RuntimeError("MCP_SERVER_TOKEN environment variable must be set for authentication.") |
||||||
|
if not provided_token: |
||||||
|
raise PermissionError("Missing MCP access token.") |
||||||
|
if not secrets.compare_digest(provided_token, expected_token): |
||||||
|
raise PermissionError("Invalid MCP access token.") |
||||||
@ -0,0 +1,130 @@ |
|||||||
|
import ssl |
||||||
|
import socket |
||||||
|
from datetime import datetime |
||||||
|
from typing import Any, Dict, List, Optional |
||||||
|
import argparse |
||||||
|
import pprint |
||||||
|
import sys # added to detect whether --host was passed |
||||||
|
|
||||||
|
def fetch_certificate(host: str, port: int = 443, server_hostname: Optional[str] = None, timeout: float = 5.0) -> Dict[str, Any]: |
||||||
|
""" |
||||||
|
Fetch the TLS certificate from host:port. This function intentionally |
||||||
|
uses a non-verifying SSL context to retrieve the certificate even if it |
||||||
|
doesn't validate, so we can inspect its fields. |
||||||
|
|
||||||
|
Parameters: |
||||||
|
- host: TCP connect target (can be an IP or hostname) |
||||||
|
- port: TCP port (default 443) |
||||||
|
- server_hostname: SNI value to send. If None, server_hostname = host. |
||||||
|
- timeout: socket connect timeout in seconds |
||||||
|
|
||||||
|
Returns: |
||||||
|
Dictionary with the peer certificate (as returned by SSLSocket.getpeercert()) and additional metadata. |
||||||
|
""" |
||||||
|
if server_hostname is None: |
||||||
|
server_hostname = host |
||||||
|
|
||||||
|
# Create an SSL context that does NOT verify so we can always fetch the cert. |
||||||
|
context = ssl.create_default_context() |
||||||
|
context.check_hostname = False |
||||||
|
context.verify_mode = ssl.CERT_NONE |
||||||
|
|
||||||
|
with socket.create_connection((host, port), timeout=timeout) as sock: |
||||||
|
with context.wrap_socket(sock, server_hostname=server_hostname) as sslsock: |
||||||
|
cert = sslsock.getpeercert() |
||||||
|
peer_cipher = sslsock.cipher() |
||||||
|
peertime = datetime.utcnow().isoformat() + "Z" |
||||||
|
|
||||||
|
info: Dict[str, Any] = {"peer_certificate": cert, "cipher": peer_cipher, "fetched_at": peertime, "server_hostname_used": server_hostname} |
||||||
|
return info |
||||||
|
|
||||||
|
def parse_san(cert: Dict[str, Any]) -> List[str]: |
||||||
|
"""Return list of DNS names from subjectAltName (if any).""" |
||||||
|
san = [] |
||||||
|
for typ, val in cert.get("subjectAltName", ()): |
||||||
|
if typ.lower() == "dns": |
||||||
|
san.append(val) |
||||||
|
return san |
||||||
|
|
||||||
|
def format_subject(cert: Dict[str, Any]) -> str: |
||||||
|
"""Return a short human-friendly subject string.""" |
||||||
|
subject = cert.get("subject", ()) |
||||||
|
parts = [] |
||||||
|
for rdn in subject: |
||||||
|
for k, v in rdn: |
||||||
|
parts.append(f"{k}={v}") |
||||||
|
return ", ".join(parts) |
||||||
|
|
||||||
|
def check_hostname_match(cert: Dict[str, Any], hostname: str) -> bool: |
||||||
|
""" |
||||||
|
Check whether the certificate matches hostname using ssl.match_hostname. |
||||||
|
Returns True if match, False otherwise. |
||||||
|
""" |
||||||
|
try: |
||||||
|
ssl.match_hostname(cert, hostname) |
||||||
|
return True |
||||||
|
except Exception: |
||||||
|
return False |
||||||
|
|
||||||
|
def print_report(host: str, port: int, server_hostname: Optional[str]) -> None: |
||||||
|
"""Fetch certificate and print a readable report.""" |
||||||
|
info = fetch_certificate(host=host, port=port, server_hostname=server_hostname) |
||||||
|
cert = info["peer_certificate"] |
||||||
|
|
||||||
|
print(f"Connected target: {host}:{port}") |
||||||
|
print(f"SNI sent: {info['server_hostname_used']}") |
||||||
|
print(f"Cipher: {info['cipher']}") |
||||||
|
print(f"Fetched at (UTC): {info['fetched_at']}") |
||||||
|
print() |
||||||
|
|
||||||
|
print("Subject:") |
||||||
|
print(" ", format_subject(cert)) |
||||||
|
print() |
||||||
|
|
||||||
|
print("Issuer:") |
||||||
|
issuer = cert.get("issuer", ()) |
||||||
|
issuer_parts = [] |
||||||
|
for rdn in issuer: |
||||||
|
for k, v in rdn: |
||||||
|
issuer_parts.append(f"{k}={v}") |
||||||
|
print(" ", ", ".join(issuer_parts)) |
||||||
|
print() |
||||||
|
|
||||||
|
sans = parse_san(cert) |
||||||
|
print("Subject Alternative Names (SANs):") |
||||||
|
if sans: |
||||||
|
for n in sans: |
||||||
|
print(" -", n) |
||||||
|
else: |
||||||
|
print(" (none)") |
||||||
|
|
||||||
|
not_before = cert.get("notBefore") |
||||||
|
not_after = cert.get("notAfter") |
||||||
|
print() |
||||||
|
print("Validity:") |
||||||
|
print(" notBefore:", not_before) |
||||||
|
print(" notAfter: ", not_after) |
||||||
|
|
||||||
|
match = check_hostname_match(cert, server_hostname or host) |
||||||
|
print() |
||||||
|
print(f"Hostname match for '{server_hostname or host}':", "YES" if match else "NO") |
||||||
|
|
||||||
|
# For debugging show the full cert dict if requested |
||||||
|
# pprint.pprint(cert) |
||||||
|
|
||||||
|
def main() -> None: |
||||||
|
parser = argparse.ArgumentParser(description="Fetch and inspect TLS certificate from a host (SNI-aware).") |
||||||
|
# make --host optional and default to api.rixdagen.se so running without args works |
||||||
|
parser.add_argument("--host", "-H", required=False, default="api.rixdagen.se", help="Host or IP to connect to (TCP target). Defaults to api.rixdagen.se") |
||||||
|
parser.add_argument("--port", "-p", type=int, default=443, help="Port to connect to (default 443).") |
||||||
|
parser.add_argument("--sni", help="SNI hostname to send. If omitted, the --host value is used as SNI.") |
||||||
|
args = parser.parse_args() |
||||||
|
|
||||||
|
# Notify when using the default host for quick testing |
||||||
|
if ("--host" not in sys.argv) and ("-H" not in sys.argv): |
||||||
|
print("No --host provided: defaulting to api.rixdagen.se (you can override with --host or -H)") |
||||||
|
|
||||||
|
print_report(host=args.host, port=args.port, server_hostname=args.sni) |
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
main() |
||||||
@ -0,0 +1,114 @@ |
|||||||
|
""" |
||||||
|
RiksdagenTools MCP server (HTTP only, compatible with current FastMCP version) |
||||||
|
""" |
||||||
|
from __future__ import annotations |
||||||
|
import asyncio |
||||||
|
import logging |
||||||
|
import os |
||||||
|
import inspect |
||||||
|
from typing import Any, Dict, List, Optional, Sequence |
||||||
|
from fastmcp import FastMCP |
||||||
|
from mcp_server.auth import validate_token |
||||||
|
from mcp_server import tools |
||||||
|
|
||||||
|
HOST = os.getenv("MCP_HOST", "127.0.0.1") |
||||||
|
PORT = int(os.getenv("MCP_PORT", "8010")) |
||||||
|
PATH = os.getenv("MCP_PATH", "/mcp") |
||||||
|
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() |
||||||
|
|
||||||
|
logging.basicConfig( |
||||||
|
level=LOG_LEVEL, |
||||||
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", |
||||||
|
datefmt="%Y-%m-%d %H:%M:%S", |
||||||
|
) |
||||||
|
log = logging.getLogger("mcp_server") |
||||||
|
|
||||||
|
app = FastMCP("RiksdagenTools") |
||||||
|
|
||||||
|
# --- tools unchanged --- |
||||||
|
@app.tool() |
||||||
|
async def search_documents(token: str, aql_query: str) -> Dict[str, Any]: |
||||||
|
validate_token(token) |
||||||
|
return await asyncio.to_thread(tools.search_documents, aql_query) |
||||||
|
|
||||||
|
@app.tool() |
||||||
|
async def aql_query(token: str, query: str) -> List[Dict[str, Any]]: |
||||||
|
validate_token(token) |
||||||
|
return await asyncio.to_thread(tools.run_aql_query, query) |
||||||
|
|
||||||
|
@app.tool() |
||||||
|
async def vector_search_talks(token: str, query: str, limit: int = 8) -> List[Dict[str, Any]]: |
||||||
|
validate_token(token) |
||||||
|
return await asyncio.to_thread(tools.vector_search, query, limit) |
||||||
|
|
||||||
|
@app.tool() |
||||||
|
async def fetch_documents( |
||||||
|
token: str, document_ids: Sequence[str], fields: Optional[Sequence[str]] = None |
||||||
|
) -> List[Dict[str, Any]]: |
||||||
|
validate_token(token) |
||||||
|
return await asyncio.to_thread( |
||||||
|
tools.fetch_documents, list(document_ids), list(fields) if fields else None |
||||||
|
) |
||||||
|
|
||||||
|
@app.tool() |
||||||
|
async def arango_search( |
||||||
|
token: str, |
||||||
|
query: str, |
||||||
|
limit: int = 20, |
||||||
|
parties: Optional[Sequence[str]] = None, |
||||||
|
people: Optional[Sequence[str]] = None, |
||||||
|
from_year: Optional[int] = None, |
||||||
|
to_year: Optional[int] = None, |
||||||
|
return_snippets: bool = False, |
||||||
|
focus_ids: Optional[Sequence[str]] = None, |
||||||
|
speaker_ids: Optional[Sequence[str]] = None, |
||||||
|
) -> Dict[str, Any]: |
||||||
|
validate_token(token) |
||||||
|
return await asyncio.to_thread( |
||||||
|
tools.arango_search, |
||||||
|
query, |
||||||
|
limit, |
||||||
|
parties, |
||||||
|
people, |
||||||
|
from_year, |
||||||
|
to_year, |
||||||
|
return_snippets, |
||||||
|
focus_ids, |
||||||
|
speaker_ids, |
||||||
|
) |
||||||
|
|
||||||
|
@app.tool() |
||||||
|
async def ping() -> str: |
||||||
|
""" |
||||||
|
Lightweight test tool for connectivity checks. |
||||||
|
|
||||||
|
Returns: |
||||||
|
A short confirmation string ("ok"). This tool is intentionally |
||||||
|
unauthenticated so it can be used to validate the transport/proxy |
||||||
|
(e.g. nginx -> backend) without presenting credentials. |
||||||
|
""" |
||||||
|
return "ok" |
||||||
|
|
||||||
|
# --- Entrypoint --- |
||||||
|
def run() -> None: |
||||||
|
log.info( |
||||||
|
"Starting RiksdagenTools MCP server (HTTP) on http://%s:%d%s", |
||||||
|
HOST, |
||||||
|
PORT, |
||||||
|
PATH, |
||||||
|
) |
||||||
|
|
||||||
|
try: |
||||||
|
# Pass host, port, and path directly to run() method |
||||||
|
app.run( |
||||||
|
transport="streamable-http", |
||||||
|
host=HOST, |
||||||
|
port=PORT, |
||||||
|
path=PATH, |
||||||
|
) |
||||||
|
except Exception: |
||||||
|
log.exception("Unexpected error while running the MCP server.") |
||||||
|
raise |
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
run() |
||||||
@ -0,0 +1,168 @@ |
|||||||
|
""" |
||||||
|
Test script for the RiksdagenTools MCP server (HTTP transport). |
||||||
|
This script connects to the MCP server via Streamable HTTP and tests your main tools. |
||||||
|
Ensure the MCP server is running and that the environment variable MCP_SERVER_TOKEN is set. |
||||||
|
Also adjust MCP_SERVER_URL if needed. |
||||||
|
""" |
||||||
|
|
||||||
|
import os |
||||||
|
import asyncio |
||||||
|
from typing import Any, Dict, List, Optional, Sequence |
||||||
|
|
||||||
|
from mcp.client.streamable_http import streamablehttp_client |
||||||
|
from mcp.client.session import ClientSession # adjusted import per SDK version |
||||||
|
|
||||||
|
TOKEN: str = os.environ.get("MCP_SERVER_TOKEN", "2q89rwpfaiukdjshp298n3qw") |
||||||
|
SERVER_URL: str = os.environ.get("MCP_SERVER_URL", "https://api.rixdagen.se/mcp") # use the public HTTPS endpoint by default so tests target the nginx proxy |
||||||
|
|
||||||
|
async def run_tests() -> None: |
||||||
|
""" |
||||||
|
Attempt to connect to SERVER_URL and run the tool tests. On SSL certificate |
||||||
|
verification failures, optionally retry against the backend IP if |
||||||
|
MCP_FALLBACK_TO_IP=1 is set (or a custom MCP_FALLBACK_URL is provided). |
||||||
|
""" |
||||||
|
async def run_with_url(url: str) -> None: |
||||||
|
print(f"Connecting to server URL: {url}") |
||||||
|
async with streamablehttp_client( |
||||||
|
url=url, |
||||||
|
headers={ "Authorization": f"Bearer {TOKEN}" } |
||||||
|
) as (read_stream, write_stream, get_session_id): |
||||||
|
async with ClientSession(read_stream, write_stream) as session: |
||||||
|
# initialize the session (if needed) |
||||||
|
init_result = await session.initialize() |
||||||
|
print("Initialized session:", init_result) |
||||||
|
|
||||||
|
print("\nListing available tools...") |
||||||
|
tools_info = await session.list_tools() |
||||||
|
print("Tools:", [ tool.name for tool in tools_info.tools ]) |
||||||
|
|
||||||
|
# Test aql_query |
||||||
|
print("\n== Testing aql_query ==") |
||||||
|
result1 = await session.call_tool( |
||||||
|
"aql_query", |
||||||
|
arguments={ |
||||||
|
"token": TOKEN, |
||||||
|
"query": "FOR doc IN talks LIMIT 2 RETURN { _id: doc._id, talare: doc.talare }" |
||||||
|
} |
||||||
|
) |
||||||
|
print("aql_query result:", result1) |
||||||
|
|
||||||
|
# Test search_documents |
||||||
|
print("\n== Testing search_documents ==") |
||||||
|
result2 = await session.call_tool( |
||||||
|
"search_documents", |
||||||
|
arguments={ |
||||||
|
"token": TOKEN, |
||||||
|
"aql_query": "FOR doc IN talks LIMIT 2 RETURN { _id: doc._id, talare: doc.talare }" |
||||||
|
} |
||||||
|
) |
||||||
|
print("search_documents result:", result2) |
||||||
|
|
||||||
|
# Test vector_search_talks |
||||||
|
print("\n== Testing vector_search_talks ==") |
||||||
|
result3 = await session.call_tool( |
||||||
|
"vector_search_talks", |
||||||
|
arguments={ |
||||||
|
"token": TOKEN, |
||||||
|
"query": "klimatförändringar", |
||||||
|
"limit": 2 |
||||||
|
} |
||||||
|
) |
||||||
|
print("vector_search_talks result:", result3) |
||||||
|
|
||||||
|
# Test fetch_documents |
||||||
|
print("\n== Testing fetch_documents ==") |
||||||
|
# try to pull out IDs from result3 if available |
||||||
|
doc_ids: List[str] |
||||||
|
maybe = result3 |
||||||
|
if hasattr(maybe, "output") and isinstance(maybe.output, list) and maybe.output: |
||||||
|
doc_ids = [ maybe.output[0].get("_id", "") ] |
||||||
|
else: |
||||||
|
doc_ids = ["talks/1"] |
||||||
|
result4 = await session.call_tool( |
||||||
|
"fetch_documents", |
||||||
|
arguments={ |
||||||
|
"token": TOKEN, |
||||||
|
"document_ids": doc_ids |
||||||
|
} |
||||||
|
) |
||||||
|
print("fetch_documents result:", result4) |
||||||
|
|
||||||
|
# Test arango_search |
||||||
|
print("\n== Testing arango_search ==") |
||||||
|
result5 = await session.call_tool( |
||||||
|
"arango_search", |
||||||
|
arguments={ |
||||||
|
"token": TOKEN, |
||||||
|
"query": "klimat", |
||||||
|
"limit": 2 |
||||||
|
} |
||||||
|
) |
||||||
|
print("arango_search result:", result5) |
||||||
|
|
||||||
|
# try primary URL first |
||||||
|
try: |
||||||
|
await run_with_url(SERVER_URL) |
||||||
|
except Exception as e: # capture failures from streamablehttp_client / httpx |
||||||
|
err_str = str(e).lower() |
||||||
|
ssl_fail = "certificate_verify_failed" in err_str or "hostname mismatch" in err_str or "certificate verify failed" in err_str |
||||||
|
gateway_fail = "502" in err_str or "bad gateway" in err_str or "502 bad gateway" in err_str |
||||||
|
|
||||||
|
if gateway_fail: |
||||||
|
print("Received 502 Bad Gateway from the proxy when connecting to the server URL.") |
||||||
|
# If user explicitly set fallback env var, retry against backend IP or custom fallback |
||||||
|
fallback_flag = os.environ.get("MCP_FALLBACK_TO_IP", "0").lower() in ("1", "true", "yes") |
||||||
|
if fallback_flag: |
||||||
|
fallback_url = os.environ.get("MCP_FALLBACK_URL", "http://127.0.0.1:8010/mcp") |
||||||
|
print(f"Retrying with fallback URL (MCP_FALLBACK_URL or default backend): {fallback_url}") |
||||||
|
await run_with_url(fallback_url) |
||||||
|
return |
||||||
|
print("") |
||||||
|
print("Possible causes:") |
||||||
|
print("- The proxy (nginx) couldn't reach the backend (backend down, wrong proxy_pass or path).") |
||||||
|
print("- Proxy buffering or HTTP version issues interfering with streaming transport.") |
||||||
|
print("") |
||||||
|
print("Options:") |
||||||
|
print("- Bypass the proxy and target the backend directly:") |
||||||
|
print(" export MCP_SERVER_URL='http://127.0.0.1:8010/mcp'") |
||||||
|
print("- Or enable automatic fallback to the backend (insecure) for testing:") |
||||||
|
print(" export MCP_FALLBACK_TO_IP=1") |
||||||
|
print(" # optionally override the fallback target") |
||||||
|
print(" export MCP_FALLBACK_URL='http://127.0.0.1:8010/mcp'") |
||||||
|
print("- Check the proxy's error log (e.g. /var/log/nginx/error.log) for upstream errors.") |
||||||
|
print("") |
||||||
|
# re-raise so caller still sees the error if they don't follow guidance |
||||||
|
raise |
||||||
|
|
||||||
|
if ssl_fail: |
||||||
|
print("SSL certificate verification failed while connecting to the server URL.") |
||||||
|
# If user explicitly set fallback env var, retry against backend IP or custom fallback |
||||||
|
fallback_flag = os.environ.get("MCP_FALLBACK_TO_IP", "0").lower() in ("1", "true", "yes") |
||||||
|
if fallback_flag: |
||||||
|
fallback_url = os.environ.get("MCP_FALLBACK_URL", "http://192.168.1.10:8010/mcp") |
||||||
|
print(f"Retrying with fallback URL (MCP_FALLBACK_URL or default backend IP): {fallback_url}") |
||||||
|
await run_with_url(fallback_url) |
||||||
|
return |
||||||
|
# Otherwise give actionable guidance |
||||||
|
print("") |
||||||
|
print("Possible causes:") |
||||||
|
print("- The TLS certificate served for api.rixdagen.se does not match that hostname on this machine.") |
||||||
|
print("") |
||||||
|
print("Options:") |
||||||
|
print("- Set MCP_SERVER_URL to the backend HTTP address to bypass TLS: e.g.") |
||||||
|
print(" export MCP_SERVER_URL='http://192.168.1.10:8010/mcp'") |
||||||
|
print("- Or enable automatic fallback to the backend IP for testing (insecure):") |
||||||
|
print(" export MCP_FALLBACK_TO_IP=1") |
||||||
|
print(" # optionally override the fallback target") |
||||||
|
print(" export MCP_FALLBACK_URL='http://192.168.1.10:8010/mcp'") |
||||||
|
print("") |
||||||
|
# re-raise so caller still sees the error if they don't follow guidance |
||||||
|
raise |
||||||
|
# Not an SSL or gateway failure: re-raise |
||||||
|
raise |
||||||
|
|
||||||
|
def main() -> None: |
||||||
|
asyncio.run(run_tests()) |
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
main() |
||||||
@ -0,0 +1,81 @@ |
|||||||
|
import os |
||||||
|
import sys |
||||||
|
import asyncio |
||||||
|
from typing import Tuple, Any |
||||||
|
|
||||||
|
from mcp.client.streamable_http import streamablehttp_client |
||||||
|
from mcp.client.session import ClientSession |
||||||
|
|
||||||
|
SERVER_URL: str = os.environ.get("MCP_SERVER_URL", "http://127.0.0.1:8010/mcp") |
||||||
|
TOKEN: str = os.environ.get("MCP_SERVER_TOKEN", "") |
||||||
|
|
||||||
|
|
||||||
|
async def _extract_ping_result(res: Any) -> Any: |
||||||
|
""" |
||||||
|
Extract a sensible value from various CallToolResult shapes returned by the SDK. |
||||||
|
|
||||||
|
Handles: |
||||||
|
- objects with 'structuredContent' (dict) -> use 'result' or first value |
||||||
|
- objects with 'output' attribute |
||||||
|
- objects with 'content' list containing a TextContent with .text |
||||||
|
- plain scalars |
||||||
|
""" |
||||||
|
# structuredContent is common in newer SDK responses |
||||||
|
if hasattr(res, "structuredContent") and isinstance(res.structuredContent, dict): |
||||||
|
# prefer a 'result' key |
||||||
|
if "result" in res.structuredContent: |
||||||
|
return res.structuredContent["result"] |
||||||
|
# fallback to any first value |
||||||
|
for v in res.structuredContent.values(): |
||||||
|
return v |
||||||
|
|
||||||
|
# older / alternative shape: an 'output' attribute |
||||||
|
if hasattr(res, "output"): |
||||||
|
return res.output |
||||||
|
|
||||||
|
# textual content list (TextContent objects) |
||||||
|
if hasattr(res, "content") and isinstance(res.content, list) and res.content: |
||||||
|
first = res.content[0] |
||||||
|
# Some SDK TextContent exposes 'text' |
||||||
|
if hasattr(first, "text"): |
||||||
|
return first.text |
||||||
|
# fallback to stringifying the object |
||||||
|
return str(first) |
||||||
|
|
||||||
|
# final fallback: try direct indexing or string conversion |
||||||
|
try: |
||||||
|
return res["result"] |
||||||
|
except Exception: |
||||||
|
return str(res) |
||||||
|
|
||||||
|
|
||||||
|
async def check_ping(url: str, token: str) -> Tuple[bool, str]: |
||||||
|
""" |
||||||
|
Connect to the MCP server at `url` and call the 'ping' tool. |
||||||
|
|
||||||
|
Returns: |
||||||
|
(ok, message) where ok is True if ping returned "ok", otherwise False. |
||||||
|
""" |
||||||
|
headers = {"Authorization": f"Bearer {token}"} if token else {} |
||||||
|
try: |
||||||
|
async with streamablehttp_client(url=url, headers=headers) as (read_stream, write_stream, _): |
||||||
|
async with ClientSession(read_stream, write_stream) as session: |
||||||
|
await session.initialize() |
||||||
|
res = await session.call_tool("ping", arguments={}) |
||||||
|
output = await _extract_ping_result(res) # robust extractor |
||||||
|
if output == "ok": |
||||||
|
return True, "ping -> ok" |
||||||
|
return False, f"unexpected ping response: {output!r}" |
||||||
|
except Exception as e: |
||||||
|
return False, f"error connecting/calling ping: {e!r}" |
||||||
|
|
||||||
|
|
||||||
|
def main() -> None: |
||||||
|
ok, msg = asyncio.run(check_ping(SERVER_URL, TOKEN)) |
||||||
|
print(msg) |
||||||
|
if not ok: |
||||||
|
sys.exit(1) |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
main() |
||||||
@ -0,0 +1,360 @@ |
|||||||
|
from __future__ import annotations |
||||||
|
|
||||||
|
from dataclasses import dataclass |
||||||
|
from typing import Any, Dict, Iterable, List, Optional, Sequence |
||||||
|
from pydantic import BaseModel, Field |
||||||
|
|
||||||
|
import env_manager |
||||||
|
|
||||||
|
env_manager.set_env() |
||||||
|
|
||||||
|
from arango.collection import Collection # noqa: E402 |
||||||
|
from arango_client import arango # noqa: E402 |
||||||
|
from backend.services.search import SearchService # noqa: E402 |
||||||
|
from _chromadb.chroma_client import chroma_db # noqa: E402 |
||||||
|
|
||||||
|
|
||||||
|
class HitDocument(BaseModel): |
||||||
|
""" |
||||||
|
HitDocument is a Pydantic model that provides a normalized representation of a search hit across various tools, enabling consistent downstream handling. |
||||||
|
|
||||||
|
Attributes: |
||||||
|
id (Optional[str]): Fully qualified ArangoDB document identifier. |
||||||
|
key (Optional[str]): Document key without collection prefix. |
||||||
|
speaker (Optional[str]): Name of the speaker associated with the hit. |
||||||
|
party (Optional[str]): Party affiliation of the speaker. |
||||||
|
date (Optional[str]): ISO formatted document date (YYYY-MM-DD). |
||||||
|
snippet (Optional[str]): Contextual snippet or highlight from the document. |
||||||
|
text (Optional[str]): Full text of the document when available. |
||||||
|
score (Optional[float]): Relevance score supplied by the executing tool. |
||||||
|
metadata (Dict[str, Any]): Additional metadata specific to the originating tool that should be preserved. |
||||||
|
|
||||||
|
Methods: |
||||||
|
to_string() -> str: |
||||||
|
Renders the hit as a human-readable string with uppercase labels, including all present fields and metadata. |
||||||
|
""" |
||||||
|
|
||||||
|
"""Normalized representation of a search hit across tools to enable consistent downstream handling.""" |
||||||
|
id: Optional[str] = Field( |
||||||
|
default=None, description="Fully qualified ArangoDB document identifier." |
||||||
|
) |
||||||
|
key: Optional[str] = Field( |
||||||
|
default=None, description="Document key without collection prefix." |
||||||
|
) |
||||||
|
speaker: Optional[str] = Field( |
||||||
|
default=None, description="Name of the speaker associated with the hit." |
||||||
|
) |
||||||
|
party: Optional[str] = Field( |
||||||
|
default=None, description="Party affiliation of the speaker." |
||||||
|
) |
||||||
|
date: Optional[str] = Field( |
||||||
|
default=None, description="ISO formatted document date (YYYY-MM-DD)." |
||||||
|
) |
||||||
|
snippet: Optional[str] = Field( |
||||||
|
default=None, description="Contextual snippet or highlight from the document." |
||||||
|
) |
||||||
|
text: Optional[str] = Field( |
||||||
|
default=None, description="Full text of the document when available." |
||||||
|
) |
||||||
|
score: Optional[float] = Field( |
||||||
|
default=None, description="Relevance score supplied by the executing tool." |
||||||
|
) |
||||||
|
metadata: Dict[str, Any] = Field( |
||||||
|
default_factory=dict, |
||||||
|
description="Additional metadata specific to the originating tool that should be preserved.", |
||||||
|
) |
||||||
|
|
||||||
|
def to_string(self, include_metadata: bool = True) -> str: |
||||||
|
""" |
||||||
|
Render the object as a human-readable string with uppercase labels. |
||||||
|
|
||||||
|
Args: |
||||||
|
include_metadata (bool, optional): Whether to include metadata fields in the output. Defaults to True. |
||||||
|
|
||||||
|
Returns: |
||||||
|
str: A formatted string representation of the object, with each field and its value separated by double newlines, and field names in uppercase. |
||||||
|
""" |
||||||
|
data: Dict[str, Any] = self.model_dump(exclude_none=True) |
||||||
|
metadata: Dict[str, Any] = data.pop("metadata", {}) |
||||||
|
segments: List[str] = [] |
||||||
|
for field_name, field_value in data.items(): |
||||||
|
segments.append(f"{field_name.upper()}\n{field_value}") |
||||||
|
for meta_key, meta_value in metadata.items(): |
||||||
|
segments.append(f"{meta_key.upper()}\n{meta_value}") |
||||||
|
return "\n\n".join(segments) |
||||||
|
|
||||||
|
|
||||||
|
class HitsResponse(BaseModel): |
||||||
|
""" |
||||||
|
HitsResponse is a Pydantic model that serves as a container for multiple HitDocument instances, providing utility methods for formatting and rendering the collection. |
||||||
|
|
||||||
|
Attributes: |
||||||
|
hits (List[HitDocument]): A list of collected search hits. |
||||||
|
|
||||||
|
Methods: |
||||||
|
to_string() -> str: |
||||||
|
Returns a string representation of all hits, separated by a visual divider. If there are no hits, returns an empty string. |
||||||
|
""" |
||||||
|
|
||||||
|
hits: List[HitDocument] = Field( |
||||||
|
default_factory=list, description="Collected search hits." |
||||||
|
) |
||||||
|
|
||||||
|
def to_string(self, include_metadata=True) -> str: |
||||||
|
""" |
||||||
|
Render all hits as a single string, separated by a visual divider. |
||||||
|
|
||||||
|
Args: |
||||||
|
include_metadata (bool, optional): Whether to include metadata in each hit's string representation. Defaults to True. |
||||||
|
|
||||||
|
Returns: |
||||||
|
str: A single string containing all hits, separated by "\n\n---\n\n". Returns an empty string if there are no hits. |
||||||
|
""" |
||||||
|
"""Render all hits as a single string separated by a visual divider.""" |
||||||
|
if not self.hits: |
||||||
|
return "" |
||||||
|
return "\n\n---\n\n".join( |
||||||
|
hit.to_string(include_metadata=include_metadata) for hit in self.hits |
||||||
|
) |
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_read_only_aql(query: str) -> None: |
||||||
|
""" |
||||||
|
Reject AQL statements that attempt to mutate data or omit a RETURN clause. |
||||||
|
|
||||||
|
Args: |
||||||
|
query: Raw AQL statement from the client. |
||||||
|
|
||||||
|
Raises: |
||||||
|
ValueError: If the query looks unsafe. |
||||||
|
""" |
||||||
|
normalized = query.upper() |
||||||
|
forbidden = ( |
||||||
|
"INSERT ", |
||||||
|
"UPDATE ", |
||||||
|
"UPSERT ", |
||||||
|
"REMOVE ", |
||||||
|
"REPLACE ", |
||||||
|
"DELETE ", |
||||||
|
"DROP ", |
||||||
|
"TRUNCATE ", |
||||||
|
"UPSERT ", |
||||||
|
"MERGE ", |
||||||
|
) |
||||||
|
if any(keyword in normalized for keyword in forbidden): |
||||||
|
raise ValueError("Only read-only AQL queries are allowed.") |
||||||
|
if " RETURN " not in normalized and not normalized.strip().startswith("RETURN "): |
||||||
|
raise ValueError("AQL queries must include a RETURN clause.") |
||||||
|
|
||||||
|
|
||||||
|
def strip_private_fields(document: Dict[str, Any]) -> Dict[str, Any]: |
||||||
|
""" |
||||||
|
Remove large internal fields from a document dictionary. |
||||||
|
|
||||||
|
Args: |
||||||
|
document: Document returned by ArangoDB. |
||||||
|
|
||||||
|
Returns: |
||||||
|
Sanitized copy without chunk payloads. |
||||||
|
""" |
||||||
|
if "chunks" in document: |
||||||
|
del document["chunks"] |
||||||
|
return document |
||||||
|
|
||||||
|
|
||||||
|
def search_documents(aql_query: str) -> Dict[str, Any]: |
||||||
|
""" |
||||||
|
Execute a read-only AQL query and return the result set together with the query string. |
||||||
|
|
||||||
|
Args: |
||||||
|
aql_query: Read-only AQL statement supplied by the client. |
||||||
|
|
||||||
|
Returns: |
||||||
|
Dictionary containing the executed AQL string, row count, and result rows. |
||||||
|
""" |
||||||
|
ensure_read_only_aql(aql_query) |
||||||
|
rows = [strip_private_fields(doc) for doc in arango.execute_aql(aql_query)] |
||||||
|
return { |
||||||
|
"aql": aql_query, |
||||||
|
"row_count": len(rows), |
||||||
|
"rows": rows, |
||||||
|
} |
||||||
|
|
||||||
|
|
||||||
|
def run_aql_query(aql_query: str) -> List[Dict[str, Any]]: |
||||||
|
""" |
||||||
|
Execute a read-only AQL query and return the rows. |
||||||
|
|
||||||
|
Args: |
||||||
|
aql_query: Read-only AQL statement. |
||||||
|
|
||||||
|
Returns: |
||||||
|
List of result rows. |
||||||
|
""" |
||||||
|
ensure_read_only_aql(aql_query) |
||||||
|
return [strip_private_fields(doc) for doc in arango.execute_aql(aql_query)] |
||||||
|
|
||||||
|
|
||||||
|
def _get_existing_collection(name: str) -> Collection: |
||||||
|
""" |
||||||
|
Fetch an existing Chroma collection without creating new data. |
||||||
|
|
||||||
|
Args: |
||||||
|
name: Collection identifier. |
||||||
|
|
||||||
|
Returns: |
||||||
|
The requested collection. |
||||||
|
|
||||||
|
Raises: |
||||||
|
ValueError: If the collection is absent. |
||||||
|
""" |
||||||
|
available = {collection.name for collection in chroma_db._client.list_collections()} |
||||||
|
if name not in available: |
||||||
|
raise ValueError(f"Chroma collection '{name}' does not exist.") |
||||||
|
return chroma_db._client.get_collection(name=name) |
||||||
|
|
||||||
|
|
||||||
|
def vector_search(query: str, limit: int) -> List[Dict[str, Any]]: |
||||||
|
""" |
||||||
|
Perform semantic search against the pre-built Chroma collection. |
||||||
|
|
||||||
|
Args: |
||||||
|
query: Free-form search text. |
||||||
|
limit: Maximum number of hits to return. |
||||||
|
|
||||||
|
Returns: |
||||||
|
List of hit dictionaries with metadata and scores. |
||||||
|
""" |
||||||
|
collection_name = chroma_db.path.split("/")[-1] # ...existing code... |
||||||
|
chroma_collection = _get_existing_collection(collection_name) |
||||||
|
results = chroma_collection.query( |
||||||
|
query_texts=[query], |
||||||
|
n_results=limit, |
||||||
|
) |
||||||
|
metadatas = results.get("metadatas") or [] |
||||||
|
documents = results.get("documents") or [] |
||||||
|
ids = results.get("ids") or [] |
||||||
|
distances = results.get("distances") or [] |
||||||
|
|
||||||
|
def as_int(value: Any, default: int = -1) -> int: |
||||||
|
if isinstance(value, int): |
||||||
|
return value |
||||||
|
if isinstance(value, float) and value.is_integer(): |
||||||
|
return int(value) |
||||||
|
if isinstance(value, str) and value.strip().lstrip("+-").isdigit(): |
||||||
|
return int(value) |
||||||
|
return default |
||||||
|
|
||||||
|
hits: List[Dict[str, Any]] = [] |
||||||
|
for index, metadata in enumerate(metadatas[0] if metadatas else []): |
||||||
|
meta = metadata or {} |
||||||
|
document = documents[0][index] if documents else "" |
||||||
|
identifier = ids[0][index] if ids else "" |
||||||
|
hit = { |
||||||
|
"_id": meta.get("_id") or identifier, |
||||||
|
"heading": meta.get("heading") or meta.get("title") or meta.get("talare"), |
||||||
|
"snippet": meta.get("snippet") or meta.get("text") or document, |
||||||
|
"debateurl": meta.get("debateurl") or meta.get("debate_url"), |
||||||
|
"chunk_index": as_int(meta.get("chunk_index") or meta.get("index")), |
||||||
|
"score": distances[0][index] if distances else None, |
||||||
|
} |
||||||
|
if hit["_id"]: |
||||||
|
hits.append(hit) |
||||||
|
return hits |
||||||
|
|
||||||
|
|
||||||
|
def fetch_documents(document_ids: Sequence[str], fields: Optional[Iterable[str]] = None) -> List[Dict[str, Any]]: |
||||||
|
""" |
||||||
|
Pull full documents by _id while stripping heavy fields. |
||||||
|
|
||||||
|
Args: |
||||||
|
document_ids: Iterable with fully qualified Arango document ids. |
||||||
|
fields: Optional subset of fields to return. |
||||||
|
|
||||||
|
Returns: |
||||||
|
List of sanitized documents. |
||||||
|
""" |
||||||
|
ids = [doc_id.replace("\\", "/") for doc_id in document_ids] |
||||||
|
query = """ |
||||||
|
FOR id IN @document_ids |
||||||
|
RETURN DOCUMENT(id) |
||||||
|
""" |
||||||
|
documents = arango.execute_aql(query, bind_vars={"document_ids": ids}) |
||||||
|
if fields: |
||||||
|
return [{field: doc.get(field) for field in fields if field in doc} for doc in documents] |
||||||
|
return [strip_private_fields(doc) for doc in documents] |
||||||
|
|
||||||
|
|
||||||
|
@dataclass |
||||||
|
class SearchPayload: |
||||||
|
""" |
||||||
|
Lightweight container passed to SearchService.search. |
||||||
|
""" |
||||||
|
q: str |
||||||
|
parties: Optional[List[str]] |
||||||
|
people: Optional[List[str]] |
||||||
|
debates: Optional[List[str]] |
||||||
|
from_year: Optional[int] |
||||||
|
to_year: Optional[int] |
||||||
|
limit: int |
||||||
|
return_snippets: bool |
||||||
|
focus_ids: Optional[List[str]] |
||||||
|
speaker_ids: Optional[List[str]] |
||||||
|
speaker: Optional[str] = None |
||||||
|
|
||||||
|
|
||||||
|
def arango_search( |
||||||
|
query: str, |
||||||
|
limit: int, |
||||||
|
parties: Optional[Sequence[str]] = None, |
||||||
|
people: Optional[Sequence[str]] = None, |
||||||
|
from_year: Optional[int] = None, |
||||||
|
to_year: Optional[int] = None, |
||||||
|
return_snippets: bool = False, |
||||||
|
focus_ids: Optional[Sequence[str]] = None, |
||||||
|
speaker_ids: Optional[Sequence[str]] = None, |
||||||
|
) -> Dict[str, Any]: |
||||||
|
""" |
||||||
|
Run an ArangoSearch query using the existing SearchService utilities. |
||||||
|
|
||||||
|
Args: |
||||||
|
query: Search expression (supports AND/OR/NOT and phrases). |
||||||
|
limit: Maximum number of hits to return. |
||||||
|
parties: Party filters. |
||||||
|
people: Speaker name filters. |
||||||
|
from_year: Start year filter. |
||||||
|
to_year: End year filter. |
||||||
|
return_snippets: Whether only snippets should be returned. |
||||||
|
focus_ids: Optional list restricting the search scope. |
||||||
|
speaker_ids: Optional list of speaker identifiers. |
||||||
|
|
||||||
|
Returns: |
||||||
|
Dictionary containing results, stats, limit flag, and focus_ids for follow-up queries. |
||||||
|
""" |
||||||
|
payload = SearchPayload( |
||||||
|
q=query, |
||||||
|
parties=list(parties) if parties else None, |
||||||
|
people=list(people) if people else None, |
||||||
|
debates=None, |
||||||
|
from_year=from_year, |
||||||
|
to_year=to_year, |
||||||
|
limit=limit, |
||||||
|
return_snippets=return_snippets, |
||||||
|
focus_ids=list(focus_ids) if focus_ids else None, |
||||||
|
speaker_ids=list(speaker_ids) if speaker_ids else None, |
||||||
|
) |
||||||
|
service = SearchService() |
||||||
|
results, stats, limit_reached = service.search( |
||||||
|
payload=payload, |
||||||
|
include_snippets=True, |
||||||
|
return_snippets=return_snippets, |
||||||
|
focus_ids=payload.focus_ids, |
||||||
|
) |
||||||
|
return { |
||||||
|
"results": results, |
||||||
|
"stats": stats, |
||||||
|
"limit_reached": limit_reached, |
||||||
|
"return_snippets": return_snippets, |
||||||
|
"focus_ids": [hit["_id"] for hit in results if isinstance(hit, dict) and hit.get("_id")], |
||||||
|
} |
||||||
|
After Width: | Height: | Size: 153 KiB |
@ -0,0 +1,113 @@ |
|||||||
|
# Template for providers.yaml |
||||||
|
# |
||||||
|
# You can add any OpenAI API compatible provider to the "providers" list. |
||||||
|
# For each provider you must also specify a list of models, along with model abilities. |
||||||
|
# |
||||||
|
# All fields are required unless marked as optional. |
||||||
|
# |
||||||
|
# Refer to your provider's API documentation for specific |
||||||
|
# details such as model identifiers, capabilities etc |
||||||
|
# |
||||||
|
# Note: Since the OpenAI API is not a standard we can't guarantee that all |
||||||
|
# providers will work correctly with Raycast AI. |
||||||
|
# |
||||||
|
# To use this template rename as `providers.yaml` |
||||||
|
# |
||||||
|
providers: |
||||||
|
- id: perplexity |
||||||
|
name: Perplexity |
||||||
|
base_url: https://api.perplexity.ai |
||||||
|
# Specify at least one api key if authentication is required. |
||||||
|
# Optional if authentication is not required or is provided elsewhere. |
||||||
|
# If individual models require separate api keys, then specify a separate `key` for each model's `provider` |
||||||
|
api_keys: |
||||||
|
perplexity: PERPLEXITY_KEY |
||||||
|
# Optional additional parameters sent to the `/chat/completions` endpoint |
||||||
|
additional_parameters: |
||||||
|
return_images: true |
||||||
|
web_search_options: |
||||||
|
search_context_size: medium |
||||||
|
# Specify all models to use with the current provider |
||||||
|
models: |
||||||
|
- id: sonar # `id` must match the identifier used by the provider |
||||||
|
name: Sonar # name visible in Raycast |
||||||
|
provider: perplexity # Only required if mapping to a specific api key |
||||||
|
description: Perplexity AI model for general-purpose queries # optional |
||||||
|
context: 128000 # refer to provider's API documentation |
||||||
|
# Optional abilities - all child properties are also optional. |
||||||
|
# If you specify abilities incorrectly the model may fail to work as expected in Raycast AI. |
||||||
|
# Refer to provider's API documentation for model abilities. |
||||||
|
abilities: |
||||||
|
temperature: |
||||||
|
supported: true |
||||||
|
vision: |
||||||
|
supported: true |
||||||
|
system_message: |
||||||
|
supported: true |
||||||
|
tools: |
||||||
|
supported: false |
||||||
|
reasoning_effort: |
||||||
|
supported: false |
||||||
|
- id: sonar-pro |
||||||
|
name: Sonar Pro |
||||||
|
description: Perplexity AI model for complex queries |
||||||
|
context: 200000 |
||||||
|
abilities: |
||||||
|
temperature: |
||||||
|
supported: true |
||||||
|
vision: |
||||||
|
supported: true |
||||||
|
system_message: |
||||||
|
supported: true |
||||||
|
# provider with multiple api keys |
||||||
|
- id: my_provider |
||||||
|
name: My Provider |
||||||
|
base_url: http://localhost:4000 |
||||||
|
api_keys: |
||||||
|
openai: OPENAI_KEY |
||||||
|
anthropic: ANTHROPIC_KEY |
||||||
|
models: |
||||||
|
- id: gpt-4o |
||||||
|
name: "GPT-4o" |
||||||
|
context: 200000 |
||||||
|
provider: openai # matches "openai" in api_keys |
||||||
|
abilities: |
||||||
|
temperature: |
||||||
|
supported: true |
||||||
|
vision: |
||||||
|
supported: true |
||||||
|
system_message: |
||||||
|
supported: true |
||||||
|
tools: |
||||||
|
supported: true |
||||||
|
- id: claude-sonnet-4 |
||||||
|
name: "Claude Sonnet 4" |
||||||
|
context: 200000 |
||||||
|
provider: anthropic # matches "anthropic" in api_keys |
||||||
|
abilities: |
||||||
|
temperature: |
||||||
|
supported: true |
||||||
|
vision: |
||||||
|
supported: true |
||||||
|
system_message: |
||||||
|
supported: true |
||||||
|
tools: |
||||||
|
supported: true |
||||||
|
- id: litellm |
||||||
|
name: LiteLLM |
||||||
|
base_url: http://localhost:4000 |
||||||
|
# No `api_keys` - authentication is provided by the LiteLLM config |
||||||
|
models: |
||||||
|
- id: anthropic/claude-sonnet-4-20250514 |
||||||
|
name: "Claude Sonnet 4" |
||||||
|
context: 200000 |
||||||
|
abilities: |
||||||
|
temperature: |
||||||
|
supported: true |
||||||
|
vision: |
||||||
|
supported: true |
||||||
|
system_message: |
||||||
|
supported: true |
||||||
|
tools: |
||||||
|
supported: true |
||||||
|
|
||||||
@ -0,0 +1,20 @@ |
|||||||
|
providers: |
||||||
|
- id: vllm |
||||||
|
name: vLLM Instance |
||||||
|
base_url: https://lasseedfast.se/vllm |
||||||
|
api_keys: |
||||||
|
vllm: "ap98sfoiuajcnwe89sozchnsw9oeacislh" |
||||||
|
models: |
||||||
|
- id: "gpt-oss:20b" |
||||||
|
name: "GPT OSS 20B (Main)" |
||||||
|
provider: vllm |
||||||
|
context: 16000 |
||||||
|
abilities: |
||||||
|
temperature: |
||||||
|
supported: true |
||||||
|
vision: |
||||||
|
supported: false |
||||||
|
system_message: |
||||||
|
supported: true |
||||||
|
tools: |
||||||
|
supported: true |
||||||
@ -1,182 +0,0 @@ |
|||||||
#!/usr/bin/env python3 |
|
||||||
""" |
|
||||||
Convert stored embeddings to plain Python lists for an existing Chroma collection. |
|
||||||
|
|
||||||
Usage: |
|
||||||
# Dry run (inspect first 50 ids) |
|
||||||
python scripts/convert_embeddings_to_lists.py --collection talks --limit 50 --dry-run |
|
||||||
|
|
||||||
# Full run (no dry run) |
|
||||||
python scripts/convert_embeddings_to_lists.py --collection talks |
|
||||||
|
|
||||||
Notes: |
|
||||||
- Run from your project root (same env you use to access chroma_db). |
|
||||||
- Back up chromadb_data before running. |
|
||||||
""" |
|
||||||
import argparse |
|
||||||
import json |
|
||||||
import os |
|
||||||
import time |
|
||||||
from pathlib import Path |
|
||||||
from typing import List |
|
||||||
import math |
|
||||||
import sys |
|
||||||
|
|
||||||
# Use the same imports/bootstrapping as you already have in your project |
|
||||||
# so the same chroma client and embedding function are loaded. |
|
||||||
# Adjust the import path if necessary. |
|
||||||
os.chdir("/home/lasse/riksdagen") |
|
||||||
sys.path.append("/home/lasse/riksdagen") |
|
||||||
|
|
||||||
import numpy as np |
|
||||||
from _chromadb.chroma_client import chroma_db |
|
||||||
|
|
||||||
CHECKPOINT_DIR = Path("var/chroma_repair") |
|
||||||
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) |
|
||||||
|
|
||||||
def normalize_embedding(emb): |
|
||||||
""" |
|
||||||
Convert a single embedding to a plain Python list[float]. |
|
||||||
Accepts numpy arrays, array-likes, lists. |
|
||||||
""" |
|
||||||
# numpy ndarray |
|
||||||
if isinstance(emb, np.ndarray): |
|
||||||
return emb.tolist() |
|
||||||
# Some array-likes (pandas/other) may have tolist() |
|
||||||
if hasattr(emb, "tolist") and not isinstance(emb, list): |
|
||||||
try: |
|
||||||
return emb.tolist() |
|
||||||
except Exception: |
|
||||||
pass |
|
||||||
# If it's already a list of numbers, convert elements to float |
|
||||||
if isinstance(emb, list): |
|
||||||
return [float(x) for x in emb] |
|
||||||
# last resort: try iterating |
|
||||||
try: |
|
||||||
return [float(x) for x in emb] |
|
||||||
except Exception: |
|
||||||
raise ValueError("Cannot normalize embedding of type: %s" % type(emb)) |
|
||||||
|
|
||||||
def chunked_iter(iterable, n): |
|
||||||
it = iter(iterable) |
|
||||||
while True: |
|
||||||
chunk = [] |
|
||||||
try: |
|
||||||
for _ in range(n): |
|
||||||
chunk.append(next(it)) |
|
||||||
except StopIteration: |
|
||||||
pass |
|
||||||
if not chunk: |
|
||||||
break |
|
||||||
yield chunk |
|
||||||
|
|
||||||
def load_checkpoint(name): |
|
||||||
path = CHECKPOINT_DIR / f"{name}.json" |
|
||||||
if path.exists(): |
|
||||||
return json.load(path) |
|
||||||
return {"last_index": 0, "processed_ids": []} |
|
||||||
|
|
||||||
def save_checkpoint(name, data): |
|
||||||
path = CHECKPOINT_DIR / f"{name}.json" |
|
||||||
with open(path, "w") as f: |
|
||||||
json.dump(data, f) |
|
||||||
|
|
||||||
def main(): |
|
||||||
parser = argparse.ArgumentParser() |
|
||||||
parser.add_argument("--collection", required=True, help="Chroma collection name (e.g. talks)") |
|
||||||
parser.add_argument("--batch", type=int, default=1000, help="Batch size for update (default 1000)") |
|
||||||
parser.add_argument("--dry-run", action="store_true", help="Dry run: don't write updates, just report") |
|
||||||
parser.add_argument("--limit", type=int, default=None, help="Limit total number of ids to process (for testing)") |
|
||||||
parser.add_argument("--checkpoint-name", default=None, help="Name for checkpoint file (defaults to collection name)") |
|
||||||
args = parser.parse_args() |
|
||||||
|
|
||||||
coll_name = args.collection |
|
||||||
checkpoint_name = args.checkpoint_name or coll_name |
|
||||||
|
|
||||||
print(f"Connecting to Chroma collection '{coll_name}'...") |
|
||||||
col = chroma_db.get_collection(coll_name) |
|
||||||
|
|
||||||
# Get the full list of ids. For 600k this should be okay to hold in memory, |
|
||||||
# but if you need a more streaming approach, tell me and I can adapt. |
|
||||||
all_info = col.get(include=[]) # may return {'ids': [...]} as in your env |
|
||||||
ids = list(all_info.get("ids", [])) |
|
||||||
total_ids = len(ids) |
|
||||||
if args.limit: |
|
||||||
ids = ids[: args.limit] |
|
||||||
total_process = len(ids) |
|
||||||
else: |
|
||||||
total_process = total_ids |
|
||||||
|
|
||||||
print(f"Found {total_ids} ids in collection; will process {total_process} ids (limit={args.limit})") |
|
||||||
|
|
||||||
# load checkpoint |
|
||||||
ck = load_checkpoint(checkpoint_name) |
|
||||||
start_index = ck.get("last_index", 0) |
|
||||||
print(f"Resuming at index {start_index}") |
|
||||||
|
|
||||||
# iterate in batches starting from last_index |
|
||||||
processed = 0 |
|
||||||
for i in range(start_index, total_process, args.batch): |
|
||||||
batch_ids = ids[i : i + args.batch] |
|
||||||
print(f"\nProcessing batch {i}..{i+len(batch_ids)-1} (count={len(batch_ids)})") |
|
||||||
|
|
||||||
# fetch full info for this batch (documents, metadatas, embeddings) |
|
||||||
# we only need embeddings for this repair, but include docs/meta for verification if you want |
|
||||||
try: |
|
||||||
items = col.get(ids=batch_ids, include=["embeddings", "documents", "metadatas"]) |
|
||||||
except Exception as e: |
|
||||||
print("Error fetching batch:", e) |
|
||||||
# do a small retry after sleep |
|
||||||
time.sleep(2) |
|
||||||
items = col.get(ids=batch_ids, include=["embeddings", "documents", "metadatas"]) |
|
||||||
|
|
||||||
batch_embeddings = items.get("embeddings", []) |
|
||||||
# items.get("ids") should match batch_ids order; if not, align by ids |
|
||||||
ids_from_get = items.get("ids", batch_ids) |
|
||||||
if len(ids_from_get) != len(batch_ids): |
|
||||||
print("Warning: length mismatch between requested ids and returned ids") |
|
||||||
|
|
||||||
# Normalize embeddings |
|
||||||
normalized_embeddings = [] |
|
||||||
failed = False |
|
||||||
for idx, emb in enumerate(batch_embeddings): |
|
||||||
try: |
|
||||||
norm = normalize_embedding(emb) |
|
||||||
except Exception as e: |
|
||||||
print(f"Failed to normalize embedding for id {ids_from_get[idx]}: {e}") |
|
||||||
failed = True |
|
||||||
break |
|
||||||
normalized_embeddings.append(norm) |
|
||||||
|
|
||||||
if failed: |
|
||||||
print("Skipping this batch due to failures. You can adjust batch size and retry.") |
|
||||||
break |
|
||||||
|
|
||||||
# Dry-run: just print stats and continue |
|
||||||
if args.dry_run: |
|
||||||
# show a sample |
|
||||||
sample_i = min(3, len(normalized_embeddings)) |
|
||||||
print("Sample normalized embedding lengths:", [len(normalized_embeddings[k]) for k in range(sample_i)]) |
|
||||||
# Optionally inspect first few floats |
|
||||||
print("Sample values (first 6 floats):", [normalized_embeddings[k][:6] for k in range(sample_i)]) |
|
||||||
else: |
|
||||||
# Update the collection in place (update will upsert embeddings for given ids) |
|
||||||
try: |
|
||||||
col.update(ids=ids_from_get, embeddings=normalized_embeddings) |
|
||||||
except Exception as e: |
|
||||||
print("Update failed, retrying once after short sleep:", e) |
|
||||||
time.sleep(2) |
|
||||||
col.update(ids=ids_from_get, embeddings=normalized_embeddings) |
|
||||||
|
|
||||||
print(f"Updated {len(normalized_embeddings)} embeddings in collection '{coll_name}'") |
|
||||||
|
|
||||||
# checkpoint progress |
|
||||||
ck["last_index"] = i + len(batch_ids) |
|
||||||
save_checkpoint(checkpoint_name, ck) |
|
||||||
processed += len(batch_ids) |
|
||||||
|
|
||||||
print(f"\nDone. Processed {processed} ids. Checkpoint saved to {CHECKPOINT_DIR / (checkpoint_name + '.json')}") |
|
||||||
print("Reminder: run a few queries to validate search quality.") |
|
||||||
|
|
||||||
if __name__ == "__main__": |
|
||||||
main() |
|
||||||
@ -0,0 +1,25 @@ |
|||||||
|
from arango_client import arango |
||||||
|
|
||||||
|
chunks_collection = arango.db.collection("chunks") |
||||||
|
|
||||||
|
q = """ |
||||||
|
FOR chunk IN chunks |
||||||
|
FILTER chunk.parent_id == null |
||||||
|
RETURN chunk |
||||||
|
""" |
||||||
|
|
||||||
|
cursor = arango.db.aql.execute(q, batch_size=1000, count=True, ttl=360) |
||||||
|
updated_docs = [] |
||||||
|
n = 0 |
||||||
|
for doc in cursor: |
||||||
|
n += 1 |
||||||
|
doc['collection'] = 'talks' |
||||||
|
del doc['chroma_collecton'] |
||||||
|
del doc['chroma_id'] |
||||||
|
doc['parent_id'] = f"talks/{doc['_key'].split(':')[0]}" |
||||||
|
updated_docs.append(doc) |
||||||
|
if len(updated_docs) >= 100: |
||||||
|
chunks_collection.update_many(updated_docs, merge=False, silent=True) |
||||||
|
updated_docs = [] |
||||||
|
print(f"Updated {n} documents", end="\r") |
||||||
|
chunks_collection.update_many(updated_docs, merge=False, silent=True) |
||||||
@ -0,0 +1,173 @@ |
|||||||
|
import os |
||||||
|
import sys |
||||||
|
import logging |
||||||
|
|
||||||
|
# Silence the per-request HTTP logs from the ollama/httpx library |
||||||
|
logging.getLogger("httpx").setLevel(logging.WARNING) |
||||||
|
|
||||||
|
os.chdir("/home/lasse/riksdagen") |
||||||
|
sys.path.append("/home/lasse/riksdagen") |
||||||
|
|
||||||
|
from arango_client import arango |
||||||
|
from ollama import Client as Ollama |
||||||
|
from arango.collection import Collection |
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
||||||
|
from typing import List, Dict |
||||||
|
from time import sleep |
||||||
|
from utils import TextChunker |
||||||
|
|
||||||
|
|
||||||
|
def make_embeddings(texts: List[str]) -> List[List[float]]: |
||||||
|
""" |
||||||
|
Generate embeddings for a list of texts using Ollama. |
||||||
|
|
||||||
|
Args: |
||||||
|
texts (List[str]): List of text strings to embed. |
||||||
|
|
||||||
|
Returns: |
||||||
|
List[List[float]]: List of embedding vectors. |
||||||
|
""" |
||||||
|
ollama_client = Ollama(host='192.168.1.12:33405') |
||||||
|
embeddings = ollama_client.embed( |
||||||
|
model="qwen3-embedding:latest", |
||||||
|
input=texts, |
||||||
|
dimensions=384, |
||||||
|
) |
||||||
|
return embeddings.embeddings |
||||||
|
|
||||||
|
|
||||||
|
def process_chunk_batch(chunk_batch: List[Dict]) -> List[Dict]: |
||||||
|
""" |
||||||
|
Generate embeddings for a batch of chunks and attach them. |
||||||
|
|
||||||
|
Args: |
||||||
|
chunk_batch (List[Dict]): List of chunk dicts, each with a 'text' field. |
||||||
|
|
||||||
|
Returns: |
||||||
|
List[Dict]: Same list with an 'embedding' field added to each dict. |
||||||
|
""" |
||||||
|
sleep(1) |
||||||
|
texts = [chunk['text'] for chunk in chunk_batch] |
||||||
|
embeddings = make_embeddings(texts) |
||||||
|
for i, chunk in enumerate(chunk_batch): |
||||||
|
chunk['embedding'] = embeddings[i] |
||||||
|
return chunk_batch |
||||||
|
|
||||||
|
|
||||||
|
def make_arango_embeddings() -> int: |
||||||
|
""" |
||||||
|
Chunks and embeds all talks that are not yet represented in the 'chunks' collection. |
||||||
|
|
||||||
|
For each talk that has no chunks in the collection yet: |
||||||
|
- If the talk document already has a 'chunks' field (legacy path), those are used. |
||||||
|
- Otherwise the speech text is split into chunks using TextChunker. |
||||||
|
Embedding vectors are generated via Ollama and stored in the 'chunks' collection. |
||||||
|
|
||||||
|
Each chunk document in ArangoDB has: |
||||||
|
_key : "{talk_key}:{chunk_index}" (unique within the collection) |
||||||
|
text : the chunk text |
||||||
|
index : chunk index within the talk |
||||||
|
parent_id : "talks/{talk_key}" (links back to the source talk) |
||||||
|
collection: "talks" |
||||||
|
embedding : the vector (list of floats) |
||||||
|
|
||||||
|
Returns: |
||||||
|
int: Total number of chunk documents inserted/updated. |
||||||
|
""" |
||||||
|
if not arango.db.has_collection("chunks"): |
||||||
|
chunks_collection: Collection = arango.db.create_collection("chunks") |
||||||
|
else: |
||||||
|
chunks_collection: Collection = arango.db.collection("chunks") |
||||||
|
|
||||||
|
# Find every talk that has no entry yet in the chunks collection. |
||||||
|
# The inner FOR loop returns [] if no match exists (acts as NOT EXISTS). |
||||||
|
cursor = arango.db.aql.execute( |
||||||
|
""" |
||||||
|
FOR p IN talks |
||||||
|
FILTER p.anforandetext != null AND p.anforandetext != "" |
||||||
|
FILTER ( |
||||||
|
FOR c IN chunks |
||||||
|
FILTER c.parent_id == p._id |
||||||
|
LIMIT 1 |
||||||
|
RETURN 1 |
||||||
|
) == [] |
||||||
|
RETURN { |
||||||
|
_key: p._key, |
||||||
|
_id: p._id, |
||||||
|
anforandetext: p.anforandetext, |
||||||
|
chunks: p.chunks |
||||||
|
} |
||||||
|
""", |
||||||
|
batch_size=1000, |
||||||
|
ttl=360, |
||||||
|
) |
||||||
|
|
||||||
|
n = 0 |
||||||
|
embed_batch_size = 20 # Number of chunks per Ollama call |
||||||
|
chunk_batches: List[List[Dict]] = [] |
||||||
|
|
||||||
|
for talk in cursor: |
||||||
|
talk_key = talk["_key"] |
||||||
|
parent_id = f"talks/{talk_key}" |
||||||
|
|
||||||
|
if talk.get("chunks"): |
||||||
|
# Legacy path: chunks were previously generated and stored on the talk document. |
||||||
|
# Strip out the old ChromaDB-specific fields and assign a proper _key. |
||||||
|
_chunks = [] |
||||||
|
for chunk in talk["chunks"]: |
||||||
|
idx = chunk.get("index", 0) |
||||||
|
_chunks.append({ |
||||||
|
"_key": f"{talk_key}:{idx}", |
||||||
|
"text": chunk["text"], |
||||||
|
"index": idx, |
||||||
|
"parent_id": parent_id, |
||||||
|
"collection": "talks", |
||||||
|
}) |
||||||
|
else: |
||||||
|
# New path: chunk the speech text directly with TextChunker. |
||||||
|
text = (talk.get("anforandetext") or "").strip() |
||||||
|
text_chunks = TextChunker(chunk_limit=500).chunk(text) |
||||||
|
_chunks = [ |
||||||
|
{ |
||||||
|
"_key": f"{talk_key}:{idx}", |
||||||
|
"text": content, |
||||||
|
"index": idx, |
||||||
|
"parent_id": parent_id, |
||||||
|
"collection": "talks", |
||||||
|
} |
||||||
|
for idx, content in enumerate(text_chunks) |
||||||
|
if content and content.strip() |
||||||
|
] |
||||||
|
|
||||||
|
# Split into batches for embedding |
||||||
|
for i in range(0, len(_chunks), embed_batch_size): |
||||||
|
batch = _chunks[i : i + embed_batch_size] |
||||||
|
if batch: |
||||||
|
chunk_batches.append(batch) |
||||||
|
|
||||||
|
# Embed all batches in parallel (Ollama calls are I/O-bound, threads are fine) |
||||||
|
total_batches = len(chunk_batches) |
||||||
|
completed_batches = 0 |
||||||
|
with ThreadPoolExecutor(max_workers=3) as executor: |
||||||
|
futures = [executor.submit(process_chunk_batch, batch) for batch in chunk_batches] |
||||||
|
processed_chunks: List[Dict] = [] |
||||||
|
for future in as_completed(futures): |
||||||
|
result = future.result() |
||||||
|
completed_batches += 1 |
||||||
|
processed_chunks.extend(result) |
||||||
|
print(f"Embedding batches: {completed_batches}/{total_batches} | chunks ready to insert: {len(processed_chunks)}", end="\r") |
||||||
|
# Insert in batches of 100 to keep HTTP payloads small |
||||||
|
if len(processed_chunks) >= 100: |
||||||
|
n += len(processed_chunks) |
||||||
|
chunks_collection.insert_many(processed_chunks, overwrite=True) |
||||||
|
processed_chunks = [] |
||||||
|
if processed_chunks: |
||||||
|
n += len(processed_chunks) |
||||||
|
chunks_collection.insert_many(processed_chunks, overwrite=True) |
||||||
|
|
||||||
|
print(f"\nDone. Inserted/updated {n} chunks in ArangoDB.") |
||||||
|
return n |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
make_arango_embeddings() |
||||||
@ -0,0 +1,5 @@ |
|||||||
|
### Inlogg riksdagsgruppen från Fojo-hackathon |
||||||
|
https://arango.lasseedfast.se |
||||||
|
riksdagsgruppen |
||||||
|
popre4-cygcuz-viHjyc |
||||||
|
|
||||||
@ -0,0 +1,177 @@ |
|||||||
|
""" |
||||||
|
Synkroniserar nya anföranden från riksdagen.se till databasen och processar dem. |
||||||
|
|
||||||
|
Pipeline (körs dagligen via systemd timer): |
||||||
|
1. Ladda ned årets anföranden från riksdagen.se (ersätter tidigare nerladdning) |
||||||
|
2. Infoga nya anföranden i ArangoDB (hoppar över redan existerande) |
||||||
|
3. Tilldela debatt-ID:n till anföranden som saknar det |
||||||
|
4. Bygg embeddings för datum som saknar chunks |
||||||
|
5. Generera sammanfattningar för datum som saknar summary |
||||||
|
|
||||||
|
Kör manuellt: python scripts/sync_talks.py |
||||||
|
""" |
||||||
|
|
||||||
|
import os |
||||||
|
import sys |
||||||
|
import logging |
||||||
|
from datetime import datetime |
||||||
|
from io import BytesIO |
||||||
|
from urllib.request import urlopen |
||||||
|
from zipfile import ZipFile |
||||||
|
|
||||||
|
# Säkerställ att vi kör från projektroten och att lokala moduler hittas |
||||||
|
os.chdir("/home/lasse/riksdagen") |
||||||
|
sys.path.append("/home/lasse/riksdagen") |
||||||
|
|
||||||
|
logging.basicConfig( |
||||||
|
level=logging.INFO, |
||||||
|
format="%(asctime)s [%(levelname)s] %(message)s", |
||||||
|
) |
||||||
|
logger = logging.getLogger(__name__) |
||||||
|
|
||||||
|
# Systemprompt som används av LLM:en vid sammanfattning av debatter |
||||||
|
SYSTEM_MESSAGE = """Din uppgift är att sammanfatta debatter i Sveriges riksdag. |
||||||
|
Du kommer först att få enskilda tal som du ska sammanfatta var för sig, efter det ska du sammanfatta hela debatten. |
||||||
|
Sammanfattningarna ska vara på svenska och vara koncisa och informativa. |
||||||
|
Det är viktigt att du förstår vad som är kärnan i varje tal och debatt, fokusera därför på de argument och sakförhållanden som framförs. |
||||||
|
""" |
||||||
|
|
||||||
|
|
||||||
|
def get_current_session_year() -> int: |
||||||
|
""" |
||||||
|
Returnerar startåret för aktuell riksdagssession. |
||||||
|
|
||||||
|
Riksdagssessionen löper september–augusti, så: |
||||||
|
- Januari–Augusti 2026 → sessionen startade sep 2025 → returnerar 2025 |
||||||
|
- September–December 2025 → sessionen startade sep 2025 → returnerar 2025 |
||||||
|
|
||||||
|
Returns: |
||||||
|
int: Fyrsiffrigt startår för aktuell session (t.ex. 2025). |
||||||
|
""" |
||||||
|
now = datetime.now() |
||||||
|
if now.month >= 9: |
||||||
|
return now.year |
||||||
|
else: |
||||||
|
return now.year - 1 |
||||||
|
|
||||||
|
|
||||||
|
def download_current_year(session_year: int) -> str: |
||||||
|
""" |
||||||
|
Laddar ned och extraherar ZIP-arkivet för angiven riksdagssession, |
||||||
|
och ersätter eventuella tidigare nerladdade filer för det året. |
||||||
|
|
||||||
|
Riksdagen uppdaterar kontinuerligt samma ZIP-fil under pågående session, |
||||||
|
så vi måste ladda ned den på nytt varje gång för att få med nya anföranden. |
||||||
|
|
||||||
|
Args: |
||||||
|
session_year (int): Sessionens startår (t.ex. 2025 för session 2025/26). |
||||||
|
|
||||||
|
Returns: |
||||||
|
str: Sökväg till katalogen dit filerna extraherades. |
||||||
|
""" |
||||||
|
second_part = str(session_year + 1)[2:] # t.ex. "26" för 2026 |
||||||
|
url = f"https://data.riksdagen.se/dataset/anforande/anforande-{session_year}{second_part}.json.zip" |
||||||
|
folder_name = f"anforande-{session_year}{second_part}" |
||||||
|
dir_path = os.path.join("talks", folder_name) |
||||||
|
|
||||||
|
logger.info(f"Downloading {url} → {dir_path}") |
||||||
|
os.makedirs(dir_path, exist_ok=True) |
||||||
|
|
||||||
|
# Rensa gamla filer så vi får en färsk kopia |
||||||
|
for f in os.listdir(dir_path): |
||||||
|
os.remove(os.path.join(dir_path, f)) |
||||||
|
|
||||||
|
with urlopen(url) as resp: |
||||||
|
with ZipFile(BytesIO(resp.read())) as zf: |
||||||
|
zf.extractall(dir_path) |
||||||
|
|
||||||
|
count = len(os.listdir(dir_path)) |
||||||
|
logger.info(f"Extracted {count} files to {dir_path}") |
||||||
|
return dir_path |
||||||
|
|
||||||
|
|
||||||
|
def get_unsummarized_dates() -> list[str]: |
||||||
|
""" |
||||||
|
Hämtar datum från ArangoDB som har anföranden utan sammanfattning. |
||||||
|
|
||||||
|
Returns: |
||||||
|
list[str]: Sorterad lista med datumsträngar, t.ex. ["2026-02-10", "2026-02-11"]. |
||||||
|
""" |
||||||
|
from arango_client import arango |
||||||
|
|
||||||
|
cursor = arango.db.aql.execute( |
||||||
|
""" |
||||||
|
FOR doc IN talks |
||||||
|
FILTER doc.summary == null |
||||||
|
RETURN DISTINCT doc.datum |
||||||
|
""", |
||||||
|
ttl=300, |
||||||
|
) |
||||||
|
dates = sorted(list(cursor)) |
||||||
|
logger.info(f"Found {len(dates)} dates with unsummarized talks") |
||||||
|
return dates |
||||||
|
|
||||||
|
|
||||||
|
def sync() -> None: |
||||||
|
""" |
||||||
|
Kör hela sync-pipelinen: |
||||||
|
1. Ladda ned årets anföranden |
||||||
|
2. Infoga nya anföranden i ArangoDB |
||||||
|
3. Tilldela debatt-ID:n |
||||||
|
4. Bygg embeddings för nya datum |
||||||
|
5. Generera sammanfattningar för nya datum |
||||||
|
""" |
||||||
|
logger.info("=== Starting daily riksdagen sync ===") |
||||||
|
|
||||||
|
# --- Steg 1: Ladda ned --- |
||||||
|
session_year = get_current_session_year() |
||||||
|
logger.info(f"Current session year: {session_year}/{session_year + 1}") |
||||||
|
dir_path = download_current_year(session_year) |
||||||
|
|
||||||
|
# --- Steg 2: Infoga nya anföranden i ArangoDB --- |
||||||
|
# update_folder() hämtar alla befintliga _key:s från databasen och hoppar |
||||||
|
# över dem, så enbart nya anföranden infogas. |
||||||
|
logger.info("Stage 2: Inserting new talks into ArangoDB...") |
||||||
|
from scripts.documents_to_arango import update_folder |
||||||
|
|
||||||
|
new_talks = update_folder(os.path.abspath(dir_path)) |
||||||
|
logger.info(f"Stage 2 complete: {new_talks} new talks inserted") |
||||||
|
|
||||||
|
# --- Steg 3: Tilldela debatt-ID:n --- |
||||||
|
# Anföranden som saknar fältet 'debate' grupperas i debatter baserat på |
||||||
|
# datum och om de är repliker eller ej. |
||||||
|
logger.info("Stage 3: Assigning debate IDs to talks missing them...") |
||||||
|
from scripts.debates import make_debate_ids |
||||||
|
|
||||||
|
make_debate_ids() |
||||||
|
logger.info("Stage 3 complete") |
||||||
|
|
||||||
|
# --- Steg 4: Chunk + bygg embeddings i ArangoDB --- |
||||||
|
# make_arango_embeddings() hittar alla anföranden som saknar chunks i |
||||||
|
# 'chunks'-kollektionen, chunkar texten, genererar vektorer via Ollama |
||||||
|
# och lagrar allt direkt i ArangoDB. ChromaDB används inte. |
||||||
|
logger.info("Stage 4: Chunking and embedding new talks into ArangoDB...") |
||||||
|
from scripts.make_arango_embeddings import make_arango_embeddings |
||||||
|
|
||||||
|
total_chunks = make_arango_embeddings() |
||||||
|
logger.info(f"Stage 4 complete: {total_chunks} chunks created") |
||||||
|
|
||||||
|
# --- Steg 5: Generera sammanfattningar --- |
||||||
|
# process_debate_date() hoppar automatiskt över anföranden som redan har |
||||||
|
# en sammanfattning, så det är säkert att köra igen. |
||||||
|
new_dates = get_unsummarized_dates() |
||||||
|
if new_dates: |
||||||
|
logger.info(f"Stage 5: Generating summaries for {len(new_dates)} dates...") |
||||||
|
from scripts.debates import process_debate_date |
||||||
|
|
||||||
|
for date in new_dates: |
||||||
|
process_debate_date(date, SYSTEM_MESSAGE) |
||||||
|
logger.info(f"Stage 5 complete: summaries generated for {len(new_dates)} dates") |
||||||
|
else: |
||||||
|
logger.info("Stage 5: No unsummarized dates, skipping") |
||||||
|
|
||||||
|
logger.info("=== Sync complete ===") |
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__": |
||||||
|
sync() |
||||||
@ -0,0 +1,57 @@ |
|||||||
|
from arango_client import arango |
||||||
|
from scripts.make_arango_embeddings import process_chunk_batch |
||||||
|
from arango.collection import Collection |
||||||
|
from typing import List, Dict |
||||||
|
|
||||||
|
def test_full_make_arango_embeddings_for_one_talk() -> None: |
||||||
|
""" |
||||||
|
Integration test for the full make_arango_embeddings chain: |
||||||
|
- Fetches a specific talk document from ArangoDB. |
||||||
|
- Processes its chunks to generate embeddings. |
||||||
|
- Inserts/updates those chunks in the 'chunks' collection. |
||||||
|
- Verifies that the chunks were updated in ArangoDB. |
||||||
|
|
||||||
|
This test requires ArangoDB and Ollama to be running and accessible. |
||||||
|
""" |
||||||
|
# The _id of the talk we want to process |
||||||
|
target_id: str = "talks/000004cc-b896-e611-9441-00262d0d7125" |
||||||
|
_key = target_id.split("/")[-1] |
||||||
|
|
||||||
|
# Get the talks and chunks collections |
||||||
|
talks_collection: Collection = arango.db.collection("talks") |
||||||
|
chunks_collection: Collection = arango.db.collection("chunks") |
||||||
|
|
||||||
|
# Fetch the talk document |
||||||
|
talk: Dict = talks_collection.get(target_id) |
||||||
|
assert talk is not None, f"Talk with _id {target_id} not found" |
||||||
|
assert "chunks" in talk and talk["chunks"], "Talk has no chunks" |
||||||
|
|
||||||
|
# Prepare chunks for embedding |
||||||
|
processed_chunks: List[Dict] = [] |
||||||
|
for chunk in talk["chunks"]: |
||||||
|
key: str = chunk["chroma_id"].split("/")[-1] |
||||||
|
chunk["_key"] = key.split(":")[-1] |
||||||
|
chunk["parent_id"] = target_id |
||||||
|
chunk["collection"] = "talks" |
||||||
|
# Remove fields not needed for embedding |
||||||
|
if "chroma_id" in chunk: |
||||||
|
del chunk["chroma_id"] |
||||||
|
if "chroma_collecton" in chunk: |
||||||
|
del chunk["chroma_collecton"] |
||||||
|
processed_chunks.append(chunk) |
||||||
|
|
||||||
|
# Generate embeddings for all chunks |
||||||
|
processed_chunks = process_chunk_batch(processed_chunks) |
||||||
|
|
||||||
|
# Insert/update chunks in the 'chunks' collection |
||||||
|
chunks_collection.insert_many(processed_chunks, overwrite=True) |
||||||
|
|
||||||
|
# Verify that the chunks were updated in ArangoDB |
||||||
|
for chunk in processed_chunks: |
||||||
|
db_chunk = chunks_collection.get(chunk["_key"]) |
||||||
|
assert db_chunk is not None, f"Chunk {_key} not found in DB" |
||||||
|
assert "embedding" in db_chunk, "Chunk missing embedding in DB" |
||||||
|
assert isinstance(db_chunk["embedding"], list), "Embedding is not a list" |
||||||
|
print(f"Chunk {chunk['_key']} updated with embedding of length {len(db_chunk['embedding'])}") |
||||||
|
|
||||||
|
test_full_make_arango_embeddings_for_one_talk() |
||||||
Loading…
Reference in new issue