From fe972973c4247676165783213e867b6d2eee0c20 Mon Sep 17 00:00:00 2001 From: Lasse Server Date: Wed, 1 May 2024 09:15:57 +0200 Subject: [PATCH] Now messages are not piling up if chat=False --- _llm.py | 108 ++++++++++++++++++-------------------------------------- 1 file changed, 35 insertions(+), 73 deletions(-) diff --git a/_llm.py b/_llm.py index 9d2171e..f444d66 100644 --- a/_llm.py +++ b/_llm.py @@ -1,10 +1,10 @@ +import datetime from time import sleep import requests import concurrent.futures import queue import threading -from _arango import arango - +from pprint import pprint class LLM: def __init__( @@ -27,11 +27,11 @@ class LLM: """ self.server = "192.168.1.12" - self.port = 3300 # 11440 All 4 GPU # 4500 "SW" + self.port = 3300 # 11440 All 4 GPU # 4500 "SW" 3300 balancer self.model = model self.temperature = 0 - self.system_message = 'Svara alltid på svenska. Svara bara på det som efterfrågas. Om du inte kan svara, skriv "Jag vet inte".' - self.messages = [{"role": "system", "content": self.system_message}] + self.system_message = {"role": "system", "content": 'Svara alltid på svenska. Svara bara på det som efterfrågas. Om du inte kan svara, skriv "Jag vet inte".'} + self.messages = [self.system_message] self.chat = chat self.max_tokens = 24000 self.keep_alive = keep_alive @@ -55,8 +55,7 @@ class LLM: self.build_message(message) messages = self.messages else: - self.messages.append({"role": "user", "content": message}) - messages = self.messages + messages = [self.system_message, {"role": "user", "content": message}] data = { "model": self.model, @@ -71,10 +70,23 @@ class LLM: f"http://{self.server}:{self.port}/api/chat", json=data ).json() + # print_data = result.copy() + # del print_data["message"] + # del print_data["model"] + + # # Convert durations from nanoseconds to seconds + # for key in ['eval_duration', 'total_duration']: + # if key in print_data: + # duration = print_data[key] / 1e9 # Convert nanoseconds to seconds + # minutes, seconds = divmod(duration, 60) # Convert seconds to minutes and remainder seconds + # print_data[key] = f'{int(minutes)}:{seconds:02.0f}' # Format as minutes:seconds + + # pprint(print_data) + # print('Number of messages', len(messages)) + if "message" in result: answer = result["message"]["content"] else: - from pprint import pprint pprint(result) raise Exception("Error occurred during API request") @@ -197,68 +209,18 @@ class LLM: if __name__ == "__main__": - # Initialize the LLM object - llm = LLM(chat=False, model="llama3:8b-instruct-q5_K_M") - - # Create a queue for requests and a queue for results - request_queue = queue.Queue() - result_queue = queue.Queue() - - # Create an event to signal when all requests have been added - all_requests_added_event = threading.Event() - all_results_processed_event = threading.Event() - - # Start a separate thread that runs the generate_concurrent method - threading.Thread( - target=llm.generate_concurrent, - args=( - request_queue, - result_queue, - all_requests_added_event, - all_results_processed_event, - ), - ).start() - - # Add requests to the request queue - from _arango import arango - - interrogations = arango.db.collection("interrogations").all() - for doc in interrogations: - text = doc["text"] - prompt = f'Kolla på texten nedan: \n\n """{text}""" \n\n Sammanfatta förhöret med fokus på vad som sades, inte var det hölls eller annat formalia. Svara så kort som möjligt men var noga med detaljer som händelser som beskrivs, namn, datum och platser.\nKort sammanfattning:' - request_queue.put((doc["_key"], prompt)) - - # Signal that all requests have been added - all_requests_added_event.set() - - # Process the results - while True: - try: - # Take a result from the result queue - doc_id, summary = result_queue.get(timeout=1) - print("\033[92m" + doc_id + "\033[0m", summary) - # Update the document with the summary - arango.db.collection("interrogations").update_match( - {"_key": doc_id}, {"summary": summary} - ) - except queue.Empty: - # If the result queue is empty and all results have been processed, break the loop - if all_results_processed_event.is_set(): - break - else: - continue - - # import argparse - # parser = argparse.ArgumentParser() - # parser.add_argument("--unload", action="store_true", help="Unload the model") - # args = parser.parse_args() - - # #llm = LLM(model='llama3:70b-text-q4_K_M', keep_alive=6000, chat=True) - # llm = LLM(keep_alive=60, chat=True) - - # if args.unload: - # llm.unload_model() - # else: - # while True: - # message = input(">>> ") - # print(llm.generate(message)) + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--unload", action="store_true", help="Unload the model") + args = parser.parse_args() + + #llm = LLM(model='llama3:70b-text-q4_K_M', keep_alive=6000, chat=True) + llm = LLM(keep_alive=6000, chat=True) + + if args.unload: + llm.unload_model() + else: + while True: + message = input(">>> ") + print(llm.generate(message))