diff --git a/chatbot.py b/chatbot.py index 1545a3b..ef3ebed 100644 --- a/chatbot.py +++ b/chatbot.py @@ -1,43 +1,61 @@ +import streamlit as st +import crossref_commons.retrieval from _llm import LLM from _chromadb import ChromaDB from _arango import ArangoDB from pprint import pprint +from colorprinter.print_color import * +# Initialize databases and chatbot chromadb = ChromaDB() arango = ArangoDB() chatbot = LLM(temperature=0.1) -while True: - user_input = "What problems are there in battery production?" # input("Enter a prompt: ") +# Streamlit app setup +st.title("EV Cars Chatbot") +st.write("Ask a question about EV car battery production:") +# User input +user_input = st.text_input("Ask something") - chunks = chromadb.db.get_collection('sci_articles').query(query_texts=user_input, n_results=7) +if user_input: + chunks = chromadb.db.get_collection("sci_articles").query( + query_texts=user_input, n_results=7 + ) combined_chunks = [ {"document": doc, "metadata": meta} - for doc, meta in zip(chunks['documents'][0], chunks['metadatas'][0]) + for doc, meta in zip(chunks["documents"][0], chunks["metadatas"][0]) ] for i in combined_chunks: - _key = i['metadata']['_key'] - arango_metadata = arango.db.collection('sci_articles').get(_key)['metadata'] - i['crossref_info'] = arango_metadata + _key = i["metadata"]["_key"] + arango_metadata = arango.db.collection("sci_articles").get(_key)["metadata"] + i["crossref_info"] = arango_metadata if arango_metadata else {'title': 'No title', 'published_date': 'No published date', 'journal': 'No journal'} # Sort the combined_chunks list first by published_date, then by title - sorted_chunks = sorted(combined_chunks, key=lambda x: (x['crossref_info']['published_date'], x['crossref_info']['title'])) + sorted_chunks = sorted( + combined_chunks, + key=lambda x: ( + x["crossref_info"]["published_date"], + x["crossref_info"]["title"], + ), + ) # Group the chunks by title grouped_chunks = {} for chunk in sorted_chunks: - title = chunk['crossref_info']['title'] + title = chunk["crossref_info"]["title"] if title not in grouped_chunks: grouped_chunks[title] = [] grouped_chunks[title].append(chunk) - chunks_string = '' + chunks_string = "" for title, chunks in grouped_chunks.items(): - chunks_content_string = '\n(...)\n'.join([chunk['document'] for chunk in chunks]) + chunks_content_string = "\n(...)\n".join( + [chunk["document"] for chunk in chunks] + ) chunks_string += f"""\n -## {title} -### {chunks[0]['crossref_info']['published_date']} in {chunks[0]['crossref_info']['journal']} +# {title} +## {chunks[0]['crossref_info']['published_date']} in {chunks[0]['crossref_info']['journal']} {chunks_content_string}\n --- \n @@ -53,8 +71,5 @@ ONLY use the information below to answer the question. Do not use any other info {user_input} ''' - print(prompt) - exit() response = chatbot.generate(prompt) - print(response) - print() \ No newline at end of file + st.write(response) \ No newline at end of file