Skip to content

Commit a682a11

Browse files
feat: add more AgentOption for react agent and a converter function to compose.Option
Change-Id: I3a069cd40b5fff9d701b35c4f1d383c303e2387f
1 parent 2d75c5c commit a682a11

File tree

11 files changed

+424
-50
lines changed

11 files changed

+424
-50
lines changed

components/tool/option.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
package tool
1818

1919
// Option defines call option for InvokableTool or StreamableTool component, which is part of component interface signature.
20-
// Each tool implementation could define its own options struct and option funcs within its own package,
21-
// then wrap the impl specific option funcs into this type, before passing to InvokableRun or StreamableRun.
20+
// Each tool implementation could define its own options struct and option functions within its own package,
21+
// then wrap the impl specific option functions into this type, before passing to InvokableRun or StreamableRun.
2222
type Option struct {
2323
implSpecificOptFn any
2424
}

compose/graph_call_options.go

+18
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,24 @@ func (o Option) DesignateNodeWithPath(path ...*NodePath) Option {
8080
return o
8181
}
8282

83+
// DesignateNodePrependPath prepends the prefix to the path of the node(s) to which the option will be applied to.
84+
// Useful when you already have an Option designated to a graph's node, and now you want to add this graph as a subgraph.
85+
// e.g.
86+
// Your subgraph has a Node with key "A", and your subgraph's NodeKey is "sub_graph", you can specify option to A using:
87+
//
88+
// option := WithCallbacks(...).DesignateNode("A").DesignateNodePrependPath("sub_graph")
89+
// Note: as an End User, you probably don't need to use this method, as DesignateNodeWithPath will be sufficient in most use cases.
90+
// Note: as a Flow author, if you define your own Option type, and at the same time your flow can be exported to graph and added as GraphNode,
91+
// you can use this method to prepend your Option's designated path with the GraphNode's path.
92+
func (o Option) DesignateNodePrependPath(prefix *NodePath) Option {
93+
for i := range o.paths {
94+
p := o.paths[i]
95+
p.path = append(prefix.path, p.path...)
96+
}
97+
98+
return o
99+
}
100+
83101
// WithEmbeddingOption is a functional option type for embedding component.
84102
// e.g.
85103
//

flow/agent/agent_option.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ package agent
1919
import "github.com/cloudwego/eino/compose"
2020

2121
// AgentOption is the common option type for various agent and multi-agent implementations.
22-
// For options intended to use with underlying graph or components, use WithComposeOptions to specify.
2322
// For options intended to use with particular agent/multi-agent implementations, use WrapImplSpecificOptFn to specify.
2423
type AgentOption struct {
2524
implSpecificOptFn any
2625
composeOptions []compose.Option
2726
}
2827

