You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
76 lines
2.6 KiB
76 lines
2.6 KiB
from openai import OpenAI, RateLimitError |
|
from dotenv import load_dotenv |
|
import os |
|
from _llm import LLM as LLM_ollama |
|
from print_color import * |
|
from time import sleep |
|
load_dotenv() |
|
|
|
class LLM_OpenAI: |
|
def __init__( |
|
self, |
|
system_prompt='Svara alltid på svenska. Svara bara på det som efterfrågas. Om du inte kan svara, skriv "Jag vet inte".', |
|
chat=False, |
|
model="gpt-3.5-turbo-0125", |
|
max_tokens=24000, |
|
sleep_time=0 |
|
): |
|
self.chat = chat |
|
self.model = model |
|
self.temperature=0 |
|
self.max_tokens = max_tokens |
|
self.system_message = {"role": "system", "content": system_prompt} |
|
self.messages =[self.system_message] |
|
self.client = OpenAI( |
|
# This is the default and can be omitted |
|
api_key=os.getenv("OPEN_AI"), |
|
) |
|
self.llm_ollama = LLM_ollama(chat=False, stream=True) # For backup |
|
self.sleep_time = sleep_time |
|
|
|
def build_message(self, message): |
|
# Add the new message to the list |
|
self.messages.append({"role": "user", "content": message}) |
|
|
|
# Calculate the total token length of the messages |
|
total_tokens = sum([len((msg["content"])) for msg in self.messages]) |
|
|
|
# While the total token length exceeds the limit, remove the oldest messages |
|
while total_tokens > self.max_tokens: |
|
removed_message = self.messages.pop( |
|
1 |
|
) # Remove the oldest message (not the system message) |
|
total_tokens -= len((removed_message["content"])) |
|
|
|
def generate(self, prompt, stream=False, local=False): |
|
sleep(self.sleep_time) |
|
if self.chat: |
|
self.build_message(prompt) |
|
messages = self.messages |
|
else: |
|
messages = [self.system_message, {"role": "user", "content": prompt}] |
|
print(sum([len((msg["content"])) for msg in messages])) |
|
|
|
if local: |
|
response = self.llm_ollama.generate_stream(prompt) |
|
|
|
else: |
|
try: |
|
response = self.client.chat.completions.create( |
|
messages=messages, |
|
model=self.model, |
|
stream=stream |
|
) |
|
except RateLimitError as e: |
|
print_red(e) |
|
response = self.llm_ollama.generate_stream(prompt) |
|
|
|
if stream: |
|
return response |
|
else: |
|
answer = response.choices[0].message.content |
|
if self.chat: |
|
self.messages.append({"role": "assistant", "content": answer}) |
|
return answer |
|
|
|
|
|
|