Improved tool use

main
lasseedfast 5 months ago
parent d2319209e1
commit 238a5146f8
  1. 2
      __init__.py
  2. 2
      _llm/__init__.py
  3. 68
      _llm/llm.py
  4. 13
      _llm/ollama-cloud-test.py
  5. 176
      _llm/tool_registry.py
  6. 214
      _llm/tool_registy.py

@ -3,6 +3,6 @@ llm_client: A Python package for interacting with LLM models through Ollama.
""" """
from _llm._llm.llm import LLM from _llm._llm.llm import LLM
from _llm._llm.tool_registy import register_tool, get_tools from _llm._llm.tool_registry import register_tool, get_tools
__all__ = ["LLM", "register_tool", "get_tools"] __all__ = ["LLM", "register_tool", "get_tools"]

@ -1,6 +1,6 @@
from .llm import LLM # re-export the class from the module from .llm import LLM # re-export the class from the module
from .tool_registy import register_tool, get_tools from .tool_registry import register_tool, get_tools
# Define public API # Define public API
__all__ = ["LLM", "register_tool", "get_tools"] __all__ = ["LLM", "register_tool", "get_tools"]

@ -9,19 +9,19 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.responses import ParsedResponse from openai.types.responses import ParsedResponse
import backoff import backoff
import env_manager import env_manager
import json import json
try: try:
from .tool_registy import get_tools, register_tool from .tool_registry import get_tools, parse_function_call_arguments, execute_tool
except ImportError: except ImportError:
from tool_registy import get_tools, register_tool
from _llm._llm.tool_registry import get_tools, parse_function_call_arguments, execute_tool
try: try:
from colorprinter.print_color import * from colorprinter.print_color import *
except ImportError: except ImportError:
from colorprinter.print_color import * from colorprinter.colorprinter.print_color import *
@ -487,6 +487,7 @@ class LLM:
# Call the OpenAI API # Call the OpenAI API
else: else:
response: ChatCompletion = self.client.chat.completions.create(**kwargs) response: ChatCompletion = self.client.chat.completions.create(**kwargs)
# Try to extract backend information if available # Try to extract backend information if available
try: try:
response_headers = getattr(response, "_headers", {}) response_headers = getattr(response, "_headers", {})
@ -720,10 +721,64 @@ class LLM:
else: else:
choice = response.choices[0] choice = response.choices[0]
message: ChatCompletionMessage = choice.message message: ChatCompletionMessage = choice.message
print(message)
if hasattr(message, 'tool_calls') and message.tool_calls:
# Hantera flera verktygsanrop sequensielt
for tool_call in message.tool_calls:
try:
fn = getattr(tool_call, "function", None) or (tool_call.get("function") if isinstance(tool_call, dict) else None)
if not fn:
continue
func_name = getattr(fn, "name", None) or (fn.get("name") if isinstance(fn, dict) else None)
raw_args = getattr(fn, "arguments", None) or (fn.get("arguments") if isinstance(fn, dict) else None)
# Automatisk JSON-parsing av argument om de kommer som sträng (för vLLM-kompatibilitet)
if isinstance(raw_args, str):
try:
parsed_args = json.loads(raw_args)
# Uppdatera function.arguments med den parsade versionen för enklare användning
if hasattr(fn, "arguments"):
fn.arguments = parsed_args # Uppdatera objektet direkt
elif isinstance(fn, dict):
fn["arguments"] = parsed_args
except json.JSONDecodeError as e:
print_red(f"Warning: Could not parse tool arguments as JSON: {e}")
# Fallback till parse_function_call_arguments för robusthet
parsed_args = parse_function_call_arguments(raw_args)
else:
parsed_args = raw_args if isinstance(raw_args, dict) else {}
# Kör verktyget via tool_registry.execute_tool (validering och typ-coercion görs där)
tool_result = execute_tool(func_name, parsed_args)
# Sätt in tool-result i messages så modellen kan läsa det vidare
tool_content = tool_result if isinstance(tool_result, str) else json.dumps(tool_result, ensure_ascii=False)
self.messages.append({"role": "tool", "name": func_name, "content": tool_content})
except Exception as e:
print_red(f"Error executing tool {func_name}: {e}")
# append error to messages so model sees it (and you can debug)
self.messages.append({"role": "tool", "name": func_name or "unknown", "content": json.dumps({"error": str(e)})})
# fallback: older SDKs / shapes:
if hasattr(message, 'function_call') and message.function_call:
fc = message.function_call
func_name = getattr(fc, "name", None) or (fc.get("name") if isinstance(fc, dict) else None)
args_raw = getattr(fc, "arguments", None) or (fc.get("arguments") if isinstance(fc, dict) else None)
parsed_args = parse_function_call_arguments(args_raw)
try:
tool_result = execute_tool(func_name, parsed_args)
tool_content = tool_result if isinstance(tool_result, str) else json.dumps(tool_result, ensure_ascii=False)
self.messages.append({"role": "tool", "name": func_name, "content": tool_content})
except Exception as e:
self.messages.append({"role": "tool", "name": func_name or "unknown", "content": json.dumps({"error": str(e)})})
# Hämta textsvar från meddelandet
result: str = message.content result: str = message.content
if hasattr(message, 'content_text'): if hasattr(message, 'content_text'):
result: str = message.content_text result: str = message.content_text
# Store in message history (without tool calls for clean history)
# Spara i meddelandehistorik (utan verktygsanrop för ren historik)
self.messages.append({"role": "assistant", "content": result}) self.messages.append({"role": "assistant", "content": result})
if not self.chat: if not self.chat:
self.messages = [self.messages[0]] self.messages = [self.messages[0]]
@ -1120,7 +1175,6 @@ if __name__ == "__main__":
print(response.__dict__) print(response.__dict__)
response = llm.generate("What's the weather like in San Francisco? Also calculate 15 * 7 for me.", model='vllm') response = llm.generate("What's the weather like in San Francisco? Also calculate 15 * 7 for me.", model='vllm')
print(response.__dict__) print(response.__dict__)
exit()
# Define a tool for calculations # Define a tool for calculations
@register_tool @register_tool
@ -1198,7 +1252,7 @@ if __name__ == "__main__":
except Exception as e: except Exception as e:
print(f"❌ Tools test failed: {e}") print(f"❌ Tools test failed: {e}")
# Test 3: Thinking mode (use vllm model since reasoning model doesn't exist) # Test 3: Thinking mode (use vLLM model since reasoning model doesn't exist)
print("\n3 Thinking Mode Test (using vllm)") print("\n3 Thinking Mode Test (using vllm)")
print("-" * 30) print("-" * 30)
try: try:

