You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1527 lines
61 KiB
1527 lines
61 KiB
from datetime import datetime |
|
import streamlit as st |
|
|
|
from _llm import LLM |
|
from prompts import * |
|
from colorprinter.print_color import * |
|
from projects_page import Project |
|
from ollama._types import Message as OllamaMessage |
|
from _base_class import StreamlitBaseClass, BaseClass |
|
from typing import List |
|
from models import ( |
|
ChunkSearchResults, |
|
DocumentChunk, |
|
DocumentChunk, |
|
ChunkMetadata, |
|
QueryResponse, |
|
UnifiedSearchResults, |
|
UnifiedDataChunk, |
|
) |
|
|
|
|
|
class Chat(StreamlitBaseClass): |
|
""" |
|
A class to represent a chat session in a Streamlit application. |
|
|
|
Attributes: |
|
----------- |
|
name : str |
|
The name of the chat. |
|
chat_history : list |
|
A list to store the chat history. |
|
role : str |
|
The role of the user in the chat. |
|
project : str |
|
The project associated with the chat. |
|
collection : str |
|
The collection associated with the chat. |
|
_key : str |
|
The unique key for the chat. |
|
|
|
Methods: |
|
-------- |
|
add_message(role, content): |
|
Adds a message to the chat history. |
|
|
|
to_dict(): |
|
Converts the chat object to a dictionary. |
|
|
|
update_in_arango(): |
|
Updates the chat object in the ArangoDB. |
|
|
|
set_name(user_input): |
|
Sets the name of the chat based on user input. |
|
|
|
show_title(title=None): |
|
Displays the title of the chat in the Streamlit application. |
|
|
|
from_dict(data): |
|
Creates a Chat object from a dictionary. |
|
|
|
chat_history2bot(n_messages=None, remove_system=False): |
|
Converts the chat history to a format suitable for a bot. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
username=None, |
|
role=None, |
|
key=None, |
|
project=None, |
|
collection=None, |
|
**kwargs, |
|
): |
|
super().__init__(username=username, **kwargs) |
|
self.name = kwargs.get("name", None) |
|
self.chat_history = kwargs.get("chat_history", []) |
|
self.role = role |
|
self.project = kwargs.get("project") if "project" in kwargs else project |
|
self.collection = ( |
|
kwargs.get("collection") if "collection" in kwargs else collection |
|
) |
|
self._key = key |
|
|
|
def add_message(self, role, content): |
|
if isinstance(content, str): |
|
content = content.strip().strip('"') |
|
elif isinstance(content, dict): |
|
content = content["content"].strip().strip('"') |
|
else: |
|
try: |
|
content = content.get("content", "").strip().strip('"') |
|
except: |
|
content = content |
|
|
|
self.chat_history.append( |
|
{ |
|
"role": role, |
|
"content": content, |
|
"role_type": self.role, |
|
} |
|
) |
|
|
|
def to_dict(self): |
|
return { |
|
"_key": self._key, |
|
"name": self.name, |
|
"chat_history": self.chat_history, |
|
"role": self.role, |
|
"username": self.username, |
|
} |
|
|
|
def update_in_arango(self): |
|
self.last_updated = datetime.now().isoformat() |
|
self.user_arango.db.collection("chats").insert( |
|
self.to_dict(), overwrite=True, overwrite_mode="update" |
|
) |
|
|
|
def set_name(self, user_input): |
|
llm = LLM( |
|
model="small", |
|
max_length_answer=50, |
|
temperature=0.4, |
|
system_message="You are a chatbot who will be chatting with a user", |
|
) |
|
prompt = ( |
|
f'Give a short name to the chat based on this user input: "{user_input}" ' |
|
"No more than 30 characters. Answer ONLY with the name of the chat." |
|
) |
|
name = llm.generate(prompt).content.strip('"') |
|
name = f'{name} - {datetime.now().strftime("%B %d")}' |
|
existing_chat = self.user_arango.db.aql.execute( |
|
f'FOR doc IN chats FILTER doc.name == "{name}" RETURN doc', count=True |
|
) |
|
if existing_chat.count() > 0: |
|
name = f'{name} ({datetime.now().strftime("%H:%M")})' |
|
name += f" - [{self.role}]" |
|
self.name = name |
|
return name |
|
|
|
def show_title(self, title=None): |
|
title = ( |
|
title |
|
if title |
|
else ( |
|
self.project |
|
if self.project |
|
else self.collection if self.collection else "No title" |
|
) |
|
) |
|
st.markdown( |
|
f"""### Chat about *{title.strip()}* with *{self.role}*""", |
|
) |
|
|
|
@classmethod |
|
def from_dict(cls, data): |
|
return cls( |
|
username=data.get("username"), |
|
name=data.get("name"), |
|
chat_history=data.get("chat_history", []), |
|
role=data.get("role", "Research Assistant"), |
|
_key=data.get("_key"), |
|
) |
|
|
|
def chat_history2bot(self, n_messages: int = None, remove_system: bool = False): |
|
history = [ |
|
{"role": m["role"], "content": m["content"]} for m in self.chat_history |
|
] |
|
if n_messages and len(history) > n_messages: |
|
history = history[-n_messages:] |
|
if ( |
|
all([history[0]["role"] == "system", remove_system]) |
|
or history[0]["role"] == "assistant" |
|
): |
|
history = history[1:] |
|
return history |
|
|
|
|
|
class StreamlitChat(Chat): |
|
""" |
|
A class to manage chat interactions within a Streamlit application. |
|
|
|
Inherits from the Chat class and provides additional functionality to handle |
|
chat history, user roles, and avatars within a Streamlit app context. |
|
|
|
Attributes: |
|
project (str): The project associated with the chat. |
|
collection (str): The collection associated with the chat. |
|
message_attachments (None): Placeholder for message attachments. |
|
last_updated (str): Timestamp of the last update in ISO format. |
|
_key (str): Unique identifier for the chat. |
|
role (str): The role of the user in the chat. |
|
username (str): The username of the user in the chat. |
|
name (str): The name of the chat. |
|
chat_history (list): List of messages in the chat history. |
|
|
|
Methods: |
|
show_chat_history(): |
|
get_avatar(message: dict = None, role=None) -> str: |
|
""" |
|
|
|
def __init__(self, username: str, role: str, _key: str = None, **kwargs): |
|
super().__init__(username, role, _key, **kwargs) |
|
self.project = kwargs.get("project", None) |
|
self.collection = kwargs.get("collection", None) |
|
self.message_attachments = None |
|
self.last_updated = datetime.now().isoformat() |
|
self._key = _key |
|
self.role = role |
|
|
|
if self._key: |
|
chat = self.user_arango.db.collection("chats").get(self._key) |
|
if chat: |
|
self.name = chat.get("name") |
|
self.chat_history = chat.get("chat_history", []) |
|
self.role = chat.get("role") |
|
self.username = chat.get("username") |
|
else: |
|
self._key = self.user_arango.db.collection("chats").insert( |
|
{ |
|
"name": self.name, |
|
"chat_history": self.chat_history, |
|
"role": self.role, |
|
"username": self.username, |
|
} |
|
)["_key"] |
|
|
|
def show_chat_history(self): |
|
""" |
|
Displays the chat history in the Streamlit app. |
|
|
|
Iterates through the chat history and displays messages from the user and assistant. |
|
Messages from other roles are ignored. Each message is displayed with an avatar. |
|
|
|
Returns: |
|
None |
|
""" |
|
for message in self.chat_history: |
|
if message["role"] not in ["user", "assistant"]: |
|
continue |
|
avatar = self.get_avatar(message) |
|
with st.chat_message(message["role"], avatar=avatar): |
|
if message["content"]: |
|
st.markdown(message["content"].strip('"')) |
|
|
|
def get_avatar(self, message: dict = None, role=None) -> str: |
|
""" |
|
Retrieves the avatar image path based on the message or role provided. |
|
|
|
Args: |
|
message (dict, optional): A dictionary containing message details, including the role. |
|
role (str, optional): The role of the user if the message is not provided. |
|
|
|
Returns: |
|
str: The file path to the avatar image. |
|
|
|
Raises: |
|
AssertionError: If neither message nor role is provided. |
|
""" |
|
assert message or role, "Either message or role must be provided" |
|
if message and message.get("role", None) == "user" or role == "user": |
|
avatar = st.session_state["settings"].get("avatar", "user") |
|
elif ( |
|
message and message.get("role", None) == "assistant" or role == "assistant" |
|
): |
|
role_type = message.get("role_type", self.role) if message else self.role |
|
if role_type == "Research Assistant": |
|
avatar = "img/avatar_researcher.png" |
|
elif role_type == "Editor": |
|
avatar = "img/avatar_editor.png" |
|
elif role_type == "Host": |
|
avatar = "img/avatar_host.png" |
|
elif role_type == "Guest": |
|
avatar = "img/avatar_guest.png" |
|
else: |
|
avatar = None |
|
else: |
|
avatar = None |
|
return avatar |
|
|
|
|
|
class Bot(BaseClass): |
|
""" |
|
A chatbot class that integrates with research tools and document retrieval systems. |
|
The Bot class provides an interface for conversational AI that can access and process |
|
various document sources, including scientific articles, user notes, and other documents. |
|
It initializes multiple specialized language models for different tasks, including |
|
regular conversation, query generation, and tool selection. |
|
Attributes: |
|
username (str): The username associated with this bot instance. |
|
chat (Chat): Chat instance for managing conversation history. |
|
project (Project, optional): Associated project for document context. |
|
collection (list, optional): Collections of documents to search within. |
|
arango_ids (list): List of document IDs in ArangoDB. |
|
chatbot (LLM): Main language bot for conversation. |
|
helperbot (LLM): Bot for generating queries. |
|
toolbot (LLM): Bot for selecting appropriate tools. |
|
tools (list): List of tool functions available to the bot. |
|
Methods: |
|
initiate_bots(): Initialize the different language model instances. |
|
get_chunks(): Retrieve relevant text chunks based on user input. |
|
answer_tool_call(): Process and execute tool calls from the AI. |
|
generate_from_notes(): Generate a response from user notes. |
|
generate_from_chunks(): Generate a response from document chunks. |
|
run(): Run the bot (implemented by subclasses). |
|
get_notes(): Retrieve notes from the database. |
|
fetch_science_articles_tool(): Retrieve scientific articles. |
|
fetch_other_documents_tool(): Retrieve non-scientific documents. |
|
fetch_science_articles_and_other_documents_tool(): Retrieve both document types. |
|
fetch_notes_tool(): Retrieve user notes. |
|
conversational_response_tool(): Generate a simple conversational response. |
|
|
|
""" |
|
|
|
def __init__(self, username: str, chat: Chat = None, tools: list = None, **kwargs): |
|
super().__init__(username=username, **kwargs) |
|
# Use the passed in chat or create a new Chat |
|
self.chat = chat if chat else Chat(username=username, role="Research Assistant") |
|
# Store or set up project/collection if available |
|
self.project: Project = kwargs.get("project", None) |
|
|
|
self.collection = kwargs.get("collection", None) |
|
if self.collection and not isinstance(self.collection, list): |
|
self.collection = [self.collection] |
|
elif self.project: |
|
self.collection = self.project.collections |
|
|
|
# Load articles in the collections |
|
self.arango_ids = [] |
|
|
|
# Bots to be initiated later |
|
self.chatbot = None |
|
self.helperbot = None |
|
self.toolbot = None |
|
|
|
if self.collection: |
|
for c in self.collection: |
|
for _id in self.user_arango.db.aql.execute( |
|
""" |
|
FOR doc IN article_collections |
|
FILTER doc.name == @collection |
|
FOR article IN doc.articles |
|
RETURN article._id |
|
""", |
|
bind_vars={"collection": c}, |
|
): |
|
self.arango_ids.append(_id) |
|
|
|
# Give tools to the bot |
|
if tools: |
|
# Map tool names to functions |
|
tool_mapping = { |
|
"fetch_other_documents_tool": self.fetch_other_documents_tool, |
|
"fetch_science_articles_tool": self.fetch_science_articles_tool, |
|
"fetch_science_articles_and_other_documents_tool": self.fetch_science_articles_and_other_documents_tool, |
|
"fetch_notes_tool": self.fetch_notes_tool, |
|
"conversational_response_tool": self.conversational_response_tool, |
|
"analyze_tool": self.analyze_tool, |
|
} |
|
if tools == "all": |
|
self.tools = list(tool_mapping.values()) |
|
else: |
|
self.tools = [ |
|
tool_mapping[tool] if isinstance(tool, str) else tool |
|
for tool in tools |
|
] |
|
else: |
|
self.tools = None |
|
|
|
self.initiate_bots() |
|
# Store other kwargs |
|
for arg in kwargs: |
|
setattr(self, arg, kwargs[arg]) |
|
|
|
# # Initiate the bots |
|
# try: |
|
# self.initiate_bots() |
|
# except Exception as e: |
|
# print_red(f"Error initiating bots: {e}") |
|
|
|
def initiate_bots(self): |
|
""" |
|
Initialize the different bot instances used in the chatbot application. |
|
|
|
Creates three types of bots: |
|
1. chatbot: A standard LLM for normal conversation with the user |
|
2. helperbot: A specialized LLM with low temperature for generating concise queries or prompts |
|
3. toolbot: A specialized LLM for selecting which tool to use when responding to user queries |
|
(only created if tools are provided) |
|
|
|
The toolbot is configured to prefer specialized tools over conversational responses |
|
when the user is seeking information rather than engaging in small talk. |
|
|
|
Note: |
|
- The chatbot uses the full chat history |
|
- The helperbot uses a limited chat history (last 4 messages) with system message removed |
|
- The toolbot uses a system message that lists all available tools |
|
""" |
|
# A standard LLM for normal chat |
|
self.chatbot = LLM(messages=self.chat.chat_history2bot()) |
|
# A helper bot for generating queries or short prompts |
|
self.helperbot = LLM( |
|
temperature=0, |
|
model="small", |
|
max_length_answer=500, |
|
system_message=get_query_builder_system_message(), |
|
messages=self.chat.chat_history2bot(n_messages=4, remove_system=True), |
|
) |
|
# A specialized LLM picking which tool to use |
|
if self.tools: |
|
tools_names = [tool.__name__ for tool in self.tools] |
|
tools_name_string = "\n– ".join(tools_names) |
|
self.toolbot = LLM( |
|
temperature=0, |
|
system_message=f""" |
|
You are an helpful assistant with tools. The tools you can choose from are: |
|
{tools_name_string} |
|
Your task is to choose one or multiple tools to answering a user's query. |
|
DON'T come up with your own tools, only use the ones provided. |
|
""", |
|
# system_message='Use one of the provided tools to help the answering bot to answer the user. Do not answer directly. Use the "tool_calls" field in your answer.', |
|
chat=True, |
|
model="tools", |
|
) |
|
if len(tools_names) > 1 and "conversational_response_tool" in tools_names: |
|
self.toolbot.system_message += "\n\nMake sure to only use the conversational response tool if the user is engaging in small talk. If the user is asking a question or looking for information, make sure to use one of the other tools!" |
|
|
|
def get_chunks( |
|
self, |
|
user_input, |
|
collections=["sci_articles", "other_documents"], |
|
n_results=7, |
|
n_sources=4, |
|
filter=True, |
|
where_filter: dict = {}, |
|
get_full_text=False, |
|
) -> UnifiedSearchResults: # Changed return type to match what's expected |
|
""" |
|
Retrieves relevant text chunks from the vector database based on user input. |
|
|
|
This method: |
|
1. Generates a vector query based on user input using the helper bot |
|
2. Searches multiple collections in the vector database |
|
3. Combines results and sorts them by relevance |
|
4. Limits results to the specified number of unique sources |
|
5. Cleans the text by removing footnote references |
|
6. Enriches the chunks with detailed metadata from ArangoDB |
|
7. Returns chunks as UnifiedDataChunk objects in a UnifiedSearchResults container |
|
|
|
Parameters: |
|
----------- |
|
user_input : str |
|
The user query to search for relevant documents |
|
collections : list, optional |
|
List of collection names to search in (default: ["sci_articles", "other_documents"]) |
|
n_results : int, optional |
|
Maximum number of results to return (default: 7) |
|
n_sources : int, optional |
|
Maximum number of unique document sources to include (default: 4) |
|
filter : bool, optional |
|
Whether to filter results by ArangoDB IDs (default: True) |
|
where_filter : dict, optional |
|
Additional filter criteria for the search (default: empty dict) |
|
get_full_text : bool, optional |
|
Whether to return the full text of the documents (default: False) |
|
|
|
Returns: |
|
-------- |
|
UnifiedSearchResults |
|
A Pydantic model containing the search results with: |
|
- chunks: List of UnifiedDataChunk objects containing: |
|
- content: The document text |
|
- metadata: Document metadata |
|
- source_type: The type of the source |
|
- source_ids: List of IDs for the sources |
|
""" |
|
print_blue("CHROMA FILTER:", filter) |
|
# Generate vector query using LLM |
|
response = self.helperbot.generate( |
|
get_generate_vector_query_prompt(user_input, self.chat.role), |
|
format=QueryResponse.model_json_schema(), |
|
) |
|
query = QueryResponse.model_validate_json(response.content).query |
|
print_purple(f"Query for vector DB:\n {query}") |
|
|
|
# Process chunks using ChromaDB's enhanced methods |
|
chromadb = self.get_chromadb() |
|
|
|
if filter: |
|
if where_filter in [None, {}]: |
|
where_filter = {"_id": {"$in": self.arango_ids}} |
|
else: |
|
where_filter = None |
|
|
|
# Get processed chunks from ChromaDB |
|
closest_chunks: list = chromadb.search_chunks( |
|
query=query, |
|
collections=collections, |
|
n_results=n_results, |
|
n_sources=n_sources, |
|
where=where_filter, |
|
max_retries=3, |
|
) |
|
|
|
# Fetch metadata from Arango and prepare uniform chunks |
|
source_ids = [] |
|
unified_chunks = [] |
|
|
|
for i, chunk in enumerate(closest_chunks): |
|
# Track IDs |
|
chunk_id = chunk["id"] |
|
arango_id = chunk["metadata"].get("_id") |
|
source_ids.append(chunk_id) |
|
|
|
# Get enhanced metadata from ArangoDB |
|
if arango_id: |
|
arango_metadata = self.user_arango.get_document_metadata(arango_id) |
|
if isinstance(arango_metadata, dict): |
|
# Add tracking IDs to metadata |
|
arango_metadata["chroma_id"] = chunk_id |
|
arango_metadata["arango_id"] = arango_id |
|
|
|
# Set metadata or create minimal version if not available |
|
metadata = arango_metadata |
|
else: |
|
# Create minimal metadata if ArangoDB doesn't return any |
|
metadata = { |
|
"title": "Unknown Document", |
|
"journal": None, |
|
"published_date": None, |
|
"chroma_id": chunk_id, |
|
"arango_id": arango_id |
|
} |
|
else: |
|
# Minimal metadata for chunks without arango_id |
|
metadata = { |
|
"title": "Unknown Document", |
|
"chroma_id": chunk_id |
|
} |
|
|
|
# Get full document text if requested |
|
document_content = "" |
|
if get_full_text and arango_id: |
|
doc = self.user_arango.db.collection("sci_articles").get(arango_id) |
|
document_content = doc.get("text", "") |
|
else: |
|
# Use the chunk text |
|
document_content = chunk.get("document", chunk.get("text", "")) |
|
|
|
# Determine source type based on collection |
|
source_type = "science_article" if "sci_article" in collections[0] else "other_document" |
|
|
|
# Create a UnifiedDataChunk (what the model expects) |
|
unified_chunk = UnifiedDataChunk( |
|
content=document_content, |
|
metadata=metadata, |
|
source_type=source_type, |
|
article_number=i+1 # Add article numbering |
|
) |
|
unified_chunks.append(unified_chunk) |
|
|
|
# Return the properly structured results |
|
return UnifiedSearchResults( |
|
chunks=unified_chunks, |
|
source_ids=source_ids |
|
) |
|
|
|
def answer_tool_call(self, response, user_input): |
|
""" |
|
Process tool calls returned by the AI and execute the corresponding functions. |
|
|
|
This method evaluates tool calls in the AI response, executes the appropriate |
|
functions with the provided arguments, and collects the resulting responses. |
|
|
|
Parameters: |
|
----------- |
|
response : dict |
|
The AI response containing potential tool_calls to be executed |
|
user_input : str |
|
The original user query that will be passed to tool functions |
|
|
|
Returns: |
|
-------- |
|
list |
|
A list of string responses generated from executing the tool calls. |
|
Returns an empty string if no tool calls are present. |
|
|
|
Notes: |
|
------ |
|
Supported tool functions include: |
|
- fetch_other_documents_tool: Retrieves non-scientific documents |
|
- fetch_science_articles_tool: Retrieves scientific articles |
|
- fetch_science_articles_and_other_documents_tool: Retrieves both types of documents |
|
- fetch_notes_tool: Retrieves user notes |
|
- conversational_response_tool: Generates a conversational response |
|
""" |
|
bot_responses = [] |
|
# This method returns / stores responses (no Streamlit calls) |
|
if not response.get("tool_calls"): |
|
return "" |
|
|
|
for tool in response.get("tool_calls"): |
|
function_name = tool.function.get("name") |
|
arguments = tool.function.arguments |
|
arguments["query"] = user_input |
|
|
|
if hasattr(self, function_name): |
|
print_purple("Function name:", function_name) |
|
if function_name in [ |
|
"fetch_other_documents_tool", |
|
"fetch_science_articles_tool", |
|
"fetch_science_articles_and_other_documents_tool", |
|
]: |
|
chunks = getattr(self, function_name)(**arguments) |
|
bot_responses.append(self.generate_from_chunks(user_input, chunks)) |
|
elif function_name == "fetch_notes_tool": |
|
notes = getattr(self, function_name)() |
|
bot_responses.append(self.generate_from_notes(user_input, notes)) |
|
elif function_name == "conversational_response_tool": |
|
response: OllamaMessage = getattr(self, function_name)(user_input) |
|
print_green("Conversation response:", response) |
|
bot_responses.append(response.content.strip('"')) |
|
return bot_responses |
|
|
|
# def process_user_input(self, user_input, content_attachment=None): |
|
# # Add user message |
|
# self.chat.add_message("user", user_input) |
|
# print('content_attachment', content_attachment) |
|
# if not content_attachment: |
|
# prompt = get_tools_prompt(user_input) |
|
# print('TOOLS PROMOT:', prompt) |
|
# print_red('\nToolbot system message:', self.toolbot.system_message) |
|
# response = self.toolbot.generate(prompt, tools=self.tools, stream=False) |
|
# print_rainbow(response) |
|
# if response.get("tool_calls"): |
|
# bot_response = self.answer_tool_call(response, user_input) |
|
# else: |
|
# # Just respond directly |
|
# bot_response = response.content.strip('"') |
|
# else: |
|
# # If there's an attachment, do something minimal |
|
# bot_response = "Content attachment received (Base Bot)." |
|
|
|
# # Add assistant message |
|
# if self.chat.chat_history[-1]["role"] != "assistant": |
|
# self.chat.add_message("assistant", bot_response) |
|
|
|
# # Update in Arango |
|
# self.chat.update_in_arango() |
|
# return bot_response |
|
|
|
def generate_from_notes(self, user_input, notes): |
|
""" |
|
Generate a response based on user input and a collection of notes. |
|
|
|
This method takes a user query and relevant notes, formats the notes into a string, |
|
creates a prompt with the formatted notes and user input, and generates a streamed response. |
|
|
|
Parameters |
|
---------- |
|
user_input : str |
|
The user's query or message to respond to |
|
notes : list of dict |
|
A list of note dictionaries, where each note has 'title' and 'content' keys |
|
|
|
Returns |
|
------- |
|
generator |
|
A generator that streams the AI-generated response |
|
|
|
Notes |
|
----- |
|
This method does not make any Streamlit calls and is safe to use outside of the Streamlit context. |
|
The notes are formatted with titles and content separated by horizontal rules. |
|
""" |
|
# No Streamlit calls |
|
notes_string = "" |
|
for note in notes: |
|
notes_string += ( |
|
f"\n# {note.get('title','No title')}\n{note.get('content','')}\n---\n" |
|
) |
|
prompt = get_chat_prompt( |
|
user_input, content_string=notes_string, role=self.chat.role |
|
) |
|
return self.chatbot.generate(prompt, stream=True) |
|
|
|
def generate_from_chunks(self, user_input, chunks: UnifiedSearchResults): |
|
""" |
|
Generate a response based on user input and retrieved document chunks. |
|
|
|
This method formats the retrieved document chunks into a structured string, |
|
combines it with the user's input in a prompt, and generates a streaming |
|
response using the chatbot. |
|
|
|
Parameters: |
|
----------- |
|
user_input : str |
|
The user's query or message to respond to. |
|
chunks : UnifiedSearchResults |
|
A Pydantic model containing document chunks as UnifiedDataChunk objects. |
|
|
|
Returns: |
|
-------- |
|
generator |
|
A streaming generator of the chatbot's response. |
|
""" |
|
# No Streamlit calls |
|
chunks_string = "" |
|
for chunk in chunks.chunks: |
|
user_notes_string = "" |
|
# Handle metadata from either a dict or object structure |
|
metadata = chunk.metadata if hasattr(chunk, 'metadata') else {} |
|
|
|
# Get user notes if available |
|
user_notes = metadata.get("user_notes") if isinstance(metadata, dict) else getattr(metadata, "user_notes", None) |
|
if user_notes: |
|
user_notes_string = f'\n\nUser notes:\n"""\n{user_notes}\n"""\n\n' |
|
|
|
# Get title |
|
title = metadata.get("title", "Untitled Document") if isinstance(metadata, dict) else getattr(metadata, "title", "Untitled Document") |
|
|
|
# Get content from either 'document' or 'content' |
|
content = chunk.content if hasattr(chunk, 'content') else getattr(chunk, "document", "") |
|
|
|
# Combine into structured format |
|
chunks_string += f"\n# {title}\n{user_notes_string}{content}\n---\n" |
|
|
|
# Create prompt and generate response |
|
prompt = get_chat_prompt( |
|
user_input, content_string=chunks_string, role=self.chat.role |
|
) |
|
return self.chatbot.generate(prompt, stream=True) |
|
|
|
def run(self): |
|
# Base Bot has no Streamlit run loop |
|
pass |
|
|
|
def get_notes(self) -> List: |
|
""" |
|
Returns all projects notes as a list of strings. |
|
""" |
|
# Minimal note retrieval |
|
notes_cursor = self.user_arango.db.aql.execute( |
|
"FOR doc IN notes FILTER doc._id IN @note_ids RETURN doc.text", |
|
bind_vars={"note_ids": self.project.notes}, |
|
) |
|
return list(notes_cursor) |
|
|
|
def fetch_science_articles_tool( |
|
self, query: str, n_documents: int = 6, retrieve_full_articles: bool = False |
|
) -> UnifiedSearchResults: |
|
""" |
|
Fetches information from scientific articles. |
|
|
|
Parameters: |
|
query (str): The search query to find relevant scientific articles. |
|
n_documents (int): How many documents to fetch. A complex query may require more documents. Min: 3, Max: 10. |
|
retrieve_full_articles (bool): If True, returns article IDs for full article processing. Default: False. |
|
|
|
Returns: |
|
UnifiedSearchResults: A structured result containing articles with their chunks or article IDs for full retrieval. |
|
""" |
|
|
|
where_filter = {} |
|
if hasattr(self, "chroma_ids_retrieved") and len(self.chroma_ids_retrieved) > 0: |
|
where_filter = {"_id": {"$in": self.chroma_ids_retrieved}} |
|
|
|
found_chunks = self.get_chunks( |
|
user_input=query, |
|
collections=["sci_articles"], |
|
n_results=n_documents, |
|
n_sources=max(n_documents, 4) |
|
) |
|
|
|
# Collect unique article IDs if full articles are requested |
|
if retrieve_full_articles: |
|
# Get unique article IDs from the chunks |
|
unique_article_ids = list(set([chunk.metadata._id for chunk in found_chunks.chunks |
|
if chunk.metadata and hasattr(chunk.metadata, '_id')])) |
|
|
|
# Return article IDs and metadata for full article processing |
|
return UnifiedSearchResults( |
|
chunks=[ |
|
UnifiedDataChunk( |
|
metadata=chunk.metadata, |
|
source_type="sci_article_full" |
|
) for chunk in found_chunks |
|
], |
|
source_ids=unique_article_ids |
|
) |
|
else: |
|
# Chunk-based processing |
|
unified_chunks = [ |
|
UnifiedDataChunk( |
|
content=chunk.content, |
|
metadata=chunk.metadata.model_dump(), |
|
source_type="science_article_chunk", |
|
) |
|
for chunk in found_chunks.chunks |
|
] |
|
return UnifiedSearchResults(chunks=unified_chunks, source_ids=found_chunks.chroma_ids) |
|
|
|
def fetch_other_documents_tool( |
|
self, query: str, n_documents: int = 6 |
|
) -> UnifiedSearchResults: |
|
""" |
|
Fetches information from other documents based on the user's query. |
|
|
|
Parameters: |
|
query (str): The search query provided by the user. |
|
n_documents (int): How many documents to fetch. Min: 2, Max: 10. |
|
|
|
Returns: |
|
UnifiedSearchResults: A structured result containing document chunks. |
|
""" |
|
n_documents = max(2, min(n_documents, 10)) |
|
|
|
found_chunks = self.get_chunks( |
|
user_input=query, |
|
collections=[f"{self.username}__other_documents"], |
|
n_results=n_documents, |
|
n_sources=max(n_documents, 4) |
|
) |
|
|
|
# Standardize the chunks using UnifiedDataChunk |
|
unified_chunks = [ |
|
UnifiedDataChunk( |
|
content=chunk.content, |
|
metadata=chunk.metadata.model_dump(), |
|
source_type="other_documents", |
|
) |
|
for chunk in found_chunks.chunks |
|
] |
|
|
|
return UnifiedSearchResults(chunks=unified_chunks, source_ids=found_chunks.chroma_ids) |
|
|
|
def fetch_science_articles_and_other_documents_tool( |
|
self, query: str, n_documents: int, whole_articles: bool = False |
|
) -> UnifiedSearchResults: |
|
""" |
|
Fetches information from both scientific articles and other documents. |
|
This method is often used when the user hasn't specified what kind of sources they are interested in. |
|
Args: |
|
query (str): The search query to fetch information for. |
|
n_documents (int): How many documents to fetch. A complex query may require more documents. Min: 3, Max: 10. |
|
whole_articles (bool): If True, fetches the entire article instead of just chunks, so that the whole article can be analyzed. Takes a lot of resources so use this only if important. Default is False. |
|
Returns: |
|
list: A list of document chunks that match the search query. |
|
""" |
|
assert isinstance(self, Bot), "The first argument must be a Bot object." |
|
n_documents = int(n_documents) |
|
if n_documents < 3: |
|
n_documents = 3 |
|
elif n_documents > 10: |
|
n_documents = 10 |
|
|
|
found_chunks: ChunkSearchResults = self.get_chunks( |
|
query, |
|
collections=["sci_articles", f"{self.username}__other_documents"], |
|
n_results=n_documents, |
|
) |
|
|
|
# Standardize the chunks using UnifiedDataChunk |
|
unified_chunks = [] |
|
for chunk in found_chunks.chunks: |
|
unified_chunk = UnifiedDataChunk( |
|
content=chunk.content, |
|
metadata=chunk.metadata, |
|
source_type="other_and_sci_documents", |
|
) |
|
unified_chunks.append(unified_chunk) |
|
# Return the unified search results |
|
result = UnifiedSearchResults(chunks=unified_chunks, source_ids=[]) |
|
return result |
|
|
|
def fetch_notes_tool(self) -> UnifiedSearchResults: |
|
""" |
|
Fetches information from the project notes and returns it in a unified format. |
|
|
|
Returns: |
|
UnifiedSearchResults: A unified representation of the notes. |
|
""" |
|
notes: list = self.get_notes() |
|
|
|
# Standardize the notes using UnifiedDataChunk |
|
unified_chunks = [ |
|
UnifiedDataChunk( |
|
content=note, |
|
metadata={"source_type": "notes"}, |
|
source_type="notes", |
|
) |
|
for note in notes |
|
] |
|
return UnifiedSearchResults(chunks=unified_chunks, source_ids=unified_chunks) |
|
|
|
def summarize_full_article_tool( |
|
self, article_id: str, question: str = None, arango_collection: str = "sci_articles" |
|
) -> str: |
|
""" |
|
Fetches a complete scientific article by ID and summarizes its content. |
|
This tool is useful when a comprehensive understanding of an entire article is needed. |
|
|
|
Parameters: |
|
article_id (str): The ID of the article to retrieve and summarize. |
|
question (str, optional): A specific question to focus the summary on. |
|
|
|
Returns: |
|
str: A detailed summary of the article focused on relevant information. |
|
""" |
|
|
|
try: |
|
if arango_collection == 'sci_articles': |
|
doc = self.base_arango.db.collection("sci_articles").get(article_id) |
|
full_text = self.base_arango.get_document_text(_id=article_id) |
|
else: |
|
arango_key = article_id.split("/")[-1] |
|
doc = self.user_arango.db.collection(arango_collection).get(article_id) |
|
full_text = self.base_arango.get_document_text(_ket=arango_key, collection=arango_collection) |
|
|
|
# Get article metadata |
|
metadata = { |
|
"title": doc.get("title", None), |
|
"authors": doc.get("authors", None), |
|
"journal": doc.get("journal", None), |
|
"published_date": doc.get("published_date", None), |
|
"doi": doc.get("doi", ""), |
|
"abstract": doc.get("abstract", None), |
|
} |
|
|
|
metadata_string = "" |
|
for k, v in metadata.items(): |
|
if v: |
|
metadata_string += f"{k.capitalize()}: {v}\n" |
|
|
|
# Create a prompt for summarization |
|
summary_prompt = f''' |
|
You are a research assistant helping with an investigation on: |
|
"{question}" |
|
|
|
Please read this complete scientific article and create a comprehensive PM. |
|
|
|
{metadata_string} |
|
|
|
FULL TEXT: |
|
""" |
|
{full_text} |
|
""" |
|
|
|
Create a structured, detailed PM of this article focusing on information relevant to |
|
the research question. Include key findings, methodologies, and conclusions. |
|
Do not answer the research question directly - just summarize the article's content. A researcher will later draw conclusions etc. |
|
Make sure to preserve important details and evidence from the original. |
|
''' |
|
|
|
# Use a small model for efficient summarization |
|
summary: OllamaMessage = self.generate(query=summary_prompt, model="small", stream=False) |
|
summary_text = summary.content.strip('"') |
|
|
|
# Format with source information |
|
formatted_summary = f"{metadata_string}\n\nSUMMARY:\n{summary_text}" |
|
|
|
|
|
return formatted_summary |
|
|
|
except Exception as e: |
|
print_red(f"Error summarizing article {article_id}: {str(e)}") |
|
return f"Error processing article {article_id}: {str(e)}" |
|
def conversational_response_tool(self, query: str): |
|
""" |
|
Generate a conversational response to a user's query. |
|
|
|
This method is designed to provide a short and conversational response without fetching additional data. |
|
It should be used ONLY when it is clear that the user is engaging in small talk (like saying 'hi'). |
|
|
|
Args: |
|
query (str): The user's message to which the bot should respond. |
|
|
|
Returns: |
|
str: The generated conversational response. |
|
""" |
|
query = f""" |
|
User message: "{query}". |
|
Make your answer short and conversational. |
|
Don't answer with anything you're not sure of! |
|
""" |
|
|
|
return self.chatbot.generate(query, stream=True) |
|
|
|
def analyze_tool(self, text: str, instructions: str) -> str: |
|
""" |
|
This tool is used to analyze information based on the provided instructions. |
|
Use it to extract insights or perform other analytical tasks. |
|
The instructions should be clear and specific for the information provided. |
|
|
|
Args: |
|
text (str): The text content to be analyzed. |
|
instructions (str): Specific instructions guiding how the analysis should be performed. |
|
|
|
Returns: |
|
str: The analysis result from the language model. |
|
""" |
|
|
|
query = f''' |
|
Analyze the following information based on the instructions provided. |
|
following: \n"""\n{text}\n""\n\n |
|
Instructions: \n"""\n{instructions}\n""" |
|
''' |
|
|
|
print_blue("\nQuery for analysis:\n", query, "\n") |
|
response = self.llm.generate(query=query, model=self.model) |
|
return response.content if hasattr(response, "content") else str(response) |
|
|
|
|
|
class StreamlitBot(Bot): |
|
def __init__( |
|
self, username: str, chat: StreamlitChat = None, tools: list = None, **kwargs |
|
): |
|
super().__init__(username=username, chat=chat, tools=tools, **kwargs) |
|
|
|
# For Streamlit, we can override or add attributes |
|
if "llm_chosen_backend" not in st.session_state: |
|
st.session_state["llm_chosen_backend"] = None |
|
|
|
self.chatbot.chosen_backend = st.session_state["llm_chosen_backend"] |
|
if not st.session_state["llm_chosen_backend"]: |
|
st.session_state["llm_chosen_backend"] = self.chatbot.chosen_backend |
|
|
|
settings = self.get_settings() |
|
if settings.get("use_reasoning_model", False): |
|
self.chatbot.model = self.chatbot.get_model("reasoning") |
|
|
|
print_rainbow(settings) |
|
print("MODEL", self.chatbot.model) |
|
|
|
def run(self): |
|
# Example Streamlit run loop |
|
title = ( |
|
self.project.name |
|
if self.project |
|
else self.collection.name if self.collection else None |
|
) |
|
self.chat.show_title(title=title) |
|
self.chat.show_chat_history() |
|
if user_input := st.chat_input("Write your message here...", accept_file=True): |
|
text_input = user_input.text.replace('"""', "---") |
|
if len(user_input.files) > 1: |
|
st.error("Please upload only one file at a time.") |
|
return |
|
attached_file = user_input.files[0] if user_input.files else None |
|
|
|
content_attachment = None |
|
if attached_file: |
|
if attached_file.type == "application/pdf": |
|
import fitz |
|
|
|
pdf_document = fitz.open( |
|
stream=attached_file.read(), filetype="pdf" |
|
) |
|
pdf_text = "" |
|
for page_num in range(len(pdf_document)): |
|
page = pdf_document.load_page(page_num) |
|
pdf_text += page.get_text() |
|
content_attachment = pdf_text |
|
elif attached_file.type in ["image/png", "image/jpeg"]: |
|
self.chat.message_attachments = "image" |
|
content_attachment = attached_file.read() |
|
with st.chat_message( |
|
"user", avatar=self.chat.get_avatar(role="user") |
|
): |
|
st.image(content_attachment) |
|
|
|
with st.chat_message("user", avatar=self.chat.get_avatar(role="user")): |
|
st.write(text_input) |
|
|
|
if not self.chat.name: |
|
self.chat.set_name(text_input) |
|
self.chat.last_updated = datetime.now().isoformat() |
|
self.chat.saved = False |
|
self.user_arango.db.collection("chats").insert( |
|
self.chat.to_dict(), overwrite=True, overwrite_mode="update" |
|
) |
|
|
|
# Use streaming tool calling - simpler and better UX |
|
self.process_user_input(text_input, content_attachment) |
|
|
|
def get_settings(self): |
|
return self.user_arango.db.document("settings/settings") |
|
|
|
def process_user_input(self, user_input, content_attachment=None): |
|
""" |
|
Process user input using streaming tool calling for better UX and simpler code. |
|
This method allows the LLM to call tools and generate responses in a single streaming call. |
|
""" |
|
# Add user message to chat history |
|
self.chat.add_message("user", user_input) |
|
|
|
# Remove conversational response tool if there are more than 2 messages |
|
# This prevents small talk responses when users are asking research questions |
|
if len(self.chat.chat_history) > 2 and len(self.tools) > 1: |
|
self.tools = [tool for tool in self.tools if tool.__name__ != "conversational_response_tool"] |
|
|
|
if not content_attachment: |
|
# Single streaming call with tools - much simpler than the old two-step approach |
|
prompt = get_chat_prompt(user_input, role=self.chat.role) |
|
|
|
with st.chat_message("assistant", avatar=self.chat.get_avatar(role="assistant")): |
|
# Stream response with tool calling enabled |
|
response_stream = self.chatbot.generate( |
|
prompt, |
|
tools=self.tools, |
|
stream=True, |
|
think=True # Keep thinking mode if enabled |
|
) |
|
|
|
bot_response = self.handle_streaming_with_tools(response_stream, user_input) |
|
else: |
|
# Handle attachments (images, PDFs) using existing approach |
|
with st.chat_message("assistant", avatar=self.chat.get_avatar(role="assistant")): |
|
if self.chat.message_attachments == "image": |
|
prompt = get_chat_prompt( |
|
user_input, role=self.chat.role, image_attachment=True |
|
) |
|
bot_resp = self.chatbot.generate( |
|
prompt, |
|
stream=False, |
|
images=[content_attachment], |
|
model="vision", |
|
) |
|
if isinstance(bot_resp, dict): |
|
bot_resp = bot_resp.get("content", "") |
|
elif isinstance(bot_resp, OllamaMessage): |
|
bot_resp = bot_resp.content |
|
st.write(bot_resp) |
|
bot_response = bot_resp |
|
else: |
|
prompt = get_chat_prompt( |
|
user_input, |
|
content_attachment=content_attachment, |
|
role=self.chat.role, |
|
) |
|
response = self.chatbot.generate(prompt, stream=True) |
|
bot_response = st.write_stream(response) |
|
|
|
# Update chat history and database |
|
if self.chat.chat_history[-1]["role"] != "assistant": |
|
self.chat.add_message("assistant", bot_response) |
|
self.chat.update_in_arango() |
|
|
|
def handle_streaming_with_tools(self, response_stream, user_input): |
|
""" |
|
Handle streaming response that may include tool calls, thinking, and content. |
|
This replaces the manual tool execution with real-time streaming tool calling. |
|
""" |
|
thinking_content = [] |
|
content_chunks = [] |
|
tool_executions = [] |
|
sources_info = "" |
|
final_content = "" |
|
|
|
# Process the streaming response |
|
for chunk in response_stream: |
|
if isinstance(chunk, tuple): |
|
chunk_type, content = chunk |
|
|
|
if chunk_type == "thinking": |
|
thinking_content.append(content) |
|
|
|
elif chunk_type == "content": |
|
content_chunks.append(content) |
|
|
|
elif chunk_type == "tool_execution": |
|
# Show tool execution in real-time |
|
st.info(f"🔍 {content}") |
|
tool_executions.append(content) |
|
|
|
# Handle direct tool calls from Ollama (if using native tool calling) |
|
elif hasattr(chunk, 'tool_calls') and chunk.tool_calls: |
|
for tool_call in chunk.tool_calls: |
|
function_name = tool_call.function.name |
|
|
|
# Execute the tool and get results |
|
if hasattr(self, function_name): |
|
st.info(f"🔍 Searching {function_name.replace('_tool', '').replace('_', ' ')}...") |
|
|
|
# Execute tool function |
|
arguments = tool_call.function.arguments |
|
arguments["query"] = user_input |
|
|
|
if function_name in ["fetch_science_articles_tool", "fetch_other_documents_tool", |
|
"fetch_science_articles_and_other_documents_tool"]: |
|
chunks = getattr(self, function_name)(**arguments) |
|
sources_info = self.format_sources_for_streaming(chunks) |
|
|
|
elif hasattr(chunk, 'content') and chunk.content: |
|
content_chunks.append(chunk.content) |
|
|
|
# Display thinking content in expandable section if present |
|
if thinking_content and len("".join(thinking_content).strip()) > 10: |
|
with st.expander("🤔 How the bot reasoned"): |
|
st.write("".join(thinking_content)) |
|
|
|
# Stream the main content |
|
if content_chunks: |
|
def content_generator(): |
|
for content in content_chunks: |
|
yield content |
|
|
|
final_content = st.write_stream(content_generator()) |
|
|
|
# Show sources if any were found |
|
if sources_info: |
|
st.markdown(sources_info) |
|
final_content += f"\n\n{sources_info}" |
|
|
|
return final_content |
|
|
|
def format_sources_for_streaming(self, chunks): |
|
"""Format source information for display in streaming tool calling""" |
|
if not chunks or not chunks.chunks: |
|
return "" |
|
|
|
sources = "###### Sources:\n" |
|
for i, chunk in enumerate(chunks.chunks): |
|
# Get metadata (handle both dict and object forms) |
|
metadata = chunk.metadata if isinstance(chunk.metadata, dict) else chunk.metadata |
|
|
|
if isinstance(metadata, dict): |
|
journal = metadata.get("journal", "No Journal") or "No Journal" |
|
date = metadata.get("published_date", "No Date") or "No Date" |
|
title = metadata.get("title", "Untitled") or "Untitled" |
|
else: |
|
journal = getattr(metadata, "journal", "No Journal") or "No Journal" |
|
date = getattr(metadata, "published_date", "No Date") or "No Date" |
|
title = getattr(metadata, "title", "Untitled") or "Untitled" |
|
|
|
article_num = getattr(chunk, "article_number", i+1) |
|
sources += f"[{article_num}] **{title}** :gray[*{journal}* ({date})] \n" |
|
|
|
return sources |
|
|
|
def generate_from_notes(self, user_input, notes): |
|
with st.spinner("Reading project notes..."): |
|
return super().generate_from_notes(user_input, notes) |
|
|
|
def generate_from_chunks(self, user_input, chunks: ChunkSearchResults): |
|
# For reading articles with a spinner |
|
magazines = set() |
|
for chunk in chunks.chunks: |
|
if chunk.metadata: |
|
journal = chunk.metadata.journal or "No Journal" |
|
magazines.add(f"*{journal}*") |
|
|
|
# Create spinner message |
|
if len(magazines) > 1: |
|
spinner_text = f"Reading articles from {', '.join(list(magazines)[:-1])} and {list(magazines)[-1]}..." |
|
else: |
|
spinner_text = "Reading articles..." |
|
|
|
with st.spinner(spinner_text): |
|
return super().generate_from_chunks(user_input, chunks) |
|
|
|
def sidebar_content(self): |
|
with st.sidebar: |
|
st.write("---") |
|
st.markdown(f'#### {self.chat.name if self.chat.name else ""}') |
|
st.button("Delete this chat", on_click=self.delete_chat) |
|
|
|
def delete_chat(self): |
|
self.user_arango.db.collection("chats").delete_match( |
|
filters={"name": self.chat.name} |
|
) |
|
self.chat = Chat() |
|
|
|
def get_notes(self): |
|
# We can show a spinner or messages too |
|
with st.spinner("Fetching notes..."): |
|
return super().get_notes() |
|
|
|
|
|
class EditorBot(StreamlitBot): |
|
def __init__(self, username: str, chat: Chat, **kwargs): |
|
super().__init__(username=username, chat=chat, **kwargs) |
|
self.role = "Editor" |
|
self.tools = [self.fetch_notes_tool, self.fetch_other_documents_tool] |
|
# self.chatbot = LLM( |
|
# system_message=get_editor_prompt(kwargs.get("project")), |
|
# messages=self.chat.chat_history2bot(), |
|
# chosen_backend=kwargs.get("chosen_backend"), |
|
# ) |
|
print_purple("MODEL FOR EDITOR BOT:", self.chatbot.model) |
|
|
|
|
|
class ResearchAssistantBot(StreamlitBot): |
|
def __init__(self, username: str, chat: Chat, **kwargs): |
|
super().__init__(username=username, chat=chat, **kwargs) |
|
self.role = "Research Assistant" |
|
# self.chatbot = LLM( |
|
# system_message=get_assistant_prompt(), |
|
# temperature=0.1, |
|
# messages=self.chat.chat_history2bot(), |
|
# ) |
|
self.tools = [ |
|
self.fetch_science_articles_tool, |
|
self.fetch_science_articles_and_other_documents_tool, |
|
self.conversational_response_tool, |
|
] |
|
|
|
|
|
class PodBot(StreamlitBot): |
|
"""Two LLM agents construct a conversation using material from science articles.""" |
|
|
|
def __init__( |
|
self, |
|
username: str, |
|
chat: Chat, |
|
subject: str, |
|
instructions: str = None, |
|
**kwargs, |
|
): |
|
super().__init__(username=username, chat=chat, **kwargs) |
|
self.subject = subject |
|
self.instructions = instructions |
|
self.guest_name = kwargs.get("name_guest", "Merit") |
|
self.hostbot = HostBot( |
|
Chat(username=self.username, role="Host"), |
|
subject, |
|
username, |
|
instructions=instructions, |
|
**kwargs, |
|
) |
|
self.guestbot = GuestBot( |
|
Chat(username=self.username, role="Guest"), |
|
subject, |
|
username, |
|
name_guest=self.guest_name, |
|
**kwargs, |
|
) |
|
|
|
def run(self): |
|
|
|
notes = self.get_notes() |
|
notes_string = "" |
|
if self.instructions: |
|
instructions_string = f''' |
|
These are the instructions for the podcast from the producer: |
|
""" |
|
{self.instructions} |
|
""" |
|
''' |
|
else: |
|
instructions_string = "" |
|
|
|
for note in notes: |
|
notes_string += f"\n# {note['title']}\n{note['content']}\n---\n" |
|
a = f'''You will make a podcast interview with {self.guest_name}, an expert on "{self.subject}". |
|
{instructions_string} |
|
Below are notes on the subject that you can use to ask relevant questions: |
|
""" |
|
{notes_string} |
|
""" |
|
Say hello to the expert and start the interview. Remember to keep the interview to the subject of {self.subject} throughout the conversation. |
|
''' |
|
|
|
# Stop button for the podcast |
|
with st.sidebar: |
|
stop = st.button("Stop podcast", on_click=self.stop_podcast) |
|
|
|
while st.session_state["make_podcast"]: |
|
# Stop the podcast if there are more than 14 messages in the chat |
|
self.chat.show_chat_history() |
|
if len(self.chat.chat_history) == 14: |
|
result = self.hostbot.generate( |
|
"The interview has ended. Say thank you to the expert and end the conversation." |
|
) |
|
self.chat.add_message("Host", result) |
|
with st.chat_message( |
|
"assistant", avatar=self.chat.get_avatar(role="assistant") |
|
): |
|
st.write(result.strip('"')) |
|
st.stop() |
|
|
|
_q = self.hostbot.toolbot.generate( |
|
query=f"{self.guest_name} has answered: {a}. You have to choose a tool to help the host continue the interview.", |
|
tools=self.hostbot.tools, |
|
temperature=0.6, |
|
stream=False, |
|
) |
|
if "tool_calls" in _q: |
|
q = self.hostbot.answer_tool_call(_q, a) |
|
else: |
|
q = _q |
|
|
|
self.chat.add_message("Host", q) |
|
|
|
_a = self.guestbot.toolbot.generate( |
|
f'The podcast host has asked: "{q}" Choose a tool to help the expert answer with relevant facts and information.', |
|
tools=self.guestbot.tools, |
|
) |
|
if "tool_calls" in _a: |
|
print_yellow("Tool call response (guest)", _a) |
|
print_yellow(self.guestbot.chat.role) |
|
a = self.guestbot.answer_tool_call(_a, q) |
|
else: |
|
a = _a |
|
self.chat.add_message("Guest", a) |
|
|
|
self.update_session_state() |
|
|
|
def stop_podcast(self): |
|
st.session_state["make_podcast"] = False |
|
self.update_session_state() |
|
self.chat.show_chat_history() |
|
|
|
|
|
class HostBot(StreamlitBot): |
|
def __init__( |
|
self, chat: Chat, subject: str, username: str, instructions: str, **kwargs |
|
): |
|
super().__init__(chat=chat, username=username, **kwargs) |
|
self.chat.role = kwargs.get("role", "Host") |
|
self.tools = [self.fetch_notes_tool, self.conversational_response_tool] |
|
self.instructions = instructions |
|
self.llm = LLM( |
|
system_message=f''' |
|
You are the host of a podcast and an expert on {subject}. You will ask one question at a time about the subject, and then wait for the guest to answer. |
|
Don't ask the guest to talk about herself/himself, only about the subject. |
|
Make your questions short and clear, only if necessary add a brief context to the question. |
|
These are the instructions for the podcast from the producer: |
|
""" |
|
{self.instructions} |
|
""" |
|
If the experts' answer is complicated, try to make a very brief summary of it for the audience to understand. You can also ask follow-up questions to clarify the answer, or ask for examples. |
|
''', |
|
messages=self.chat.chat_history2bot(), |
|
) |
|
self.toolbot = LLM( |
|
temperature=0, |
|
system_message=""" |
|
You are assisting a podcast host in asking questions to an expert. |
|
Choose one or many tools to use in order to assist the host in asking relevant questions. |
|
Often "conversational_response_tool" is enough, but sometimes project notes are needed. |
|
Make sure to read the description of the tools carefully!""", |
|
chat=True, |
|
model="tools", |
|
) |
|
|
|
def generate(self, query): |
|
return self.llm.generate(query) |
|
|
|
|
|
class GuestBot(StreamlitBot): |
|
def __init__(self, chat: Chat, subject: str, username: str, **kwargs): |
|
super().__init__(chat=chat, username=username, **kwargs) |
|
self.chat.role = kwargs.get("role", "Guest") |
|
self.tools = [ |
|
self.fetch_notes_tool, |
|
self.fetch_science_articles_tool, |
|
] |
|
|
|
self.llm = LLM( |
|
system_message=f""" |
|
You are {kwargs.get('name', 'Merit')}, an expert on {subject}. |
|
Today you are a guest in a podcast about {subject}. A host will ask you questions about the subject and you will answer by using scientific facts and information. |
|
When answering, don't say things like "based on the documents" or alike, as neither the host nor the audience can see the documents. Act just as if you were talking to someone in a conversation. |
|
Try to be concise when answering, and remember that the audience of the podcast is not expert on the subject, so don't complicate things too much. |
|
It's very important that you answer in a "spoken" way, as if you were talking to someone in a conversation. That means you should avoid using scientific jargon and complex terms, too many figures or abstract concepts. |
|
Lists are also not recommended, instead use "for the first reason", "secondly", etc. |
|
Instead, use "..." to indicate a pause, "-" to indicate a break in the sentence, as if you were speaking. |
|
""", |
|
messages=self.chat.chat_history2bot(), |
|
) |
|
self.toolbot = LLM( |
|
temperature=0, |
|
system_message=f"You are an assistant to an expert on {subject}. Choose one or many tools to use in order to assist the expert in answering questions. Make sure to read the description of the tools carefully.", |
|
chat=False, |
|
model="tools", |
|
) |
|
|
|
def generate(self, query): |
|
return self.llm.generate(query) |
|
|
|
|
|
# if __name__ == "__main__": |
|
# from _arango import ArangoDB |
|
|
|
# # Example usage |
|
# from dotenv import load_dotenv |
|
# import os |
|
|
|
# load_dotenv() |
|
# question = "What are the environmental impacts of lithium mining?" |
|
# username = "lasse" |
|
# user_arango = ArangoDB(user="lasse", password=os.getenv("ARANGO_PASSWORD")) |
|
# base_arango = ArangoDB( |
|
# user="admin", password=os.getenv("ARANGO_PASSWORD"), db_name="base" |
|
# ) |
|
# project = Project( |
|
# username=username, project_name="Electric Cars", user_arango=user_arango |
|
# ) |
|
# bot = Bot(username=username, project=project) |
|
# bot.run() |
|
|
|
# result = bot.fetch_science_articles_tool( |
|
# "lithium mining", n_documents=4, whole_articles=True |
|
# ) |
|
# print(result.arango_ids) |
|
# for _id in result.arango_ids: |
|
# doc = base_arango.db.collection("sci_articles").get(_id) |
|
# text = '' |
|
# for chunk in doc["chunks"]: |
|
# text += chunk['text'] |
|
|
|
# q = f''' |
|
# You are a research assistant. You are helping a research to answer the question "{question}". |
|
# The article below is probably relevant to the question. Please read it to make a PM. |
|
|
|
# """ |
|
# {text} |
|
# """ |
|
|
|
# Please write a PM based on the article with focus on the question: {question} |
|
# *Don't answer the question directly!* Just make a summary of the article – the researcher will use your summary to answer the question. |
|
# Make the PM structured and clear, and make sure to include all the relevant derails. |
|
# ''' |
|
# pm = bot.chatbot.generate(q, model="small") |
|
# print(pm) |
|
|
|
# exit()
|
|
|