2928
// GetComposeOptions returns all compose options from the given agent options.
29+
// Deprecated
3030
func GetComposeOptions(opts ...AgentOption) []compose.Option {
3131
var result []compose.Option
3232
for _, opt := range opts {
@@ -37,6 +37,7 @@ func GetComposeOptions(opts ...AgentOption) []compose.Option {
3737
}
3838

3939
// WithComposeOptions returns an agent option that specifies compose options.
40+
// Deprecated: use option functions defined by each agent flow implementation instead.
4041
func WithComposeOptions(opts ...compose.Option) AgentOption {
4142
return AgentOption{
4243
composeOptions: opts,

flow/agent/multiagent/host/callback.go

+73-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ type HandOffInfo struct {
3939
}
4040

4141
// ConvertCallbackHandlers converts []host.MultiAgentCallback to callbacks.Handler.
42+
// Deprecated: use ConvertOptions to convert agent.AgentOption to compose.Option when adding MultiAgent's Graph to another Graph.
4243
func ConvertCallbackHandlers(handlers ...MultiAgentCallback) callbacks.Handler {
4344
onChatModelEnd := func(ctx context.Context, info *callbacks.RunInfo, output *model.CallbackOutput) context.Context {
4445
if output == nil || info == nil {
@@ -121,5 +122,76 @@ func convertCallbacks(opts ...agent.AgentOption) callbacks.Handler {
121122
}
122123

123124
handlers := agentOptions.agentCallbacks
124-
return ConvertCallbackHandlers(handlers...)
125+
126+
onChatModelEnd := func(ctx context.Context, info *callbacks.RunInfo, output *model.CallbackOutput) context.Context {
127+
if output == nil || info == nil {
128+
return ctx
129+
}
130+
131+
msg := output.Message
132+
if msg == nil || msg.Role != schema.Assistant || len(msg.ToolCalls) == 0 {
133+
return ctx
134+
}
135+
136+
agentName := msg.ToolCalls[0].Function.Name
137+
argument := msg.ToolCalls[0].Function.Arguments
138+
139+
for _, cb := range handlers {
140+
ctx = cb.OnHandOff(ctx, &HandOffInfo{
141+
ToAgentName: agentName,
142+
Argument: argument,
143+
})
144+
}
145+
146+
return ctx
147+
}
148+
149+
onChatModelEndWithStreamOutput := func(ctx context.Context, info *callbacks.RunInfo, output *schema.StreamReader[*model.CallbackOutput]) context.Context {
150+
if output == nil || info == nil {
151+
return ctx
152+
}
153+
154+
defer output.Close()
155+
156+
var msgs []*schema.Message
157+
for {
158+
oneOutput, err := output.Recv()
159+
if err == io.EOF {
160+
break
161+
}
162+
if err != nil {
163+
return ctx
164+
}
165+
166+
msg := oneOutput.Message
167+
if msg == nil {
168+
continue
169+
}
170+
171+
msgs = append(msgs, msg)
172+
}
173+
174+
msg, err := schema.ConcatMessages(msgs)
175+
if err != nil {
176+
return ctx
177+
}
178+
179+
if msg.Role != schema.Assistant || len(msg.ToolCalls) == 0 {
180+
return ctx
181+
}
182+
183+
for _, cb := range handlers {
184+
ctx = cb.OnHandOff(ctx, &HandOffInfo{
185+
ToAgentName: msg.ToolCalls[0].Function.Name,
186+
Argument: msg.ToolCalls[0].Function.Arguments,
187+
})
188+
}
189+
190+
return ctx
191+
}
192+
193+
return template.NewHandlerHelper().ChatModel(&template.ModelCallbackHandler{
194+
OnEnd: onChatModelEnd,
195+
OnEndWithStreamOutput: onChatModelEndWithStreamOutput,
196+
}).Handler()
125197
}

flow/agent/multiagent/host/compose_test.go

+97-8
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,15 @@ import (
2424
"github.com/stretchr/testify/assert"
2525
"go.uber.org/mock/gomock"
2626

27+
"github.com/cloudwego/eino/callbacks"
28+
chatmodel "github.com/cloudwego/eino/components/model"
2729
"github.com/cloudwego/eino/components/prompt"
2830
"github.com/cloudwego/eino/compose"
2931
"github.com/cloudwego/eino/flow/agent"
3032
"github.com/cloudwego/eino/internal/generic"
3133
"github.com/cloudwego/eino/internal/mock/components/model"
3234
"github.com/cloudwego/eino/schema"
35+
template "github.com/cloudwego/eino/utils/callbacks"
3336
)
3437

3538
func TestHostMultiAgent(t *testing.T) {
@@ -48,6 +51,14 @@ func TestHostMultiAgent(t *testing.T) {
4851

4952
specialist2 := &Specialist{
5053
Invokable: func(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.Message, error) {
54+
agentOpts := agent.GetImplSpecificOptions(&specialist2Options{}, opts...)
55+
if agentOpts.mockOutput != nil {
56+
return &schema.Message{
57+
Role: schema.Assistant,
58+
Content: *agentOpts.mockOutput,
59+
}, nil
60+
}
61+
5162
return &schema.Message{
5263
Role: schema.Assistant,
5364
Content: "specialist2 invoke answer",
@@ -92,11 +103,18 @@ func TestHostMultiAgent(t *testing.T) {
92103
Content: "direct answer",
93104
}
94105

95-
mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(directAnswerMsg, nil).Times(1)
106+
mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
107+
DoAndReturn(func(_ context.Context, input []*schema.Message, opts ...chatmodel.Option) (*schema.Message, error) {
108+
modelOpts := chatmodel.GetCommonOptions(&chatmodel.Options{}, opts...)
109+
assert.Equal(t, *modelOpts.Temperature, float32(0.7))
110+
return directAnswerMsg, nil
111+
}).
112+
Times(1)
96113

97114
mockCallback := &mockAgentCallback{}
98115

99-
out, err := hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback))
116+
out, err := hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback),
117+
WithAgentModelOptions(hostMA.HostNodeKey(), chatmodel.WithTemperature(0.7)))
100118
assert.NoError(t, err)
101119
assert.Equal(t, "direct answer", out.Content)
102120
assert.Empty(t, mockCallback.infos)
@@ -164,11 +182,18 @@ func TestHostMultiAgent(t *testing.T) {
164182
}
165183

166184
mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(handOffMsg, nil).Times(1)
167-
mockSpecialistLLM1.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(specialistMsg, nil).Times(1)
185+
mockSpecialistLLM1.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
186+
DoAndReturn(func(_ context.Context, input []*schema.Message, opts ...chatmodel.Option) (*schema.Message, error) {
187+
modelOpts := chatmodel.GetCommonOptions(&chatmodel.Options{}, opts...)
188+
assert.Equal(t, *modelOpts.Temperature, float32(0.7))
189+
return specialistMsg, nil
190+
}).
191+
Times(1)
168192

169193
mockCallback := &mockAgentCallback{}
170194

171-
out, err := hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback))
195+
out, err := hostMA.Generate(ctx, nil, WithAgentCallbacks(mockCallback),
196+
WithAgentModelOptions(specialist1.Name, chatmodel.WithTemperature(0.7)))
172197
assert.NoError(t, err)
173198
assert.Equal(t, "specialist 1 answer", out.Content)
174199
assert.Equal(t, []*HandOffInfo{
@@ -379,16 +404,41 @@ func TestHostMultiAgent(t *testing.T) {
379404
},
380405
}
381406

382-
specialistMsg := &schema.Message{
407+
specialist1Msg := &schema.Message{
383408
Role: schema.Assistant,
384409
Content: "Beijing",
385410
}
386411

387-
mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(handOffMsg, nil).Times(1)
388-
mockSpecialistLLM1.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(specialistMsg, nil).Times(1)
412+
mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(handOffMsg, nil).Times(2)
413+
mockSpecialistLLM1.EXPECT().Generate(gomock.Any(), gomock.Any(), gomock.Any()).
414+
DoAndReturn(func(_ context.Context, input []*schema.Message, opts ...chatmodel.Option) (*schema.Message, error) {
415+
modelOpts := chatmodel.GetCommonOptions(&chatmodel.Options{}, opts...)
416+
assert.Equal(t, *modelOpts.Temperature, float32(0.7))
417+
return specialist1Msg, nil
418+
}).
419+
Times(1)
389420

390421
mockCallback := &mockAgentCallback{}
391422

423+
var hostOutput, specialist1Output, specialist2Output string
424+
hostModelCallback := template.NewHandlerHelper().ChatModel(&template.ModelCallbackHandler{
425+
OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *chatmodel.CallbackOutput) context.Context {
426+
hostOutput = output.Message.ToolCalls[0].Function.Name
427+
return ctx
428+
},
429+
}).Handler()
430+
specialist1ModelCallback := template.NewHandlerHelper().ChatModel(&template.ModelCallbackHandler{
431+
OnEnd: func(ctx context.Context, runInfo *callbacks.RunInfo, output *chatmodel.CallbackOutput) context.Context {
432+
specialist1Output = output.Message.Content
433+
return ctx
434+
},
435+
}).Handler()
436+
specialist2LambdaCallback := template.NewHandlerHelper().Lambda(callbacks.NewHandlerBuilder().OnEndFn(
437+
func(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
438+
specialist2Output = output.(*schema.Message).Content
439+
return ctx
440+
}).Build()).Handler()
441+
392442
hostMA, err := NewMultiAgent(ctx, &MultiAgentConfig{
393443
Host: Host{
394444
ChatModel: mockHostLLM,
@@ -409,7 +459,14 @@ func TestHostMultiAgent(t *testing.T) {
409459
Compile(ctx)
410460
assert.NoError(t, err)
411461

412-
out, err := fullGraph.Invoke(ctx, map[string]any{"country_name": "China"}, compose.WithCallbacks(ConvertCallbackHandlers(mockCallback)).DesignateNodeWithPath(compose.NewNodePath("host_ma_node", hostMA.HostNodeKey())))
462+
convertedOptions := ConvertOptions(compose.NewNodePath("host_ma_node"), WithAgentCallbacks(mockCallback),
463+
WithAgentModelOptions(specialist1.Name, chatmodel.WithTemperature(0.7)),
464+
WithAgentModelCallbacks(hostMA.HostNodeKey(), hostModelCallback),
465+
WithAgentModelCallbacks(specialist1.Name, specialist1ModelCallback),
466+
WithSpecialistLambdaCallbacks(specialist2.Name, specialist2LambdaCallback),
467+
WithSpecialistLambdaOptions(specialist2.Name, withSpecialist2MockOutput("mock_city_name")))
468+
469+
out, err := fullGraph.Invoke(ctx, map[string]any{"country_name": "China"}, convertedOptions...)
413470
assert.NoError(t, err)
414471
assert.Equal(t, "Beijing", out.Content)
415472
assert.Equal(t, []*HandOffInfo{
@@ -418,6 +475,28 @@ func TestHostMultiAgent(t *testing.T) {
418475
Argument: `{"reason": "specialist 1 is the best"}`,
419476
},
420477
}, mockCallback.infos)
478+
assert.Equal(t, hostOutput, specialist1.Name)
479+
assert.Equal(t, specialist1Output, out.Content)
480+
assert.Equal(t, specialist2Output, "")
481+
482+
handOffMsg.ToolCalls[0].Function.Name = specialist2.Name
483+
handOffMsg.ToolCalls[0].Function.Arguments = `{"reason": "specialist 2 is even better"}`
484+
485+
out, err = fullGraph.Invoke(ctx, map[string]any{"country_name": "China"}, convertedOptions...)
486+
assert.NoError(t, err)
487+
assert.Equal(t, "mock_city_name", out.Content)
488+
assert.Equal(t, []*HandOffInfo{
489+
{
490+
ToAgentName: specialist1.Name,
491+
Argument: `{"reason": "specialist 1 is the best"}`,
492+
},
493+
{
494+
ToAgentName: specialist2.Name,
495+
Argument: `{"reason": "specialist 2 is even better"}`,
496+
},
497+
}, mockCallback.infos)
498+
assert.Equal(t, hostOutput, specialist2.Name)
499+
assert.Equal(t, specialist2Output, "mock_city_name")
421500
})
422501
}
423502

@@ -429,3 +508,13 @@ func (m *mockAgentCallback) OnHandOff(ctx context.Context, info *HandOffInfo) co
429508
m.infos = append(m.infos, info)
430509
return ctx
431510
}
511+
512+
type specialist2Options struct {
513+
mockOutput *string
514+
}
515+
516+
func withSpecialist2MockOutput(mockOutput string) agent.AgentOption {
517+
return agent.WrapImplSpecificOptFn(func(o *specialist2Options) {
518+
o.mockOutput = &mockOutput
519+
})
520+
}

0 commit comments

Comments
 (0)