parent
732793b79f
commit
83044a905b
11 changed files with 1501 additions and 275 deletions
@ -0,0 +1,800 @@ |
||||
from datetime import datetime |
||||
import streamlit as st |
||||
from _base_class import StreamlitBaseClass, BaseClass |
||||
from _llm import LLM |
||||
from prompts import * |
||||
from colorprinter.print_color import * |
||||
from llm_tools import ToolRegistry |
||||
|
||||
class Chat(StreamlitBaseClass): |
||||
def __init__(self, username=None, **kwargs): |
||||
super().__init__(username=username, **kwargs) |
||||
self.name = kwargs.get("name", None) |
||||
self.chat_history = kwargs.get("chat_history", []) |
||||
|
||||
|
||||
def add_message(self, role, content): |
||||
self.chat_history.append( |
||||
{ |
||||
"role": role, |
||||
"content": content.strip().strip('"'), |
||||
"role_type": self.role, |
||||
} |
||||
) |
||||
|
||||
def to_dict(self): |
||||
return { |
||||
"_key": self._key, |
||||
"name": self.name, |
||||
"chat_history": self.chat_history, |
||||
"role": self.role, |
||||
"username": self.username, |
||||
} |
||||
|
||||
def update_in_arango(self): |
||||
self.last_updated = datetime.now().isoformat() |
||||
self.user_arango.db.collection("chats").insert( |
||||
self.to_dict(), overwrite=True, overwrite_mode="update" |
||||
) |
||||
|
||||
def set_name(self, user_input): |
||||
llm = LLM( |
||||
model="small", |
||||
max_length_answer=50, |
||||
temperature=0.4, |
||||
system_message="You are a chatbot who will be chatting with a user", |
||||
) |
||||
prompt = ( |
||||
f'Give a short name to the chat based on this user input: "{user_input}" ' |
||||
"No more than 30 characters. Answer ONLY with the name of the chat." |
||||
) |
||||
name = llm.generate(prompt).content.strip('"') |
||||
name = f'{name} - {datetime.now().strftime("%B %d")}' |
||||
existing_chat = self.user_arango.db.aql.execute( |
||||
f'FOR doc IN chats FILTER doc.name == "{name}" RETURN doc', count=True |
||||
) |
||||
if existing_chat.count() > 0: |
||||
name = f'{name} ({datetime.now().strftime("%H:%M")})' |
||||
name += f" - [{self.role}]" |
||||
self.name = name |
||||
return name |
||||
|
||||
@classmethod |
||||
def from_dict(cls, data): |
||||
return cls( |
||||
username=data.get("username"), |
||||
name=data.get("name"), |
||||
chat_history=data.get("chat_history", []), |
||||
role=data.get("role", "Research Assistant"), |
||||
_key=data.get("_key"), |
||||
) |
||||
|
||||
def chat_history2bot(self, n_messages: int = None, remove_system: bool = False): |
||||
history = [ |
||||
{"role": m["role"], "content": m["content"]} for m in self.chat_history |
||||
] |
||||
if n_messages and len(history) > n_messages: |
||||
history = history[-n_messages:] |
||||
if ( |
||||
all([history[0]["role"] == "system", remove_system]) |
||||
or history[0]["role"] == "assistant" |
||||
): |
||||
history = history[1:] |
||||
return history |
||||
|
||||
|
||||
class Bot(BaseClass): |
||||
def __init__(self, username: str, chat: Chat = None, tools: list = None, **kwargs): |
||||
super().__init__(username=username, **kwargs) |
||||
|
||||
# Use the passed in chat or create a new Chat |
||||
self.chat = chat if chat else Chat(username=username, role="Research Assistant") |
||||
print_yellow(f"Chat:", chat, type(chat)) |
||||
# Store or set up project/collection if available |
||||
self.project = kwargs.get("project", None) |
||||
self.collection = kwargs.get("collection", None) |
||||
if self.collection and not isinstance(self.collection, list): |
||||
self.collection = [self.collection] |
||||
|
||||
# Load articles in the collections |
||||
self.arango_ids = [] |
||||
if self.collection: |
||||
for c in self.collection: |
||||
for _id in self.user_arango.db.aql.execute( |
||||
""" |
||||
FOR doc IN article_collections |
||||
FILTER doc.name == @collection |
||||
FOR article IN doc.articles |
||||
RETURN article._id |
||||
""", |
||||
bind_vars={"collection": c}, |
||||
): |
||||
self.arango_ids.append(_id) |
||||
|
||||
# A standard LLM for normal chat |
||||
self.chatbot = LLM(messages=self.chat.chat_history2bot()) |
||||
# A helper bot for generating queries or short prompts |
||||
self.helperbot = LLM( |
||||
temperature=0, |
||||
model="small", |
||||
max_length_answer=500, |
||||
system_message=get_query_builder_system_message(), |
||||
messages=self.chat.chat_history2bot(n_messages=4, remove_system=True), |
||||
) |
||||
# A specialized LLM picking which tool to use |
||||
self.toolbot = LLM( |
||||
temperature=0, |
||||
system_message=""" |
||||
You are an assistant bot helping an answering bot to answer a user's messages. |
||||
Your task is to choose one or multiple tools that will help the answering bot to provide the user with the best possible answer. |
||||
You should NEVER directly answer the user. You MUST choose a tool. |
||||
""", |
||||
chat=False, |
||||
model="small", |
||||
) |
||||
|
||||
# Load or register the passed-in tools |
||||
if tools: |
||||
self.tools = ToolRegistry.get_tools(tools=tools) |
||||
else: |
||||
self.tools = ToolRegistry.get_tools() |
||||
|
||||
# Store other kwargs |
||||
for arg in kwargs: |
||||
setattr(self, arg, kwargs[arg]) |
||||
|
||||
|
||||
|
||||
|
||||
def get_chunks( |
||||
self, |
||||
user_input, |
||||
collections=["sci_articles", "other_documents"], |
||||
n_results=7, |
||||
n_sources=4, |
||||
filter=True, |
||||
): |
||||
# Basic version without Streamlit calls |
||||
query = self.helperbot.generate( |
||||
get_generate_vector_query_prompt(user_input, self.chat.role) |
||||
).content.strip('"') |
||||
|
||||
combined_chunks = [] |
||||
if collections: |
||||
for collection in collections: |
||||
where_filter = {"_id": {"$in": self.arango_ids}} if filter else {} |
||||
chunks = self.get_chromadb().query( |
||||
query=query, |
||||
collection=collection, |
||||
n_results=n_results, |
||||
n_sources=n_sources, |
||||
where=where_filter, |
||||
max_retries=3, |
||||
) |
||||
for doc, meta, dist in zip( |
||||
chunks["documents"][0], |
||||
chunks["metadatas"][0], |
||||
chunks["distances"][0], |
||||
): |
||||
combined_chunks.append( |
||||
{"document": doc, "metadata": meta, "distance": dist} |
||||
) |
||||
combined_chunks.sort(key=lambda x: x["distance"]) |
||||
|
||||
# Keep the best chunks according to n_sources |
||||
sources = set() |
||||
closest_chunks = [] |
||||
for chunk in combined_chunks: |
||||
source_id = chunk["metadata"].get("_id", "no_id") |
||||
if source_id not in sources: |
||||
sources.add(source_id) |
||||
closest_chunks.append(chunk) |
||||
if len(sources) >= n_sources: |
||||
break |
||||
if len(closest_chunks) < n_results: |
||||
remaining_chunks = [ |
||||
c for c in combined_chunks if c not in closest_chunks |
||||
] |
||||
closest_chunks.extend(remaining_chunks[: n_results - len(closest_chunks)]) |
||||
|
||||
# Now fetch real metadata from Arango |
||||
for chunk in closest_chunks: |
||||
_id = chunk["metadata"].get("_id") |
||||
if not _id: |
||||
continue |
||||
if _id.startswith("sci_articles"): |
||||
arango_doc = self.base_arango.db.document(_id) |
||||
else: |
||||
arango_doc = self.user_arango.db.document(_id) |
||||
if arango_doc: |
||||
arango_metadata = arango_doc.get("metadata", {}) |
||||
# Possibly merge notes |
||||
if "user_notes" in arango_doc: |
||||
arango_metadata["user_notes"] = arango_doc["user_notes"] |
||||
chunk["metadata"] = arango_metadata |
||||
|
||||
# Group by article title |
||||
grouped_chunks = {} |
||||
article_number = 1 |
||||
for chunk in closest_chunks: |
||||
title = chunk["metadata"].get("title", "No title") |
||||
chunk["article_number"] = article_number |
||||
if title not in grouped_chunks: |
||||
grouped_chunks[title] = { |
||||
"article_number": article_number, |
||||
"chunks": [], |
||||
} |
||||
article_number += 1 |
||||
grouped_chunks[title]["chunks"].append(chunk) |
||||
return grouped_chunks |
||||
|
||||
def answer_tool_call(self, response, user_input): |
||||
bot_responses = [] |
||||
# This method returns / stores responses (no Streamlit calls) |
||||
if not response.get("tool_calls"): |
||||
return "" |
||||
|
||||
for tool in response.get("tool_calls"): |
||||
function_name = tool.function.get('name') |
||||
arguments = tool.function.arguments |
||||
arguments["query"] = user_input |
||||
|
||||
if hasattr(self, function_name): |
||||
if function_name in [ |
||||
"fetch_other_documents_tool", |
||||
"fetch_science_articles_tool", |
||||
"fetch_science_articles_and_other_documents_tool", |
||||
]: |
||||
chunks = getattr(self, function_name)(**arguments) |
||||
bot_responses.append( |
||||
self.generate_from_chunks(user_input, chunks).strip('"') |
||||
) |
||||
elif function_name == "fetch_notes_tool": |
||||
notes = getattr(self, function_name)() |
||||
bot_responses.append( |
||||
self.generate_from_notes(user_input, notes).strip('"') |
||||
) |
||||
elif function_name == "conversational_response_tool": |
||||
bot_responses.append( |
||||
getattr(self, function_name)(user_input).strip('"') |
||||
) |
||||
return "\n\n".join(bot_responses) |
||||
|
||||
def process_user_input(self, user_input, content_attachment=None): |
||||
# Add user message |
||||
self.chat.add_message("user", user_input) |
||||
|
||||
if not content_attachment: |
||||
prompt = get_tools_prompt(user_input) |
||||
response = self.toolbot.generate(prompt, tools=self.tools, stream=False) |
||||
if response.get("tool_calls"): |
||||
bot_response = self.answer_tool_call(response, user_input) |
||||
else: |
||||
# Just respond directly |
||||
bot_response = response.content.strip('"') |
||||
else: |
||||
# If there's an attachment, do something minimal |
||||
bot_response = "Content attachment received (Base Bot)." |
||||
|
||||
# Add assistant message |
||||
if self.chat.chat_history[-1]["role"] != "assistant": |
||||
self.chat.add_message("assistant", bot_response) |
||||
|
||||
# Update in Arango |
||||
self.chat.update_in_arango() |
||||
return bot_response |
||||
|
||||
def generate_from_notes(self, user_input, notes): |
||||
# No Streamlit calls |
||||
notes_string = "" |
||||
for note in notes: |
||||
notes_string += f"\n# {note.get('title','No title')}\n{note.get('content','')}\n---\n" |
||||
prompt = get_chat_prompt(user_input, content_string=notes_string, role=self.chat.role) |
||||
return self.chatbot.generate(prompt, stream=True) |
||||
|
||||
def generate_from_chunks(self, user_input, chunks): |
||||
# No Streamlit calls |
||||
chunks_string = "" |
||||
for title, group in chunks.items(): |
||||
user_notes_string = "" |
||||
if "user_notes" in group["chunks"][0]["metadata"]: |
||||
notes = group["chunks"][0]["metadata"]["user_notes"] |
||||
user_notes_string = f'\n\nUser notes:\n"""\n{notes}\n"""\n\n' |
||||
docs = "\n(...)\n".join([c["document"] for c in group["chunks"]]) |
||||
chunks_string += ( |
||||
f"\n# {title}\n## Article #{group['article_number']}\n{user_notes_string}{docs}\n---\n" |
||||
) |
||||
prompt = get_chat_prompt(user_input, content_string=chunks_string, role=self.chat.role) |
||||
return self.chatbot.generate(prompt, stream=True) |
||||
|
||||
def run(self): |
||||
# Base Bot has no Streamlit run loop |
||||
pass |
||||
|
||||
def get_notes(self): |
||||
# Minimal note retrieval |
||||
notes = self.user_arango.db.aql.execute( |
||||
f'FOR doc IN notes FILTER doc.project == "{self.project.name if self.project else ""}" RETURN doc' |
||||
) |
||||
return list(notes) |
||||
|
||||
@ToolRegistry.register |
||||
def fetch_science_articles_tool(self, query: str, n_documents: int): |
||||
""" |
||||
"Fetches information from scientific articles. Use this tool when the user is looking for information from scientific articles." |
||||
|
||||
Parameters: |
||||
query (str): The search query to find relevant scientific articles. |
||||
n_documents (int): How many documents to fetch. A complex query may require more documents. Min: 3, Max: 10. |
||||
|
||||
Returns: |
||||
list: A list of chunks containing information from the fetched scientific articles. |
||||
""" |
||||
print_purple('Query:', query) |
||||
|
||||
n_documents = int(n_documents) |
||||
if n_documents < 3: |
||||
n_documents = 3 |
||||
elif n_documents > 10: |
||||
n_documents = 10 |
||||
return self.get_chunks( |
||||
query, collections=["sci_articles"], n_results=n_documents |
||||
) |
||||
|
||||
@ToolRegistry.register |
||||
def fetch_other_documents_tool(self, query: str, n_documents: int): |
||||
""" |
||||
Fetches information from other documents based on the user's query. |
||||
|
||||
This method retrieves information from various types of documents such as reports, news articles, and other texts. It should be used only when it is clear that the user is not seeking scientific articles. |
||||
|
||||
Args: |
||||
query (str): The search query provided by the user. |
||||
n_documents (int): How many documents to fetch. A complex query may require more documents. Min: 2, Max: 10. |
||||
|
||||
Returns: |
||||
list: A list of document chunks that match the query. |
||||
""" |
||||
assert isinstance(self, Bot), "The first argument must be a Bot object." |
||||
n_documents = int(n_documents) |
||||
if n_documents < 2: |
||||
n_documents = 2 |
||||
elif n_documents > 10: |
||||
n_documents = 10 |
||||
return self.get_chunks( |
||||
query, |
||||
collections=[f"{self.username}__other_documents"], |
||||
n_results=n_documents, |
||||
) |
||||
|
||||
@ToolRegistry.register |
||||
def fetch_science_articles_and_other_documents_tool( |
||||
self, query: str, n_documents: int |
||||
): |
||||
""" |
||||
Fetches information from both scientific articles and other documents. |
||||
|
||||
This method is often used when the user hasn't specified what kind of sources they are interested in. |
||||
|
||||
Args: |
||||
query (str): The search query to fetch information for. |
||||
n_documents (int): How many documents to fetch. A complex query may require more documents. Min: 3, Max: 10. |
||||
|
||||
Returns: |
||||
list: A list of document chunks that match the search query. |
||||
""" |
||||
assert isinstance(self, Bot), "The first argument must be a Bot object." |
||||
n_documents = int(n_documents) |
||||
if n_documents < 3: |
||||
n_documents = 3 |
||||
elif n_documents > 10: |
||||
n_documents = 10 |
||||
return self.get_chunks( |
||||
query, |
||||
collections=["sci_articles", f"{self.username}__other_documents"], |
||||
n_results=n_documents, |
||||
) |
||||
|
||||
@ToolRegistry.register |
||||
def fetch_notes_tool(bot): |
||||
""" |
||||
Fetches information from the project notes when you as an editor need context from the project notes to understand other information. ONLY use this together with other tools! No arguments needed. |
||||
|
||||
Returns: |
||||
list: A list of notes. |
||||
""" |
||||
assert isinstance(bot, Bot), "The first argument must be a Bot object." |
||||
return bot.get_notes() |
||||
|
||||
@ToolRegistry.register |
||||
def conversational_response_tool(self, query: str): |
||||
""" |
||||
Generate a conversational response to a user's query. |
||||
|
||||
This method is designed to provide a short and conversational response |
||||
without fetching additional data. It should be used only when it is clear |
||||
that the user is engaging in small talk (like saying 'hi') and not seeking detailed information. |
||||
|
||||
Args: |
||||
query (str): The user's message to which the bot should respond. |
||||
|
||||
Returns: |
||||
str: The generated conversational response. |
||||
""" |
||||
query = f""" |
||||
User message: "{query}". |
||||
Make your answer short and conversational. |
||||
This is perhaps not a conversation about a journalistic project, so try not to be too informative. |
||||
Don't answer with anything you're not sure of! |
||||
""" |
||||
|
||||
result = ( |
||||
self.chatbot.generate(query, stream=True) |
||||
if self.chatbot |
||||
else self.llm.generate(query, stream=True) |
||||
) |
||||
return result |
||||
|
||||
class StreamlitBot(Bot): |
||||
def __init__(self, username: str, chat: StreamlitChat = None, tools: list = None, **kwargs): |
||||
print_purple("StreamlitBot init chat:", chat) |
||||
super().__init__(username=username, chat=chat, tools=tools, **kwargs) |
||||
|
||||
# For Streamlit, we can override or add attributes |
||||
if 'llm_chosen_backend' not in st.session_state: |
||||
st.session_state['llm_chosen_backend'] = None |
||||
|
||||
self.chatbot.chosen_backend = st.session_state['llm_chosen_backend'] |
||||
if not st.session_state['llm_chosen_backend']: |
||||
st.session_state['llm_chosen_backend'] = self.chatbot.chosen_backend |
||||
|
||||
def run(self): |
||||
# Example Streamlit run loop |
||||
self.chat.show_chat_history() |
||||
if user_input := st.chat_input("Write your message here...", accept_file=True): |
||||
text_input = user_input.text.replace('"""', "---") |
||||
if len(user_input.files) > 1: |
||||
st.error("Please upload only one file at a time.") |
||||
return |
||||
attached_file = user_input.files[0] if user_input.files else None |
||||
|
||||
content_attachment = None |
||||
if attached_file: |
||||
if attached_file.type == "application/pdf": |
||||
import fitz |
||||
pdf_document = fitz.open(stream=attached_file.read(), filetype="pdf") |
||||
pdf_text = "" |
||||
for page_num in range(len(pdf_document)): |
||||
page = pdf_document.load_page(page_num) |
||||
pdf_text += page.get_text() |
||||
content_attachment = pdf_text |
||||
elif attached_file.type in ["image/png", "image/jpeg"]: |
||||
self.chat.message_attachments = "image" |
||||
content_attachment = attached_file.read() |
||||
with st.chat_message("user", avatar=self.chat.get_avatar(role="user")): |
||||
st.image(content_attachment) |
||||
|
||||
with st.chat_message("user", avatar=self.chat.get_avatar(role="user")): |
||||
st.write(text_input) |
||||
|
||||
if not self.chat.name: |
||||
self.chat.set_name(text_input) |
||||
self.chat.last_updated = datetime.now().isoformat() |
||||
self.chat.saved = False |
||||
self.user_arango.db.collection("chats").insert( |
||||
self.chat.to_dict(), overwrite=True, overwrite_mode="update" |
||||
) |
||||
|
||||
self.process_user_input(text_input, content_attachment) |
||||
|
||||
def process_user_input(self, user_input, content_attachment=None): |
||||
# We override to show messages in Streamlit instead of just storing |
||||
self.chat.add_message("user", user_input) |
||||
if not content_attachment: |
||||
prompt = get_tools_prompt(user_input) |
||||
response = self.toolbot.generate(prompt, tools=self.tools, stream=False) |
||||
if response.get("tool_calls"): |
||||
bot_response = self.answer_tool_call(response, user_input) |
||||
else: |
||||
bot_response = response.content.strip('"') |
||||
with st.chat_message("assistant", avatar=self.chat.get_avatar(role="assistant")): |
||||
st.write(bot_response) |
||||
else: |
||||
with st.chat_message("assistant", avatar=self.chat.get_avatar(role="assistant")): |
||||
with st.spinner("Reading the content..."): |
||||
if self.chat.message_attachments == "image": |
||||
prompt = get_chat_prompt(user_input, role=self.chat.role, image_attachment=True) |
||||
bot_resp = self.chatbot.generate(prompt, stream=False, images=[content_attachment], model="vision") |
||||
st.write(bot_resp) |
||||
bot_response = bot_resp |
||||
else: |
||||
prompt = get_chat_prompt(user_input, content_attachment=content_attachment, role=self.chat.role) |
||||
response = self.chatbot.generate(prompt, stream=True) |
||||
bot_response = st.write_stream(response) |
||||
|
||||
if self.chat.chat_history[-1]["role"] != "assistant": |
||||
self.chat.add_message("assistant", bot_response) |
||||
|
||||
self.chat.update_in_arango() |
||||
|
||||
def answer_tool_call(self, response, user_input): |
||||
bot_responses = [] |
||||
for tool in response.get("tool_calls", []): |
||||
function_name = tool.function.get('name') |
||||
arguments = tool.function.arguments |
||||
arguments["query"] = user_input |
||||
|
||||
with st.chat_message("assistant", avatar=self.chat.get_avatar(role="assistant")): |
||||
if function_name in [ |
||||
"fetch_other_documents_tool", |
||||
"fetch_science_articles_tool", |
||||
"fetch_science_articles_and_other_documents_tool", |
||||
]: |
||||
chunks = getattr(self, function_name)(**arguments) |
||||
response_text = self.generate_from_chunks(user_input, chunks) |
||||
bot_response = st.write_stream(response_text).strip('"') |
||||
if chunks: |
||||
sources = "###### Sources:\n" |
||||
for title, group in chunks.items(): |
||||
j = group["chunks"][0]["metadata"].get("journal", "No Journal") |
||||
d = group["chunks"][0]["metadata"].get("published_date", "No Date") |
||||
sources += f"[{group['article_number']}] **{title}** :gray[{j} ({d})]\n" |
||||
st.markdown(sources) |
||||
bot_response += f"\n\n{sources}" |
||||
bot_responses.append(bot_response) |
||||
|
||||
elif function_name == "fetch_notes_tool": |
||||
notes = getattr(self, function_name)() |
||||
response_text = self.generate_from_notes(user_input, notes) |
||||
bot_responses.append(st.write_stream(response_text).strip('"')) |
||||
|
||||
elif function_name == "conversational_response_tool": |
||||
response_text = getattr(self, function_name)(user_input) |
||||
bot_responses.append(st.write_stream(response_text).strip('"')) |
||||
|
||||
return "\n\n".join(bot_responses) |
||||
|
||||
def generate_from_notes(self, user_input, notes): |
||||
with st.spinner("Reading project notes..."): |
||||
return super().generate_from_notes(user_input, notes) |
||||
|
||||
def generate_from_chunks(self, user_input, chunks): |
||||
# For reading articles with a spinner |
||||
magazines = set() |
||||
for group in chunks.values(): |
||||
j = group["chunks"][0]["metadata"].get("journal", "No Journal") |
||||
magazines.add(f"*{j}*") |
||||
s = ( |
||||
f"Reading articles from {', '.join(list(magazines)[:-1])} and {list(magazines)[-1]}..." |
||||
if len(magazines) > 1 |
||||
else "Reading articles..." |
||||
) |
||||
with st.spinner(s): |
||||
return super().generate_from_chunks(user_input, chunks) |
||||
|
||||
def sidebar_content(self): |
||||
with st.sidebar: |
||||
st.write("---") |
||||
st.markdown(f'#### {self.chat.name if self.chat.name else ""}') |
||||
st.button("Delete this chat", on_click=self.delete_chat) |
||||
|
||||
def delete_chat(self): |
||||
self.user_arango.db.collection("chats").delete_match( |
||||
filters={"name": self.chat.name} |
||||
) |
||||
self.chat = Chat() |
||||
|
||||
def get_notes(self): |
||||
# We can show a spinner or messages too |
||||
with st.spinner("Fetching notes..."): |
||||
return super().get_notes() |
||||
|
||||
|
||||
class EditorBot(StreamlitBot(Bot)): |
||||
def __init__(self, chat: Chat, username: str, **kwargs): |
||||
print_blue("EditorBot init chat:", chat) |
||||
super().__init__(chat=chat, username=username, **kwargs) |
||||
self.role = "Editor" |
||||
self.tools = ToolRegistry.get_tools() |
||||
self.chatbot = LLM( |
||||
system_message=get_editor_prompt(kwargs.get("project")), |
||||
messages=self.chat.chat_history2bot(), |
||||
chosen_backend=kwargs.get("chosen_backend"), |
||||
) |
||||
|
||||
|
||||
class ResearchAssistantBot(StreamlitBot(Bot)): |
||||
def __init__(self, chat: Chat, username: str, **kwargs): |
||||
super().__init__(chat=chat, username=username, **kwargs) |
||||
self.role = "Research Assistant" |
||||
self.chatbot = LLM( |
||||
system_message=get_assistant_prompt(), |
||||
temperature=0.1, |
||||
messages=self.chat.chat_history2bot(), |
||||
) |
||||
self.tools = [ |
||||
self.fetch_science_articles_tool, |
||||
self.fetch_science_articles_and_other_documents_tool, |
||||
] |
||||
|
||||
|
||||
class PodBot(StreamlitBot(Bot)): |
||||
"""Two LLM agents construct a conversation using material from science articles.""" |
||||
|
||||
def __init__( |
||||
self, |
||||
chat: Chat, |
||||
subject: str, |
||||
username: str, |
||||
instructions: str = None, |
||||
**kwargs, |
||||
): |
||||
super().__init__(chat=chat, username=username, **kwargs) |
||||
self.subject = subject |
||||
self.instructions = instructions |
||||
self.guest_name = kwargs.get("name_guest", "Merit") |
||||
self.hostbot = HostBot( |
||||
Chat(username=self.username, role="Host"), |
||||
subject, |
||||
username, |
||||
instructions=instructions, |
||||
**kwargs, |
||||
) |
||||
self.guestbot = GuestBot( |
||||
Chat(username=self.username, role="Guest"), |
||||
subject, |
||||
username, |
||||
name_guest=self.guest_name, |
||||
**kwargs, |
||||
) |
||||
|
||||
def run(self): |
||||
|
||||
notes = self.get_notes() |
||||
notes_string = "" |
||||
if self.instructions: |
||||
instructions_string = f''' |
||||
These are the instructions for the podcast from the producer: |
||||
""" |
||||
{self.instructions} |
||||
""" |
||||
''' |
||||
else: |
||||
instructions_string = "" |
||||
|
||||
for note in notes: |
||||
notes_string += f"\n# {note['title']}\n{note['content']}\n---\n" |
||||
a = f'''You will make a podcast interview with {self.guest_name}, an expert on "{self.subject}". |
||||
{instructions_string} |
||||
Below are notes on the subject that you can use to ask relevant questions: |
||||
""" |
||||
{notes_string} |
||||
""" |
||||
Say hello to the expert and start the interview. Remember to keep the interview to the subject of {self.subject} throughout the conversation. |
||||
''' |
||||
|
||||
# Stop button for the podcast |
||||
with st.sidebar: |
||||
stop = st.button("Stop podcast", on_click=self.stop_podcast) |
||||
|
||||
while st.session_state["make_podcast"]: |
||||
# Stop the podcast if there are more than 14 messages in the chat |
||||
self.chat.show_chat_history() |
||||
if len(self.chat.chat_history) == 14: |
||||
result = self.hostbot.generate( |
||||
"The interview has ended. Say thank you to the expert and end the conversation." |
||||
) |
||||
self.chat.add_message("Host", result) |
||||
with st.chat_message( |
||||
"assistant", avatar=self.chat.get_avatar(role="assistant") |
||||
): |
||||
st.write(result.strip('"')) |
||||
st.stop() |
||||
|
||||
_q = self.hostbot.toolbot.generate( |
||||
query=f"{self.guest_name} has answered: {a}. You have to choose a tool to help the host continue the interview.", |
||||
tools=self.hostbot.tools, |
||||
temperature=0.6, |
||||
stream=False, |
||||
) |
||||
if "tool_calls" in _q: |
||||
q = self.hostbot.answer_tool_call(_q, a) |
||||
else: |
||||
q = _q |
||||
|
||||
self.chat.add_message("Host", q) |
||||
|
||||
_a = self.guestbot.toolbot.generate( |
||||
f'The podcast host has asked: "{q}" Choose a tool to help the expert answer with relevant facts and information.', |
||||
tools=self.guestbot.tools, |
||||
) |
||||
if "tool_calls" in _a: |
||||
print_yellow("Tool call response (guest)", _a) |
||||
print_yellow(self.guestbot.chat.role) |
||||
a = self.guestbot.answer_tool_call(_a, q) |
||||
else: |
||||
a = _a |
||||
self.chat.add_message("Guest", a) |
||||
|
||||
self.update_session_state() |
||||
|
||||
def stop_podcast(self): |
||||
st.session_state["make_podcast"] = False |
||||
self.update_session_state() |
||||
self.chat.show_chat_history() |
||||
|
||||
|
||||
class HostBot(StreamlitBot(Bot)): |
||||
def __init__( |
||||
self, chat: Chat, subject: str, username: str, instructions: str, **kwargs |
||||
): |
||||
super().__init__(chat=chat, username=username, **kwargs) |
||||
self.chat.role = kwargs.get("role", "Host") |
||||
self.tools = ToolRegistry.get_tools( |
||||
tools=[ |
||||
self.fetch_notes_tool, |
||||
self.conversational_response_tool, |
||||
# "fetch_other_documents", #TODO Should this be included? |
||||
] |
||||
) |
||||
self.instructions = instructions |
||||
self.llm = LLM( |
||||
system_message=f''' |
||||
You are the host of a podcast and an expert on {subject}. You will ask one question at a time about the subject, and then wait for the guest to answer. |
||||
Don't ask the guest to talk about herself/himself, only about the subject. |
||||
Make your questions short and clear, only if necessary add a brief context to the question. |
||||
These are the instructions for the podcast from the producer: |
||||
""" |
||||
{self.instructions} |
||||
""" |
||||
If the experts' answer is complicated, try to make a very brief summary of it for the audience to understand. You can also ask follow-up questions to clarify the answer, or ask for examples. |
||||
''', |
||||
messages=self.chat.chat_history2bot() |
||||
) |
||||
self.toolbot = LLM( |
||||
temperature=0, |
||||
system_message=""" |
||||
You are assisting a podcast host in asking questions to an expert. |
||||
Choose one or many tools to use in order to assist the host in asking relevant questions. |
||||
Often "conversational_response_tool" is enough, but sometimes project notes are needed. |
||||
Make sure to read the description of the tools carefully!""", |
||||
chat=False, |
||||
model="small", |
||||
) |
||||
|
||||
def generate(self, query): |
||||
return self.llm.generate(query) |
||||
|
||||
|
||||
class GuestBot(StreamlitBot(Bot)): |
||||
def __init__(self, chat: Chat, subject: str, username: str, **kwargs): |
||||
super().__init__(chat=chat, username=username, **kwargs) |
||||
self.chat.role = kwargs.get("role", "Guest") |
||||
self.tools = ToolRegistry.get_tools( |
||||
tools=[ |
||||
self.fetch_notes_tool, |
||||
self.fetch_science_articles_tool, |
||||
] |
||||
) |
||||
|
||||
self.llm = LLM( |
||||
system_message=f""" |
||||
You are {kwargs.get('name', 'Merit')}, an expert on {subject}. |
||||
Today you are a guest in a podcast about {subject}. A host will ask you questions about the subject and you will answer by using scientific facts and information. |
||||
When answering, don't say things like "based on the documents" or alike, as neither the host nor the audience can see the documents. Act just as if you were talking to someone in a conversation. |
||||
Try to be concise when answering, and remember that the audience of the podcast is not expert on the subject, so don't complicate things too much. |
||||
It's very important that you answer in a "spoken" way, as if you were talking to someone in a conversation. That means you should avoid using scientific jargon and complex terms, too many figures or abstract concepts. |
||||
Lists are also not recommended, instead use "for the first reason", "secondly", etc. |
||||
Instead, use "..." to indicate a pause, "-" to indicate a break in the sentence, as if you were speaking. |
||||
""", |
||||
messages=self.chat.chat_history2bot() |
||||
) |
||||
self.toolbot = LLM( |
||||
temperature=0, |
||||
system_message=f"You are an assistant to an expert on {subject}. Choose one or many tools to use in order to assist the expert in answering questions. Make sure to read the description of the tools carefully.", |
||||
chat=False, |
||||
model="small", |
||||
) |
||||
|
||||
def generate(self, query): |
||||
return self.llm.generate(query) |
||||
Loading…
Reference in new issue