diff --git a/server/cancel_subagent_test.go b/server/cancel_subagent_test.go new file mode 100644 index 00000000..1a85f0c8 --- /dev/null +++ b/server/cancel_subagent_test.go @@ -0,0 +1,110 @@ +package server + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "shelley.exe.dev/db" + "shelley.exe.dev/db/generated" + "shelley.exe.dev/llm" +) + +// TestCancelParentCancelsSubagents verifies that cancelling a parent conversation +// also cancels any active subagent conversations. +func TestCancelParentCancelsSubagents(t *testing.T) { + server, database, _ := newTestServer(t) + + // Create parent conversation + conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil, db.ConversationOptions{}) + if err != nil { + t.Fatalf("failed to create conversation: %v", err) + } + parentID := conversation.ConversationID + + // Start the parent conversation. The predictable model will: + // 1. See "subagent: worker bash: sleep 30" and invoke the subagent tool + // 2. The subagent will receive "bash: sleep 30" and start a long sleep + chatReq := ChatRequest{ + Message: "subagent: worker bash: sleep 30", + Model: "predictable", + } + chatBody, _ := json.Marshal(chatReq) + + req := httptest.NewRequest("POST", "/api/conversation/"+parentID+"/chat", strings.NewReader(string(chatBody))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + server.handleChatConversation(w, req, parentID) + if w.Code != http.StatusAccepted { + t.Fatalf("expected 202, got %d: %s", w.Code, w.Body.String()) + } + + // Wait for the subagent to exist and start working + var subagentID string + waitFor(t, 10*time.Second, func() bool { + subagents, err := database.GetSubagents(context.Background(), parentID) + if err != nil || len(subagents) == 0 { + return false + } + subagentID = subagents[0].ConversationID + return server.IsAgentWorking(subagentID) + }) + + t.Logf("subagent %s is working", subagentID) + + // Verify parent is also working (blocked on the subagent tool call) + if !server.IsAgentWorking(parentID) { + t.Fatal("expected parent to be working") + } + + // Cancel the parent + cancelReq := httptest.NewRequest("POST", "/api/conversation/"+parentID+"/cancel", nil) + cancelW := httptest.NewRecorder() + server.handleCancelConversation(cancelW, cancelReq, parentID) + if cancelW.Code != http.StatusOK { + t.Fatalf("cancel expected 200, got %d: %s", cancelW.Code, cancelW.Body.String()) + } + + // Wait for parent to stop working + waitFor(t, 5*time.Second, func() bool { + return !server.IsAgentWorking(parentID) + }) + + // The subagent must also stop working + waitFor(t, 5*time.Second, func() bool { + return !server.IsAgentWorking(subagentID) + }) + + // Verify subagent has a cancellation end-of-turn message + var subMsgs []generated.Message + err = database.Queries(context.Background(), func(q *generated.Queries) error { + var qerr error + subMsgs, qerr = q.ListMessages(context.Background(), subagentID) + return qerr + }) + if err != nil { + t.Fatalf("failed to list subagent messages: %v", err) + } + + foundEndTurn := false + for _, msg := range subMsgs { + if msg.Type != string(db.MessageTypeAgent) || msg.LlmData == nil { + continue + } + var llmMsg llm.Message + if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil { + continue + } + if llmMsg.EndOfTurn { + foundEndTurn = true + break + } + } + if !foundEndTurn { + t.Error("expected subagent to have an end-of-turn message after parent cancellation") + } +} diff --git a/server/handlers.go b/server/handlers.go index 1ae85a39..4a10a4ef 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -908,6 +908,9 @@ func (s *Server) handleCancelConversation(w http.ResponseWriter, r *http.Request return } + // Cancel active subagent conversations first + s.cancelSubagents(ctx, conversationID) + // Cancel the conversation if err := manager.CancelConversation(ctx); err != nil { s.logger.Error("Failed to cancel conversation", "conversationID", conversationID, "error", err) diff --git a/server/server.go b/server/server.go index c2c65592..be017495 100644 --- a/server/server.go +++ b/server/server.go @@ -1126,6 +1126,28 @@ func (s *Server) getWorkingConversations() map[string]bool { return working } +// cancelSubagents cancels all active subagent conversations for the given parent. +func (s *Server) cancelSubagents(ctx context.Context, parentID string) { + subagents, err := s.db.GetSubagents(ctx, parentID) + if err != nil { + s.logger.Error("Failed to get subagents for cancellation", "parentID", parentID, "error", err) + return + } + + for _, sub := range subagents { + s.mu.Lock() + mgr, ok := s.activeConversations[sub.ConversationID] + s.mu.Unlock() + if !ok || !mgr.IsAgentWorking() { + continue + } + s.logger.Info("Cancelling subagent", "subagentID", sub.ConversationID, "parentID", parentID) + if err := mgr.CancelConversation(ctx); err != nil { + s.logger.Error("Failed to cancel subagent", "subagentID", sub.ConversationID, "error", err) + } + } +} + // IsAgentWorking returns whether the agent is currently working on the given conversation. // Returns false if the conversation doesn't have an active manager. func (s *Server) IsAgentWorking(conversationID string) bool {