You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

443 lines
16 KiB

# _base_class.py
import os
import re
import streamlit as st
from _arango import ArangoDB
from _chromadb import ChromaDB
class BaseClass:
def __init__(self, username: str, **kwargs) -> None:
self.username: str = username
self.project_name: str = kwargs.get("project_name", None)
self.collection: str = kwargs.get("collection_name", None)
self.user_arango: ArangoDB = self.get_arango()
self.base_arango: ArangoDB = self.get_arango(admin=True)
for key, value in kwargs.items():
setattr(self, key, value)
def get_arango(self, admin: bool = False, db_name: str = None) -> ArangoDB:
if db_name:
return ArangoDB(db_name=db_name)
elif admin:
return ArangoDB()
else:
return ArangoDB(user=self.username, db_name=self.username)
def get_article_collections(self) -> list:
"""
Gets the names of all article collections for the current user.
Returns:
list: A list of article collection names.
"""
article_collections = self.user_arango.execute_aql(
'FOR doc IN article_collections RETURN doc["name"]'
)
return list(article_collections)
def get_projects(self) -> list:
"""
Gets the names of all projects for the current user.
Returns:
list: A list of project names.
"""
projects = self.user_arango.get_projects(username=self.username)
return [project["name"] for project in projects]
def get_chromadb(self):
return ChromaDB()
def get_project(self, project_name: str):
"""
Get a project by name for the current user.
Args:
project_name (str): The name of the project.
Returns:
dict: The project document if found, None otherwise.
"""
return self.user_arango.get_project(project_name, username=self.username)
def set_filename(self, filename=None, folder="other_documents"):
"""
Checks if a file with the given filename already exists in the download folder.
If it does, appends or increments a numerical suffix to the filename until a unique name is found.
Args:
filename (str): The base name of the file to check.
Sets:
self.file_path (str): The unique file path generated.
"""
download_folder = f"user_data/{self.username}/{self.document_type}"
if self.is_sci and not self.document_type == "notes":
self.file_path = f"sci_articles/{self.doi}.pdf".replace("/", "_")
return os.path.exists(self.file_path)
else:
file_path = f"{download_folder}/{filename}"
while os.path.exists(file_path + ".pdf"):
if not re.search(r"(_\d+)$", file_path):
file_path += "_1"
else:
file_path = re.sub(
r"(\d+)$", lambda x: str(int(x.group()) + 1), file_path
)
self.file_path = file_path + ".pdf"
return file_path
def remove_thinking(self, response):
"""Remove the thinking section from the response"""
response_text = response.content if hasattr(response, "content") else str(response)
if "</think>" in response_text:
return response_text.split("</think>")[1].strip()
return response_text
class StreamlitBaseClass(BaseClass):
"""
StreamlitBaseClass is a base class for Streamlit applications that provides methods for managing user settings, session state, and user interactions with collections and projects.
Methods:
__init__(username: str, **kwargs) -> None:
Initializes the StreamlitBaseClass with a username and additional keyword arguments.
get_settings(field: str = None):
Retrieves user settings from the database. If a specific field is provided, returns the value of that field. Otherwise, returns all settings.
update_settings(key, value) -> None:
Updates a specific setting in the database and the Streamlit session state.
get_settings():
Retrieves user settings from the database.
update_session_state(page_name=None):
Updates the Streamlit session state with the attributes of the current instance. If a page name is provided, updates the session state for that page.
update_current_page(page_name):
Updates the current page in the Streamlit session state and the database.
choose_collection(text="Select a collection of favorite articles") -> str:
Displays a select box for choosing a collection of favorite articles. Updates the current collection in the session state and the database.
choose_project(text="Select a project") -> str:
Displays a select box for choosing a project. Updates the current project in the session state and the database.
"""
def __init__(self, username: str, **kwargs) -> None:
super().__init__(username, **kwargs)
def get_settings(self, field: str = None):
"""
Retrieve or initialize user settings from the database.
This method fetches the user settings document from the "settings" collection
in the ArangoDB database. If the settings document does not exist, it initializes
it with default values for "current_collection" and "current_page". The settings
are then stored in the Streamlit session state.
Args:
field (str, optional): The specific field to retrieve from the settings.
If not provided, the entire settings document is returned.
Returns:
dict or any: The entire settings document if no field is specified,
otherwise the value of the specified field.
"""
settings = self.user_arango.get_settings()
if not settings:
default_settings = {
"_key": "settings",
"current_collection": None,
"current_page": None,
}
self.user_arango.initialize_settings(default_settings)
settings = default_settings
# Ensure required fields exist
for i in ["current_collection", "current_page"]:
if i not in settings:
settings[i] = None
st.session_state["settings"] = settings
if field:
return settings.get(field)
return settings
def update_settings(self, key, value) -> None:
"""
Update a specific setting in the database and session state.
Args:
key (str): The key of the setting to update.
value (Any): The new value for the setting.
Returns:
None
"""
self.user_arango.db.collection("settings").update_match(
filters={"_key": "settings"},
body={key: value},
merge=True,
)
st.session_state["settings"][key] = value
def update_session_state(self, page_name=None):
"""
Updates the Streamlit session state with the attributes of the current instance.
Parameters:
page_name (str, optional): The name of the page to update in the session state.
If not provided, it defaults to the current page stored in the session state.
The method iterates over the instance's attributes and updates the session state
for the given page name with those attributes that are of type str, int, float, list, dict, or bool.
"""
if not page_name:
page_name = st.session_state.get("current_page")
for attr, value in self.__dict__.items():
if any([isinstance(value, t) for t in [str, int, float, list, dict, bool]]):
st.session_state[page_name][attr] = value
# for k, v in st.session_state[page_name].items():
# if isinstance(v, list):
# print(k.upper())
# for j in v:
# print(j)
# else:
# print(k.upper(), v)
def update_current_page(self, page_name):
"""
Updates the current page in the session state and settings.
Args:
page_name (str): The name of the page to set as the current page.
Side Effects:
Updates the "current_page" in the session state and settings if it is different from the current value.
"""
if st.session_state.get("current_page") != page_name:
st.session_state["current_page"] = page_name
self.update_settings("current_page", page_name)
def choose_collection(self, text="Select a collection of favorite articles") -> str:
"""
Prompts the user to select a collection of favorite articles from a list.
Args:
text (str): The prompt text to display for the selection box. Defaults to "Select a collection of favorite articles".
Returns:
str: The name of the selected collection.
Side Effects:
- Sets the `project` attribute to None.
- Sets the `collection` attribute to the selected collection.
- Updates the settings with the key "current_collection" to the selected collection.
- Updates the session state.
"""
collections = self.get_article_collections()
collection = st.selectbox(text, collections, index=None)
if collection:
self.project = None
self.collection = collection
self.update_settings("current_collection", collection)
self.update_session_state()
return collection
def choose_project(self, text="Select a project") -> str:
"""
Prompts the user to select a project from a list of available projects.
Args:
text (str): The prompt text to display for project selection. Defaults to "Select a project".
Returns:
str: The name of the selected project.
Side Effects:
- Updates the current project settings.
- Updates the session state.
- Prints the chosen project name to the console.
"""
projects = self.get_projects()
print("projects", projects)
print(self.project_name)
project = st.selectbox(
text,
projects,
index=(
projects.index(self.project_name)
if self.project_name in projects
else None
),
)
print("Choosing project...")
if project:
from projects_page import Project
self.project = Project(self.username, project, self.user_arango)
self.collection = None
self.update_settings("current_project", self.project.name)
self.update_session_state()
print("CHOOSEN PROJECT:", self.project.name)
return self.project
def add_article_to_collection(self, article_id: str, collection_name: str = None):
"""
Add an article to a user's collection.
Args:
article_id (str): The ID of the article.
collection_name (str, optional): The name of the collection. Defaults to current collection.
Returns:
bool: True if the article was added successfully.
"""
if collection_name is None:
collection_name = self.collection
return self.user_arango.add_article_to_collection(article_id, collection_name)
def remove_article_from_collection(
self, article_id: str, collection_name: str = None
):
"""
Remove an article from a user's collection.
Args:
article_id (str): The ID of the article.
collection_name (str, optional): The name of the collection. Defaults to current collection.
Returns:
bool: True if the article was removed successfully.
"""
if collection_name is None:
collection_name = self.collection
return self.user_arango.remove_article_from_collection(
article_id, collection_name
)
def get_project_notes(self, project_name: str = None):
"""
Get notes for a project.
Args:
project_name (str, optional): The name of the project. Defaults to current project.
Returns:
list: A list of note documents.
"""
if project_name is None:
project_name = self.project_name
return self.user_arango.get_project_notes(project_name, username=self.username)
def add_note_to_project(self, note_data: dict):
"""
Add a note to a project.
Args:
note_data (dict): The note data. Should contain project, username, and timestamp.
Returns:
dict: The created note document.
"""
if "project" not in note_data:
note_data["project"] = self.project_name
if "username" not in note_data:
note_data["username"] = self.username
return self.user_arango.add_note_to_project(note_data)
def create_project(self, project_data: dict):
"""
Create a new project for the current user.
Args:
project_data (dict): The project data. Should include a name field.
Returns:
dict: The created project document.
"""
if "username" not in project_data:
project_data["username"] = self.username
return self.user_arango.create_project(project_data)
def update_project(self, project_data: dict):
"""
Update an existing project.
Args:
project_data (dict): The project data. Must include _key.
Returns:
dict: The updated project document.
"""
return self.user_arango.update_project(project_data)
def delete_project(self, project_name: str):
"""
Delete a project for the current user.
Args:
project_name (str): The name of the project.
Returns:
bool: True if the project was deleted successfully.
"""
return self.user_arango.delete_project(project_name, username=self.username)
def get_chat(self, chat_key: str):
"""
Get a chat by key.
Args:
chat_key (str): The key of the chat.
Returns:
dict: The chat document if found, None otherwise.
"""
return self.user_arango.get_chat(chat_key)
def create_or_update_chat(self, chat_data: dict):
"""
Create or update a chat.
Args:
chat_data (dict): The chat data.
Returns:
dict: The created or updated chat document.
"""
if "username" not in chat_data:
chat_data["username"] = self.username
return self.user_arango.create_or_update_chat(chat_data)
def get_chats_for_project(self, project_name: str = None):
"""
Get all chats for a project.
Args:
project_name (str, optional): The name of the project. Defaults to current project.
Returns:
list: A list of chat documents.
"""
if project_name is None:
project_name = self.project_name
return self.user_arango.get_chats_for_project(
project_name, username=self.username
)
def delete_chat(self, chat_key: str):
"""
Delete a chat.
Args:
chat_key (str): The key of the chat.
Returns:
dict: The deletion result.
"""
return self.user_arango.delete_chat(chat_key)