You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

949 lines
30 KiB

import os
from datetime import datetime
from dotenv import load_dotenv
from arango import ArangoClient
from arango.collection import StandardCollection as ArangoCollection
from models import UnifiedDataChunk, UnifiedSearchResults
from utils import fix_key
if "INFO" not in os.environ:
import env_manager
env_manager.set_env()
load_dotenv() # Install with pip install python-dotenv
COLLECTIONS_IN_BASE = [
"sci_articles",
]
class ArangoDB:
"""
ArangoDB Client Wrapper
This class provides a wrapper around the ArangoClient to simplify working with ArangoDB databases
and collections in a scientific document management context. It handles authentication, database
connections, and provides high-level methods for common operations.
Key features:
- Database and collection management
- Document CRUD operations (Create, Read, Update, Delete)
- AQL query execution
- Scientific article storage and retrieval
- Project and note management
- Chat history storage
- Settings management
Usage example:
arango = ArangoDB(user="admin", password="password")
# Create a collection
arango.create_collection("my_collection")
# Insert a document
doc = arango.insert_document("my_collection", {"name": "Test Document"})
# Query documents
results = arango.execute_aql("FOR doc IN my_collection RETURN doc")
Environment variables:
ARANGO_HOST: The ArangoDB host URL
ARANGO_PASSWORD: The default password for authentication
"""
def __init__(self, user="admin", password=None, db_name="base"):
"""
Initialize a connection to an ArangoDB database.
This constructor establishes a connection to an ArangoDB instance using the provided
credentials and database name. It uses environment variables for host and password
if not explicitly provided.
Parameters
----------
user : str, optional
Username for database authentication. Defaults to "admin".
If db_name is not "base", then user will be set to db_name.
password : str, optional
Password for database authentication. If not provided,
the password will be retrieved from the ARANGO_PASSWORD environment variable.
db_name : str, optional
Name of the database to connect to. Defaults to "base".
If not "base", this value will also be used as the username.
Notes
-----
- The host URL is always retrieved from the ARANGO_HOST environment variable.
- For the "base" database, the username will be either "admin" or the provided user.
- For other databases, the username will be the same as the database name.
Attributes
----------
user : str
The username used for authentication.
password : str
The password used for authentication.
db_name : str
The name of the connected database.
client : ArangoClient
The ArangoDB client instance.
db : Database
The database instance for executing operations.
"""
host = os.getenv("ARANGO_HOST")
if not password:
self.password = os.getenv("ARANGO_PASSWORD")
# This is the default user for the base database
if db_name != "base":
self.user = db_name
self.db_name = db_name
elif user == "admin":
self.user = "admin"
self.db_name = "base"
else:
self.user = user
self.db_name = user
self.client = ArangoClient(hosts=host)
self.db = self.client.db(
self.db_name, username=self.user, password=self.password
)
def fix_key(self, _key):
return fix_key(_key)
# Collection operations
def get_collection(self, collection_name: str) -> ArangoCollection:
"""
Get a collection by name.
Args:
collection_name (str): The name of the collection.
Returns:
ArangoCollection: The collection object.
"""
return self.db.collection(collection_name)
def has_collection(self, collection_name: str) -> bool:
"""
Check if a collection exists.
Args:
collection_name (str): The name of the collection.
Returns:
bool: True if the collection exists, False otherwise.
"""
return self.db.has_collection(collection_name)
def create_collection(self, collection_name: str) -> ArangoCollection:
"""
Create a new collection.
Args:
collection_name (str): The name of the collection to create.
Returns:
ArangoCollection: The created collection.
"""
return self.db.create_collection(collection_name)
def delete_collection(self, collection_name: str) -> bool:
"""
Delete a collection.
Args:
collection_name (str): The name of the collection to delete.
Returns:
bool: True if the collection was deleted successfully.
"""
if self.has_collection(collection_name):
return self.db.delete_collection(collection_name)
return False
def truncate_collection(self, collection_name: str) -> bool:
"""
Truncate a collection (remove all documents).
Args:
collection_name (str): The name of the collection to truncate.
Returns:
bool: True if the collection was truncated successfully.
"""
if self.has_collection(collection_name):
return self.db.collection(collection_name).truncate()
return False
# Document operations
def get_document(self, document_id: str):
"""
Get a document by ID.
Args:
document_id (str): The ID of the document to get.
Returns:
dict: The document if found, None otherwise.
"""
try:
return self.db.document(document_id)
except:
return None
def has_document(self, collection_name: str, document_key: str) -> bool:
"""
Check if a document exists in a collection.
Args:
collection_name (str): The name of the collection.
document_key (str): The key of the document.
Returns:
bool: True if the document exists, False otherwise.
"""
return self.db.collection(collection_name).has(document_key)
def insert_document(
self,
collection_name: str,
document: dict,
overwrite: bool = False,
overwrite_mode: str = "update",
keep_none: bool = False,
):
"""
Insert a document into a collection.
Args:
collection_name (str): The name of the collection.
document (dict): The document to insert.
overwrite (bool, optional): Whether to overwrite an existing document. Defaults to False.
overwrite_mode (str, optional): The mode for overwriting ('replace' or 'update'). Defaults to "replace".
keep_none (bool, optional): Whether to keep None values. Defaults to False.
Returns:
dict: The inserted document with its metadata (_id, _key, etc.)
"""
assert '_id' in document or '_key' in document, "Document must have either _id or _key"
if '_id' not in document:
document['_id'] = f"{collection_name}/{document['_key']}"
return self.db.collection(collection_name).insert(
document,
overwrite=overwrite,
overwrite_mode=overwrite_mode,
keep_none=keep_none,
)
def update_document(
self, document: dict, check_rev: bool = False, silent: bool = False
):
"""
Update a document that already has _id or _key.
Args:
document (dict): The document to update.
check_rev (bool, optional): Whether to check document revision. Defaults to False.
silent (bool, optional): Whether to return the updated document. Defaults to False.
Returns:
dict: The updated document if silent is False.
"""
return self.db.update_document(document, check_rev=check_rev, silent=silent)
def update_document_by_match(
self, collection_name: str, filters: dict, body: dict, merge: bool = True
):
"""
Update documents that match a filter.
Args:
collection_name (str): The name of the collection.
filters (dict): The filter to match documents.
body (dict): The update to apply.
merge (bool, optional): Whether to merge the update with existing data. Defaults to True.
Returns:
dict: The result of the update operation.
"""
return self.db.collection(collection_name).update_match(
filters=filters, body=body, merge=merge
)
def delete_document(self, collection_name: str, document_key: str):
"""
Delete a document from a collection.
Args:
collection_name (str): The name of the collection.
document_key (str): The key of the document to delete.
Returns:
dict: The deletion result.
"""
return self.db.collection(collection_name).delete(document_key)
def delete_document_by_match(self, collection_name: str, filters: dict):
"""
Delete documents that match a filter.
Args:
collection_name (str): The name of the collection.
filters (dict): The filter to match documents.
Returns:
dict: The deletion result.
"""
return self.db.collection(collection_name).delete_match(filters=filters)
# Query operations
def execute_aql(self, query: str, bind_vars: dict = None):
"""
Execute an AQL query.
Args:
query (str): The AQL query to execute.
bind_vars (dict, optional): Bind variables for the query. Defaults to None.
Returns:
Cursor: A cursor to the query results.
"""
return self.db.aql.execute(query, bind_vars=bind_vars)
def get_all_documents(self, collection_name: str):
"""
Get all documents from a collection.
Args:
collection_name (str): The name of the collection.
Returns:
list: All documents in the collection.
"""
return list(self.db.collection(collection_name).all())
# Database operations
def has_database(self, db_name: str) -> bool:
"""
Check if a database exists.
Args:
db_name (str): The name of the database.
Returns:
bool: True if the database exists, False otherwise.
"""
return self.client.has_database(db_name)
def create_database(self, db_name: str, users: list = None) -> bool:
"""
Create a new database.
Args:
db_name (str): The name of the database to create.
users (list, optional): List of user objects with access to the database. Defaults to None.
Returns:
bool: True if the database was created successfully.
"""
return self.client.create_database(db_name, users=users)
def delete_database(self, db_name: str) -> bool:
"""
Delete a database.
Args:
db_name (str): The name of the database to delete.
Returns:
bool: True if the database was deleted successfully.
"""
if self.client.has_database(db_name):
return self.client.delete_database(db_name)
return False
# Domain-specific operations
# Scientific Articles
def get_article(
self,
article_key: str,
db_name: str = None,
collection_name: str = "sci_articles",
):
"""
Get a scientific article by key.
Args:
article_key (str): The key of the article.
db_name (str, optional): The database name to search in. Defaults to current database.
Returns:
dict: The article document if found, None otherwise.
"""
try:
return self.db.collection("sci_articles").get(article_key)
except Exception as e:
print(f"Error retrieving article {article_key}: {e}")
raise e
return None
def get_article_by_doi(self, doi: str):
"""
Get a scientific article by DOI.
Args:
doi (str): The DOI of the article.
Returns:
dict: The article document if found, None otherwise.
"""
query = """
FOR doc IN sci_articles
FILTER doc.metadata.doi == @doi
RETURN doc
"""
cursor = self.db.aql.execute(query, bind_vars={"doi": doi})
try:
return next(cursor)
except StopIteration:
return None
def get_document_text(
self, _id: str = None, _key: str = None, collection: str = None
):
"""
Get the text content of a document. If _key is used, collection must be provided.
* Use base_arango for sci_articles and user_arango for other collections. *
Args:
_id (str, optional): The ID of the document. Defaults to None.
_key (str, optional): The key of the document. Defaults to None.
collection (str, optional): The name of the collection. Defaults to None.
Returns:
str: The text content of the document, or None if not found.
"""
if collection == "sci_articles" or _id.startswith("sci_articles"):
assert (
self.db_name == "base"
), "If requesting sci_articles base_arango must be used"
else:
assert (
self.db_name != "base"
), "If not requesting sci_articles user_arango must be used"
try:
if _id:
doc = self.db.document(_id)
elif _key:
assert (
collection is not None
), "Collection name must be provided if _key is used"
doc = self.db.collection(collection).get(_key)
text = [chunk.get("text") for chunk in doc.get("chunks", [])]
except Exception as e:
print(f"Error retrieving text for document {_id or _key}: {e}")
return None
return "\n".join(text) if text else None
def store_article_chunks(
self, article_data: dict, chunks: list, document_key: str = None
):
"""
Store article chunks in the database.
Args:
article_data (dict): The article metadata.
chunks (list): The chunks of text from the article.
document_key (str, optional): The key to use for the document. Defaults to None.
Returns:
tuple: (document_id, database_name, document_doi)
"""
collection = "sci_articles"
arango_chunks = []
for index, chunk in enumerate(chunks):
chunk_id = f"{document_key}_{index}" if document_key else f"chunk_{index}"
page_numbers = chunk.get("pages", [])
text = chunk.get("text", "")
arango_chunks.append({"text": text, "pages": page_numbers, "id": chunk_id})
arango_document = {
"_key": document_key,
"chunks": arango_chunks,
"metadata": article_data.get("metadata", {}),
}
if article_data.get("summary"):
arango_document["summary"] = article_data.get("summary")
if article_data.get("doi"):
arango_document["crossref"] = True
doc = self.insert_document(
collection_name=collection,
document=arango_document,
overwrite=True,
overwrite_mode="update",
keep_none=False,
)
return doc["_id"], self.db_name, article_data.get("doi")
def add_article_to_collection(self, article_id: str, collection_name: str):
"""
Add an article to a user's article collection.
Args:
article_id (str): The ID of the article.
collection_name (str): The name of the user's collection.
Returns:
bool: True if the article was added successfully.
"""
query = """
FOR collection IN article_collections
FILTER collection.name == @collection_name
UPDATE collection WITH {
articles: PUSH(collection.articles, @article_id)
} IN article_collections
RETURN NEW
"""
cursor = self.db.aql.execute(
query,
bind_vars={"collection_name": collection_name, "article_id": article_id},
)
try:
return next(cursor) is not None
except StopIteration:
return False
def remove_article_from_collection(self, article_id: str, collection_name: str):
"""
Remove an article from a user's article collection.
Args:
article_id (str): The ID of the article.
collection_name (str): The name of the user's collection.
Returns:
bool: True if the article was removed successfully.
"""
query = """
FOR collection IN article_collections
FILTER collection.name == @collection_name
UPDATE collection WITH {
articles: REMOVE_VALUE(collection.articles, @article_id)
} IN article_collections
RETURN NEW
"""
cursor = self.db.aql.execute(
query,
bind_vars={"collection_name": collection_name, "article_id": article_id},
)
try:
return next(cursor) is not None
except StopIteration:
return False
# Projects
def get_projects(self, username: str = None):
"""
Get all projects for a user.
Returns:
list: A list of project documents.
"""
if username:
query = """
FOR p IN projects
SORT p.name ASC
RETURN p
"""
return list(self.db.aql.execute(query))
else:
return self.get_all_documents("projects")
def get_project(self, project_name: str, username: str = None):
"""
Get a project by name.
Args:
project_name (str): The name of the project.
Returns:
dict: The project document if found, None otherwise.
"""
if username:
query = """
FOR p IN projects
FILTER p.name == @project_name
RETURN p
"""
cursor = self.db.aql.execute(
query, bind_vars={"project_name": project_name}
)
try:
return next(cursor)
except StopIteration:
return None
else:
query = """
FOR p IN projects
FILTER p.name == @project_name
RETURN p
"""
cursor = self.db.aql.execute(
query, bind_vars={"project_name": project_name}
)
try:
return next(cursor)
except StopIteration:
return None
def create_project(self, project_data: dict):
"""
Create a new project.
Args:
project_data (dict): The project data.
Returns:
dict: The created project document.
"""
return self.insert_document("projects", project_data)
def update_project(self, project_data: dict):
"""
Update an existing project.
Args:
project_data (dict): The project data.
Returns:
dict: The updated project document.
"""
return self.update_document(project_data, check_rev=False)
def delete_project(self, project_name: str, username: str = None):
"""
Delete a project.
Args:
project_name (str): The name of the project.
username (str, optional): The username. Defaults to None.
Returns:
bool: True if the project was deleted successfully.
"""
filters = {"name": project_name}
if username:
filters["username"] = username
return self.delete_document_by_match("projects", filters)
def get_project_notes(self, project_name: str, username: str = None):
"""
Get notes for a project.
Args:
project_name (str): The name of the project.
username (str, optional): The username. Defaults to None.
Returns:
list: A list of note documents.
"""
query = """
FOR note IN notes
FILTER note.project == @project_name
"""
if username:
query += " AND note.username == @username"
query += """
SORT note.timestamp DESC
RETURN note
"""
bind_vars = {"project_name": project_name}
if username:
bind_vars["username"] = username
return list(self.db.aql.execute(query, bind_vars=bind_vars))
def add_note_to_project(self, note_data: dict):
"""
Add a note to a project.
Args:
note_data (dict): The note data.
Returns:
dict: The created note document.
"""
return self.insert_document("notes", note_data)
def fetch_notes_tool(
self, project_name: str, username: str = None
) -> UnifiedSearchResults:
"""
Fetch notes for a project and return them in a unified format.
Args:
project_name (str): The name of the project.
username (str, optional): The username. Defaults to None.
Returns:
UnifiedSearchResults: A unified representation of the notes.
"""
notes = self.get_project_notes(project_name, username)
chunks = []
source_ids = []
for note in notes:
chunk = UnifiedDataChunk(
content=note.get("content", ""),
metadata={
"title": note.get("title", "No title"),
"timestamp": note.get("timestamp", ""),
},
source_type="note",
)
chunks.append(chunk)
source_ids.append(note.get("_id", "unknown_id"))
return UnifiedSearchResults(chunks=chunks, source_ids=source_ids)
# Chat operations
def get_chat(self, chat_key: str):
"""
Get a chat by key.
Args:
chat_key (str): The key of the chat.
Returns:
dict: The chat document if found, None otherwise.
"""
try:
return self.db.collection("chats").get(chat_key)
except:
return None
def create_or_update_chat(self, chat_data: dict):
"""
Create or update a chat.
Args:
chat_data (dict): The chat data.
Returns:
dict: The created or updated chat document.
"""
return self.insert_document("chats", chat_data, overwrite=True)
def get_chats_for_project(self, project_name: str, username: str = None):
"""
Get all chats for a project.
Args:
project_name (str): The name of the project.
username (str, optional): The username. Defaults to None.
Returns:
list: A list of chat documents.
"""
query = """
FOR chat IN chats
FILTER chat.project == @project_name
"""
if username:
query += " AND chat.username == @username"
query += """
SORT chat.timestamp DESC
RETURN chat
"""
bind_vars = {"project_name": project_name}
if username:
bind_vars["username"] = username
return list(self.db.aql.execute(query, bind_vars=bind_vars))
def delete_chat(self, chat_key: str):
"""
Delete a chat.
Args:
chat_key (str): The key of the chat.
Returns:
dict: The deletion result.
"""
return self.delete_document("chats", chat_key)
def delete_old_chats(self, days: int = 30):
"""
Delete chats older than a certain number of days.
Args:
days (int, optional): The number of days. Defaults to 30.
Returns:
int: The number of deleted chats.
"""
query = """
FOR chat IN chats
FILTER DATE_DIFF(chat.timestamp, DATE_NOW(), "d") > @days
REMOVE chat IN chats
RETURN OLD
"""
cursor = self.db.aql.execute(query, bind_vars={"days": days})
return len(list(cursor))
# Settings operations
def get_settings(self):
"""
Get settings document.
Returns:
dict: The settings document if found, None otherwise.
"""
try:
return self.db.document("settings/settings")
except:
return None
def initialize_settings(self, settings_data: dict):
"""
Initialize settings.
Args:
settings_data (dict): The settings data.
Returns:
dict: The created settings document.
"""
settings_data["_key"] = "settings"
return self.insert_document("settings", settings_data)
def update_settings(self, settings_data: dict):
"""
Update settings.
Args:
settings_data (dict): The settings data.
Returns:
dict: The updated settings document.
"""
return self.update_document_by_match(
collection_name="settings", filters={"_key": "settings"}, body=settings_data
)
def get_document_metadata(self, document_id: str) -> dict:
"""
Retrieve document metadata with merged user notes if available.
This method determines the appropriate database based on the document ID,
retrieves the document, and enriches its metadata with any user notes.
Args:
document_id (str): The document ID to retrieve metadata for
Returns:
dict: The document metadata dictionary, or empty dict if not found
"""
if not document_id:
return {}
try:
# Determine which database to use based on document ID prefix
if document_id.startswith("sci_articles"):
# Science articles are in the base database
db_to_use = self.client.db(
"base",
username=os.getenv("ARANGO_USER"),
password=os.getenv("ARANGO_PASSWORD"),
)
arango_doc = db_to_use.document(document_id)
else:
# User documents are in the user's database
arango_doc = self.db.document(document_id)
if not arango_doc:
return {}
# Get metadata and merge user notes if available
arango_metadata = arango_doc.get("metadata", {})
if "user_notes" in arango_doc:
arango_metadata["user_notes"] = arango_doc["user_notes"]
return arango_metadata
except Exception as e:
print(f"Error retrieving metadata for document {document_id}: {e}")
return {}
def summarise_chunks(self, document: dict, is_sci=False):
from _llm import LLM
from models import ArticleChunk
assert "_id" in document, "Document must have an _id field"
if is_sci:
system_message = """You are a science assistant summarizing scientific articles.
You will get an article chunk by chunk, and you have three tasks for each chunk:
1. Summarize the content of the chunk.
2. Tag the chunk with relevant tags.
3. Extract the scientific references from the chunk.
"""
else:
system_message = """You are a general assistant summarizing articles.
You will get an article chunk by chunk, and you have two tasks for each chunk:
1. Summarize the content of the chunk.
2. Tag the chunk with relevant tags.
"""
system_message += """\nPlease make use of the previous chunks you have already seen to understand the current chunk in context and make the summary stand for itself. But remember, *it is the current chunk you are summarizing*
ONLY use the information in the chunks to make the summary, and do not add any information that is not in the chunks."""
llm = LLM(system_message=system_message)
chunks = []
for chunk in document["chunks"]:
if "summary" in chunk:
chunks.append(chunk)
continue
prompt = f"""Summarize the following text to make it stand on its own:\n
'''
{chunk['text']}
'''\n
Your tasks are:
1. Summarize the content of the chunk. Make sure to include all relevant details!
2. Tag the chunk with relevant tags.
"""
if is_sci:
prompt += "\n3. Extract the scientific references mentioned in this specific chunk. If there is a DOI reference, include that in the reference. Sometimes the reference is only a number in brackets, like [1], so make sure to include that as well (in brackets)."
prompt += "\nONLY use the information in the chunks to make the summary, and do not add any information that is not in the chunks."
try:
response = llm.generate(prompt, format=ArticleChunk.model_json_schema())
structured_response = ArticleChunk.model_validate_json(response.content)
chunk["summary"] = structured_response.summary
chunk["tags"] = [i.lower() for i in structured_response.tags]
chunk["summary_meta"] = {
"model": llm.model,
"date": datetime.now().strftime("%Y-%m-%d"),
}
except Exception as e:
print(f"Error processing chunk: {e}")
chunks.append(chunk)
document["chunks"] = chunks
self.update_document(document, check_rev=False)
if __name__ == "__main__":
arango = ArangoDB(user='lasse')
random_doc = arango.db.aql.execute(
"FOR doc IN other_documents LIMIT 1 RETURN doc"
)
print(next(random_doc))