@ -0,0 +1,13 @@
from ollama import Client
client = Client()
messages = [
{
'role': 'user',
'content': 'Why is the sky blue?',
},
]
for part in client.chat('gpt-oss:120b-cloud', messages=messages, stream=True):
print(part['message']['content'], end='', flush=True)

@ -0,0 +1,176 @@
import inspect, json, re, ast
from typing import Callable, Dict, Any, List, get_origin, get_args
from pydantic import BaseModel
TOOL_REGISTRY: Dict[str, Dict[str, Any]] = {}
# --- type mapping ---
def _pytype_to_jsonschema(t):
origin = get_origin(t)
if origin is list or origin is List:
args = get_args(t)
item_type = args[0] if args else str
return {"type": "array", "items": _pytype_to_jsonschema(item_type)}
if inspect.isclass(t) and issubclass(t, BaseModel):
sch = t.schema()
return {"type": "object", **sch}
mapping = {
str: {"type": "string"},
int: {"type": "integer"},
float: {"type": "number"},
bool: {"type": "boolean"},
dict: {"type": "object"},
list: {"type": "array", "items": {"type": "string"}},
}
return mapping.get(t, {"type": "string"})
# --- docstring parser (Google style) ---
def _parse_google_docstring(docstring: str):
if not docstring:
return {"description": "", "params": {}}
lines = [ln.rstrip() for ln in docstring.splitlines()]
desc_lines = []
i = 0
while i < len(lines) and not lines[i].lower().startswith(("args:", "arguments:")):
if lines[i].strip():
desc_lines.append(lines[i].strip())
i += 1
description = " ".join(desc_lines).strip()
params = {}
if i < len(lines):
i += 1
while i < len(lines):
line = lines[i].strip()
if not line:
i += 1
continue
m = re.match(r'^(\w+)\s*(?:\(([^)]+)\))?\s*:\s*(.*)$', line)
if m:
name = m.group(1)
desc = m.group(3)
j = i + 1
while j < len(lines) and not re.match(r'^\w+\s*(?:\([^)]+\))?\s*:', lines[j].strip()):
if lines[j].strip():
desc += " " + lines[j].strip()
j += 1
params[name] = {"description": desc.strip(), "type": m.group(2)}
i = j
continue
i += 1
return {"description": description, "params": params}
# --- helper: make OpenAI-style function spec ---
def _wrap_openai_function_schema(name: str, description: str, parameters: dict):
"""Create OpenAI function calling format with 'function' wrapper"""
params = parameters.copy()
if params.get("type") != "object":
params = {"type": "object", "properties": params.get("properties", params), "required": params.get("required", [])}
params.setdefault("additionalProperties", False)
# Return in OpenAI function calling format with 'function' wrapper
return {
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": params
}
}
# --- decorator to register tools ---
def register_tool(func: Callable = None, *, name: str = None, description: str = None, schema: dict = None):
def _register(f):
fname = name or f.__name__
doc = _parse_google_docstring(f.__doc__)
func_description = description or doc["description"] or ""
if schema is not None:
func_schema = schema
else:
sig = inspect.signature(f)
props = {}
required = []
for param_name, param in sig.parameters.items():
ann = param.annotation if param.annotation is not inspect._empty else str
prop_schema = _pytype_to_jsonschema(ann)
if param_name in doc["params"]:
prop_schema["description"] = doc["params"][param_name]["description"]
props[param_name] = prop_schema
if param.default is inspect._empty:
required.append(param_name)
func_schema = {"type": "object", "properties": props, "required": required, "additionalProperties": False}
TOOL_REGISTRY[fname] = {
"callable": f,
"schema": _wrap_openai_function_schema(fname, func_description, func_schema)
}
return f
if func is None:
return _register
else:
return _register(func)
# --- what to send to model ---
def get_tools() -> List[dict]:
"""Return OpenAI-compatible functions list with proper 'function' wrapper."""
return [entry["schema"] for entry in TOOL_REGISTRY.values()]
# --- robust parser for arguments ---
def parse_function_call_arguments(raw) -> dict:
if isinstance(raw, dict):
return raw
if not isinstance(raw, str):
return {"_raw_unexpected": str(type(raw)), "value": raw}
try:
return json.loads(raw)
except json.JSONDecodeError:
pass
try:
return ast.literal_eval(raw)
except Exception:
pass
stripped = raw.strip()
if re.match(r'^(SELECT|WITH)\b', stripped, flags=re.IGNORECASE):
return {"sql_query": stripped}
m = re.search(r'\{.*\}', raw, flags=re.DOTALL)
if m:
candidate = m.group(0)
try:
return json.loads(candidate)
except Exception:
try:
return ast.literal_eval(candidate)
except Exception:
pass
return {"_raw": raw}
# --- safe executor ---
def execute_tool(name: str, args: dict):
"""
Execute registered callable with args (basic validation).
Returns Python object (dict/list/str).
"""
entry = TOOL_REGISTRY.get(name)
if not entry:
raise RuntimeError(f"Function {name} not registered")
fn = entry["callable"]
# simple SQL safety example: if function expects sql_query ensure SELECT
if "sql_query" in args:
q = args["sql_query"].strip()
if not re.match(r'^(SELECT|WITH)\b', q, flags=re.IGNORECASE):
raise ValueError("Only SELECT/ WITH queries allowed in sql_query")
if q.endswith(";"):
args["sql_query"] = q[:-1]
# Prepare kwargs with minimal type coercion
sig = inspect.signature(fn)
kwargs = {}
for pname, param in sig.parameters.items():
if pname not in args:
continue
val = args[pname]
ann = param.annotation if param.annotation is not inspect._empty else None
origin = get_origin(ann)
if origin in (list, List) and isinstance(val, str):
kwargs[pname] = [x.strip() for x in val.split(",") if x.strip() != ""]
else:
kwargs[pname] = val
result = fn(**kwargs)
return result

