diff --git a/chatbot.py b/chatbot.py index c08f123..cc997e4 100644 --- a/chatbot.py +++ b/chatbot.py @@ -4,19 +4,29 @@ from _chromadb import ChromaDB from _arango import ArangoDB from colorprinter.print_color import * import re +from highlighter.highlight_pdf import Highlighter + +async def highlight_pdf(user_input, pdf_file, make_comments): + highlighter = Highlighter(comment=make_comments) + pdf_buffer = io.BytesIO(pdf_file.read()) + highlighted_pdf_buffer = await highlighter.highlight(user_input, pdf_buffer=pdf_buffer) + return highlighted_pdf_buffer def get_stream(response): for i in response: yield str(i) + + def get_chunks(user_input, n_results=5): - query = helperbot.generate(f"""A user asked this question: "{user_input}"". + query = helperbot.generate( + f"""A user asked this question: "{user_input}"". Generate a query for the vector database. Make sure to follow the instructions you got earlier!""" ) # Strip the query from anything that is not a word character, number, or space query = re.sub(r"[^\w\d\s]", "", query) print_purple(query) - + chunks = chromadb.db.get_collection("sci_articles").query( query_texts=query, n_results=n_results ) @@ -51,12 +61,15 @@ def get_chunks(user_input, n_results=5): article_number = 1 # Initialize article counter for chunk in sorted_chunks: title = chunk["crossref_info"]["title"] + chunk['file_path'] = chunk['file'] chunk["article_number"] = article_number # Add article number to chunk if title not in grouped_chunks: - grouped_chunks[title] = {'article_number': article_number, 'chunks': []} - article_number += 1 # Increment article counter when a new title is encountered - grouped_chunks[title]['chunks'].append(chunk) - + grouped_chunks[title] = {"article_number": article_number, "chunks": []} + article_number += ( + 1 # Increment article counter when a new title is encountered + ) + grouped_chunks[title]["chunks"].append(chunk) + return grouped_chunks @@ -69,11 +82,16 @@ if "chat_history" not in st.session_state: # Initialize databases and chatbot chromadb = ChromaDB() arango = ArangoDB() +highlighter = Highlighter() chatbot = LLM( temperature=0.1, system_message="""You are chatting about electric cars. Only use the information from scientific articles you are provided with to answer questions. - Format your answers in Markdown format. Be sure to reference the source of the information with ONLY the number of the article in the running text (e.g. " [
]"). """, + The articles are ordered by publication date, so the first article is the oldest one. Sometimes the articles might contain conflicting information, in that case be clear about the conflict and provide both sides of the argument with publication dates taken into account. + Be sure to reference the source of the information with the number of the article inside square brackets (e.g. " [article number]"). + If you have to reference the articles in running text, e.g. in a headline or the beginning of a bullet point, use the title of the article. + You should not write a reference section as this will be added later. + Format your answers in Markdown format. """, ) if st.session_state.chat_history: chatbot.messages = st.session_state.chatbot_memory @@ -103,17 +121,17 @@ if user_input: for message in st.session_state.chat_history: with st.chat_message(message["role"]): - if message['content']: + if message["content"]: st.markdown(message["content"]) # Show a loading message with st.spinner("Getting information from database..."): - relevant_chunks = get_chunks(user_input, n_results=5) #! Change n_results to 7 + relevant_chunks = get_chunks(user_input, n_results=5) #! Change n_results to 7? chunks_string = "" for title, chunks in relevant_chunks.items(): chunks_content_string = "\n(...)\n".join( - [chunk["document"] for chunk in chunks['chunks']] + [chunk["document"] for chunk in chunks["chunks"]] ) chunks_string += f"""\n # {title} @@ -134,19 +152,39 @@ ONLY use the information below to answer the question. Do not use any other info {user_input} ''' + magazines = list(set([f"*{chunks['chunks'][0]['crossref_info']['journal']}*" for _, chunks in relevant_chunks.items()])) + with st.spinner(f"Reading articles from {', '.join(magazines[:-1])} and {magazines[-1]}..."): + response = chatbot.generate(prompt, stream=True) - response = chatbot.generate(prompt, stream=True) # Assuming chatbot.generate returns a generator with st.chat_message("assistant"): bot_response = st.write_stream(get_stream(response)) - - sources = '###### Sources: \n' + + sources = "###### Sources: \n" for title, chunks in relevant_chunks.items(): - sources += f'''[{chunks['article_number']}] **{title}** :gray[{chunks['chunks'][0]['crossref_info']['journal']} ({chunks['chunks'][0]['crossref_info']['published_date']})] \n''' + sources += f"""[{chunks['article_number']}] **{title}** :gray[{chunks['chunks'][0]['crossref_info']['journal']} ({chunks['chunks'][0]['crossref_info']['published_date']})] \n""" st.markdown(sources) - bot_response = f'{bot_response}\n\n{sources}' + bot_response = f"{bot_response}\n\n{sources}" # Append user input and response to chat history - st.session_state.chat_history.append( - {"role": "assistant", "content": bot_response} - ) + st.session_state.chat_history.append({"role": "assistant", "content": bot_response}) st.session_state.chatbot_memory = chatbot.messages st.session_state.helperbot_memory = helperbot.messages + +#TODO Add highlighter +#TODO Add preview +# import base64 +# base64_pdf = base64.b64encode(highlighted_pdf_buffer.getvalue()).decode('utf-8') + +# # Embed PDF in HTML +# pdf_display = F'' + +# with st.sidebar: +# # Display file +# st.markdown("_Preview of highlighted PDF:_") +# st.markdown(pdf_display, unsafe_allow_html=True) + +# st.download_button( +# label="Download Highlighted PDF", +# data=highlighted_pdf_buffer, +# file_name="highlighted_document.pdf", +# mime="application/pdf" +# ) \ No newline at end of file