Skip to content

Commit 5e1db14

Browse files
authored
Add tool call arguments in ToolContext for RunHooks (#1765)
## Background Currently, the `RunHooks` lifecycle (`on_tool_start`, `on_tool_end`) exposes the `Tool` and `ToolContext`, but does not include the actual arguments passed to the tool call. resolves #939 ## Solution This implementation is inspired by [PR #1598](#1598). * Add a new `tool_arguments` field to `ToolContext` and populate it via from_agent_context with tool_call.arguments. * Update `lifecycle_example.py` to demonstrate tool_arguments in hooks * Unlike the proposal in [PR #253](#253), this solution is not expected to introduce breaking changes, making it easier to adopt.
1 parent 4007cba commit 5e1db14

File tree

6 files changed

+74
-26
lines changed

6 files changed

+74
-26
lines changed

examples/basic/lifecycle_example.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ async def on_agent_end(self, context: RunContextWrapper, agent: Agent, output: A
4646
async def on_tool_start(self, context: RunContextWrapper, agent: Agent, tool: Tool) -> None:
4747
self.event_counter += 1
4848
print(
49-
f"### {self.event_counter}: Tool {tool.name} started. Usage: {self._usage_to_str(context.usage)}"
49+
f"### {self.event_counter}: Tool {tool.name} started. name={context.tool_name}, call_id={context.tool_call_id}, args={context.tool_arguments}. Usage: {self._usage_to_str(context.usage)}" # type: ignore[attr-defined]
5050
)
5151

5252
async def on_tool_end(
5353
self, context: RunContextWrapper, agent: Agent, tool: Tool, result: str
5454
) -> None:
5555
self.event_counter += 1
5656
print(
57-
f"### {self.event_counter}: Tool {tool.name} ended with result {result}. Usage: {self._usage_to_str(context.usage)}"
57+
f"### {self.event_counter}: Tool {tool.name} finished. result={result}, name={context.tool_name}, call_id={context.tool_call_id}, args={context.tool_arguments}. Usage: {self._usage_to_str(context.usage)}" # type: ignore[attr-defined]
5858
)
5959

6060
async def on_handoff(
@@ -128,19 +128,19 @@ async def main() -> None:
128128
### 1: Agent Start Agent started. Usage: 0 requests, 0 input tokens, 0 output tokens, 0 total tokens
129129
### 2: LLM started. Usage: 0 requests, 0 input tokens, 0 output tokens, 0 total tokens
130130
### 3: LLM ended. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens
131-
### 4: Tool random_number started. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens
132-
### 5: Tool random_number ended with result 69. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens
131+
### 4: Tool random_number started. name=random_number, call_id=call_IujmDZYiM800H0hy7v17VTS0, args={"max":250}. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens
132+
### 5: Tool random_number finished. result=107, name=random_number, call_id=call_IujmDZYiM800H0hy7v17VTS0, args={"max":250}. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens
133133
### 6: LLM started. Usage: 1 requests, 143 input tokens, 15 output tokens, 158 total tokens
134134
### 7: LLM ended. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens
135135
### 8: Handoff from Start Agent to Multiply Agent. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens
136136
### 9: Agent Multiply Agent started. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens
137137
### 10: LLM started. Usage: 2 requests, 310 input tokens, 29 output tokens, 339 total tokens
138138
### 11: LLM ended. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens
139-
### 12: Tool multiply_by_two started. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens
140-
### 13: Tool multiply_by_two ended with result 138. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens
139+
### 12: Tool multiply_by_two started. name=multiply_by_two, call_id=call_KhHvTfsgaosZsfi741QvzgYw, args={"x":107}. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens
140+
### 13: Tool multiply_by_two finished. result=214, name=multiply_by_two, call_id=call_KhHvTfsgaosZsfi741QvzgYw, args={"x":107}. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens
141141
### 14: LLM started. Usage: 3 requests, 472 input tokens, 45 output tokens, 517 total tokens
142142
### 15: LLM ended. Usage: 4 requests, 660 input tokens, 56 output tokens, 716 total tokens
143-
### 16: Agent Multiply Agent ended with output number=138. Usage: 4 requests, 660 input tokens, 56 output tokens, 716 total tokens
143+
### 16: Agent Multiply Agent ended with output number=214. Usage: 4 requests, 660 input tokens, 56 output tokens, 716 total tokens
144144
Done!
145145
146146
"""

src/agents/realtime/session.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
408408
usage=self._context_wrapper.usage,
409409
tool_name=event.name,
410410
tool_call_id=event.call_id,
411+
tool_arguments=event.arguments,
411412
)
412413
result = await func_tool.on_invoke_tool(tool_context, event.arguments)
413414

@@ -432,6 +433,7 @@ async def _handle_tool_call(self, event: RealtimeModelToolCallEvent) -> None:
432433
usage=self._context_wrapper.usage,
433434
tool_name=event.name,
434435
tool_call_id=event.call_id,
436+
tool_arguments=event.arguments,
435437
)
436438

437439
# Execute the handoff to get the new agent

src/agents/tool_context.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ def _assert_must_pass_tool_name() -> str:
1414
raise ValueError("tool_name must be passed to ToolContext")
1515

1616

17+
def _assert_must_pass_tool_arguments() -> str:
18+
raise ValueError("tool_arguments must be passed to ToolContext")
19+
20+
1721
@dataclass
1822
class ToolContext(RunContextWrapper[TContext]):
1923
"""The context of a tool call."""
@@ -24,6 +28,9 @@ class ToolContext(RunContextWrapper[TContext]):
2428
tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
2529
"""The ID of the tool call."""
2630

31+
tool_arguments: str = field(default_factory=_assert_must_pass_tool_arguments)
32+
"""The raw arguments string of the tool call."""
33+
2734
@classmethod
2835
def from_agent_context(
2936
cls,
@@ -39,4 +46,10 @@ def from_agent_context(
3946
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
4047
}
4148
tool_name = tool_call.name if tool_call is not None else _assert_must_pass_tool_name()
42-
return cls(tool_name=tool_name, tool_call_id=tool_call_id, **base_values)
49+
tool_args = (
50+
tool_call.arguments if tool_call is not None else _assert_must_pass_tool_arguments()
51+
)
52+
53+
return cls(
54+
tool_name=tool_name, tool_call_id=tool_call_id, tool_arguments=tool_args, **base_values
55+
)

tests/test_agent_as_tool.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,12 @@ async def fake_run(
277277
)
278278

279279
assert isinstance(tool, FunctionTool)
280-
tool_context = ToolContext(context=None, tool_name="story_tool", tool_call_id="call_1")
280+
tool_context = ToolContext(
281+
context=None,
282+
tool_name="story_tool",
283+
tool_call_id="call_1",
284+
tool_arguments='{"input": "hello"}',
285+
)
281286
output = await tool.on_invoke_tool(tool_context, '{"input": "hello"}')
282287

283288
assert output == "Hello world"
@@ -374,7 +379,12 @@ async def extractor(result) -> str:
374379
)
375380

376381
assert isinstance(tool, FunctionTool)
377-
tool_context = ToolContext(context=None, tool_name="summary_tool", tool_call_id="call_2")
382+
tool_context = ToolContext(
383+
context=None,
384+
tool_name="summary_tool",
385+
tool_call_id="call_2",
386+
tool_arguments='{"input": "summarize this"}',
387+
)
378388
output = await tool.on_invoke_tool(tool_context, '{"input": "summarize this"}')
379389

380390
assert output == "custom output"

tests/test_function_tool.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ async def test_argless_function():
2727
assert tool.name == "argless_function"
2828

2929
result = await tool.on_invoke_tool(
30-
ToolContext(context=None, tool_name=tool.name, tool_call_id="1"), ""
30+
ToolContext(context=None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), ""
3131
)
3232
assert result == "ok"
3333

@@ -41,12 +41,15 @@ async def test_argless_with_context():
4141
tool = function_tool(argless_with_context)
4242
assert tool.name == "argless_with_context"
4343

44-
result = await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "")
44+
result = await tool.on_invoke_tool(
45+
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), ""
46+
)
4547
assert result == "ok"
4648

4749
# Extra JSON should not raise an error
4850
result = await tool.on_invoke_tool(
49-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}'
51+
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1}'),
52+
'{"a": 1}',
5053
)
5154
assert result == "ok"
5255

@@ -61,18 +64,22 @@ async def test_simple_function():
6164
assert tool.name == "simple_function"
6265

6366
result = await tool.on_invoke_tool(
64-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1}'
67+
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1}'),
68+
'{"a": 1}',
6569
)
6670
assert result == 6
6771

6872
result = await tool.on_invoke_tool(
69-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"a": 1, "b": 2}'
73+
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"a": 1, "b": 2}'),
74+
'{"a": 1, "b": 2}',
7075
)
7176
assert result == 3
7277

7378
# Missing required argument should raise an error
7479
with pytest.raises(ModelBehaviorError):
75-
await tool.on_invoke_tool(ToolContext(None, tool_name=tool.name, tool_call_id="1"), "")
80+
await tool.on_invoke_tool(
81+
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=""), ""
82+
)
7683

7784

7885
class Foo(BaseModel):
@@ -101,7 +108,8 @@ async def test_complex_args_function():
101108
}
102109
)
103110
result = await tool.on_invoke_tool(
104-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
111+
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=valid_json),
112+
valid_json,
105113
)
106114
assert result == "6 hello10 hello"
107115

@@ -112,7 +120,8 @@ async def test_complex_args_function():
112120
}
113121
)
114122
result = await tool.on_invoke_tool(
115-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
123+
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=valid_json),
124+
valid_json,
116125
)
117126
assert result == "3 hello10 hello"
118127

@@ -124,14 +133,18 @@ async def test_complex_args_function():
124133
}
125134
)
126135
result = await tool.on_invoke_tool(
127-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), valid_json
136+
ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments=valid_json),
137+
valid_json,
128138
)
129139
assert result == "3 hello10 world"
130140

131141
# Missing required argument should raise an error
132142
with pytest.raises(ModelBehaviorError):
133143
await tool.on_invoke_tool(
134-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"foo": {"a": 1}}'
144+
ToolContext(
145+
None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"foo": {"a": 1}}'
146+
),
147+
'{"foo": {"a": 1}}',
135148
)
136149

137150

@@ -193,7 +206,10 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
193206
assert tool.strict_json_schema
194207

195208
result = await tool.on_invoke_tool(
196-
ToolContext(None, tool_name=tool.name, tool_call_id="1"), '{"data": "hello"}'
209+
ToolContext(
210+
None, tool_name=tool.name, tool_call_id="1", tool_arguments='{"data": "hello"}'
211+
),
212+
'{"data": "hello"}',
197213
)
198214
assert result == "hello_done"
199215

@@ -209,7 +225,12 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
209225
assert "additionalProperties" not in tool_not_strict.params_json_schema
210226

211227
result = await tool_not_strict.on_invoke_tool(
212-
ToolContext(None, tool_name=tool_not_strict.name, tool_call_id="1"),
228+
ToolContext(
229+
None,
230+
tool_name=tool_not_strict.name,
231+
tool_call_id="1",
232+
tool_arguments='{"data": "hello", "bar": "baz"}',
233+
),
213234
'{"data": "hello", "bar": "baz"}',
214235
)
215236
assert result == "hello_done"
@@ -221,7 +242,7 @@ def my_func(a: int, b: int = 5):
221242
raise ValueError("test")
222243

223244
tool = function_tool(my_func)
224-
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")
245+
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments="")
225246

226247
result = await tool.on_invoke_tool(ctx, "")
227248
assert "Invalid JSON" in str(result)
@@ -245,7 +266,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
245266
return f"error_{error.__class__.__name__}"
246267

247268
tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
248-
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")
269+
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments="")
249270

250271
result = await tool.on_invoke_tool(ctx, "")
251272
assert result == "error_ModelBehaviorError"
@@ -269,7 +290,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
269290
return f"error_{error.__class__.__name__}"
270291

271292
tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
272-
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1")
293+
ctx = ToolContext(None, tool_name=tool.name, tool_call_id="1", tool_arguments="")
273294

274295
result = await tool.on_invoke_tool(ctx, "")
275296
assert result == "error_ModelBehaviorError"

tests/test_function_tool_decorator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ def __init__(self):
1616

1717

1818
def ctx_wrapper() -> ToolContext[DummyContext]:
19-
return ToolContext(context=DummyContext(), tool_name="dummy", tool_call_id="1")
19+
return ToolContext(
20+
context=DummyContext(), tool_name="dummy", tool_call_id="1", tool_arguments=""
21+
)
2022

2123

2224
@function_tool

0 commit comments

Comments
 (0)