Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changeset/patch-add-changeset-pr-8700.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 30 additions & 9 deletions pkg/awmg/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,18 +303,39 @@ func parseGatewayConfig(data []byte) (*MCPGatewayServiceConfig, error) {

gatewayLog.Printf("Successfully parsed JSON configuration")

// Filter out internal workflow MCP servers (safeinputs and safeoutputs)
// These are used internally by the workflow and should not be proxied by the gateway
filteredServers := make(map[string]parser.MCPServerConfig)
// Apply environment variable expansion to all server configurations
// This supports ${VAR} or $VAR patterns in URLs, headers, and env values
expandedServers := make(map[string]parser.MCPServerConfig)
for name, serverConfig := range config.MCPServers {
if name == "safeinputs" || name == "safeoutputs" {
gatewayLog.Printf("Filtering out internal workflow server: %s", name)
fmt.Fprintln(os.Stderr, console.FormatInfoMessage(fmt.Sprintf("Filtering out internal workflow server: %s", name)))
continue
// Expand URL field
if serverConfig.URL != "" {
serverConfig.URL = os.ExpandEnv(serverConfig.URL)
gatewayLog.Printf("Expanded URL for server %s: %s", name, serverConfig.URL)
}
filteredServers[name] = serverConfig

// Expand headers
if len(serverConfig.Headers) > 0 {
expandedHeaders := make(map[string]string)
for key, value := range serverConfig.Headers {
expandedHeaders[key] = os.ExpandEnv(value)
}
serverConfig.Headers = expandedHeaders
gatewayLog.Printf("Expanded %d headers for server %s", len(expandedHeaders), name)
}

// Expand environment variables
if len(serverConfig.Env) > 0 {
expandedEnv := make(map[string]string)
for key, value := range serverConfig.Env {
expandedEnv[key] = os.ExpandEnv(value)
}
serverConfig.Env = expandedEnv
gatewayLog.Printf("Expanded %d env vars for server %s", len(expandedEnv), name)
}

expandedServers[name] = serverConfig
}
config.MCPServers = filteredServers
config.MCPServers = expandedServers

return &config, nil
}
Expand Down
76 changes: 52 additions & 24 deletions pkg/awmg/gateway_rewrite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import (
"github.com/githubnext/gh-aw/pkg/parser"
)

// TestRewriteMCPConfigForGateway_PreservesNonProxiedServers tests that
// servers not being proxied (like safeinputs/safeoutputs) are preserved unchanged
func TestRewriteMCPConfigForGateway_PreservesNonProxiedServers(t *testing.T) {
// TestRewriteMCPConfigForGateway_ProxiesSafeInputsAndSafeOutputs tests that
// safeinputs and safeoutputs servers ARE proxied through the gateway (rewritten)
func TestRewriteMCPConfigForGateway_ProxiesSafeInputsAndSafeOutputs(t *testing.T) {
// Create a temporary config file
tmpDir := t.TempDir()
configFile := filepath.Join(tmpDir, "test-config.json")
Expand Down Expand Up @@ -44,9 +44,21 @@ func TestRewriteMCPConfigForGateway_PreservesNonProxiedServers(t *testing.T) {
t.Fatalf("Failed to write config file: %v", err)
}

// Gateway config only includes external server (github), not internal servers
// Gateway config includes ALL servers (including safeinputs/safeoutputs)
gatewayConfig := &MCPGatewayServiceConfig{
MCPServers: map[string]parser.MCPServerConfig{
"safeinputs": {
BaseMCPServerConfig: types.BaseMCPServerConfig{
Command: "gh",
Args: []string{"aw", "mcp-server", "--mode", "safe-inputs"},
},
},
"safeoutputs": {
BaseMCPServerConfig: types.BaseMCPServerConfig{
Command: "gh",
Args: []string{"aw", "mcp-server", "--mode", "safe-outputs"},
},
},
"github": {
BaseMCPServerConfig: types.BaseMCPServerConfig{
Command: "docker",
Expand Down Expand Up @@ -81,45 +93,61 @@ func TestRewriteMCPConfigForGateway_PreservesNonProxiedServers(t *testing.T) {
t.Fatal("mcpServers not found or wrong type")
}

// Should have all 3 servers: 2 preserved + 1 rewritten
// Should have all 3 servers, all rewritten
if len(mcpServers) != 3 {
t.Errorf("Expected 3 servers in rewritten config, got %d", len(mcpServers))
}

// Verify safeinputs is preserved with original command/args
// Verify safeinputs points to gateway (rewritten)
safeinputs, ok := mcpServers["safeinputs"].(map[string]any)
if !ok {
t.Fatal("safeinputs server not found")
}

safeinputsCommand, ok := safeinputs["command"].(string)
if !ok || safeinputsCommand != "gh" {
t.Errorf("Expected safeinputs to preserve original command 'gh', got '%v'", safeinputsCommand)
safeinputsURL, ok := safeinputs["url"].(string)
if !ok {
t.Fatal("safeinputs server should have url (rewritten)")
}

safeinputsArgs, ok := safeinputs["args"].([]any)
if !ok {
t.Error("Expected safeinputs to have args array")
} else if len(safeinputsArgs) < 3 {
t.Errorf("Expected safeinputs to have at least 3 args, got %d", len(safeinputsArgs))
expectedURL := "http://localhost:8080/mcp/safeinputs"
if safeinputsURL != expectedURL {
t.Errorf("Expected safeinputs URL %s, got %s", expectedURL, safeinputsURL)
}

// Verify safeoutputs is preserved with original command/args
safeinputsType, ok := safeinputs["type"].(string)
if !ok || safeinputsType != "http" {
t.Errorf("Expected safeinputs to have type 'http', got %v", safeinputsType)
}

// Verify safeinputs does NOT have command/args (was rewritten)
if _, hasCommand := safeinputs["command"]; hasCommand {
t.Error("Rewritten safeinputs server should not have 'command' field")
}

// Verify safeoutputs points to gateway (rewritten)
safeoutputs, ok := mcpServers["safeoutputs"].(map[string]any)
if !ok {
t.Fatal("safeoutputs server not found")
}

safeoutputsCommand, ok := safeoutputs["command"].(string)
if !ok || safeoutputsCommand != "gh" {
t.Errorf("Expected safeoutputs to preserve original command 'gh', got '%v'", safeoutputsCommand)
safeoutputsURL, ok := safeoutputs["url"].(string)
if !ok {
t.Fatal("safeoutputs server should have url (rewritten)")
}

expectedURL = "http://localhost:8080/mcp/safeoutputs"
if safeoutputsURL != expectedURL {
t.Errorf("Expected safeoutputs URL %s, got %s", expectedURL, safeoutputsURL)
}

safeoutputsArgs, ok := safeoutputs["args"].([]any)
if !ok {
t.Error("Expected safeoutputs to have args array")
} else if len(safeoutputsArgs) < 3 {
t.Errorf("Expected safeoutputs to have at least 3 args, got %d", len(safeoutputsArgs))
safeoutputsType, ok := safeoutputs["type"].(string)
if !ok || safeoutputsType != "http" {
t.Errorf("Expected safeoutputs to have type 'http', got %v", safeoutputsType)
}

// Verify safeoutputs does NOT have command/args (was rewritten)
if _, hasCommand := safeoutputs["command"]; hasCommand {
t.Error("Rewritten safeoutputs server should not have 'command' field")
}

// Verify github server points to gateway (was rewritten)
Expand All @@ -133,7 +161,7 @@ func TestRewriteMCPConfigForGateway_PreservesNonProxiedServers(t *testing.T) {
t.Fatal("github server should have url (rewritten)")
}

expectedURL := "http://localhost:8080/mcp/github"
expectedURL = "http://localhost:8080/mcp/github"
if githubURL != expectedURL {
t.Errorf("Expected github URL %s, got %s", expectedURL, githubURL)
}
Expand Down
60 changes: 40 additions & 20 deletions pkg/awmg/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ func TestMergeConfigs_EmptyOverride(t *testing.T) {
}
}

func TestParseGatewayConfig_FiltersInternalServers(t *testing.T) {
func TestParseGatewayConfig_IncludesSafeInputsAndSafeOutputs(t *testing.T) {
// Create a config with safeinputs, safeoutputs, and other servers
configJSON := `{
"mcpServers": {
Expand Down Expand Up @@ -454,13 +454,13 @@ func TestParseGatewayConfig_FiltersInternalServers(t *testing.T) {
t.Fatalf("Failed to parse config: %v", err)
}

// Verify that safeinputs and safeoutputs are filtered out
if _, exists := config.MCPServers["safeinputs"]; exists {
t.Error("safeinputs should be filtered out")
// Verify that safeinputs and safeoutputs are included (not filtered)
if _, exists := config.MCPServers["safeinputs"]; !exists {
t.Error("safeinputs should be included")
}

if _, exists := config.MCPServers["safeoutputs"]; exists {
t.Error("safeoutputs should be filtered out")
if _, exists := config.MCPServers["safeoutputs"]; !exists {
t.Error("safeoutputs should be included")
}

// Verify that other servers are kept
Expand All @@ -472,23 +472,29 @@ func TestParseGatewayConfig_FiltersInternalServers(t *testing.T) {
t.Error("custom-server should be kept")
}

// Verify server count
if len(config.MCPServers) != 2 {
t.Errorf("Expected 2 servers after filtering, got %d", len(config.MCPServers))
// Verify server count - all 4 servers should be present
if len(config.MCPServers) != 4 {
t.Errorf("Expected 4 servers, got %d", len(config.MCPServers))
}
}

func TestParseGatewayConfig_OnlyInternalServers(t *testing.T) {
// Create a config with only safeinputs and safeoutputs
func TestParseGatewayConfig_TemplateSubstitution(t *testing.T) {
// Set environment variables for testing
t.Setenv("TEST_PORT", "3000")
t.Setenv("TEST_API_KEY", "test-secret-key")
t.Setenv("TEST_ENV_VALUE", "test-value")

configJSON := `{
"mcpServers": {
"safeinputs": {
"command": "node",
"args": ["/tmp/gh-aw/safeinputs/mcp-server.cjs"]
},
"safeoutputs": {
"command": "node",
"args": ["/tmp/gh-aw/safeoutputs/mcp-server.cjs"]
"type": "http",
"url": "http://localhost:${TEST_PORT}",
"headers": {
"Authorization": "Bearer ${TEST_API_KEY}"
},
"env": {
"CUSTOM_VAR": "${TEST_ENV_VALUE}"
}
}
}
}`
Expand All @@ -498,9 +504,23 @@ func TestParseGatewayConfig_OnlyInternalServers(t *testing.T) {
t.Fatalf("Failed to parse config: %v", err)
}

// Verify that all internal servers are filtered out, resulting in 0 servers
if len(config.MCPServers) != 0 {
t.Errorf("Expected 0 servers after filtering internal servers, got %d", len(config.MCPServers))
// Verify URL expansion
safeinputs := config.MCPServers["safeinputs"]
expectedURL := "http://localhost:3000"
if safeinputs.URL != expectedURL {
t.Errorf("Expected URL %s, got %s", expectedURL, safeinputs.URL)
}

// Verify headers expansion
expectedAuth := "Bearer test-secret-key"
if safeinputs.Headers["Authorization"] != expectedAuth {
t.Errorf("Expected Authorization header %s, got %s", expectedAuth, safeinputs.Headers["Authorization"])
}

// Verify env expansion
expectedEnvValue := "test-value"
if safeinputs.Env["CUSTOM_VAR"] != expectedEnvValue {
t.Errorf("Expected env CUSTOM_VAR=%s, got %s", expectedEnvValue, safeinputs.Env["CUSTOM_VAR"])
}
}

Expand Down
Loading