You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

497 lines
19 KiB

from datetime import datetime
import streamlit as st
import uuid
from _base_class import StreamlitBaseClass, BaseClass
from _llm import LLM
from _arango import ArangoDB
from prompts import *
from colorprinter.print_color import *
from llm_tools import ToolRegistry
from streamlit_chatbot import StreamlitBot, PodBot, EditorBot, ResearchAssistantBot
class Chat(StreamlitBaseClass):
def __init__(self, username=None, **kwargs):
super().__init__(username=username, **kwargs)
self.name = kwargs.get("name", None)
self.chat_history = kwargs.get("chat_history", [])
self.role = kwargs.get("role", "Research Assistant")
self._key = kwargs.get("_key", str(uuid.uuid4()))
self.saved = kwargs.get("saved", False)
self.last_updated = kwargs.get("last_updated", datetime.now().isoformat())
self.message_attachments = None
self.project = kwargs.get("project", None)
def add_message(self, role, content):
self.chat_history.append(
{
"role": role,
"content": content.strip().strip('"'),
"role_type": self.role,
}
)
def to_dict(self):
return {
"_key": self._key,
"name": self.name,
"chat_history": self.chat_history,
"role": self.role,
"username": self.username,
"project": self.project,
"last_updated": self.last_updated,
"saved": self.saved,
}
def update_in_arango(self):
"""Update chat in ArangoDB using the new API"""
self.last_updated = datetime.now().isoformat()
# Use the create_or_update_chat method from the new API
self.user_arango.create_or_update_chat(self.to_dict())
def set_name(self, user_input):
llm = LLM(
model="small",
max_length_answer=50,
temperature=0.4,
system_message="You are a chatbot who will be chatting with a user",
)
prompt = (
f'Give a short name to the chat based on this user input: "{user_input}" '
"No more than 30 characters. Answer ONLY with the name of the chat."
)
name = llm.generate(prompt).content.strip('"')
name = f'{name} - {datetime.now().strftime("%B %d")}'
# Check for existing chat with the same name
existing_chat = self.user_arango.execute_aql(
"""
FOR chat IN chats
FILTER chat.name == @name AND chat.username == @username
RETURN chat
""",
bind_vars={"name": name, "username": self.username}
)
if list(existing_chat):
name = f'{name} ({datetime.now().strftime("%H:%M")})'
name += f" - [{self.role}]"
self.name = name
return name
def show_chat_history(self):
"""Display chat history in the Streamlit UI"""
for message in self.chat_history:
with st.chat_message(
name="assistant" if message["role"] == "assistant" else "user",
avatar=self.get_avatar(role=message["role"])
):
st.write(message["content"])
def get_avatar(self, role):
"""Get avatar for a role"""
if role == "user":
return None
elif role == "Host":
return "🎙"
elif role == "Guest":
return "🎤"
elif role == "assistant":
if self.role == "Research Assistant":
return "🔬"
elif self.role == "Editor":
return "📝"
else:
return "🤖"
return None
@classmethod
def from_dict(cls, data):
return cls(
username=data.get("username"),
name=data.get("name"),
chat_history=data.get("chat_history", []),
role=data.get("role", "Research Assistant"),
_key=data.get("_key"),
project=data.get("project"),
last_updated=data.get("last_updated"),
saved=data.get("saved", False),
)
def chat_history2bot(self, n_messages: int = None, remove_system: bool = False):
history = [
{"role": m["role"], "content": m["content"]} for m in self.chat_history
]
if n_messages and len(history) > n_messages:
history = history[-n_messages:]
if (
all([history[0]["role"] == "system", remove_system])
or history[0]["role"] == "assistant"
):
history = history[1:]
return history
class Bot(BaseClass):
def __init__(self, username: str, chat: Chat = None, tools: list = None, **kwargs):
super().__init__(username=username, **kwargs)
# Use the passed in chat or create a new Chat
self.chat = chat if chat else Chat(username=username, role="Research Assistant")
print_yellow(f"Chat:", chat, type(chat))
# Store or set up project/collection if available
self.project = kwargs.get("project", None)
self.collection = kwargs.get("collection", None)
if self.collection and not isinstance(self.collection, list):
self.collection = [self.collection]
# Load articles in the collections using the new API
self.arango_ids = []
if self.collection:
for c in self.collection:
# Use execute_aql from the new API
article_ids = self.user_arango.execute_aql(
"""
FOR doc IN article_collections
FILTER doc.name == @collection
FOR article IN doc.articles
RETURN article
""",
bind_vars={"collection": c}
)
for _id in article_ids:
self.arango_ids.append(_id)
# A standard LLM for normal chat
self.chatbot = LLM(messages=self.chat.chat_history2bot())
# A helper bot for generating queries or short prompts
self.helperbot = LLM(
temperature=0,
model="small",
max_length_answer=500,
system_message=get_query_builder_system_message(),
messages=self.chat.chat_history2bot(n_messages=4, remove_system=True),
)
# A specialized LLM picking which tool to use
self.toolbot = LLM(
temperature=0,
system_message="""
You are an assistant bot helping an answering bot to answer a user's messages.
Your task is to choose one or multiple tools that will help the answering bot to provide the user with the best possible answer.
You should NEVER directly answer the user. You MUST choose a tool.
""",
chat=False,
model="small",
)
# Load or register the passed-in tools
if tools:
self.tools = ToolRegistry.get_tools(tools=tools)
else:
self.tools = ToolRegistry.get_tools()
# Store other kwargs
for arg in kwargs:
setattr(self, arg, kwargs[arg])
def get_chunks(
self,
user_input,
collections=["sci_articles", "other_documents"],
n_results=7,
n_sources=4,
filter=True,
):
# Basic version without Streamlit calls
query = self.helperbot.generate(
get_generate_vector_query_prompt(user_input, self.chat.role)
).content.strip('"')
combined_chunks = []
if collections:
for collection in collections:
where_filter = {"_id": {"$in": self.arango_ids}} if filter else {}
chunks = self.get_chromadb().query(
query=query,
collection=collection,
n_results=n_results,
n_sources=n_sources,
where=where_filter,
max_retries=3,
)
for doc, meta, dist in zip(
chunks["documents"][0],
chunks["metadatas"][0],
chunks["distances"][0],
):
combined_chunks.append(
{"document": doc, "metadata": meta, "distance": dist}
)
combined_chunks.sort(key=lambda x: x["distance"])
# Keep the best chunks according to n_sources
sources = set()
closest_chunks = []
for chunk in combined_chunks:
source_id = chunk["metadata"].get("_id", "no_id")
if source_id not in sources:
sources.add(source_id)
closest_chunks.append(chunk)
if len(sources) >= n_sources:
break
if len(closest_chunks) < n_results:
remaining_chunks = [
c for c in combined_chunks if c not in closest_chunks
]
closest_chunks.extend(remaining_chunks[: n_results - len(closest_chunks)])
# Now fetch real metadata from Arango using the new API
for chunk in closest_chunks:
_id = chunk["metadata"].get("_id")
if not _id:
continue
try:
# Determine which database to use based on collection name
if _id.startswith("sci_articles"):
# Use base_arango for common documents
arango_doc = self.base_arango.get_document(_id)
else:
# Use user_arango for user-specific documents
arango_doc = self.user_arango.get_document(_id)
if arango_doc:
arango_metadata = arango_doc.get("metadata", {})
# Possibly merge notes
if "user_notes" in arango_doc:
arango_metadata["user_notes"] = arango_doc["user_notes"]
chunk["metadata"] = arango_metadata
except Exception as e:
print_red(f"Error fetching document {_id}: {e}")
# Group by article title
grouped_chunks = {}
article_number = 1
for chunk in closest_chunks:
title = chunk["metadata"].get("title", "No title")
chunk["article_number"] = article_number
if title not in grouped_chunks:
grouped_chunks[title] = {
"article_number": article_number,
"chunks": [],
}
article_number += 1
grouped_chunks[title]["chunks"].append(chunk)
return grouped_chunks
def answer_tool_call(self, response, user_input):
bot_responses = []
# This method returns / stores responses (no Streamlit calls)
if not response.get("tool_calls"):
return ""
for tool in response.get("tool_calls"):
function_name = tool.function.get('name')
arguments = tool.function.arguments
arguments["query"] = user_input
if hasattr(self, function_name):
if function_name in [
"fetch_other_documents_tool",
"fetch_science_articles_tool",
"fetch_science_articles_and_other_documents_tool",
]:
chunks = getattr(self, function_name)(**arguments)
bot_responses.append(
self.generate_from_chunks(user_input, chunks).strip('"')
)
elif function_name == "fetch_notes_tool":
notes = getattr(self, function_name)()
bot_responses.append(
self.generate_from_notes(user_input, notes).strip('"')
)
elif function_name == "conversational_response_tool":
bot_responses.append(
getattr(self, function_name)(user_input).strip('"')
)
return "\n\n".join(bot_responses)
def process_user_input(self, user_input, content_attachment=None):
# Add user message
self.chat.add_message("user", user_input)
if not content_attachment:
prompt = get_tools_prompt(user_input)
response = self.toolbot.generate(prompt, tools=self.tools, stream=False)
if response.get("tool_calls"):
bot_response = self.answer_tool_call(response, user_input)
else:
# Just respond directly
bot_response = response.content.strip('"')
else:
# If there's an attachment, do something minimal
bot_response = "Content attachment received (Base Bot)."
# Add assistant message
if self.chat.chat_history[-1]["role"] != "assistant":
self.chat.add_message("assistant", bot_response)
# Update in Arango
self.chat.update_in_arango()
return bot_response
def generate_from_notes(self, user_input, notes):
# No Streamlit calls
notes_string = ""
for note in notes:
notes_string += f"\n# {note.get('title','No title')}\n{note.get('text','')}\n---\n"
prompt = get_chat_prompt(user_input, content_string=notes_string, role=self.chat.role)
return self.chatbot.generate(prompt, stream=True)
def generate_from_chunks(self, user_input, chunks):
# No Streamlit calls
chunks_string = ""
for title, group in chunks.items():
user_notes_string = ""
if "user_notes" in group["chunks"][0]["metadata"]:
notes = group["chunks"][0]["metadata"]["user_notes"]
user_notes_string = f'\n\nUser notes:\n"""\n{notes}\n"""\n\n'
docs = "\n(...)\n".join([c["document"] for c in group["chunks"]])
chunks_string += (
f"\n# {title}\n## Article #{group['article_number']}\n{user_notes_string}{docs}\n---\n"
)
prompt = get_chat_prompt(user_input, content_string=chunks_string, role=self.chat.role)
return self.chatbot.generate(prompt, stream=True)
def run(self):
# Base Bot has no Streamlit run loop
pass
def get_notes(self):
# Get project notes using the new API
if self.project and hasattr(self.project, "name"):
notes = self.user_arango.get_project_notes(
project_name=self.project.name,
username=self.username
)
return list(notes)
return []
@ToolRegistry.register
def fetch_science_articles_tool(self, query: str, n_documents: int):
"""
"Fetches information from scientific articles. Use this tool when the user is looking for information from scientific articles."
Parameters:
query (str): The search query to find relevant scientific articles.
n_documents (int): How many documents to fetch. A complex query may require more documents. Min: 3, Max: 10.
Returns:
list: A list of chunks containing information from the fetched scientific articles.
"""
print_purple('Query:', query)
n_documents = int(n_documents)
if n_documents < 3:
n_documents = 3
elif n_documents > 10:
n_documents = 10
return self.get_chunks(
query, collections=["sci_articles"], n_results=n_documents
)
@ToolRegistry.register
def fetch_other_documents_tool(self, query: str, n_documents: int):
"""
Fetches information from other documents based on the user's query.
This method retrieves information from various types of documents such as reports, news articles, and other texts. It should be used only when it is clear that the user is not seeking scientific articles.
Args:
query (str): The search query provided by the user.
n_documents (int): How many documents to fetch. A complex query may require more documents. Min: 2, Max: 10.
Returns:
list: A list of document chunks that match the query.
"""
assert isinstance(self, Bot), "The first argument must be a Bot object."
n_documents = int(n_documents)
if n_documents < 2:
n_documents = 2
elif n_documents > 10:
n_documents = 10
return self.get_chunks(
query,
collections=[f"{self.username}__other_documents"],
n_results=n_documents,
)
@ToolRegistry.register
def fetch_science_articles_and_other_documents_tool(
self, query: str, n_documents: int
):
"""
Fetches information from both scientific articles and other documents.
This method is often used when the user hasn't specified what kind of sources they are interested in.
Args:
query (str): The search query to fetch information for.
n_documents (int): How many documents to fetch. A complex query may require more documents. Min: 3, Max: 10.
Returns:
list: A list of document chunks that match the search query.
"""
assert isinstance(self, Bot), "The first argument must be a Bot object."
n_documents = int(n_documents)
if n_documents < 3:
n_documents = 3
elif n_documents > 10:
n_documents = 10
return self.get_chunks(
query,
collections=["sci_articles", f"{self.username}__other_documents"],
n_results=n_documents,
)
@ToolRegistry.register
def fetch_notes_tool(bot):
"""
Fetches information from the project notes when you as an editor need context from the project notes to understand other information. ONLY use this together with other tools! No arguments needed.
Returns:
list: A list of notes.
"""
assert isinstance(bot, Bot), "The first argument must be a Bot object."
return bot.get_notes()
@ToolRegistry.register
def conversational_response_tool(self, query: str):
"""
Generate a conversational response to a user's query.
This method is designed to provide a short and conversational response
without fetching additional data. It should be used only when it is clear
that the user is engaging in small talk (like saying 'hi') and not seeking detailed information.
Args:
query (str): The user's message to which the bot should respond.
Returns:
str: The generated conversational response.
"""
query = f"""
User message: "{query}".
Make your answer short and conversational.
This is perhaps not a conversation about a journalistic project, so try not to be too informative.
Don't answer with anything you're not sure of!
"""
result = (
self.chatbot.generate(query, stream=True)
if self.chatbot
else self.llm.generate(query, stream=True)
)
return result