@@ -27,7 +27,7 @@ async def test_argless_function():
27
27
assert tool .name == "argless_function"
28
28
29
29
result = await tool .on_invoke_tool (
30
- ToolContext (context = None , tool_name = tool .name , tool_call_id = "1" ), ""
30
+ ToolContext (context = None , tool_name = tool .name , tool_call_id = "1" , tool_arguments = "" ), ""
31
31
)
32
32
assert result == "ok"
33
33
@@ -41,12 +41,15 @@ async def test_argless_with_context():
41
41
tool = function_tool (argless_with_context )
42
42
assert tool .name == "argless_with_context"
43
43
44
- result = await tool .on_invoke_tool (ToolContext (None , tool_name = tool .name , tool_call_id = "1" ), "" )
44
+ result = await tool .on_invoke_tool (
45
+ ToolContext (None , tool_name = tool .name , tool_call_id = "1" , tool_arguments = "" ), ""
46
+ )
45
47
assert result == "ok"
46
48
47
49
# Extra JSON should not raise an error
48
50
result = await tool .on_invoke_tool (
49
- ToolContext (None , tool_name = tool .name , tool_call_id = "1" ), '{"a": 1}'
51
+ ToolContext (None , tool_name = tool .name , tool_call_id = "1" , tool_arguments = '{"a": 1}' ),
52
+ '{"a": 1}' ,
50
53
)
51
54
assert result == "ok"
52
55
@@ -61,18 +64,22 @@ async def test_simple_function():
61
64
assert tool .name == "simple_function"
62
65
63
66
result = await tool .on_invoke_tool (
64
- ToolContext (None , tool_name = tool .name , tool_call_id = "1" ), '{"a": 1}'
67
+ ToolContext (None , tool_name = tool .name , tool_call_id = "1" , tool_arguments = '{"a": 1}' ),
68
+ '{"a": 1}' ,
65
69
)
66
70
assert result == 6
67
71
68
72
result = await tool .on_invoke_tool (
69
- ToolContext (None , tool_name = tool .name , tool_call_id = "1" ), '{"a": 1, "b": 2}'
73
+ ToolContext (None , tool_name = tool .name , tool_call_id = "1" , tool_arguments = '{"a": 1, "b": 2}' ),
74
+ '{"a": 1, "b": 2}' ,
70
75
)
71
76
assert result == 3
72
77
73
78
# Missing required argument should raise an error
74
79
with pytest .raises (ModelBehaviorError ):
75
- await tool .on_invoke_tool (ToolContext (None , tool_name = tool .name , tool_call_id = "1" ), "" )
80
+ await tool .on_invoke_tool (
81
+ ToolContext (None , tool_name = tool .name , tool_call_id = "1" , tool_arguments = "" ), ""
82
+ )
76
83
77
84
78
85
class Foo (BaseModel ):
@@ -101,7 +108,8 @@ async def test_complex_args_function():
101
108
}
102
109
)
103
110
result = await tool .on_invoke_tool (
104
- ToolContext (None , tool_name = tool .name , tool_call_id = "1" ), valid_json
111
+ ToolContext (None , tool_name = tool .name , tool_call_id = "1" , tool_arguments = valid_json ),
112
+ valid_json ,
105
113
)
106
114
assert result == "6 hello10 hello"
107
115
@@ -112,7 +120,8 @@ async def test_complex_args_function():
112
120
}
113
121
)
114
122
result = await tool .on_invoke_tool (
115
- ToolContext (None , tool_name = tool .name , tool_call_id = "1" ), valid_json
123
+ ToolContext (None , tool_name = tool .name , tool_call_id = "1" , tool_arguments = valid_json ),
124
+ valid_json ,
116
125
)
117
126
assert result == "3 hello10 hello"
118
127
@@ -124,14 +133,18 @@ async def test_complex_args_function():
124
133
}
125
134
)
126
135
result = await tool .on_invoke_tool (
127
- ToolContext (None , tool_name = tool .name , tool_call_id = "1" ), valid_json
136
+ ToolContext (None , tool_name = tool .name , tool_call_id = "1" , tool_arguments = valid_json ),
137
+ valid_json ,
128
138
)
129
139
assert result == "3 hello10 world"
130
140
131
141
# Missing required argument should raise an error
132
142
with pytest .raises (ModelBehaviorError ):
133
143
await tool .on_invoke_tool (
134
- ToolContext (None , tool_name = tool .name , tool_call_id = "1" ), '{"foo": {"a": 1}}'
144
+ ToolContext (
145
+ None , tool_name = tool .name , tool_call_id = "1" , tool_arguments = '{"foo": {"a": 1}}'
146
+ ),
147
+ '{"foo": {"a": 1}}' ,
135
148
)
136
149
137
150
@@ -193,7 +206,10 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
193
206
assert tool .strict_json_schema
194
207
195
208
result = await tool .on_invoke_tool (
196
- ToolContext (None , tool_name = tool .name , tool_call_id = "1" ), '{"data": "hello"}'
209
+ ToolContext (
210
+ None , tool_name = tool .name , tool_call_id = "1" , tool_arguments = '{"data": "hello"}'
211
+ ),
212
+ '{"data": "hello"}' ,
197
213
)
198
214
assert result == "hello_done"
199
215
@@ -209,7 +225,12 @@ async def run_function(ctx: RunContextWrapper[Any], args: str) -> str:
209
225
assert "additionalProperties" not in tool_not_strict .params_json_schema
210
226
211
227
result = await tool_not_strict .on_invoke_tool (
212
- ToolContext (None , tool_name = tool_not_strict .name , tool_call_id = "1" ),
228
+ ToolContext (
229
+ None ,
230
+ tool_name = tool_not_strict .name ,
231
+ tool_call_id = "1" ,
232
+ tool_arguments = '{"data": "hello", "bar": "baz"}' ,
233
+ ),
213
234
'{"data": "hello", "bar": "baz"}' ,
214
235
)
215
236
assert result == "hello_done"
@@ -221,7 +242,7 @@ def my_func(a: int, b: int = 5):
221
242
raise ValueError ("test" )
222
243
223
244
tool = function_tool (my_func )
224
- ctx = ToolContext (None , tool_name = tool .name , tool_call_id = "1" )
245
+ ctx = ToolContext (None , tool_name = tool .name , tool_call_id = "1" , tool_arguments = "" )
225
246
226
247
result = await tool .on_invoke_tool (ctx , "" )
227
248
assert "Invalid JSON" in str (result )
@@ -245,7 +266,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
245
266
return f"error_{ error .__class__ .__name__ } "
246
267
247
268
tool = function_tool (my_func , failure_error_function = custom_sync_error_function )
248
- ctx = ToolContext (None , tool_name = tool .name , tool_call_id = "1" )
269
+ ctx = ToolContext (None , tool_name = tool .name , tool_call_id = "1" , tool_arguments = "" )
249
270
250
271
result = await tool .on_invoke_tool (ctx , "" )
251
272
assert result == "error_ModelBehaviorError"
@@ -269,7 +290,7 @@ def custom_sync_error_function(ctx: RunContextWrapper[Any], error: Exception) ->
269
290
return f"error_{ error .__class__ .__name__ } "
270
291
271
292
tool = function_tool (my_func , failure_error_function = custom_sync_error_function )
272
- ctx = ToolContext (None , tool_name = tool .name , tool_call_id = "1" )
293
+ ctx = ToolContext (None , tool_name = tool .name , tool_call_id = "1" , tool_arguments = "" )
273
294
274
295
result = await tool .on_invoke_tool (ctx , "" )
275
296
assert result == "error_ModelBehaviorError"
0 commit comments