You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
443 lines
16 KiB
443 lines
16 KiB
# _base_class.py |
|
import os |
|
import re |
|
import streamlit as st |
|
from _arango import ArangoDB |
|
from _chromadb import ChromaDB |
|
|
|
|
|
class BaseClass: |
|
def __init__(self, username: str, **kwargs) -> None: |
|
self.username: str = username |
|
self.project_name: str = kwargs.get("project_name", None) |
|
self.collection: str = kwargs.get("collection_name", None) |
|
self.user_arango: ArangoDB = self.get_arango() |
|
self.base_arango: ArangoDB = self.get_arango(admin=True) |
|
for key, value in kwargs.items(): |
|
setattr(self, key, value) |
|
|
|
def get_arango(self, admin: bool = False, db_name: str = None) -> ArangoDB: |
|
if db_name: |
|
return ArangoDB(db_name=db_name) |
|
elif admin: |
|
return ArangoDB() |
|
else: |
|
return ArangoDB(user=self.username, db_name=self.username) |
|
|
|
def get_article_collections(self) -> list: |
|
""" |
|
Gets the names of all article collections for the current user. |
|
|
|
Returns: |
|
list: A list of article collection names. |
|
""" |
|
article_collections = self.user_arango.execute_aql( |
|
'FOR doc IN article_collections RETURN doc["name"]' |
|
) |
|
return list(article_collections) |
|
|
|
def get_projects(self) -> list: |
|
""" |
|
Gets the names of all projects for the current user. |
|
|
|
Returns: |
|
list: A list of project names. |
|
""" |
|
projects = self.user_arango.get_projects(username=self.username) |
|
return [project["name"] for project in projects] |
|
|
|
def get_chromadb(self): |
|
return ChromaDB() |
|
|
|
def get_project(self, project_name: str): |
|
""" |
|
Get a project by name for the current user. |
|
|
|
Args: |
|
project_name (str): The name of the project. |
|
|
|
Returns: |
|
dict: The project document if found, None otherwise. |
|
""" |
|
return self.user_arango.get_project(project_name, username=self.username) |
|
|
|
def set_filename(self, filename=None, folder="other_documents"): |
|
""" |
|
Checks if a file with the given filename already exists in the download folder. |
|
If it does, appends or increments a numerical suffix to the filename until a unique name is found. |
|
|
|
Args: |
|
filename (str): The base name of the file to check. |
|
|
|
Sets: |
|
self.file_path (str): The unique file path generated. |
|
""" |
|
|
|
download_folder = f"user_data/{self.username}/{self.document_type}" |
|
if self.is_sci and not self.document_type == "notes": |
|
self.file_path = f"sci_articles/{self.doi}.pdf".replace("/", "_") |
|
return os.path.exists(self.file_path) |
|
else: |
|
file_path = f"{download_folder}/{filename}" |
|
while os.path.exists(file_path + ".pdf"): |
|
if not re.search(r"(_\d+)$", file_path): |
|
file_path += "_1" |
|
else: |
|
file_path = re.sub( |
|
r"(\d+)$", lambda x: str(int(x.group()) + 1), file_path |
|
) |
|
self.file_path = file_path + ".pdf" |
|
return file_path |
|
|
|
def remove_thinking(self, response): |
|
"""Remove the thinking section from the response""" |
|
response_text = response.content if hasattr(response, "content") else str(response) |
|
if "</think>" in response_text: |
|
return response_text.split("</think>")[1].strip() |
|
return response_text |
|
|
|
class StreamlitBaseClass(BaseClass): |
|
""" |
|
StreamlitBaseClass is a base class for Streamlit applications that provides methods for managing user settings, session state, and user interactions with collections and projects. |
|
Methods: |
|
__init__(username: str, **kwargs) -> None: |
|
Initializes the StreamlitBaseClass with a username and additional keyword arguments. |
|
get_settings(field: str = None): |
|
Retrieves user settings from the database. If a specific field is provided, returns the value of that field. Otherwise, returns all settings. |
|
update_settings(key, value) -> None: |
|
Updates a specific setting in the database and the Streamlit session state. |
|
get_settings(): |
|
Retrieves user settings from the database. |
|
update_session_state(page_name=None): |
|
Updates the Streamlit session state with the attributes of the current instance. If a page name is provided, updates the session state for that page. |
|
update_current_page(page_name): |
|
Updates the current page in the Streamlit session state and the database. |
|
choose_collection(text="Select a collection of favorite articles") -> str: |
|
Displays a select box for choosing a collection of favorite articles. Updates the current collection in the session state and the database. |
|
choose_project(text="Select a project") -> str: |
|
Displays a select box for choosing a project. Updates the current project in the session state and the database. |
|
""" |
|
|
|
def __init__(self, username: str, **kwargs) -> None: |
|
super().__init__(username, **kwargs) |
|
|
|
def get_settings(self, field: str = None): |
|
""" |
|
Retrieve or initialize user settings from the database. |
|
|
|
This method fetches the user settings document from the "settings" collection |
|
in the ArangoDB database. If the settings document does not exist, it initializes |
|
it with default values for "current_collection" and "current_page". The settings |
|
are then stored in the Streamlit session state. |
|
|
|
Args: |
|
field (str, optional): The specific field to retrieve from the settings. |
|
If not provided, the entire settings document is returned. |
|
|
|
Returns: |
|
dict or any: The entire settings document if no field is specified, |
|
otherwise the value of the specified field. |
|
""" |
|
settings = self.user_arango.get_settings() |
|
if not settings: |
|
default_settings = { |
|
"_key": "settings", |
|
"current_collection": None, |
|
"current_page": None, |
|
} |
|
self.user_arango.initialize_settings(default_settings) |
|
settings = default_settings |
|
|
|
# Ensure required fields exist |
|
for i in ["current_collection", "current_page"]: |
|
if i not in settings: |
|
settings[i] = None |
|
|
|
st.session_state["settings"] = settings |
|
if field: |
|
return settings.get(field) |
|
return settings |
|
|
|
def update_settings(self, key, value) -> None: |
|
""" |
|
Update a specific setting in the database and session state. |
|
|
|
Args: |
|
key (str): The key of the setting to update. |
|
value (Any): The new value for the setting. |
|
|
|
Returns: |
|
None |
|
""" |
|
self.user_arango.db.collection("settings").update_match( |
|
filters={"_key": "settings"}, |
|
body={key: value}, |
|
merge=True, |
|
) |
|
st.session_state["settings"][key] = value |
|
|
|
def update_session_state(self, page_name=None): |
|
""" |
|
Updates the Streamlit session state with the attributes of the current instance. |
|
|
|
Parameters: |
|
page_name (str, optional): The name of the page to update in the session state. |
|
If not provided, it defaults to the current page stored in the session state. |
|
|
|
The method iterates over the instance's attributes and updates the session state |
|
for the given page name with those attributes that are of type str, int, float, list, dict, or bool. |
|
""" |
|
|
|
if not page_name: |
|
page_name = st.session_state.get("current_page") |
|
for attr, value in self.__dict__.items(): |
|
if any([isinstance(value, t) for t in [str, int, float, list, dict, bool]]): |
|
st.session_state[page_name][attr] = value |
|
# for k, v in st.session_state[page_name].items(): |
|
# if isinstance(v, list): |
|
# print(k.upper()) |
|
# for j in v: |
|
# print(j) |
|
# else: |
|
# print(k.upper(), v) |
|
|
|
def update_current_page(self, page_name): |
|
""" |
|
Updates the current page in the session state and settings. |
|
|
|
Args: |
|
page_name (str): The name of the page to set as the current page. |
|
|
|
Side Effects: |
|
Updates the "current_page" in the session state and settings if it is different from the current value. |
|
""" |
|
if st.session_state.get("current_page") != page_name: |
|
st.session_state["current_page"] = page_name |
|
self.update_settings("current_page", page_name) |
|
|
|
def choose_collection(self, text="Select a collection of favorite articles") -> str: |
|
""" |
|
Prompts the user to select a collection of favorite articles from a list. |
|
|
|
Args: |
|
text (str): The prompt text to display for the selection box. Defaults to "Select a collection of favorite articles". |
|
|
|
Returns: |
|
str: The name of the selected collection. |
|
|
|
Side Effects: |
|
- Sets the `project` attribute to None. |
|
- Sets the `collection` attribute to the selected collection. |
|
- Updates the settings with the key "current_collection" to the selected collection. |
|
- Updates the session state. |
|
""" |
|
collections = self.get_article_collections() |
|
collection = st.selectbox(text, collections, index=None) |
|
if collection: |
|
self.project = None |
|
self.collection = collection |
|
self.update_settings("current_collection", collection) |
|
self.update_session_state() |
|
return collection |
|
|
|
def choose_project(self, text="Select a project") -> str: |
|
""" |
|
Prompts the user to select a project from a list of available projects. |
|
|
|
Args: |
|
text (str): The prompt text to display for project selection. Defaults to "Select a project". |
|
|
|
Returns: |
|
str: The name of the selected project. |
|
|
|
Side Effects: |
|
- Updates the current project settings. |
|
- Updates the session state. |
|
- Prints the chosen project name to the console. |
|
""" |
|
projects = self.get_projects() |
|
print("projects", projects) |
|
print(self.project_name) |
|
|
|
project = st.selectbox( |
|
text, |
|
projects, |
|
index=( |
|
projects.index(self.project_name) |
|
if self.project_name in projects |
|
else None |
|
), |
|
) |
|
print("Choosing project...") |
|
if project: |
|
from projects_page import Project |
|
|
|
self.project = Project(self.username, project, self.user_arango) |
|
self.collection = None |
|
self.update_settings("current_project", self.project.name) |
|
self.update_session_state() |
|
print("CHOOSEN PROJECT:", self.project.name) |
|
return self.project |
|
|
|
def add_article_to_collection(self, article_id: str, collection_name: str = None): |
|
""" |
|
Add an article to a user's collection. |
|
|
|
Args: |
|
article_id (str): The ID of the article. |
|
collection_name (str, optional): The name of the collection. Defaults to current collection. |
|
|
|
Returns: |
|
bool: True if the article was added successfully. |
|
""" |
|
if collection_name is None: |
|
collection_name = self.collection |
|
|
|
return self.user_arango.add_article_to_collection(article_id, collection_name) |
|
|
|
def remove_article_from_collection( |
|
self, article_id: str, collection_name: str = None |
|
): |
|
""" |
|
Remove an article from a user's collection. |
|
|
|
Args: |
|
article_id (str): The ID of the article. |
|
collection_name (str, optional): The name of the collection. Defaults to current collection. |
|
|
|
Returns: |
|
bool: True if the article was removed successfully. |
|
""" |
|
if collection_name is None: |
|
collection_name = self.collection |
|
|
|
return self.user_arango.remove_article_from_collection( |
|
article_id, collection_name |
|
) |
|
|
|
def get_project_notes(self, project_name: str = None): |
|
""" |
|
Get notes for a project. |
|
|
|
Args: |
|
project_name (str, optional): The name of the project. Defaults to current project. |
|
|
|
Returns: |
|
list: A list of note documents. |
|
""" |
|
if project_name is None: |
|
project_name = self.project_name |
|
|
|
return self.user_arango.get_project_notes(project_name, username=self.username) |
|
|
|
def add_note_to_project(self, note_data: dict): |
|
""" |
|
Add a note to a project. |
|
|
|
Args: |
|
note_data (dict): The note data. Should contain project, username, and timestamp. |
|
|
|
Returns: |
|
dict: The created note document. |
|
""" |
|
if "project" not in note_data: |
|
note_data["project"] = self.project_name |
|
if "username" not in note_data: |
|
note_data["username"] = self.username |
|
|
|
return self.user_arango.add_note_to_project(note_data) |
|
|
|
def create_project(self, project_data: dict): |
|
""" |
|
Create a new project for the current user. |
|
|
|
Args: |
|
project_data (dict): The project data. Should include a name field. |
|
|
|
Returns: |
|
dict: The created project document. |
|
""" |
|
if "username" not in project_data: |
|
project_data["username"] = self.username |
|
|
|
return self.user_arango.create_project(project_data) |
|
|
|
def update_project(self, project_data: dict): |
|
""" |
|
Update an existing project. |
|
|
|
Args: |
|
project_data (dict): The project data. Must include _key. |
|
|
|
Returns: |
|
dict: The updated project document. |
|
""" |
|
return self.user_arango.update_project(project_data) |
|
|
|
def delete_project(self, project_name: str): |
|
""" |
|
Delete a project for the current user. |
|
|
|
Args: |
|
project_name (str): The name of the project. |
|
|
|
Returns: |
|
bool: True if the project was deleted successfully. |
|
""" |
|
return self.user_arango.delete_project(project_name, username=self.username) |
|
|
|
def get_chat(self, chat_key: str): |
|
""" |
|
Get a chat by key. |
|
|
|
Args: |
|
chat_key (str): The key of the chat. |
|
|
|
Returns: |
|
dict: The chat document if found, None otherwise. |
|
""" |
|
return self.user_arango.get_chat(chat_key) |
|
|
|
def create_or_update_chat(self, chat_data: dict): |
|
""" |
|
Create or update a chat. |
|
|
|
Args: |
|
chat_data (dict): The chat data. |
|
|
|
Returns: |
|
dict: The created or updated chat document. |
|
""" |
|
if "username" not in chat_data: |
|
chat_data["username"] = self.username |
|
|
|
return self.user_arango.create_or_update_chat(chat_data) |
|
|
|
def get_chats_for_project(self, project_name: str = None): |
|
""" |
|
Get all chats for a project. |
|
|
|
Args: |
|
project_name (str, optional): The name of the project. Defaults to current project. |
|
|
|
Returns: |
|
list: A list of chat documents. |
|
""" |
|
if project_name is None: |
|
project_name = self.project_name |
|
|
|
return self.user_arango.get_chats_for_project( |
|
project_name, username=self.username |
|
) |
|
|
|
def delete_chat(self, chat_key: str): |
|
""" |
|
Delete a chat. |
|
|
|
Args: |
|
chat_key (str): The key of the chat. |
|
|
|
Returns: |
|
dict: The deletion result. |
|
""" |
|
return self.user_arango.delete_chat(chat_key)
|
|
|