From 62b68c37176ac09b36a8d2eb5c163579764e5c12 Mon Sep 17 00:00:00 2001
From: lasseedfast <>
Date: Fri, 30 May 2025 21:05:50 +0200
Subject: [PATCH] Add models, testing scripts, and result viewing functionality
- Implemented Pydantic models for article processing and summarization.
- Created `test_and_view.py` for testing LLM server document summarization.
- Developed `test_llm_server.py` for unit testing summarization functionality.
- Added `test_server.py` for additional testing of document and chunk summarization.
- Introduced `view_latest_results.py` to display the latest summaries from the LLM server.
- Established a structured plan for handling document chunks and their metadata.
- Enhanced error handling and user feedback in testing scripts.
---
_arango.py | 967 ++++++++++++++++++++++--
_base_class.py | 257 ++++++-
_bots.py | 800 --------------------
_bots_dont_use.py | 497 +++++++++++++
_chromadb.py | 299 +++++++-
_llm.py | 574 --------------
_llmOLD.py | 581 +++++++++++++++
agent_research.py | 1448 ++++++++++++++++++++++++++++++++++++
article2db.py | 335 +++++++--
bot_tools.py | 0
info.py | 1 +
llm_queries.py | 5 +
llm_server.py | 373 +++++++++-
manage_users.py | 19 +-
models.py | 334 +++++++++
ollama_response_classes.py | 6 -
projects_page.py | 239 +++---
research_page.py | 249 ++++---
streamlit_app.py | 36 +-
streamlit_chatbot.py | 737 +++++++++++-------
test.py | 37 +-
test_ tortoise.py | 31 -
test_and_view.py | 209 ++++++
test_fairseq.py | 51 --
test_highlight.py | 91 ---
test_llm_server.py | 191 +++++
test_ollama_client.py | 38 -
test_ollama_image.py | 9 -
test_research.py | 206 -----
test_server.py | 123 +++
test_tts.py | 45 --
test_tts_call_server.py | 22 -
tts_save_speaker.py | 33 -
utils.py | 94 ++-
view_latest_results.py | 111 +++
35 files changed, 6481 insertions(+), 2567 deletions(-)
delete mode 100644 _bots.py
create mode 100644 _bots_dont_use.py
delete mode 100644 _llm.py
create mode 100644 _llmOLD.py
create mode 100644 agent_research.py
create mode 100644 bot_tools.py
create mode 100644 models.py
delete mode 100644 ollama_response_classes.py
delete mode 100644 test_ tortoise.py
create mode 100644 test_and_view.py
delete mode 100644 test_fairseq.py
delete mode 100644 test_highlight.py
create mode 100644 test_llm_server.py
delete mode 100644 test_ollama_client.py
delete mode 100644 test_ollama_image.py
delete mode 100644 test_research.py
create mode 100644 test_server.py
delete mode 100644 test_tts.py
delete mode 100644 test_tts_call_server.py
delete mode 100644 tts_save_speaker.py
create mode 100644 view_latest_results.py
diff --git a/_arango.py b/_arango.py
index 661ddce..c0907d9 100644
--- a/_arango.py
+++ b/_arango.py
@@ -1,75 +1,950 @@
-import re
-from arango import ArangoClient
-from dotenv import load_dotenv
import os
+from datetime import datetime
+
+from dotenv import load_dotenv
+from arango import ArangoClient
+from arango.collection import StandardCollection as ArangoCollection
+
+from models import UnifiedDataChunk, UnifiedSearchResults
+from utils import fix_key
+
if "INFO" not in os.environ:
import env_manager
+
env_manager.set_env()
load_dotenv() # Install with pip install python-dotenv
+COLLECTIONS_IN_BASE = [
+ "sci_articles",
+]
+
class ArangoDB:
- def __init__(self, user=None, password=None, db_name=None):
+ """
+ ArangoDB Client Wrapper
+ This class provides a wrapper around the ArangoClient to simplify working with ArangoDB databases
+ and collections in a scientific document management context. It handles authentication, database
+ connections, and provides high-level methods for common operations.
+ Key features:
+ - Database and collection management
+ - Document CRUD operations (Create, Read, Update, Delete)
+ - AQL query execution
+ - Scientific article storage and retrieval
+ - Project and note management
+ - Chat history storage
+ - Settings management
+ Usage example:
+ arango = ArangoDB(user="admin", password="password")
+ # Create a collection
+ arango.create_collection("my_collection")
+ # Insert a document
+ doc = arango.insert_document("my_collection", {"name": "Test Document"})
+ # Query documents
+ results = arango.execute_aql("FOR doc IN my_collection RETURN doc")
+ Environment variables:
+ ARANGO_HOST: The ArangoDB host URL
+ ARANGO_PASSWORD: The default password for authentication
+ """
+ def __init__(self, user="admin", password=None, db_name="base"):
"""
- Initializes an instance of the ArangoClass.
-
- Args:
- db_name (str): The name of the database.
- username (str): The username for authentication.
- password (str): The password for authentication.
+ Initialize a connection to an ArangoDB database.
+ This constructor establishes a connection to an ArangoDB instance using the provided
+ credentials and database name. It uses environment variables for host and password
+ if not explicitly provided.
+ Parameters
+ ----------
+ user : str, optional
+ Username for database authentication. Defaults to "admin".
+ If db_name is not "base", then user will be set to db_name.
+ password : str, optional
+ Password for database authentication. If not provided,
+ the password will be retrieved from the ARANGO_PASSWORD environment variable.
+ db_name : str, optional
+ Name of the database to connect to. Defaults to "base".
+ If not "base", this value will also be used as the username.
+ Notes
+ -----
+ - The host URL is always retrieved from the ARANGO_HOST environment variable.
+ - For the "base" database, the username will be either "admin" or the provided user.
+ - For other databases, the username will be the same as the database name.
+ Attributes
+ ----------
+ user : str
+ The username used for authentication.
+ password : str
+ The password used for authentication.
+ db_name : str
+ The name of the connected database.
+ client : ArangoClient
+ The ArangoDB client instance.
+ db : Database
+ The database instance for executing operations.
"""
-
+
host = os.getenv("ARANGO_HOST")
if not password:
- password = os.getenv("ARANGO_PASSWORD")
- if not db_name:
- if user:
- db_name = user
- else:
- db_name = os.getenv("ARANGO_DB")
- if not user:
- user = os.getenv("ARANGO_USER")
+ self.password = os.getenv("ARANGO_PASSWORD")
+ # This is the default user for the base database
+ if db_name != "base":
+ self.user = db_name
+ self.db_name = db_name
+
+ elif user == "admin":
+ self.user = "admin"
+ self.db_name = "base"
+ else:
+ self.user = user
+ self.db_name = user
self.client = ArangoClient(hosts=host)
- if user=='lasse': #! This need to be fixed to work with all users!
- password = os.getenv("ARANGO_PWD_LASSE")
- self.db = self.client.db(db_name, username=user, password=password)
+ self.db = self.client.db(
+ self.db_name, username=self.user, password=self.password
+ )
+
+ def fix_key(self, _key):
+ return fix_key(_key)
+ # Collection operations
+ def get_collection(self, collection_name: str) -> ArangoCollection:
+ """
+ Get a collection by name.
+ Args:
+ collection_name (str): The name of the collection.
- def fix_key(self, _key):
+ Returns:
+ ArangoCollection: The collection object.
+ """
+ return self.db.collection(collection_name)
+
+ def has_collection(self, collection_name: str) -> bool:
+ """
+ Check if a collection exists.
+
+ Args:
+ collection_name (str): The name of the collection.
+
+ Returns:
+ bool: True if the collection exists, False otherwise.
+ """
+ return self.db.has_collection(collection_name)
+
+ def create_collection(self, collection_name: str) -> ArangoCollection:
+ """
+ Create a new collection.
+
+ Args:
+ collection_name (str): The name of the collection to create.
+
+ Returns:
+ ArangoCollection: The created collection.
+ """
+ return self.db.create_collection(collection_name)
+
+ def delete_collection(self, collection_name: str) -> bool:
+ """
+ Delete a collection.
+
+ Args:
+ collection_name (str): The name of the collection to delete.
+
+ Returns:
+ bool: True if the collection was deleted successfully.
+ """
+ if self.has_collection(collection_name):
+ return self.db.delete_collection(collection_name)
+ return False
+
+ def truncate_collection(self, collection_name: str) -> bool:
+ """
+ Truncate a collection (remove all documents).
+
+ Args:
+ collection_name (str): The name of the collection to truncate.
+
+ Returns:
+ bool: True if the collection was truncated successfully.
+ """
+ if self.has_collection(collection_name):
+ return self.db.collection(collection_name).truncate()
+ return False
+
+ # Document operations
+ def get_document(self, document_id: str):
+ """
+ Get a document by ID.
+
+ Args:
+ document_id (str): The ID of the document to get.
+
+ Returns:
+ dict: The document if found, None otherwise.
+ """
+ try:
+ return self.db.document(document_id)
+ except:
+ return None
+
+ def has_document(self, collection_name: str, document_key: str) -> bool:
+ """
+ Check if a document exists in a collection.
+
+ Args:
+ collection_name (str): The name of the collection.
+ document_key (str): The key of the document.
+
+ Returns:
+ bool: True if the document exists, False otherwise.
+ """
+ return self.db.collection(collection_name).has(document_key)
+
+ def insert_document(
+ self,
+ collection_name: str,
+ document: dict,
+ overwrite: bool = False,
+ overwrite_mode: str = "update",
+ keep_none: bool = False,
+ ):
+ """
+ Insert a document into a collection.
+
+ Args:
+ collection_name (str): The name of the collection.
+ document (dict): The document to insert.
+ overwrite (bool, optional): Whether to overwrite an existing document. Defaults to False.
+ overwrite_mode (str, optional): The mode for overwriting ('replace' or 'update'). Defaults to "replace".
+ keep_none (bool, optional): Whether to keep None values. Defaults to False.
+
+ Returns:
+ dict: The inserted document with its metadata (_id, _key, etc.)
+ """
+ assert '_id' in document or '_key' in document, "Document must have either _id or _key"
+ if '_id' not in document:
+ document['_id'] = f"{collection_name}/{document['_key']}"
+
+ return self.db.collection(collection_name).insert(
+ document,
+ overwrite=overwrite,
+ overwrite_mode=overwrite_mode,
+ keep_none=keep_none,
+ )
+
+ def update_document(
+ self, document: dict, check_rev: bool = False, silent: bool = False
+ ):
+ """
+ Update a document that already has _id or _key.
+
+ Args:
+ document (dict): The document to update.
+ check_rev (bool, optional): Whether to check document revision. Defaults to False.
+ silent (bool, optional): Whether to return the updated document. Defaults to False.
+
+ Returns:
+ dict: The updated document if silent is False.
+ """
+ return self.db.update_document(document, check_rev=check_rev, silent=silent)
+
+ def update_document_by_match(
+ self, collection_name: str, filters: dict, body: dict, merge: bool = True
+ ):
+ """
+ Update documents that match a filter.
+
+ Args:
+ collection_name (str): The name of the collection.
+ filters (dict): The filter to match documents.
+ body (dict): The update to apply.
+ merge (bool, optional): Whether to merge the update with existing data. Defaults to True.
+
+ Returns:
+ dict: The result of the update operation.
+ """
+ return self.db.collection(collection_name).update_match(
+ filters=filters, body=body, merge=merge
+ )
+
+ def delete_document(self, collection_name: str, document_key: str):
+ """
+ Delete a document from a collection.
+
+ Args:
+ collection_name (str): The name of the collection.
+ document_key (str): The key of the document to delete.
+
+ Returns:
+ dict: The deletion result.
+ """
+ return self.db.collection(collection_name).delete(document_key)
+
+ def delete_document_by_match(self, collection_name: str, filters: dict):
+ """
+ Delete documents that match a filter.
+
+ Args:
+ collection_name (str): The name of the collection.
+ filters (dict): The filter to match documents.
+
+ Returns:
+ dict: The deletion result.
+ """
+ return self.db.collection(collection_name).delete_match(filters=filters)
+
+ # Query operations
+ def execute_aql(self, query: str, bind_vars: dict = None):
+ """
+ Execute an AQL query.
+
+ Args:
+ query (str): The AQL query to execute.
+ bind_vars (dict, optional): Bind variables for the query. Defaults to None.
+
+ Returns:
+ Cursor: A cursor to the query results.
+ """
+ return self.db.aql.execute(query, bind_vars=bind_vars)
+
+ def get_all_documents(self, collection_name: str):
+ """
+ Get all documents from a collection.
+
+ Args:
+ collection_name (str): The name of the collection.
+
+ Returns:
+ list: All documents in the collection.
+ """
+ return list(self.db.collection(collection_name).all())
+
+ # Database operations
+ def has_database(self, db_name: str) -> bool:
+ """
+ Check if a database exists.
+
+ Args:
+ db_name (str): The name of the database.
+
+ Returns:
+ bool: True if the database exists, False otherwise.
+ """
+ return self.client.has_database(db_name)
+
+ def create_database(self, db_name: str, users: list = None) -> bool:
+ """
+ Create a new database.
+
+ Args:
+ db_name (str): The name of the database to create.
+ users (list, optional): List of user objects with access to the database. Defaults to None.
+
+ Returns:
+ bool: True if the database was created successfully.
+ """
+ return self.client.create_database(db_name, users=users)
+
+ def delete_database(self, db_name: str) -> bool:
+ """
+ Delete a database.
+
+ Args:
+ db_name (str): The name of the database to delete.
+
+ Returns:
+ bool: True if the database was deleted successfully.
+ """
+ if self.client.has_database(db_name):
+ return self.client.delete_database(db_name)
+ return False
+
+ # Domain-specific operations
+
+ # Scientific Articles
+ def get_article(
+ self,
+ article_key: str,
+ db_name: str = None,
+ collection_name: str = "sci_articles",
+ ):
+ """
+ Get a scientific article by key.
+
+ Args:
+ article_key (str): The key of the article.
+ db_name (str, optional): The database name to search in. Defaults to current database.
+
+ Returns:
+ dict: The article document if found, None otherwise.
+ """
+ try:
+ return self.db.collection("sci_articles").get(article_key)
+ except Exception as e:
+ print(f"Error retrieving article {article_key}: {e}")
+ raise e
+ return None
+
+ def get_article_by_doi(self, doi: str):
+ """
+ Get a scientific article by DOI.
+
+ Args:
+ doi (str): The DOI of the article.
+
+ Returns:
+ dict: The article document if found, None otherwise.
+ """
+ query = """
+ FOR doc IN sci_articles
+ FILTER doc.metadata.doi == @doi
+ RETURN doc
+ """
+ cursor = self.db.aql.execute(query, bind_vars={"doi": doi})
+ try:
+ return next(cursor)
+ except StopIteration:
+ return None
+
+ def get_document_text(
+ self, _id: str = None, _key: str = None, collection: str = None
+ ):
+ """
+ Get the text content of a document. If _key is used, collection must be provided.
+ * Use base_arango for sci_articles and user_arango for other collections. *
+ Args:
+ _id (str, optional): The ID of the document. Defaults to None.
+ _key (str, optional): The key of the document. Defaults to None.
+ collection (str, optional): The name of the collection. Defaults to None.
+ Returns:
+ str: The text content of the document, or None if not found.
+ """
+ if collection == "sci_articles" or _id.startswith("sci_articles"):
+ assert (
+ self.db_name == "base"
+ ), "If requesting sci_articles base_arango must be used"
+ else:
+ assert (
+ self.db_name != "base"
+ ), "If not requesting sci_articles user_arango must be used"
+
+ try:
+ if _id:
+ doc = self.db.document(_id)
+ elif _key:
+ assert (
+ collection is not None
+ ), "Collection name must be provided if _key is used"
+ doc = self.db.collection(collection).get(_key)
+
+ text = [chunk.get("text") for chunk in doc.get("chunks", [])]
+ except Exception as e:
+ print(f"Error retrieving text for document {_id or _key}: {e}")
+ return None
+ return "\n".join(text) if text else None
+
+ def store_article_chunks(
+ self, article_data: dict, chunks: list, document_key: str = None
+ ):
+ """
+ Store article chunks in the database.
+
+ Args:
+ article_data (dict): The article metadata.
+ chunks (list): The chunks of text from the article.
+ document_key (str, optional): The key to use for the document. Defaults to None.
+
+ Returns:
+ tuple: (document_id, database_name, document_doi)
"""
- Sanitize a given key by replacing all characters that are not alphanumeric,
- underscore, hyphen, dot, at symbol, parentheses, plus, equals, semicolon,
- dollar sign, asterisk, single quote, percent, or colon with an underscore.
+ collection = "sci_articles"
+
+ arango_chunks = []
+ for index, chunk in enumerate(chunks):
+ chunk_id = f"{document_key}_{index}" if document_key else f"chunk_{index}"
+ page_numbers = chunk.get("pages", [])
+ text = chunk.get("text", "")
+ arango_chunks.append({"text": text, "pages": page_numbers, "id": chunk_id})
+
+ arango_document = {
+ "_key": document_key,
+ "chunks": arango_chunks,
+ "metadata": article_data.get("metadata", {}),
+ }
+
+ if article_data.get("summary"):
+ arango_document["summary"] = article_data.get("summary")
+
+ if article_data.get("doi"):
+ arango_document["crossref"] = True
+
+ doc = self.insert_document(
+ collection_name=collection,
+ document=arango_document,
+ overwrite=True,
+ overwrite_mode="update",
+ keep_none=False,
+ )
+
+ return doc["_id"], self.db_name, article_data.get("doi")
+
+ def add_article_to_collection(self, article_id: str, collection_name: str):
+ """
+ Add an article to a user's article collection.
Args:
- _key (str): The key to be sanitized.
+ article_id (str): The ID of the article.
+ collection_name (str): The name of the user's collection.
Returns:
- str: The sanitized key with disallowed characters replaced by underscores.
+ bool: True if the article was added successfully.
+ """
+ query = """
+ FOR collection IN article_collections
+ FILTER collection.name == @collection_name
+ UPDATE collection WITH {
+ articles: PUSH(collection.articles, @article_id)
+ } IN article_collections
+ RETURN NEW
+ """
+ cursor = self.db.aql.execute(
+ query,
+ bind_vars={"collection_name": collection_name, "article_id": article_id},
+ )
+ try:
+ return next(cursor) is not None
+ except StopIteration:
+ return False
+
+ def remove_article_from_collection(self, article_id: str, collection_name: str):
+ """
+ Remove an article from a user's article collection.
+
+ Args:
+ article_id (str): The ID of the article.
+ collection_name (str): The name of the user's collection.
+
+ Returns:
+ bool: True if the article was removed successfully.
+ """
+ query = """
+ FOR collection IN article_collections
+ FILTER collection.name == @collection_name
+ UPDATE collection WITH {
+ articles: REMOVE_VALUE(collection.articles, @article_id)
+ } IN article_collections
+ RETURN NEW
+ """
+ cursor = self.db.aql.execute(
+ query,
+ bind_vars={"collection_name": collection_name, "article_id": article_id},
+ )
+ try:
+ return next(cursor) is not None
+ except StopIteration:
+ return False
+
+ # Projects
+ def get_projects(self, username: str = None):
"""
+ Get all projects for a user.
+
+ Returns:
+ list: A list of project documents.
+ """
+ if username:
+ query = """
+ FOR p IN projects
+ SORT p.name ASC
+ RETURN p
+ """
+ return list(self.db.aql.execute(query))
+ else:
+ return self.get_all_documents("projects")
+
+ def get_project(self, project_name: str, username: str = None):
+ """
+ Get a project by name.
+
+ Args:
+ project_name (str): The name of the project.
+
+ Returns:
+ dict: The project document if found, None otherwise.
+ """
+ if username:
+ query = """
+ FOR p IN projects
+ FILTER p.name == @project_name
+ RETURN p
+ """
+ cursor = self.db.aql.execute(
+ query, bind_vars={"project_name": project_name}
+ )
+ try:
+ return next(cursor)
+ except StopIteration:
+ return None
+ else:
+ query = """
+ FOR p IN projects
+ FILTER p.name == @project_name
+ RETURN p
+ """
+ cursor = self.db.aql.execute(
+ query, bind_vars={"project_name": project_name}
+ )
+ try:
+ return next(cursor)
+ except StopIteration:
+ return None
+
+ def create_project(self, project_data: dict):
+ """
+ Create a new project.
+
+ Args:
+ project_data (dict): The project data.
+
+ Returns:
+ dict: The created project document.
+ """
+ return self.insert_document("projects", project_data)
+
+ def update_project(self, project_data: dict):
+ """
+ Update an existing project.
+
+ Args:
+ project_data (dict): The project data.
+
+ Returns:
+ dict: The updated project document.
+ """
+ return self.update_document(project_data, check_rev=False)
+
+ def delete_project(self, project_name: str, username: str = None):
+ """
+ Delete a project.
+
+ Args:
+ project_name (str): The name of the project.
+ username (str, optional): The username. Defaults to None.
+
+ Returns:
+ bool: True if the project was deleted successfully.
+ """
+ filters = {"name": project_name}
+ if username:
+ filters["username"] = username
+
+ return self.delete_document_by_match("projects", filters)
+
+ def get_project_notes(self, project_name: str, username: str = None):
+ """
+ Get notes for a project.
+
+ Args:
+ project_name (str): The name of the project.
+ username (str, optional): The username. Defaults to None.
+
+ Returns:
+ list: A list of note documents.
+ """
+ query = """
+ FOR note IN notes
+ FILTER note.project == @project_name
+ """
+
+ if username:
+ query += " AND note.username == @username"
+
+ query += """
+ SORT note.timestamp DESC
+ RETURN note
+ """
+
+ bind_vars = {"project_name": project_name}
+ if username:
+ bind_vars["username"] = username
+
+ return list(self.db.aql.execute(query, bind_vars=bind_vars))
+
+ def add_note_to_project(self, note_data: dict):
+ """
+ Add a note to a project.
+
+ Args:
+ note_data (dict): The note data.
+
+ Returns:
+ dict: The created note document.
+ """
+ return self.insert_document("notes", note_data)
+
+ def fetch_notes_tool(
+ self, project_name: str, username: str = None
+ ) -> UnifiedSearchResults:
+ """
+ Fetch notes for a project and return them in a unified format.
+
+ Args:
+ project_name (str): The name of the project.
+ username (str, optional): The username. Defaults to None.
+
+ Returns:
+ UnifiedSearchResults: A unified representation of the notes.
+ """
+ notes = self.get_project_notes(project_name, username)
+ chunks = []
+ source_ids = []
+
+ for note in notes:
+ chunk = UnifiedDataChunk(
+ content=note.get("content", ""),
+ metadata={
+ "title": note.get("title", "No title"),
+ "timestamp": note.get("timestamp", ""),
+ },
+ source_type="note",
+ )
+ chunks.append(chunk)
+ source_ids.append(note.get("_id", "unknown_id"))
+
+ return UnifiedSearchResults(chunks=chunks, source_ids=source_ids)
+
+ # Chat operations
+ def get_chat(self, chat_key: str):
+ """
+ Get a chat by key.
+
+ Args:
+ chat_key (str): The key of the chat.
+
+ Returns:
+ dict: The chat document if found, None otherwise.
+ """
+ try:
+ return self.db.collection("chats").get(chat_key)
+ except:
+ return None
+
+ def create_or_update_chat(self, chat_data: dict):
+ """
+ Create or update a chat.
+
+ Args:
+ chat_data (dict): The chat data.
+
+ Returns:
+ dict: The created or updated chat document.
+ """
+ return self.insert_document("chats", chat_data, overwrite=True)
+
+ def get_chats_for_project(self, project_name: str, username: str = None):
+ """
+ Get all chats for a project.
+
+ Args:
+ project_name (str): The name of the project.
+ username (str, optional): The username. Defaults to None.
+
+ Returns:
+ list: A list of chat documents.
+ """
+ query = """
+ FOR chat IN chats
+ FILTER chat.project == @project_name
+ """
+
+ if username:
+ query += " AND chat.username == @username"
+
+ query += """
+ SORT chat.timestamp DESC
+ RETURN chat
+ """
+
+ bind_vars = {"project_name": project_name}
+ if username:
+ bind_vars["username"] = username
+
+ return list(self.db.aql.execute(query, bind_vars=bind_vars))
+
+ def delete_chat(self, chat_key: str):
+ """
+ Delete a chat.
+
+ Args:
+ chat_key (str): The key of the chat.
+
+ Returns:
+ dict: The deletion result.
+ """
+ return self.delete_document("chats", chat_key)
+
+ def delete_old_chats(self, days: int = 30):
+ """
+ Delete chats older than a certain number of days.
+
+ Args:
+ days (int, optional): The number of days. Defaults to 30.
+
+ Returns:
+ int: The number of deleted chats.
+ """
+ query = """
+ FOR chat IN chats
+ FILTER DATE_DIFF(chat.timestamp, DATE_NOW(), "d") > @days
+ REMOVE chat IN chats
+ RETURN OLD
+ """
+ cursor = self.db.aql.execute(query, bind_vars={"days": days})
+ return len(list(cursor))
+
+ # Settings operations
+ def get_settings(self):
+ """
+ Get settings document.
+
+ Returns:
+ dict: The settings document if found, None otherwise.
+ """
+ try:
+ return self.db.document("settings/settings")
+ except:
+ return None
+
+ def initialize_settings(self, settings_data: dict):
+ """
+ Initialize settings.
+
+ Args:
+ settings_data (dict): The settings data.
+
+ Returns:
+ dict: The created settings document.
+ """
+ settings_data["_key"] = "settings"
+ return self.insert_document("settings", settings_data)
+
+ def update_settings(self, settings_data: dict):
+ """
+ Update settings.
+
+ Args:
+ settings_data (dict): The settings data.
+
+ Returns:
+ dict: The updated settings document.
+ """
+ return self.update_document_by_match(
+ collection_name="settings", filters={"_key": "settings"}, body=settings_data
+ )
+
+ def get_document_metadata(self, document_id: str) -> dict:
+ """
+ Retrieve document metadata with merged user notes if available.
+
+ This method determines the appropriate database based on the document ID,
+ retrieves the document, and enriches its metadata with any user notes.
+
+ Args:
+ document_id (str): The document ID to retrieve metadata for
+
+ Returns:
+ dict: The document metadata dictionary, or empty dict if not found
+ """
+ if not document_id:
+ return {}
+
+ try:
+ # Determine which database to use based on document ID prefix
+ if document_id.startswith("sci_articles"):
+ # Science articles are in the base database
+ db_to_use = self.client.db(
+ "base",
+ username=os.getenv("ARANGO_USER"),
+ password=os.getenv("ARANGO_PASSWORD"),
+ )
+ arango_doc = db_to_use.document(document_id)
+ else:
+ # User documents are in the user's database
+ arango_doc = self.db.document(document_id)
+
+ if not arango_doc:
+ return {}
+
+ # Get metadata and merge user notes if available
+ arango_metadata = arango_doc.get("metadata", {})
+ if "user_notes" in arango_doc:
+ arango_metadata["user_notes"] = arango_doc["user_notes"]
+
+ return arango_metadata
+ except Exception as e:
+ print(f"Error retrieving metadata for document {document_id}: {e}")
+ return {}
+
+ def summarise_chunks(self, document: dict, is_sci=False):
+ from _llm import LLM
+ from models import ArticleChunk
+
+ assert "_id" in document, "Document must have an _id field"
+
+ if is_sci:
+ system_message = """You are a science assistant summarizing scientific articles.
+ You will get an article chunk by chunk, and you have three tasks for each chunk:
+ 1. Summarize the content of the chunk.
+ 2. Tag the chunk with relevant tags.
+ 3. Extract the scientific references from the chunk.
+ """
+ else:
+ system_message = """You are a general assistant summarizing articles.
+ You will get an article chunk by chunk, and you have two tasks for each chunk:
+ 1. Summarize the content of the chunk.
+ 2. Tag the chunk with relevant tags.
+ """
+
+ system_message += """\nPlease make use of the previous chunks you have already seen to understand the current chunk in context and make the summary stand for itself. But remember, *it is the current chunk you are summarizing*
+ ONLY use the information in the chunks to make the summary, and do not add any information that is not in the chunks."""
+
+ llm = LLM(system_message=system_message)
+ chunks = []
+ for chunk in document["chunks"]:
+ if "summary" in chunk:
+ chunks.append(chunk)
+ continue
+ prompt = f"""Summarize the following text to make it stand on its own:\n
+ '''
+ {chunk['text']}
+ '''\n
+ Your tasks are:
+ 1. Summarize the content of the chunk. Make sure to include all relevant details!
+ 2. Tag the chunk with relevant tags.
+ """
+ if is_sci:
+ prompt += "\n3. Extract the scientific references mentioned in this specific chunk. If there is a DOI reference, include that in the reference. Sometimes the reference is only a number in brackets, like [1], so make sure to include that as well (in brackets)."
+ prompt += "\nONLY use the information in the chunks to make the summary, and do not add any information that is not in the chunks."
- return re.sub(r"[^A-Za-z0-9_\-\.@()+=;$!*\'%:]", "_", _key)
+ try:
+ response = llm.generate(prompt, format=ArticleChunk.model_json_schema())
+ structured_response = ArticleChunk.model_validate_json(response.content)
+ chunk["summary"] = structured_response.summary
+ chunk["tags"] = [i.lower() for i in structured_response.tags]
+ chunk["summary_meta"] = {
+ "model": llm.model,
+ "date": datetime.now().strftime("%Y-%m-%d"),
+ }
+ except Exception as e:
+ print(f"Error processing chunk: {e}")
+ chunks.append(chunk)
+ document["chunks"] = chunks
+ self.update_document(document, check_rev=False)
if __name__ == "__main__":
+ arango = ArangoDB(user='lasse')
+ random_doc = arango.db.aql.execute(
+ "FOR doc IN other_documents LIMIT 1 RETURN doc"
+ )
+ print(next(random_doc))
- arango = ArangoDB(db_name='base')
- articles = arango.db.collection('sci_articles').all()
- for article in articles:
- if 'metadata' in article and article['metadata']:
- if 'abstract' in article['metadata']:
- abstract = article['metadata']['abstract']
- if isinstance(abstract, str):
- # Remove text within <> brackets and the brackets themselves
- article['metadata']['abstract'] = re.sub(r'<[^>]*>', '', abstract)
- arango.db.collection('sci_articles').update_match(
- filters={'_key': article['_key']},
- body={'metadata': article['metadata']},
- merge=True
- )
- print(f"Updated abstract for {article['_key']}")
-
-
diff --git a/_base_class.py b/_base_class.py
index cf40829..a150329 100644
--- a/_base_class.py
+++ b/_base_class.py
@@ -5,17 +5,17 @@ import streamlit as st
from _arango import ArangoDB
from _chromadb import ChromaDB
+
class BaseClass:
def __init__(self, username: str, **kwargs) -> None:
self.username: str = username
- self.project_name: str = kwargs.get('project_name', None)
- self.collection: str = kwargs.get('collection_name', None)
+ self.project_name: str = kwargs.get("project_name", None)
+ self.collection: str = kwargs.get("collection_name", None)
self.user_arango: ArangoDB = self.get_arango()
self.base_arango: ArangoDB = self.get_arango(admin=True)
for key, value in kwargs.items():
setattr(self, key, value)
-
def get_arango(self, admin: bool = False, db_name: str = None) -> ArangoDB:
if db_name:
return ArangoDB(db_name=db_name)
@@ -25,29 +25,41 @@ class BaseClass:
return ArangoDB(user=self.username, db_name=self.username)
def get_article_collections(self) -> list:
- article_collections = self.user_arango.db.aql.execute(
+ """
+ Gets the names of all article collections for the current user.
+
+ Returns:
+ list: A list of article collection names.
+ """
+ article_collections = self.user_arango.execute_aql(
'FOR doc IN article_collections RETURN doc["name"]'
)
return list(article_collections)
def get_projects(self) -> list:
- projects = self.user_arango.db.aql.execute(
- 'FOR doc IN projects RETURN doc["name"]'
- )
- return list(projects)
+ """
+ Gets the names of all projects for the current user.
+ Returns:
+ list: A list of project names.
+ """
+ projects = self.user_arango.get_projects(username=self.username)
+ return [project["name"] for project in projects]
def get_chromadb(self):
return ChromaDB()
def get_project(self, project_name: str):
- doc = self.user_arango.db.aql.execute(
- f'FOR doc IN projects FILTER doc["name"] == "{project_name}" RETURN doc',
- count=True,
- )
- if doc:
- return doc.next()
+ """
+ Get a project by name for the current user.
+
+ Args:
+ project_name (str): The name of the project.
+ Returns:
+ dict: The project document if found, None otherwise.
+ """
+ return self.user_arango.get_project(project_name, username=self.username)
def set_filename(self, filename=None, folder="other_documents"):
"""
@@ -77,6 +89,12 @@ class BaseClass:
self.file_path = file_path + ".pdf"
return file_path
+ def remove_thinking(self, response):
+ """Remove the thinking section from the response"""
+ response_text = response.content if hasattr(response, "content") else str(response)
+ if "" in response_text:
+ return response_text.split("")[1].strip()
+ return response_text
class StreamlitBaseClass(BaseClass):
"""
@@ -98,10 +116,11 @@ class StreamlitBaseClass(BaseClass):
Displays a select box for choosing a collection of favorite articles. Updates the current collection in the session state and the database.
choose_project(text="Select a project") -> str:
Displays a select box for choosing a project. Updates the current project in the session state and the database.
- """
+ """
+
def __init__(self, username: str, **kwargs) -> None:
super().__init__(username, **kwargs)
-
+
def get_settings(self, field: str = None):
"""
Retrieve or initialize user settings from the database.
@@ -112,24 +131,31 @@ class StreamlitBaseClass(BaseClass):
are then stored in the Streamlit session state.
Args:
- field (str, optional): The specific field to retrieve from the settings.
+ field (str, optional): The specific field to retrieve from the settings.
If not provided, the entire settings document is returned.
Returns:
- dict or any: The entire settings document if no field is specified,
+ dict or any: The entire settings document if no field is specified,
otherwise the value of the specified field.
"""
- settings = self.user_arango.db.document("settings/settings")
+ settings = self.user_arango.get_settings()
if not settings:
- self.user_arango.db.collection("settings").insert(
- {"_key": "settings", "current_collection": None, "current_page": None}
- )
+ default_settings = {
+ "_key": "settings",
+ "current_collection": None,
+ "current_page": None,
+ }
+ self.user_arango.initialize_settings(default_settings)
+ settings = default_settings
+
+ # Ensure required fields exist
for i in ["current_collection", "current_page"]:
if i not in settings:
settings[i] = None
+
st.session_state["settings"] = settings
if field:
- return settings[field]
+ return settings.get(field)
return settings
def update_settings(self, key, value) -> None:
@@ -189,7 +215,6 @@ class StreamlitBaseClass(BaseClass):
st.session_state["current_page"] = page_name
self.update_settings("current_page", page_name)
-
def choose_collection(self, text="Select a collection of favorite articles") -> str:
"""
Prompts the user to select a collection of favorite articles from a list.
@@ -214,7 +239,7 @@ class StreamlitBaseClass(BaseClass):
self.update_settings("current_collection", collection)
self.update_session_state()
return collection
-
+
def choose_project(self, text="Select a project") -> str:
"""
Prompts the user to select a project from a list of available projects.
@@ -231,16 +256,188 @@ class StreamlitBaseClass(BaseClass):
- Prints the chosen project name to the console.
"""
projects = self.get_projects()
- print('projects', projects)
+ print("projects", projects)
print(self.project_name)
-
- project = st.selectbox(text, projects, index=projects.index(self.project_name) if self.project_name in projects else None)
- print('Choosing project...')
+
+ project = st.selectbox(
+ text,
+ projects,
+ index=(
+ projects.index(self.project_name)
+ if self.project_name in projects
+ else None
+ ),
+ )
+ print("Choosing project...")
if project:
from projects_page import Project
+
self.project = Project(self.username, project, self.user_arango)
self.collection = None
self.update_settings("current_project", self.project.name)
self.update_session_state()
- print('CHOOSEN PROJECT:', self.project.name)
+ print("CHOOSEN PROJECT:", self.project.name)
return self.project
+
+ def add_article_to_collection(self, article_id: str, collection_name: str = None):
+ """
+ Add an article to a user's collection.
+
+ Args:
+ article_id (str): The ID of the article.
+ collection_name (str, optional): The name of the collection. Defaults to current collection.
+
+ Returns:
+ bool: True if the article was added successfully.
+ """
+ if collection_name is None:
+ collection_name = self.collection
+
+ return self.user_arango.add_article_to_collection(article_id, collection_name)
+
+ def remove_article_from_collection(
+ self, article_id: str, collection_name: str = None
+ ):
+ """
+ Remove an article from a user's collection.
+
+ Args:
+ article_id (str): The ID of the article.
+ collection_name (str, optional): The name of the collection. Defaults to current collection.
+
+ Returns:
+ bool: True if the article was removed successfully.
+ """
+ if collection_name is None:
+ collection_name = self.collection
+
+ return self.user_arango.remove_article_from_collection(
+ article_id, collection_name
+ )
+
+ def get_project_notes(self, project_name: str = None):
+ """
+ Get notes for a project.
+
+ Args:
+ project_name (str, optional): The name of the project. Defaults to current project.
+
+ Returns:
+ list: A list of note documents.
+ """
+ if project_name is None:
+ project_name = self.project_name
+
+ return self.user_arango.get_project_notes(project_name, username=self.username)
+
+ def add_note_to_project(self, note_data: dict):
+ """
+ Add a note to a project.
+
+ Args:
+ note_data (dict): The note data. Should contain project, username, and timestamp.
+
+ Returns:
+ dict: The created note document.
+ """
+ if "project" not in note_data:
+ note_data["project"] = self.project_name
+ if "username" not in note_data:
+ note_data["username"] = self.username
+
+ return self.user_arango.add_note_to_project(note_data)
+
+ def create_project(self, project_data: dict):
+ """
+ Create a new project for the current user.
+
+ Args:
+ project_data (dict): The project data. Should include a name field.
+
+ Returns:
+ dict: The created project document.
+ """
+ if "username" not in project_data:
+ project_data["username"] = self.username
+
+ return self.user_arango.create_project(project_data)
+
+ def update_project(self, project_data: dict):
+ """
+ Update an existing project.
+
+ Args:
+ project_data (dict): The project data. Must include _key.
+
+ Returns:
+ dict: The updated project document.
+ """
+ return self.user_arango.update_project(project_data)
+
+ def delete_project(self, project_name: str):
+ """
+ Delete a project for the current user.
+
+ Args:
+ project_name (str): The name of the project.
+
+ Returns:
+ bool: True if the project was deleted successfully.
+ """
+ return self.user_arango.delete_project(project_name, username=self.username)
+
+ def get_chat(self, chat_key: str):
+ """
+ Get a chat by key.
+
+ Args:
+ chat_key (str): The key of the chat.
+
+ Returns:
+ dict: The chat document if found, None otherwise.
+ """
+ return self.user_arango.get_chat(chat_key)
+
+ def create_or_update_chat(self, chat_data: dict):
+ """
+ Create or update a chat.
+
+ Args:
+ chat_data (dict): The chat data.
+
+ Returns:
+ dict: The created or updated chat document.
+ """
+ if "username" not in chat_data:
+ chat_data["username"] = self.username
+
+ return self.user_arango.create_or_update_chat(chat_data)
+
+ def get_chats_for_project(self, project_name: str = None):
+ """
+ Get all chats for a project.
+
+ Args:
+ project_name (str, optional): The name of the project. Defaults to current project.
+
+ Returns:
+ list: A list of chat documents.
+ """
+ if project_name is None:
+ project_name = self.project_name
+
+ return self.user_arango.get_chats_for_project(
+ project_name, username=self.username
+ )
+
+ def delete_chat(self, chat_key: str):
+ """
+ Delete a chat.
+
+ Args:
+ chat_key (str): The key of the chat.
+
+ Returns:
+ dict: The deletion result.
+ """
+ return self.user_arango.delete_chat(chat_key)
diff --git a/_bots.py b/_bots.py
deleted file mode 100644
index f172b43..0000000
--- a/_bots.py
+++ /dev/null
@@ -1,800 +0,0 @@
-from datetime import datetime
-import streamlit as st
-from _base_class import StreamlitBaseClass, BaseClass
-from _llm import LLM
-from prompts import *
-from colorprinter.print_color import *
-from llm_tools import ToolRegistry
-
-class Chat(StreamlitBaseClass):
- def __init__(self, username=None, **kwargs):
- super().__init__(username=username, **kwargs)
- self.name = kwargs.get("name", None)
- self.chat_history = kwargs.get("chat_history", [])
-
-
- def add_message(self, role, content):
- self.chat_history.append(
- {
- "role": role,
- "content": content.strip().strip('"'),
- "role_type": self.role,
- }
- )
-
- def to_dict(self):
- return {
- "_key": self._key,
- "name": self.name,
- "chat_history": self.chat_history,
- "role": self.role,
- "username": self.username,
- }
-
- def update_in_arango(self):
- self.last_updated = datetime.now().isoformat()
- self.user_arango.db.collection("chats").insert(
- self.to_dict(), overwrite=True, overwrite_mode="update"
- )
-
- def set_name(self, user_input):
- llm = LLM(
- model="small",
- max_length_answer=50,
- temperature=0.4,
- system_message="You are a chatbot who will be chatting with a user",
- )
- prompt = (
- f'Give a short name to the chat based on this user input: "{user_input}" '
- "No more than 30 characters. Answer ONLY with the name of the chat."
- )
- name = llm.generate(prompt).content.strip('"')
- name = f'{name} - {datetime.now().strftime("%B %d")}'
- existing_chat = self.user_arango.db.aql.execute(
- f'FOR doc IN chats FILTER doc.name == "{name}" RETURN doc', count=True
- )
- if existing_chat.count() > 0:
- name = f'{name} ({datetime.now().strftime("%H:%M")})'
- name += f" - [{self.role}]"
- self.name = name
- return name
-
- @classmethod
- def from_dict(cls, data):
- return cls(
- username=data.get("username"),
- name=data.get("name"),
- chat_history=data.get("chat_history", []),
- role=data.get("role", "Research Assistant"),
- _key=data.get("_key"),
- )
-
- def chat_history2bot(self, n_messages: int = None, remove_system: bool = False):
- history = [
- {"role": m["role"], "content": m["content"]} for m in self.chat_history
- ]
- if n_messages and len(history) > n_messages:
- history = history[-n_messages:]
- if (
- all([history[0]["role"] == "system", remove_system])
- or history[0]["role"] == "assistant"
- ):
- history = history[1:]
- return history
-
-
-class Bot(BaseClass):
- def __init__(self, username: str, chat: Chat = None, tools: list = None, **kwargs):
- super().__init__(username=username, **kwargs)
-
- # Use the passed in chat or create a new Chat
- self.chat = chat if chat else Chat(username=username, role="Research Assistant")
- print_yellow(f"Chat:", chat, type(chat))
- # Store or set up project/collection if available
- self.project = kwargs.get("project", None)
- self.collection = kwargs.get("collection", None)
- if self.collection and not isinstance(self.collection, list):
- self.collection = [self.collection]
-
- # Load articles in the collections
- self.arango_ids = []
- if self.collection:
- for c in self.collection:
- for _id in self.user_arango.db.aql.execute(
- """
- FOR doc IN article_collections
- FILTER doc.name == @collection
- FOR article IN doc.articles
- RETURN article._id
- """,
- bind_vars={"collection": c},
- ):
- self.arango_ids.append(_id)
-
- # A standard LLM for normal chat
- self.chatbot = LLM(messages=self.chat.chat_history2bot())
- # A helper bot for generating queries or short prompts
- self.helperbot = LLM(
- temperature=0,
- model="small",
- max_length_answer=500,
- system_message=get_query_builder_system_message(),
- messages=self.chat.chat_history2bot(n_messages=4, remove_system=True),
- )
- # A specialized LLM picking which tool to use
- self.toolbot = LLM(
- temperature=0,
- system_message="""
- You are an assistant bot helping an answering bot to answer a user's messages.
- Your task is to choose one or multiple tools that will help the answering bot to provide the user with the best possible answer.
- You should NEVER directly answer the user. You MUST choose a tool.
- """,
- chat=False,
- model="small",
- )
-
- # Load or register the passed-in tools
- if tools:
- self.tools = ToolRegistry.get_tools(tools=tools)
- else:
- self.tools = ToolRegistry.get_tools()
-
- # Store other kwargs
- for arg in kwargs:
- setattr(self, arg, kwargs[arg])
-
-
-
-
- def get_chunks(
- self,
- user_input,
- collections=["sci_articles", "other_documents"],
- n_results=7,
- n_sources=4,
- filter=True,
- ):
- # Basic version without Streamlit calls
- query = self.helperbot.generate(
- get_generate_vector_query_prompt(user_input, self.chat.role)
- ).content.strip('"')
-
- combined_chunks = []
- if collections:
- for collection in collections:
- where_filter = {"_id": {"$in": self.arango_ids}} if filter else {}
- chunks = self.get_chromadb().query(
- query=query,
- collection=collection,
- n_results=n_results,
- n_sources=n_sources,
- where=where_filter,
- max_retries=3,
- )
- for doc, meta, dist in zip(
- chunks["documents"][0],
- chunks["metadatas"][0],
- chunks["distances"][0],
- ):
- combined_chunks.append(
- {"document": doc, "metadata": meta, "distance": dist}
- )
- combined_chunks.sort(key=lambda x: x["distance"])
-
- # Keep the best chunks according to n_sources
- sources = set()
- closest_chunks = []
- for chunk in combined_chunks:
- source_id = chunk["metadata"].get("_id", "no_id")
- if source_id not in sources:
- sources.add(source_id)
- closest_chunks.append(chunk)
- if len(sources) >= n_sources:
- break
- if len(closest_chunks) < n_results:
- remaining_chunks = [
- c for c in combined_chunks if c not in closest_chunks
- ]
- closest_chunks.extend(remaining_chunks[: n_results - len(closest_chunks)])
-
- # Now fetch real metadata from Arango
- for chunk in closest_chunks:
- _id = chunk["metadata"].get("_id")
- if not _id:
- continue
- if _id.startswith("sci_articles"):
- arango_doc = self.base_arango.db.document(_id)
- else:
- arango_doc = self.user_arango.db.document(_id)
- if arango_doc:
- arango_metadata = arango_doc.get("metadata", {})
- # Possibly merge notes
- if "user_notes" in arango_doc:
- arango_metadata["user_notes"] = arango_doc["user_notes"]
- chunk["metadata"] = arango_metadata
-
- # Group by article title
- grouped_chunks = {}
- article_number = 1
- for chunk in closest_chunks:
- title = chunk["metadata"].get("title", "No title")
- chunk["article_number"] = article_number
- if title not in grouped_chunks:
- grouped_chunks[title] = {
- "article_number": article_number,
- "chunks": [],
- }
- article_number += 1
- grouped_chunks[title]["chunks"].append(chunk)
- return grouped_chunks
-
- def answer_tool_call(self, response, user_input):
- bot_responses = []
- # This method returns / stores responses (no Streamlit calls)
- if not response.get("tool_calls"):
- return ""
-
- for tool in response.get("tool_calls"):
- function_name = tool.function.get('name')
- arguments = tool.function.arguments
- arguments["query"] = user_input
-
- if hasattr(self, function_name):
- if function_name in [
- "fetch_other_documents_tool",
- "fetch_science_articles_tool",
- "fetch_science_articles_and_other_documents_tool",
- ]:
- chunks = getattr(self, function_name)(**arguments)
- bot_responses.append(
- self.generate_from_chunks(user_input, chunks).strip('"')
- )
- elif function_name == "fetch_notes_tool":
- notes = getattr(self, function_name)()
- bot_responses.append(
- self.generate_from_notes(user_input, notes).strip('"')
- )
- elif function_name == "conversational_response_tool":
- bot_responses.append(
- getattr(self, function_name)(user_input).strip('"')
- )
- return "\n\n".join(bot_responses)
-
- def process_user_input(self, user_input, content_attachment=None):
- # Add user message
- self.chat.add_message("user", user_input)
-
- if not content_attachment:
- prompt = get_tools_prompt(user_input)
- response = self.toolbot.generate(prompt, tools=self.tools, stream=False)
- if response.get("tool_calls"):
- bot_response = self.answer_tool_call(response, user_input)
- else:
- # Just respond directly
- bot_response = response.content.strip('"')
- else:
- # If there's an attachment, do something minimal
- bot_response = "Content attachment received (Base Bot)."
-
- # Add assistant message
- if self.chat.chat_history[-1]["role"] != "assistant":
- self.chat.add_message("assistant", bot_response)
-
- # Update in Arango
- self.chat.update_in_arango()
- return bot_response
-
- def generate_from_notes(self, user_input, notes):
- # No Streamlit calls
- notes_string = ""
- for note in notes:
- notes_string += f"\n# {note.get('title','No title')}\n{note.get('content','')}\n---\n"
- prompt = get_chat_prompt(user_input, content_string=notes_string, role=self.chat.role)
- return self.chatbot.generate(prompt, stream=True)
-
- def generate_from_chunks(self, user_input, chunks):
- # No Streamlit calls
- chunks_string = ""
- for title, group in chunks.items():
- user_notes_string = ""
- if "user_notes" in group["chunks"][0]["metadata"]:
- notes = group["chunks"][0]["metadata"]["user_notes"]
- user_notes_string = f'\n\nUser notes:\n"""\n{notes}\n"""\n\n'
- docs = "\n(...)\n".join([c["document"] for c in group["chunks"]])
- chunks_string += (
- f"\n# {title}\n## Article #{group['article_number']}\n{user_notes_string}{docs}\n---\n"
- )
- prompt = get_chat_prompt(user_input, content_string=chunks_string, role=self.chat.role)
- return self.chatbot.generate(prompt, stream=True)
-
- def run(self):
- # Base Bot has no Streamlit run loop
- pass
-
- def get_notes(self):
- # Minimal note retrieval
- notes = self.user_arango.db.aql.execute(
- f'FOR doc IN notes FILTER doc.project == "{self.project.name if self.project else ""}" RETURN doc'
- )
- return list(notes)
-
- @ToolRegistry.register
- def fetch_science_articles_tool(self, query: str, n_documents: int):
- """
- "Fetches information from scientific articles. Use this tool when the user is looking for information from scientific articles."
-
- Parameters:
- query (str): The search query to find relevant scientific articles.
- n_documents (int): How many documents to fetch. A complex query may require more documents. Min: 3, Max: 10.
-
- Returns:
- list: A list of chunks containing information from the fetched scientific articles.
- """
- print_purple('Query:', query)
-
- n_documents = int(n_documents)
- if n_documents < 3:
- n_documents = 3
- elif n_documents > 10:
- n_documents = 10
- return self.get_chunks(
- query, collections=["sci_articles"], n_results=n_documents
- )
-
- @ToolRegistry.register
- def fetch_other_documents_tool(self, query: str, n_documents: int):
- """
- Fetches information from other documents based on the user's query.
-
- This method retrieves information from various types of documents such as reports, news articles, and other texts. It should be used only when it is clear that the user is not seeking scientific articles.
-
- Args:
- query (str): The search query provided by the user.
- n_documents (int): How many documents to fetch. A complex query may require more documents. Min: 2, Max: 10.
-
- Returns:
- list: A list of document chunks that match the query.
- """
- assert isinstance(self, Bot), "The first argument must be a Bot object."
- n_documents = int(n_documents)
- if n_documents < 2:
- n_documents = 2
- elif n_documents > 10:
- n_documents = 10
- return self.get_chunks(
- query,
- collections=[f"{self.username}__other_documents"],
- n_results=n_documents,
- )
-
- @ToolRegistry.register
- def fetch_science_articles_and_other_documents_tool(
- self, query: str, n_documents: int
- ):
- """
- Fetches information from both scientific articles and other documents.
-
- This method is often used when the user hasn't specified what kind of sources they are interested in.
-
- Args:
- query (str): The search query to fetch information for.
- n_documents (int): How many documents to fetch. A complex query may require more documents. Min: 3, Max: 10.
-
- Returns:
- list: A list of document chunks that match the search query.
- """
- assert isinstance(self, Bot), "The first argument must be a Bot object."
- n_documents = int(n_documents)
- if n_documents < 3:
- n_documents = 3
- elif n_documents > 10:
- n_documents = 10
- return self.get_chunks(
- query,
- collections=["sci_articles", f"{self.username}__other_documents"],
- n_results=n_documents,
- )
-
- @ToolRegistry.register
- def fetch_notes_tool(bot):
- """
- Fetches information from the project notes when you as an editor need context from the project notes to understand other information. ONLY use this together with other tools! No arguments needed.
-
- Returns:
- list: A list of notes.
- """
- assert isinstance(bot, Bot), "The first argument must be a Bot object."
- return bot.get_notes()
-
- @ToolRegistry.register
- def conversational_response_tool(self, query: str):
- """
- Generate a conversational response to a user's query.
-
- This method is designed to provide a short and conversational response
- without fetching additional data. It should be used only when it is clear
- that the user is engaging in small talk (like saying 'hi') and not seeking detailed information.
-
- Args:
- query (str): The user's message to which the bot should respond.
-
- Returns:
- str: The generated conversational response.
- """
- query = f"""
- User message: "{query}".
- Make your answer short and conversational.
- This is perhaps not a conversation about a journalistic project, so try not to be too informative.
- Don't answer with anything you're not sure of!
- """
-
- result = (
- self.chatbot.generate(query, stream=True)
- if self.chatbot
- else self.llm.generate(query, stream=True)
- )
- return result
-
-class StreamlitBot(Bot):
- def __init__(self, username: str, chat: StreamlitChat = None, tools: list = None, **kwargs):
- print_purple("StreamlitBot init chat:", chat)
- super().__init__(username=username, chat=chat, tools=tools, **kwargs)
-
- # For Streamlit, we can override or add attributes
- if 'llm_chosen_backend' not in st.session_state:
- st.session_state['llm_chosen_backend'] = None
-
- self.chatbot.chosen_backend = st.session_state['llm_chosen_backend']
- if not st.session_state['llm_chosen_backend']:
- st.session_state['llm_chosen_backend'] = self.chatbot.chosen_backend
-
- def run(self):
- # Example Streamlit run loop
- self.chat.show_chat_history()
- if user_input := st.chat_input("Write your message here...", accept_file=True):
- text_input = user_input.text.replace('"""', "---")
- if len(user_input.files) > 1:
- st.error("Please upload only one file at a time.")
- return
- attached_file = user_input.files[0] if user_input.files else None
-
- content_attachment = None
- if attached_file:
- if attached_file.type == "application/pdf":
- import fitz
- pdf_document = fitz.open(stream=attached_file.read(), filetype="pdf")
- pdf_text = ""
- for page_num in range(len(pdf_document)):
- page = pdf_document.load_page(page_num)
- pdf_text += page.get_text()
- content_attachment = pdf_text
- elif attached_file.type in ["image/png", "image/jpeg"]:
- self.chat.message_attachments = "image"
- content_attachment = attached_file.read()
- with st.chat_message("user", avatar=self.chat.get_avatar(role="user")):
- st.image(content_attachment)
-
- with st.chat_message("user", avatar=self.chat.get_avatar(role="user")):
- st.write(text_input)
-
- if not self.chat.name:
- self.chat.set_name(text_input)
- self.chat.last_updated = datetime.now().isoformat()
- self.chat.saved = False
- self.user_arango.db.collection("chats").insert(
- self.chat.to_dict(), overwrite=True, overwrite_mode="update"
- )
-
- self.process_user_input(text_input, content_attachment)
-
- def process_user_input(self, user_input, content_attachment=None):
- # We override to show messages in Streamlit instead of just storing
- self.chat.add_message("user", user_input)
- if not content_attachment:
- prompt = get_tools_prompt(user_input)
- response = self.toolbot.generate(prompt, tools=self.tools, stream=False)
- if response.get("tool_calls"):
- bot_response = self.answer_tool_call(response, user_input)
- else:
- bot_response = response.content.strip('"')
- with st.chat_message("assistant", avatar=self.chat.get_avatar(role="assistant")):
- st.write(bot_response)
- else:
- with st.chat_message("assistant", avatar=self.chat.get_avatar(role="assistant")):
- with st.spinner("Reading the content..."):
- if self.chat.message_attachments == "image":
- prompt = get_chat_prompt(user_input, role=self.chat.role, image_attachment=True)
- bot_resp = self.chatbot.generate(prompt, stream=False, images=[content_attachment], model="vision")
- st.write(bot_resp)
- bot_response = bot_resp
- else:
- prompt = get_chat_prompt(user_input, content_attachment=content_attachment, role=self.chat.role)
- response = self.chatbot.generate(prompt, stream=True)
- bot_response = st.write_stream(response)
-
- if self.chat.chat_history[-1]["role"] != "assistant":
- self.chat.add_message("assistant", bot_response)
-
- self.chat.update_in_arango()
-
- def answer_tool_call(self, response, user_input):
- bot_responses = []
- for tool in response.get("tool_calls", []):
- function_name = tool.function.get('name')
- arguments = tool.function.arguments
- arguments["query"] = user_input
-
- with st.chat_message("assistant", avatar=self.chat.get_avatar(role="assistant")):
- if function_name in [
- "fetch_other_documents_tool",
- "fetch_science_articles_tool",
- "fetch_science_articles_and_other_documents_tool",
- ]:
- chunks = getattr(self, function_name)(**arguments)
- response_text = self.generate_from_chunks(user_input, chunks)
- bot_response = st.write_stream(response_text).strip('"')
- if chunks:
- sources = "###### Sources:\n"
- for title, group in chunks.items():
- j = group["chunks"][0]["metadata"].get("journal", "No Journal")
- d = group["chunks"][0]["metadata"].get("published_date", "No Date")
- sources += f"[{group['article_number']}] **{title}** :gray[{j} ({d})]\n"
- st.markdown(sources)
- bot_response += f"\n\n{sources}"
- bot_responses.append(bot_response)
-
- elif function_name == "fetch_notes_tool":
- notes = getattr(self, function_name)()
- response_text = self.generate_from_notes(user_input, notes)
- bot_responses.append(st.write_stream(response_text).strip('"'))
-
- elif function_name == "conversational_response_tool":
- response_text = getattr(self, function_name)(user_input)
- bot_responses.append(st.write_stream(response_text).strip('"'))
-
- return "\n\n".join(bot_responses)
-
- def generate_from_notes(self, user_input, notes):
- with st.spinner("Reading project notes..."):
- return super().generate_from_notes(user_input, notes)
-
- def generate_from_chunks(self, user_input, chunks):
- # For reading articles with a spinner
- magazines = set()
- for group in chunks.values():
- j = group["chunks"][0]["metadata"].get("journal", "No Journal")
- magazines.add(f"*{j}*")
- s = (
- f"Reading articles from {', '.join(list(magazines)[:-1])} and {list(magazines)[-1]}..."
- if len(magazines) > 1
- else "Reading articles..."
- )
- with st.spinner(s):
- return super().generate_from_chunks(user_input, chunks)
-
- def sidebar_content(self):
- with st.sidebar:
- st.write("---")
- st.markdown(f'#### {self.chat.name if self.chat.name else ""}')
- st.button("Delete this chat", on_click=self.delete_chat)
-
- def delete_chat(self):
- self.user_arango.db.collection("chats").delete_match(
- filters={"name": self.chat.name}
- )
- self.chat = Chat()
-
- def get_notes(self):
- # We can show a spinner or messages too
- with st.spinner("Fetching notes..."):
- return super().get_notes()
-
-
-class EditorBot(StreamlitBot(Bot)):
- def __init__(self, chat: Chat, username: str, **kwargs):
- print_blue("EditorBot init chat:", chat)
- super().__init__(chat=chat, username=username, **kwargs)
- self.role = "Editor"
- self.tools = ToolRegistry.get_tools()
- self.chatbot = LLM(
- system_message=get_editor_prompt(kwargs.get("project")),
- messages=self.chat.chat_history2bot(),
- chosen_backend=kwargs.get("chosen_backend"),
- )
-
-
-class ResearchAssistantBot(StreamlitBot(Bot)):
- def __init__(self, chat: Chat, username: str, **kwargs):
- super().__init__(chat=chat, username=username, **kwargs)
- self.role = "Research Assistant"
- self.chatbot = LLM(
- system_message=get_assistant_prompt(),
- temperature=0.1,
- messages=self.chat.chat_history2bot(),
- )
- self.tools = [
- self.fetch_science_articles_tool,
- self.fetch_science_articles_and_other_documents_tool,
- ]
-
-
-class PodBot(StreamlitBot(Bot)):
- """Two LLM agents construct a conversation using material from science articles."""
-
- def __init__(
- self,
- chat: Chat,
- subject: str,
- username: str,
- instructions: str = None,
- **kwargs,
- ):
- super().__init__(chat=chat, username=username, **kwargs)
- self.subject = subject
- self.instructions = instructions
- self.guest_name = kwargs.get("name_guest", "Merit")
- self.hostbot = HostBot(
- Chat(username=self.username, role="Host"),
- subject,
- username,
- instructions=instructions,
- **kwargs,
- )
- self.guestbot = GuestBot(
- Chat(username=self.username, role="Guest"),
- subject,
- username,
- name_guest=self.guest_name,
- **kwargs,
- )
-
- def run(self):
-
- notes = self.get_notes()
- notes_string = ""
- if self.instructions:
- instructions_string = f'''
- These are the instructions for the podcast from the producer:
- """
- {self.instructions}
- """
- '''
- else:
- instructions_string = ""
-
- for note in notes:
- notes_string += f"\n# {note['title']}\n{note['content']}\n---\n"
- a = f'''You will make a podcast interview with {self.guest_name}, an expert on "{self.subject}".
- {instructions_string}
- Below are notes on the subject that you can use to ask relevant questions:
- """
- {notes_string}
- """
- Say hello to the expert and start the interview. Remember to keep the interview to the subject of {self.subject} throughout the conversation.
- '''
-
- # Stop button for the podcast
- with st.sidebar:
- stop = st.button("Stop podcast", on_click=self.stop_podcast)
-
- while st.session_state["make_podcast"]:
- # Stop the podcast if there are more than 14 messages in the chat
- self.chat.show_chat_history()
- if len(self.chat.chat_history) == 14:
- result = self.hostbot.generate(
- "The interview has ended. Say thank you to the expert and end the conversation."
- )
- self.chat.add_message("Host", result)
- with st.chat_message(
- "assistant", avatar=self.chat.get_avatar(role="assistant")
- ):
- st.write(result.strip('"'))
- st.stop()
-
- _q = self.hostbot.toolbot.generate(
- query=f"{self.guest_name} has answered: {a}. You have to choose a tool to help the host continue the interview.",
- tools=self.hostbot.tools,
- temperature=0.6,
- stream=False,
- )
- if "tool_calls" in _q:
- q = self.hostbot.answer_tool_call(_q, a)
- else:
- q = _q
-
- self.chat.add_message("Host", q)
-
- _a = self.guestbot.toolbot.generate(
- f'The podcast host has asked: "{q}" Choose a tool to help the expert answer with relevant facts and information.',
- tools=self.guestbot.tools,
- )
- if "tool_calls" in _a:
- print_yellow("Tool call response (guest)", _a)
- print_yellow(self.guestbot.chat.role)
- a = self.guestbot.answer_tool_call(_a, q)
- else:
- a = _a
- self.chat.add_message("Guest", a)
-
- self.update_session_state()
-
- def stop_podcast(self):
- st.session_state["make_podcast"] = False
- self.update_session_state()
- self.chat.show_chat_history()
-
-
-class HostBot(StreamlitBot(Bot)):
- def __init__(
- self, chat: Chat, subject: str, username: str, instructions: str, **kwargs
- ):
- super().__init__(chat=chat, username=username, **kwargs)
- self.chat.role = kwargs.get("role", "Host")
- self.tools = ToolRegistry.get_tools(
- tools=[
- self.fetch_notes_tool,
- self.conversational_response_tool,
- # "fetch_other_documents", #TODO Should this be included?
- ]
- )
- self.instructions = instructions
- self.llm = LLM(
- system_message=f'''
- You are the host of a podcast and an expert on {subject}. You will ask one question at a time about the subject, and then wait for the guest to answer.
- Don't ask the guest to talk about herself/himself, only about the subject.
- Make your questions short and clear, only if necessary add a brief context to the question.
- These are the instructions for the podcast from the producer:
- """
- {self.instructions}
- """
- If the experts' answer is complicated, try to make a very brief summary of it for the audience to understand. You can also ask follow-up questions to clarify the answer, or ask for examples.
- ''',
- messages=self.chat.chat_history2bot()
- )
- self.toolbot = LLM(
- temperature=0,
- system_message="""
- You are assisting a podcast host in asking questions to an expert.
- Choose one or many tools to use in order to assist the host in asking relevant questions.
- Often "conversational_response_tool" is enough, but sometimes project notes are needed.
- Make sure to read the description of the tools carefully!""",
- chat=False,
- model="small",
- )
-
- def generate(self, query):
- return self.llm.generate(query)
-
-
-class GuestBot(StreamlitBot(Bot)):
- def __init__(self, chat: Chat, subject: str, username: str, **kwargs):
- super().__init__(chat=chat, username=username, **kwargs)
- self.chat.role = kwargs.get("role", "Guest")
- self.tools = ToolRegistry.get_tools(
- tools=[
- self.fetch_notes_tool,
- self.fetch_science_articles_tool,
- ]
- )
-
- self.llm = LLM(
- system_message=f"""
- You are {kwargs.get('name', 'Merit')}, an expert on {subject}.
- Today you are a guest in a podcast about {subject}. A host will ask you questions about the subject and you will answer by using scientific facts and information.
- When answering, don't say things like "based on the documents" or alike, as neither the host nor the audience can see the documents. Act just as if you were talking to someone in a conversation.
- Try to be concise when answering, and remember that the audience of the podcast is not expert on the subject, so don't complicate things too much.
- It's very important that you answer in a "spoken" way, as if you were talking to someone in a conversation. That means you should avoid using scientific jargon and complex terms, too many figures or abstract concepts.
- Lists are also not recommended, instead use "for the first reason", "secondly", etc.
- Instead, use "..." to indicate a pause, "-" to indicate a break in the sentence, as if you were speaking.
- """,
- messages=self.chat.chat_history2bot()
- )
- self.toolbot = LLM(
- temperature=0,
- system_message=f"You are an assistant to an expert on {subject}. Choose one or many tools to use in order to assist the expert in answering questions. Make sure to read the description of the tools carefully.",
- chat=False,
- model="small",
- )
-
- def generate(self, query):
- return self.llm.generate(query)
diff --git a/_bots_dont_use.py b/_bots_dont_use.py
new file mode 100644
index 0000000..e1fd68f
--- /dev/null
+++ b/_bots_dont_use.py
@@ -0,0 +1,497 @@
+from datetime import datetime
+import streamlit as st
+import uuid
+
+from _base_class import StreamlitBaseClass, BaseClass
+from _llm import LLM
+from _arango import ArangoDB
+from prompts import *
+from colorprinter.print_color import *
+from llm_tools import ToolRegistry
+from streamlit_chatbot import StreamlitBot, PodBot, EditorBot, ResearchAssistantBot
+
+class Chat(StreamlitBaseClass):
+ def __init__(self, username=None, **kwargs):
+ super().__init__(username=username, **kwargs)
+ self.name = kwargs.get("name", None)
+ self.chat_history = kwargs.get("chat_history", [])
+ self.role = kwargs.get("role", "Research Assistant")
+ self._key = kwargs.get("_key", str(uuid.uuid4()))
+ self.saved = kwargs.get("saved", False)
+ self.last_updated = kwargs.get("last_updated", datetime.now().isoformat())
+ self.message_attachments = None
+ self.project = kwargs.get("project", None)
+
+ def add_message(self, role, content):
+ self.chat_history.append(
+ {
+ "role": role,
+ "content": content.strip().strip('"'),
+ "role_type": self.role,
+ }
+ )
+
+ def to_dict(self):
+ return {
+ "_key": self._key,
+ "name": self.name,
+ "chat_history": self.chat_history,
+ "role": self.role,
+ "username": self.username,
+ "project": self.project,
+ "last_updated": self.last_updated,
+ "saved": self.saved,
+ }
+
+ def update_in_arango(self):
+ """Update chat in ArangoDB using the new API"""
+ self.last_updated = datetime.now().isoformat()
+
+ # Use the create_or_update_chat method from the new API
+ self.user_arango.create_or_update_chat(self.to_dict())
+
+ def set_name(self, user_input):
+ llm = LLM(
+ model="small",
+ max_length_answer=50,
+ temperature=0.4,
+ system_message="You are a chatbot who will be chatting with a user",
+ )
+ prompt = (
+ f'Give a short name to the chat based on this user input: "{user_input}" '
+ "No more than 30 characters. Answer ONLY with the name of the chat."
+ )
+ name = llm.generate(prompt).content.strip('"')
+ name = f'{name} - {datetime.now().strftime("%B %d")}'
+
+ # Check for existing chat with the same name
+ existing_chat = self.user_arango.execute_aql(
+ """
+ FOR chat IN chats
+ FILTER chat.name == @name AND chat.username == @username
+ RETURN chat
+ """,
+ bind_vars={"name": name, "username": self.username}
+ )
+
+ if list(existing_chat):
+ name = f'{name} ({datetime.now().strftime("%H:%M")})'
+ name += f" - [{self.role}]"
+ self.name = name
+ return name
+
+ def show_chat_history(self):
+ """Display chat history in the Streamlit UI"""
+ for message in self.chat_history:
+ with st.chat_message(
+ name="assistant" if message["role"] == "assistant" else "user",
+ avatar=self.get_avatar(role=message["role"])
+ ):
+ st.write(message["content"])
+
+ def get_avatar(self, role):
+ """Get avatar for a role"""
+ if role == "user":
+ return None
+ elif role == "Host":
+ return "🎙️"
+ elif role == "Guest":
+ return "🎤"
+ elif role == "assistant":
+ if self.role == "Research Assistant":
+ return "🔬"
+ elif self.role == "Editor":
+ return "📝"
+ else:
+ return "🤖"
+ return None
+
+ @classmethod
+ def from_dict(cls, data):
+ return cls(
+ username=data.get("username"),
+ name=data.get("name"),
+ chat_history=data.get("chat_history", []),
+ role=data.get("role", "Research Assistant"),
+ _key=data.get("_key"),
+ project=data.get("project"),
+ last_updated=data.get("last_updated"),
+ saved=data.get("saved", False),
+ )
+
+ def chat_history2bot(self, n_messages: int = None, remove_system: bool = False):
+ history = [
+ {"role": m["role"], "content": m["content"]} for m in self.chat_history
+ ]
+ if n_messages and len(history) > n_messages:
+ history = history[-n_messages:]
+ if (
+ all([history[0]["role"] == "system", remove_system])
+ or history[0]["role"] == "assistant"
+ ):
+ history = history[1:]
+ return history
+
+
+class Bot(BaseClass):
+ def __init__(self, username: str, chat: Chat = None, tools: list = None, **kwargs):
+ super().__init__(username=username, **kwargs)
+
+ # Use the passed in chat or create a new Chat
+ self.chat = chat if chat else Chat(username=username, role="Research Assistant")
+ print_yellow(f"Chat:", chat, type(chat))
+
+ # Store or set up project/collection if available
+ self.project = kwargs.get("project", None)
+ self.collection = kwargs.get("collection", None)
+ if self.collection and not isinstance(self.collection, list):
+ self.collection = [self.collection]
+
+ # Load articles in the collections using the new API
+ self.arango_ids = []
+ if self.collection:
+ for c in self.collection:
+ # Use execute_aql from the new API
+ article_ids = self.user_arango.execute_aql(
+ """
+ FOR doc IN article_collections
+ FILTER doc.name == @collection
+ FOR article IN doc.articles
+ RETURN article
+ """,
+ bind_vars={"collection": c}
+ )
+ for _id in article_ids:
+ self.arango_ids.append(_id)
+
+ # A standard LLM for normal chat
+ self.chatbot = LLM(messages=self.chat.chat_history2bot())
+ # A helper bot for generating queries or short prompts
+ self.helperbot = LLM(
+ temperature=0,
+ model="small",
+ max_length_answer=500,
+ system_message=get_query_builder_system_message(),
+ messages=self.chat.chat_history2bot(n_messages=4, remove_system=True),
+ )
+ # A specialized LLM picking which tool to use
+ self.toolbot = LLM(
+ temperature=0,
+ system_message="""
+ You are an assistant bot helping an answering bot to answer a user's messages.
+ Your task is to choose one or multiple tools that will help the answering bot to provide the user with the best possible answer.
+ You should NEVER directly answer the user. You MUST choose a tool.
+ """,
+ chat=False,
+ model="small",
+ )
+
+ # Load or register the passed-in tools
+ if tools:
+ self.tools = ToolRegistry.get_tools(tools=tools)
+ else:
+ self.tools = ToolRegistry.get_tools()
+
+ # Store other kwargs
+ for arg in kwargs:
+ setattr(self, arg, kwargs[arg])
+
+ def get_chunks(
+ self,
+ user_input,
+ collections=["sci_articles", "other_documents"],
+ n_results=7,
+ n_sources=4,
+ filter=True,
+ ):
+ # Basic version without Streamlit calls
+ query = self.helperbot.generate(
+ get_generate_vector_query_prompt(user_input, self.chat.role)
+ ).content.strip('"')
+
+ combined_chunks = []
+ if collections:
+ for collection in collections:
+ where_filter = {"_id": {"$in": self.arango_ids}} if filter else {}
+ chunks = self.get_chromadb().query(
+ query=query,
+ collection=collection,
+ n_results=n_results,
+ n_sources=n_sources,
+ where=where_filter,
+ max_retries=3,
+ )
+ for doc, meta, dist in zip(
+ chunks["documents"][0],
+ chunks["metadatas"][0],
+ chunks["distances"][0],
+ ):
+ combined_chunks.append(
+ {"document": doc, "metadata": meta, "distance": dist}
+ )
+ combined_chunks.sort(key=lambda x: x["distance"])
+
+ # Keep the best chunks according to n_sources
+ sources = set()
+ closest_chunks = []
+ for chunk in combined_chunks:
+ source_id = chunk["metadata"].get("_id", "no_id")
+ if source_id not in sources:
+ sources.add(source_id)
+ closest_chunks.append(chunk)
+ if len(sources) >= n_sources:
+ break
+ if len(closest_chunks) < n_results:
+ remaining_chunks = [
+ c for c in combined_chunks if c not in closest_chunks
+ ]
+ closest_chunks.extend(remaining_chunks[: n_results - len(closest_chunks)])
+
+ # Now fetch real metadata from Arango using the new API
+ for chunk in closest_chunks:
+ _id = chunk["metadata"].get("_id")
+ if not _id:
+ continue
+
+ try:
+ # Determine which database to use based on collection name
+ if _id.startswith("sci_articles"):
+ # Use base_arango for common documents
+ arango_doc = self.base_arango.get_document(_id)
+ else:
+ # Use user_arango for user-specific documents
+ arango_doc = self.user_arango.get_document(_id)
+
+ if arango_doc:
+ arango_metadata = arango_doc.get("metadata", {})
+ # Possibly merge notes
+ if "user_notes" in arango_doc:
+ arango_metadata["user_notes"] = arango_doc["user_notes"]
+ chunk["metadata"] = arango_metadata
+ except Exception as e:
+ print_red(f"Error fetching document {_id}: {e}")
+
+ # Group by article title
+ grouped_chunks = {}
+ article_number = 1
+ for chunk in closest_chunks:
+ title = chunk["metadata"].get("title", "No title")
+ chunk["article_number"] = article_number
+ if title not in grouped_chunks:
+ grouped_chunks[title] = {
+ "article_number": article_number,
+ "chunks": [],
+ }
+ article_number += 1
+ grouped_chunks[title]["chunks"].append(chunk)
+ return grouped_chunks
+
+ def answer_tool_call(self, response, user_input):
+ bot_responses = []
+ # This method returns / stores responses (no Streamlit calls)
+ if not response.get("tool_calls"):
+ return ""
+
+ for tool in response.get("tool_calls"):
+ function_name = tool.function.get('name')
+ arguments = tool.function.arguments
+ arguments["query"] = user_input
+
+ if hasattr(self, function_name):
+ if function_name in [
+ "fetch_other_documents_tool",
+ "fetch_science_articles_tool",
+ "fetch_science_articles_and_other_documents_tool",
+ ]:
+ chunks = getattr(self, function_name)(**arguments)
+ bot_responses.append(
+ self.generate_from_chunks(user_input, chunks).strip('"')
+ )
+ elif function_name == "fetch_notes_tool":
+ notes = getattr(self, function_name)()
+ bot_responses.append(
+ self.generate_from_notes(user_input, notes).strip('"')
+ )
+ elif function_name == "conversational_response_tool":
+ bot_responses.append(
+ getattr(self, function_name)(user_input).strip('"')
+ )
+ return "\n\n".join(bot_responses)
+
+ def process_user_input(self, user_input, content_attachment=None):
+ # Add user message
+ self.chat.add_message("user", user_input)
+
+ if not content_attachment:
+ prompt = get_tools_prompt(user_input)
+ response = self.toolbot.generate(prompt, tools=self.tools, stream=False)
+ if response.get("tool_calls"):
+ bot_response = self.answer_tool_call(response, user_input)
+ else:
+ # Just respond directly
+ bot_response = response.content.strip('"')
+ else:
+ # If there's an attachment, do something minimal
+ bot_response = "Content attachment received (Base Bot)."
+
+ # Add assistant message
+ if self.chat.chat_history[-1]["role"] != "assistant":
+ self.chat.add_message("assistant", bot_response)
+
+ # Update in Arango
+ self.chat.update_in_arango()
+ return bot_response
+
+ def generate_from_notes(self, user_input, notes):
+ # No Streamlit calls
+ notes_string = ""
+ for note in notes:
+ notes_string += f"\n# {note.get('title','No title')}\n{note.get('text','')}\n---\n"
+ prompt = get_chat_prompt(user_input, content_string=notes_string, role=self.chat.role)
+ return self.chatbot.generate(prompt, stream=True)
+
+ def generate_from_chunks(self, user_input, chunks):
+ # No Streamlit calls
+ chunks_string = ""
+ for title, group in chunks.items():
+ user_notes_string = ""
+ if "user_notes" in group["chunks"][0]["metadata"]:
+ notes = group["chunks"][0]["metadata"]["user_notes"]
+ user_notes_string = f'\n\nUser notes:\n"""\n{notes}\n"""\n\n'
+ docs = "\n(...)\n".join([c["document"] for c in group["chunks"]])
+ chunks_string += (
+ f"\n# {title}\n## Article #{group['article_number']}\n{user_notes_string}{docs}\n---\n"
+ )
+ prompt = get_chat_prompt(user_input, content_string=chunks_string, role=self.chat.role)
+ return self.chatbot.generate(prompt, stream=True)
+
+ def run(self):
+ # Base Bot has no Streamlit run loop
+ pass
+
+ def get_notes(self):
+ # Get project notes using the new API
+ if self.project and hasattr(self.project, "name"):
+ notes = self.user_arango.get_project_notes(
+ project_name=self.project.name,
+ username=self.username
+ )
+ return list(notes)
+ return []
+
+ @ToolRegistry.register
+ def fetch_science_articles_tool(self, query: str, n_documents: int):
+ """
+ "Fetches information from scientific articles. Use this tool when the user is looking for information from scientific articles."
+
+ Parameters:
+ query (str): The search query to find relevant scientific articles.
+ n_documents (int): How many documents to fetch. A complex query may require more documents. Min: 3, Max: 10.
+
+ Returns:
+ list: A list of chunks containing information from the fetched scientific articles.
+ """
+ print_purple('Query:', query)
+
+ n_documents = int(n_documents)
+ if n_documents < 3:
+ n_documents = 3
+ elif n_documents > 10:
+ n_documents = 10
+ return self.get_chunks(
+ query, collections=["sci_articles"], n_results=n_documents
+ )
+
+ @ToolRegistry.register
+ def fetch_other_documents_tool(self, query: str, n_documents: int):
+ """
+ Fetches information from other documents based on the user's query.
+
+ This method retrieves information from various types of documents such as reports, news articles, and other texts. It should be used only when it is clear that the user is not seeking scientific articles.
+
+ Args:
+ query (str): The search query provided by the user.
+ n_documents (int): How many documents to fetch. A complex query may require more documents. Min: 2, Max: 10.
+
+ Returns:
+ list: A list of document chunks that match the query.
+ """
+ assert isinstance(self, Bot), "The first argument must be a Bot object."
+ n_documents = int(n_documents)
+ if n_documents < 2:
+ n_documents = 2
+ elif n_documents > 10:
+ n_documents = 10
+ return self.get_chunks(
+ query,
+ collections=[f"{self.username}__other_documents"],
+ n_results=n_documents,
+ )
+
+ @ToolRegistry.register
+ def fetch_science_articles_and_other_documents_tool(
+ self, query: str, n_documents: int
+ ):
+ """
+ Fetches information from both scientific articles and other documents.
+
+ This method is often used when the user hasn't specified what kind of sources they are interested in.
+
+ Args:
+ query (str): The search query to fetch information for.
+ n_documents (int): How many documents to fetch. A complex query may require more documents. Min: 3, Max: 10.
+
+ Returns:
+ list: A list of document chunks that match the search query.
+ """
+ assert isinstance(self, Bot), "The first argument must be a Bot object."
+ n_documents = int(n_documents)
+ if n_documents < 3:
+ n_documents = 3
+ elif n_documents > 10:
+ n_documents = 10
+ return self.get_chunks(
+ query,
+ collections=["sci_articles", f"{self.username}__other_documents"],
+ n_results=n_documents,
+ )
+
+ @ToolRegistry.register
+ def fetch_notes_tool(bot):
+ """
+ Fetches information from the project notes when you as an editor need context from the project notes to understand other information. ONLY use this together with other tools! No arguments needed.
+
+ Returns:
+ list: A list of notes.
+ """
+ assert isinstance(bot, Bot), "The first argument must be a Bot object."
+ return bot.get_notes()
+
+ @ToolRegistry.register
+ def conversational_response_tool(self, query: str):
+ """
+ Generate a conversational response to a user's query.
+
+ This method is designed to provide a short and conversational response
+ without fetching additional data. It should be used only when it is clear
+ that the user is engaging in small talk (like saying 'hi') and not seeking detailed information.
+
+ Args:
+ query (str): The user's message to which the bot should respond.
+
+ Returns:
+ str: The generated conversational response.
+ """
+ query = f"""
+ User message: "{query}".
+ Make your answer short and conversational.
+ This is perhaps not a conversation about a journalistic project, so try not to be too informative.
+ Don't answer with anything you're not sure of!
+ """
+
+ result = (
+ self.chatbot.generate(query, stream=True)
+ if self.chatbot
+ else self.llm.generate(query, stream=True)
+ )
+ return result
diff --git a/_chromadb.py b/_chromadb.py
index d73dec0..47413bc 100644
--- a/_chromadb.py
+++ b/_chromadb.py
@@ -1,8 +1,13 @@
import chromadb
import os
+from typing import Union, List, Dict, Tuple, Any, Union
+import re
+
from chromadb.config import Settings
from dotenv import load_dotenv
from colorprinter.print_color import *
+from models import ChunkSearchResults
+
load_dotenv(".env")
@@ -20,6 +25,7 @@ class ChromaDB:
)
self.db = chromadb.HttpClient(
host=host,
+ #database=db,
settings=Settings(
chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider",
chroma_client_auth_credentials=credentials,
@@ -63,14 +69,20 @@ class ChromaDB:
col = self.db.get_collection(collection)
sources = []
n = 0
-
+ print('Collection', collection)
result = {"ids": [[]], "metadatas": [[]], "documents": [[]], "distances": [[]]}
+
while True:
n += 1
if n > max_retries:
break
if where == {}:
- where = None
+ where = None
+
+ print_rainbow(kwargs)
+ print('N_results:', n_results)
+ print('Sources:', sources)
+ print('Query:', query)
r = col.query(
query_texts=query,
n_results=n_results - len(sources),
@@ -79,6 +91,7 @@ class ChromaDB:
)
if r["ids"][0] == []:
if result["ids"][0] == []:
+ print_rainbow(r)
print_red("No results found in vector database.")
else:
print_red("No more results found in vector database.")
@@ -123,6 +136,210 @@ class ChromaDB:
break
return result
+ def search(
+ self,
+ query: str,
+ collection: str,
+ n_results: int = 6,
+ n_sources: int = 3,
+ where: dict = None,
+ format_results: bool = False,
+ **kwargs,
+ ) -> Union[dict, ChunkSearchResults]:
+ """
+ An enhanced search method that provides a cleaner interface for querying and processing results.
+
+ Args:
+ query (str): The search query
+ collection (str): Collection name to search in
+ n_results (int): Maximum number of results to return
+ n_sources (int): Maximum number of unique sources to include
+ where (dict, optional): Additional filtering criteria
+ format_results (bool): Whether to return formatted ChunkSearchResults
+ **kwargs: Additional arguments to pass to the query
+
+ Returns:
+ List[dict]: List of dictionaries containing the search results
+ """
+ # Get raw query results with existing query method
+ result = self.query(
+ query=query,
+ collection=collection,
+ n_results=n_results,
+ n_sources=n_sources,
+ where=where,
+ **kwargs,
+ )
+
+
+ # If no formatting requested, return raw results
+ if not format_results:
+ return result
+
+ # Process results into dictionary format
+ combined_chunks = []
+ for doc, meta, dist, _id in zip(
+ result["documents"][0],
+ result["metadatas"][0],
+ result["distances"][0],
+ result["ids"][0],
+ ):
+ combined_chunks.append(
+ {"document": doc, "metadata": meta, "distance": dist, "id": _id}
+ )
+
+ return combined_chunks
+
+ def clean_result_text(self, documents: list) -> list:
+ """
+ Clean text in document results by removing footnote references.
+
+ Args:
+ documents (list): List of document dictionaries
+
+ Returns:
+ list: Documents with cleaned text
+ """
+ import re
+
+ for doc in documents:
+ if "document" in doc:
+ doc["document"] = re.sub(r"\[\d+\]", "", doc["document"])
+ return documents
+
+ def filter_by_unique_sources(
+ self, results: list, n_sources: int, source_key: str = "_id"
+ ) -> Tuple[List, List]:
+ """
+ Filters search results to keep only a specified number of unique sources.
+
+ Args:
+ results (list): List of documents from search
+ n_sources (int): Maximum number of unique sources to include
+ source_key (str): The key in metadata that identifies the source
+
+ Returns:
+ tuple: (filtered_results, remaining_results)
+ """
+ sources = set()
+ filtered_results = []
+ remaining_results = []
+
+ for item in results:
+ source_id = item["metadata"].get(source_key, "no_id")
+ if source_id not in sources and len(sources) < n_sources:
+ sources.add(source_id)
+ filtered_results.append(item)
+ else:
+ remaining_results.append(item)
+
+ return filtered_results, remaining_results
+
+ def backfill_results(
+ self, filtered_results: list, remaining_results: list, n_results: int
+ ) -> list:
+ """
+ Adds additional results from remaining_results to filtered_results
+ until n_results is reached.
+
+ Args:
+ filtered_results (list): Initial filtered results
+ remaining_results (list): Other results that can be added
+ n_results (int): Target number of total results
+
+ Returns:
+ list: Combined results up to n_results
+ """
+ if len(filtered_results) >= n_results:
+ return filtered_results[:n_results]
+
+ needed = n_results - len(filtered_results)
+ return filtered_results + remaining_results[:needed]
+
+ def search_chunks(
+ self,
+ query: str,
+ collections: List[str],
+ n_results: int = 7,
+ n_sources: int = 4,
+ where: dict = None,
+ **kwargs,
+ ) -> ChunkSearchResults:
+ """
+ Complete pipeline for processing chunks: search, filter, clean, and format.
+
+ Args:
+ query (str): The search query
+ collections (List[str]): List of collection names to search
+ n_results (int): Maximum number of results to return
+ n_sources (int): Maximum number of unique sources to include
+ where (dict, optional): Additional filtering criteria
+ **kwargs: Additional arguments to pass to search
+
+ Returns:
+ ChunkSearchResults: Processed chunks with Chroma IDs
+ """
+ combined_chunks = []
+
+ if isinstance(collections, str):
+ collections = [collections]
+
+ # Search all collections
+
+ for collection in collections:
+ chunks = self.search(
+ query=query,
+ collection=collection,
+ n_results=n_results,
+ n_sources=n_sources,
+ where=where,
+ format_results=True,
+ **kwargs,
+ )
+
+
+ for chunk in chunks:
+ combined_chunks.append({
+ "document": chunk["document"],
+ "metadata": chunk["metadata"],
+ "distance": chunk["distance"],
+ "id": chunk["id"],
+ })
+
+ # Sort and filter results
+ combined_chunks.sort(key=lambda x: x["distance"])
+
+ # Filter by unique sources and backfill
+ closest_chunks, remaining_chunks = self.filter_by_unique_sources(
+ combined_chunks, n_sources
+ )
+ closest_chunks = self.backfill_results(
+ closest_chunks, remaining_chunks, n_results
+ )
+
+ # Clean text
+ closest_chunks = self.clean_result_text(closest_chunks)
+ return closest_chunks
+
+
+ def add_document(self, _id, collection: str, document: str, metadata: dict = None):
+ """
+ Adds a single document to a specified collection in the database.
+
+ Args:
+ _id (str): Arango ID for the document, used as a unique identifier.
+ collection (str): The name of the collection to add the document to.
+ document (str): The document text to be added.
+ metadata (dict, optional): Metadata to be associated with the document. Defaults to None.
+
+ Returns:
+ None
+ """
+ col = self.db.get_or_create_collection(collection)
+ if metadata is None:
+ metadata = {}
+ col.add(ids=[_id], documents=[document], metadatas=[metadata])
+
def add_chunks(self, collection: str, chunks: list, _key, metadata: dict = None):
"""
Adds chunks to a specified collection in the database.
@@ -148,18 +365,88 @@ class ChromaDB:
ids.append(f"{_key}_{number}")
col.add(ids=ids, metadatas=metadatas, documents=chunks)
+ def get_collection(self, collection: str) -> chromadb.Collection:
+ """
+ Retrieves a collection from the database.
+
+ Args:
+ collection (str): The name of the collection to retrieve.
+
+ Returns:
+ chromadb.Collection: The requested collection.
+ """
+ return self.db.get_or_create_collection(collection)
+
+def is_reference_chunk(text: str) -> bool:
+ """
+ Determine if a text chunk primarily consists of academic references.
+
+ Args:
+ text (str): Text chunk to analyze
+
+ Returns:
+ bool: True if the chunk appears to be mainly references
+ """
+ # Count significant reference indicators
+ indicators = 0
+
+ # Check for DOI links (very strong indicator)
+ doi_matches = len(re.findall(r'https?://doi\.org/10\.\d+/\S+', text))
+ if doi_matches >= 2: # Multiple DOIs almost certainly means references
+ return True
+ elif doi_matches == 1:
+ indicators += 3
+
+ # Check for citation patterns with year, volume, pages (e.g., 2018;178:551–60)
+ citation_patterns = len(re.findall(r'\d{4};\d+:\d+[-–]\d+', text))
+ indicators += citation_patterns * 2
+
+ # Check for year patterns in brackets [YYYY]
+ year_brackets = len(re.findall(r'\[\d{4}\]', text))
+ indicators += year_brackets
+
+ # Check for multiple lines starting with author name patterns
+ lines = [line.strip() for line in text.split('\n') if line.strip()]
+ author_started_lines = 0
+
+ for line in lines:
+ # Common pattern in references: starts with Author Name(s)
+ if re.match(r'^\s*[A-Z][a-z]+\s+[A-Z][a-z]+', line):
+ author_started_lines += 1
+
+ # If multiple lines start with author names (common in reference lists)
+ if author_started_lines >= 2:
+ indicators += 2
+
+ # Check for academic reference terms
+ if re.search(r'\bet al\b|\bet al\.\b', text, re.IGNORECASE):
+ indicators += 1
+
+ # Return True if we have sufficient indicators
+ return indicators >= 4 # Adjust threshold as needed
if __name__ == "__main__":
from colorprinter.print_color import *
+
chroma = ChromaDB()
print(chroma.db.list_collections())
- exit()
- result = chroma.query(
+ print('DB', chroma.db.database)
+ print('SETTINGS', chroma.db.get_version())
+
+ result = chroma.search_chunks(
query="What is Open Science)",
- collection="sci_articles",
+ collections="lasse__other_documents",
n_results=2,
n_sources=3,
max_retries=4,
)
- print_rainbow(result["metadatas"][0])
+
+ collection = chroma.db.get_or_create_collection("lasse__other_documents")
+ result = collection.query(
+ query_texts="What is Open Science?",
+ n_results=2,
+ )
+ from pprint import pprint
+ pprint(result)
+ #print_rainbow(result["metadatas"][0])
diff --git a/_llm.py b/_llm.py
deleted file mode 100644
index 05ebe67..0000000
--- a/_llm.py
+++ /dev/null
@@ -1,574 +0,0 @@
-import os
-import base64
-import re
-from typing import Literal, Optional
-import requests
-import tiktoken
-from ollama import (
- Client,
- AsyncClient,
- ResponseError,
- ChatResponse,
- Tool,
- Options,
-)
-
-import env_manager
-from colorprinter.print_color import *
-
-env_manager.set_env()
-
-tokenizer = tiktoken.get_encoding("cl100k_base")
-
-
-class LLM:
- """
- LLM class for interacting with an instance of Ollama.
-
- Attributes:
- model (str): The model to be used for response generation.
- system_message (str): The system message to be used in the chat.
- options (dict): Options for the model, such as temperature.
- messages (list): List of messages in the chat.
- max_length_answer (int): Maximum length of the generated answer.
- chat (bool): Whether the chat mode is enabled.
- chosen_backend (str): The chosen backend server for the API.
- client (Client): The client for synchronous API calls.
- async_client (AsyncClient): The client for asynchronous API calls.
- tools (list): List of tools to be used in generating the response.
-
- Methods:
- __init__(self, system_message, temperature, model, max_length_answer, messages, chat, chosen_backend):
- Initializes the LLM class with the provided parameters.
-
- get_model(self, model_alias):
- Retrieves the model name based on the provided alias.
-
- count_tokens(self):
- Counts the number of tokens in the messages.
-
- get_least_conn_server(self):
- Retrieves the least connected server from the backend.
-
- generate(self, query, user_input, context, stream, tools, images, model, temperature):
- Generates a response based on the provided query and options.
-
- make_summary(self, text):
- Generates a summary of the provided text.
-
- read_stream(self, response):
- Handles streaming responses.
-
- async_generate(self, query, user_input, context, stream, tools, images, model, temperature):
- Asynchronously generates a response based on the provided query and options.
-
- prepare_images(self, images, message):
- """
-
- def __init__(
- self,
- system_message: str = "You are an assistant.",
- temperature: float = 0.01,
- model: Optional[
- Literal["small", "standard", "vision", "reasoning", "tools"]
- ] = "standard",
- max_length_answer: int = 4096,
- messages: list[dict] = None,
- chat: bool = True,
- chosen_backend: str = None,
- tools: list = None,
- ) -> None:
- """
- Initialize the assistant with the given parameters.
-
- Args:
- system_message (str): The initial system message for the assistant. Defaults to "You are an assistant.".
- temperature (float): The temperature setting for the model, affecting randomness. Defaults to 0.01.
- model (Optional[Literal["small", "standard", "vision", "reasoning"]]): The model type to use. Defaults to "standard".
- max_length_answer (int): The maximum length of the generated answer. Defaults to 4096.
- messages (list[dict], optional): A list of initial messages. Defaults to None.
- chat (bool): Whether the assistant is in chat mode. Defaults to True.
- chosen_backend (str, optional): The backend server to use. If not provided, the least connected server is chosen.
-
- Returns:
- None
- """
-
- self.model = self.get_model(model)
- self.call_model = (
- self.model
- ) # This is set per call to decide what model that was actually used
- self.system_message = system_message
- self.options = {"temperature": temperature}
- self.messages = messages or [{"role": "system", "content": self.system_message}]
- self.max_length_answer = max_length_answer
- self.chat = chat
-
- if not chosen_backend:
- chosen_backend = self.get_least_conn_server()
- self.chosen_backend = chosen_backend
-
-
- headers = {
- "Authorization": f"Basic {self.get_credentials()}",
- "X-Chosen-Backend": self.chosen_backend,
- }
- self.host_url = os.getenv("LLM_API_URL").rstrip("/api/chat/")
- self.host_url = 'http://192.168.1.12:3300' #! Change back when possible
- self.client: Client = Client(host=self.host_url, headers=headers, timeout=120)
- self.async_client: AsyncClient = AsyncClient()
-
- def get_credentials(self):
- # Initialize the client with the host and default headers
- credentials = f"{os.getenv('LLM_API_USER')}:{os.getenv('LLM_API_PWD_LASSE')}"
- return base64.b64encode(credentials.encode()).decode()
-
- def get_model(self, model_alias):
-
- models = {
- "standard": "LLM_MODEL",
- "small": "LLM_MODEL_SMALL",
- "vision": "LLM_MODEL_VISION",
- "standard_64k": "LLM_MODEL_LARGE",
- "reasoning": "LLM_MODEL_REASONING",
- "tools": "LLM_MODEL_TOOLS",
- }
- model = os.getenv(models.get(model_alias, "LLM_MODEL"))
- self.model = model
- return model
-
- def count_tokens(self):
- num_tokens = 0
- for i in self.messages:
- for k, v in i.items():
- if k == "content":
- if not isinstance(v, str):
- v = str(v)
- tokens = tokenizer.encode(v)
- num_tokens += len(tokens)
- return int(num_tokens)
-
- def get_least_conn_server(self):
- try:
- response = requests.get("http://192.168.1.12:5000/least_conn")
- response.raise_for_status()
- # Extract the least connected server from the response
- least_conn_server = response.headers.get("X-Upstream-Address")
- return least_conn_server
- except requests.RequestException as e:
- print_red("Error getting least connected server:", e)
- return None
-
- def generate(
- self,
- query: str = None,
- user_input: str = None,
- context: str = None,
- stream: bool = False,
- tools: list = None,
- images: list = None,
- model: Optional[
- Literal["small", "standard", "vision", "reasoning", "tools"]
- ] = None,
- temperature: float = None,
- messages: list[dict] = None,
- format = None,
- think = False
- ):
- """
- Generate a response based on the provided query and context.
- Parameters:
- query (str): The query string from the user.
- user_input (str): Additional user input to be appended to the last message.
- context (str): Contextual information to be used in generating the response.
- stream (bool): Whether to stream the response.
- tools (list): List of tools to be used in generating the response.
- images (list): List of images to be included in the response.
- model (Optional[Literal["small", "standard", "vision", "tools"]]): The model type to be used.
- temperature (float): The temperature setting for the model.
- messages (list[dict]): List of previous messages in the conversation.
- format (Optional[BaseModel]): The format of the response.
- think (bool): Whether to use the reasoning model.
-
- Returns:
- str: The generated response or an error message if an exception occurs.
- """
- print_yellow(stream)
- print_yellow("GENERATE")
- # Prepare the model and temperature
-
- model = self.get_model(model) if model else self.model
- # if model == self.get_model('tools'):
- # stream = False
- temperature = temperature if temperature else self.options["temperature"]
-
- if messages:
- messages = [
- {"role": i["role"], "content": re.sub(r"\s*\n\s*", "\n", i["content"])}
- for i in messages
- ]
- message = messages.pop(-1)
- query = message["content"]
- self.messages = messages
- else:
- # Normalize whitespace and add the query to the messages
- query = re.sub(r"\s*\n\s*", "\n", query)
- message = {"role": "user", "content": query}
-
- # Handle images if any
- if images:
- message = self.prepare_images(images, message)
- model = self.get_model("vision")
-
- self.messages.append(message)
-
- # Prepare headers
- headers = {"Authorization": f"Basic {self.get_credentials()}"}
- if self.chosen_backend and model not in [self.get_model("vision"), self.get_model("tools"), self.get_model("reasoning")]: #TODO Maybe reasoning shouldn't be here.
- headers["X-Chosen-Backend"] = self.chosen_backend
-
- if model == self.get_model("small"):
- headers["X-Model-Type"] = "small"
- if model == self.get_model("tools"):
- headers["X-Model-Type"] = "tools"
-
- reasoning_models = ['qwen3', 'deepseek'] #TODO Add more reasoning models here when added to ollama
- if any([model_name in model for model_name in reasoning_models]):
- if think:
- query = f"/think\n{query}"
- else:
- query = f"/no_think\n{query}"
-
- # Prepare options
- options = Options(**self.options)
- options.temperature = temperature
-
- print_yellow("Stream the answer?", stream)
-
- # Call the client.chat method
- try:
- self.call_model = model
- self.client: Client = Client(host=self.host_url, headers=headers, timeout=300) #!
- #print_rainbow(self.client._client.__dict__)
- print_yellow("Model used in call:", model)
- # if headers:
- # self.client.headers.update(headers)
-
- response = self.client.chat(
- model=model,
- messages=self.messages,
- tools=tools,
- stream=stream,
- options=options,
- keep_alive=3600 * 24 * 7,
- format=format
- )
-
- except ResponseError as e:
- print_red("Error!")
- print(e)
- return "An error occurred."
- # print_rainbow(response.__dict__)
- # If user_input is provided, update the last message
-
- if user_input:
- if context:
- if len(context) > 2000:
- context = self.make_summary(context)
- user_input = (
- f"{user_input}\n\nUse the information below to answer the question.\n"
- f'"""{context}"""\n[This is a summary of the context provided in the original message.]'
- )
- system_message_info = "\nSometimes some of the messages in the chat history are summarised, then that is clearly indicated in the message."
- if system_message_info not in self.messages[0]["content"]:
- self.messages[0]["content"] += system_message_info
- self.messages[-1] = {"role": "user", "content": user_input}
-
- # self.chosen_backend = self.client.last_response.headers.get("X-Chosen-Backend")
-
- # Handle streaming response
- if stream:
- print_purple("STREAMING")
- return self.read_stream(response)
- else:
- print_purple("NOT STREAMING")
- # Process the response
- if isinstance(response, ChatResponse):
- result = response.message.content.strip('"')
- if '' in result:
- result = result.split('')[-1]
- self.messages.append(
- {"role": "assistant", "content": result.strip('"')}
- )
- if tools and not response.message.get("tool_calls"):
- print_yellow("No tool calls in response".upper())
- if not self.chat:
- self.messages = [self.messages[0]]
-
- if not think:
- response.message.content = remove_thinking(response.message.content)
- return response.message
- else:
- print_red("Unexpected response type")
- return "An error occurred."
-
- def make_summary(self, text):
- # Implement your summary logic using self.client.chat()
- summary_message = {
- "role": "user",
- "content": f'Summarize the text below:\n"""{text}"""\nRemember to be concise and detailed. Answer in English.',
- }
- messages = [
- {
- "role": "system",
- "content": "You are summarizing a text. Make it detailed and concise. Answer ONLY with the summary. Don't add any new information.",
- },
- summary_message,
- ]
- try:
- response = self.client.chat(
- model=self.get_model("small"),
- messages=messages,
- options=Options(temperature=0.01),
- keep_alive=3600 * 24 * 7,
- )
- summary = response.message.content.strip()
- print_blue("Summary:", summary)
- return summary
- except ResponseError as e:
- print_red("Error generating summary:", e)
- return "Summary generation failed."
-
- def read_stream(self, response):
- """
- Yields tuples of (chunk_type, text). The first tuple is ('thinking', ...)
- if in_thinking is True and stops at . After that, yields ('normal', ...)
- for the rest of the text.
- """
- thinking_buffer = ""
- in_thinking = self.call_model == self.get_model("reasoning")
- first_chunk = True
- prev_content = None
-
- for chunk in response:
- if not chunk:
- continue
- content = chunk.message.content
-
- # Remove leading quote if it's the first chunk
- if first_chunk and content.startswith('"'):
- content = content[1:]
- first_chunk = False
-
- if in_thinking:
- thinking_buffer += content
- if "" in thinking_buffer:
- end_idx = thinking_buffer.index("") + len("")
- yield ("thinking", thinking_buffer[:end_idx])
- remaining = thinking_buffer[end_idx:].strip('"')
- if chunk.done and remaining:
- yield ("normal", remaining)
- break
- else:
- prev_content = remaining
- in_thinking = False
- else:
- if prev_content:
- yield ("normal", prev_content)
- prev_content = content
-
- if chunk.done:
- if prev_content and prev_content.endswith('"'):
- prev_content = prev_content[:-1]
- if prev_content:
- yield ("normal", prev_content)
- break
-
- self.messages.append({"role": "assistant", "content": ""})
-
- async def async_generate(
- self,
- query: str = None,
- user_input: str = None,
- context: str = None,
- stream: bool = False,
- tools: list = None,
- images: list = None,
- model: Optional[Literal["small", "standard", "vision"]] = None,
- temperature: float = None,
- ):
- """
- Asynchronously generates a response based on the provided query and other parameters.
-
- Args:
- query (str, optional): The query string to generate a response for.
- user_input (str, optional): Additional user input to be included in the response.
- context (str, optional): Context information to be used in generating the response.
- stream (bool, optional): Whether to stream the response. Defaults to False.
- tools (list, optional): List of tools to be used in generating the response. Will set the model to 'tools'.
- images (list, optional): List of images to be included in the response.
- model (Optional[Literal["small", "standard", "vision", "tools"]], optional): The model to be used for generating the response.
- temperature (float, optional): The temperature setting for the model.
-
- Returns:
- str: The generated response or an error message if an exception occurs.
-
- Raises:
- ResponseError: If an error occurs during the response generation.
-
- Notes:
- - The function prepares the model and temperature settings.
- - It normalizes whitespace in the query and handles images if provided.
- - It prepares headers and options for the request.
- - It adjusts options for long messages and calls the async client's chat method.
- - If user_input is provided, it updates the last message.
- - It updates the chosen backend based on the response headers.
- - It handles streaming responses and processes the response accordingly.
- - It's not neccecary to set model to 'tools' if you provide tools as an argument.
- """
- print_yellow("ASYNC GENERATE")
- # Normaliz e whitespace and add the query to the messages
- query = re.sub(r"\s*\n\s*", "\n", query)
- message = {"role": "user", "content": query}
- self.messages.append(message)
-
- # Prepare the model and temperature
- model = self.get_model(model) if model else self.model
- temperature = temperature if temperature else self.options["temperature"]
-
- # Prepare options
- options = Options(**self.options)
- options.temperature = temperature
-
- # Prepare headers
- headers = {}
-
- # Set model depending on the input
- if images:
- message = self.prepare_images(images, message)
- model = self.get_model("vision")
- elif tools:
- model = self.get_model("tools")
- headers["X-Model-Type"] = "tools"
- tools = [Tool(**tool) if isinstance(tool, dict) else tool for tool in tools]
- elif self.chosen_backend and model not in [self.get_model("vision"), self.get_model("tools"), self.get_model("reasoning")]:
- headers["X-Chosen-Backend"] = self.chosen_backend
- elif model == self.get_model("small"):
- headers["X-Model-Type"] = "small"
-
- # Adjust options for long messages
- if self.chat or len(self.messages) > 15000:
- num_tokens = self.count_tokens() + self.max_length_answer // 2
- if num_tokens > 8000 and model not in [
- self.get_model("vision"),
- self.get_model("tools"),
- ]:
- model = self.get_model("standard_64k")
- headers["X-Model-Type"] = "large"
-
- # Call the async client's chat method
- try:
- response = await self.async_client.chat(
- model=model,
- messages=self.messages,
- headers=headers,
- tools=tools,
- stream=stream,
- options=options,
- keep_alive=3600 * 24 * 7,
- )
- except ResponseError as e:
- print_red("Error!")
- print(e)
- return "An error occurred."
-
- # If user_input is provided, update the last message
- if user_input:
- if context:
- if len(context) > 2000:
- context = self.make_summary(context)
- user_input = (
- f"{user_input}\n\nUse the information below to answer the question.\n"
- f'"""{context}"""\n[This is a summary of the context provided in the original message.]'
- )
- system_message_info = "\nSometimes some of the messages in the chat history are summarised, then that is clearly indicated in the message."
- if system_message_info not in self.messages[0]["content"]:
- self.messages[0]["content"] += system_message_info
- self.messages[-1] = {"role": "user", "content": user_input}
-
- print_red(self.async_client.last_response.headers.get("X-Chosen-Backend", "No backend"))
- # Update chosen_backend
- if model not in [self.get_model("vision"), self.get_model("tools"), self.get_model("reasoning")]:
- self.chosen_backend = self.async_client.last_response.headers.get(
- "X-Chosen-Backend"
- )
-
- # Handle streaming response
- if stream:
- return self.read_stream(response)
- else:
- # Process the response
- if isinstance(response, ChatResponse):
- result = response.message.content.strip('"')
- self.messages.append(
- {"role": "assistant", "content": result.strip('"')}
- )
- if tools and not response.message.get("tool_calls"):
- print_yellow("No tool calls in response".upper())
- if not self.chat:
- self.messages = [self.messages[0]]
- return result
- else:
- print_red("Unexpected response type")
- return "An error occurred."
-
- def prepare_images(self, images, message):
- """
- Prepares a list of images by converting them to base64 encoded strings and adds them to the provided message dictionary.
- Args:
- images (list): A list of images, where each image can be a file path (str), a base64 encoded string (str), or bytes.
- message (dict): A dictionary to which the base64 encoded images will be added under the key "images".
- Returns:
- dict: The updated message dictionary with the base64 encoded images added under the key "images".
- Raises:
- ValueError: If an image is not a string or bytes.
- """
- import base64
-
- base64_images = []
- base64_pattern = re.compile(r"^[A-Za-z0-9+/]+={0,2}$")
-
- for image in images:
- if isinstance(image, str):
- if base64_pattern.match(image):
- base64_images.append(image)
- else:
- with open(image, "rb") as image_file:
- base64_images.append(
- base64.b64encode(image_file.read()).decode("utf-8")
- )
- elif isinstance(image, bytes):
- base64_images.append(base64.b64encode(image).decode("utf-8"))
- else:
- print_red("Invalid image type")
-
- message["images"] = base64_images
- # Use the vision model
-
- return message
-
-def remove_thinking(response):
- """Remove the thinking section from the response"""
- response_text = response.content if hasattr(response, "content") else str(response)
- if "" in response_text:
- return response_text.split("")[1].strip()
- return response_text
-
-if __name__ == "__main__":
-
- llm = LLM()
-
- result = llm.generate(
- query="I want to add 2 and 2",
- )
- print(result.content)
diff --git a/_llmOLD.py b/_llmOLD.py
new file mode 100644
index 0000000..93ea770
--- /dev/null
+++ b/_llmOLD.py
@@ -0,0 +1,581 @@
+from _llm import LLM
+
+
+if __name__ == "__main__":
+ llm = LLM()
+
+ result = llm.generate(
+ query="I want to add 2 and 2",
+ think=True,
+ )
+ print(result)
+# import os
+# import base64
+# import re
+# from typing import Literal, Optional
+# from pydantic import BaseModel
+# import requests
+# import tiktoken
+# from ollama import (
+# Client,
+# AsyncClient,
+# ResponseError,
+# ChatResponse,
+# Tool,
+# Options,
+# )
+
+# import env_manager
+# from colorprinter.print_color import *
+
+# env_manager.set_env()
+
+# tokenizer = tiktoken.get_encoding("cl100k_base")
+
+
+# class LLM:
+# """
+# LLM class for interacting with an instance of Ollama.
+
+# Attributes:
+# model (str): The model to be used for response generation.
+# system_message (str): The system message to be used in the chat.
+# options (dict): Options for the model, such as temperature.
+# messages (list): List of messages in the chat.
+# max_length_answer (int): Maximum length of the generated answer.
+# chat (bool): Whether the chat mode is enabled.
+# chosen_backend (str): The chosen backend server for the API.
+# client (Client): The client for synchronous API calls.
+# async_client (AsyncClient): The client for asynchronous API calls.
+# tools (list): List of tools to be used in generating the response.
+
+# Methods:
+# __init__(self, system_message, temperature, model, max_length_answer, messages, chat, chosen_backend):
+# Initializes the LLM class with the provided parameters.
+
+# get_model(self, model_alias):
+# Retrieves the model name based on the provided alias.
+
+# count_tokens(self):
+# Counts the number of tokens in the messages.
+
+# get_least_conn_server(self):
+# Retrieves the least connected server from the backend.
+
+# generate(self, query, user_input, context, stream, tools, images, model, temperature):
+# Generates a response based on the provided query and options.
+
+# make_summary(self, text):
+# Generates a summary of the provided text.
+
+# read_stream(self, response):
+# Handles streaming responses.
+
+# async_generate(self, query, user_input, context, stream, tools, images, model, temperature):
+# Asynchronously generates a response based on the provided query and options.
+
+# prepare_images(self, images, message):
+# """
+
+# def __init__(
+# self,
+# system_message: str = "You are an assistant.",
+# temperature: float = 0.01,
+# model: Optional[
+# Literal["small", "standard", "vision", "reasoning", "tools"]
+# ] = "standard",
+# max_length_answer: int = 4096,
+# messages: list[dict] = None,
+# chat: bool = True,
+# chosen_backend: str = None,
+# tools: list = None,
+# ) -> None:
+# """
+# Initialize the assistant with the given parameters.
+
+# Args:
+# system_message (str): The initial system message for the assistant. Defaults to "You are an assistant.".
+# temperature (float): The temperature setting for the model, affecting randomness. Defaults to 0.01.
+# model (Optional[Literal["small", "standard", "vision", "reasoning"]]): The model type to use. Defaults to "standard".
+# max_length_answer (int): The maximum length of the generated answer. Defaults to 4096.
+# messages (list[dict], optional): A list of initial messages. Defaults to None.
+# chat (bool): Whether the assistant is in chat mode. Defaults to True.
+# chosen_backend (str, optional): The backend server to use. If not provided, the least connected server is chosen.
+
+# Returns:
+# None
+# """
+
+# self.model = self.get_model(model)
+# self.call_model = (
+# self.model
+# ) # This is set per call to decide what model that was actually used
+# self.system_message = system_message
+# self.options = {"temperature": temperature}
+# self.messages = messages or [{"role": "system", "content": self.system_message}]
+# self.max_length_answer = max_length_answer
+# self.chat = chat
+
+# if not chosen_backend:
+# chosen_backend = self.get_least_conn_server()
+# self.chosen_backend = chosen_backend
+
+
+# headers = {
+# "Authorization": f"Basic {self.get_credentials()}",
+# "X-Chosen-Backend": self.chosen_backend,
+# }
+# self.host_url = os.getenv("LLM_API_URL").rstrip("/api/chat/")
+# self.host_url = 'http://192.168.1.12:3300' #! Change back when possible
+# self.client: Client = Client(host=self.host_url, headers=headers, timeout=240)
+# self.async_client: AsyncClient = AsyncClient()
+
+# def get_credentials(self):
+# # Initialize the client with the host and default headers
+# credentials = f"{os.getenv('LLM_API_USER')}:{os.getenv('LLM_API_PWD_LASSE')}"
+# return base64.b64encode(credentials.encode()).decode()
+
+# def get_model(self, model_alias):
+
+# models = {
+# "standard": "LLM_MODEL",
+# "small": "LLM_MODEL_SMALL",
+# "vision": "LLM_MODEL_VISION",
+# "standard_64k": "LLM_MODEL_LARGE",
+# "reasoning": "LLM_MODEL_REASONING",
+# "tools": "LLM_MODEL_TOOLS",
+# }
+# model = os.getenv(models.get(model_alias, "LLM_MODEL"))
+# self.model = model
+# return model
+
+# def count_tokens(self):
+# num_tokens = 0
+# for i in self.messages:
+# for k, v in i.items():
+# if k == "content":
+# if not isinstance(v, str):
+# v = str(v)
+# tokens = tokenizer.encode(v)
+# num_tokens += len(tokens)
+# return int(num_tokens)
+
+# def get_least_conn_server(self):
+# try:
+# response = requests.get("http://192.168.1.12:5000/least_conn")
+# response.raise_for_status()
+# # Extract the least connected server from the response
+# least_conn_server = response.headers.get("X-Upstream-Address")
+# return least_conn_server
+# except requests.RequestException as e:
+# print_red("Error getting least connected server:", e)
+# return None
+
+# def generate(
+# self,
+# query: str = None,
+# user_input: str = None,
+# context: str = None,
+# stream: bool = False,
+# tools: list = None,
+# images: list = None,
+# model: Optional[
+# Literal["small", "standard", "vision", "reasoning", "tools"]
+# ] = None,
+# temperature: float = None,
+# messages: list[dict] = None,
+# format: BaseModel = None,
+# think: bool = False
+# ):
+# """
+# Generate a response based on the provided query and context.
+# Parameters:
+# query (str): The query string from the user.
+# user_input (str): Additional user input to be appended to the last message.
+# context (str): Contextual information to be used in generating the response.
+# stream (bool): Whether to stream the response.
+# tools (list): List of tools to be used in generating the response.
+# images (list): List of images to be included in the response.
+# model (Optional[Literal["small", "standard", "vision", "tools"]]): The model type to be used.
+# temperature (float): The temperature setting for the model.
+# messages (list[dict]): List of previous messages in the conversation.
+# format (Optional[BaseModel]): The format of the response.
+# think (bool): Whether to use the reasoning model.
+
+# Returns:
+# str: The generated response or an error message if an exception occurs.
+# """
+
+# # Prepare the model and temperature
+
+# model = self.get_model(model) if model else self.model
+# # if model == self.get_model('tools'):
+# # stream = False
+# temperature = temperature if temperature else self.options["temperature"]
+
+# if messages:
+# messages = [
+# {"role": i["role"], "content": re.sub(r"\s*\n\s*", "\n", i["content"])}
+# for i in messages
+# ]
+# message = messages.pop(-1)
+# query = message["content"]
+# self.messages = messages
+# else:
+# # Normalize whitespace and add the query to the messages
+# query = re.sub(r"\s*\n\s*", "\n", query)
+# message = {"role": "user", "content": query}
+
+# # Handle images if any
+# if images:
+# message = self.prepare_images(images, message)
+# model = self.get_model("vision")
+
+# self.messages.append(message)
+
+# # Prepare headers
+# headers = {"Authorization": f"Basic {self.get_credentials()}"}
+# if self.chosen_backend and model not in [self.get_model("vision"), self.get_model("tools"), self.get_model("reasoning")]: #TODO Maybe reasoning shouldn't be here.
+# headers["X-Chosen-Backend"] = self.chosen_backend
+
+# if model == self.get_model("small"):
+# headers["X-Model-Type"] = "small"
+# if model == self.get_model("tools"):
+# headers["X-Model-Type"] = "tools"
+
+# reasoning_models = ['qwen3', 'deepseek'] #TODO Add more reasoning models here when added to ollama
+# if any([model_name in model for model_name in reasoning_models]):
+# if think:
+# self.messages[-1]['content'] = f"/think\n{self.messages[-1]['content']}"
+# else:
+# self.messages[-1]['content'] = f"/no_think\n{self.messages[-1]['content']}"
+
+# # Prepare options
+# options = Options(**self.options)
+# options.temperature = temperature
+
+# # Call the client.chat method
+# try:
+# self.call_model = model
+# self.client: Client = Client(host=self.host_url, headers=headers, timeout=300) #!
+# #print_rainbow(self.client._client.__dict__)
+# print_yellow(f"🤖 Generating using {model}...")
+# # if headers:
+# # self.client.headers.update(headers)
+# response = self.client.chat(
+# model=model,
+# messages=self.messages,
+# tools=tools,
+# stream=stream,
+# options=options,
+# keep_alive=3600 * 24 * 7,
+# format=format
+# )
+
+# except ResponseError as e:
+# print_red("Error!")
+# print(e)
+# return "An error occurred."
+# # print_rainbow(response.__dict__)
+# # If user_input is provided, update the last message
+
+# if user_input:
+# if context:
+# if len(context) > 2000:
+# context = self.make_summary(context)
+# user_input = (
+# f"{user_input}\n\nUse the information below to answer the question.\n"
+# f'"""{context}"""\n[This is a summary of the context provided in the original message.]'
+# )
+# system_message_info = "\nSometimes some of the messages in the chat history are summarised, then that is clearly indicated in the message."
+# if system_message_info not in self.messages[0]["content"]:
+# self.messages[0]["content"] += system_message_info
+# self.messages[-1] = {"role": "user", "content": user_input}
+
+# # self.chosen_backend = self.client.last_response.headers.get("X-Chosen-Backend")
+
+# # Handle streaming response
+# if stream:
+# print_purple("STREAMING")
+# return self.read_stream(response)
+# else:
+# # Process the response
+# if isinstance(response, ChatResponse):
+# result = response.message.content.strip('"')
+# if '' in result:
+# result = result.split('')[-1]
+# self.messages.append(
+# {"role": "assistant", "content": result.strip('"')}
+# )
+# if tools and not response.message.get("tool_calls"):
+# print_yellow("No tool calls in response".upper())
+# if not self.chat:
+# self.messages = [self.messages[0]]
+
+# if not think:
+# response.message.content = remove_thinking(response.message.content)
+# return response.message
+# else:
+# print_red("Unexpected response type")
+# return "An error occurred."
+
+# def make_summary(self, text):
+# # Implement your summary logic using self.client.chat()
+# summary_message = {
+# "role": "user",
+# "content": f'Summarize the text below:\n"""{text}"""\nRemember to be concise and detailed. Answer in English.',
+# }
+# messages = [
+# {
+# "role": "system",
+# "content": "You are summarizing a text. Make it detailed and concise. Answer ONLY with the summary. Don't add any new information.",
+# },
+# summary_message,
+# ]
+# try:
+# response = self.client.chat(
+# model=self.get_model("small"),
+# messages=messages,
+# options=Options(temperature=0.01),
+# keep_alive=3600 * 24 * 7,
+# )
+# summary = response.message.content.strip()
+# print_blue("Summary:", summary)
+# return summary
+# except ResponseError as e:
+# print_red("Error generating summary:", e)
+# return "Summary generation failed."
+
+# def read_stream(self, response):
+# """
+# Yields tuples of (chunk_type, text). The first tuple is ('thinking', ...)
+# if in_thinking is True and stops at . After that, yields ('normal', ...)
+# for the rest of the text.
+# """
+# thinking_buffer = ""
+# in_thinking = self.call_model == self.get_model("reasoning")
+# first_chunk = True
+# prev_content = None
+
+# for chunk in response:
+# if not chunk:
+# continue
+# content = chunk.message.content
+
+# # Remove leading quote if it's the first chunk
+# if first_chunk and content.startswith('"'):
+# content = content[1:]
+# first_chunk = False
+
+# if in_thinking:
+# thinking_buffer += content
+# if "" in thinking_buffer:
+# end_idx = thinking_buffer.index("") + len("")
+# yield ("thinking", thinking_buffer[:end_idx])
+# remaining = thinking_buffer[end_idx:].strip('"')
+# if chunk.done and remaining:
+# yield ("normal", remaining)
+# break
+# else:
+# prev_content = remaining
+# in_thinking = False
+# else:
+# if prev_content:
+# yield ("normal", prev_content)
+# prev_content = content
+
+# if chunk.done:
+# if prev_content and prev_content.endswith('"'):
+# prev_content = prev_content[:-1]
+# if prev_content:
+# yield ("normal", prev_content)
+# break
+
+# self.messages.append({"role": "assistant", "content": ""})
+
+# async def async_generate(
+# self,
+# query: str = None,
+# user_input: str = None,
+# context: str = None,
+# stream: bool = False,
+# tools: list = None,
+# images: list = None,
+# model: Optional[Literal["small", "standard", "vision"]] = None,
+# temperature: float = None,
+# ):
+# """
+# Asynchronously generates a response based on the provided query and other parameters.
+
+# Args:
+# query (str, optional): The query string to generate a response for.
+# user_input (str, optional): Additional user input to be included in the response.
+# context (str, optional): Context information to be used in generating the response.
+# stream (bool, optional): Whether to stream the response. Defaults to False.
+# tools (list, optional): List of tools to be used in generating the response. Will set the model to 'tools'.
+# images (list, optional): List of images to be included in the response.
+# model (Optional[Literal["small", "standard", "vision", "tools"]], optional): The model to be used for generating the response.
+# temperature (float, optional): The temperature setting for the model.
+
+# Returns:
+# str: The generated response or an error message if an exception occurs.
+
+# Raises:
+# ResponseError: If an error occurs during the response generation.
+
+# Notes:
+# - The function prepares the model and temperature settings.
+# - It normalizes whitespace in the query and handles images if provided.
+# - It prepares headers and options for the request.
+# - It adjusts options for long messages and calls the async client's chat method.
+# - If user_input is provided, it updates the last message.
+# - It updates the chosen backend based on the response headers.
+# - It handles streaming responses and processes the response accordingly.
+# - It's not neccecary to set model to 'tools' if you provide tools as an argument.
+# """
+# print_yellow("ASYNC GENERATE")
+# # Normaliz e whitespace and add the query to the messages
+# query = re.sub(r"\s*\n\s*", "\n", query)
+# message = {"role": "user", "content": query}
+# self.messages.append(message)
+
+# # Prepare the model and temperature
+# model = self.get_model(model) if model else self.model
+# temperature = temperature if temperature else self.options["temperature"]
+
+# # Prepare options
+# options = Options(**self.options)
+# options.temperature = temperature
+
+# # Prepare headers
+# headers = {}
+
+# # Set model depending on the input
+# if images:
+# message = self.prepare_images(images, message)
+# model = self.get_model("vision")
+# elif tools:
+# model = self.get_model("tools")
+# headers["X-Model-Type"] = "tools"
+# tools = [Tool(**tool) if isinstance(tool, dict) else tool for tool in tools]
+# elif self.chosen_backend and model not in [self.get_model("vision"), self.get_model("tools"), self.get_model("reasoning")]:
+# headers["X-Chosen-Backend"] = self.chosen_backend
+# elif model == self.get_model("small"):
+# headers["X-Model-Type"] = "small"
+
+# # Adjust options for long messages
+# if self.chat or len(self.messages) > 15000:
+# num_tokens = self.count_tokens() + self.max_length_answer // 2
+# if num_tokens > 8000 and model not in [
+# self.get_model("vision"),
+# self.get_model("tools"),
+# ]:
+# model = self.get_model("standard_64k")
+# headers["X-Model-Type"] = "large"
+
+# # Call the async client's chat method
+# try:
+# response = await self.async_client.chat(
+# model=model,
+# messages=self.messages,
+# headers=headers,
+# tools=tools,
+# stream=stream,
+# options=options,
+# keep_alive=3600 * 24 * 7,
+# )
+# except ResponseError as e:
+# print_red("Error!")
+# print(e)
+# return "An error occurred."
+
+# # If user_input is provided, update the last message
+# if user_input:
+# if context:
+# if len(context) > 2000:
+# context = self.make_summary(context)
+# user_input = (
+# f"{user_input}\n\nUse the information below to answer the question.\n"
+# f'"""{context}"""\n[This is a summary of the context provided in the original message.]'
+# )
+# system_message_info = "\nSometimes some of the messages in the chat history are summarised, then that is clearly indicated in the message."
+# if system_message_info not in self.messages[0]["content"]:
+# self.messages[0]["content"] += system_message_info
+# self.messages[-1] = {"role": "user", "content": user_input}
+
+# print_red(self.async_client.last_response.headers.get("X-Chosen-Backend", "No backend"))
+# # Update chosen_backend
+# if model not in [self.get_model("vision"), self.get_model("tools"), self.get_model("reasoning")]:
+# self.chosen_backend = self.async_client.last_response.headers.get(
+# "X-Chosen-Backend"
+# )
+
+# # Handle streaming response
+# if stream:
+# return self.read_stream(response)
+# else:
+# # Process the response
+# if isinstance(response, ChatResponse):
+# result = response.message.content.strip('"')
+# self.messages.append(
+# {"role": "assistant", "content": result.strip('"')}
+# )
+# if tools and not response.message.get("tool_calls"):
+# print_yellow("No tool calls in response".upper())
+# if not self.chat:
+# self.messages = [self.messages[0]]
+# return result
+# else:
+# print_red("Unexpected response type")
+# return "An error occurred."
+
+# def prepare_images(self, images, message):
+# """
+# Prepares a list of images by converting them to base64 encoded strings and adds them to the provided message dictionary.
+# Args:
+# images (list): A list of images, where each image can be a file path (str), a base64 encoded string (str), or bytes.
+# message (dict): A dictionary to which the base64 encoded images will be added under the key "images".
+# Returns:
+# dict: The updated message dictionary with the base64 encoded images added under the key "images".
+# Raises:
+# ValueError: If an image is not a string or bytes.
+# """
+# import base64
+
+# base64_images = []
+# base64_pattern = re.compile(r"^[A-Za-z0-9+/]+={0,2}$")
+
+# for image in images:
+# if isinstance(image, str):
+# if base64_pattern.match(image):
+# base64_images.append(image)
+# else:
+# with open(image, "rb") as image_file:
+# base64_images.append(
+# base64.b64encode(image_file.read()).decode("utf-8")
+# )
+# elif isinstance(image, bytes):
+# base64_images.append(base64.b64encode(image).decode("utf-8"))
+# else:
+# print_red("Invalid image type")
+
+# message["images"] = base64_images
+# # Use the vision model
+
+# return message
+
+# def remove_thinking(response):
+# """Remove the thinking section from the response"""
+# response_text = response.content if hasattr(response, "content") else str(response)
+# if "" in response_text:
+# return response_text.split("")[1].strip()
+# return response_text
+
+# if __name__ == "__main__":
+
+# llm = LLM()
+
+# result = llm.generate(
+# query="I want to add 2 and 2",
+# )
+# print(result.content)
diff --git a/agent_research.py b/agent_research.py
new file mode 100644
index 0000000..3310fe8
--- /dev/null
+++ b/agent_research.py
@@ -0,0 +1,1448 @@
+from _llm import LLM
+from streamlit_chatbot import Bot
+from typing import Dict, List, Tuple, Optional, Any
+from colorprinter.print_color import *
+from projects_page import Project
+from _base_class import BaseClass
+from prompts import get_tools_prompt
+import time
+import traceback
+import json
+from datetime import datetime
+
+import llm_queries
+from models import EvaluateFormat, Plan, ChunkSearchResults, UnifiedSearchResults, UnifiedDataChunk, UnifiedToolResponse
+
+
+class ResearchReport:
+ """Class for tracking and logging decisions and data access during research"""
+
+ def __init__(self, question, username, project_name=None):
+ self.report = {
+ "metadata": {
+ "question": question,
+ "username": username,
+ "project_name": project_name,
+ "started_at": datetime.now().isoformat(),
+ "finished_at": None,
+ },
+ "plan": {"original_text": None, "structured": None, "subquestions": []},
+ "steps": {},
+ "evaluation": None,
+ "final_report": None,
+ "statistics": {
+ "tools_used": {},
+ "sources_accessed": [],
+ "total_time": None,
+ },
+ }
+ # Track current context for easier logging
+ self.current_step = None
+ self.current_task = None
+
+ def log_plan(self, original_plan, structured_plan=None):
+ """Log the research plan"""
+ self.report["plan"]["original_text"] = original_plan
+ if structured_plan:
+ self.report["plan"]["structured"] = structured_plan
+
+ def start_step(self, step_name):
+ """Mark the beginning of a new step"""
+ self.current_step = step_name
+ if step_name not in self.report["steps"]:
+ self.report["steps"][step_name] = {
+ "started_at": datetime.now().isoformat(),
+ "finished_at": None,
+ "tasks": {},
+ "tools_used": [],
+ "information_gathered": [],
+ "summary": None,
+ "evaluation": None,
+ }
+ return self.current_step
+
+ def start_task(self, task_name, task_description):
+ """Mark the beginning of a new task within the current step"""
+ if not self.current_step:
+ raise ValueError("Cannot start task without active step")
+
+ self.current_task = task_name
+ self.report["steps"][self.current_step]["tasks"][task_name] = {
+ "description": task_description,
+ "started_at": datetime.now().isoformat(),
+ "finished_at": None,
+ "tools_used": [],
+ "information_gathered": [],
+ }
+ return self.current_task
+
+ def log_tool_use(self, tool_name, tool_args):
+ """Log when a tool is used"""
+ if not self.current_step:
+ raise ValueError("Cannot log tool use without active step")
+
+ # Add to step level
+ self.report["steps"][self.current_step]["tools_used"].append(
+ {
+ "tool": tool_name,
+ "args": tool_args,
+ "timestamp": datetime.now().isoformat(),
+ }
+ )
+
+ # Add to task level if we have an active task
+ if self.current_task:
+ self.report["steps"][self.current_step]["tasks"][self.current_task][
+ "tools_used"
+ ].append(
+ {
+ "tool": tool_name,
+ "args": tool_args,
+ "timestamp": datetime.now().isoformat(),
+ }
+ )
+
+ # Update global statistics
+ if tool_name in self.report["statistics"]["tools_used"]:
+ self.report["statistics"]["tools_used"][tool_name] += 1
+ else:
+ self.report["statistics"]["tools_used"][tool_name] = 1
+
+ def log_information(self, information):
+ """Log information gathered from tools"""
+ if not self.current_step:
+ raise ValueError("Cannot log information without active step")
+
+ # Process information to extract sources
+ sources = self._extract_sources(information)
+
+ # Add unique sources to global statistics
+ for source in sources:
+ if source not in self.report["statistics"]["sources_accessed"]:
+ self.report["statistics"]["sources_accessed"].append(source)
+
+ # Add to step level
+ self.report["steps"][self.current_step]["information_gathered"].append(
+ {
+ "data": information,
+ "sources": sources,
+ "timestamp": datetime.now().isoformat(),
+ }
+ )
+
+ # Add to task level if we have an active task
+ if self.current_task:
+ self.report["steps"][self.current_step]["tasks"][self.current_task][
+ "information_gathered"
+ ].append(
+ {
+ "data": information,
+ "sources": sources,
+ "timestamp": datetime.now().isoformat(),
+ }
+ )
+
+ def _extract_sources(self, information):
+ """Extract source information from gathered data"""
+ sources = []
+
+ # Handle different result formats
+ for item in information:
+ try:
+ if "result" in item and "content" in item["result"]:
+ if isinstance(item["result"]["content"], dict):
+ # Handle structured content like chunks
+ for title, group in item["result"]["content"].items():
+ if "chunks" in group:
+ for chunk in group["chunks"]:
+ metadata = chunk.get("metadata", {})
+ source = f"{metadata.get('title', 'Unknown')}"
+ if metadata.get("journal"):
+ source += f" ({metadata.get('journal')})"
+ if source not in sources:
+ sources.append(source)
+ except Exception as e:
+ print_yellow(f"Error extracting sources: {e}")
+ sources.append("No source")
+
+ return sources
+
+ def update_step_summary(self, summary):
+ """Log summary of gathered information"""
+ if not self.current_step:
+ raise ValueError("Cannot log summary without active step")
+
+ self.report["steps"][self.current_step]["summary"] = remove_thinking(summary)
+
+ def log_evaluation(self, evaluation):
+ """Log evaluation of gathered information"""
+ if not self.current_step:
+ raise ValueError("Cannot log evaluation without active step")
+
+ self.report["steps"][self.current_step]["evaluation"] = evaluation
+
+ def finish_task(self):
+ """Mark the end of the current task"""
+ if not self.current_step or not self.current_task:
+ raise ValueError("No active task to finish")
+
+ self.report["steps"][self.current_step]["tasks"][self.current_task][
+ "finished_at"
+ ] = datetime.now().isoformat()
+ self.current_task = None
+
+ def finish_step(self):
+ """Mark the end of the current step"""
+ if not self.current_step:
+ raise ValueError("No active step to finish")
+
+ self.report["steps"][self.current_step][
+ "finished_at"
+ ] = datetime.now().isoformat()
+ self.current_step = None
+
+ def log_plan_evaluation(self, evaluation):
+ """Log the overall plan evaluation"""
+ self.report["evaluation"] = evaluation
+
+ def log_final_report(self, report):
+ """Log the final generated report"""
+ self.report["final_report"] = report
+ self.report["metadata"]["finished_at"] = datetime.now().isoformat()
+ # Calculate total time
+ start = datetime.fromisoformat(self.report["metadata"]["started_at"])
+ end = datetime.fromisoformat(self.report["metadata"]["finished_at"])
+ self.report["statistics"]["total_time"] = (end - start).total_seconds()
+
+ def get_full_report(self):
+ """Get the complete report data"""
+ return self.report
+
+ def get_markdown_report(self):
+ """Get the report formatted as markdown for easy viewing"""
+ md = f"# Research Report: {self.report['metadata']['question']}\n\n"
+
+ # Metadata
+ md += "## Metadata \n"
+ md += f"- **Project**: {self.report['metadata']['project_name'] or 'None'} \n"
+ md += f"- **User**: {self.report['metadata']['username']} \n"
+ md += f"- **Started**: {self.report['metadata']['started_at']} \n"
+ if self.report["metadata"]["finished_at"]:
+ md += f"- **Finished**: {self.report['metadata']['finished_at']} \n"
+ md += f"- **Total time**: {self.report['statistics']['total_time']:.2f} seconds \n"
+
+ # Statistics
+ md += "\n## Statistics \n"
+ md += "### Tools Used \n"
+ for tool, count in self.report["statistics"]["tools_used"].items():
+ md += f"- {tool}: {count} times \n"
+
+ md += "\n### Sources Accessed \n"
+ for source in self.report["statistics"]["sources_accessed"]:
+ md += f"- {source} \n"
+
+ # Research plan
+ md += "\n## Research Plan \n"
+ if self.report["plan"]["original_text"]:
+ md += f"\n{self.report['plan']['original_text']} \n\n"
+
+ # Steps
+ md += "\n## Research Steps \n"
+ for step_name, step_data in self.report["steps"].items():
+ md += f"### {step_name}\n"
+
+ if step_data.get("summary"):
+ md += f"**Summary**: {step_data['summary']} \n\n"
+
+ md += "**Tools used**: \n"
+ for tool in step_data["tools_used"]:
+ md += f"- {tool['tool']} with query: _{tool['args'].get('query', 'No query').replace('_', ' ')}_\n"
+
+ md += "\n**Tasks**:\n"
+ for task_name, task_data in step_data.get("tasks", {}).items():
+ md += f"- {task_name}: {task_data['description']}\n"
+
+ # Final report
+ if self.report["final_report"]:
+ md += "\n## Final Report \n"
+ md += self.report["final_report"]
+
+ return md
+
+ def save_to_file(self, filepath=None):
+ """Save the report to a file"""
+ if not filepath:
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ filepath = f"research_report_{timestamp}.json"
+
+ # Create a deep copy of the report that's JSON serializable
+ def make_json_serializable(obj):
+ """Convert any non-JSON serializable objects to dictionaries"""
+ if hasattr(obj, "model_dump"): # Check if it's a Pydantic model
+ return obj.model_dump() # Convert Pydantic models to dict
+ elif isinstance(obj, dict):
+ return {k: make_json_serializable(v) for k, v in obj.items()}
+ elif isinstance(obj, list):
+ return [make_json_serializable(item) for item in obj]
+ else:
+ return obj
+
+ # Create a JSON-serializable version of the report
+ json_report = make_json_serializable(self.report)
+
+ with open(filepath, "w") as f:
+ json.dump(json_report, f, indent=2)
+
+ print_green(f"Report saved to {filepath}")
+ return filepath
+
+
+class ResearchBase(Bot):
+ """Base class for all research agents with improved integration with Bot functionality"""
+
+ def __init__(
+ self,
+ username: str,
+ model: str = "standard",
+ chat: bool = True,
+ report=None,
+ **kwargs,
+ ):
+ super().__init__(username=username, **kwargs)
+ self.model: str = model
+ self.llm = LLM(
+ system_message="You are a research assistant.",
+ model=model,
+ chat=chat,
+ messages=[],
+ )
+
+ # Tracking for research flow
+ self.research_state = {
+ "current_step": None,
+ "current_task": None,
+ "start_time": time.time(),
+ "steps_completed": 0,
+ "tasks_completed": 0,
+ }
+
+ self.report: ResearchReport = report
+
+ # Define available tool functions
+ self.available_functions = {
+ "fetch_science_articles_tool": self.fetch_science_articles_tool,
+ "fetch_notes_tool": self.fetch_notes_tool,
+ "fetch_other_documents_tool": self.fetch_other_documents_tool,
+ "fetch_science_articles_and_other_documents_tool": self.fetch_science_articles_and_other_documents_tool,
+ "analyze_tool": self.analyze_tool,
+ }
+
+ self.tools = [
+ self.available_functions[tool] if isinstance(tool, str) else tool
+ for tool in self.available_functions
+ ]
+
+ def update_research_state(self, **kwargs):
+ """Update the research state with new information"""
+ self.research_state.update(kwargs)
+ current_time = time.time()
+ elapsed = current_time - self.research_state["start_time"]
+
+ # Log progress info
+ if "current_step" in kwargs or "current_task" in kwargs:
+ print_yellow(
+ f"Progress: Step {self.research_state.get('steps_completed', 0)}, "
+ f"Time elapsed: {elapsed:.1f}s"
+ )
+
+ # Update report if available
+ if self.report:
+ if "current_step" in kwargs and kwargs["current_step"]:
+ self.report.start_step(kwargs["current_step"])
+ if (
+ "current_task" in kwargs
+ and kwargs["current_task"]
+ and self.report.current_step
+ ):
+ # For simplicity, we're using the task description as the name too
+ self.report.start_task(
+ kwargs["current_task"], kwargs["current_task"]
+ )
+
+ def use_tools(self, tool_calls, task_description, task_description_as_query=True) -> UnifiedToolResponse:
+ """Execute the selected tools to gather information"""
+ self.update_research_state(current_task=f"Gathering information with tools")
+
+ gathered_information = UnifiedToolResponse()
+
+ for tool_call in tool_calls:
+ tool_name = tool_call.function.name
+ tool_args = tool_call.function.arguments
+
+ print_green(f"Using tool: {tool_name} with args: {tool_args}")
+
+ # Add the query to arguments if not already present
+ if "query" in tool_args:
+ if task_description_as_query:
+ tool_args["query"] = task_description
+ else:
+ if "query" not in tool_args and task_description:
+ tool_args["query"] = task_description
+
+ # Log tool use in report
+ if self.report:
+ self.report.log_tool_use(tool_name, tool_args)
+
+ try:
+ # Call the tool function
+ function_to_call = self.available_functions.get(tool_name)
+ if function_to_call:
+ result = function_to_call(**tool_args)
+ else:
+ result = f"Unknown tool: {tool_name}"
+
+ # Process the result
+ if isinstance(result, UnifiedSearchResults):
+ # Convert to a unified format
+ gathered_information.extend_search_results(result)
+ gathered_information.extend_tool_name(tool_name)
+ elif isinstance(result, str):
+ # Already in the correct format
+ gathered_information.extend_text_results(result)
+ gathered_information.extend_tool_name(tool_name)
+
+ # Log gathered information in report
+ # if self.report:
+ # self.report.log_information([gathered_info])
+
+ except Exception as e:
+ print_red(f"Error executing tool {tool_name}: {e}")
+ traceback.print_exc()
+ import sys
+
+ sys.exit(1)
+
+ return gathered_information
+
+ # Tool function definitions
+ def fetch_science_articles_tool(
+ self, query: str, n_documents: int = 6
+ ) -> ChunkSearchResults:
+ """
+ Fetches information from scientific articles.
+
+ Parameters:
+ query (str): The search query to find relevant scientific articles in a vector database.
+ n_documents (int): How many documents to fetch. A complex query may require more documents. Min: 3, Max: 10.
+
+ Returns:
+ ChunkSearchResults: A structured result containing articles with their chunks.
+ """
+
+ where_filter = {}
+ if hasattr(self, "chroma_ids_retrieved") and len(self.chroma_ids_retrieved) > 0:
+ where_filter = {"_id": {"$in": self.chroma_ids_retrieved}}
+
+
+ found_chunks = self.get_chunks(
+ user_input=query,
+ collections=["sci_articles"],
+ n_results=n_documents,
+ n_sources=max(n_documents, 4)
+ )
+
+ # Standardize the chunks using UnifiedDataChunk
+ unified_chunks = [
+ UnifiedDataChunk(
+ content=chunk.content,
+ metadata=chunk.metadata.model_dump(),
+ source_type="other_documents",
+ )
+ for chunk in found_chunks.chunks
+ ]
+ # Return the unified search results
+ return UnifiedSearchResults(chunks=unified_chunks, source_ids=[])
+
+
+ def fetch_notes_tool(self, **argv) -> UnifiedSearchResults:
+ """
+ Fetches project notes as a list for the researcher to understand what's important in the project.
+ This tool is useful for getting a quick overview of the project's key points.
+ Takes no arguments!
+
+ Returns:
+ UnifiedSearchResults: A structured result containing notes as data chunks.
+ """
+ chunks = []
+ for i, note in enumerate(self.get_notes()):
+ # Create a unified data chunk for each note
+ unified_chunk = UnifiedDataChunk(
+ content=note,
+ metadata={
+ "title": f"Note {i+1}", # Use a default string instead of None
+ "source": "Project Notes" # Add more metadata for better identification
+ },
+ source_type="note",
+ )
+ chunks.append(unified_chunk)
+ return UnifiedSearchResults(chunks=chunks, source_ids=[])
+
+
+class MasterAgent(ResearchBase):
+ """A large and reasoning (if not specified not to be) LLM that handles the complex thinking tasks and coordinates other agents"""
+
+ def __init__(
+ self, username: str, project: Project = None, model: str = "reasoning", tools: list=[], **kwargs
+ ):
+ # Configure for reasoning model
+ kwargs["model_config"] = {
+ "system_message": "You are an assistant helping a journalist writing a report based on extensive research.",
+ "temperature": 0.3,
+ "model": model,
+ "chat": True,
+ }
+ super().__init__(username=username, project=project, tools=tools, **kwargs)
+ self.model = model
+ self.available_sources = {}
+
+ # Initialize sub-agents
+ self.structure_agent = StructureAgent(
+ username=username, model="small", report=self.report
+ )
+ self.tool_agent = ToolAgent(
+ username=username,
+ model="tools",
+ system_message=f"You are an assistant with some tools. The tools you can choose from are {tools} Always choose a tool to help with the task. Your task is to choose one or multiple tools to answer a user's query. DON'T come up with your own tools, only use the ones provided.",
+ report=self.report,
+ project=project,
+ chat=True,
+ )
+ self.archive_agent = ArchiveAgent(
+ username=username,
+ report=self.report,
+ project=project,
+ system_message="""
+ You are an assistant specialized in reading and summarizing research information.
+ You are helping a researcher with a research divided in many steps, where you will get information for each step.
+ Your goal is to provide clear, accurate summaries that capture the essential points while maintaining context.
+ If your summary is deemed insufficient to complete the step, you will be given more information and asked for a updated summary. Please then include the previous information you got in your new summary so that all information is taken into account.
+ """,
+ chat=True,
+ )
+ self.assistant_agent = AssistantAgent(
+ username=username,
+ report=self.report,
+ project=project,
+ system_message="""
+ You are an assistant specialized in summarizing steps and keeping track of the research process.
+ Your job is to maintain a structured record of what has been done and help the researcher navigate
+ through the research process by providing clear summaries of steps completed.
+ """,
+ chat=True,
+ )
+
+ # Track execution results
+ self.execution_results = {}
+
+ def check_available_sources(self):
+ """
+ Check the available sources in the database and update the report.
+
+ This method iterates through the arango_ids stored in the instance and
+ counts the number of documents by type based on their ID prefixes.
+ It then updates the self.available_sources dictionary with counts for:
+ - other_documents: Documents with IDs starting with "other_documents"
+ - sci_articles: Scientific articles with IDs starting with "sci_articles"
+ - notes: Notes with IDs starting with "note"
+ - interviews: Interviews with IDs starting with "interview"
+
+ Returns:
+ None
+ """
+ #! Update when more sources are added!
+ other_documents = 0
+ science_articles = 0
+ notes = 0
+ interviews = 0
+
+ for id in self.arango_ids:
+ if id.startswith("other_documents"):
+ other_documents += 1
+ elif id.startswith("sci_articles"):
+ science_articles += 1
+ elif id.startswith("note"):
+ notes += 1
+ elif id.startswith("interview"):
+ interviews += 1
+
+ for source in [
+ "other_documents",
+ "sci_articles",
+ "notes",
+ "interviews",
+ ]:
+ if source == "other_documents":
+ self.available_sources[source] = other_documents
+ elif source == "sci_articles":
+ self.available_sources[source] = science_articles
+ elif source == "notes":
+ self.available_sources[source] = notes
+ elif source == "interviews":
+ self.available_sources[source] = interviews
+
+ def make_plan(self, question):
+ """Generate a research plan for answering the question/exploring the topic"""
+
+ self.update_research_state(
+ current_step="Plan Creation", current_task="Splitting into questions."
+ )
+
+ query = llm_queries.create_plan_questions(self, question)
+
+ response = self.llm.generate(query=query, model=self.model, think=True)
+ print_purple(response.content)
+ subquestions = [i for i in remove_thinking(response).split("\n") if "?" in i]
+
+ self.report.report["plan"]["subquestions"] = subquestions
+
+ self.update_research_state(
+ current_step="Plan Creation", current_task="Creating initial research plan"
+ )
+
+ # TODO Update the available resources in the query when more resources are added!
+ make_plan_query = llm_queries.create_plan(self, question)
+
+ # Generate the plan and handle potential formatting issues
+ try:
+ response = self.llm.generate(query=make_plan_query, model=self.model, think=True)
+ plan = self.structure_agent.make_structured(response.content, question)
+
+
+ print("THIS IS THE PLAN\n")
+ print_rainbow(plan.__dict__)
+ self.update_research_state(steps_completed=1)
+ return plan
+
+ except Exception as e:
+ print_red(f"Error creating research plan: {e}")
+ traceback.print_exc()
+ return f"Error creating research plan: {str(e)}"
+
+ def process_step(self, step_name, step_tasks, max_attempts=3):
+ """
+ Process a research step with multiple tasks using various agents to gather and organize information.
+ This function handles the complete workflow for a research step:
+ 1. Determines required tools for all tasks
+ 2. Gathers information using the selected tools
+ 3. Summarizes the gathered information
+ 4. Evaluates if the information is sufficient
+ 5. Iteratively gathers more information if needed (up to 3 attempts)
+ 6. Finalizes the step results
+ """
+ print_purple(f"\nProcessing Step: {step_name}")
+ self.update_research_state(
+ current_step=step_name, current_task="Processing entire step"
+ )
+
+ # 1. Determine tools needed for all tasks in the step
+ print_blue("Determining tools for all tasks in step...")
+ all_tasks_description = f"## Step: {step_name}\n"
+ for task in step_tasks:
+ all_tasks_description += (
+ f"- {task['task_name']}: {task['task_description']}\n"
+ )
+ all_tasks_description += "\nWhat tools should I use to gather all necessary information for these tasks efficiently?"
+
+ tool_calls = self.tool_agent.task_tools(all_tasks_description)
+ print_purple("Tools selected for the entire step:")
+ for i, tool_call in enumerate(tool_calls, 1):
+ args_dict = tool_call.function.arguments
+ # Format the arguments as a comma-separated list of key=value pairs
+ args_formatted = ", ".join([f"{k}={v}" for k, v in args_dict.items()])
+ print_purple(f"{i}. {tool_call.function.name} ({args_formatted})")
+
+ # 2. Gather data according to the selected tools
+ print_blue("Gathering data for all tasks...")
+ gathered_info = self.archive_agent.use_tools(tool_calls, all_tasks_description)
+ self.archive_agent.chroma_ids_retrieved += gathered_info.get_chroma_ids
+
+ # 3. Summarize the gathered data
+ print_blue("Summarizing gathered information...")
+ self.archive_agent.reset_chat_history()
+ print_yellow("Step description:")
+ print_yellow(all_tasks_description)
+ print()
+ print_yellow("Summarizing all gathered information...")
+ print("Gathered information:")
+ for info in gathered_info:
+ print(info)
+ print_yellow("Summarizing all gathered information...")
+ summary = self.archive_agent.read_and_summarize(
+ gathered_info, all_tasks_description
+ )
+ summary = remove_thinking(summary)
+ print_green("Step Information Summary:")
+ print(summary)
+
+ # Have the assistant agent track this step
+ self.assistant_agent.summarize_step(step_name, summary)
+
+ # 4. Evaluate if the data is sufficient for all tasks
+ print_blue("Evaluating if information is sufficient for all tasks...")
+ evaluation = self.evaluate_step_completeness(step_tasks, summary)
+
+ self.report.log_evaluation(evaluation)
+
+ # 5. If not enough information, gather more
+ attempt = 1
+
+ while not evaluation["status"] and attempt < max_attempts:
+ print_yellow(
+ f"Information not sufficient. Attempt {attempt}/{max_attempts} to gather more..."
+ )
+
+ # Create a query focusing on missing information
+ additional_query = f"For step '{step_name}', I need additional information on:\n{evaluation['missing_info']}\n\nWhat tools should I use to fill these gaps?"
+
+ # Get additional tools
+ additional_tool_calls = self.tool_agent.task_tools(additional_query)
+
+ # Use additional tools to gather more information
+ additional_info = self.tool_agent.use_tools(
+ additional_tool_calls, additional_query
+ )
+
+ # Add to gathered information
+ # TODO Is it better to append or or make the LLM use the chat history?
+ # gathered_info.extend(additional_info)
+ gathered_info = additional_info
+
+ # Update summary with all information
+ updated_summary = self.archive_agent.read_and_summarize(
+ gathered_info, all_tasks_description
+ )
+ summary = remove_thinking(updated_summary)
+ print_green("Updated Summary:")
+ print(summary)
+
+ # Update assistant agent with new summary
+ self.assistant_agent.summarize_step(step_name, summary)
+
+ # Re-evaluate
+ evaluation = self.evaluate_step_completeness(step_tasks, summary)
+ attempt += 1
+
+ self.report.update_step_summary(summary)
+
+ # 6. Let the MasterAgent use the gathered data to finalize the step
+ print_blue("Finalizing step results...")
+ step_result = self.finalize_step_result(step_name, step_tasks, summary)
+
+ # Pack and store results
+ step_result = {
+ "step_name": step_name,
+ "tasks": [
+ {
+ "task_name": task["task_name"],
+ "task_description": task["task_description"],
+ }
+ for task in step_tasks
+ ],
+ "information_gathered": gathered_info,
+ "summary": summary,
+ "evaluation": evaluation,
+ "result": step_result,
+ }
+
+ self.execution_results[step_name] = step_result
+
+ def execute_research_plan(self, structured_plan):
+ """Execute the structured research plan step by step"""
+ # Execute the plan step by step
+ print_blue("\n--- EXECUTING RESEARCH PLAN ---")
+
+ for step_name, tasks in structured_plan.steps.items():
+ print_blue(f"\n### Processing Step: {step_name}")
+ self.report.start_step(step_name)
+
+ # Collect all task descriptions in this step
+ step_tasks = [
+ {"task_name": task_name, "task_description": task_description}
+ for task_name, task_description in tasks
+ ]
+
+ # Process the entire step
+ self.archive_agent.reset_chroma_ids()
+ self.process_step(step_name, step_tasks)
+
+ # Finish the step in report
+ self.report.finish_step()
+
+ # Evaluate if more steps are needed
+ print_blue("\n--- EVALUATING RESEARCH PLAN ---")
+ plan_evaluation = self.evaluate_plan(self.execution_results)
+ self.report.log_plan_evaluation(plan_evaluation)
+ print_yellow("Plan Evaluation:")
+ print(plan_evaluation["explanation"])
+
+ return self.execution_results
+
+ def evaluate_step_completeness(self, step_tasks, summary):
+ """Evaluate if the information is sufficient for all tasks in the step"""
+
+ # Add None for additional_info if it's not present
+
+ self.update_research_state(current_task="Evaluating step completeness")
+
+ # Create a query to evaluate all tasks
+ step_description = "\n".join(
+ [
+ f"- {task['task_name']}: {task['task_description']}"
+ for task in step_tasks
+ ]
+ )
+
+ query = f"""
+ You are evaluating if the gathered information is sufficient for completing ALL the following tasks:
+
+ {step_description}
+
+ Information gathered:
+ """
+ {summary}
+ """
+
+ Is this information sufficient to complete ALL tasks in this step?
+
+ First, analyze each task individually and determine if the information is sufficient.
+ Then, provide an overall assessment where "status" is True if all tasks are complete and False if not.
+ Explain why the information is sufficient or not in the "explanation" field.
+ If ANY task has insufficient information, specify exactly what additional information is needed.
+ """
+
+ response = self.llm.generate(
+ query=query,
+ format=EvaluateFormat.model_json_schema(),
+ model="standard",
+ think=True,
+ )
+ structured_response = EvaluateFormat.model_validate_json(response.content)
+ # Add None for additional_info if it's not present
+ if not hasattr(structured_response, "additional_info"):
+ structured_response.additional_info = None
+
+ if structured_response.status:
+ print_green(
+ f'\nEVALUATION PASSED\n"Step: {step_description}\n{structured_response.explanation}'
+ )
+ elif not structured_response.status:
+ print_red(f"EVALUATION FAILED\n{structured_response.explanation}")
+
+ return {
+ "status": structured_response.status,
+ "explanation": structured_response.explanation,
+ "missing_info": structured_response.additional_info,
+ }
+
+ def evaluate_step(self, information, task_description):
+ """Evaluate if the information is sufficient for the current step/task"""
+
+ self.update_research_state(current_task=f"Evaluating '{task_description}'")
+
+ query = f'''
+ You are evaluating if the gathered information is sufficient for completing this research task.
+
+ Task: {task_description}
+
+ Information gathered:
+ """
+ {information}
+ """
+
+ Is this information sufficient to complete the task? Respond in the format requested.
+ If insufficient, explain exactly what additional information would be needed.
+ '''
+
+ response = self.llm.generate(
+ query=query, format=EvaluateFormat.model_json_schema()
+ )
+ structured_response = EvaluateFormat.model_validate_json(response.content)
+ if structured_response.status:
+ print_green(
+ f'\nEVALUATION PASSED\n"Task: {task_description}\n{structured_response.explanation}'
+ )
+
+ # Determine status based on the response
+ else:
+ print_red("EVALUATION FAILED")
+ print_yellow(f"Task: {task_description}")
+ print_rainbow(structured_response.__dict__)
+ return {
+ "status": structured_response.status,
+ "explanation": structured_response.explanation,
+ "additional_info": structured_response.additional_info,
+ }
+
+ def evaluate_plan(self, execution_results):
+ """Evaluate if more research steps are needed"""
+ self.update_research_state(
+ current_step="Plan Evaluation",
+ current_task="Evaluating overall research progress",
+ )
+
+ # Create a summary of completed research
+ steps_summary = ""
+ for step_name, step_data in execution_results.items():
+ steps_summary += f"\n## {step_name} \n"
+ # Add the step's summary
+ steps_summary += f"{step_data.get('summary', 'No summary available')} \n"
+
+ # If you want to include individual tasks
+ for task in step_data.get('tasks', []):
+ task_name = task.get('task_name', 'Unnamed task')
+ steps_summary += f"- {task_name} \n"
+
+
+ query = f'''
+ Based on the research that has been conducted so far, determine if additional steps are needed
+ to create a comprehensive report.
+
+ Research completed:
+ """
+ {steps_summary}
+ """"""
+
+ Original question to answer: {self.research_state.get('original_question', 'No question provided')}
+
+ Are additional research steps needed? Respond with COMPLETE or INCOMPLETE,
+ followed by a brief explanation. If INCOMPLETE, suggest what additional steps would be valuable.
+ '''
+
+ response = self.llm.generate(query=query, think=True)
+ evaluation = response.content if hasattr(response, "content") else str(response)
+
+ if "COMPLETE" in evaluation.upper().split(" "):
+ print_green(f'\nEVALUATION PASSED\n"Evaluation: {evaluation}')
+ return {
+ "status": "no more information is needed",
+ "explanation": evaluation,
+ }
+ else:
+ print_red(f'\nEVALUATION FAILED\n"Evaluation: {evaluation}')
+ return {"status": "more information is needed", "explanation": evaluation}
+
+ def finalize_step_result(self, step_name, step_tasks, summary):
+ """Generate a comprehensive result for the entire step using all gathered information"""
+ self.update_research_state(
+ current_task=f"Finalizing results for step: {step_name}"
+ )
+
+ tasks_description = "\n".join(
+ [
+ f"- {task['task_name']}: {task['task_description']}"
+ for task in step_tasks
+ ]
+ )
+
+ query = f"""
+ Based on the following information gathered for step "{step_name}",
+ create a comprehensive analysis that addresses all the tasks in this step.
+
+ Step tasks:
+ {tasks_description}
+
+ Information gathered:
+ {summary}
+
+ Your response should:
+ 1. Be structured with clear sections for each aspect of the analysis
+ 2. Draw connections between different pieces of information
+ 3. Highlight key insights relevant to the original research question
+ 4. Provide a comprehensive understanding of this step's contribution to the overall research
+
+ Sometimes the information is limited, if so do not make up information, but rather say that the information is limited and write a shorter response.
+ """
+
+ response = self.llm.generate(query=query)
+ step_result = remove_thinking(response)
+ print_green("Step Result:")
+ print(step_result)
+
+ self.update_research_state(
+ steps_completed=self.research_state.get("steps_completed", 0) + 1
+ )
+ return step_result
+
+ def write_report(self, execution_results):
+ """Generate the final report based on the collected information"""
+ self.update_research_state(
+ current_step="Report Writing", current_task="Generating final report"
+ )
+
+ # Prepare all the gathered information in a structured way
+ gathered_info = ""
+ for step_name, step_data in execution_results.items():
+ gathered_info += f"\n## {step_name}\n"
+ # Add the step's summary
+ gathered_info += f"Step Summary: {step_data.get('summary', 'No summary available')}\n\n"
+
+ # Add information about tasks
+ for task in step_data.get('tasks', []):
+ task_name = task.get('task_name', 'Unnamed task')
+ task_description = task.get('task_description', 'No description')
+ gathered_info += f"### {task_name}\n"
+ gathered_info += f"Description: {task_description}\n\n"
+
+ # Include sources when available
+ sources = []
+ for info in step_data.get('information_gathered', []):
+ if isinstance(info, dict) and "result" in info and "content" in info["result"]:
+ if (isinstance(info["result"]["content"], dict) and
+ "chunks" in info["result"]["content"]):
+ for chunk in info["result"]["content"].get("chunks", []):
+ metadata = chunk.get("metadata", {})
+ source = f"{metadata.get('title', 'Unknown')}"
+ if metadata.get("journal"):
+ source += f" ({metadata.get('journal')})"
+ if source not in sources:
+ sources.append(source)
+
+ if sources:
+ gathered_info += "\n### Sources:\n"
+ for i, source in enumerate(sources):
+ gathered_info += f"- [{i+1}] {source}\n"
+
+ # Rest of the method continues...
+
+ print_blue("\n\nGathered information:\n".upper())
+ print(gathered_info, "\n")
+
+ query = f'''
+ Based on the following research information, write a extensive report that in detail answers the question:
+ "{self.research_state.get('original_question', 'No question provided').replace('"', "'")}"
+
+ Research Information:
+ """
+ {gathered_info}
+ """
+
+ The report should be well-structured with appropriate headings, present the information
+ accurately, and highlight key insights. Cite sources using [number] notation when referencing specific information.
+ As the report is for journalistic reseach, please be generous with details and cases that can be used when reporting on the subject!
+ '''
+
+ response = self.llm.generate(query=query)
+ report = response.content if hasattr(response, "content") else str(response)
+ report = remove_thinking(report)
+
+ self.update_research_state(
+ steps_completed=self.research_state.get("steps_completed", 0) + 1
+ )
+ return report
+
+
+class StructureAgent(ResearchBase):
+ """A small LLM for structuring text as JSON"""
+
+ def __init__(self, username, model: str = "standard", **kwargs):
+
+ super().__init__(username=username, **kwargs)
+ self.model = model
+ self.system_message = """You are helping a researcher to structure a text. You will get a text and make it into structured data.
+ Make sure not to change the meaning of the text and keeps all the details in the subtasks.
+ The content and/or of each step and task should be understandable by itself. Therefore, if a task seems to refer to something that is not mentioned in the step, you should include the necessary information in the task itself. Example: if a task consists of "Collect relevant information for the subject", you should include what the subject is in the task itself.
+ """
+ self.llm = LLM(
+ system_message="You are a research assistant.",
+ model=self.model,
+ chat=False,
+ messages=[],
+ )
+
+ def make_structured(self, text, question=None):
+ """Convert the research plan into a structured format"""
+ self.update_research_state(
+ current_step="Plan Structuring",
+ current_task="Converting plan to structured format",
+ )
+
+ # Prepare query based on whether a question is provided
+ if question:
+ query = f'''This is a proposed plan for how to write a report on "{question}":\n"""{text}"""\nPlease make the plan into structured data with subtasks. Make sure to keep all the details in the subtasks.'''
+ else:
+ query = f'''This is a proposed plan for how to write a report:\n"""{text}"""\nPlease make the plan into structured data with subtasks. Make sure to keep all the details in the subtasks.'''
+
+ # Generate the structured plan
+ try:
+ response = self.llm.generate(
+ query=query, format=Plan.model_json_schema(), model=self.model
+ )
+ response_content = (
+ response.content if hasattr(response, "content") else str(response)
+ )
+ structured_response = Plan.model_validate_json(response_content)
+
+ self.update_research_state(
+ steps_completed=self.research_state.get("steps_completed", 0) + 1
+ )
+ return Plan.model_validate_json(response_content)
+
+ except Exception as e:
+ print_red(f"Error structuring plan: {e}")
+ traceback.print_exc()
+ # Create a basic fallback structure
+ import sys
+
+ sys.exit(1)
+
+
+class ToolAgent(ResearchBase):
+ """An LLM specialized in choosing tools based on information needs"""
+
+ def __init__(self, username, **kwargs):
+ # Initialize the LLM configuration
+ kwargs["model_config"] = {
+ "system_message": kwargs.get(
+ "system_message",
+ f"""
+ You are a helpful assistant with tools.
+ Your task is to choose one or multiple tools to answer a user's query.
+ DON'T come up with your own tools, only use the ones provided.
+ """,
+ ),
+ "temperature": 0.1,
+ "model": "tools",
+ "chat": kwargs.get("chat", True),
+ }
+
+ super().__init__(username=username, **kwargs)
+
+ def task_tools(self, task_description):
+ """Determine which tools to use for a task"""
+ self.update_research_state(current_task=f"Selecting tools for task")
+
+ query = f'''Research task description:
+ """
+ {task_description}
+ """
+ You have to choose one or many tools in order fetch information neccessary to complete the task.
+ It's important that you think of what information is needed, and choose the right tool for the job considering the tools descriptions.
+ Make sure to read the description of the tools carefully before choosing!
+ You can ONLY chose a tool you are provided with, don't make up a tool!
+ You HAVE TO CHOOSE A TOOL, even if you think you can answer without it. Don't answer the question without choosing a tool.
+ '''
+
+ response = self.llm.generate(query=query, tools=self.tools, model="tools")
+
+ # Extract tool calls from the response
+ tool_calls = response.tool_calls if hasattr(response, "tool_calls") else []
+ return tool_calls
+
+
+class AssistantAgent(ResearchBase):
+ """A small LLM agent for summarizing steps and keeping track of the research process by managing "research notes"/common memory.
+ This agent is designed to work with smaller language models and maintain a structured
+ record of the research process through its Notes system.
+ """
+
+ class Notes:
+ """
+ A class for storing and retrieving notes related to different steps in a process.
+ This class allows adding notes with step name, information, and summary,
+ and retrieving notes for specific steps.
+ Attributes:
+ step_notes (list): A list of dictionaries containing step notes.
+ Each dictionary has keys 'step_name', 'step_information', and 'summary'.
+ """
+
+ def __init__(self):
+ self.step_notes = []
+
+ def add_step_note(self, step_name, step_information, step_summary):
+ """Add a note for a specific step"""
+
+ self.step_notes.append(
+ {
+ "step_name": step_name,
+ "step_information": step_information,
+ "summary": step_summary,
+ }
+ )
+
+ def get_step_notes(self, step_name):
+ """Get notes for a specific step.
+
+ Arguments:
+ step_name (str): The name of the step to retrieve notes for.
+ Returns:
+ list: A list of notes for the specified step.
+ """
+ return [note for note in self.step_notes if note["step_name"] == step_name]
+
+ def __init__(self, username: str, system_message: str, **kwargs):
+ # Configure for small model
+ kwargs["model_config"] = {
+ "temperature": 0.1,
+ "system_message": system_message,
+ "model": "small",
+ "chat": kwargs.get("chat", True),
+ }
+ super().__init__(username=username, **kwargs)
+ self.system_message = system_message
+ self.notes = self.Notes()
+
+ def summarize_step(self, step_name, step_tasks):
+ """Summarize the results of a step"""
+ self.update_research_state(current_task=f"Summarizing step '{step_name}'")
+
+ # Create a query to summarize the step
+ query = f"""
+ You are summarizing the results of this research step:
+
+ Step name: {step_name}
+
+ Tasks:
+ {step_tasks}
+
+ Summarize the results of the tasks in a clear and concise manner. Focus on the facts, and mention sources for reference.
+ """
+
+ response = self.llm.generate(query=query)
+ summary = response.content if hasattr(response, "content") else str(response)
+ self.notes.add_step_note(
+ step_name=step_name,
+ step_information=step_tasks,
+ step_summary=summary,
+ )
+
+
+class ArchiveAgent(ResearchBase):
+ """A small LLM for summarizing large amounts of text"""
+
+ def __init__(self, username: str, system_message: str, **kwargs):
+ # Configure for small model
+ kwargs["model_config"] = {
+ "temperature": 0.1,
+ "system_message": system_message,
+ "model": "small",
+ "chat": kwargs.get("chat", True),
+ }
+ super().__init__(username=username, **kwargs)
+ self.system_message = system_message
+ self.chroma_ids_retrieved = []
+
+ def reset_chroma_ids(self):
+ self.chroma_ids_retrieved = []
+
+
+ def read_and_summarize(self, information: UnifiedToolResponse, step_information):
+ """Summarize the information gathered by the tools"""
+ self.update_research_state(current_task=f"Summarizing gathered information")
+
+ # Check if there are full articles to process
+ full_articles_to_process = []
+ if information.search_results and information.search_results.chunks:
+ for chunk in information.search_results.chunks:
+ if chunk.source_type == "full_article" and chunk.metadata.get("requires_full_summary", False):
+ full_articles_to_process.append(chunk.metadata)
+
+ # If we have full articles to process, summarize them
+ if full_articles_to_process:
+ article_summaries = []
+ question = None
+ if hasattr(self, "research_state"):
+ question = self.research_state.get("original_question", "")
+
+ for article_meta in full_articles_to_process:
+ article_id = article_meta.get("article_id")
+ if article_id:
+ summary = self.fetch_and_summarize_full_article_tool(article_id, question)
+ article_summaries.append(summary)
+
+ # If we've processed full articles, create a unified summary
+ if article_summaries:
+ info_text = "\n\n---\n\n".join(article_summaries)
+ else:
+ # If no full articles were successfully processed, fall back to regular text
+ info_text = information.to_text
+ else:
+ # Process chunks as usual
+ info_text = information.to_text
+
+ print_purple(f"INFO TEXT for summarization:\n{info_text}\n")
+
+ query = f'''
+ Below is the description of the current research step. *It is only for your information, nothing you should reference in the summary*.
+
+ """
+ {step_information}
+ """
+
+ Please read the following information to make a summary for the researcher.
+
+ """
+ {info_text}
+ """
+
+ Focus on *information and facts* that are important and relevant for the research step.
+ Ensure no important details are lost.
+ Reference the sources in the summary so it's clear where each piece of information comes from.
+ Some pieces of information might not be of use in this research step, if so, just ignore it without mentioning it.
+ You are only allowed to use the sources provided in the information, don't make up any sources.
+ You should only make a summary of the information, not the step itself or any kind of evaluation!
+ '''
+
+ response = self.llm.generate(query=query)
+ summary = response.content if hasattr(response, "content") else str(response)
+
+ return summary
+
+ def reset_chat_history(self):
+ self.llm.messages = [{"role": "system", "content": self.system_message}]
+
+
+def main(question, username="lasse", project_name="Electric Cars"):
+ """Main function to execute the research workflow"""
+
+ # Initialize base and project
+ base = BaseClass(username=username)
+ project: Project = Project(
+ username=username, project_name=project_name, user_arango=base.get_arango()
+ )
+
+ # Map what kind of sources there are in self.arango_ids
+ number_of_documents = {'other_documents': 0, 'science_articles': 0, 'notes': 0, 'interviews': 0}
+
+ bot = Bot(username=username, project=project, user_arango=base.get_arango(), tools="all")
+ for id in bot.arango_ids:
+ if id.startswith("other_documents"):
+ number_of_documents['other_documents'] += 1
+ elif id.startswith("sci_articles"):
+ number_of_documents['science_articles'] += 1
+ elif id.startswith("note"):
+ number_of_documents['notes'] += 1
+ elif id.startswith("interview"):
+ number_of_documents['interviews'] += 1
+
+
+ tool_sources = {
+ "fetch_other_documents_tool": ["other_documents"],
+ "fetch_science_articles_tool": ["science_articles"],
+ "fetch_science_articles_and_other_documents_tool": ["other_documents", "science_articles"],
+ "fetch_notes_tool": ["notes"]
+ }
+
+ # Create a list of tools that is to be given to the bots
+ bot_tools: list = [tool.__name__ for tool in bot.tools if callable(tool)]
+ for tool, sources in tool_sources.items():
+ print(tool, sources)
+ documents = 0
+ for source in sources:
+ documents += number_of_documents[source]
+ if documents == 0:
+ print_yellow(f"Removing {tool} from bot, as there are no documents in the source: {sources}")
+ bot_tools.remove(tool)
+
+
+ # Initialize report tracking
+ report: ResearchReport = ResearchReport(
+ question=question, username=username, project_name=project_name
+ )
+
+ # Initialize agents with report tracking
+ master_agent = MasterAgent(
+ username=username, project=project, report=report, chat=True, tools=bot_tools
+ )
+
+
+ # Track the research state in the master agent
+ master_agent.research_state["original_question"] = question
+
+ # Execute workflow with proper error handling and progress tracking
+ print_blue(f"Starting research on: {question}")
+
+ # Research plan creation
+ print_blue("\n--- CREATING RESEARCH PLAN ---")
+ research_plan = master_agent.make_plan(question)
+
+ # Log the plan in the report
+ report.log_plan(research_plan)
+
+# research_plan = '''
+# plan = """## Step 1: Review the journalist's notes
+# - Task1: Identify and extract information from the journalist's notes that directly relates to lithium mining's social, technical, and economic aspects.
+# - Task2: Summarize the extracted information into a structured format, highlighting key themes (e.g., environmental impact, cost benefits, political greenwashing).
+
+# ## Step 2: Search for social impact information
+# - Task1: Use the database/LLM to search for information on the social impacts of lithium mining, such as displacement, labor conditions, health risks, and indigenous land rights.
+# - Task2: Summarize findings into a structured format, focusing on how lithium mining affects local communities and indigenous populations.
+
+# ## Step 3: Search for technical challenges
+# - Task1: Use the database/LLM to search for technical challenges of lithium mining, including environmental degradation, water usage, energy consumption, and ecosystem impacts.
+# - Task2: Summarize findings into a structured format, emphasizing technical risks and environmental consequences.
+
+# ## Step 4: Search for economic aspects
+# - Task1: Use the database/LLM to search for economic challenges of lithium mining, such as production costs, market volatility, profitability, and local economic impacts.
+# - Task2: Summarize findings into a structured format, highlighting economic trade-offs and long-term sustainability.
+
+# ## Step 5: Cross-reference and compile findings
+# - Task1: Compare information from Steps 2–4 to identify overlaps, contradictions, or gaps in the data.
+# - Task2: Compile all findings into a cohesive summary, ensuring each aspect (social, technical, economic) is addressed with evidence from the sources.
+
+# ## Step 6: Analyze long-term risks and sustainability
+# - Task1: Use the database/LLM to search for information on long-term risks of lithium mining, such as resource depletion, pollution, and water scarcity.
+# - Task2: Summarize findings into a structured format, linking long-term risks to social, technical, and economic aspects.
+# '''
+
+
+
+ report.log_plan(research_plan, research_plan)
+
+
+ # Execute the plan step by step
+ execution_results = master_agent.execute_research_plan(research_plan)
+
+ # Write the final report
+ print_blue("\n--- WRITING FINAL REPORT ---")
+ final_report = master_agent.write_report(execution_results)
+ report.log_final_report(final_report)
+ print_green("Final Report:")
+ print(final_report)
+
+ # Save the full research report
+ report_path = report.save_to_file(
+ f"/home/lasse/sci/reports/research_report_{username}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
+ )
+
+ # Create a more readable markdown version
+ markdown_report = report.get_markdown_report()
+ markdown_path = report_path.replace(".json", ".md")
+ with open(markdown_path, "w") as f:
+ f.write(markdown_report)
+ print_green(f"Markdown report saved to {markdown_path}")
+
+ return {
+ "question": question,
+ "research_plan": research_plan,
+ "structured_plan": research_plan,
+ "execution_results": execution_results,
+ "final_report": final_report,
+ "full_report": report.get_full_report(),
+ "report_path": report_path,
+ "markdown_path": markdown_path,
+ }
+
+
+def remove_thinking(response):
+ """Remove the thinking section from the response"""
+ response_text = response.content if hasattr(response, "content") else str(response)
+ if "" in response_text:
+ return response_text.split("")[1].strip()
+ return response_text
+
+
+if __name__ == "__main__":
+ question = "What are the problems around lithium mining? I'm interested in social, technical and economical aspects."
+ result = main(
+ question, username="lasse", project_name="Electric Cars"
+ ) # Use these parameters to test the code, don't change!
diff --git a/article2db.py b/article2db.py
index 2df4ef8..ecf1e89 100644
--- a/article2db.py
+++ b/article2db.py
@@ -18,13 +18,14 @@ import xml.etree.ElementTree as ET
from streamlit.runtime.uploaded_file_manager import UploadedFile
import streamlit as st
-from _arango import ArangoDB
+from _arango import ArangoDB, COLLECTIONS_IN_BASE
from _chromadb import ChromaDB
from _llm import LLM
from colorprinter.print_color import *
-from utils import fix_key
+from utils import fix_key, is_reference_chunk
import semantic_schoolar
+from models import ArticleMetadataResponse
class Document:
def __init__(
@@ -39,6 +40,7 @@ class Document:
_key: str = None,
arango_db_name: str = None,
arango_collection: str = None,
+ arango_doc: dict = None
):
self.filename = filename
self.pdf_file = pdf_file
@@ -50,6 +52,7 @@ class Document:
self.arango_db_name = arango_db_name
self.arango_collection = arango_collection
self.text = text
+ self.arango_doc: dict = arango_doc
self.chunks = []
self.pdf = None
@@ -61,6 +64,8 @@ class Document:
self.download_folder = None
self.document_type = None
+ if self._key:
+ self._key = fix_key(self._key)
if self.pdf_file:
self.open_pdf(self.pdf_file)
@@ -71,9 +76,8 @@ class Document:
if not self._id:
return
data = {
- "text": self.text,
+ "arango_doc": self.arango_doc,
"arango_db_name": self.arango_db_name,
- "arango_id": self._id,
"is_sci": self.is_sci,
}
@@ -132,7 +136,13 @@ class Document:
else:
better_chunks.append(chunk.strip())
- self.chunks = better_chunks
+ # Check if the chunk is mainly academic references
+ chunks = []
+ for chunk in better_chunks:
+ if not is_reference_chunk(chunk):
+ self.chunks.append(chunk)
+ else:
+ print_yellow(f"Chunk is mainly academic references, skipping it.\n{chunk[:100]}...")
def get_title(self, only_meta=False):
"""
@@ -238,7 +248,84 @@ class Document:
class Processor:
+ """
+ Processor class for handling scientific and non-scientific document ingestion, metadata extraction, and storage.
+ This class provides a comprehensive pipeline for processing documents (primarily PDFs), extracting metadata (such as DOI, title, authors, journal, etc.), verifying and enriching metadata using external APIs (CrossRef, Semantic Scholar, DOAJ), chunking document text, and storing both the document and its chunks in vector and document databases (ChromaDB and ArangoDB).
+ Key Features:
+ -------------
+ - Extracts DOI from filenames and document text using regex and LLM fallback.
+ - Retrieves and verifies metadata from CrossRef, Semantic Scholar, and DOAJ.
+ - Handles both scientific articles and other document types, with appropriate collection routing.
+ - Chunks document text for vector storage and search.
+ - Stores documents and chunks in ArangoDB (document DB) and ChromaDB (vector DB).
+ - Manages user access and open access flags.
+ - Supports background summary generation for scientific articles.
+ - Provides PDF download utilities from open access sources.
+ - Designed for extensibility and robust error handling.
+ Parameters:
+ -----------
+ document : Document
+ The document object to be processed.
+ filename : str, optional
+ The filename of the document (default: None).
+ chroma_db : str, optional
+ Name of the ChromaDB database to use (default: "sci_articles").
+ len_chunks : int, optional
+ Length of text chunks for vector storage (default: 2200).
+ local_chroma_deployment : bool, optional
+ Whether to use a local ChromaDB deployment (default: False).
+ process : bool, optional
+ Whether to immediately process the document upon initialization (default: True).
+ document_type : str, optional
+ Type of the document for collection routing (default: None).
+ username : str, optional
+ Username for access control and database routing (default: None).
+ Methods:
+ get_arango(db_name=None, document_type=None)
+ extract_doi(text, multi=False)
+ Extract DOI(s) from text using regex and LLM fallback.
+ chunks2chroma(_id, key)
+ Add document chunks to ChromaDB vector database.
+ chunks2arango()
+ Add document chunks and metadata to ArangoDB document database.
+ llm2metadata()
+ Extract metadata from a scientific article using an LLM.
+ get_crossref(doi)
+ Retrieve and parse metadata from CrossRef by DOI.
+ check_doaj(doi)
+ Check if a DOI is listed in DOAJ and retrieve metadata.
+ get_semantic_scholar_by_doi(doi)
+ Retrieve and verify metadata from Semantic Scholar by DOI.
+ get_semantic_scholar_by_title(title)
+ Retrieve and verify metadata from Semantic Scholar by title.
+ process_document()
+ Main pipeline for processing, extracting, chunking, and storing the document.
+ dl_pyppeteer(doi, url)
+ Download a PDF using a headless browser (async).
+ doi2pdf(doi)
+ Download a PDF for a DOI from open access sources or retrieve from database.
+ Attributes:
+ -----------
+ document : Document
+ The document being processed.
+ chromadb : ChromaDB
+ The ChromaDB instance for vector storage.
+ len_chunks : int
+ Length of text chunks for vector storage.
+ document_type : str
+ Type of the document for collection routing.
+ filename : str
+ Filename of the document.
+ username : str
+ Username for access control and database routing.
+ _id : str
+ Internal document ID after processing.
+ Usage:
+ ------
+ processor = Processor(document, filename="paper.pdf")
+ """
def __init__(
+
self,
document: Document,
filename: str = None,
@@ -249,6 +336,31 @@ class Processor:
document_type: str = None,
username: str = None,
):
+ """
+ Initializes the class with the provided document and configuration parameters.
+
+ Args:
+ document (Document): The document object to be processed and stored.
+ filename (str, optional): The filename associated with the document. Defaults to None.
+ chroma_db (str, optional): The name of the ChromaDB database to use. Defaults to "sci_articles".
+ len_chunks (int, optional): The length of text chunks for processing. Defaults to 2200.
+ local_chroma_deployment (bool, optional): Whether to use a local ChromaDB deployment. Defaults to False.
+ process (bool, optional): Whether to process the document upon initialization. Defaults to True.
+ document_type (str, optional): The type/category of the document. Defaults to None.
+ username (str, optional): The username associated with the document. If not provided, uses document.username. Defaults to None.
+
+ Attributes:
+ document (Document): The document object.
+ chromadb (ChromaDB): The ChromaDB instance for database operations.
+ len_chunks (int): The length of text chunks for processing.
+ document_type (str): The type/category of the document.
+ filename (str): The filename associated with the document.
+ username (str): The username associated with the document.
+ _id: Internal identifier for the document.
+
+ Side Effects:
+ If process is True, calls self.process_document() to process the document.
+ """
self.document = document
self.chromadb = ChromaDB(local_deployment=local_chroma_deployment, db=chroma_db)
self.len_chunks = len_chunks
@@ -258,28 +370,47 @@ class Processor:
self.username = username if username else document.username
self._id = None
+ self._key = None
if process:
self.process_document()
def get_arango(self, db_name=None, document_type=None):
- if db_name and document_type:
- arango = ArangoDB(db_name=db_name)
- arango_collection = arango.db.collection(document_type)
+ """
+ Get an ArangoDB collection based on document type and context.
+
+ This method determines the appropriate ArangoDB collection to use based on the
+ document type and the document's properties.
+
+ Args:
+ db_name (str, optional): The name of the database to connect to.
+ Defaults to None, in which case the default database is used.
+ document_type (str, optional): The type of document, which maps to a collection name.
+ Defaults to None, in which case the method attempts to determine the appropriate collection.
+
+ Returns:
+ Collection: An ArangoDB collection object.
+
+ Raises:
+ AssertionError: If document_type is not provided for non-sci articles, or
+ if username is not provided for non-sci articles.
+
+ Notes:
+ - For document types in COLLECTIONS_IN_BASE, returns the corresponding collection.
+ - For scientific articles (document.is_sci == True), returns the "sci_articles" collection.
+ - For other documents, requires both document_type and document.username to be specified.
+ """
+
+ if document_type in COLLECTIONS_IN_BASE:
+ return ArangoDB().get_collection(document_type)
elif self.document.is_sci:
- arango = ArangoDB(db_name="base")
- arango_collection = arango.db.collection("sci_articles")
- elif self.document.open_access:
- arango = ArangoDB(db_name="base")
- arango_collection = arango.db.collection("other_documents")
+ return ArangoDB().get_collection("sci_articles")
else:
- arango = ArangoDB(db_name=self.document.username)
- arango_collection: ArangoCollection = arango.db.collection(
- self.document_type
- )
- self.document.arango_db_name = arango.db.name
- self.arango_collection = arango_collection
- return arango_collection
+ assert document_type, "Document type must be provided for non-sci articles."
+ assert self.document.username, "Username must be provided for non-sci articles."
+ if self.document.username:
+ return ArangoDB(db_name=self.document.username).get_collection(document_type)
+
def extract_doi(self, text, multi=False):
"""
@@ -360,7 +491,7 @@ class Processor:
ids.append(id)
metadata = {
- "_key": id,
+ "_key": self.document._key,
"file": self.document.file_path,
"chunk_nr": i,
"pages": ",".join([str(i) for i in page_numbers]),
@@ -378,6 +509,11 @@ class Processor:
"sci_articles"
)
else:
+ print('collection name'.upper(), f"{self.username}__other_documents")
+ print_yellow(self.chromadb.db.list_collections())
+ print(self.chromadb.db.database)
+ print('VERSION', self.chromadb.db.get_version)
+ print('CHROMA DB', self.chromadb.db)
chroma_collection = self.chromadb.db.get_or_create_collection(
f"{self.username}__other_documents"
)
@@ -385,6 +521,31 @@ class Processor:
chroma_collection.add(ids=ids, documents=documents, metadatas=metadatas)
def chunks2arango(self):
+ """
+ Adds document chunks to an ArangoDB database.
+
+ This method processes the document and its chunks to store them in the ArangoDB.
+ It handles scientific and non-scientific documents differently, applies access control,
+ and manages document metadata.
+
+ Prerequisites:
+ - Document must have a 'text' attribute
+ - Scientific documents must have 'doi' and 'metadata' attributes
+ - Non-scientific documents must have either '_key' attribute or DOI
+
+ The method:
+ 1. Validates document attributes
+ 2. Gets ArangoDB collection
+ 3. Processes document chunks with page information
+ 4. Manages user access permissions
+ 5. Creates the ArangoDB document with all necessary fields
+ 6. Handles special processing for scientific documents with abstracts
+ 7. Inserts the document into ArangoDB with update capabilities
+ 8. Initiates background summary generation if needed
+
+ Returns:
+ tuple: A tuple containing (document_id, document_key)
+ """
st.write("Adding to document database...")
assert self.document.text, "Document must have 'text' attribute."
if self.document.is_sci:
@@ -397,7 +558,7 @@ class Processor:
getattr(self.document, "_key", None) or self.document.doi
), "Document must have '_key' attribute or DOI."
- arango_collection = self.get_arango()
+ arango_collection = self.get_arango(document_type=self.document.arango_collection)
if self.document.doi:
key = self.document.doi
@@ -435,7 +596,7 @@ class Processor:
if self.document.open_access:
user_access = None
- arango_document = {
+ self.document.arango_doc = {
"_key": fix_key(self.document._key),
"file": self.document.file_path,
"chunks": arango_chunks,
@@ -446,6 +607,7 @@ class Processor:
"metadata": self.document.metadata,
"filename": self.document.filename,
}
+ print_purple('Number of chunks:', len(self.document.arango_doc['chunks']))
if self.document.metadata and self.document.is_sci:
if "abstract" in self.document.metadata:
@@ -453,8 +615,8 @@ class Processor:
self.document.metadata["abstract"] = re.sub(
r"<[^>]*>", "", self.document.metadata["abstract"]
)
- arango_document["metadata"] = self.document.metadata
- arango_document["summary"] = {
+ self.document.arango_doc["metadata"] = self.document.metadata
+ self.document.arango_doc["summary"] = {
"text_sum": (
self.document.metadata["abstract"]["text_sum"]
if "text_sum" in self.document.metadata["abstract"]
@@ -463,20 +625,49 @@ class Processor:
"meta": {"model": "from_metadata"},
}
- arango_document["crossref"] = True
+ self.document.arango_doc["crossref"] = True
- doc = arango_collection.insert(
- arango_document, overwrite=True, overwrite_mode="update", keep_none=False
+ arango = ArangoDB(db_name=self.document.arango_db_name)
+ print_purple(self.document.arango_collection, self.document.arango_db_name)
+ inserted_document = arango.insert_document(
+ collection_name=self.document.arango_collection,
+ document=self.document.arango_doc,
+ overwrite=True,
+ overwrite_mode="update",
+ keep_none=False
)
- self.document._id = doc["_id"]
+ print_green("ArangoDB document inserted:", inserted_document['_id'])
+
+ self.document.arango_doc = arango.db.collection(
+ self.document.arango_collection
+ ).get(self.document._key)
+ self.document._id = self.document.arango_doc["_id"]
- if "summary" not in arango_document:
+ if "summary" not in self.document.arango_doc:
# Make a summary in the background
+ print_yellow("No summary found in the document, generating in background...")
+ print_rainbow(self.document.arango_doc['chunks'])
self.document.make_summary_in_background()
-
- return doc["_id"], key
+ else:
+ print_green("Summary already exists in the document.")
+ print(self.document.arango_doc['summary'])
+ return self.document.arango_doc
def llm2metadata(self):
+ """
+ Extract metadata from a scientific article PDF using a LLM.
+ Uses the first page (or first two pages for multi-page documents) of the PDF
+ to extract the title, publication date, and journal name via LLM.
+ Returns:
+ dict: A dictionary containing the extracted metadata with the following keys:
+ - "title": The article title (str)
+ - "published_date": The publication date (str)
+ - "journal": The journal name (str)
+ - "published_year": The publication year (int or None if not parseable)
+ Note:
+ Default values are provided for any metadata that cannot be extracted.
+ The published_year is extracted from published_date when possible.
+ """
st.write("Extracting metadata using LLM...")
llm = LLM(
temperature=0.01,
@@ -499,38 +690,27 @@ class Processor:
"""
Answer ONLY with the information requested.
- I want to know the published date on the form "YYYY-MM-DD".
- I want the full title of the article.
- I want the name of the journal/paper/outlet where the article was published.
- Be sure to answer on the form "published_date;title;journal" as the answer will be used in a CSV.
- If you can't find the information, answer "not_found".
'''
- result = llm.generate(prompt)
- print_blue(result)
- if result == "not_found":
- return None
- else:
- parts = result.content.split(";", 2)
- if len(parts) != 3:
- return None
- published_date, title, journal = parts
- if published_date == "not_found":
- published_date = "[Unknown date]"
- else:
- try:
- published_year = int(published_date.split("-")[0])
- except:
- published_year = None
- if title == "not_found":
- title = "[Unknown title]"
- if journal == "not_found":
- journal = "[Unknown publication]"
- return {
- "published_date": published_date,
- "published_year": published_year,
- "title": title,
- "journal": journal,
- }
+ result = llm.generate(prompt, format=ArticleMetadataResponse.model_json_schema())
+ structured_response = ArticleMetadataResponse.model_validate_json(result.content)
+
+ # Extract and process metadata with defaults and safer type conversion
+ metadata = {
+ "title": structured_response.title or "[Unknown title]",
+ "published_date": structured_response.published_date or "[Unknown date]",
+ "journal": structured_response.journal or "[Unknown publication]",
+ "published_year": None
+ }
+
+ # Parse year from date if available
+ if metadata["published_date"] and metadata["published_date"] != "[Unknown date]":
+ try:
+ metadata["published_year"] = int(metadata["published_date"].split("-")[0])
+ except (ValueError, IndexError):
+ pass
+
+ # Now you can use metadata dictionary instead of separate variables
+ return metadata
def get_crossref(self, doi):
try:
@@ -903,7 +1083,7 @@ class Processor:
assert self.document.pdf_file or self.document.pdf, "PDF file must be provided."
if not self.document.pdf:
self.document.open_pdf(self.document.pdf_file)
-
+
if self.document.is_image:
return pymupdf4llm.to_markdown(
self.document.pdf, page_chunks=False, show_progress=False
@@ -940,11 +1120,10 @@ class Processor:
if not self.document.metadata and self.document.title:
self.document.metadata = self.get_semantic_scholar_by_title(self.document.title)
-
- # Continue with the rest of the method...
- arango_collection = self.get_arango()
-
- # ... rest of the method remains the same ...
+ if self.document.is_sci:
+ arango_collection = self.get_arango(document_type='sci_articles')
+ else:
+ arango_collection = self.get_arango(document_type='other_documents')
doc = arango_collection.get(self.document._key) if self.document.doi else None
@@ -975,6 +1154,7 @@ class Processor:
arango_collection.update(self.document.doc)
return doc["_id"], arango_collection.db_name, self.document.doi
+ # If no document found, create a new one
else:
self.document.doc = (
{"doi": self.document.doi, "_key": fix_key(self.document.doi)}
@@ -1021,7 +1201,8 @@ class Processor:
print_yellow(f"Document key: {_key}")
print(self.document.doi, self.document.title, self.document.get_title())
self.document.doc["_key"] = fix_key(_key)
- self.document._key = fix_key(_key)
+ self.document._key = self.document.doc["_key"]
+
self.document.metadata = self.document.doc["metadata"]
if not self.document.text:
self.document.extract_text()
@@ -1035,8 +1216,16 @@ class Processor:
self.document.make_chunks()
- _id, key = self.chunks2arango()
- self.chunks2chroma(_id=_id, key=key)
+ if not self.document.is_sci and not self.document.doi:
+ self.document.arango_collection = "other_documents"
+ self.document.arango_db_name = self.username
+
+ print_purple("Not a scientific article, using 'other_articles' collection.")
+
+ arango_doc = self.chunks2arango()
+ _id = arango_doc["_id"]
+ _key = arango_doc["_key"]
+ self.chunks2chroma(_id=_id, key=_key)
self._id = _id
return _id, arango_collection.db_name, self.document.doi
@@ -1224,6 +1413,8 @@ class PDFProcessor(Processor):
return False, None, None, False
+
+
if __name__ == "__main__":
doi = "10.1007/s10584-019-02646-9"
print(f"Processing article with DOI: {doi}")
diff --git a/bot_tools.py b/bot_tools.py
new file mode 100644
index 0000000..e69de29
diff --git a/info.py b/info.py
index 90eff92..64b8a35 100644
--- a/info.py
+++ b/info.py
@@ -190,3 +190,4 @@ country_emojis = {
"ro": "🇷🇴",
"rs": "🇷🇸",
}
+
diff --git a/llm_queries.py b/llm_queries.py
index 41b3c24..e347dcc 100644
--- a/llm_queries.py
+++ b/llm_queries.py
@@ -80,6 +80,11 @@ def create_plan(agent, question):
'''
The example above is just an example, you can use other steps and tasks that are more relevant for the question.
+
+ Again: The research will be done in a restricted context, with only the available sources and tools. Therefore:
+ - DO NOT include any steps that require access to the internet or external databases.
+ - DO NOT include any steps that require cross-referencing sources.
+ - DO NOT include any steps to find new sources or tools.
"""
return query
\ No newline at end of file
diff --git a/llm_server.py b/llm_server.py
index 3f81bc8..3d6341b 100644
--- a/llm_server.py
+++ b/llm_server.py
@@ -1,26 +1,223 @@
from fastapi import FastAPI, BackgroundTasks, Request
-from fastapi.responses import JSONResponse
+from fastapi.responses import JSONResponse, HTMLResponse
import logging
+from datetime import datetime
+import json
+import os
+from typing import Dict, Any
from prompts import get_summary_prompt
from _llm import LLM
from _arango import ArangoDB
+from models import ArticleChunk
+from _chromadb import ChromaDB
+
app = FastAPI()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
+# Storage for the latest processed document
+latest_result: Dict[str, Any] = {}
+latest_result_file = os.path.join(os.path.dirname(__file__), "latest_summary_result.json")
+
+# Load any previously saved result on startup
+try:
+ if os.path.exists(latest_result_file):
+ with open(latest_result_file, 'r') as f:
+ latest_result = json.load(f)
+ logger.info(f"Loaded previous result from {latest_result_file}")
+except Exception as e:
+ logger.warning(f"Could not load previous result: {e}")
+
+# Function to save the latest result to disk
+def save_latest_result(result: Dict[str, Any]):
+ global latest_result
+ latest_result = result
+ try:
+ # Save sanitized version (remove internal fields if needed)
+ result_to_save = {k: v for k, v in result.items() if not k.startswith('_') or k == '_id'}
+ with open(latest_result_file, 'w') as f:
+ json.dump(result_to_save, f, indent=2)
+ logger.info(f"Saved latest result to {latest_result_file}")
+ except Exception as e:
+ logger.error(f"Error saving latest result: {e}")
+
+# New endpoint to get the latest summarized document
+@app.get("/latest_result")
+async def get_latest_result():
+ """
+ Get the latest summarized document result.
+
+ Returns the most recently processed document summary and chunk information.
+ If no document has been processed yet, returns an empty object.
+
+ Returns
+ -------
+ dict
+ The latest processed document with summaries
+ """
+ if not latest_result:
+ return {"message": "No documents have been processed yet"}
+ return latest_result
+
+@app.get("/view_results")
+async def view_results():
+ """
+ View the latest summarization results in a more readable format.
+
+ Returns a formatted response with document summary and chunks.
+
+ Returns
+ -------
+ dict
+ A formatted representation of the latest summarized document
+ """
+ if not latest_result:
+ return {"message": "No documents have been processed yet"}
+
+ # Extract the key information
+ formatted_result = {
+ "document_id": latest_result.get("_id", "Unknown"),
+ "timestamp": datetime.now().isoformat(),
+ "summary": latest_result.get("summary", {}).get("text_sum", "No summary available"),
+ "model": latest_result.get("summary", {}).get("meta", {}).get("model", "Unknown model"),
+ }
+
+ # Format chunks information if available
+ chunks = latest_result.get("chunks", [])
+ if chunks:
+ formatted_chunks = []
+ for i, chunk in enumerate(chunks):
+ chunk_data = {
+ "chunk_number": i + 1,
+ "summary": chunk.get("summary", "No summary available"),
+ "tags": chunk.get("tags", [])
+ }
+ # Add references for scientific articles if available
+ if "references" in chunk:
+ chunk_data["references"] = chunk.get("references", [])
+ formatted_chunks.append(chunk_data)
+
+ formatted_result["chunks"] = formatted_chunks
+ formatted_result["chunk_count"] = len(chunks)
+
+ return formatted_result
+
+@app.get("/html_results", response_class=HTMLResponse)
+async def html_results():
+ """
+ View the latest summarization results in a human-readable HTML format.
+ """
+ if not latest_result:
+ return """
+
+
+ No Results Available
+
+
+
+
No Documents Have Been Processed Yet
+
Submit a document for summarization first.
+
+
+ """
+
+ # Get the document ID and summary
+ doc_id = latest_result.get("_id", "Unknown")
+ summary = latest_result.get("summary", {}).get("text_sum", "No summary available")
+ model = latest_result.get("summary", {}).get("meta", {}).get("model", "Unknown model")
+
+ # Format chunks
+ chunks_html = ""
+ chunks = latest_result.get("chunks", [])
+ for i, chunk in enumerate(chunks):
+ chunk_summary = chunk.get("summary", "No summary available")
+ tags = chunk.get("tags", [])
+ tags_html = ", ".join(tags) if tags else "None"
+
+ references_html = ""
+ if "references" in chunk and chunk["references"]:
+ references_html = "
References:
"
+ for ref in chunk["references"]:
+ references_html += f"