You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
235 lines
10 KiB
235 lines
10 KiB
|
|
import streamlit as st |
|
from datetime import datetime, timedelta |
|
from colorprinter.print_color import * |
|
|
|
|
|
from _base_class import StreamlitBaseClass |
|
from _rss import RSSReader |
|
from projects_page import Project |
|
from streamlit_chatbot import StreamlitChat, StreamlitBot |
|
|
|
|
|
class BotChatPage(StreamlitBaseClass): |
|
""" |
|
BotChatPage - A Streamlit interface for chatting with various AI assistants. |
|
This class provides a user interface for interacting with different types of AI bots |
|
(Research Assistant, Editor, Podcast) that can access and work with user's collections |
|
and projects. |
|
Attributes: |
|
username (str): The username of the current user. |
|
collection_name (str): Name of the selected collection. |
|
project_name (str): Name of the selected project. |
|
project (Project): Project instance the chat is associated with. |
|
chat (StreamlitChat): Chat instance for maintaining conversation history. |
|
role (str): The selected bot persona, default is "Research Assistant". |
|
page_name (str): Name of the current page ("Bot Chat"). |
|
chat_key (str): Unique identifier for the current chat session. |
|
bot (StreamlitBot): Instance of the selected bot type. |
|
Methods: |
|
run(): Main method to render the chat interface and handle interactions. |
|
get_chat(role, new_chat): Retrieves existing chat or creates a new one. |
|
sidebar_actions(): Renders sidebar elements for selecting collections, projects, and chat options. |
|
remove_old_unsaved_chats(): Cleans up unsaved chats older than two weeks. |
|
""" |
|
def __init__(self, username): |
|
super().__init__(username=username) |
|
self.collection_name = None |
|
self.project_name = None |
|
self.project: Project = None |
|
self.chat = None |
|
self.role = "Research Assistant" # Default persona |
|
self.page_name = "Bot Chat" |
|
self.chat_key = None |
|
self.bot: StreamlitBot = None |
|
|
|
# Initialize attributes from session state if available |
|
if self.page_name in st.session_state: |
|
for k, v in st.session_state[self.page_name].items(): |
|
setattr(self, k, v) |
|
|
|
def run(self): |
|
from streamlit_chatbot import EditorBot, ResearchAssistantBot, PodBot, StreamlitBot |
|
self.bot: StreamlitBot = None |
|
self.update_current_page("Bot Chat") |
|
self.remove_old_unsaved_chats() |
|
self.sidebar_actions() |
|
|
|
if self.collection_name or self.project: |
|
print_purple("Collection:", self.collection_name, "Project:", self.project_name) |
|
# If no chat exists, create a new Chat instance |
|
self.chat = self.get_chat(role=self.role) |
|
|
|
# Create a Bot instance with the Chat object |
|
if self.role == "Research Assistant": |
|
print_blue("Creating Research Assistant Bot") |
|
self.bot: ResearchAssistantBot = ResearchAssistantBot( |
|
username=self.username, |
|
chat=self.chat, |
|
collection=self.collection_name, |
|
project=self.project, |
|
tools=[ |
|
"fetch_other_documents_tool", |
|
"fetch_science_articles_tool", |
|
"fetch_science_articles_and_other_documents_tool", |
|
"conversational_response_tool"] |
|
) |
|
|
|
elif self.role == "Editor": |
|
self.bot: StreamlitBot = EditorBot( |
|
username=self.username, |
|
chat=self.chat, |
|
collection=self.collection, |
|
project=self.project, |
|
tools=[ |
|
"fetch_other_documents_tool", |
|
"fetch_notes_tool", |
|
"conversational_response_tool"] |
|
) |
|
|
|
elif self.role == "Podcast": |
|
st.session_state["make_podcast"] = True |
|
# with st.sidebar: |
|
with st.sidebar: |
|
with st.form("make_podcast_form"): |
|
instructions = st.text_area( |
|
"What should the podcast be about? Give a brief description, as if you were the producer." |
|
) |
|
start = st.form_submit_button("Make Podcast!") |
|
if start: |
|
bot = PodBot( |
|
subject=self.project.name, |
|
username=self.username, |
|
chat=self.chat, |
|
collection=self.collection, |
|
project=self.project, |
|
instructions=instructions |
|
) |
|
|
|
# Save updated chat state to session state |
|
st.session_state[self.page_name] = { |
|
"collection": self.collection, |
|
"project": self.project, |
|
"chat": self.chat, |
|
"role": self.role, |
|
} |
|
|
|
# Run the bot (this will display chat history and process user input) |
|
if self.bot: |
|
self.bot.run() |
|
|
|
else: # If no collection or project is selected, use the conversational response bot |
|
print_yellow("No collection or project selected. Using conversational response bot.") |
|
self.bot: StreamlitBot = StreamlitBot( |
|
username=self.username, |
|
chat=self.get_chat(), |
|
tools=["conversational_response_tool"], |
|
) |
|
self.bot.run() |
|
|
|
|
|
def get_chat(self, role="Research Assistant", new_chat=False): |
|
""" |
|
Retrieves or creates a chat session. |
|
|
|
This method handles chat session management by either creating a new chat, |
|
retrieving an existing one from the database, or initializing a chat when |
|
none exists in the session state. |
|
|
|
Parameters: |
|
----------- |
|
role : str, optional |
|
The role assigned to the chat (default is "Research Assistant"). |
|
new_chat : bool, optional |
|
If True, creates a new chat regardless of existing sessions (default is False). |
|
|
|
Returns: |
|
-------- |
|
StreamlitChat |
|
A chat instance either newly created or retrieved from the database. |
|
|
|
Notes: |
|
------ |
|
- If new_chat is True, a new chat is always created |
|
- If no chat exists in session state, a new one is created |
|
- Otherwise, retrieves the existing chat from the database using the chat_key in session state |
|
""" |
|
print_blue('CHAT TYPE:', role) |
|
if new_chat: |
|
chat = StreamlitChat(username=self.username, role=role) |
|
st.session_state['chat_key'] = chat._key |
|
print_blue("Creating new chat:", st.session_state['chat_key']) |
|
elif 'chat_key' not in st.session_state: |
|
chat = StreamlitChat(username=self.username, role=role) |
|
st.session_state['chat_key'] = chat._key |
|
print_blue("Creating new chat:", st.session_state['chat_key']) |
|
else: |
|
print_blue("Old chat:", st.session_state['chat_key']) |
|
chat_data = self.user_arango.db.collection("chats").get(st.session_state['chat_key']) |
|
chat = StreamlitChat.from_dict(chat_data) |
|
return chat |
|
|
|
def sidebar_actions(self): |
|
with st.sidebar: |
|
with st.form("select_chat"): |
|
self.collection = self.choose_collection("Article collection to use for chat:") |
|
self.project = self.choose_project("Project to use for chat:") |
|
submitted = st.form_submit_button("Select Collection/Project") |
|
|
|
with st.form("chat_settings"): |
|
if submitted or any([self.collection, self.project]): |
|
if self.project: |
|
self.role = st.selectbox( |
|
"Choose Bot Role", |
|
options=["Research Assistant", "Editor", "Podcast"], |
|
index=0, |
|
) |
|
elif self.collection: |
|
self.role = "Research Assistant" |
|
|
|
# Load existing chats from the database |
|
if self.project: |
|
chat_history = list( |
|
self.user_arango.db.aql.execute( |
|
f'FOR doc IN chats FILTER doc["project"] == "{self.project}" RETURN {{"_key": doc["_key"], "name": doc["name"]}}' |
|
) |
|
) |
|
# self.project = Project(username=self.username, project_name=self.project_name, user_arango=self.user_arango) |
|
elif self.collection: |
|
chat_history = list( |
|
self.user_arango.db.aql.execute( |
|
f'FOR doc IN chats FILTER doc["collection"] == "{self.collection}" RETURN {{"_key": doc["_key"], "name": doc["name"]}}' |
|
) |
|
) |
|
|
|
chats = {i["name"]: i["_key"] for i in chat_history} |
|
selected_chat = st.selectbox( |
|
"Continue another chat", options=[""] + list(chats.keys()), index=None |
|
) |
|
|
|
if not self.role: |
|
self.role == "Research Assistant" |
|
|
|
start_chat = st.form_submit_button("Start Chat") |
|
if start_chat: |
|
if selected_chat: |
|
st.session_state["chat_key"] = chats[selected_chat] |
|
self.chat = self.get_chat() |
|
else: |
|
self.chat = self.get_chat(role=self.role, new_chat=True) |
|
st.rerun() |
|
|
|
def remove_old_unsaved_chats(self): |
|
two_weeks_ago = datetime.now() - timedelta(weeks=2) |
|
q = f'FOR doc IN chats FILTER doc.saved == false AND doc.last_updated < "{two_weeks_ago.isoformat()}" RETURN doc' |
|
|
|
old_chats = self.user_arango.db.aql.execute( |
|
f'FOR doc IN chats RETURN doc' |
|
) |
|
old_chats = self.user_arango.db.aql.execute( |
|
f'FOR doc IN chats FILTER doc.saved == false AND doc.last_updated < "{two_weeks_ago.isoformat()}" RETURN doc' |
|
) |
|
for chat in old_chats: |
|
print_red(chat["_id"]) |
|
self.user_arango.db.collection("chats").delete(chat["_key"]) |
|
|
|
|