You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

261 lines
8.8 KiB

from time import sleep
import requests
import concurrent.futures
import queue
import threading
import re
from dotenv import load_dotenv
import os
import json
from print_color import *
load_dotenv()
class LLM:
def __init__(
self,
chat: bool = False,
model: str = os.getenv("LLM_MODEL"),
keep_alive: int = 3600 * 24,
start: bool = False,
system_prompt: str = 'Svara alltid på svenska. Svara bara på det som efterfrågas. Om du inte kan svara, skriv "Jag vet inte".',
temperature: str = 0,
stream=False,
small=False
):
"""
Initializes an instance of MyClass.
Args:
chat (bool, optional): Specifies whether the instance is for chat purposes. Defaults to False.
model (str, optional): The model to be used. Defaults to "llama3:8b-instruct-q5_K_M".
keep_alive (int, optional): The duration in seconds to keep the instance alive. Defaults to 3600*24.
start (bool, optional): If True, the instance will automatically start processing requests upon initialization.
This means that a separate thread will be started that runs the generate_concurrent method,
which processes requests concurrently. Defaults to False.
"""
self.model = model
self.server = os.getenv("LLM_URL")
self.port = os.getenv("LLM_PORT")
if small:
self.model = os.getenv("LLM_SMALL_MODEL")
self.server = os.getenv("LLM_SMALL_URL")
self.port = os.getenv("LLM_SMALL_PORT")
self.temperature = temperature
self.system_message = {"role": "system", "content": system_prompt}
self.messages = [self.system_message]
self.chat = chat
self.max_length = 24000
self.keep_alive = keep_alive
self.request_queue = queue.Queue()
self.result_queue = queue.Queue()
self.all_requests_added_event = threading.Event()
self.all_results_processed_event = threading.Event()
self.stop_event = threading.Event()
self.stream = stream
if start:
self.start()
def create_data_request(self, message):
# Remove leading and trailing whitespace
message = '\n'.join(line.strip() for line in message.split('\n'))
# Prepare the request data
options = {
"temperature": self.temperature,
}
if self.chat:
self.build_message(message)
messages = self.messages
else:
messages = [self.system_message, {"role": "user", "content": message}]
data = {
"model": self.model,
"messages": messages,
"options": options,
"keep_alive": self.keep_alive,
"stream": self.stream,
}
return data
def generate_stream(self, message):
# Make a POST request to the API endpoint
data = self.create_data_request(message)
response = requests.post(
f"http://{self.server}:{self.port}/api/chat", json=data, stream=True
)
# Iterate over the response
# Iterate over the response
for line in response.iter_lines():
# Filter out keep-alive new lines
if line:
decoded_line = line.decode('utf-8')
json_line = json.loads(decoded_line) # Parse the line as JSON
yield json_line['message']['content']
def generate(self, message):
data = self.create_data_request(message)
# Make a POST request to the API endpoint
result = requests.post(
f"http://{self.server}:{self.port}/api/chat", json=data
)
try:
if 'message' in result.json():
answer = result.json()["message"]["content"]
else:
print_red(result.content)
raise 'Error occurred during API request'
except requests.exceptions.JSONDecodeError:
print_red(result.content)
raise Exception("Error occurred during API request")
if self.chat:
self.messages.append({"role": "assistant", "content": answer})
return answer
def generate_concurrent(
self,
request_queue,
result_queue,
all_requests_added_event,
all_results_processed_event,
):
self.chat = False
with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_message = {}
buffer_size = 6 # The number of tasks to keep in the executor
while True:
if self.stop_event.is_set():
break
try:
# If there are less than buffer_size tasks being processed, add new tasks
while len(future_to_message) < buffer_size:
# Take a request from the queue
doc_id, message = request_queue.get(timeout=1)
# Submit the generate method to the executor for execution
future = executor.submit(self.generate, message)
future_to_message[future] = doc_id
except queue.Empty:
# If the queue is empty and all requests have been added, break the loop
if all_requests_added_event.is_set():
break
else:
continue
# Process completed futures
done_futures = [f for f in future_to_message if f.done()]
for future in done_futures:
doc_id = future_to_message.pop(future)
try:
summary = future.result()
except Exception as exc:
print("Document %r generated an exception: %s" % (doc_id, exc))
else:
# Put the document ID and the summary into the result queue
result_queue.put((doc_id, summary))
all_results_processed_event.set()
def start(self):
# Start a separate thread that runs the generate_concurrent method
threading.Thread(
target=self.generate_concurrent,
args=(
self.request_queue,
self.result_queue,
self.all_requests_added_event,
self.all_results_processed_event,
),
).start()
def stop(self):
"""
Stops the instance from processing further requests.
"""
self.stop_event.set()
def add_request(self, id, prompt):
# Add a request to the request queue
self.request_queue.put((id, prompt))
def finish_adding_requests(self):
# Signal that all requests have been added
print("\033[92mAll requests added\033[0m")
self.all_requests_added_event.set()
def get_results(self):
# Process the results
while True:
try:
# Take a result from the result queue
doc_id, summary = self.result_queue.get(timeout=1)
return doc_id, summary
except queue.Empty:
# If the result queue is empty and all results have been processed, break the loop
if self.all_results_processed_event.is_set():
break
else:
sleep(0.2)
continue
def build_message(self, message):
# Add the new message to the list
self.messages.append({"role": "user", "content": message})
# Calculate the total length of the messages
total_length = sum([len((msg["content"])) for msg in self.messages])
# While the total length exceeds the limit, remove the oldest messages
while total_length > self.max_length:
removed_message = self.messages.pop(
1
) # Remove the oldest message (not the system message)
total_length -= len((removed_message["content"]))
def unload_model(self):
data = {
"model": self.model,
"messages": self.messages,
"keep_alive": 0,
"stream": False,
}
# Make a POST request to the API endpoint
requests.post(f"http://{self.server}:{self.port}/api/chat", json=data).json()[
"message"
]["content"]
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--unload", action="store_true", help="Unload the model")
args = parser.parse_args()
llm = LLM(keep_alive=60000, chat=True, small=False)
if args.unload:
llm.unload_model()
else:
while True:
message = input(">>> ")
print(llm.generate(message))