You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
235 lines
8.3 KiB
235 lines
8.3 KiB
from time import sleep |
|
import requests |
|
import concurrent.futures |
|
import queue |
|
import threading |
|
from pprint import pprint |
|
import re |
|
|
|
class LLM: |
|
def __init__( |
|
self, |
|
chat: bool = False, |
|
model: str = "llama3:8b-instruct-q5_K_M", |
|
keep_alive: int = 3600 * 24, |
|
start: bool = False, |
|
system_prompt: str = 'Svara alltid på svenska. Svara bara på det som efterfrågas. Om du inte kan svara, skriv "Jag vet inte".', |
|
temperature: str = 0, |
|
): |
|
""" |
|
Initializes an instance of MyClass. |
|
|
|
Args: |
|
chat (bool, optional): Specifies whether the instance is for chat purposes. Defaults to False. |
|
model (str, optional): The model to be used. Defaults to "llama3:8b-instruct-q5_K_M". |
|
keep_alive (int, optional): The duration in seconds to keep the instance alive. Defaults to 3600*24. |
|
start (bool, optional): If True, the instance will automatically start processing requests upon initialization. |
|
This means that a separate thread will be started that runs the generate_concurrent method, |
|
which processes requests concurrently. Defaults to False. |
|
""" |
|
|
|
self.server = "192.168.1.12" |
|
self.port = 3300 # 11440 All 4 GPU # 4500 "SW" 3300 balancer |
|
self.model = model |
|
self.temperature = temperature |
|
self.system_message = {"role": "system", "content": system_prompt} |
|
self.messages = [self.system_message] |
|
self.chat = chat |
|
self.max_tokens = 24000 |
|
self.keep_alive = keep_alive |
|
self.request_queue = queue.Queue() |
|
self.result_queue = queue.Queue() |
|
self.all_requests_added_event = threading.Event() |
|
self.all_results_processed_event = threading.Event() |
|
self.stop_event = threading.Event() |
|
|
|
if start: |
|
self.start() |
|
|
|
def generate(self, message): |
|
|
|
# Remove leading and trailing whitespace |
|
message = '\n'.join(line.strip() for line in message.split('\n')) |
|
|
|
# Prepare the request data |
|
options = { |
|
"temperature": self.temperature, |
|
} |
|
|
|
if self.chat: |
|
self.build_message(message) |
|
messages = self.messages |
|
else: |
|
messages = [self.system_message, {"role": "user", "content": message}] |
|
|
|
data = { |
|
"model": self.model, |
|
"messages": messages, |
|
"options": options, |
|
"keep_alive": self.keep_alive, |
|
"stream": False, |
|
} |
|
|
|
# Make a POST request to the API endpoint |
|
result = requests.post( |
|
f"http://{self.server}:{self.port}/api/chat", json=data |
|
).json() |
|
|
|
# print_data = result.copy() |
|
# del print_data["message"] |
|
# del print_data["model"] |
|
|
|
# # Convert durations from nanoseconds to seconds |
|
# for key in ['eval_duration', 'total_duration']: |
|
# if key in print_data: |
|
# duration = print_data[key] / 1e9 # Convert nanoseconds to seconds |
|
# minutes, seconds = divmod(duration, 60) # Convert seconds to minutes and remainder seconds |
|
# print_data[key] = f'{int(minutes)}:{seconds:02.0f}' # Format as minutes:seconds |
|
|
|
# pprint(print_data) |
|
# print('Number of messages', len(messages)) |
|
|
|
if "message" in result: |
|
answer = result["message"]["content"] |
|
else: |
|
pprint(result) |
|
raise Exception("Error occurred during API request") |
|
|
|
if self.chat: |
|
self.messages.append({"role": "assistant", "content": answer}) |
|
|
|
return answer |
|
|
|
def generate_concurrent( |
|
self, |
|
request_queue, |
|
result_queue, |
|
all_requests_added_event, |
|
all_results_processed_event, |
|
): |
|
self.chat = False |
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
future_to_message = {} |
|
buffer_size = 6 # The number of tasks to keep in the executor |
|
while True: |
|
if self.stop_event.is_set(): |
|
break |
|
try: |
|
# If there are less than buffer_size tasks being processed, add new tasks |
|
while len(future_to_message) < buffer_size: |
|
# Take a request from the queue |
|
doc_id, message = request_queue.get(timeout=1) |
|
# Submit the generate method to the executor for execution |
|
future = executor.submit(self.generate, message) |
|
future_to_message[future] = doc_id |
|
except queue.Empty: |
|
# If the queue is empty and all requests have been added, break the loop |
|
if all_requests_added_event.is_set(): |
|
break |
|
else: |
|
continue |
|
|
|
# Process completed futures |
|
done_futures = [f for f in future_to_message if f.done()] |
|
for future in done_futures: |
|
doc_id = future_to_message.pop(future) |
|
try: |
|
summary = future.result() |
|
except Exception as exc: |
|
print("Document %r generated an exception: %s" % (doc_id, exc)) |
|
else: |
|
# Put the document ID and the summary into the result queue |
|
result_queue.put((doc_id, summary)) |
|
|
|
all_results_processed_event.set() |
|
|
|
def start(self): |
|
# Start a separate thread that runs the generate_concurrent method |
|
threading.Thread( |
|
target=self.generate_concurrent, |
|
args=( |
|
self.request_queue, |
|
self.result_queue, |
|
self.all_requests_added_event, |
|
self.all_results_processed_event, |
|
), |
|
).start() |
|
|
|
def stop(self): |
|
""" |
|
Stops the instance from processing further requests. |
|
""" |
|
self.stop_event.set() |
|
|
|
def add_request(self, id, prompt): |
|
# Add a request to the request queue |
|
self.request_queue.put((id, prompt)) |
|
|
|
def finish_adding_requests(self): |
|
# Signal that all requests have been added |
|
print("\033[92mAll requests added\033[0m") |
|
self.all_requests_added_event.set() |
|
|
|
def get_results(self): |
|
# Process the results |
|
while True: |
|
try: |
|
# Take a result from the result queue |
|
doc_id, summary = self.result_queue.get(timeout=1) |
|
return doc_id, summary |
|
|
|
except queue.Empty: |
|
# If the result queue is empty and all results have been processed, break the loop |
|
if self.all_results_processed_event.is_set(): |
|
break |
|
else: |
|
sleep(0.2) |
|
continue |
|
|
|
def build_message(self, message): |
|
# Add the new message to the list |
|
self.messages.append({"role": "user", "content": message}) |
|
|
|
# Calculate the total token length of the messages |
|
total_tokens = sum([len((msg["content"])) for msg in self.messages]) |
|
|
|
# While the total token length exceeds the limit, remove the oldest messages |
|
while total_tokens > self.max_tokens: |
|
removed_message = self.messages.pop( |
|
1 |
|
) # Remove the oldest message (not the system message) |
|
total_tokens -= len((removed_message["content"])) |
|
|
|
def unload_model(self): |
|
data = { |
|
"model": self.model, |
|
"messages": self.messages, |
|
"keep_alive": 0, |
|
"stream": False, |
|
} |
|
|
|
# Make a POST request to the API endpoint |
|
requests.post(f"http://{self.server}:{self.port}/api/chat", json=data).json()[ |
|
"message" |
|
]["content"] |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
import argparse |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--unload", action="store_true", help="Unload the model") |
|
args = parser.parse_args() |
|
|
|
# llm = LLM(model='llama3:70b-text-q4_K_M', keep_alive=6000, chat=True) |
|
llm = LLM(keep_alive=6000, chat=True) |
|
|
|
if args.unload: |
|
llm.unload_model() |
|
else: |
|
while True: |
|
message = input(">>> ") |
|
message = '''Hej |
|
bad är kul''' |
|
print(llm.generate(message))
|
|
|