Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 42 additions & 4 deletions agent/remoteagent/a2a_agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ func TestRemoteAgent_ADK2ADK(t *testing.T) {
name string
remoteEvents []*session.Event
wantResponses []model.LLMResponse
wantEscalate bool
wantTransfer string
}{
{
name: "text streaming",
Expand Down Expand Up @@ -256,6 +258,34 @@ func TestRemoteAgent_ADK2ADK(t *testing.T) {
{TurnComplete: true},
},
},
{
name: "escalation",
remoteEvents: []*session.Event{
{
LLMResponse: model.LLMResponse{Content: genai.NewContentFromText("stop", genai.RoleModel)},
Actions: session.EventActions{Escalate: true},
},
},
wantResponses: []model.LLMResponse{
{Content: genai.NewContentFromText("stop", genai.RoleModel), Partial: true},
{TurnComplete: true},
},
wantEscalate: true,
},
{
name: "transfer",
remoteEvents: []*session.Event{
{
LLMResponse: model.LLMResponse{Content: genai.NewContentFromText("stop", genai.RoleModel)},
Actions: session.EventActions{TransferToAgent: "a-2"},
},
},
wantResponses: []model.LLMResponse{
{Content: genai.NewContentFromText("stop", genai.RoleModel), Partial: true},
{TurnComplete: true},
},
wantTransfer: "a-2",
},
}

ignoreFields := []cmp.Option{
Expand All @@ -272,19 +302,27 @@ func TestRemoteAgent_ADK2ADK(t *testing.T) {
ictx := newInvocationContext(t, []*session.Event{newUserHello()})
gotEvents, err := runAndCollect(ictx, remoteAgent)
if err != nil {
t.Errorf("agent.Run() error = %v", err)
t.Fatalf("agent.Run() error = %v", err)
}
gotResponses := toLLMResponses(gotEvents)
if diff := cmp.Diff(tc.wantResponses, gotResponses, ignoreFields...); diff != "" {
t.Errorf("agent.Run() wrong result (+got,-want):\ngot = %+v\nwant = %+v\ndiff = %s", gotResponses, tc.wantResponses, diff)
t.Fatalf("agent.Run() wrong result (+got,-want):\ngot = %+v\nwant = %+v\ndiff = %s", gotResponses, tc.wantResponses, diff)
}
var lastActions *session.EventActions
for _, event := range gotEvents {
if _, ok := event.CustomMetadata[adka2a.ToADKMetaKey("response")]; !ok {
t.Errorf("event.CustomMetadata = %v, want meta[%q] = original a2a event", event.CustomMetadata, adka2a.ToADKMetaKey("response"))
t.Fatalf("event.CustomMetadata = %v, want meta[%q] = original a2a event", event.CustomMetadata, adka2a.ToADKMetaKey("response"))
}
if _, ok := event.CustomMetadata[adka2a.ToADKMetaKey("request")]; !ok {
t.Errorf("event.CustomMetadata = %v, want meta[%q] = original a2a request", event.CustomMetadata, adka2a.ToADKMetaKey("request"))
t.Fatalf("event.CustomMetadata = %v, want meta[%q] = original a2a request", event.CustomMetadata, adka2a.ToADKMetaKey("request"))
}
lastActions = &event.Actions
}
if tc.wantEscalate != lastActions.Escalate {
t.Fatalf("lastActions.Escalate = %v, want %v", lastActions.Escalate, tc.wantEscalate)
}
if tc.wantTransfer != lastActions.TransferToAgent {
t.Fatalf("lastActions.TransferToAgent = %v, want %v", lastActions.TransferToAgent, tc.wantTransfer)
}
})
}
Expand Down
18 changes: 17 additions & 1 deletion server/adka2a/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ import (
var (
customMetaTaskIDKey = ToADKMetaKey("task_id")
customMetaContextIDKey = ToADKMetaKey("context_id")

metadataEscalateKey = ToA2AMetaKey("escalate")
metadataTransferToAgentKey = ToA2AMetaKey("transfer_to_agent")
)

// NewRemoteAgentEvent create a new Event authored by the agent running in the provided invocation context.
Expand Down Expand Up @@ -54,7 +57,9 @@ func EventToMessage(event *session.Event) (*a2a.Message, error) {
role = a2a.MessageRoleAgent
}

return a2a.NewMessage(role, parts...), nil
msg := a2a.NewMessage(role, parts...)
msg.Metadata = setActionsMeta(msg.Metadata, event.Actions)
return msg, nil
}

// ToSessionEvent converts the provided a2a event to session event authored by the agent running in the provided invocation context.
Expand Down Expand Up @@ -154,6 +159,7 @@ func messageToEvent(ctx agent.InvocationContext, msg *a2a.Message) (*session.Eve
if msg.TaskID != "" || msg.ContextID != "" {
event.CustomMetadata = ToCustomMetadata(msg.TaskID, msg.ContextID)
}
event.Actions = toEventActions(msg)
return event, nil
}

