diff --git a/config/paths.go b/config/paths.go index cd1c89f..e57ed0d 100644 --- a/config/paths.go +++ b/config/paths.go @@ -12,9 +12,17 @@ import ( ) const ( - RewriteIdHeader = "RewriteId" + PascalCaseRewriteIdHeader = "Rewriteid" + SnakeCaseRewriteIdHeader = "Rewrite_id" + KebabCaseRewriteIdHeader = "Rewrite-Id" ) +var rewriteIdHeaders = []string{ + PascalCaseRewriteIdHeader, + SnakeCaseRewriteIdHeader, + KebabCaseRewriteIdHeader, +} + type PathRewrite struct { RewrittenPath string PathConfiguration *shared.WiretapPathConfig @@ -94,13 +102,54 @@ func rewriteTaget(path string, pathConfig *shared.WiretapPathConfig, configurati } } +func getRewriteIdHeaderValues(req *http.Request) ([]string, bool) { + + // Let's first try to get the header with expected key names + for _, possibleHeaderKey := range rewriteIdHeaders { + + if rewriteIdHeaderValues, ok := req.Header[possibleHeaderKey]; ok { + return rewriteIdHeaderValues, true + } + + if rewriteIdHeaderValues, ok := req.Header[strings.ToLower(possibleHeaderKey)]; ok { + return rewriteIdHeaderValues, true + } + + } + + // Let's now try to ignore case ; this may produce collisions if a user has two headers with similar keys, + // but different capitalization. This is okay, as this is a last ditch effort to find any possible match + loweredHeaders := map[string][]string{} + + for headerKey, headerValues := range req.Header { + loweredKey := strings.ToLower(headerKey) + + if _, ok := loweredHeaders[loweredKey]; ok { + loweredHeaders[loweredKey] = append(loweredHeaders[loweredKey], headerValues...) + } else { + loweredHeaders[loweredKey] = headerValues + } + + } + + for _, possibleHeaderKey := range rewriteIdHeaders { + + if rewriteIdHeaderValues, ok := loweredHeaders[strings.ToLower(possibleHeaderKey)]; ok { + return rewriteIdHeaderValues, true + } + + } + + return []string{}, false +} + func FindPathWithRewriteId(paths []*shared.WiretapPathConfig, req *http.Request) *shared.WiretapPathConfig { if req == nil { return nil } - if rewriteIdHeaderValues, ok := req.Header[RewriteIdHeader]; ok { + if rewriteIdHeaderValues, ok := getRewriteIdHeaderValues(req); ok { for _, pathRewriteConfig := range paths { // Iterate through header values - since it's a multi-value field diff --git a/config/paths_test.go b/config/paths_test.go index c52a377..8efbccc 100644 --- a/config/paths_test.go +++ b/config/paths_test.go @@ -494,3 +494,96 @@ func TestValidationAllowList_NoPathsRegistered(t *testing.T) { assert.False(t, ignore) } + +func TestGetRewriteHeaderValues(t *testing.T) { + + expectedValue := []string{"ExpectedValue"} + + requestList := []*http.Request{ + { + Header: http.Header{ + "Rewriteid": expectedValue, + "Other-Header": []string{"another header"}, + "other-header": []string{"another another header"}, + }, + }, + { + Header: http.Header{ + "Rewrite-Id": expectedValue, + "Other-Header": []string{"another header"}, + "other-header": []string{"another another header"}, + }, + }, + { + Header: http.Header{ + "Rewrite_id": expectedValue, + "Other-Header": []string{"another header"}, + "other-header": []string{"another another header"}, + }, + }, + { + Header: http.Header{ + "RewriteId": expectedValue, + "Other-Header": []string{"another header"}, + "other-header": []string{"another another header"}, + }, + }, + { + Header: http.Header{ + "RewrIte-Id": expectedValue, + "Other-Header": []string{"another header"}, + "other-header": []string{"another another header"}, + }, + }, + { + Header: http.Header{ + "rewriteid": expectedValue, + "Other-Header": []string{"another header"}, + "other-header": []string{"another another header"}, + }, + }, + { + Header: http.Header{ + "rewrite-id": expectedValue, + "Other-Header": []string{"another header"}, + "other-header": []string{"another another header"}, + }, + }, + { + Header: http.Header{ + "rewrite_id": expectedValue, + "Other-Header": []string{"another header"}, + "other-header": []string{"another another header"}, + }, + }, + } + + for _, request := range requestList { + actualValue, found := getRewriteIdHeaderValues(request) + assert.Equal(t, expectedValue, actualValue) + assert.True(t, found) + } + +} + +func TestGetRewriteHeaderValues_MissingHeader(t *testing.T) { + + requestList := []*http.Request{ + { + Header: http.Header{ + "Other-Header": []string{"another header"}, + "other-header": []string{"another another header"}, + }, + }, + { + Header: http.Header{}, + }, + } + + for _, request := range requestList { + actualValue, found := getRewriteIdHeaderValues(request) + assert.Equal(t, []string{}, actualValue) + assert.False(t, found) + } + +}