@@ -24,12 +24,15 @@ import (
24
24
"github.com/stretchr/testify/assert"
25
25
"go.uber.org/mock/gomock"
26
26
27
+ "github.com/cloudwego/eino/callbacks"
28
+ chatmodel "github.com/cloudwego/eino/components/model"
27
29
"github.com/cloudwego/eino/components/prompt"
28
30
"github.com/cloudwego/eino/compose"
29
31
"github.com/cloudwego/eino/flow/agent"
30
32
"github.com/cloudwego/eino/internal/generic"
31
33
"github.com/cloudwego/eino/internal/mock/components/model"
32
34
"github.com/cloudwego/eino/schema"
35
+ template "github.com/cloudwego/eino/utils/callbacks"
33
36
)
34
37
35
38
func TestHostMultiAgent (t * testing.T ) {
@@ -48,6 +51,14 @@ func TestHostMultiAgent(t *testing.T) {
48
51
49
52
specialist2 := & Specialist {
50
53
Invokable : func (ctx context.Context , input []* schema.Message , opts ... agent.AgentOption ) (* schema.Message , error ) {
54
+ agentOpts := agent .GetImplSpecificOptions (& specialist2Options {}, opts ... )
55
+ if agentOpts .mockOutput != nil {
56
+ return & schema.Message {
57
+ Role : schema .Assistant ,
58
+ Content : * agentOpts .mockOutput ,
59
+ }, nil
60
+ }
61
+
51
62
return & schema.Message {
52
63
Role : schema .Assistant ,
53
64
Content : "specialist2 invoke answer" ,
@@ -92,11 +103,18 @@ func TestHostMultiAgent(t *testing.T) {
92
103
Content : "direct answer" ,
93
104
}
94
105
95
- mockHostLLM .EXPECT ().Generate (gomock .Any (), gomock .Any ()).Return (directAnswerMsg , nil ).Times (1 )
106
+ mockHostLLM .EXPECT ().Generate (gomock .Any (), gomock .Any (), gomock .Any ()).
107
+ DoAndReturn (func (_ context.Context , input []* schema.Message , opts ... chatmodel.Option ) (* schema.Message , error ) {
108
+ modelOpts := chatmodel .GetCommonOptions (& chatmodel.Options {}, opts ... )
109
+ assert .Equal (t , * modelOpts .Temperature , float32 (0.7 ))
110
+ return directAnswerMsg , nil
111
+ }).
112
+ Times (1 )
96
113
97
114
mockCallback := & mockAgentCallback {}
98
115
99
- out , err := hostMA .Generate (ctx , nil , WithAgentCallbacks (mockCallback ))
116
+ out , err := hostMA .Generate (ctx , nil , WithAgentCallbacks (mockCallback ),
117
+ WithAgentModelOptions (hostMA .HostNodeKey (), chatmodel .WithTemperature (0.7 )))
100
118
assert .NoError (t , err )
101
119
assert .Equal (t , "direct answer" , out .Content )
102
120
assert .Empty (t , mockCallback .infos )
@@ -164,11 +182,18 @@ func TestHostMultiAgent(t *testing.T) {
164
182
}
165
183
166
184
mockHostLLM .EXPECT ().Generate (gomock .Any (), gomock .Any ()).Return (handOffMsg , nil ).Times (1 )
167
- mockSpecialistLLM1 .EXPECT ().Generate (gomock .Any (), gomock .Any ()).Return (specialistMsg , nil ).Times (1 )
185
+ mockSpecialistLLM1 .EXPECT ().Generate (gomock .Any (), gomock .Any (), gomock .Any ()).
186
+ DoAndReturn (func (_ context.Context , input []* schema.Message , opts ... chatmodel.Option ) (* schema.Message , error ) {
187
+ modelOpts := chatmodel .GetCommonOptions (& chatmodel.Options {}, opts ... )
188
+ assert .Equal (t , * modelOpts .Temperature , float32 (0.7 ))
189
+ return specialistMsg , nil
190
+ }).
191
+ Times (1 )
168
192
169
193
mockCallback := & mockAgentCallback {}
170
194
171
- out , err := hostMA .Generate (ctx , nil , WithAgentCallbacks (mockCallback ))
195
+ out , err := hostMA .Generate (ctx , nil , WithAgentCallbacks (mockCallback ),
196
+ WithAgentModelOptions (specialist1 .Name , chatmodel .WithTemperature (0.7 )))
172
197
assert .NoError (t , err )
173
198
assert .Equal (t , "specialist 1 answer" , out .Content )
174
199
assert .Equal (t , []* HandOffInfo {
@@ -379,16 +404,41 @@ func TestHostMultiAgent(t *testing.T) {
379
404
},
380
405
}
381
406
382
- specialistMsg := & schema.Message {
407
+ specialist1Msg := & schema.Message {
383
408
Role : schema .Assistant ,
384
409
Content : "Beijing" ,
385
410
}
386
411
387
- mockHostLLM .EXPECT ().Generate (gomock .Any (), gomock .Any ()).Return (handOffMsg , nil ).Times (1 )
388
- mockSpecialistLLM1 .EXPECT ().Generate (gomock .Any (), gomock .Any ()).Return (specialistMsg , nil ).Times (1 )
412
+ mockHostLLM .EXPECT ().Generate (gomock .Any (), gomock .Any ()).Return (handOffMsg , nil ).Times (2 )
413
+ mockSpecialistLLM1 .EXPECT ().Generate (gomock .Any (), gomock .Any (), gomock .Any ()).
414
+ DoAndReturn (func (_ context.Context , input []* schema.Message , opts ... chatmodel.Option ) (* schema.Message , error ) {
415
+ modelOpts := chatmodel .GetCommonOptions (& chatmodel.Options {}, opts ... )
416
+ assert .Equal (t , * modelOpts .Temperature , float32 (0.7 ))
417
+ return specialist1Msg , nil
418
+ }).
419
+ Times (1 )
389
420
390
421
mockCallback := & mockAgentCallback {}
391
422
423
+ var hostOutput , specialist1Output , specialist2Output string
424
+ hostModelCallback := template .NewHandlerHelper ().ChatModel (& template.ModelCallbackHandler {
425
+ OnEnd : func (ctx context.Context , runInfo * callbacks.RunInfo , output * chatmodel.CallbackOutput ) context.Context {
426
+ hostOutput = output .Message .ToolCalls [0 ].Function .Name
427
+ return ctx
428
+ },
429
+ }).Handler ()
430
+ specialist1ModelCallback := template .NewHandlerHelper ().ChatModel (& template.ModelCallbackHandler {
431
+ OnEnd : func (ctx context.Context , runInfo * callbacks.RunInfo , output * chatmodel.CallbackOutput ) context.Context {
432
+ specialist1Output = output .Message .Content
433
+ return ctx
434
+ },
435
+ }).Handler ()
436
+ specialist2LambdaCallback := template .NewHandlerHelper ().Lambda (callbacks .NewHandlerBuilder ().OnEndFn (
437
+ func (ctx context.Context , info * callbacks.RunInfo , output callbacks.CallbackOutput ) context.Context {
438
+ specialist2Output = output .(* schema.Message ).Content
439
+ return ctx
440
+ }).Build ()).Handler ()
441
+
392
442
hostMA , err := NewMultiAgent (ctx , & MultiAgentConfig {
393
443
Host : Host {
394
444
ChatModel : mockHostLLM ,
@@ -409,7 +459,14 @@ func TestHostMultiAgent(t *testing.T) {
409
459
Compile (ctx )
410
460
assert .NoError (t , err )
411
461
412
- out , err := fullGraph .Invoke (ctx , map [string ]any {"country_name" : "China" }, compose .WithCallbacks (ConvertCallbackHandlers (mockCallback )).DesignateNodeWithPath (compose .NewNodePath ("host_ma_node" , hostMA .HostNodeKey ())))
462
+ convertedOptions := ConvertOptions (compose .NewNodePath ("host_ma_node" ), WithAgentCallbacks (mockCallback ),
463
+ WithAgentModelOptions (specialist1 .Name , chatmodel .WithTemperature (0.7 )),
464
+ WithAgentModelCallbacks (hostMA .HostNodeKey (), hostModelCallback ),
465
+ WithAgentModelCallbacks (specialist1 .Name , specialist1ModelCallback ),
466
+ WithSpecialistLambdaCallbacks (specialist2 .Name , specialist2LambdaCallback ),
467
+ WithSpecialistLambdaOptions (specialist2 .Name , withSpecialist2MockOutput ("mock_city_name" )))
468
+
469
+ out , err := fullGraph .Invoke (ctx , map [string ]any {"country_name" : "China" }, convertedOptions ... )
413
470
assert .NoError (t , err )
414
471
assert .Equal (t , "Beijing" , out .Content )
415
472
assert .Equal (t , []* HandOffInfo {
@@ -418,6 +475,28 @@ func TestHostMultiAgent(t *testing.T) {
418
475
Argument : `{"reason": "specialist 1 is the best"}` ,
419
476
},
420
477
}, mockCallback .infos )
478
+ assert .Equal (t , hostOutput , specialist1 .Name )
479
+ assert .Equal (t , specialist1Output , out .Content )
480
+ assert .Equal (t , specialist2Output , "" )
481
+
482
+ handOffMsg .ToolCalls [0 ].Function .Name = specialist2 .Name
483
+ handOffMsg .ToolCalls [0 ].Function .Arguments = `{"reason": "specialist 2 is even better"}`
484
+
485
+ out , err = fullGraph .Invoke (ctx , map [string ]any {"country_name" : "China" }, convertedOptions ... )
486
+ assert .NoError (t , err )
487
+ assert .Equal (t , "mock_city_name" , out .Content )
488
+ assert .Equal (t , []* HandOffInfo {
489
+ {
490
+ ToAgentName : specialist1 .Name ,
491
+ Argument : `{"reason": "specialist 1 is the best"}` ,
492
+ },
493
+ {
494
+ ToAgentName : specialist2 .Name ,
495
+ Argument : `{"reason": "specialist 2 is even better"}` ,
496
+ },
497
+ }, mockCallback .infos )
498
+ assert .Equal (t , hostOutput , specialist2 .Name )
499
+ assert .Equal (t , specialist2Output , "mock_city_name" )
421
500
})
422
501
}
423
502
@@ -429,3 +508,13 @@ func (m *mockAgentCallback) OnHandOff(ctx context.Context, info *HandOffInfo) co
429
508
m .infos = append (m .infos , info )
430
509
return ctx
431
510
}
511
+
512
+ type specialist2Options struct {
513
+ mockOutput * string
514
+ }
515
+
516
+ func withSpecialist2MockOutput (mockOutput string ) agent.AgentOption {
517
+ return agent .WrapImplSpecificOptFn (func (o * specialist2Options ) {
518
+ o .mockOutput = & mockOutput
519
+ })
520
+ }
0 commit comments