from time import sleep import requests import concurrent.futures import queue import threading from pprint import pprint import re from dotenv import load_dotenv import os load_dotenv() class LLM: def __init__( self, chat: bool = False, model: str = "llama3:8b-instruct-q5_K_M", keep_alive: int = 3600 * 24, start: bool = False, system_prompt: str = 'Svara alltid på svenska. Svara bara på det som efterfrågas. Om du inte kan svara, skriv "Jag vet inte".', temperature: str = 0, ): """ Initializes an instance of MyClass. Args: chat (bool, optional): Specifies whether the instance is for chat purposes. Defaults to False. model (str, optional): The model to be used. Defaults to "llama3:8b-instruct-q5_K_M". keep_alive (int, optional): The duration in seconds to keep the instance alive. Defaults to 3600*24. start (bool, optional): If True, the instance will automatically start processing requests upon initialization. This means that a separate thread will be started that runs the generate_concurrent method, which processes requests concurrently. Defaults to False. """ self.server = os.getenv("LLM_URL") self.port = os.getenv("LLM_PORT") self.model = model self.temperature = temperature self.system_message = {"role": "system", "content": system_prompt} self.messages = [self.system_message] self.chat = chat self.max_tokens = 24000 self.keep_alive = keep_alive self.request_queue = queue.Queue() self.result_queue = queue.Queue() self.all_requests_added_event = threading.Event() self.all_results_processed_event = threading.Event() self.stop_event = threading.Event() if start: self.start() def generate(self, message): # Remove leading and trailing whitespace message = '\n'.join(line.strip() for line in message.split('\n')) # Prepare the request data options = { "temperature": self.temperature, } if self.chat: self.build_message(message) messages = self.messages else: messages = [self.system_message, {"role": "user", "content": message}] data = { "model": self.model, "messages": messages, "options": options, "keep_alive": self.keep_alive, "stream": False, } # Make a POST request to the API endpoint result = requests.post( f"http://{self.server}:{self.port}/api/chat", json=data ).json() # print_data = result.copy() # del print_data["message"] # del print_data["model"] # # Convert durations from nanoseconds to seconds # for key in ['eval_duration', 'total_duration']: # if key in print_data: # duration = print_data[key] / 1e9 # Convert nanoseconds to seconds # minutes, seconds = divmod(duration, 60) # Convert seconds to minutes and remainder seconds # print_data[key] = f'{int(minutes)}:{seconds:02.0f}' # Format as minutes:seconds # pprint(print_data) # print('Number of messages', len(messages)) if "message" in result: answer = result["message"]["content"] else: pprint(result) raise Exception("Error occurred during API request") if self.chat: self.messages.append({"role": "assistant", "content": answer}) return answer def generate_concurrent( self, request_queue, result_queue, all_requests_added_event, all_results_processed_event, ): self.chat = False with concurrent.futures.ThreadPoolExecutor() as executor: future_to_message = {} buffer_size = 6 # The number of tasks to keep in the executor while True: if self.stop_event.is_set(): break try: # If there are less than buffer_size tasks being processed, add new tasks while len(future_to_message) < buffer_size: # Take a request from the queue doc_id, message = request_queue.get(timeout=1) # Submit the generate method to the executor for execution future = executor.submit(self.generate, message) future_to_message[future] = doc_id except queue.Empty: # If the queue is empty and all requests have been added, break the loop if all_requests_added_event.is_set(): break else: continue # Process completed futures done_futures = [f for f in future_to_message if f.done()] for future in done_futures: doc_id = future_to_message.pop(future) try: summary = future.result() except Exception as exc: print("Document %r generated an exception: %s" % (doc_id, exc)) else: # Put the document ID and the summary into the result queue result_queue.put((doc_id, summary)) all_results_processed_event.set() def start(self): # Start a separate thread that runs the generate_concurrent method threading.Thread( target=self.generate_concurrent, args=( self.request_queue, self.result_queue, self.all_requests_added_event, self.all_results_processed_event, ), ).start() def stop(self): """ Stops the instance from processing further requests. """ self.stop_event.set() def add_request(self, id, prompt): # Add a request to the request queue self.request_queue.put((id, prompt)) def finish_adding_requests(self): # Signal that all requests have been added print("\033[92mAll requests added\033[0m") self.all_requests_added_event.set() def get_results(self): # Process the results while True: try: # Take a result from the result queue doc_id, summary = self.result_queue.get(timeout=1) return doc_id, summary except queue.Empty: # If the result queue is empty and all results have been processed, break the loop if self.all_results_processed_event.is_set(): break else: sleep(0.2) continue def build_message(self, message): # Add the new message to the list self.messages.append({"role": "user", "content": message}) # Calculate the total token length of the messages total_tokens = sum([len((msg["content"])) for msg in self.messages]) # While the total token length exceeds the limit, remove the oldest messages while total_tokens > self.max_tokens: removed_message = self.messages.pop( 1 ) # Remove the oldest message (not the system message) total_tokens -= len((removed_message["content"])) def unload_model(self): data = { "model": self.model, "messages": self.messages, "keep_alive": 0, "stream": False, } # Make a POST request to the API endpoint requests.post(f"http://{self.server}:{self.port}/api/chat", json=data).json()[ "message" ]["content"] if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--unload", action="store_true", help="Unload the model") args = parser.parse_args() # llm = LLM(model='llama3:70b-text-q4_K_M', keep_alive=6000, chat=True) llm = LLM(keep_alive=6000, chat=True) if args.unload: llm.unload_model() else: while True: message = input(">>> ") message = '''Hej bad är kul''' print(llm.generate(message))