Skip to content

Commit de878a2

Browse files
committed
feat: Add tool guardrails to function_tool decorator args (ref #2218)
1 parent ba55bbd commit de878a2

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

src/agents/tool.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,8 @@ def function_tool(
687687
failure_error_function: ToolErrorFunction | None = None,
688688
strict_mode: bool = True,
689689
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
690+
tool_input_guardrails: list[ToolInputGuardrail[Any]] | None = None,
691+
tool_output_guardrails: list[ToolOutputGuardrail[Any]] | None = None,
690692
) -> FunctionTool:
691693
"""Overload for usage as @function_tool (no parentheses)."""
692694
...
@@ -702,6 +704,8 @@ def function_tool(
702704
failure_error_function: ToolErrorFunction | None = None,
703705
strict_mode: bool = True,
704706
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
707+
tool_input_guardrails: list[ToolInputGuardrail[Any]] | None = None,
708+
tool_output_guardrails: list[ToolOutputGuardrail[Any]] | None = None,
705709
) -> Callable[[ToolFunction[...]], FunctionTool]:
706710
"""Overload for usage as @function_tool(...)."""
707711
...
@@ -717,6 +721,8 @@ def function_tool(
717721
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
718722
strict_mode: bool = True,
719723
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
724+
tool_input_guardrails: list[ToolInputGuardrail[Any]] | None = None,
725+
tool_output_guardrails: list[ToolOutputGuardrail[Any]] | None = None,
720726
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
721727
"""
722728
Decorator to create a FunctionTool from a function. By default, we will:
@@ -748,6 +754,8 @@ def function_tool(
748754
is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
749755
context and agent and returns whether the tool is enabled. Disabled tools are hidden
750756
from the LLM at runtime.
757+
tool_input_guardrails: Optional list of guardrails to run before invoking the tool.
758+
tool_output_guardrails: Optional list of guardrails to run after the tool returns.
751759
"""
752760

753761
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
@@ -845,6 +853,8 @@ async def _on_invoke_tool(ctx: ToolContext[Any], input: str) -> Any:
845853
on_invoke_tool=_on_invoke_tool,
846854
strict_json_schema=strict_mode,
847855
is_enabled=is_enabled,
856+
tool_input_guardrails=tool_input_guardrails,
857+
tool_output_guardrails=tool_output_guardrails,
848858
)
849859

850860
# If func is actually a callable, we were used as @function_tool with no parentheses

tests/test_function_tool.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111
FunctionTool,
1212
ModelBehaviorError,
1313
RunContextWrapper,
14+
ToolGuardrailFunctionOutput,
15+
ToolInputGuardrailData,
16+
ToolOutputGuardrailData,
1417
function_tool,
18+
tool_input_guardrail,
19+
tool_output_guardrail,
1520
)
1621
from agents.tool import default_tool_error_function
1722
from agents.tool_context import ToolContext
@@ -96,6 +101,21 @@ def complex_args_function(foo: Foo, bar: Bar, baz: str = "hello"):
96101
return f"{foo.a + foo.b} {bar['x']}{bar['y']} {baz}"
97102

98103

104+
@tool_input_guardrail
105+
def reject_args_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput:
106+
"""Reject tool calls for test purposes."""
107+
return ToolGuardrailFunctionOutput.reject_content(
108+
message="blocked",
109+
output_info={"tool": data.context.tool_name},
110+
)
111+
112+
113+
@tool_output_guardrail
114+
def allow_output_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput:
115+
"""Allow tool outputs for test purposes."""
116+
return ToolGuardrailFunctionOutput.allow(output_info={"echo": data.output})
117+
118+
99119
@pytest.mark.asyncio
100120
async def test_complex_args_function():
101121
tool = function_tool(complex_args_function, failure_error_function=None)
@@ -359,3 +379,26 @@ def boom() -> None:
359379
ctx = ToolContext(None, tool_name=boom.name, tool_call_id="boom", tool_arguments="{}")
360380
result = await boom.on_invoke_tool(ctx, "{}")
361381
assert result.startswith("handled:")
382+
383+
384+
def test_function_tool_accepts_guardrail_arguments():
385+
tool = function_tool(
386+
simple_function,
387+
tool_input_guardrails=[reject_args_guardrail],
388+
tool_output_guardrails=[allow_output_guardrail],
389+
)
390+
391+
assert tool.tool_input_guardrails == [reject_args_guardrail]
392+
assert tool.tool_output_guardrails == [allow_output_guardrail]
393+
394+
395+
def test_function_tool_decorator_accepts_guardrail_arguments():
396+
@function_tool(
397+
tool_input_guardrails=[reject_args_guardrail],
398+
tool_output_guardrails=[allow_output_guardrail],
399+
)
400+
def guarded(a: int) -> int:
401+
return a
402+
403+
assert guarded.tool_input_guardrails == [reject_args_guardrail]
404+
assert guarded.tool_output_guardrails == [allow_output_guardrail]

0 commit comments

Comments
 (0)