|
|
|
|
@ -1,43 +1,61 @@ |
|
|
|
|
import streamlit as st |
|
|
|
|
import crossref_commons.retrieval |
|
|
|
|
from _llm import LLM |
|
|
|
|
from _chromadb import ChromaDB |
|
|
|
|
from _arango import ArangoDB |
|
|
|
|
from pprint import pprint |
|
|
|
|
from colorprinter.print_color import * |
|
|
|
|
|
|
|
|
|
# Initialize databases and chatbot |
|
|
|
|
chromadb = ChromaDB() |
|
|
|
|
arango = ArangoDB() |
|
|
|
|
chatbot = LLM(temperature=0.1) |
|
|
|
|
|
|
|
|
|
while True: |
|
|
|
|
user_input = "What problems are there in battery production?" # input("Enter a prompt: ") |
|
|
|
|
# Streamlit app setup |
|
|
|
|
st.title("EV Cars Chatbot") |
|
|
|
|
st.write("Ask a question about EV car battery production:") |
|
|
|
|
|
|
|
|
|
# User input |
|
|
|
|
user_input = st.text_input("Ask something") |
|
|
|
|
|
|
|
|
|
chunks = chromadb.db.get_collection('sci_articles').query(query_texts=user_input, n_results=7) |
|
|
|
|
if user_input: |
|
|
|
|
chunks = chromadb.db.get_collection("sci_articles").query( |
|
|
|
|
query_texts=user_input, n_results=7 |
|
|
|
|
) |
|
|
|
|
combined_chunks = [ |
|
|
|
|
{"document": doc, "metadata": meta} |
|
|
|
|
for doc, meta in zip(chunks['documents'][0], chunks['metadatas'][0]) |
|
|
|
|
for doc, meta in zip(chunks["documents"][0], chunks["metadatas"][0]) |
|
|
|
|
] |
|
|
|
|
for i in combined_chunks: |
|
|
|
|
_key = i['metadata']['_key'] |
|
|
|
|
arango_metadata = arango.db.collection('sci_articles').get(_key)['metadata'] |
|
|
|
|
i['crossref_info'] = arango_metadata |
|
|
|
|
_key = i["metadata"]["_key"] |
|
|
|
|
arango_metadata = arango.db.collection("sci_articles").get(_key)["metadata"] |
|
|
|
|
i["crossref_info"] = arango_metadata if arango_metadata else {'title': 'No title', 'published_date': 'No published date', 'journal': 'No journal'} |
|
|
|
|
|
|
|
|
|
# Sort the combined_chunks list first by published_date, then by title |
|
|
|
|
sorted_chunks = sorted(combined_chunks, key=lambda x: (x['crossref_info']['published_date'], x['crossref_info']['title'])) |
|
|
|
|
sorted_chunks = sorted( |
|
|
|
|
combined_chunks, |
|
|
|
|
key=lambda x: ( |
|
|
|
|
x["crossref_info"]["published_date"], |
|
|
|
|
x["crossref_info"]["title"], |
|
|
|
|
), |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# Group the chunks by title |
|
|
|
|
grouped_chunks = {} |
|
|
|
|
for chunk in sorted_chunks: |
|
|
|
|
title = chunk['crossref_info']['title'] |
|
|
|
|
title = chunk["crossref_info"]["title"] |
|
|
|
|
if title not in grouped_chunks: |
|
|
|
|
grouped_chunks[title] = [] |
|
|
|
|
grouped_chunks[title].append(chunk) |
|
|
|
|
|
|
|
|
|
chunks_string = '' |
|
|
|
|
chunks_string = "" |
|
|
|
|
for title, chunks in grouped_chunks.items(): |
|
|
|
|
chunks_content_string = '\n(...)\n'.join([chunk['document'] for chunk in chunks]) |
|
|
|
|
chunks_content_string = "\n(...)\n".join( |
|
|
|
|
[chunk["document"] for chunk in chunks] |
|
|
|
|
) |
|
|
|
|
chunks_string += f"""\n |
|
|
|
|
## {title} |
|
|
|
|
### {chunks[0]['crossref_info']['published_date']} in {chunks[0]['crossref_info']['journal']} |
|
|
|
|
# {title} |
|
|
|
|
## {chunks[0]['crossref_info']['published_date']} in {chunks[0]['crossref_info']['journal']} |
|
|
|
|
{chunks_content_string}\n |
|
|
|
|
--- |
|
|
|
|
\n |
|
|
|
|
@ -53,8 +71,5 @@ ONLY use the information below to answer the question. Do not use any other info |
|
|
|
|
|
|
|
|
|
{user_input} |
|
|
|
|
''' |
|
|
|
|
print(prompt) |
|
|
|
|
exit() |
|
|
|
|
response = chatbot.generate(prompt) |
|
|
|
|
print(response) |
|
|
|
|
print() |
|
|
|
|
st.write(response) |