diff --git a/cors.go b/cors.go index 724f242..46a4171 100644 --- a/cors.go +++ b/cors.go @@ -30,10 +30,6 @@ import ( "github.com/rs/cors/internal" ) -var headerVaryOrigin = []string{"Origin"} -var headerOriginAll = []string{"*"} -var headerTrue = []string{"true"} - // Options is a configuration container to setup the CORS middleware. type Options struct { // AllowedOrigins is a list of origins a cross-domain request can be executed from. @@ -133,7 +129,7 @@ type Cors struct { allowCredentials bool allowPrivateNetwork bool optionPassthrough bool - preflightVary []string + preflightVary string } // New creates a new Cors handler with the provided options. @@ -229,9 +225,9 @@ func New(options Options) *Cors { // Pre-compute prefight Vary header to save allocations if c.allowPrivateNetwork { - c.preflightVary = []string{"Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Request-Private-Network"} + c.preflightVary = "Origin, Access-Control-Request-Method, Access-Control-Request-Headers, Access-Control-Request-Private-Network" } else { - c.preflightVary = []string{"Origin, Access-Control-Request-Method, Access-Control-Request-Headers"} + c.preflightVary = "Origin, Access-Control-Request-Method, Access-Control-Request-Headers" } // Precompute max-age @@ -337,11 +333,7 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) { // Always set Vary headers // see https://github.com/rs/cors/issues/10, // https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001 - if vary, found := headers["Vary"]; found { - headers["Vary"] = append(vary, c.preflightVary[0]) - } else { - headers["Vary"] = c.preflightVary - } + headers.Add("Vary", c.preflightVary) allowed, additionalVaryHeaders := c.isOriginAllowed(r, origin) if len(additionalVaryHeaders) > 0 { headers.Add("Vary", strings.Join(convert(additionalVaryHeaders, http.CanonicalHeaderKey), ", ")) @@ -372,7 +364,7 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) { return } if c.allowedOriginsAll { - headers["Access-Control-Allow-Origin"] = headerOriginAll + headers.Set("Access-Control-Allow-Origin", "*") } else { headers["Access-Control-Allow-Origin"] = r.Header["Origin"] } @@ -385,10 +377,10 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) { headers["Access-Control-Allow-Headers"] = reqHeaders } if c.allowCredentials { - headers["Access-Control-Allow-Credentials"] = headerTrue + headers.Set("Access-Control-Allow-Credentials", "true") } if c.allowPrivateNetwork && r.Header.Get("Access-Control-Request-Private-Network") == "true" { - headers["Access-Control-Allow-Private-Network"] = headerTrue + headers.Set("Access-Control-Allow-Private-Network", "true") } if len(c.maxAge) > 0 { headers["Access-Control-Max-Age"] = c.maxAge @@ -404,11 +396,7 @@ func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) { allowed, additionalVaryHeaders := c.isOriginAllowed(r, origin) // Always set Vary, see https://github.com/rs/cors/issues/10 - if vary := headers["Vary"]; vary == nil { - headers["Vary"] = headerVaryOrigin - } else { - headers["Vary"] = append(vary, headerVaryOrigin[0]) - } + headers.Add("Vary", "Origin") if len(additionalVaryHeaders) > 0 { headers.Add("Vary", strings.Join(convert(additionalVaryHeaders, http.CanonicalHeaderKey), ", ")) } @@ -430,7 +418,7 @@ func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) { return } if c.allowedOriginsAll { - headers["Access-Control-Allow-Origin"] = headerOriginAll + headers.Set("Access-Control-Allow-Origin", "*") } else { headers["Access-Control-Allow-Origin"] = r.Header["Origin"] } @@ -438,7 +426,7 @@ func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) { headers["Access-Control-Expose-Headers"] = c.exposedHeaders } if c.allowCredentials { - headers["Access-Control-Allow-Credentials"] = headerTrue + headers.Set("Access-Control-Allow-Credentials", "true") } c.logf(" Actual response added headers: %v", headers) } diff --git a/cors_test.go b/cors_test.go index f537e0f..22bea21 100644 --- a/cors_test.go +++ b/cors_test.go @@ -8,6 +8,7 @@ import ( "regexp" "slices" "strings" + "sync" "testing" ) @@ -830,3 +831,59 @@ func TestAccessControlExposeHeadersPresence(t *testing.T) { } } + +var mutatingHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for _, k := range keys { + vv := w.Header()[k] + if len(vv) > 0 { + vv[0] = "oops!" + } + } +}) + +var keys = []string{ + "Access-Control-Allow-Credentials", + "Access-Control-Allow-Origin", + "Vary", +} + +// Note: run this test with -race +func TestSynchronizationBugWithPrelightRequest(t *testing.T) { + testSynchronizationBug(t, true) +} + +// Note: run this test with -race +func TestSynchronizationBugWithActualRequest(t *testing.T) { + testSynchronizationBug(t, false) +} + +func testSynchronizationBug(t *testing.T, preflight bool) { + t.Helper() + c := New(Options{ + AllowedOrigins: []string{"*"}, + AllowCredentials: true, + AllowedMethods: []string{http.MethodPut}, + OptionsPassthrough: true, + }) + var req *http.Request + if preflight { + req = httptest.NewRequest(http.MethodOptions, "https://example.org", nil) + req.Header.Add("Access-Control-Request-Method", http.MethodPut) + } else { + req = httptest.NewRequest(http.MethodGet, "https://example.org", nil) + } + req.Header.Add("Origin", "https://example.com") + + // simulate concurrent requests + const n = 128 + var wg sync.WaitGroup + wg.Add(n) + for range n { + go func() { + defer wg.Done() + rec := httptest.NewRecorder() + c.Handler(mutatingHandler).ServeHTTP(rec, req) + }() + } + wg.Wait() +}