You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

235 lines
8.3 KiB

from time import sleep
import requests
import concurrent.futures
import queue
import threading
from pprint import pprint
import re
class LLM:
def __init__(
self,
chat: bool = False,
model: str = "llama3:8b-instruct-q5_K_M",
keep_alive: int = 3600 * 24,
start: bool = False,
system_prompt: str = 'Svara alltid på svenska. Svara bara på det som efterfrågas. Om du inte kan svara, skriv "Jag vet inte".',
temperature: str = 0,
):
"""
Initializes an instance of MyClass.
Args:
chat (bool, optional): Specifies whether the instance is for chat purposes. Defaults to False.
model (str, optional): The model to be used. Defaults to "llama3:8b-instruct-q5_K_M".
keep_alive (int, optional): The duration in seconds to keep the instance alive. Defaults to 3600*24.
start (bool, optional): If True, the instance will automatically start processing requests upon initialization.
This means that a separate thread will be started that runs the generate_concurrent method,
which processes requests concurrently. Defaults to False.
"""
self.server = "192.168.1.12"
self.port = 3300 # 11440 All 4 GPU # 4500 "SW" 3300 balancer
self.model = model
self.temperature = temperature
self.system_message = {"role": "system", "content": system_prompt}
self.messages = [self.system_message]
self.chat = chat
self.max_tokens = 24000
self.keep_alive = keep_alive
self.request_queue = queue.Queue()
self.result_queue = queue.Queue()
self.all_requests_added_event = threading.Event()
self.all_results_processed_event = threading.Event()
self.stop_event = threading.Event()
if start:
self.start()
def generate(self, message):
# Remove leading and trailing whitespace
message = '\n'.join(line.strip() for line in message.split('\n'))
# Prepare the request data
options = {
"temperature": self.temperature,
}
if self.chat:
self.build_message(message)
messages = self.messages
else:
messages = [self.system_message, {"role": "user", "content": message}]
data = {
"model": self.model,
"messages": messages,
"options": options,
"keep_alive": self.keep_alive,
"stream": False,
}
# Make a POST request to the API endpoint
result = requests.post(
f"http://{self.server}:{self.port}/api/chat", json=data
).json()
# print_data = result.copy()
# del print_data["message"]
# del print_data["model"]
# # Convert durations from nanoseconds to seconds
# for key in ['eval_duration', 'total_duration']:
# if key in print_data:
# duration = print_data[key] / 1e9 # Convert nanoseconds to seconds
# minutes, seconds = divmod(duration, 60) # Convert seconds to minutes and remainder seconds
# print_data[key] = f'{int(minutes)}:{seconds:02.0f}' # Format as minutes:seconds
# pprint(print_data)
# print('Number of messages', len(messages))
if "message" in result:
answer = result["message"]["content"]
else:
pprint(result)
raise Exception("Error occurred during API request")
if self.chat:
self.messages.append({"role": "assistant", "content": answer})
return answer
def generate_concurrent(
self,
request_queue,
result_queue,
all_requests_added_event,
all_results_processed_event,
):
self.chat = False
with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_message = {}
buffer_size = 6 # The number of tasks to keep in the executor
while True:
if self.stop_event.is_set():
break
try:
# If there are less than buffer_size tasks being processed, add new tasks
while len(future_to_message) < buffer_size:
# Take a request from the queue
doc_id, message = request_queue.get(timeout=1)
# Submit the generate method to the executor for execution
future = executor.submit(self.generate, message)
future_to_message[future] = doc_id
except queue.Empty:
# If the queue is empty and all requests have been added, break the loop
if all_requests_added_event.is_set():
break
else:
continue
# Process completed futures
done_futures = [f for f in future_to_message if f.done()]
for future in done_futures:
doc_id = future_to_message.pop(future)
try:
summary = future.result()
except Exception as exc:
print("Document %r generated an exception: %s" % (doc_id, exc))
else:
# Put the document ID and the summary into the result queue
result_queue.put((doc_id, summary))
all_results_processed_event.set()
def start(self):
# Start a separate thread that runs the generate_concurrent method
threading.Thread(
target=self.generate_concurrent,
args=(
self.request_queue,
self.result_queue,
self.all_requests_added_event,
self.all_results_processed_event,
),
).start()
def stop(self):
"""
Stops the instance from processing further requests.
"""
self.stop_event.set()
def add_request(self, id, prompt):
# Add a request to the request queue
self.request_queue.put((id, prompt))
def finish_adding_requests(self):
# Signal that all requests have been added
print("\033[92mAll requests added\033[0m")
self.all_requests_added_event.set()
def get_results(self):
# Process the results
while True:
try:
# Take a result from the result queue
doc_id, summary = self.result_queue.get(timeout=1)
return doc_id, summary
except queue.Empty:
# If the result queue is empty and all results have been processed, break the loop
if self.all_results_processed_event.is_set():
break
else:
sleep(0.2)
continue
def build_message(self, message):
# Add the new message to the list
self.messages.append({"role": "user", "content": message})
# Calculate the total token length of the messages
total_tokens = sum([len((msg["content"])) for msg in self.messages])
# While the total token length exceeds the limit, remove the oldest messages
while total_tokens > self.max_tokens:
removed_message = self.messages.pop(
1
) # Remove the oldest message (not the system message)
total_tokens -= len((removed_message["content"]))
def unload_model(self):
data = {
"model": self.model,
"messages": self.messages,
"keep_alive": 0,
"stream": False,
}
# Make a POST request to the API endpoint
requests.post(f"http://{self.server}:{self.port}/api/chat", json=data).json()[
"message"
]["content"]
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--unload", action="store_true", help="Unload the model")
args = parser.parse_args()
# llm = LLM(model='llama3:70b-text-q4_K_M', keep_alive=6000, chat=True)
llm = LLM(keep_alive=6000, chat=True)
if args.unload:
llm.unload_model()
else:
while True:
message = input(">>> ")
message = '''Hej
bad är kul'''
print(llm.generate(message))