diff --git a/pkg/cli/mcp_server.go b/pkg/cli/mcp_server.go index bd9370ec21..44cd4d7e3c 100644 --- a/pkg/cli/mcp_server.go +++ b/pkg/cli/mcp_server.go @@ -2,6 +2,7 @@ package cli import ( "context" + "encoding/json" "fmt" "log" "net/http" @@ -13,12 +14,28 @@ import ( "github.com/githubnext/gh-aw/pkg/console" "github.com/githubnext/gh-aw/pkg/logger" "github.com/githubnext/gh-aw/pkg/workflow" + "github.com/modelcontextprotocol/go-sdk/jsonrpc" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/spf13/cobra" ) var mcpLog = logger.New("mcp:server") +// mcpErrorData marshals data to JSON for use in jsonrpc.Error.Data field. +// Returns nil if marshaling fails to avoid errors in error handling. +func mcpErrorData(v any) json.RawMessage { + if v == nil { + return nil + } + data, err := json.Marshal(v) + if err != nil { + // Log the error but return nil to avoid breaking error handling + mcpLog.Printf("Failed to marshal error data: %v", err) + return nil + } + return data +} + // NewMCPServerCommand creates the mcp-server command func NewMCPServerCommand() *cobra.Command { var port int @@ -137,11 +154,11 @@ Note: Output can be filtered using the jq parameter.`, // Check for cancellation before starting select { case <-ctx.Done(): - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Error: %v", ctx.Err())}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "request cancelled", + Data: mcpErrorData(ctx.Err().Error()), + } default: } @@ -156,11 +173,11 @@ Note: Output can be filtered using the jq parameter.`, output, err := cmd.CombinedOutput() if err != nil { - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Error: %v\nOutput: %s", err, string(output))}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "failed to execute status command", + Data: mcpErrorData(map[string]any{"error": err.Error(), "output": string(output)}), + } } // Apply jq filter if provided @@ -168,11 +185,11 @@ Note: Output can be filtered using the jq parameter.`, if args.JqFilter != "" { filteredOutput, jqErr := ApplyJqFilter(outputStr, args.JqFilter) if jqErr != nil { - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Error applying jq filter: %v", jqErr)}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: "invalid jq filter expression", + Data: mcpErrorData(map[string]any{"error": jqErr.Error(), "filter": args.JqFilter}), + } } outputStr = filteredOutput } @@ -227,11 +244,11 @@ Note: Output can be filtered using the jq parameter.`, // Check for cancellation before starting select { case <-ctx.Done(): - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Error: %v", ctx.Err())}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "request cancelled", + Data: mcpErrorData(ctx.Err().Error()), + } default: } @@ -239,21 +256,21 @@ Note: Output can be filtered using the jq parameter.`, if args.Zizmor || args.Poutine || args.Actionlint { // Check if Docker images are available; if not, start downloading and return retry message if err := CheckAndPrepareDockerImages(args.Zizmor, args.Poutine, args.Actionlint); err != nil { - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: err.Error()}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "docker images not ready", + Data: mcpErrorData(err.Error()), + } } // Check for cancellation after Docker image preparation select { case <-ctx.Done(): - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Error: %v", ctx.Err())}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "request cancelled", + Data: mcpErrorData(ctx.Err().Error()), + } default: } } @@ -290,11 +307,11 @@ Note: Output can be filtered using the jq parameter.`, output, err := cmd.CombinedOutput() if err != nil { - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Error: %v\nOutput: %s", err, string(output))}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "failed to compile workflows", + Data: mcpErrorData(map[string]any{"error": err.Error(), "output": string(output)}), + } } // Apply jq filter if provided @@ -302,11 +319,11 @@ Note: Output can be filtered using the jq parameter.`, if args.JqFilter != "" { filteredOutput, jqErr := ApplyJqFilter(outputStr, args.JqFilter) if jqErr != nil { - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Error applying jq filter: %v", jqErr)}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: "invalid jq filter expression", + Data: mcpErrorData(map[string]any{"error": jqErr.Error(), "filter": args.JqFilter}), + } } outputStr = filteredOutput } @@ -374,21 +391,21 @@ to filter the output to a manageable size, or adjust the 'max_tokens' parameter. // Check for cancellation before starting select { case <-ctx.Done(): - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Error: %v", ctx.Err())}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "request cancelled", + Data: mcpErrorData(ctx.Err().Error()), + } default: } // Validate firewall parameters if args.Firewall && args.NoFirewall { - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: "Error: cannot specify both 'firewall' and 'no_firewall' parameters"}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: "conflicting parameters: cannot specify both 'firewall' and 'no_firewall'", + Data: nil, + } } // Build command arguments @@ -440,11 +457,11 @@ to filter the output to a manageable size, or adjust the 'max_tokens' parameter. output, err := cmd.CombinedOutput() if err != nil { - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Error: %v\nOutput: %s", err, string(output))}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "failed to download workflow logs", + Data: mcpErrorData(map[string]any{"error": err.Error(), "output": string(output)}), + } } // Apply jq filter if provided @@ -452,11 +469,11 @@ to filter the output to a manageable size, or adjust the 'max_tokens' parameter. if args.JqFilter != "" { filteredOutput, err := ApplyJqFilter(outputStr, args.JqFilter) if err != nil { - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Error applying jq filter: %v", err)}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: "invalid jq filter expression", + Data: mcpErrorData(map[string]any{"error": err.Error(), "filter": args.JqFilter}), + } } outputStr = filteredOutput } @@ -497,11 +514,11 @@ Note: Output can be filtered using the jq parameter.`, // Check for cancellation before starting select { case <-ctx.Done(): - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Error: %v", ctx.Err())}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "request cancelled", + Data: mcpErrorData(ctx.Err().Error()), + } default: } @@ -515,11 +532,11 @@ Note: Output can be filtered using the jq parameter.`, output, err := cmd.CombinedOutput() if err != nil { - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Error: %v\nOutput: %s", err, string(output))}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "failed to audit workflow run", + Data: mcpErrorData(map[string]any{"error": err.Error(), "output": string(output), "run_id": args.RunID}), + } } // Apply jq filter if provided @@ -527,11 +544,11 @@ Note: Output can be filtered using the jq parameter.`, if args.JqFilter != "" { filteredOutput, jqErr := ApplyJqFilter(outputStr, args.JqFilter) if jqErr != nil { - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Error applying jq filter: %v", jqErr)}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: "invalid jq filter expression", + Data: mcpErrorData(map[string]any{"error": jqErr.Error(), "filter": args.JqFilter}), + } } outputStr = filteredOutput } @@ -575,11 +592,11 @@ Returns formatted text output showing: // Check for cancellation before starting select { case <-ctx.Done(): - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Error: %v", ctx.Err())}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "request cancelled", + Data: mcpErrorData(ctx.Err().Error()), + } default: } @@ -606,11 +623,11 @@ Returns formatted text output showing: output, err := cmd.CombinedOutput() if err != nil { - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Error: %v\nOutput: %s", err, string(output))}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "failed to inspect MCP servers", + Data: mcpErrorData(map[string]any{"error": err.Error(), "output": string(output)}), + } } return &mcp.CallToolResult{ @@ -634,21 +651,21 @@ Returns formatted text output showing: // Check for cancellation before starting select { case <-ctx.Done(): - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Error: %v", ctx.Err())}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "request cancelled", + Data: mcpErrorData(ctx.Err().Error()), + } default: } // Validate required arguments if len(args.Workflows) == 0 { - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: "Error: at least one workflow specification is required"}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInvalidParams, + Message: "missing required parameter: at least one workflow specification is required", + Data: nil, + } } // Build command arguments @@ -670,11 +687,11 @@ Returns formatted text output showing: output, err := cmd.CombinedOutput() if err != nil { - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Error: %v\nOutput: %s", err, string(output))}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "failed to add workflows", + Data: mcpErrorData(map[string]any{"error": err.Error(), "output": string(output)}), + } } return &mcp.CallToolResult{ @@ -713,11 +730,11 @@ Returns formatted text output showing: // Check for cancellation before starting select { case <-ctx.Done(): - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Error: %v", ctx.Err())}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "request cancelled", + Data: mcpErrorData(ctx.Err().Error()), + } default: } @@ -740,11 +757,11 @@ Returns formatted text output showing: output, err := cmd.CombinedOutput() if err != nil { - return &mcp.CallToolResult{ - Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("Error: %v\nOutput: %s", err, string(output))}, - }, - }, nil, nil + return nil, nil, &jsonrpc.Error{ + Code: jsonrpc.CodeInternalError, + Message: "failed to update workflows", + Data: mcpErrorData(map[string]any{"error": err.Error(), "output": string(output)}), + } } return &mcp.CallToolResult{ diff --git a/pkg/cli/mcp_server_error_codes_test.go b/pkg/cli/mcp_server_error_codes_test.go new file mode 100644 index 0000000000..76bb35937e --- /dev/null +++ b/pkg/cli/mcp_server_error_codes_test.go @@ -0,0 +1,215 @@ +//go:build integration + +package cli + +import ( + "context" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/githubnext/gh-aw/pkg/testutil" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// TestMCPServer_ErrorCodes_InvalidParams tests that InvalidParams error code is returned for parameter validation errors +func TestMCPServer_ErrorCodes_InvalidParams(t *testing.T) { + // Skip if the binary doesn't exist + binaryPath := "../../gh-aw" + if _, err := os.Stat(binaryPath); os.IsNotExist(err) { + t.Skip("Skipping test: gh-aw binary not found. Run 'make build' first.") + } + + // Get the current directory for proper path resolution + originalDir, _ := os.Getwd() + + // Create MCP client + client := mcp.NewClient(&mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, nil) + + // Start the MCP server as a subprocess + serverCmd := exec.Command(filepath.Join(originalDir, binaryPath), "mcp-server") + transport := &mcp.CommandTransport{Command: serverCmd} + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("Failed to connect to MCP server: %v", err) + } + defer session.Close() + + // Test 1: add tool with missing workflows parameter + t.Run("add_missing_workflows", func(t *testing.T) { + params := &mcp.CallToolParams{ + Name: "add", + Arguments: map[string]any{}, // Missing required workflows + } + + _, err := session.CallTool(ctx, params) + if err == nil { + t.Error("Expected error for missing workflows parameter, got nil") + return + } + + // The error message should contain the InvalidParams error message + errMsg := err.Error() + if !strings.Contains(errMsg, "missing required parameter") && !strings.Contains(errMsg, "missing properties") { + t.Errorf("Expected error message about missing parameter, got: %s", errMsg) + } else { + t.Logf("✓ Correct error for missing workflows: %s", errMsg) + } + }) + + // Test 2: logs tool with conflicting firewall parameters + t.Run("logs_conflicting_params", func(t *testing.T) { + params := &mcp.CallToolParams{ + Name: "logs", + Arguments: map[string]any{ + "firewall": true, + "no_firewall": true, // Conflicting with firewall + }, + } + + _, err := session.CallTool(ctx, params) + if err == nil { + t.Error("Expected error for conflicting parameters, got nil") + return + } + + // The error message should contain the conflicting parameters error + errMsg := err.Error() + if !strings.Contains(errMsg, "conflicting parameters") { + t.Errorf("Expected error message about conflicting parameters, got: %s", errMsg) + } else { + t.Logf("✓ Correct error for conflicting parameters: %s", errMsg) + } + }) + + // Test 3: invalid jq filter + t.Run("status_invalid_jq_filter", func(t *testing.T) { + // Create a temporary directory with a workflow file + tmpDir := testutil.TempDir(t, "test-*") + workflowsDir := filepath.Join(tmpDir, ".github", "workflows") + if err := os.MkdirAll(workflowsDir, 0755); err != nil { + t.Fatalf("Failed to create workflows directory: %v", err) + } + + // Create a test workflow file + workflowContent := `--- +on: push +engine: copilot +--- +# Test Workflow + +This is a test workflow. +` + workflowPath := filepath.Join(workflowsDir, "test.md") + if err := os.WriteFile(workflowPath, []byte(workflowContent), 0644); err != nil { + t.Fatalf("Failed to write workflow file: %v", err) + } + + // Initialize git repository in the temp directory + initCmd := exec.Command("git", "init") + initCmd.Dir = tmpDir + if err := initCmd.Run(); err != nil { + t.Fatalf("Failed to initialize git repository: %v", err) + } + + // Start new MCP server in the temp directory + serverCmd := exec.Command(filepath.Join(originalDir, binaryPath), "mcp-server") + serverCmd.Dir = tmpDir + transport := &mcp.CommandTransport{Command: serverCmd} + + ctx2, cancel2 := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel2() + + session2, err := client.Connect(ctx2, transport, nil) + if err != nil { + t.Fatalf("Failed to connect to MCP server: %v", err) + } + defer session2.Close() + + params := &mcp.CallToolParams{ + Name: "status", + Arguments: map[string]any{ + "jq": ".invalid[syntax", // Invalid jq filter + }, + } + + _, err = session2.CallTool(ctx2, params) + if err == nil { + t.Error("Expected error for invalid jq filter, got nil") + return + } + + // The error message should contain the invalid jq filter error + errMsg := err.Error() + if !strings.Contains(errMsg, "invalid jq filter") { + t.Errorf("Expected error message about invalid jq filter, got: %s", errMsg) + } else { + t.Logf("✓ Correct error for invalid jq filter: %s", errMsg) + } + }) +} + +// TestMCPServer_ErrorCodes_InternalError tests that InternalError code is returned for execution failures +func TestMCPServer_ErrorCodes_InternalError(t *testing.T) { + // Skip if the binary doesn't exist + binaryPath := "../../gh-aw" + if _, err := os.Stat(binaryPath); os.IsNotExist(err) { + t.Skip("Skipping test: gh-aw binary not found. Run 'make build' first.") + } + + // Get the current directory for proper path resolution + originalDir, _ := os.Getwd() + + // Create MCP client + client := mcp.NewClient(&mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, nil) + + // Start the MCP server as a subprocess + serverCmd := exec.Command(filepath.Join(originalDir, binaryPath), "mcp-server") + transport := &mcp.CommandTransport{Command: serverCmd} + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("Failed to connect to MCP server: %v", err) + } + defer session.Close() + + // Test: audit tool with invalid run_id (should cause internal error) + t.Run("audit_invalid_run_id", func(t *testing.T) { + params := &mcp.CallToolParams{ + Name: "audit", + Arguments: map[string]any{ + "run_id": int64(1), // Invalid run ID + }, + } + + _, err := session.CallTool(ctx, params) + if err == nil { + t.Error("Expected error for invalid run_id, got nil") + return + } + + // The error message should contain the failed audit error + errMsg := err.Error() + if !strings.Contains(errMsg, "failed to audit workflow run") { + t.Errorf("Expected error message about failed audit, got: %s", errMsg) + } else { + t.Logf("✓ Correct error for failed audit: %s", errMsg) + } + }) +}