Expand Down Expand Up @@ -212,6 +218,7 @@ func taskToEvent(ctx agent.InvocationContext, task *a2a.Task) (*session.Event, e
if !task.Status.State.Terminal() && task.Status.State != a2a.TaskStateInputRequired {
event.Partial = true
}
event.Actions = toEventActions(task)
return event, nil
}

Expand All @@ -233,6 +240,7 @@ func finalTaskStatusUpdateToEvent(ctx agent.InvocationContext, update *a2a.TaskS
event.Content = genai.NewContentFromParts(parts, genai.RoleModel)
}
event.CustomMetadata = ToCustomMetadata(update.TaskID, update.ContextID)
event.Actions = toEventActions(update)
event.TurnComplete = true
return event, nil
}
Expand Down Expand Up @@ -263,3 +271,11 @@ func toGenAIRole(role a2a.MessageRole) genai.Role {
return genai.RoleModel
}
}

func toEventActions(event a2a.Event) session.EventActions {
meta := event.Meta()
var result session.EventActions
result.Escalate, _ = meta[metadataEscalateKey].(bool)
result.TransferToAgent, _ = meta[metadataTransferToAgentKey].(string)
return result
}
16 changes: 11 additions & 5 deletions server/adka2a/events_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func TestToSessionEvent(t *testing.T) {
Parts: []a2a.Part{a2a.TextPart{Text: "foo"}},
TaskID: taskID,
ContextID: contextID,
Metadata: map[string]any{metadataEscalateKey: true, metadataTransferToAgentKey: "a-2"},
},
want: &session.Event{
LLMResponse: model.LLMResponse{
Expand All @@ -56,8 +57,9 @@ func TestToSessionEvent(t *testing.T) {
customMetaContextIDKey: contextID,
},
},
Author: agentName,
Branch: branch,
Author: agentName,
Branch: branch,
Actions: session.EventActions{Escalate: true, TransferToAgent: "a-2"},
},
},
{
Expand Down Expand Up @@ -99,6 +101,7 @@ func TestToSessionEvent(t *testing.T) {
State: a2a.TaskStateCompleted,
Message: a2a.NewMessage(a2a.MessageRoleAgent, a2a.TextPart{Text: "done"}),
},
Metadata: map[string]any{metadataEscalateKey: true},
},
want: &session.Event{
LLMResponse: model.LLMResponse{
Expand All @@ -119,8 +122,9 @@ func TestToSessionEvent(t *testing.T) {
customMetaContextIDKey: contextID,
},
},
Author: agentName,
Branch: branch,
Author: agentName,
Branch: branch,
Actions: session.EventActions{Escalate: true},
},
},
{
Expand Down Expand Up @@ -157,7 +161,8 @@ func TestToSessionEvent(t *testing.T) {
},
},
},
Status: a2a.TaskStatus{State: a2a.TaskStateInputRequired},
Status: a2a.TaskStatus{State: a2a.TaskStateInputRequired},
Metadata: map[string]any{metadataEscalateKey: true},
},
want: &session.Event{
LLMResponse: model.LLMResponse{
Expand All @@ -178,6 +183,7 @@ func TestToSessionEvent(t *testing.T) {
LongRunningToolIDs: []string{"get_weather"},
Author: agentName,
Branch: branch,
Actions: session.EventActions{Escalate: true},
},
},
{
Expand Down
16 changes: 16 additions & 0 deletions server/adka2a/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,19 @@ func toEventMeta(meta invocationMeta, event *session.Event) (map[string]any, err

return result, nil
}

func setActionsMeta(meta map[string]any, actions session.EventActions) map[string]any {
if actions.TransferToAgent == "" && !actions.Escalate { // if meta was nil, it should remain nil
return meta
}
if meta == nil {
meta = map[string]any{}
}
if actions.Escalate {
meta[metadataEscalateKey] = true
}
if actions.TransferToAgent != "" {
meta[metadataTransferToAgentKey] = actions.TransferToAgent
}
return meta
}
22 changes: 18 additions & 4 deletions server/adka2a/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,15 @@ type eventProcessor struct {
reqCtx *a2asrv.RequestContext
meta invocationMeta

// Created once the first TaskArtifactUpdateEvent is sent. Used for subsequent artifact updates.
// terminalActions is used to keep track of escalate and agent transfer actions on processed events.
// It is then gets passed to caller through with metadata of a terminal event.
// This is done to make sure the caller processes it, since intermediate events without parts might be ignored.
terminalActions session.EventActions

// responseID is created once the first TaskArtifactUpdateEvent is sent. Used for subsequent artifact updates.
responseID a2a.ArtifactID

// We don't send terminal events during processing because we don't want A2A server to stop reading from the queue
// until the whole ADK response is saved as an A2A artifact.
// terminalEvents is used to postpone sending a terminal event until the whole ADK response is saved as an A2A artifact.
// The highest-priority terminal event from this map is going to be send as the final Task status update, in the order of priority:
// - failed
// - input_required
Expand All @@ -55,6 +59,8 @@ func (p *eventProcessor) process(_ context.Context, event *session.Event) (*a2a.
return nil, nil
}

p.updateTerminalActions(event)

eventMeta, err := toEventMeta(p.meta, event)
if err != nil {
return nil, err
Expand Down Expand Up @@ -108,14 +114,15 @@ func (p *eventProcessor) makeTerminalEvents() []a2a.Event {

for _, s := range []a2a.TaskState{a2a.TaskStateFailed, a2a.TaskStateInputRequired} {
if ev, ok := p.terminalEvents[s]; ok {
ev.Metadata = setActionsMeta(ev.Metadata, p.terminalActions)
result = append(result, ev)
return result
}
}

ev := a2a.NewStatusUpdateEvent(p.reqCtx, a2a.TaskStateCompleted, nil)
ev.Metadata = p.meta.eventMeta
ev.Final = true
ev.Metadata = setActionsMeta(p.meta.eventMeta, p.terminalActions)
result = append(result, ev)
return result
}
Expand All @@ -132,6 +139,13 @@ func (p *eventProcessor) makeTaskFailedEvent(cause error, event *session.Event)
return toTaskFailedUpdateEvent(p.reqCtx, cause, meta)
}

func (p *eventProcessor) updateTerminalActions(event *session.Event) {
p.terminalActions.Escalate = p.terminalActions.Escalate || event.Actions.Escalate
if event.Actions.TransferToAgent != "" {
p.terminalActions.TransferToAgent = event.Actions.TransferToAgent
}
}

func toTaskFailedUpdateEvent(task a2a.TaskInfoProvider, cause error, meta map[string]any) *a2a.TaskStatusUpdateEvent {
msg := a2a.NewMessageForTask(a2a.MessageRoleAgent, task, a2a.TextPart{Text: cause.Error()})
ev := a2a.NewStatusUpdateEvent(task, a2a.TaskStateFailed, msg)
Expand Down
53 changes: 52 additions & 1 deletion server/adka2a/processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestEventProcessor_Process(t *testing.T) {
{
name: "skip if no response",
events: []*session.Event{
{ID: "125", InvocationID: "345", Actions: session.EventActions{Escalate: true}},
{ID: "125", InvocationID: "345"},
{ID: "127", InvocationID: "345", Branch: "b", Author: "a"},
},
terminal: []a2a.Event{newFinalStatusUpdate(task, a2a.TaskStateCompleted, nil)},
Expand Down Expand Up @@ -244,6 +244,57 @@ func TestEventProcessor_Process(t *testing.T) {
newFinalStatusUpdate(task, a2a.TaskStateInputRequired, nil),
},
},
{
name: "actions in completed event meta",
events: []*session.Event{
{ID: "125", InvocationID: "345", Actions: session.EventActions{Escalate: true, TransferToAgent: "a-2"}},
},
terminal: []a2a.Event{
&a2a.TaskStatusUpdateEvent{
TaskID: task.ID,
ContextID: task.ContextID,
Status: a2a.TaskStatus{State: a2a.TaskStateCompleted},
Metadata: map[string]any{metadataEscalateKey: true, metadataTransferToAgentKey: "a-2"},
Final: true,
},
},
},
{
name: "last agent transfer is returned",
events: []*session.Event{
{ID: "125", InvocationID: "345", Actions: session.EventActions{TransferToAgent: "a-2"}},
{ID: "126", InvocationID: "346", Actions: session.EventActions{TransferToAgent: "a-3"}},
},
terminal: []a2a.Event{
&a2a.TaskStatusUpdateEvent{
TaskID: task.ID,
ContextID: task.ContextID,
Status: a2a.TaskStatus{State: a2a.TaskStateCompleted},
Metadata: map[string]any{metadataTransferToAgentKey: "a-3"},
Final: true,
},
},
},
{
name: "actions not overwritten by subsequent events",
events: []*session.Event{
{
LLMResponse: modelResponseFromParts(genai.NewPartFromText("The answer is")),
Actions: session.EventActions{Escalate: true, TransferToAgent: "a-2"},
},
{LLMResponse: model.LLMResponse{ErrorCode: "1", ErrorMessage: "failed"}},
},
processed: []*a2a.TaskArtifactUpdateEvent{
a2a.NewArtifactEvent(task, a2a.TextPart{Text: "The answer is"}),
},
terminal: []a2a.Event{
newArtifactLastChunkEvent(task),
toTaskFailedUpdateEvent(
task, errorFromResponse(&model.LLMResponse{ErrorCode: "1", ErrorMessage: "failed"}),
map[string]any{ToA2AMetaKey("error_code"): "1", metadataEscalateKey: true, metadataTransferToAgentKey: "a-2"},
),
},
},
}

for _, tc := range testCases {
Expand Down