parent
b2725b1376
commit
8381c3bb63
1 changed files with 84 additions and 18 deletions
@ -1,37 +1,103 @@ |
||||
from ollama import Client |
||||
import os |
||||
import requests |
||||
from requests.auth import HTTPBasicAuth |
||||
import tiktoken |
||||
import json |
||||
import env_manager |
||||
from colorprinter.print_color import * |
||||
|
||||
env_manager.set_env() |
||||
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base") |
||||
|
||||
class LLM: |
||||
def __init__( |
||||
self, system_message=None, num_ctx=20000, temperature=0, chat=True |
||||
self, system_message='You are an assistant.', num_ctx=4096, temperature=0, chat=True, model='standard', max_length_answer=4000 |
||||
) -> None: |
||||
self.llm_model = "mistral-nemo:12b-instruct-2407-q5_K_M" #os.getenv("LLM_MODEL") |
||||
if model == 'standard': |
||||
self.model = os.getenv("LLM_MODEL") |
||||
if model == 'small': |
||||
self.model = os.getenv('LLM_MODEL_SMALL') |
||||
self.system_message = system_message |
||||
self.options = {"temperature": temperature, "num_ctx": num_ctx} |
||||
self.messages = [{"role": "system", "content": self.system_message}] |
||||
self.chat = chat |
||||
self.ollama = Client( |
||||
host=f'{os.getenv("LLM_URL")}:{os.getenv("LLM_PORT")}', |
||||
) |
||||
self.max_length_answer = max_length_answer |
||||
|
||||
def count_tokens(self): |
||||
num_tokens = 0 |
||||
for i in self.messages: |
||||
for k, v in i.items(): |
||||
if k == "content": |
||||
if not isinstance(v, str): |
||||
v = str(v) |
||||
tokens = tokenizer.encode(v) |
||||
num_tokens += len(tokens) |
||||
return int(num_tokens) |
||||
|
||||
def read_stream(self, response): |
||||
buffer = "" |
||||
for chunk in response.iter_content(chunk_size=64): |
||||
if chunk: |
||||
buffer += chunk.decode('utf-8') |
||||
while "\n" in buffer: |
||||
line, buffer = buffer.split("\n", 1) |
||||
if line: |
||||
try: |
||||
json_data = json.loads(line) |
||||
yield json_data["message"]["content"] |
||||
except json.JSONDecodeError: |
||||
continue |
||||
def generate(self, query, stream=False): |
||||
self.messages.append({"role": "user", "content": query}) |
||||
|
||||
# Set the number of tokens to be the sum of the tokens in the messages and half of the max length of the answer |
||||
num_tokens = self.count_tokens() + self.max_length_answer / 2 |
||||
if num_tokens < 4096: |
||||
del self.options["num_ctx"] |
||||
else: |
||||
self.options["num_ctx"] = num_tokens |
||||
|
||||
def generate(self, prompt: str) -> str: |
||||
self.messages.append({"role": "user", "content": prompt}) |
||||
headers = {"Content-Type": "application/json"} |
||||
data = { |
||||
"messages": self.messages, |
||||
"stream": stream, |
||||
"keep_alive": 3600 * 24 * 7, |
||||
"model": self.model, |
||||
"options": self.options, |
||||
} |
||||
|
||||
result = self.ollama.chat( |
||||
model=self.llm_model, messages=self.messages, options=self.options |
||||
# Ensure the environment variable is correctly referenced |
||||
|
||||
|
||||
response = requests.post( |
||||
os.getenv("LLM_API_URL"), |
||||
headers=headers, |
||||
json=data, |
||||
auth=HTTPBasicAuth(os.getenv('LLM_API_USER'), os.getenv('LLM_API_PWD_LASSE')), |
||||
stream=stream, |
||||
) |
||||
|
||||
answer = result["message"]["content"] |
||||
self.messages.append({"role": "assistant", "content": answer}) |
||||
if not self.chat: |
||||
self.messages = [{"role": "system", "content": self.system_message}] |
||||
if response.status_code == 404: |
||||
return "Target endpoint not found" |
||||
|
||||
if response.status_code == 504: |
||||
return f"Gateway Timeout: {response.content.decode('utf-8')}" |
||||
|
||||
if stream: |
||||
return self.read_stream(response) |
||||
|
||||
else: |
||||
try: |
||||
result = response.json()["message"]["content"] |
||||
except requests.exceptions.JSONDecodeError: |
||||
print_red("Error: ", response.status_code, response.text) |
||||
return "An error occurred." |
||||
|
||||
return answer |
||||
self.messages.append({"role": "assistant", "content": result}) |
||||
if not self.chat: |
||||
self.messages = self.messages[:1] |
||||
return result |
||||
|
||||
if __name__ == "__main__": |
||||
llm = LLM() |
||||
print(llm.generate("Why is the sky red?")) |
||||
for chunk in llm.generate("Why is the sky red?", stream=True): |
||||
print(chunk, end='', flush=True) |
||||
Loading…
Reference in new issue