parent
3f67dd57d2
commit
b72df20b03
2 changed files with 1037 additions and 393 deletions
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,214 @@ |
||||
# assume your client already has: import inspect, json |
||||
from typing import Callable, Dict, Any |
||||
import inspect, json |
||||
import re |
||||
from pydantic import BaseModel |
||||
|
||||
TOOL_REGISTRY: Dict[str, Dict[str, Any]] = {} |
||||
|
||||
def _parse_google_docstring(docstring: str) -> Dict[str, Any]: |
||||
"""Parse Google-style docstring to extract description and parameter info.""" |
||||
if not docstring: |
||||
return {"description": "", "params": {}} |
||||
|
||||
# Split into lines and clean up |
||||
lines = [line.strip() for line in docstring.strip().split('\n')] |
||||
|
||||
# Find the main description (everything before Args:) |
||||
description_lines = [] |
||||
i = 0 |
||||
while i < len(lines): |
||||
if lines[i].lower().startswith('args:') or lines[i].lower().startswith('arguments:'): |
||||
break |
||||
description_lines.append(lines[i]) |
||||
i += 1 |
||||
|
||||
description = ' '.join(description_lines).strip() |
||||
|
||||
# Parse parameters section |
||||
params = {} |
||||
if i < len(lines): |
||||
i += 1 # Skip the "Args:" line |
||||
while i < len(lines): |
||||
line = lines[i] |
||||
if line.lower().startswith(('returns:', 'yields:', 'raises:', 'note:', 'example:')): |
||||
break |
||||
|
||||
# Match parameter format: param_name (type): description |
||||
match = re.match(r'^\s*(\w+)\s*(?:\(([^)]+)\))?\s*:\s*(.*)$', line) |
||||
if match: |
||||
param_name = match.group(1) |
||||
param_type = match.group(2) |
||||
param_desc = match.group(3) |
||||
|
||||
# Collect multi-line descriptions |
||||
j = i + 1 |
||||
while j < len(lines) and lines[j] and not re.match(r'^\s*\w+\s*(?:\([^)]+\))?\s*:', lines[j]): |
||||
param_desc += ' ' + lines[j].strip() |
||||
j += 1 |
||||
|
||||
params[param_name] = { |
||||
"description": param_desc.strip(), |
||||
"type": param_type.strip() if param_type else None |
||||
} |
||||
i = j - 1 |
||||
|
||||
i += 1 |
||||
|
||||
return {"description": description, "params": params} |
||||
|
||||
def _pytype_to_jsonschema(t): |
||||
# Very-small helper; extend as needed or use pydantic models for complex types |
||||
mapping = {str: {"type": "string"}, int: {"type": "integer"}, |
||||
float: {"type": "number"}, bool: {"type": "boolean"}, |
||||
dict: {"type": "object"}, list: {"type": "array"}} |
||||
return mapping.get(t, {"type": "string"}) # fallback to string |
||||
|
||||
def register_tool(func: Callable = None, *, name: str = None, description: str = None, schema: dict = None): |
||||
""" |
||||
Use as decorator or call directly: |
||||
@register_tool |
||||
def foo(x: int): ... |
||||
or |
||||
register_tool(func=myfunc, name="myfunc", schema=...) |
||||
""" |
||||
def _register(f): |
||||
fname = name or f.__name__ |
||||
|
||||
# Parse docstring for description and parameter info |
||||
docstring_info = _parse_google_docstring(f.__doc__) |
||||
func_description = description or docstring_info["description"] or "" |
||||
|
||||
# If explicit schema provided, use it |
||||
if schema is not None: |
||||
func_schema = schema |
||||
else: |
||||
sig = inspect.signature(f) |
||||
props = {} |
||||
required = [] |
||||
for param_name, param in sig.parameters.items(): |
||||
ann = param.annotation |
||||
# If user used a Pydantic BaseModel as a single arg, use its schema |
||||
if inspect.isclass(ann) and issubclass(ann, BaseModel): |
||||
func_schema = ann.schema() |
||||
# wrap into a single-arg object if necessary |
||||
props = func_schema.get("properties", {}) |
||||
required = func_schema.get("required", []) |
||||
# done early - for single-model param |
||||
break |
||||
|
||||
# Create property schema from type annotation |
||||
prop_schema = _pytype_to_jsonschema(ann) |
||||
|
||||
# Add description from docstring if available |
||||
if param_name in docstring_info["params"]: |
||||
prop_schema["description"] = docstring_info["params"][param_name]["description"] |
||||
|
||||
props[param_name] = prop_schema |
||||
if param.default is inspect._empty: |
||||
required.append(param_name) |
||||
|
||||
if 'func_schema' not in locals(): |
||||
func_schema = { |
||||
"type": "object", |
||||
"properties": props, |
||||
"required": required |
||||
} |
||||
|
||||
TOOL_REGISTRY[fname] = { |
||||
"callable": f, |
||||
"schema": { |
||||
"type": "function", |
||||
"function": { |
||||
"name": fname, |
||||
"description": func_description, |
||||
"parameters": func_schema |
||||
} |
||||
} |
||||
} |
||||
return f |
||||
|
||||
if func is None: |
||||
return _register |
||||
else: |
||||
return _register(func) |
||||
|
||||
def get_tools() -> list: |
||||
"""Return list of function schemas (JSON) to send to the model""" |
||||
return [v["schema"] for v in TOOL_REGISTRY.values()] |
||||
|
||||
def handle_function_call_and_inject_result(response_choice, messages): |
||||
""" |
||||
Given the model choice (response.choices[0]) and your messages list: |
||||
- extracts function/tool call |
||||
- executes the registered python callable |
||||
- appends the tool result as a tool message and returns it |
||||
""" |
||||
# Support different shapes: some SDKs use .message.tool_calls, others .message.function_call |
||||
msg = getattr(response_choice, "message", None) or (response_choice.get("message") if isinstance(response_choice, dict) else None) |
||||
func_name = None |
||||
func_args = None |
||||
# try tool_calls style |
||||
if msg: |
||||
tool_calls = getattr(msg, "tool_calls", None) or (msg.get("tool_calls") if isinstance(msg, dict) else None) |
||||
if tool_calls: |
||||
tc = tool_calls[0] |
||||
fn = getattr(tc, "function", None) or (tc.get("function") if isinstance(tc, dict) else None) |
||||
func_name = getattr(fn, "name", None) or (fn.get("name") if isinstance(fn, dict) else None) |
||||
func_args = getattr(fn, "arguments", None) or (fn.get("arguments") if isinstance(fn, dict) else None) |
||||
# fallback to function_call |
||||
if func_name is None: |
||||
fc = getattr(msg, "function_call", None) or (msg.get("function_call") if isinstance(msg, dict) else None) |
||||
if fc: |
||||
func_name = getattr(fc, "name", None) or fc.get("name") |
||||
args_raw = getattr(fc, "arguments", None) or fc.get("arguments") |
||||
# arguments are often a JSON string depending on SDK shape |
||||
if isinstance(args_raw, str): |
||||
try: |
||||
func_args = json.loads(args_raw) |
||||
except Exception: |
||||
func_args = None |
||||
else: |
||||
func_args = args_raw |
||||
|
||||
if not func_name: |
||||
return None # no function call found |
||||
|
||||
entry = TOOL_REGISTRY.get(func_name) |
||||
if not entry: |
||||
raise RuntimeError(f"Function {func_name} not registered") |
||||
|
||||
result = entry["callable"](**(func_args or {})) |
||||
# convert result to string/JSON for tool message |
||||
tool_content = result if isinstance(result, str) else json.dumps(result) |
||||
# append tool message so model can see the result |
||||
messages.append({"role": "tool", "name": func_name, "content": tool_content}) |
||||
return tool_content |
||||
|
||||
if __name__ == "__main__": |
||||
# Example usage and test |
||||
@register_tool |
||||
def add(x: int, y: int) -> int: |
||||
"""Add two integers |
||||
Args: |
||||
x (int): First integer |
||||
y (int): Second integer |
||||
Returns: |
||||
int: Sum of x and y |
||||
""" |
||||
return x + y |
||||
|
||||
@register_tool(name="echo", description="Echoes the input string") |
||||
def echo_message(message: str) -> str: |
||||
"""Echo the input message |
||||
Args: |
||||
message (str): The message to echo |
||||
Returns: |
||||
str: The echoed message |
||||
""" |
||||
return message |
||||
|
||||
print("Registered tools:") |
||||
import pprint |
||||
for info in get_tools(): |
||||
pprint.pprint(info) |
||||
Loading…
Reference in new issue