parent
d2319209e1
commit
238a5146f8
6 changed files with 252 additions and 223 deletions
@ -1,6 +1,6 @@ |
|||||||
|
|
||||||
from .llm import LLM # re-export the class from the module |
from .llm import LLM # re-export the class from the module |
||||||
from .tool_registy import register_tool, get_tools |
from .tool_registry import register_tool, get_tools |
||||||
|
|
||||||
# Define public API |
# Define public API |
||||||
__all__ = ["LLM", "register_tool", "get_tools"] |
__all__ = ["LLM", "register_tool", "get_tools"] |
||||||
@ -0,0 +1,13 @@ |
|||||||
|
from ollama import Client |
||||||
|
|
||||||
|
client = Client() |
||||||
|
|
||||||
|
messages = [ |
||||||
|
{ |
||||||
|
'role': 'user', |
||||||
|
'content': 'Why is the sky blue?', |
||||||
|
}, |
||||||
|
] |
||||||
|
|
||||||
|
for part in client.chat('gpt-oss:120b-cloud', messages=messages, stream=True): |
||||||
|
print(part['message']['content'], end='', flush=True) |
||||||
@ -0,0 +1,176 @@ |
|||||||
|
import inspect, json, re, ast |
||||||
|
from typing import Callable, Dict, Any, List, get_origin, get_args |
||||||
|
from pydantic import BaseModel |
||||||
|
|
||||||
|
TOOL_REGISTRY: Dict[str, Dict[str, Any]] = {} |
||||||
|
|
||||||
|
# --- type mapping --- |
||||||
|
def _pytype_to_jsonschema(t): |
||||||
|
origin = get_origin(t) |
||||||
|
if origin is list or origin is List: |
||||||
|
args = get_args(t) |
||||||
|
item_type = args[0] if args else str |
||||||
|
return {"type": "array", "items": _pytype_to_jsonschema(item_type)} |
||||||
|
if inspect.isclass(t) and issubclass(t, BaseModel): |
||||||
|
sch = t.schema() |
||||||
|
return {"type": "object", **sch} |
||||||
|
mapping = { |
||||||
|
str: {"type": "string"}, |
||||||
|
int: {"type": "integer"}, |
||||||
|
float: {"type": "number"}, |
||||||
|
bool: {"type": "boolean"}, |
||||||
|
dict: {"type": "object"}, |
||||||
|
list: {"type": "array", "items": {"type": "string"}}, |
||||||
|
} |
||||||
|
return mapping.get(t, {"type": "string"}) |
||||||
|
|
||||||
|
# --- docstring parser (Google style) --- |
||||||
|
def _parse_google_docstring(docstring: str): |
||||||
|
if not docstring: |
||||||
|
return {"description": "", "params": {}} |
||||||
|
lines = [ln.rstrip() for ln in docstring.splitlines()] |
||||||
|
desc_lines = [] |
||||||
|
i = 0 |
||||||
|
while i < len(lines) and not lines[i].lower().startswith(("args:", "arguments:")): |
||||||
|
if lines[i].strip(): |
||||||
|
desc_lines.append(lines[i].strip()) |
||||||
|
i += 1 |
||||||
|
description = " ".join(desc_lines).strip() |
||||||
|
params = {} |
||||||
|
if i < len(lines): |
||||||
|
i += 1 |
||||||
|
while i < len(lines): |
||||||
|
line = lines[i].strip() |
||||||
|
if not line: |
||||||
|
i += 1 |
||||||
|
continue |
||||||
|
m = re.match(r'^(\w+)\s*(?:\(([^)]+)\))?\s*:\s*(.*)$', line) |
||||||
|
if m: |
||||||
|
name = m.group(1) |
||||||
|
desc = m.group(3) |
||||||
|
j = i + 1 |
||||||
|
while j < len(lines) and not re.match(r'^\w+\s*(?:\([^)]+\))?\s*:', lines[j].strip()): |
||||||
|
if lines[j].strip(): |
||||||
|
desc += " " + lines[j].strip() |
||||||
|
j += 1 |
||||||
|
params[name] = {"description": desc.strip(), "type": m.group(2)} |
||||||
|
i = j |
||||||
|
continue |
||||||
|
i += 1 |
||||||
|
return {"description": description, "params": params} |
||||||
|
|
||||||
|
# --- helper: make OpenAI-style function spec --- |
||||||
|
def _wrap_openai_function_schema(name: str, description: str, parameters: dict): |
||||||
|
"""Create OpenAI function calling format with 'function' wrapper""" |
||||||
|
params = parameters.copy() |
||||||
|
if params.get("type") != "object": |
||||||
|
params = {"type": "object", "properties": params.get("properties", params), "required": params.get("required", [])} |
||||||
|
params.setdefault("additionalProperties", False) |
||||||
|
|
||||||
|
# Return in OpenAI function calling format with 'function' wrapper |
||||||
|
return { |
||||||
|
"type": "function", |
||||||
|
"function": { |
||||||
|
"name": name, |
||||||
|
"description": description, |
||||||
|
"parameters": params |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
# --- decorator to register tools --- |
||||||
|
def register_tool(func: Callable = None, *, name: str = None, description: str = None, schema: dict = None): |
||||||
|
def _register(f): |
||||||
|
fname = name or f.__name__ |
||||||
|
doc = _parse_google_docstring(f.__doc__) |
||||||
|
func_description = description or doc["description"] or "" |
||||||
|
if schema is not None: |
||||||
|
func_schema = schema |
||||||
|
else: |
||||||
|
sig = inspect.signature(f) |
||||||
|
props = {} |
||||||
|
required = [] |
||||||
|
for param_name, param in sig.parameters.items(): |
||||||
|
ann = param.annotation if param.annotation is not inspect._empty else str |
||||||
|
prop_schema = _pytype_to_jsonschema(ann) |
||||||
|
if param_name in doc["params"]: |
||||||
|
prop_schema["description"] = doc["params"][param_name]["description"] |
||||||
|
props[param_name] = prop_schema |
||||||
|
if param.default is inspect._empty: |
||||||
|
required.append(param_name) |
||||||
|
func_schema = {"type": "object", "properties": props, "required": required, "additionalProperties": False} |
||||||
|
TOOL_REGISTRY[fname] = { |
||||||
|
"callable": f, |
||||||
|
"schema": _wrap_openai_function_schema(fname, func_description, func_schema) |
||||||
|
} |
||||||
|
return f |
||||||
|
if func is None: |
||||||
|
return _register |
||||||
|
else: |
||||||
|
return _register(func) |
||||||
|
|
||||||
|
# --- what to send to model --- |
||||||
|
def get_tools() -> List[dict]: |
||||||
|
"""Return OpenAI-compatible functions list with proper 'function' wrapper.""" |
||||||
|
return [entry["schema"] for entry in TOOL_REGISTRY.values()] |
||||||
|
|
||||||
|
# --- robust parser for arguments --- |
||||||
|
def parse_function_call_arguments(raw) -> dict: |
||||||
|
if isinstance(raw, dict): |
||||||
|
return raw |
||||||
|
if not isinstance(raw, str): |
||||||
|
return {"_raw_unexpected": str(type(raw)), "value": raw} |
||||||
|
try: |
||||||
|
return json.loads(raw) |
||||||
|
except json.JSONDecodeError: |
||||||
|
pass |
||||||
|
try: |
||||||
|
return ast.literal_eval(raw) |
||||||
|
except Exception: |
||||||
|
pass |
||||||
|
stripped = raw.strip() |
||||||
|
if re.match(r'^(SELECT|WITH)\b', stripped, flags=re.IGNORECASE): |
||||||
|
return {"sql_query": stripped} |
||||||
|
m = re.search(r'\{.*\}', raw, flags=re.DOTALL) |
||||||
|
if m: |
||||||
|
candidate = m.group(0) |
||||||
|
try: |
||||||
|
return json.loads(candidate) |
||||||
|
except Exception: |
||||||
|
try: |
||||||
|
return ast.literal_eval(candidate) |
||||||
|
except Exception: |
||||||
|
pass |
||||||
|
return {"_raw": raw} |
||||||
|
|
||||||
|
# --- safe executor --- |
||||||
|
def execute_tool(name: str, args: dict): |
||||||
|
""" |
||||||
|
Execute registered callable with args (basic validation). |
||||||
|
Returns Python object (dict/list/str). |
||||||
|
""" |
||||||
|
entry = TOOL_REGISTRY.get(name) |
||||||
|
if not entry: |
||||||
|
raise RuntimeError(f"Function {name} not registered") |
||||||
|
fn = entry["callable"] |
||||||
|
# simple SQL safety example: if function expects sql_query ensure SELECT |
||||||
|
if "sql_query" in args: |
||||||
|
q = args["sql_query"].strip() |
||||||
|
if not re.match(r'^(SELECT|WITH)\b', q, flags=re.IGNORECASE): |
||||||
|
raise ValueError("Only SELECT/ WITH queries allowed in sql_query") |
||||||
|
if q.endswith(";"): |
||||||
|
args["sql_query"] = q[:-1] |
||||||
|
# Prepare kwargs with minimal type coercion |
||||||
|
sig = inspect.signature(fn) |
||||||
|
kwargs = {} |
||||||
|
for pname, param in sig.parameters.items(): |
||||||
|
if pname not in args: |
||||||
|
continue |
||||||
|
val = args[pname] |
||||||
|
ann = param.annotation if param.annotation is not inspect._empty else None |
||||||
|
origin = get_origin(ann) |
||||||
|
if origin in (list, List) and isinstance(val, str): |
||||||
|
kwargs[pname] = [x.strip() for x in val.split(",") if x.strip() != ""] |
||||||
|
else: |
||||||
|
kwargs[pname] = val |
||||||
|
result = fn(**kwargs) |
||||||
|
return result |
||||||
@ -1,214 +0,0 @@ |
|||||||
# assume your client already has: import inspect, json |
|
||||||
from typing import Callable, Dict, Any |
|
||||||
import inspect, json |
|
||||||
import re |
|
||||||
from pydantic import BaseModel |
|
||||||
|
|
||||||
TOOL_REGISTRY: Dict[str, Dict[str, Any]] = {} |
|
||||||
|
|
||||||
def _parse_google_docstring(docstring: str) -> Dict[str, Any]: |
|
||||||
"""Parse Google-style docstring to extract description and parameter info.""" |
|
||||||
if not docstring: |
|
||||||
return {"description": "", "params": {}} |
|
||||||
|
|
||||||
# Split into lines and clean up |
|
||||||
lines = [line.strip() for line in docstring.strip().split('\n')] |
|
||||||
|
|
||||||
# Find the main description (everything before Args:) |
|
||||||
description_lines = [] |
|
||||||
i = 0 |
|
||||||
while i < len(lines): |
|
||||||
if lines[i].lower().startswith('args:') or lines[i].lower().startswith('arguments:'): |
|
||||||
break |
|
||||||
description_lines.append(lines[i]) |
|
||||||
i += 1 |
|
||||||
|
|
||||||
description = ' '.join(description_lines).strip() |
|
||||||
|
|
||||||
# Parse parameters section |
|
||||||
params = {} |
|
||||||
if i < len(lines): |
|
||||||
i += 1 # Skip the "Args:" line |
|
||||||
while i < len(lines): |
|
||||||
line = lines[i] |
|
||||||
if line.lower().startswith(('returns:', 'yields:', 'raises:', 'note:', 'example:')): |
|
||||||
break |
|
||||||
|
|
||||||
# Match parameter format: param_name (type): description |
|
||||||
match = re.match(r'^\s*(\w+)\s*(?:\(([^)]+)\))?\s*:\s*(.*)$', line) |
|
||||||
if match: |
|
||||||
param_name = match.group(1) |
|
||||||
param_type = match.group(2) |
|
||||||
param_desc = match.group(3) |
|
||||||
|
|
||||||
# Collect multi-line descriptions |
|
||||||
j = i + 1 |
|
||||||
while j < len(lines) and lines[j] and not re.match(r'^\s*\w+\s*(?:\([^)]+\))?\s*:', lines[j]): |
|
||||||
param_desc += ' ' + lines[j].strip() |
|
||||||
j += 1 |
|
||||||
|
|
||||||
params[param_name] = { |
|
||||||
"description": param_desc.strip(), |
|
||||||
"type": param_type.strip() if param_type else None |
|
||||||
} |
|
||||||
i = j - 1 |
|
||||||
|
|
||||||
i += 1 |
|
||||||
|
|
||||||
return {"description": description, "params": params} |
|
||||||
|
|
||||||
def _pytype_to_jsonschema(t): |
|
||||||
# Very-small helper; extend as needed or use pydantic models for complex types |
|
||||||
mapping = {str: {"type": "string"}, int: {"type": "integer"}, |
|
||||||
float: {"type": "number"}, bool: {"type": "boolean"}, |
|
||||||
dict: {"type": "object"}, list: {"type": "array"}} |
|
||||||
return mapping.get(t, {"type": "string"}) # fallback to string |
|
||||||
|
|
||||||
def register_tool(func: Callable = None, *, name: str = None, description: str = None, schema: dict = None): |
|
||||||
""" |
|
||||||
Use as decorator or call directly: |
|
||||||
@register_tool |
|
||||||
def foo(x: int): ... |
|
||||||
or |
|
||||||
register_tool(func=myfunc, name="myfunc", schema=...) |
|
||||||
""" |
|
||||||
def _register(f): |
|
||||||
fname = name or f.__name__ |
|
||||||
|
|
||||||
# Parse docstring for description and parameter info |
|
||||||
docstring_info = _parse_google_docstring(f.__doc__) |
|
||||||
func_description = description or docstring_info["description"] or "" |
|
||||||
|
|
||||||
# If explicit schema provided, use it |
|
||||||
if schema is not None: |
|
||||||
func_schema = schema |
|
||||||
else: |
|
||||||
sig = inspect.signature(f) |
|
||||||
props = {} |
|
||||||
required = [] |
|
||||||
for param_name, param in sig.parameters.items(): |
|
||||||
ann = param.annotation |
|
||||||
# If user used a Pydantic BaseModel as a single arg, use its schema |
|
||||||
if inspect.isclass(ann) and issubclass(ann, BaseModel): |
|
||||||
func_schema = ann.schema() |
|
||||||
# wrap into a single-arg object if necessary |
|
||||||
props = func_schema.get("properties", {}) |
|
||||||
required = func_schema.get("required", []) |
|
||||||
# done early - for single-model param |
|
||||||
break |
|
||||||
|
|
||||||
# Create property schema from type annotation |
|
||||||
prop_schema = _pytype_to_jsonschema(ann) |
|
||||||
|
|
||||||
# Add description from docstring if available |
|
||||||
if param_name in docstring_info["params"]: |
|
||||||
prop_schema["description"] = docstring_info["params"][param_name]["description"] |
|
||||||
|
|
||||||
props[param_name] = prop_schema |
|
||||||
if param.default is inspect._empty: |
|
||||||
required.append(param_name) |
|
||||||
|
|
||||||
if 'func_schema' not in locals(): |
|
||||||
func_schema = { |
|
||||||
"type": "object", |
|
||||||
"properties": props, |
|
||||||
"required": required |
|
||||||
} |
|
||||||
|
|
||||||
TOOL_REGISTRY[fname] = { |
|
||||||
"callable": f, |
|
||||||
"schema": { |
|
||||||
"type": "function", |
|
||||||
"function": { |
|
||||||
"name": fname, |
|
||||||
"description": func_description, |
|
||||||
"parameters": func_schema |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
return f |
|
||||||
|
|
||||||
if func is None: |
|
||||||
return _register |
|
||||||
else: |
|
||||||
return _register(func) |
|
||||||
|
|
||||||
def get_tools() -> list: |
|
||||||
"""Return list of function schemas (JSON) to send to the model""" |
|
||||||
return [v["schema"] for v in TOOL_REGISTRY.values()] |
|
||||||
|
|
||||||
def handle_function_call_and_inject_result(response_choice, messages): |
|
||||||
""" |
|
||||||
Given the model choice (response.choices[0]) and your messages list: |
|
||||||
- extracts function/tool call |
|
||||||
- executes the registered python callable |
|
||||||
- appends the tool result as a tool message and returns it |
|
||||||
""" |
|
||||||
# Support different shapes: some SDKs use .message.tool_calls, others .message.function_call |
|
||||||
msg = getattr(response_choice, "message", None) or (response_choice.get("message") if isinstance(response_choice, dict) else None) |
|
||||||
func_name = None |
|
||||||
func_args = None |
|
||||||
# try tool_calls style |
|
||||||
if msg: |
|
||||||
tool_calls = getattr(msg, "tool_calls", None) or (msg.get("tool_calls") if isinstance(msg, dict) else None) |
|
||||||
if tool_calls: |
|
||||||
tc = tool_calls[0] |
|
||||||
fn = getattr(tc, "function", None) or (tc.get("function") if isinstance(tc, dict) else None) |
|
||||||
func_name = getattr(fn, "name", None) or (fn.get("name") if isinstance(fn, dict) else None) |
|
||||||
func_args = getattr(fn, "arguments", None) or (fn.get("arguments") if isinstance(fn, dict) else None) |
|
||||||
# fallback to function_call |
|
||||||
if func_name is None: |
|
||||||
fc = getattr(msg, "function_call", None) or (msg.get("function_call") if isinstance(msg, dict) else None) |
|
||||||
if fc: |
|
||||||
func_name = getattr(fc, "name", None) or fc.get("name") |
|
||||||
args_raw = getattr(fc, "arguments", None) or fc.get("arguments") |
|
||||||
# arguments are often a JSON string depending on SDK shape |
|
||||||
if isinstance(args_raw, str): |
|
||||||
try: |
|
||||||
func_args = json.loads(args_raw) |
|
||||||
except Exception: |
|
||||||
func_args = None |
|
||||||
else: |
|
||||||
func_args = args_raw |
|
||||||
|
|
||||||
if not func_name: |
|
||||||
return None # no function call found |
|
||||||
|
|
||||||
entry = TOOL_REGISTRY.get(func_name) |
|
||||||
if not entry: |
|
||||||
raise RuntimeError(f"Function {func_name} not registered") |
|
||||||
|
|
||||||
result = entry["callable"](**(func_args or {})) |
|
||||||
# convert result to string/JSON for tool message |
|
||||||
tool_content = result if isinstance(result, str) else json.dumps(result) |
|
||||||
# append tool message so model can see the result |
|
||||||
messages.append({"role": "tool", "name": func_name, "content": tool_content}) |
|
||||||
return tool_content |
|
||||||
|
|
||||||
if __name__ == "__main__": |
|
||||||
# Example usage and test |
|
||||||
@register_tool |
|
||||||
def add(x: int, y: int) -> int: |
|
||||||
"""Add two integers |
|
||||||
Args: |
|
||||||
x (int): First integer |
|
||||||
y (int): Second integer |
|
||||||
Returns: |
|
||||||
int: Sum of x and y |
|
||||||
""" |
|
||||||
return x + y |
|
||||||
|
|
||||||
@register_tool(name="echo", description="Echoes the input string") |
|
||||||
def echo_message(message: str) -> str: |
|
||||||
"""Echo the input message |
|
||||||
Args: |
|
||||||
message (str): The message to echo |
|
||||||
Returns: |
|
||||||
str: The echoed message |
|
||||||
""" |
|
||||||
return message |
|
||||||
|
|
||||||
print("Registered tools:") |
|
||||||
import pprint |
|
||||||
for info in get_tools(): |
|
||||||
pprint.pprint(info) |
|
||||||
Loading…
Reference in new issue