Enhance LLM class: add 'think' parameter for reasoning models and improve message handling

legacy
lasseedfast 9 months ago
parent c275704014
commit 17d3335ff6
  1. 2
      __init__.py
  2. 10
      _llm/__init__.py
  3. 368
      _llm/llm.py

@ -2,6 +2,6 @@
llm_client: A Python package for interacting with LLM models through Ollama. llm_client: A Python package for interacting with LLM models through Ollama.
""" """
from _llm.llm import LLM from _llm._llm.llm import LLM
__all__ = ["LLM"] __all__ = ["LLM"]

@ -1,7 +1,7 @@
""" # """
llm_client: A Python package for interacting with LLM models through Ollama. # llm_client: A Python package for interacting with LLM models through Ollama.
""" # """
from _llm.llm import LLM # from ._llm.llm import LLM # Use relative import with dot prefix
__all__ = ["LLM"] # __all__ = ["LLM"]

@ -60,6 +60,7 @@ class LLM:
chat: bool = True, chat: bool = True,
chosen_backend: str = None, chosen_backend: str = None,
tools: list = None, tools: list = None,
think: bool = False,
) -> None: ) -> None:
""" """
Initialize the assistant with the given parameters. Initialize the assistant with the given parameters.
@ -72,6 +73,7 @@ class LLM:
messages (list[dict], optional): A list of initial messages. Defaults to None. messages (list[dict], optional): A list of initial messages. Defaults to None.
chat (bool): Whether the assistant is in chat mode. Defaults to True. chat (bool): Whether the assistant is in chat mode. Defaults to True.
chosen_backend (str, optional): The backend server to use. If not provided, the least connected server is chosen. chosen_backend (str, optional): The backend server to use. If not provided, the least connected server is chosen.
think (bool): Whether to use thinking mode for reasoning models. Defaults to False.
Returns: Returns:
None None
@ -89,7 +91,6 @@ class LLM:
self.chosen_backend = chosen_backend self.chosen_backend = chosen_backend
headers = { headers = {
"Authorization": f"Basic {self.get_credentials()}", "Authorization": f"Basic {self.get_credentials()}",
} }
@ -130,7 +131,9 @@ class LLM:
num_tokens += len(tokens) num_tokens += len(tokens)
return int(num_tokens) return int(num_tokens)
def _prepare_messages_and_model(self, query, user_input, context, messages, images, model): def _prepare_messages_and_model(
self, query, user_input, context, messages, images, model
):
"""Prepare messages and select the appropriate model, handling images if present.""" """Prepare messages and select the appropriate model, handling images if present."""
if messages: if messages:
messages = [ messages = [
@ -157,25 +160,30 @@ class LLM:
def _build_headers(self, model, tools, think): def _build_headers(self, model, tools, think):
"""Build HTTP headers for API requests, including auth and backend/model info.""" """Build HTTP headers for API requests, including auth and backend/model info."""
headers = {"Authorization": f"Basic {self.get_credentials()}"} headers = {"Authorization": f"Basic {self.get_credentials()}"}
if self.chosen_backend and model not in [self.get_model("vision"), self.get_model("tools"), self.get_model("reasoning")]: if self.chosen_backend and model not in [
self.get_model("vision"),
self.get_model("tools"),
self.get_model("reasoning"),
]:
headers["X-Chosen-Backend"] = self.chosen_backend headers["X-Chosen-Backend"] = self.chosen_backend
if model == self.get_model("small"): if model == self.get_model("small"):
headers["X-Model-Type"] = "small" headers["X-Model-Type"] = "small"
if model == self.get_model("tools"): if model == self.get_model("tools"):
headers["X-Model-Type"] = "tools" headers["X-Model-Type"] = "tools"
if think and model and any([m in model for m in ['qwen3', 'deepseek']]): # No longer need to modify message content for thinking - handled by native API
self.messages[-1]['content'] = f"/think\n{self.messages[-1]['content']}"
elif model and any([m in model for m in ['qwen3', 'deepseek']]):
self.messages[-1]['content'] = f"/no_think\n{self.messages[-1]['content']}"
return headers return headers
def _get_options(self, temperature): def _get_options(self, temperature):
"""Build model options, setting temperature and other parameters.""" """Build model options, setting temperature and other parameters."""
options = Options(**self.options) options = Options(**self.options)
options.temperature = temperature if temperature is not None else self.options["temperature"] options.temperature = (
temperature if temperature is not None else self.options["temperature"]
)
return options return options
def _call_remote_api(self, model, tools, stream, options, format, headers): def _call_remote_api(
self, model, tools, stream, options, format, headers, think=False
):
"""Call the remote Ollama API synchronously.""" """Call the remote Ollama API synchronously."""
self.call_model = model self.call_model = model
self.client: Client = Client(host=self.host_url, headers=headers, timeout=300) self.client: Client = Client(host=self.host_url, headers=headers, timeout=300)
@ -187,11 +195,14 @@ class LLM:
stream=stream, stream=stream,
options=options, options=options,
keep_alive=3600 * 24 * 7, keep_alive=3600 * 24 * 7,
format=format format=format,
think=think,
) )
return response return response
async def _call_remote_api_async(self, model, tools, stream, options, format, headers): async def _call_remote_api_async(
self, model, tools, stream, options, format, headers, think=False
):
"""Call the remote Ollama API asynchronously.""" """Call the remote Ollama API asynchronously."""
print_yellow(f"🤖 Generating using {model} (remote, async)...") print_yellow(f"🤖 Generating using {model} (remote, async)...")
response = await self.async_client.chat( response = await self.async_client.chat(
@ -202,12 +213,14 @@ class LLM:
stream=stream, stream=stream,
options=options, options=options,
keep_alive=3600 * 24 * 7, keep_alive=3600 * 24 * 7,
think=think, # Use native Ollama thinking support
) )
return response return response
def _call_local_ollama(self, model, stream, temperature): def _call_local_ollama(self, model, stream, temperature, think=False):
"""Call the local Ollama instance synchronously.""" """Call the local Ollama instance synchronously."""
import ollama import ollama
print_yellow(f"🤖 Generating using {model} (local)...") print_yellow(f"🤖 Generating using {model} (local)...")
options = {"temperature": temperature} options = {"temperature": temperature}
if stream: if stream:
@ -215,72 +228,130 @@ class LLM:
model=model, model=model,
messages=self.messages, messages=self.messages,
options=options, options=options,
stream=True stream=True,
think=think, # Pass thinking parameter to local ollama
) )
def local_stream_adapter(): def local_stream_adapter():
for chunk in response_stream: for chunk in response_stream:
yield type('OllamaResponse', (), { yield type(
'message': type('Message', (), {'content': chunk['message']['content']}), "OllamaResponse",
'done': chunk.get('done', False) (),
}) {
"message": type(
"Message", (), {"content": chunk["message"]["content"]}
),
"done": chunk.get("done", False),
},
)
return self.read_stream(local_stream_adapter()) return self.read_stream(local_stream_adapter())
else: else:
response = ollama.chat( response = ollama.chat(
model=model, model=model,
messages=self.messages, messages=self.messages,
options=options options=options,
think=think, # Pass thinking parameter to local ollama
)
result = response["message"]["content"]
# Handle thinking content if present (for backward compatibility)
thinking_content = response["message"].get("thinking", None)
response_obj = type(
"LocalChatResponse",
(),
{
"message": type(
"Message",
(),
{
"content": result,
"thinking": thinking_content,
"get": lambda x: None,
},
)
},
) )
result = response['message']['content']
response_obj = type('LocalChatResponse', (), { # No longer need to manually parse </think> tags with native support
'message': type('Message', (), {
'content': result,
'get': lambda x: None
})
})
if '</think>' in result:
result = result.split('</think>')[-1].strip()
response_obj.message.content = result
self.messages.append({"role": "assistant", "content": result}) self.messages.append({"role": "assistant", "content": result})
if not self.chat: if not self.chat:
self.messages = [self.messages[0]] self.messages = [self.messages[0]]
return response_obj.message return response_obj.message
async def _call_local_ollama_async(self, model, stream, temperature): async def _call_local_ollama_async(self, model, stream, temperature, think=False):
"""Call the local Ollama instance asynchronously (using a thread pool).""" """Call the local Ollama instance asynchronously (using a thread pool)."""
import ollama import ollama
import asyncio import asyncio
print_yellow(f"🤖 Generating using {model} (local, async)...") print_yellow(f"🤖 Generating using {model} (local, async)...")
options = {"temperature": temperature} options = {"temperature": temperature}
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if stream: if stream:
def run_stream(): def run_stream():
return ollama.chat( return ollama.chat(
model=model, model=model,
messages=self.messages, messages=self.messages,
options=options, options=options,
stream=True stream=True,
think=think, # Pass thinking parameter to local ollama
) )
response_stream = await loop.run_in_executor(None, run_stream) response_stream = await loop.run_in_executor(None, run_stream)
async def local_stream_adapter(): async def local_stream_adapter():
for chunk in response_stream: for chunk in response_stream:
yield type('OllamaResponse', (), { yield type(
'message': type('Message', (), {'content': chunk['message']['content']}), "OllamaResponse",
'done': chunk.get('done', False) (),
}) {
"message": type(
"Message", (), {"content": chunk["message"]["content"]}
),
"done": chunk.get("done", False),
},
)
return local_stream_adapter() return local_stream_adapter()
else: else:
def run_chat(): def run_chat():
return ollama.chat( return ollama.chat(
model=model, model=model,
messages=self.messages, messages=self.messages,
options=options options=options,
think=think, # Pass thinking parameter to local ollama
) )
response_dict = await loop.run_in_executor(None, run_chat) response_dict = await loop.run_in_executor(None, run_chat)
result = response_dict['message']['content'] result = response_dict["message"]["content"]
# Handle thinking content if present (for backward compatibility)
thinking_content = response_dict["message"].get("thinking", None)
# Create response object with thinking support
response_obj = type(
"LocalChatResponse",
(),
{
"message": type(
"Message",
(),
{
"content": result,
"thinking": thinking_content,
"get": lambda x: None,
},
)
},
)
self.messages.append({"role": "assistant", "content": result}) self.messages.append({"role": "assistant", "content": result})
if not self.chat: if not self.chat:
self.messages = [self.messages[0]] self.messages = [self.messages[0]]
return result return response_obj.message
def generate( def generate(
self, self,
@ -292,44 +363,49 @@ class LLM:
images: list = None, images: list = None,
model: Optional[ model: Optional[
Literal["small", "standard", "vision", "reasoning", "tools"] Literal["small", "standard", "vision", "reasoning", "tools"]
] = 'standard', ] = "standard",
temperature: float = None, temperature: float = None,
messages: list[dict] = None, messages: list[dict] = None,
format = None, format=None,
think = False, think=False,
force_local: bool = False force_local: bool = False,
): ):
""" """
Generate a response based on the provided query and context. Generate a response based on the provided query and context.
""" """
model = self._prepare_messages_and_model(query, user_input, context, messages, images, model) model = self._prepare_messages_and_model(
query, user_input, context, messages, images, model
)
temperature = temperature if temperature else self.options["temperature"] temperature = temperature if temperature else self.options["temperature"]
if not force_local: if not force_local:
try: try:
headers = self._build_headers(model, tools, think) headers = self._build_headers(model, tools, think)
options = self._get_options(temperature) options = self._get_options(temperature)
response = self._call_remote_api(model, tools, stream, options, format, headers) response = self._call_remote_api(
model, tools, stream, options, format, headers, think=think
)
print_rainbow(response)
if stream: if stream:
return self.read_stream(response) return self.read_stream(response)
else: else:
if isinstance(response, ChatResponse): if isinstance(response, ChatResponse):
result = response.message.content.strip('"') result = response.message.content.strip('"')
if '</think>' in result:
result = result.split('</think>')[-1] message_content = result.strip('"')
self.messages.append({"role": "assistant", "content": result.strip('"')}) self.messages.append(
if tools and not response.message.get("tool_calls"): {"role": "assistant", "content": message_content}
pass )
if not self.chat: if not self.chat:
self.messages = [self.messages[0]] self.messages = [self.messages[0]]
if not think:
response.message.content = remove_thinking(response.message.content)
return response.message return response.message
else: else:
return "An error occurred." return "An error occurred."
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
try: try:
return self._call_local_ollama(model, stream, temperature) return self._call_local_ollama(model, stream, temperature, think=think)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
return "Both remote API and local Ollama failed. An error occurred." return "Both remote API and local Ollama failed. An error occurred."
@ -344,26 +420,81 @@ class LLM:
images: list = None, images: list = None,
model: Optional[ model: Optional[
Literal["small", "standard", "vision", "reasoning", "tools"] Literal["small", "standard", "vision", "reasoning", "tools"]
] = 'standard', ] = "standard",
temperature: float = None, temperature: float = None,
messages: list[dict] = None,
format=None,
think=False,
force_local: bool = False, force_local: bool = False,
): ):
""" """
Asynchronously generates a response based on the provided query and other parameters. Asynchronously generates a response based on the provided query and other parameters.
Args:
query (str, optional): The query string to generate a response for.
user_input (str, optional): Additional user input to be included in the response.
context (str, optional): Context information to be used in generating the response.
stream (bool, optional): Whether to stream the response. Defaults to False.
tools (list, optional): List of tools to be used in generating the response.
images (list, optional): List of images to be included in the response.
model (Optional[Literal["small", "standard", "vision", "reasoning", "tools"]], optional): The model to be used for generating the response.
temperature (float, optional): The temperature setting for the model.
messages (list[dict], optional): List of messages to use instead of building from query.
format: Format specification for the response.
think (bool, optional): Whether to use thinking mode for reasoning models.
force_local (bool, optional): Force using local Ollama instead of remote API.
Returns:
The generated response message or an error message if an exception occurs.
""" """
model = self._prepare_messages_and_model(query, user_input, context, None, images, model) model = self._prepare_messages_and_model(
query, user_input, context, messages, images, model
)
temperature = temperature if temperature else self.options["temperature"] temperature = temperature if temperature else self.options["temperature"]
# First try with remote API
if not force_local: if not force_local:
try: try:
headers = self._build_headers(model, tools, False) headers = self._build_headers(model, tools, think)
options = self._get_options(temperature) options = self._get_options(temperature)
response = await self._call_remote_api_async(model, tools, stream, options, None, headers) response = await self._call_remote_api_async(
# You can add async-specific response handling here if needed model, tools, stream, options, format, headers, think=think
)
if stream:
return self.read_stream(response)
else:
if isinstance(response, ChatResponse):
# Handle native thinking mode with separate thinking field
result = response.message.content.strip('"')
thinking_content = getattr(response.message, "thinking", None)
# Store both content and thinking in message history
message_content = result.strip('"')
self.messages.append(
{"role": "assistant", "content": message_content}
)
if not self.chat:
self.messages = [self.messages[0]]
# Return response with both content and thinking accessible
if thinking_content and think:
# Add thinking as an attribute for access if needed
response.message.thinking = thinking_content
return response.message
else:
return "An error occurred."
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
if force_local or 'response' not in locals():
# Fallback to local Ollama or if force_local is True
try: try:
return await self._call_local_ollama_async(model, stream, temperature) return await self._call_local_ollama_async(
model, stream, temperature, think=think
)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
return "Both remote API and local Ollama failed. An error occurred." return "Both remote API and local Ollama failed. An error occurred."
@ -396,108 +527,41 @@ class LLM:
def read_stream(self, response): def read_stream(self, response):
""" """
Yields tuples of (chunk_type, text). The first tuple is ('thinking', ...) Read streaming response and handle thinking content appropriately.
if in_thinking is True and stops at </think>. After that, yields ('normal', ...) With native thinking mode, the thinking content is separate from the main content.
for the rest of the text.
""" """
thinking_buffer = "" accumulated_content = ""
in_thinking = self.call_model == self.get_model("reasoning") accumulated_thinking = ""
first_chunk = True
prev_content = None
for chunk in response: for chunk in response:
if not chunk: if not chunk:
continue continue
content = chunk.message.content
# Remove leading quote if it's the first chunk # Handle thinking content (if present in streaming)
if first_chunk and content.startswith('"'): thinking_content = getattr(chunk.message, "thinking", None)
if thinking_content:
accumulated_thinking += thinking_content
yield ("thinking", thinking_content)
# Handle regular content
content = chunk.message.content
if content:
# Remove leading/trailing quotes that sometimes appear
if content.startswith('"') and len(accumulated_content) == 0:
content = content[1:] content = content[1:]
first_chunk = False if chunk.done and content.endswith('"'):
content = content[:-1]
if in_thinking:
thinking_buffer += content accumulated_content += content
if "</think>" in thinking_buffer: yield ("normal", content)
end_idx = thinking_buffer.index("</think>") + len("</think>")
yield ("thinking", thinking_buffer[:end_idx])
remaining = thinking_buffer[end_idx:].strip('"')
if chunk.done and remaining:
yield ("normal", remaining)
break
else:
prev_content = remaining
in_thinking = False
else:
if prev_content:
yield ("normal", prev_content)
prev_content = content
if chunk.done: if chunk.done:
if prev_content and prev_content.endswith('"'):
prev_content = prev_content[:-1]
if prev_content:
yield ("normal", prev_content)
break break
self.messages.append({"role": "assistant", "content": ""}) # Store the complete response in message history
self.messages.append({"role": "assistant", "content": accumulated_content})
if not self.chat:
async def async_generate( self.messages = [self.messages[0]]
self,
query: str = None,
user_input: str = None,
context: str = None,
stream: bool = False,
tools: list = None,
images: list = None,
model: Optional[Literal["small", "standard", "vision"]] = None,
temperature: float = None,
force_local: bool = False, # New parameter to force local Ollama
):
"""
Asynchronously generates a response based on the provided query and other parameters.
Args:
query (str, optional): The query string to generate a response for.
user_input (str, optional): Additional user input to be included in the response.
context (str, optional): Context information to be used in generating the response.
stream (bool, optional): Whether to stream the response. Defaults to False.
tools (list, optional): List of tools to be used in generating the response. Will set the model to 'tools'.
images (list, optional): List of images to be included in the response.
model (Optional[Literal["small", "standard", "vision", "tools"]], optional): The model to be used for generating the response.
temperature (float, optional): The temperature setting for the model.
force_local (bool, optional): Force using local Ollama instead of remote API.
Returns:
str: The generated response or an error message if an exception occurs.
"""
print_yellow("ASYNC GENERATE")
# Prepare the model and temperature
model = self._prepare_messages_and_model(query, user_input, context, None, images, model)
temperature = temperature if temperature else self.options["temperature"]
# First try with remote API
if not force_local:
try:
headers = self._build_headers(model, tools, False)
options = self._get_options(temperature)
response = await self._call_remote_api_async(model, tools, stream, options, None, headers)
# Process response from async client
# [Rest of the response processing code as in the original method]
except Exception as e:
print_red(f"Remote API error: {str(e)}")
print_yellow("Falling back to local Ollama...")
# Fallback to local Ollama (for async we'll need to use the sync version)
if force_local or 'response' not in locals():
try:
return await self._call_local_ollama_async(model, stream, temperature)
except Exception as e:
print_red(f"Local Ollama error: {str(e)}")
return "Both remote API and local Ollama failed. An error occurred."
def prepare_images(self, images, message): def prepare_images(self, images, message):
""" """
@ -532,12 +596,6 @@ class LLM:
message["images"] = base64_images message["images"] = base64_images
return message return message
def remove_thinking(response):
"""Remove the thinking section from the response"""
response_text = response.content if hasattr(response, "content") else str(response)
if "</think>" in response_text:
return response_text.split("</think>")[1].strip()
return response_text
if __name__ == "__main__": if __name__ == "__main__":

Loading…
Cancel
Save