parent
b2725b1376
commit
8381c3bb63
1 changed files with 84 additions and 18 deletions
@ -1,37 +1,103 @@ |
|||||||
from ollama import Client |
|
||||||
import os |
import os |
||||||
|
import requests |
||||||
|
from requests.auth import HTTPBasicAuth |
||||||
|
import tiktoken |
||||||
|
import json |
||||||
import env_manager |
import env_manager |
||||||
|
from colorprinter.print_color import * |
||||||
|
|
||||||
env_manager.set_env() |
tokenizer = tiktoken.get_encoding("cl100k_base") |
||||||
|
|
||||||
|
|
||||||
class LLM: |
class LLM: |
||||||
def __init__( |
def __init__( |
||||||
self, system_message=None, num_ctx=20000, temperature=0, chat=True |
self, system_message='You are an assistant.', num_ctx=4096, temperature=0, chat=True, model='standard', max_length_answer=4000 |
||||||
) -> None: |
) -> None: |
||||||
self.llm_model = "mistral-nemo:12b-instruct-2407-q5_K_M" #os.getenv("LLM_MODEL") |
if model == 'standard': |
||||||
|
self.model = os.getenv("LLM_MODEL") |
||||||
|
if model == 'small': |
||||||
|
self.model = os.getenv('LLM_MODEL_SMALL') |
||||||
self.system_message = system_message |
self.system_message = system_message |
||||||
self.options = {"temperature": temperature, "num_ctx": num_ctx} |
self.options = {"temperature": temperature, "num_ctx": num_ctx} |
||||||
self.messages = [{"role": "system", "content": self.system_message}] |
self.messages = [{"role": "system", "content": self.system_message}] |
||||||
self.chat = chat |
self.chat = chat |
||||||
self.ollama = Client( |
self.max_length_answer = max_length_answer |
||||||
host=f'{os.getenv("LLM_URL")}:{os.getenv("LLM_PORT")}', |
|
||||||
) |
def count_tokens(self): |
||||||
|
num_tokens = 0 |
||||||
|
for i in self.messages: |
||||||
|
for k, v in i.items(): |
||||||
|
if k == "content": |
||||||
|
if not isinstance(v, str): |
||||||
|
v = str(v) |
||||||
|
tokens = tokenizer.encode(v) |
||||||
|
num_tokens += len(tokens) |
||||||
|
return int(num_tokens) |
||||||
|
|
||||||
|
def read_stream(self, response): |
||||||
|
buffer = "" |
||||||
|
for chunk in response.iter_content(chunk_size=64): |
||||||
|
if chunk: |
||||||
|
buffer += chunk.decode('utf-8') |
||||||
|
while "\n" in buffer: |
||||||
|
line, buffer = buffer.split("\n", 1) |
||||||
|
if line: |
||||||
|
try: |
||||||
|
json_data = json.loads(line) |
||||||
|
yield json_data["message"]["content"] |
||||||
|
except json.JSONDecodeError: |
||||||
|
continue |
||||||
|
def generate(self, query, stream=False): |
||||||
|
self.messages.append({"role": "user", "content": query}) |
||||||
|
|
||||||
|
# Set the number of tokens to be the sum of the tokens in the messages and half of the max length of the answer |
||||||
|
num_tokens = self.count_tokens() + self.max_length_answer / 2 |
||||||
|
if num_tokens < 4096: |
||||||
|
del self.options["num_ctx"] |
||||||
|
else: |
||||||
|
self.options["num_ctx"] = num_tokens |
||||||
|
|
||||||
def generate(self, prompt: str) -> str: |
headers = {"Content-Type": "application/json"} |
||||||
self.messages.append({"role": "user", "content": prompt}) |
data = { |
||||||
|
"messages": self.messages, |
||||||
|
"stream": stream, |
||||||
|
"keep_alive": 3600 * 24 * 7, |
||||||
|
"model": self.model, |
||||||
|
"options": self.options, |
||||||
|
} |
||||||
|
|
||||||
result = self.ollama.chat( |
# Ensure the environment variable is correctly referenced |
||||||
model=self.llm_model, messages=self.messages, options=self.options |
|
||||||
|
|
||||||
|
response = requests.post( |
||||||
|
os.getenv("LLM_API_URL"), |
||||||
|
headers=headers, |
||||||
|
json=data, |
||||||
|
auth=HTTPBasicAuth(os.getenv('LLM_API_USER'), os.getenv('LLM_API_PWD_LASSE')), |
||||||
|
stream=stream, |
||||||
) |
) |
||||||
|
|
||||||
answer = result["message"]["content"] |
if response.status_code == 404: |
||||||
self.messages.append({"role": "assistant", "content": answer}) |
return "Target endpoint not found" |
||||||
if not self.chat: |
|
||||||
self.messages = [{"role": "system", "content": self.system_message}] |
if response.status_code == 504: |
||||||
|
return f"Gateway Timeout: {response.content.decode('utf-8')}" |
||||||
|
|
||||||
|
if stream: |
||||||
|
return self.read_stream(response) |
||||||
|
|
||||||
|
else: |
||||||
|
try: |
||||||
|
result = response.json()["message"]["content"] |
||||||
|
except requests.exceptions.JSONDecodeError: |
||||||
|
print_red("Error: ", response.status_code, response.text) |
||||||
|
return "An error occurred." |
||||||
|
|
||||||
return answer |
self.messages.append({"role": "assistant", "content": result}) |
||||||
|
if not self.chat: |
||||||
|
self.messages = self.messages[:1] |
||||||
|
return result |
||||||
|
|
||||||
if __name__ == "__main__": |
if __name__ == "__main__": |
||||||
llm = LLM() |
llm = LLM() |
||||||
print(llm.generate("Why is the sky red?")) |
for chunk in llm.generate("Why is the sky red?", stream=True): |
||||||
|
print(chunk, end='', flush=True) |
||||||
Loading…
Reference in new issue