You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
57 lines
1.8 KiB
57 lines
1.8 KiB
from _llm import LLM |
|
from _chromadb import ChromaDB |
|
from _arango import ArangoDB |
|
from pprint import pprint |
|
|
|
chromadb = ChromaDB() |
|
arango = ArangoDB() |
|
llm = LLM(temperature=0.1) |
|
|
|
while True: |
|
user_input = "What problems are there in battery production?" # input("Enter a prompt: ") |
|
chunks = chromadb.db.get_collection('sci_articles').query(query_texts=user_input, n_results=7) |
|
combined_chunks = [ |
|
{"document": doc, "metadata": meta} |
|
for doc, meta in zip(chunks['documents'][0], chunks['metadatas'][0]) |
|
] |
|
for i in combined_chunks: |
|
_key = i['metadata']['_key'] |
|
arango_metadata = arango.db.collection('sci_articles').get(_key)['metadata'] |
|
i['crossref_info'] = arango_metadata |
|
|
|
# Sort the combined_chunks list first by published_date, then by title |
|
sorted_chunks = sorted(combined_chunks, key=lambda x: (x['crossref_info']['published_date'], x['crossref_info']['title'])) |
|
|
|
# Group the chunks by title |
|
grouped_chunks = {} |
|
for chunk in sorted_chunks: |
|
title = chunk['crossref_info']['title'] |
|
if title not in grouped_chunks: |
|
grouped_chunks[title] = [] |
|
grouped_chunks[title].append(chunk) |
|
|
|
chunks_string = '' |
|
for title, chunks in grouped_chunks.items(): |
|
chunks_content_string = '\n(...)\n'.join([chunk['document'] for chunk in chunks]) |
|
chunks_string += f"""\n |
|
## {title} |
|
### {chunks[0]['crossref_info']['published_date']} in {chunks[0]['crossref_info']['journal']} |
|
{chunks_content_string}\n |
|
--- |
|
\n |
|
""" |
|
|
|
prompt = f'''{user_input} |
|
Below are snippets from different articles with title and date of publication. ONLY use the information below to answer the question. Do not use any other information. |
|
|
|
""" |
|
{chunks_string} |
|
""" |
|
|
|
{user_input} |
|
''' |
|
print(prompt) |
|
exit() |
|
response = llm.generate(prompt) |
|
print(response) |
|
print() |