You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

236 lines
8.6 KiB

import inspect, json, re, ast
from typing import Callable, Dict, Any, List, get_origin, get_args
from pydantic import BaseModel
TOOL_REGISTRY: Dict[str, Dict[str, Any]] = {}
# --- type mapping ---
def _pytype_to_jsonschema(t):
origin = get_origin(t)
if origin is list or origin is List:
args = get_args(t)
item_type = args[0] if args else str
return {"type": "array", "items": _pytype_to_jsonschema(item_type)}
if inspect.isclass(t) and issubclass(t, BaseModel):
sch = t.schema()
return {"type": "object", **sch}
mapping = {
str: {"type": "string"},
int: {"type": "integer"},
float: {"type": "number"},
bool: {"type": "boolean"},
dict: {"type": "object"},
list: {"type": "array", "items": {"type": "string"}},
}
return mapping.get(t, {"type": "string"})
# --- docstring parser (Google style) - FIXED VERSION ---
def _parse_google_docstring(docstring: str):
if not docstring:
return {"description": "", "params": {}}
lines = [ln.rstrip() for ln in docstring.splitlines()]
# Find where Args/Arguments section starts
args_start = None
for i, line in enumerate(lines):
if line.strip().lower() in ("args:", "arguments:"):
args_start = i
break
# Find where Args section ends (Returns:, Raises:, or another section)
args_end = len(lines)
if args_start is not None:
for i in range(args_start + 1, len(lines)):
line = lines[i].strip().lower()
if line.endswith(':') and line.rstrip(':') in ('returns', 'return', 'raises', 'raise', 'yields', 'yield', 'examples', 'example', 'notes', 'note'):
args_end = i
break
# Build description from everything EXCEPT the Args section content
desc_lines = []
# Before Args
if args_start is not None:
for i in range(args_start):
if lines[i].strip():
desc_lines.append(lines[i].strip())
else:
# No Args section, include everything
for line in lines:
if line.strip():
desc_lines.append(line.strip())
# After Args section (Returns, examples, etc.)
if args_start is not None and args_end < len(lines):
for i in range(args_end, len(lines)):
if lines[i].strip():
desc_lines.append(lines[i].strip())
description = " ".join(desc_lines).strip()
# Parse parameters from Args section
params = {}
if args_start is not None:
i = args_start + 1
while i < args_end:
line = lines[i].strip()
if not line:
i += 1
continue
# Match parameter line: "param_name (type): description" or "param_name: description"
m = re.match(r'^(\w+)\s*(?:\(([^)]+)\))?\s*:\s*(.*)$', line)
if m:
name = m.group(1)
desc = m.group(3)
# Collect continuation lines for this parameter
j = i + 1
while j < args_end:
next_line = lines[j].strip()
# Check if it's a new parameter or empty
if not next_line or re.match(r'^\w+\s*(?:\([^)]+\))?\s*:', next_line):
break
desc += " " + next_line
j += 1
params[name] = {"description": desc.strip(), "type": m.group(2)}
i = j
continue
i += 1
return {"description": description, "params": params}
# --- helper: make OpenAI-style function spec ---
def _wrap_openai_function_schema(name: str, description: str, parameters: dict):
"""Create OpenAI function calling format with 'function' wrapper"""
params = parameters.copy()
if params.get("type") != "object":
params = {"type": "object", "properties": params.get("properties", params), "required": params.get("required", [])}
params.setdefault("additionalProperties", False)
# Return in OpenAI function calling format with 'function' wrapper
return {
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": params
}
}
# --- decorator to register tools ---
def register_tool(func: Callable = None, *, name: str = None, description: str = None, schema: dict = None):
def _register(f):
fname = name or f.__name__
doc = _parse_google_docstring(f.__doc__)
func_description = description or doc["description"] or ""
if schema is not None:
func_schema = schema
else:
sig = inspect.signature(f)
props = {}
required = []
for param_name, param in sig.parameters.items():
ann = param.annotation if param.annotation is not inspect._empty else str
prop_schema = _pytype_to_jsonschema(ann)
if param_name in doc["params"]:
prop_schema["description"] = doc["params"][param_name]["description"]
props[param_name] = prop_schema
if param.default is inspect._empty:
required.append(param_name)
func_schema = {"type": "object", "properties": props, "required": required, "additionalProperties": False}
TOOL_REGISTRY[fname] = {
"callable": f,
"schema": _wrap_openai_function_schema(fname, func_description, func_schema)
}
return f
if func is None:
return _register
else:
return _register(func)
# --- what to send to model ---
def get_tools(specific_tools: list[str] = False, exclude_tools: list[str]= False) -> List[dict]:
"""Return OpenAI-compatible functions list with proper 'function' wrapper."""
assert not (specific_tools and exclude_tools), "Cannot specify both specific_tools and exclude_tools"
if isinstance(specific_tools, str):
specific_tools = [specific_tools]
if specific_tools:
# Returned named tools only
result = []
for t in specific_tools:
entry = TOOL_REGISTRY.get(t)
if entry:
result.append(entry["schema"])
elif exclude_tools:
all_tools = [entry["schema"] for entry in TOOL_REGISTRY.values()]
result = [t for t in all_tools if t["function"]["name"] not in exclude_tools]
else:
# Return all registered tools
result = [entry["schema"] for entry in TOOL_REGISTRY.values()]
return result
# --- robust parser for arguments ---
def parse_function_call_arguments(raw) -> dict:
if isinstance(raw, dict):
return raw
if not isinstance(raw, str):
return {"_raw_unexpected": str(type(raw)), "value": raw}
try:
return json.loads(raw)
except json.JSONDecodeError:
pass
try:
return ast.literal_eval(raw)
except Exception:
pass
stripped = raw.strip()
if re.match(r'^(SELECT|WITH)\b', stripped, flags=re.IGNORECASE):
return {"sql_query": stripped}
m = re.search(r'\{.*\}', raw, flags=re.DOTALL)
if m:
candidate = m.group(0)
try:
return json.loads(candidate)
except Exception:
try:
return ast.literal_eval(candidate)
except Exception:
pass
return {"_raw": raw}
# --- safe executor ---
def execute_tool(name: str, args: dict):
"""
Execute registered callable with args (basic validation).
Returns Python object (dict/list/str).
"""
entry = TOOL_REGISTRY.get(name)
if not entry:
raise RuntimeError(f"Function {name} not registered")
fn = entry["callable"]
# simple SQL safety example: if function expects sql_query ensure SELECT
if "sql_query" in args:
q = args["sql_query"].strip()
if not re.match(r'^(SELECT|WITH)\b', q, flags=re.IGNORECASE):
raise ValueError("Only SELECT/ WITH queries allowed in sql_query")
if q.endswith(";"):
args["sql_query"] = q[:-1]
# Prepare kwargs with minimal type coercion
sig = inspect.signature(fn)
kwargs = {}
for pname, param in sig.parameters.items():
if pname not in args:
continue
val = args[pname]
ann = param.annotation if param.annotation is not inspect._empty else None
origin = get_origin(ann)
if origin in (list, List) and isinstance(val, str):
kwargs[pname] = [x.strip() for x in val.split(",") if x.strip() != ""]
else:
kwargs[pname] = val
result = fn(**kwargs)
return result