You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
949 lines
30 KiB
949 lines
30 KiB
import os |
|
from datetime import datetime |
|
|
|
from dotenv import load_dotenv |
|
from arango import ArangoClient |
|
from arango.collection import StandardCollection as ArangoCollection |
|
|
|
from models import UnifiedDataChunk, UnifiedSearchResults |
|
from utils import fix_key |
|
|
|
if "INFO" not in os.environ: |
|
import env_manager |
|
|
|
env_manager.set_env() |
|
|
|
load_dotenv() # Install with pip install python-dotenv |
|
|
|
COLLECTIONS_IN_BASE = [ |
|
"sci_articles", |
|
] |
|
|
|
|
|
class ArangoDB: |
|
""" |
|
ArangoDB Client Wrapper |
|
This class provides a wrapper around the ArangoClient to simplify working with ArangoDB databases |
|
and collections in a scientific document management context. It handles authentication, database |
|
connections, and provides high-level methods for common operations. |
|
Key features: |
|
- Database and collection management |
|
- Document CRUD operations (Create, Read, Update, Delete) |
|
- AQL query execution |
|
- Scientific article storage and retrieval |
|
- Project and note management |
|
- Chat history storage |
|
- Settings management |
|
Usage example: |
|
arango = ArangoDB(user="admin", password="password") |
|
# Create a collection |
|
arango.create_collection("my_collection") |
|
# Insert a document |
|
doc = arango.insert_document("my_collection", {"name": "Test Document"}) |
|
# Query documents |
|
results = arango.execute_aql("FOR doc IN my_collection RETURN doc") |
|
Environment variables: |
|
ARANGO_HOST: The ArangoDB host URL |
|
ARANGO_PASSWORD: The default password for authentication |
|
""" |
|
def __init__(self, user="admin", password=None, db_name="base"): |
|
""" |
|
Initialize a connection to an ArangoDB database. |
|
This constructor establishes a connection to an ArangoDB instance using the provided |
|
credentials and database name. It uses environment variables for host and password |
|
if not explicitly provided. |
|
Parameters |
|
---------- |
|
user : str, optional |
|
Username for database authentication. Defaults to "admin". |
|
If db_name is not "base", then user will be set to db_name. |
|
password : str, optional |
|
Password for database authentication. If not provided, |
|
the password will be retrieved from the ARANGO_PASSWORD environment variable. |
|
db_name : str, optional |
|
Name of the database to connect to. Defaults to "base". |
|
If not "base", this value will also be used as the username. |
|
Notes |
|
----- |
|
- The host URL is always retrieved from the ARANGO_HOST environment variable. |
|
- For the "base" database, the username will be either "admin" or the provided user. |
|
- For other databases, the username will be the same as the database name. |
|
Attributes |
|
---------- |
|
user : str |
|
The username used for authentication. |
|
password : str |
|
The password used for authentication. |
|
db_name : str |
|
The name of the connected database. |
|
client : ArangoClient |
|
The ArangoDB client instance. |
|
db : Database |
|
The database instance for executing operations. |
|
""" |
|
|
|
host = os.getenv("ARANGO_HOST") |
|
if not password: |
|
self.password = os.getenv("ARANGO_PASSWORD") |
|
# This is the default user for the base database |
|
if db_name != "base": |
|
self.user = db_name |
|
self.db_name = db_name |
|
|
|
elif user == "admin": |
|
self.user = "admin" |
|
self.db_name = "base" |
|
else: |
|
self.user = user |
|
self.db_name = user |
|
|
|
self.client = ArangoClient(hosts=host) |
|
self.db = self.client.db( |
|
self.db_name, username=self.user, password=self.password |
|
) |
|
|
|
def fix_key(self, _key): |
|
return fix_key(_key) |
|
|
|
# Collection operations |
|
def get_collection(self, collection_name: str) -> ArangoCollection: |
|
""" |
|
Get a collection by name. |
|
|
|
Args: |
|
collection_name (str): The name of the collection. |
|
|
|
Returns: |
|
ArangoCollection: The collection object. |
|
""" |
|
return self.db.collection(collection_name) |
|
|
|
def has_collection(self, collection_name: str) -> bool: |
|
""" |
|
Check if a collection exists. |
|
|
|
Args: |
|
collection_name (str): The name of the collection. |
|
|
|
Returns: |
|
bool: True if the collection exists, False otherwise. |
|
""" |
|
return self.db.has_collection(collection_name) |
|
|
|
def create_collection(self, collection_name: str) -> ArangoCollection: |
|
""" |
|
Create a new collection. |
|
|
|
Args: |
|
collection_name (str): The name of the collection to create. |
|
|
|
Returns: |
|
ArangoCollection: The created collection. |
|
""" |
|
return self.db.create_collection(collection_name) |
|
|
|
def delete_collection(self, collection_name: str) -> bool: |
|
""" |
|
Delete a collection. |
|
|
|
Args: |
|
collection_name (str): The name of the collection to delete. |
|
|
|
Returns: |
|
bool: True if the collection was deleted successfully. |
|
""" |
|
if self.has_collection(collection_name): |
|
return self.db.delete_collection(collection_name) |
|
return False |
|
|
|
def truncate_collection(self, collection_name: str) -> bool: |
|
""" |
|
Truncate a collection (remove all documents). |
|
|
|
Args: |
|
collection_name (str): The name of the collection to truncate. |
|
|
|
Returns: |
|
bool: True if the collection was truncated successfully. |
|
""" |
|
if self.has_collection(collection_name): |
|
return self.db.collection(collection_name).truncate() |
|
return False |
|
|
|
# Document operations |
|
def get_document(self, document_id: str): |
|
""" |
|
Get a document by ID. |
|
|
|
Args: |
|
document_id (str): The ID of the document to get. |
|
|
|
Returns: |
|
dict: The document if found, None otherwise. |
|
""" |
|
try: |
|
return self.db.document(document_id) |
|
except: |
|
return None |
|
|
|
def has_document(self, collection_name: str, document_key: str) -> bool: |
|
""" |
|
Check if a document exists in a collection. |
|
|
|
Args: |
|
collection_name (str): The name of the collection. |
|
document_key (str): The key of the document. |
|
|
|
Returns: |
|
bool: True if the document exists, False otherwise. |
|
""" |
|
return self.db.collection(collection_name).has(document_key) |
|
|
|
def insert_document( |
|
self, |
|
collection_name: str, |
|
document: dict, |
|
overwrite: bool = False, |
|
overwrite_mode: str = "update", |
|
keep_none: bool = False, |
|
): |
|
""" |
|
Insert a document into a collection. |
|
|
|
Args: |
|
collection_name (str): The name of the collection. |
|
document (dict): The document to insert. |
|
overwrite (bool, optional): Whether to overwrite an existing document. Defaults to False. |
|
overwrite_mode (str, optional): The mode for overwriting ('replace' or 'update'). Defaults to "replace". |
|
keep_none (bool, optional): Whether to keep None values. Defaults to False. |
|
|
|
Returns: |
|
dict: The inserted document with its metadata (_id, _key, etc.) |
|
""" |
|
assert '_id' in document or '_key' in document, "Document must have either _id or _key" |
|
if '_id' not in document: |
|
document['_id'] = f"{collection_name}/{document['_key']}" |
|
return self.db.collection(collection_name).insert( |
|
document, |
|
overwrite=overwrite, |
|
overwrite_mode=overwrite_mode, |
|
keep_none=keep_none, |
|
) |
|
|
|
def update_document( |
|
self, document: dict, check_rev: bool = False, silent: bool = False |
|
): |
|
""" |
|
Update a document that already has _id or _key. |
|
|
|
Args: |
|
document (dict): The document to update. |
|
check_rev (bool, optional): Whether to check document revision. Defaults to False. |
|
silent (bool, optional): Whether to return the updated document. Defaults to False. |
|
|
|
Returns: |
|
dict: The updated document if silent is False. |
|
""" |
|
return self.db.update_document(document, check_rev=check_rev, silent=silent) |
|
|
|
def update_document_by_match( |
|
self, collection_name: str, filters: dict, body: dict, merge: bool = True |
|
): |
|
""" |
|
Update documents that match a filter. |
|
|
|
Args: |
|
collection_name (str): The name of the collection. |
|
filters (dict): The filter to match documents. |
|
body (dict): The update to apply. |
|
merge (bool, optional): Whether to merge the update with existing data. Defaults to True. |
|
|
|
Returns: |
|
dict: The result of the update operation. |
|
""" |
|
return self.db.collection(collection_name).update_match( |
|
filters=filters, body=body, merge=merge |
|
) |
|
|
|
def delete_document(self, collection_name: str, document_key: str): |
|
""" |
|
Delete a document from a collection. |
|
|
|
Args: |
|
collection_name (str): The name of the collection. |
|
document_key (str): The key of the document to delete. |
|
|
|
Returns: |
|
dict: The deletion result. |
|
""" |
|
return self.db.collection(collection_name).delete(document_key) |
|
|
|
def delete_document_by_match(self, collection_name: str, filters: dict): |
|
""" |
|
Delete documents that match a filter. |
|
|
|
Args: |
|
collection_name (str): The name of the collection. |
|
filters (dict): The filter to match documents. |
|
|
|
Returns: |
|
dict: The deletion result. |
|
""" |
|
return self.db.collection(collection_name).delete_match(filters=filters) |
|
|
|
# Query operations |
|
def execute_aql(self, query: str, bind_vars: dict = None): |
|
""" |
|
Execute an AQL query. |
|
|
|
Args: |
|
query (str): The AQL query to execute. |
|
bind_vars (dict, optional): Bind variables for the query. Defaults to None. |
|
|
|
Returns: |
|
Cursor: A cursor to the query results. |
|
""" |
|
return self.db.aql.execute(query, bind_vars=bind_vars) |
|
|
|
def get_all_documents(self, collection_name: str): |
|
""" |
|
Get all documents from a collection. |
|
|
|
Args: |
|
collection_name (str): The name of the collection. |
|
|
|
Returns: |
|
list: All documents in the collection. |
|
""" |
|
return list(self.db.collection(collection_name).all()) |
|
|
|
# Database operations |
|
def has_database(self, db_name: str) -> bool: |
|
""" |
|
Check if a database exists. |
|
|
|
Args: |
|
db_name (str): The name of the database. |
|
|
|
Returns: |
|
bool: True if the database exists, False otherwise. |
|
""" |
|
return self.client.has_database(db_name) |
|
|
|
def create_database(self, db_name: str, users: list = None) -> bool: |
|
""" |
|
Create a new database. |
|
|
|
Args: |
|
db_name (str): The name of the database to create. |
|
users (list, optional): List of user objects with access to the database. Defaults to None. |
|
|
|
Returns: |
|
bool: True if the database was created successfully. |
|
""" |
|
return self.client.create_database(db_name, users=users) |
|
|
|
def delete_database(self, db_name: str) -> bool: |
|
""" |
|
Delete a database. |
|
|
|
Args: |
|
db_name (str): The name of the database to delete. |
|
|
|
Returns: |
|
bool: True if the database was deleted successfully. |
|
""" |
|
if self.client.has_database(db_name): |
|
return self.client.delete_database(db_name) |
|
return False |
|
|
|
# Domain-specific operations |
|
|
|
# Scientific Articles |
|
def get_article( |
|
self, |
|
article_key: str, |
|
db_name: str = None, |
|
collection_name: str = "sci_articles", |
|
): |
|
""" |
|
Get a scientific article by key. |
|
|
|
Args: |
|
article_key (str): The key of the article. |
|
db_name (str, optional): The database name to search in. Defaults to current database. |
|
|
|
Returns: |
|
dict: The article document if found, None otherwise. |
|
""" |
|
try: |
|
return self.db.collection("sci_articles").get(article_key) |
|
except Exception as e: |
|
print(f"Error retrieving article {article_key}: {e}") |
|
raise e |
|
return None |
|
|
|
def get_article_by_doi(self, doi: str): |
|
""" |
|
Get a scientific article by DOI. |
|
|
|
Args: |
|
doi (str): The DOI of the article. |
|
|
|
Returns: |
|
dict: The article document if found, None otherwise. |
|
""" |
|
query = """ |
|
FOR doc IN sci_articles |
|
FILTER doc.metadata.doi == @doi |
|
RETURN doc |
|
""" |
|
cursor = self.db.aql.execute(query, bind_vars={"doi": doi}) |
|
try: |
|
return next(cursor) |
|
except StopIteration: |
|
return None |
|
|
|
def get_document_text( |
|
self, _id: str = None, _key: str = None, collection: str = None |
|
): |
|
""" |
|
Get the text content of a document. If _key is used, collection must be provided. |
|
* Use base_arango for sci_articles and user_arango for other collections. * |
|
Args: |
|
_id (str, optional): The ID of the document. Defaults to None. |
|
_key (str, optional): The key of the document. Defaults to None. |
|
collection (str, optional): The name of the collection. Defaults to None. |
|
Returns: |
|
str: The text content of the document, or None if not found. |
|
""" |
|
if collection == "sci_articles" or _id.startswith("sci_articles"): |
|
assert ( |
|
self.db_name == "base" |
|
), "If requesting sci_articles base_arango must be used" |
|
else: |
|
assert ( |
|
self.db_name != "base" |
|
), "If not requesting sci_articles user_arango must be used" |
|
|
|
try: |
|
if _id: |
|
doc = self.db.document(_id) |
|
elif _key: |
|
assert ( |
|
collection is not None |
|
), "Collection name must be provided if _key is used" |
|
doc = self.db.collection(collection).get(_key) |
|
|
|
text = [chunk.get("text") for chunk in doc.get("chunks", [])] |
|
except Exception as e: |
|
print(f"Error retrieving text for document {_id or _key}: {e}") |
|
return None |
|
return "\n".join(text) if text else None |
|
|
|
def store_article_chunks( |
|
self, article_data: dict, chunks: list, document_key: str = None |
|
): |
|
""" |
|
Store article chunks in the database. |
|
|
|
Args: |
|
article_data (dict): The article metadata. |
|
chunks (list): The chunks of text from the article. |
|
document_key (str, optional): The key to use for the document. Defaults to None. |
|
|
|
Returns: |
|
tuple: (document_id, database_name, document_doi) |
|
""" |
|
collection = "sci_articles" |
|
|
|
arango_chunks = [] |
|
for index, chunk in enumerate(chunks): |
|
chunk_id = f"{document_key}_{index}" if document_key else f"chunk_{index}" |
|
page_numbers = chunk.get("pages", []) |
|
text = chunk.get("text", "") |
|
arango_chunks.append({"text": text, "pages": page_numbers, "id": chunk_id}) |
|
|
|
arango_document = { |
|
"_key": document_key, |
|
"chunks": arango_chunks, |
|
"metadata": article_data.get("metadata", {}), |
|
} |
|
|
|
if article_data.get("summary"): |
|
arango_document["summary"] = article_data.get("summary") |
|
|
|
if article_data.get("doi"): |
|
arango_document["crossref"] = True |
|
|
|
doc = self.insert_document( |
|
collection_name=collection, |
|
document=arango_document, |
|
overwrite=True, |
|
overwrite_mode="update", |
|
keep_none=False, |
|
) |
|
|
|
return doc["_id"], self.db_name, article_data.get("doi") |
|
|
|
def add_article_to_collection(self, article_id: str, collection_name: str): |
|
""" |
|
Add an article to a user's article collection. |
|
|
|
Args: |
|
article_id (str): The ID of the article. |
|
collection_name (str): The name of the user's collection. |
|
|
|
Returns: |
|
bool: True if the article was added successfully. |
|
""" |
|
query = """ |
|
FOR collection IN article_collections |
|
FILTER collection.name == @collection_name |
|
UPDATE collection WITH { |
|
articles: PUSH(collection.articles, @article_id) |
|
} IN article_collections |
|
RETURN NEW |
|
""" |
|
cursor = self.db.aql.execute( |
|
query, |
|
bind_vars={"collection_name": collection_name, "article_id": article_id}, |
|
) |
|
try: |
|
return next(cursor) is not None |
|
except StopIteration: |
|
return False |
|
|
|
def remove_article_from_collection(self, article_id: str, collection_name: str): |
|
""" |
|
Remove an article from a user's article collection. |
|
|
|
Args: |
|
article_id (str): The ID of the article. |
|
collection_name (str): The name of the user's collection. |
|
|
|
Returns: |
|
bool: True if the article was removed successfully. |
|
""" |
|
query = """ |
|
FOR collection IN article_collections |
|
FILTER collection.name == @collection_name |
|
UPDATE collection WITH { |
|
articles: REMOVE_VALUE(collection.articles, @article_id) |
|
} IN article_collections |
|
RETURN NEW |
|
""" |
|
cursor = self.db.aql.execute( |
|
query, |
|
bind_vars={"collection_name": collection_name, "article_id": article_id}, |
|
) |
|
try: |
|
return next(cursor) is not None |
|
except StopIteration: |
|
return False |
|
|
|
# Projects |
|
def get_projects(self, username: str = None): |
|
""" |
|
Get all projects for a user. |
|
|
|
Returns: |
|
list: A list of project documents. |
|
""" |
|
if username: |
|
query = """ |
|
FOR p IN projects |
|
SORT p.name ASC |
|
RETURN p |
|
""" |
|
return list(self.db.aql.execute(query)) |
|
else: |
|
return self.get_all_documents("projects") |
|
|
|
def get_project(self, project_name: str, username: str = None): |
|
""" |
|
Get a project by name. |
|
|
|
Args: |
|
project_name (str): The name of the project. |
|
|
|
Returns: |
|
dict: The project document if found, None otherwise. |
|
""" |
|
if username: |
|
query = """ |
|
FOR p IN projects |
|
FILTER p.name == @project_name |
|
RETURN p |
|
""" |
|
cursor = self.db.aql.execute( |
|
query, bind_vars={"project_name": project_name} |
|
) |
|
try: |
|
return next(cursor) |
|
except StopIteration: |
|
return None |
|
else: |
|
query = """ |
|
FOR p IN projects |
|
FILTER p.name == @project_name |
|
RETURN p |
|
""" |
|
cursor = self.db.aql.execute( |
|
query, bind_vars={"project_name": project_name} |
|
) |
|
try: |
|
return next(cursor) |
|
except StopIteration: |
|
return None |
|
|
|
def create_project(self, project_data: dict): |
|
""" |
|
Create a new project. |
|
|
|
Args: |
|
project_data (dict): The project data. |
|
|
|
Returns: |
|
dict: The created project document. |
|
""" |
|
return self.insert_document("projects", project_data) |
|
|
|
def update_project(self, project_data: dict): |
|
""" |
|
Update an existing project. |
|
|
|
Args: |
|
project_data (dict): The project data. |
|
|
|
Returns: |
|
dict: The updated project document. |
|
""" |
|
return self.update_document(project_data, check_rev=False) |
|
|
|
def delete_project(self, project_name: str, username: str = None): |
|
""" |
|
Delete a project. |
|
|
|
Args: |
|
project_name (str): The name of the project. |
|
username (str, optional): The username. Defaults to None. |
|
|
|
Returns: |
|
bool: True if the project was deleted successfully. |
|
""" |
|
filters = {"name": project_name} |
|
if username: |
|
filters["username"] = username |
|
|
|
return self.delete_document_by_match("projects", filters) |
|
|
|
def get_project_notes(self, project_name: str, username: str = None): |
|
""" |
|
Get notes for a project. |
|
|
|
Args: |
|
project_name (str): The name of the project. |
|
username (str, optional): The username. Defaults to None. |
|
|
|
Returns: |
|
list: A list of note documents. |
|
""" |
|
query = """ |
|
FOR note IN notes |
|
FILTER note.project == @project_name |
|
""" |
|
|
|
if username: |
|
query += " AND note.username == @username" |
|
|
|
query += """ |
|
SORT note.timestamp DESC |
|
RETURN note |
|
""" |
|
|
|
bind_vars = {"project_name": project_name} |
|
if username: |
|
bind_vars["username"] = username |
|
|
|
return list(self.db.aql.execute(query, bind_vars=bind_vars)) |
|
|
|
def add_note_to_project(self, note_data: dict): |
|
""" |
|
Add a note to a project. |
|
|
|
Args: |
|
note_data (dict): The note data. |
|
|
|
Returns: |
|
dict: The created note document. |
|
""" |
|
return self.insert_document("notes", note_data) |
|
|
|
def fetch_notes_tool( |
|
self, project_name: str, username: str = None |
|
) -> UnifiedSearchResults: |
|
""" |
|
Fetch notes for a project and return them in a unified format. |
|
|
|
Args: |
|
project_name (str): The name of the project. |
|
username (str, optional): The username. Defaults to None. |
|
|
|
Returns: |
|
UnifiedSearchResults: A unified representation of the notes. |
|
""" |
|
notes = self.get_project_notes(project_name, username) |
|
chunks = [] |
|
source_ids = [] |
|
|
|
for note in notes: |
|
chunk = UnifiedDataChunk( |
|
content=note.get("content", ""), |
|
metadata={ |
|
"title": note.get("title", "No title"), |
|
"timestamp": note.get("timestamp", ""), |
|
}, |
|
source_type="note", |
|
) |
|
chunks.append(chunk) |
|
source_ids.append(note.get("_id", "unknown_id")) |
|
|
|
return UnifiedSearchResults(chunks=chunks, source_ids=source_ids) |
|
|
|
# Chat operations |
|
def get_chat(self, chat_key: str): |
|
""" |
|
Get a chat by key. |
|
|
|
Args: |
|
chat_key (str): The key of the chat. |
|
|
|
Returns: |
|
dict: The chat document if found, None otherwise. |
|
""" |
|
try: |
|
return self.db.collection("chats").get(chat_key) |
|
except: |
|
return None |
|
|
|
def create_or_update_chat(self, chat_data: dict): |
|
""" |
|
Create or update a chat. |
|
|
|
Args: |
|
chat_data (dict): The chat data. |
|
|
|
Returns: |
|
dict: The created or updated chat document. |
|
""" |
|
return self.insert_document("chats", chat_data, overwrite=True) |
|
|
|
def get_chats_for_project(self, project_name: str, username: str = None): |
|
""" |
|
Get all chats for a project. |
|
|
|
Args: |
|
project_name (str): The name of the project. |
|
username (str, optional): The username. Defaults to None. |
|
|
|
Returns: |
|
list: A list of chat documents. |
|
""" |
|
query = """ |
|
FOR chat IN chats |
|
FILTER chat.project == @project_name |
|
""" |
|
|
|
if username: |
|
query += " AND chat.username == @username" |
|
|
|
query += """ |
|
SORT chat.timestamp DESC |
|
RETURN chat |
|
""" |
|
|
|
bind_vars = {"project_name": project_name} |
|
if username: |
|
bind_vars["username"] = username |
|
|
|
return list(self.db.aql.execute(query, bind_vars=bind_vars)) |
|
|
|
def delete_chat(self, chat_key: str): |
|
""" |
|
Delete a chat. |
|
|
|
Args: |
|
chat_key (str): The key of the chat. |
|
|
|
Returns: |
|
dict: The deletion result. |
|
""" |
|
return self.delete_document("chats", chat_key) |
|
|
|
def delete_old_chats(self, days: int = 30): |
|
""" |
|
Delete chats older than a certain number of days. |
|
|
|
Args: |
|
days (int, optional): The number of days. Defaults to 30. |
|
|
|
Returns: |
|
int: The number of deleted chats. |
|
""" |
|
query = """ |
|
FOR chat IN chats |
|
FILTER DATE_DIFF(chat.timestamp, DATE_NOW(), "d") > @days |
|
REMOVE chat IN chats |
|
RETURN OLD |
|
""" |
|
cursor = self.db.aql.execute(query, bind_vars={"days": days}) |
|
return len(list(cursor)) |
|
|
|
# Settings operations |
|
def get_settings(self): |
|
""" |
|
Get settings document. |
|
|
|
Returns: |
|
dict: The settings document if found, None otherwise. |
|
""" |
|
try: |
|
return self.db.document("settings/settings") |
|
except: |
|
return None |
|
|
|
def initialize_settings(self, settings_data: dict): |
|
""" |
|
Initialize settings. |
|
|
|
Args: |
|
settings_data (dict): The settings data. |
|
|
|
Returns: |
|
dict: The created settings document. |
|
""" |
|
settings_data["_key"] = "settings" |
|
return self.insert_document("settings", settings_data) |
|
|
|
def update_settings(self, settings_data: dict): |
|
""" |
|
Update settings. |
|
|
|
Args: |
|
settings_data (dict): The settings data. |
|
|
|
Returns: |
|
dict: The updated settings document. |
|
""" |
|
return self.update_document_by_match( |
|
collection_name="settings", filters={"_key": "settings"}, body=settings_data |
|
) |
|
|
|
def get_document_metadata(self, document_id: str) -> dict: |
|
""" |
|
Retrieve document metadata with merged user notes if available. |
|
|
|
This method determines the appropriate database based on the document ID, |
|
retrieves the document, and enriches its metadata with any user notes. |
|
|
|
Args: |
|
document_id (str): The document ID to retrieve metadata for |
|
|
|
Returns: |
|
dict: The document metadata dictionary, or empty dict if not found |
|
""" |
|
if not document_id: |
|
return {} |
|
|
|
try: |
|
# Determine which database to use based on document ID prefix |
|
if document_id.startswith("sci_articles"): |
|
# Science articles are in the base database |
|
db_to_use = self.client.db( |
|
"base", |
|
username=os.getenv("ARANGO_USER"), |
|
password=os.getenv("ARANGO_PASSWORD"), |
|
) |
|
arango_doc = db_to_use.document(document_id) |
|
else: |
|
# User documents are in the user's database |
|
arango_doc = self.db.document(document_id) |
|
|
|
if not arango_doc: |
|
return {} |
|
|
|
# Get metadata and merge user notes if available |
|
arango_metadata = arango_doc.get("metadata", {}) |
|
if "user_notes" in arango_doc: |
|
arango_metadata["user_notes"] = arango_doc["user_notes"] |
|
|
|
return arango_metadata |
|
except Exception as e: |
|
print(f"Error retrieving metadata for document {document_id}: {e}") |
|
return {} |
|
|
|
def summarise_chunks(self, document: dict, is_sci=False): |
|
from _llm import LLM |
|
from models import ArticleChunk |
|
|
|
assert "_id" in document, "Document must have an _id field" |
|
|
|
if is_sci: |
|
system_message = """You are a science assistant summarizing scientific articles. |
|
You will get an article chunk by chunk, and you have three tasks for each chunk: |
|
1. Summarize the content of the chunk. |
|
2. Tag the chunk with relevant tags. |
|
3. Extract the scientific references from the chunk. |
|
""" |
|
else: |
|
system_message = """You are a general assistant summarizing articles. |
|
You will get an article chunk by chunk, and you have two tasks for each chunk: |
|
1. Summarize the content of the chunk. |
|
2. Tag the chunk with relevant tags. |
|
""" |
|
|
|
system_message += """\nPlease make use of the previous chunks you have already seen to understand the current chunk in context and make the summary stand for itself. But remember, *it is the current chunk you are summarizing* |
|
ONLY use the information in the chunks to make the summary, and do not add any information that is not in the chunks.""" |
|
|
|
llm = LLM(system_message=system_message) |
|
chunks = [] |
|
for chunk in document["chunks"]: |
|
if "summary" in chunk: |
|
chunks.append(chunk) |
|
continue |
|
prompt = f"""Summarize the following text to make it stand on its own:\n |
|
''' |
|
{chunk['text']} |
|
'''\n |
|
Your tasks are: |
|
1. Summarize the content of the chunk. Make sure to include all relevant details! |
|
2. Tag the chunk with relevant tags. |
|
""" |
|
if is_sci: |
|
prompt += "\n3. Extract the scientific references mentioned in this specific chunk. If there is a DOI reference, include that in the reference. Sometimes the reference is only a number in brackets, like [1], so make sure to include that as well (in brackets)." |
|
prompt += "\nONLY use the information in the chunks to make the summary, and do not add any information that is not in the chunks." |
|
|
|
try: |
|
response = llm.generate(prompt, format=ArticleChunk.model_json_schema()) |
|
structured_response = ArticleChunk.model_validate_json(response.content) |
|
chunk["summary"] = structured_response.summary |
|
chunk["tags"] = [i.lower() for i in structured_response.tags] |
|
chunk["summary_meta"] = { |
|
"model": llm.model, |
|
"date": datetime.now().strftime("%Y-%m-%d"), |
|
} |
|
except Exception as e: |
|
print(f"Error processing chunk: {e}") |
|
chunks.append(chunk) |
|
document["chunks"] = chunks |
|
self.update_document(document, check_rev=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
arango = ArangoDB(user='lasse') |
|
random_doc = arango.db.aql.execute( |
|
"FOR doc IN other_documents LIMIT 1 RETURN doc" |
|
) |
|
print(next(random_doc)) |
|
|
|
|