You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
152 lines
5.6 KiB
152 lines
5.6 KiB
import streamlit as st |
|
from _llm import LLM |
|
from _chromadb import ChromaDB |
|
from _arango import ArangoDB |
|
from colorprinter.print_color import * |
|
import re |
|
|
|
|
|
def get_stream(response): |
|
for i in response: |
|
yield str(i) |
|
def get_chunks(user_input, n_results=5): |
|
query = helperbot.generate(f"""A user asked this question: "{user_input}"". |
|
Generate a query for the vector database. Make sure to follow the instructions you got earlier!""" |
|
) |
|
# Strip the query from anything that is not a word character, number, or space |
|
query = re.sub(r"[^\w\d\s]", "", query) |
|
print_purple(query) |
|
|
|
chunks = chromadb.db.get_collection("sci_articles").query( |
|
query_texts=query, n_results=n_results |
|
) |
|
combined_chunks = [ |
|
{"document": doc, "metadata": meta} |
|
for doc, meta in zip(chunks["documents"][0], chunks["metadatas"][0]) |
|
] |
|
for i in combined_chunks: |
|
_key = i["metadata"]["_key"] |
|
arango_metadata = arango.db.collection("sci_articles").get(_key)["metadata"] |
|
i["crossref_info"] = ( |
|
arango_metadata |
|
if arango_metadata |
|
else { |
|
"title": "No title", |
|
"published_date": "No published date", |
|
"journal": "No journal", |
|
} |
|
) |
|
|
|
# Sort the combined_chunks list first by published_date, then by title |
|
sorted_chunks = sorted( |
|
combined_chunks, |
|
key=lambda x: ( |
|
x["crossref_info"]["published_date"], |
|
x["crossref_info"]["title"], |
|
), |
|
) |
|
|
|
# Group the chunks by title |
|
grouped_chunks = {} |
|
article_number = 1 # Initialize article counter |
|
for chunk in sorted_chunks: |
|
title = chunk["crossref_info"]["title"] |
|
chunk["article_number"] = article_number # Add article number to chunk |
|
if title not in grouped_chunks: |
|
grouped_chunks[title] = {'article_number': article_number, 'chunks': []} |
|
article_number += 1 # Increment article counter when a new title is encountered |
|
grouped_chunks[title]['chunks'].append(chunk) |
|
|
|
return grouped_chunks |
|
|
|
|
|
# Initialize session state for chat history |
|
if "chat_history" not in st.session_state: |
|
st.session_state.chat_history = [] |
|
st.session_state.chatbot_memory = None |
|
st.session_state.helperbot_memory = None |
|
|
|
# Initialize databases and chatbot |
|
chromadb = ChromaDB() |
|
arango = ArangoDB() |
|
|
|
chatbot = LLM( |
|
temperature=0.1, |
|
system_message="""You are chatting about electric cars. Only use the information from scientific articles you are provided with to answer questions. |
|
Format your answers in Markdown format. Be sure to reference the source of the information with ONLY the number of the article in the running text (e.g. "<answer based on an article> [<article number>]"). """, |
|
) |
|
if st.session_state.chat_history: |
|
chatbot.messages = st.session_state.chatbot_memory |
|
|
|
helperbot = LLM( |
|
temperature=0, |
|
model="small", |
|
max_length_answer=500, |
|
system_message="""Take the user input and write it as a sentence that could be used as a query for a vector database. |
|
The vector database will return text snippets that semantically match the query, so you CAN'T USE NEGATIONS or other complex language constructs. If there is a negation in the user input, exclude that part from the query. |
|
If the user input seems to be a follow-up question or comment, use the context from the chat history to make a relevant query. |
|
Answer ONLY with the query, no explanation or reasoning! |
|
""", |
|
) |
|
if st.session_state.chat_history: |
|
helperbot.messages = st.session_state.helperbot_memory |
|
|
|
# Streamlit app setup |
|
st.title("🚗 Electric Cars Chatbot") |
|
|
|
|
|
# User input |
|
user_input = st.chat_input("") |
|
|
|
if user_input: |
|
st.session_state.chat_history.append({"role": "user", "content": user_input}) |
|
|
|
for message in st.session_state.chat_history: |
|
with st.chat_message(message["role"]): |
|
if message['content']: |
|
st.markdown(message["content"]) |
|
|
|
# Show a loading message |
|
with st.spinner("Getting information from database..."): |
|
relevant_chunks = get_chunks(user_input, n_results=5) #! Change n_results to 7 |
|
|
|
chunks_string = "" |
|
for title, chunks in relevant_chunks.items(): |
|
chunks_content_string = "\n(...)\n".join( |
|
[chunk["document"] for chunk in chunks['chunks']] |
|
) |
|
chunks_string += f"""\n |
|
# {title} |
|
## Article number: {chunks['article_number']} |
|
## {chunks['chunks'][0]['crossref_info']['published_date']} in {chunks['chunks'][0]['crossref_info']['journal']} |
|
{chunks_content_string}\n |
|
--- |
|
\n |
|
""" |
|
|
|
prompt = f'''{user_input} |
|
Below are snippets from different articles with title and date of publication. |
|
ONLY use the information below to answer the question. Do not use any other information. |
|
|
|
""" |
|
{chunks_string} |
|
""" |
|
|
|
{user_input} |
|
''' |
|
|
|
response = chatbot.generate(prompt, stream=True) # Assuming chatbot.generate returns a generator |
|
with st.chat_message("assistant"): |
|
bot_response = st.write_stream(get_stream(response)) |
|
|
|
sources = '###### Sources: \n' |
|
for title, chunks in relevant_chunks.items(): |
|
sources += f'''[{chunks['article_number']}] **{title}** :gray[{chunks['chunks'][0]['crossref_info']['journal']} ({chunks['chunks'][0]['crossref_info']['published_date']})] \n''' |
|
st.markdown(sources) |
|
bot_response = f'{bot_response}\n\n{sources}' |
|
# Append user input and response to chat history |
|
st.session_state.chat_history.append( |
|
{"role": "assistant", "content": bot_response} |
|
) |
|
st.session_state.chatbot_memory = chatbot.messages |
|
st.session_state.helperbot_memory = helperbot.messages
|
|
|