You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
440 lines
16 KiB
440 lines
16 KiB
import os |
|
import base64 |
|
import re |
|
import json |
|
from typing import Any, Callable, Iterator, Literal, Mapping, Optional, Sequence, Union |
|
|
|
import tiktoken |
|
from ollama import Client, AsyncClient, ResponseError, ChatResponse, Message, Tool, Options |
|
from ollama._types import JsonSchemaValue, ChatRequest |
|
|
|
import env_manager |
|
from colorprinter.print_color import * |
|
|
|
env_manager.set_env() |
|
|
|
tokenizer = tiktoken.get_encoding("cl100k_base") |
|
|
|
# Define a base class for common functionality |
|
class BaseClient: |
|
def chat( |
|
self, |
|
model: str = '', |
|
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, |
|
*, |
|
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, |
|
stream: bool = False, |
|
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, |
|
options: Optional[Union[Mapping[str, Any], Options]] = None, |
|
keep_alive: Optional[Union[float, str]] = None, |
|
) -> Union[ChatResponse, Iterator[ChatResponse]]: |
|
return self._request( |
|
ChatResponse, |
|
'POST', |
|
'/api/chat', |
|
json=ChatRequest( |
|
model=model, |
|
messages=[message for message in messages or []], |
|
tools=[tool for tool in tools or []], |
|
stream=stream, |
|
format=format, |
|
options=options, |
|
keep_alive=keep_alive, |
|
).model_dump(exclude_none=True), |
|
stream=stream, |
|
) |
|
|
|
# Define your custom MyAsyncClient class |
|
class MyAsyncClient(AsyncClient, BaseClient): |
|
async def _request(self, response_type, method, path, headers=None, **kwargs): |
|
# Merge default headers with per-call headers |
|
all_headers = {**self._client.headers, **(headers or {})} |
|
|
|
# Handle streaming separately |
|
if kwargs.get('stream'): |
|
kwargs.pop('stream') |
|
async with self._client.stream(method, path, headers=all_headers, **kwargs) as response: |
|
self.last_response = response # Store the response object |
|
if response.status_code >= 400: |
|
await response.aread() |
|
raise ResponseError(response.text, response.status_code) |
|
return self._stream(response_type, response) |
|
else: |
|
# Make the HTTP request with the combined headers |
|
kwargs.pop('stream') |
|
response = await self._request_raw(method, path, headers=all_headers, **kwargs) |
|
self.last_response = response # Store the response object |
|
|
|
if response.status_code >= 400: |
|
raise ResponseError(response.text, response.status_code) |
|
return response_type.model_validate_json(response.content) |
|
|
|
async def chat( |
|
self, |
|
model: str = '', |
|
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, |
|
*, |
|
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, |
|
stream: bool = False, |
|
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, |
|
options: Optional[Union[Mapping[str, Any], Options]] = None, |
|
keep_alive: Optional[Union[float, str]] = None, |
|
) -> Union[ChatResponse, Iterator[ChatResponse]]: |
|
return await self._request( |
|
ChatResponse, |
|
'POST', |
|
'/api/chat', |
|
json=ChatRequest( |
|
model=model, |
|
messages=[message for message in messages or []], |
|
tools=[tool for tool in tools or []], |
|
stream=stream, |
|
format=format, |
|
options=options, |
|
keep_alive=keep_alive, |
|
).model_dump(exclude_none=True), |
|
stream=stream, |
|
) |
|
|
|
# Define your custom MyClient class |
|
class MyClient(Client, BaseClient): |
|
def _request(self, response_type, method, path, headers=None, **kwargs): |
|
# Merge default headers with per-call headers |
|
all_headers = {**self._client.headers, **(headers or {})} |
|
|
|
# Handle streaming separately |
|
if kwargs.get('stream'): |
|
kwargs.pop('stream') |
|
with self._client.stream(method, path, headers=all_headers, **kwargs) as response: |
|
self.last_response = response # Store the response object |
|
if response.status_code >= 400: |
|
raise ResponseError(response.text, response.status_code) |
|
return self._stream(response_type, response) |
|
else: |
|
# Make the HTTP request with the combined headers |
|
kwargs.pop('stream') |
|
response = self._request_raw(method, path, headers=all_headers, **kwargs) |
|
self.last_response = response # Store the response object |
|
|
|
if response.status_code >= 400: |
|
raise ResponseError(response.text, response.status_code) |
|
return response_type.model_validate_json(response.content) |
|
|
|
class LLM: |
|
""" |
|
LLM class for interacting with a language model. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
system_message="You are an assistant.", |
|
temperature=0.01, |
|
model: Optional[Literal["small", "standard", "vision"]] = "standard", |
|
max_length_answer=4096, |
|
messages=None, |
|
chat=True, |
|
chosen_backend=None, |
|
) -> None: |
|
|
|
self.model = self.get_model(model) |
|
self.system_message = system_message |
|
self.options = {"temperature": temperature} |
|
self.messages = messages or [{"role": "system", "content": self.system_message}] |
|
self.max_length_answer = max_length_answer |
|
self.chat = chat |
|
self.chosen_backend = chosen_backend |
|
|
|
# Initialize the client with the host and default headers |
|
credentials = f"{os.getenv('LLM_API_USER')}:{os.getenv('LLM_API_PWD_LASSE')}" |
|
encoded_credentials = base64.b64encode(credentials.encode()).decode() |
|
default_headers = { |
|
"Authorization": f"Basic {encoded_credentials}", |
|
} |
|
host_url = os.getenv("LLM_API_URL").rstrip("/api/chat/") |
|
self.client = MyClient(host=host_url, headers=default_headers) |
|
self.async_client = MyAsyncClient(host=host_url, headers=default_headers) |
|
|
|
def get_model(self, model_alias): |
|
models = { |
|
"standard": "LLM_MODEL", |
|
"small": "LLM_MODEL_SMALL", |
|
"vision": "LLM_MODEL_VISION", |
|
"standard_64k": "LLM_MODEL_64K", |
|
} |
|
return os.getenv(models.get(model_alias, "LLM_MODEL")) |
|
|
|
def count_tokens(self): |
|
num_tokens = 0 |
|
for i in self.messages: |
|
for k, v in i.items(): |
|
if k == "content": |
|
if not isinstance(v, str): |
|
v = str(v) |
|
tokens = tokenizer.encode(v) |
|
num_tokens += len(tokens) |
|
return int(num_tokens) |
|
|
|
def generate( |
|
self, |
|
query: str = None, |
|
user_input: str = None, |
|
context: str = None, |
|
stream: bool = False, |
|
tools: list = None, |
|
function_call: dict = None, |
|
images: list = None, |
|
model: Optional[Literal["small", "standard", "vision"]] = None, |
|
temperature: float = None, |
|
): |
|
""" |
|
Generates a response from the language model based on the provided inputs. |
|
""" |
|
|
|
# Prepare the model and temperature |
|
model = self.get_model(model) if model else self.model |
|
temperature = temperature if temperature else self.options["temperature"] |
|
|
|
# Normalize whitespace and add the query to the messages |
|
query = re.sub(r"\s*\n\s*", "\n", query) |
|
message = {"role": "user", "content": query} |
|
|
|
# Handle images if any |
|
if images: |
|
import base64 |
|
|
|
base64_images = [] |
|
base64_pattern = re.compile(r"^[A-Za-z0-9+/]+={0,2}$") |
|
|
|
for image in images: |
|
if isinstance(image, str): |
|
if base64_pattern.match(image): |
|
base64_images.append(image) |
|
else: |
|
with open(image, "rb") as image_file: |
|
base64_images.append( |
|
base64.b64encode(image_file.read()).decode("utf-8") |
|
) |
|
elif isinstance(image, bytes): |
|
base64_images.append(base64.b64encode(image).decode("utf-8")) |
|
else: |
|
print_red("Invalid image type") |
|
|
|
message["images"] = base64_images |
|
# Use the vision model |
|
model = self.get_model("vision") |
|
|
|
self.messages.append(message) |
|
|
|
# Prepare headers |
|
headers = {} |
|
if self.chosen_backend: |
|
headers["X-Chosen-Backend"] = self.chosen_backend |
|
|
|
if model == self.get_model("small"): |
|
headers["X-Model-Type"] = "small" |
|
|
|
# Prepare options |
|
options = Options(**self.options) |
|
options.temperature = temperature |
|
|
|
# Prepare tools if any |
|
if tools: |
|
tools = [ |
|
Tool(**tool) if isinstance(tool, dict) else tool |
|
for tool in tools |
|
] |
|
|
|
# Adjust the options for long messages |
|
if self.chat or len(self.messages) > 15000: |
|
num_tokens = self.count_tokens() + self.max_length_answer // 2 |
|
if num_tokens > 8000: |
|
model = self.get_model("standard_64k") |
|
headers["X-Model-Type"] = "large" |
|
|
|
# Call the client.chat method |
|
try: |
|
response = self.client.chat( |
|
model=model, |
|
messages=self.messages, |
|
headers=headers, |
|
tools=tools, |
|
stream=stream, |
|
options=options, |
|
keep_alive=3600 * 24 * 7, |
|
) |
|
except ResponseError as e: |
|
print_red("Error!") |
|
print(e) |
|
return "An error occurred." |
|
|
|
# If user_input is provided, update the last message |
|
if user_input: |
|
if context: |
|
if len(context) > 2000: |
|
context = self.make_summary(context) |
|
user_input = ( |
|
f"{user_input}\n\nUse the information below to answer the question.\n" |
|
f'"""{context}"""\n[This is a summary of the context provided in the original message.]' |
|
) |
|
system_message_info = "\nSometimes some of the messages in the chat history are summarised, then that is clearly indicated in the message." |
|
if system_message_info not in self.messages[0]["content"]: |
|
self.messages[0]["content"] += system_message_info |
|
self.messages[-1] = {"role": "user", "content": user_input} |
|
|
|
self.chosen_backend = self.client.last_response.headers.get("X-Chosen-Backend") |
|
|
|
# Handle streaming response |
|
if stream: |
|
return self.read_stream(response) |
|
else: |
|
# Process the response |
|
if isinstance(response, ChatResponse): |
|
result = response.message.content.strip('"') |
|
self.messages.append({"role": "assistant", "content": result.strip('"')}) |
|
if tools and not response.message.get("tool_calls"): |
|
print_yellow("No tool calls in response".upper()) |
|
if not self.chat: |
|
self.messages = [self.messages[0]] |
|
return result |
|
else: |
|
print_red("Unexpected response type") |
|
return "An error occurred." |
|
|
|
def make_summary(self, text): |
|
# Implement your summary logic using self.client.chat() |
|
summary_message = { |
|
"role": "user", |
|
"content": f'Summarize the text below:\n"""{text}"""\nRemember to be concise and detailed. Answer in English.', |
|
} |
|
messages = [ |
|
{"role": "system", "content": "You are summarizing a text. Make it detailed and concise. Answer ONLY with the summary. Don't add any new information."}, |
|
summary_message, |
|
] |
|
try: |
|
response = self.client.chat( |
|
model=self.get_model("small"), |
|
messages=messages, |
|
options=Options(temperature=0.01), |
|
keep_alive=3600 * 24 * 7, |
|
) |
|
summary = response.message.content.strip() |
|
print_blue("Summary:", summary) |
|
return summary |
|
except ResponseError as e: |
|
print_red("Error generating summary:", e) |
|
return "Summary generation failed." |
|
|
|
def read_stream(self, response): |
|
# Implement streaming response handling if needed |
|
buffer = "" |
|
message = "" |
|
first_chunk = True |
|
prev_content = None |
|
for chunk in response: |
|
if chunk: |
|
content = chunk.message.content |
|
if first_chunk and content.startswith('"'): |
|
content = content[1:] |
|
first_chunk = False |
|
|
|
if chunk.done: |
|
if prev_content and prev_content.endswith('"'): |
|
prev_content = prev_content[:-1] |
|
if prev_content: |
|
yield prev_content |
|
break |
|
else: |
|
if prev_content: |
|
yield prev_content |
|
prev_content = content |
|
self.messages.append({"role": "assistant", "content": message.strip('"')}) |
|
|
|
async def async_generate( |
|
self, |
|
query: str = None, |
|
user_input: str = None, |
|
context: str = None, |
|
stream: bool = False, |
|
tools: list = None, |
|
function_call: dict = None, |
|
images: list = None, |
|
model: Optional[Literal["small", "standard", "vision"]] = None, |
|
temperature: float = None, |
|
): |
|
""" |
|
Asynchronous method to generate a response from the language model. |
|
""" |
|
|
|
# Prepare the model and temperature |
|
model = self.get_model(model) if model else self.model |
|
temperature = temperature if temperature else self.options["temperature"] |
|
|
|
# Normalize whitespace and add the query to the messages |
|
query = re.sub(r"\s*\n\s*", "\n", query) |
|
message = {"role": "user", "content": query} |
|
|
|
# Handle images if any |
|
if images: |
|
# (Image handling code as in the generate method) |
|
... |
|
|
|
self.messages.append(message) |
|
|
|
# Prepare headers |
|
headers = {} |
|
if self.chosen_backend: |
|
headers["X-Chosen-Backend"] = self.chosen_backend |
|
|
|
if model == self.get_model("small"): |
|
headers["X-Model-Type"] = "small" |
|
|
|
# Prepare options |
|
options = Options(**self.options) |
|
options.temperature = temperature |
|
|
|
# Prepare tools if any |
|
if tools: |
|
tools = [ |
|
Tool(**tool) if isinstance(tool, dict) else tool |
|
for tool in tools |
|
] |
|
|
|
# Adjust options for long messages |
|
# (Adjustments as needed) |
|
... |
|
|
|
# Call the async client's chat method |
|
try: |
|
response = await self.async_client.chat( |
|
model=model, |
|
messages=self.messages, |
|
tools=tools, |
|
stream=stream, |
|
options=options, |
|
keep_alive=3600 * 24 * 7, |
|
) |
|
except ResponseError as e: |
|
print_red("Error!") |
|
print(e) |
|
return "An error occurred." |
|
|
|
# Process the response |
|
if isinstance(response, ChatResponse): |
|
result = response.message.content.strip('"') |
|
self.messages.append({"role": "assistant", "content": result.strip('"')}) |
|
return result |
|
else: |
|
print_red("Unexpected response type") |
|
return "An error occurred." |
|
|
|
# Usage example |
|
if __name__ == "__main__": |
|
import asyncio |
|
|
|
llm = LLM() |
|
|
|
async def main(): |
|
result = await llm.async_generate(query="Hello, how are you?") |
|
print(result) |
|
|
|
asyncio.run(main()) |