@ -1,214 +0,0 @@
# assume your client already has: import inspect, json
from typing import Callable, Dict, Any
import inspect, json
import re
from pydantic import BaseModel
TOOL_REGISTRY: Dict[str, Dict[str, Any]] = {}
def _parse_google_docstring(docstring: str) -> Dict[str, Any]:
"""Parse Google-style docstring to extract description and parameter info."""
if not docstring:
return {"description": "", "params": {}}
# Split into lines and clean up
lines = [line.strip() for line in docstring.strip().split('\n')]
# Find the main description (everything before Args:)
description_lines = []
i = 0
while i < len(lines):
if lines[i].lower().startswith('args:') or lines[i].lower().startswith('arguments:'):
break
description_lines.append(lines[i])
i += 1
description = ' '.join(description_lines).strip()
# Parse parameters section
params = {}
if i < len(lines):
i += 1 # Skip the "Args:" line
while i < len(lines):
line = lines[i]
if line.lower().startswith(('returns:', 'yields:', 'raises:', 'note:', 'example:')):
break
# Match parameter format: param_name (type): description
match = re.match(r'^\s*(\w+)\s*(?:\(([^)]+)\))?\s*:\s*(.*)$', line)
if match:
param_name = match.group(1)
param_type = match.group(2)
param_desc = match.group(3)
# Collect multi-line descriptions
j = i + 1
while j < len(lines) and lines[j] and not re.match(r'^\s*\w+\s*(?:\([^)]+\))?\s*:', lines[j]):
param_desc += ' ' + lines[j].strip()
j += 1
params[param_name] = {
"description": param_desc.strip(),
"type": param_type.strip() if param_type else None
}
i = j - 1
i += 1
return {"description": description, "params": params}
def _pytype_to_jsonschema(t):
# Very-small helper; extend as needed or use pydantic models for complex types
mapping = {str: {"type": "string"}, int: {"type": "integer"},
float: {"type": "number"}, bool: {"type": "boolean"},
dict: {"type": "object"}, list: {"type": "array"}}
return mapping.get(t, {"type": "string"}) # fallback to string
def register_tool(func: Callable = None, *, name: str = None, description: str = None, schema: dict = None):
"""
Use as decorator or call directly:
@register_tool
def foo(x: int): ...
or
register_tool(func=myfunc, name="myfunc", schema=...)
"""
def _register(f):
fname = name or f.__name__
# Parse docstring for description and parameter info
docstring_info = _parse_google_docstring(f.__doc__)
func_description = description or docstring_info["description"] or ""
# If explicit schema provided, use it
if schema is not None:
func_schema = schema
else:
sig = inspect.signature(f)
props = {}
required = []
for param_name, param in sig.parameters.items():
ann = param.annotation
# If user used a Pydantic BaseModel as a single arg, use its schema
if inspect.isclass(ann) and issubclass(ann, BaseModel):
func_schema = ann.schema()
# wrap into a single-arg object if necessary
props = func_schema.get("properties", {})
required = func_schema.get("required", [])
# done early - for single-model param
break
# Create property schema from type annotation
prop_schema = _pytype_to_jsonschema(ann)
# Add description from docstring if available
if param_name in docstring_info["params"]:
prop_schema["description"] = docstring_info["params"][param_name]["description"]
props[param_name] = prop_schema
if param.default is inspect._empty:
required.append(param_name)
if 'func_schema' not in locals():
func_schema = {
"type": "object",
"properties": props,
"required": required
}
TOOL_REGISTRY[fname] = {
"callable": f,
"schema": {
"type": "function",
"function": {
"name": fname,
"description": func_description,
"parameters": func_schema
}
}
}
return f
if func is None:
return _register
else:
return _register(func)
def get_tools() -> list:
"""Return list of function schemas (JSON) to send to the model"""
return [v["schema"] for v in TOOL_REGISTRY.values()]
def handle_function_call_and_inject_result(response_choice, messages):
"""
Given the model choice (response.choices[0]) and your messages list:
- extracts function/tool call
- executes the registered python callable
- appends the tool result as a tool message and returns it
"""
# Support different shapes: some SDKs use .message.tool_calls, others .message.function_call
msg = getattr(response_choice, "message", None) or (response_choice.get("message") if isinstance(response_choice, dict) else None)
func_name = None
func_args = None
# try tool_calls style
if msg:
tool_calls = getattr(msg, "tool_calls", None) or (msg.get("tool_calls") if isinstance(msg, dict) else None)
if tool_calls:
tc = tool_calls[0]
fn = getattr(tc, "function", None) or (tc.get("function") if isinstance(tc, dict) else None)
func_name = getattr(fn, "name", None) or (fn.get("name") if isinstance(fn, dict) else None)
func_args = getattr(fn, "arguments", None) or (fn.get("arguments") if isinstance(fn, dict) else None)
# fallback to function_call
if func_name is None:
fc = getattr(msg, "function_call", None) or (msg.get("function_call") if isinstance(msg, dict) else None)
if fc:
func_name = getattr(fc, "name", None) or fc.get("name")
args_raw = getattr(fc, "arguments", None) or fc.get("arguments")
# arguments are often a JSON string depending on SDK shape
if isinstance(args_raw, str):
try:
func_args = json.loads(args_raw)
except Exception:
func_args = None
else:
func_args = args_raw
if not func_name:
return None # no function call found
entry = TOOL_REGISTRY.get(func_name)
if not entry:
raise RuntimeError(f"Function {func_name} not registered")
result = entry["callable"](**(func_args or {}))
# convert result to string/JSON for tool message
tool_content = result if isinstance(result, str) else json.dumps(result)
# append tool message so model can see the result
messages.append({"role": "tool", "name": func_name, "content": tool_content})
return tool_content
if __name__ == "__main__":
# Example usage and test
@register_tool
def add(x: int, y: int) -> int:
"""Add two integers
Args:
x (int): First integer
y (int): Second integer
Returns:
int: Sum of x and y
"""
return x + y
@register_tool(name="echo", description="Echoes the input string")
def echo_message(message: str) -> str:
"""Echo the input message
Args:
message (str): The message to echo
Returns:
str: The echoed message
"""
return message
print("Registered tools:")
import pprint
for info in get_tools():
pprint.pprint(info)
Loading…
Cancel
Save