Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(router): handle repeated headers in response #1537

Merged
merged 4 commits into from
Feb 19, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 83 additions & 5 deletions router-tests/header_propagation_test.go
Original file line number Diff line number Diff line change
@@ -18,7 +18,9 @@ func TestHeaderPropagation(t *testing.T) {
const (
customHeader = "X-Custom-Header"
employeeVal = "employee-value"
employeeVal2 = "employee-value-2"
hobbyVal = "hobby-value"
hobbyVal2 = "hobby-value-2"
)

const queryEmployeeWithHobby = `{
@@ -97,20 +99,20 @@ func TestHeaderPropagation(t *testing.T) {
}
}

setSubgraphPropagateHeader := func(header, valA, valB string) testenv.SubgraphsConfig {
setSubgraphPropagateHeader := func(header string, valA, valB []string) testenv.SubgraphsConfig {
return testenv.SubgraphsConfig{
Employees: testenv.SubgraphConfig{
Middleware: func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(header, valA)
w.Header()[header] = valA
handler.ServeHTTP(w, r)
})
},
},
Hobbies: testenv.SubgraphConfig{
Middleware: func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set(header, valB)
w.Header()[header] = valB
handler.ServeHTTP(w, r)
})
},
@@ -140,11 +142,12 @@ func TestHeaderPropagation(t *testing.T) {
}

cacheOptions := func(cacheControlEmployees, cacheControlHobbies string) testenv.SubgraphsConfig {
return setSubgraphPropagateHeader("Cache-Control", cacheControlEmployees, cacheControlHobbies)
return setSubgraphPropagateHeader("Cache-Control", []string{cacheControlEmployees}, []string{cacheControlHobbies})
}

var (
subgraphsPropagateCustomHeader = setSubgraphPropagateHeader(customHeader, employeeVal, hobbyVal)
subgraphsPropagateCustomHeader = setSubgraphPropagateHeader(customHeader, []string{employeeVal}, []string{hobbyVal})
subgraphsPropagateRepeatedCustomHeader = setSubgraphPropagateHeader(customHeader, []string{employeeVal, employeeVal2}, []string{hobbyVal, hobbyVal2})
)

t.Run(" no propagate", func(t *testing.T) {
@@ -234,6 +237,21 @@ func TestHeaderPropagation(t *testing.T) {
require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body)
})
})

t.Run("repeated header names last write wins", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
RouterOptions: global(config.ResponseHeaderRuleAlgorithmLastWrite, customHeader, ""),
Subgraphs: subgraphsPropagateRepeatedCustomHeader,
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: queryEmployeeWithHobby,
})
ch := strings.Join(res.Response.Header.Values(customHeader), ",")
require.Equal(t, "hobby-value,hobby-value-2", ch)
require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body)
})
})
})

// Test for the First Write Wins Algorithm
@@ -283,6 +301,21 @@ func TestHeaderPropagation(t *testing.T) {
require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body)
})
})

t.Run("repeated header names first write wins", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
RouterOptions: partial(config.ResponseHeaderRuleAlgorithmFirstWrite, customHeader, ""),
Subgraphs: subgraphsPropagateRepeatedCustomHeader,
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: queryEmployeeWithHobby,
})
ch := strings.Join(res.Response.Header.Values(customHeader), ",")
require.Equal(t, "employee-value,employee-value-2", ch)
require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body)
})
})
})

// Test for the Append Algorithm
@@ -332,6 +365,21 @@ func TestHeaderPropagation(t *testing.T) {
require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body)
})
})

t.Run("repeated header names append headers", func(t *testing.T) {
t.Parallel()
testenv.Run(t, &testenv.Config{
RouterOptions: global(config.ResponseHeaderRuleAlgorithmAppend, customHeader, ""),
Subgraphs: subgraphsPropagateRepeatedCustomHeader,
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: queryEmployeeWithHobby,
})
ch := strings.Join(res.Response.Header.Values(customHeader), ",")
require.Equal(t, "employee-value,employee-value-2,hobby-value,hobby-value-2", ch)
require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body)
})
})
})

t.Run("Cache Control Propagation", func(t *testing.T) {
@@ -718,4 +766,34 @@ func TestHeaderPropagation(t *testing.T) {
})
})
})

t.Run("header name canonicalization", func(t *testing.T) {
t.Parallel()
nonCanonicalCustomHeader := "x-Custom-header"
subgraphsNonCanonicalHeader := testenv.SubgraphsConfig{
Employees: testenv.SubgraphConfig{
Middleware: func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header()[nonCanonicalCustomHeader] = []string{employeeVal}
handler.ServeHTTP(w, r)
})
},
},
}

testenv.Run(t, &testenv.Config{
RouterOptions: global(config.ResponseHeaderRuleAlgorithmAppend, nonCanonicalCustomHeader, ""),
Subgraphs: subgraphsNonCanonicalHeader,
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: queryEmployeeWithHobby,
})
cch := strings.Join(res.Response.Header.Values(customHeader), ",")
require.Equal(t, employeeVal, cch)
ncch := strings.Join(res.Response.Header[nonCanonicalCustomHeader], ",")
require.Equal(t, "", ncch)

require.Equal(t, `{"data":{"employee":{"id":1,"hobbies":[{},{"name":"Counter Strike"},{},{},{}]}}}`, res.Body)
})
})
}
29 changes: 16 additions & 13 deletions router/core/header_rule_engine.go
Original file line number Diff line number Diff line change
@@ -310,11 +310,11 @@ func (h *HeaderPropagation) applyResponseRule(propagation *responseHeaderPropaga
return
}

value := res.Header.Get(rule.Named)
if value != "" {
h.applyResponseRuleKeyValue(res, propagation, rule, rule.Named, value)
values := res.Header.Values(rule.Named)
if len(values) > 0 {
h.applyResponseRuleKeyValue(res, propagation, rule, rule.Named, values)
} else if rule.Default != "" {
h.applyResponseRuleKeyValue(res, propagation, rule, rule.Named, rule.Default)
h.applyResponseRuleKeyValue(res, propagation, rule, rule.Named, []string{rule.Default})
}

return
@@ -325,31 +325,34 @@ func (h *HeaderPropagation) applyResponseRule(propagation *responseHeaderPropaga
if slices.Contains(ignoredHeaders, name) {
continue
}
h.applyResponseRuleKeyValue(res, propagation, rule, name, res.Header.Get(name))
values := res.Header.Values(name)
h.applyResponseRuleKeyValue(res, propagation, rule, name, values)
}
}
}
} else if rule.Algorithm == config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl {
// Explicitly apply the CacheControl algorithm on the headers
h.applyResponseRuleKeyValue(res, propagation, rule, "", "")
h.applyResponseRuleKeyValue(res, propagation, rule, "", []string{""})
}
}

func (h *HeaderPropagation) applyResponseRuleKeyValue(res *http.Response, propagation *responseHeaderPropagation, rule *config.ResponseHeaderRule, key, value string) {
func (h *HeaderPropagation) applyResponseRuleKeyValue(res *http.Response, propagation *responseHeaderPropagation, rule *config.ResponseHeaderRule, key string, values []string) {
// Since we'll be setting the header map directly, we need to canonicalize the key
key = http.CanonicalHeaderKey(key)
switch rule.Algorithm {
case config.ResponseHeaderRuleAlgorithmFirstWrite:
propagation.m.Lock()
if val := propagation.header.Get(key); val == "" {
propagation.header.Set(key, value)
propagation.header[key] = values
}
propagation.m.Unlock()
case config.ResponseHeaderRuleAlgorithmLastWrite:
propagation.m.Lock()
propagation.header.Set(key, value)
propagation.header[key] = values
propagation.m.Unlock()
case config.ResponseHeaderRuleAlgorithmAppend:
propagation.m.Lock()
propagation.header.Add(key, value)
propagation.header[key] = append(propagation.header[key], values...)
propagation.m.Unlock()
case config.ResponseHeaderRuleAlgorithmMostRestrictiveCacheControl:
h.applyResponseRuleMostRestrictiveCacheControl(res, propagation, rule)
@@ -408,9 +411,9 @@ func (h *HeaderPropagation) applyRequestRule(ctx RequestContext, request *http.R
return
}

value := ctx.Request().Header.Get(rule.Named)
if value != "" {
request.Header.Set(rule.Named, ctx.Request().Header.Get(rule.Named))
values := ctx.Request().Header.Values(rule.Named)
if len(values) > 0 {
request.Header[http.CanonicalHeaderKey(rule.Named)] = values
} else if rule.Default != "" {
request.Header.Set(rule.Named, rule.Default)
}
38 changes: 37 additions & 1 deletion router/core/header_rule_engine_test.go
Original file line number Diff line number Diff line change
@@ -58,6 +58,42 @@ func TestPropagateHeaderRule(t *testing.T) {
assert.Equal(t, "test3", updatedClientReq.Header.Get("X-Test-3"))
})

t.Run("Should propagate repeated header names", func(t *testing.T) {

ht, err := NewHeaderPropagation(&config.HeaderRules{
All: &config.GlobalHeaderRule{
Request: []*config.RequestHeaderRule{
{
Operation: "propagate",
Named: "X-Test-1",
},
},
},
})
assert.Nil(t, err)

rr := httptest.NewRecorder()

clientReq, err := http.NewRequest("POST", "http://localhost", nil)
require.NoError(t, err)
clientReq.Header.Add("X-Test-1", "test1")
clientReq.Header.Add("X-Test-1", "test2")

originReq, err := http.NewRequest("POST", "http://localhost", nil)
assert.Nil(t, err)

updatedClientReq, _ := ht.OnOriginRequest(originReq, &requestContext{
logger: zap.NewNop(),
responseWriter: rr,
request: clientReq,
operation: &operationContext{},
subgraphResolver: NewSubgraphResolver(nil),
})

assert.Len(t, updatedClientReq.Header, 1)
assert.Equal(t, []string{"test1", "test2"}, updatedClientReq.Header.Values("X-Test-1"))
})

t.Run("Should propagate based on matching regex / matching", func(t *testing.T) {
ht, err := NewHeaderPropagation(&config.HeaderRules{
All: &config.GlobalHeaderRule{
@@ -217,7 +253,7 @@ func TestPropagateHeaderRule(t *testing.T) {

})

t.Run("Should handle nil resonses", func(t *testing.T) {
t.Run("Should handle nil responses", func(t *testing.T) {
ht, err := NewHeaderPropagation(&config.HeaderRules{
All: &config.GlobalHeaderRule{},
})