You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
107 lines
3.7 KiB
107 lines
3.7 KiB
import os |
|
import requests |
|
from requests.auth import HTTPBasicAuth |
|
import tiktoken |
|
import json |
|
import env_manager |
|
from colorprinter.print_color import * |
|
|
|
tokenizer = tiktoken.get_encoding("cl100k_base") |
|
|
|
class LLM: |
|
def __init__( |
|
self, system_message='You are an assistant.', num_ctx=4096, temperature=0, chat=True, model='standard', max_length_answer=4000 |
|
) -> None: |
|
if model == 'standard': |
|
self.model = os.getenv("LLM_MODEL") |
|
if model == 'small': |
|
self.model = os.getenv('LLM_MODEL_SMALL') |
|
self.system_message = system_message |
|
self.options = {"temperature": temperature, "num_ctx": num_ctx} |
|
self.messages = [{"role": "system", "content": self.system_message}] |
|
self.chat = chat |
|
self.max_length_answer = max_length_answer |
|
|
|
def count_tokens(self): |
|
num_tokens = 0 |
|
for i in self.messages: |
|
for k, v in i.items(): |
|
if k == "content": |
|
if not isinstance(v, str): |
|
v = str(v) |
|
tokens = tokenizer.encode(v) |
|
num_tokens += len(tokens) |
|
return int(num_tokens) |
|
|
|
def read_stream(self, response): |
|
buffer = "" |
|
for chunk in response.iter_content(chunk_size=64): |
|
if chunk: |
|
try: |
|
buffer += chunk.decode('utf-8') |
|
except UnicodeDecodeError: |
|
continue |
|
while "\n" in buffer: |
|
line, buffer = buffer.split("\n", 1) |
|
if line: |
|
try: |
|
json_data = json.loads(line) |
|
yield json_data["message"]["content"] |
|
except json.JSONDecodeError: |
|
continue |
|
def generate(self, query, stream=False): |
|
self.messages.append({"role": "user", "content": query}) |
|
|
|
# Set the number of tokens to be the sum of the tokens in the messages and half of the max length of the answer |
|
num_tokens = self.count_tokens() + self.max_length_answer / 2 |
|
if num_tokens < 4096: |
|
del self.options["num_ctx"] |
|
else: |
|
self.options["num_ctx"] = num_tokens |
|
|
|
headers = {"Content-Type": "application/json"} |
|
data = { |
|
"messages": self.messages, |
|
"stream": stream, |
|
"keep_alive": 3600 * 24 * 7, |
|
"model": self.model, |
|
"options": self.options, |
|
} |
|
|
|
# Ensure the environment variable is correctly referenced |
|
|
|
|
|
response = requests.post( |
|
os.getenv("LLM_API_URL"), |
|
headers=headers, |
|
json=data, |
|
auth=HTTPBasicAuth(os.getenv('LLM_API_USER'), os.getenv('LLM_API_PWD_LASSE')), |
|
stream=stream, |
|
timeout= 3600, |
|
) |
|
|
|
if response.status_code == 404: |
|
return "Target endpoint not found" |
|
|
|
if response.status_code == 504: |
|
return f"Gateway Timeout: {response.content.decode('utf-8')}" |
|
|
|
if stream: |
|
return self.read_stream(response) |
|
|
|
else: |
|
try: |
|
result = response.json()["message"]["content"] |
|
except requests.exceptions.JSONDecodeError: |
|
print_red("Error: ", response.status_code, response.text) |
|
return "An error occurred." |
|
|
|
self.messages.append({"role": "assistant", "content": result}) |
|
if not self.chat: |
|
self.messages = self.messages[:1] |
|
return result |
|
|
|
if __name__ == "__main__": |
|
llm = LLM() |
|
for chunk in llm.generate("Why is the sky red?", stream=True): |
|
print(chunk, end='', flush=True) |