From e08dd4b49f1e681e3330ad16927ed7a7503b0ccd Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Wed, 30 Oct 2024 19:34:16 +0200 Subject: [PATCH 1/3] Refactor basic auth middleware to support multiple auth headers --- middleware/basic_auth.go | 68 ++++++++++++--------- middleware/basic_auth_test.go | 112 ++++++++++++++++++++++++++-------- 2 files changed, 125 insertions(+), 55 deletions(-) diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index 9285f29fd..e2d723113 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -4,7 +4,9 @@ package middleware import ( + "bytes" "encoding/base64" + "errors" "net/http" "strconv" "strings" @@ -52,18 +54,26 @@ func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc { return BasicAuthWithConfig(c) } -// BasicAuthWithConfig returns an BasicAuth middleware with config. -// See `BasicAuth()`. +// BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config. func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { - // Defaults + mw, err := config.ToMiddleware() + if err != nil { + panic(err) + } + return mw +} + +// ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration +func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Validator == nil { - panic("echo: basic-auth middleware requires a validator function") + return nil, errors.New("echo basic-auth middleware requires a validator function") } if config.Skipper == nil { - config.Skipper = DefaultBasicAuthConfig.Skipper + config.Skipper = DefaultSkipper } - if config.Realm == "" { - config.Realm = defaultRealm + realm := defaultRealm + if config.Realm != "" && config.Realm != realm { + realm = strconv.Quote(config.Realm) } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -72,40 +82,42 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { return next(c) } - auth := c.Request().Header.Get(echo.HeaderAuthorization) + var lastError error l := len(basic) + // multiple auth headers is something that can happen in environments like + // corporate test environments that are secured application proxy servers where + // front facing proxy is configured to require own basic auth value and your application + // also requires basic auth headers from clients. + for _, auth := range c.Request().Header[echo.HeaderAuthorization] { + if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) { + continue + } - if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) { // Invalid base64 shouldn't be treated as error // instead should be treated as invalid client input - b, err := base64.StdEncoding.DecodeString(auth[l+1:]) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest).SetInternal(err) + b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:]) + if errDecode != nil { + lastError = echo.NewHTTPError(http.StatusBadRequest).WithInternal(errDecode) + continue } - - cred := string(b) - for i := 0; i < len(cred); i++ { - if cred[i] == ':' { - // Verify credentials - valid, err := config.Validator(cred[:i], cred[i+1:], c) - if err != nil { - return err - } else if valid { - return next(c) - } - break + idx := bytes.IndexByte(b, ':') + if idx >= 0 { + valid, errValidate := config.Validator(string(b[:idx]), string(b[idx+1:]), c) + if errValidate != nil { + lastError = errValidate + } else if valid { + return next(c) } } } - realm := defaultRealm - if config.Realm != defaultRealm { - realm = strconv.Quote(config.Realm) + if lastError != nil { + return lastError } // Need to return `401` for browsers to pop-up login box. c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm) return echo.ErrUnauthorized } - } + }, nil } diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index b3abfa172..6780b6b6c 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -16,9 +16,45 @@ import ( ) func TestBasicAuth(t *testing.T) { + + validator := func(u, p string, c echo.Context) (bool, error) { + if u == "joe" && p == "secret" { + return true, nil + } + return false, nil + } + + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + + userPassB64 := base64.StdEncoding.EncodeToString([]byte("joe:secret")) + req.Header.Set(echo.HeaderAuthorization, basic+" "+userPassB64) + + e := echo.New() + c := e.NewContext(req, res) + + h := BasicAuth(validator)(func(c echo.Context) error { + return c.String(http.StatusIMUsed, "test") + }) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusIMUsed, res.Code) +} + +func TestBasicAuthPanic(t *testing.T) { + assert.PanicsWithError(t, "echo basic-auth middleware requires a validator function", func() { + BasicAuth(nil) + }) +} + +func TestBasicAuthWithConfig(t *testing.T) { e := echo.New() + exampleSecret := base64.StdEncoding.EncodeToString([]byte("joe:secret")) mockValidator := func(u, p string, c echo.Context) (bool, error) { + if u == "error" { + return false, errors.New("validator_error") + } if u == "joe" && p == "secret" { return true, nil } @@ -27,56 +63,79 @@ func TestBasicAuth(t *testing.T) { tests := []struct { name string - authHeader string + authHeader []string + config *BasicAuthConfig expectedCode int expectedAuth string - skipperResult bool - expectedErr bool + expectedErr string expectedErrMsg string }{ { name: "Valid credentials", - authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), + authHeader: []string{basic + " " + exampleSecret}, expectedCode: http.StatusOK, }, { name: "Case-insensitive header scheme", - authHeader: strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), + authHeader: []string{strings.ToUpper(basic) + " " + exampleSecret}, expectedCode: http.StatusOK, }, { name: "Invalid credentials", - authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")), + authHeader: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))}, expectedCode: http.StatusUnauthorized, expectedAuth: basic + ` realm="someRealm"`, - expectedErr: true, + expectedErr: "code=401, message=Unauthorized", expectedErrMsg: "Unauthorized", }, + { + name: "validator errors out", + authHeader: []string{ + basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")), + basic + " " + base64.StdEncoding.EncodeToString([]byte("error:secret")), + }, + expectedCode: http.StatusUnauthorized, + expectedAuth: "", + expectedErr: "validator_error", + expectedErrMsg: "Unauthorized", + }, + { + name: "Invalid credentials, default realm", + authHeader: []string{basic + " " + exampleSecret}, + expectedCode: http.StatusOK, + expectedAuth: basic + ` realm="Restricted"`, + }, { name: "Invalid base64 string", - authHeader: basic + " invalidString", + authHeader: []string{basic + " invalidString"}, expectedCode: http.StatusBadRequest, - expectedErr: true, + expectedErr: "code=400, message=Bad Request, internal=illegal base64 data at input byte 12", expectedErrMsg: "Bad Request", }, { name: "Missing Authorization header", expectedCode: http.StatusUnauthorized, - expectedErr: true, + expectedErr: "code=401, message=Unauthorized", expectedErrMsg: "Unauthorized", }, { name: "Invalid Authorization header", - authHeader: base64.StdEncoding.EncodeToString([]byte("invalid")), + authHeader: []string{base64.StdEncoding.EncodeToString([]byte("invalid"))}, expectedCode: http.StatusUnauthorized, - expectedErr: true, + expectedErr: "code=401, message=Unauthorized", expectedErrMsg: "Unauthorized", }, { - name: "Skipped Request", - authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip")), - expectedCode: http.StatusOK, - skipperResult: true, + name: "Skipped Request", + authHeader: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip"))}, + expectedCode: http.StatusOK, + config: &BasicAuthConfig{ + Validator: mockValidator, + Realm: "someRealm", + Skipper: func(c echo.Context) bool { + return true + }, + }, }, } @@ -87,26 +146,25 @@ func TestBasicAuth(t *testing.T) { res := httptest.NewRecorder() c := e.NewContext(req, res) - if tt.authHeader != "" { - req.Header.Set(echo.HeaderAuthorization, tt.authHeader) + for _, h := range tt.authHeader { + req.Header.Add(echo.HeaderAuthorization, h) } - h := BasicAuthWithConfig(BasicAuthConfig{ + config := BasicAuthConfig{ Validator: mockValidator, Realm: "someRealm", - Skipper: func(c echo.Context) bool { - return tt.skipperResult - }, - })(func(c echo.Context) error { + } + if tt.config != nil { + config = *tt.config + } + h := BasicAuthWithConfig(config)(func(c echo.Context) error { return c.String(http.StatusOK, "test") }) err := h(c) - if tt.expectedErr { - var he *echo.HTTPError - errors.As(err, &he) - assert.Equal(t, tt.expectedCode, he.Code) + if tt.expectedErr != "" { + assert.EqualError(t, err, tt.expectedErr) if tt.expectedAuth != "" { assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate)) } From 0e4fdd6677cf62d6c439f5d0c84787f3617bc28e Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Wed, 30 Oct 2024 20:12:33 +0200 Subject: [PATCH 2/3] Refactor basic auth middleware to support multiple auth headers --- middleware/basic_auth.go | 27 +++++++++++++++++++++------ middleware/basic_auth_test.go | 6 +++++- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index e2d723113..b0ad5ff07 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -23,6 +23,12 @@ type BasicAuthConfig struct { // Required. Validator BasicAuthValidator + // HeaderValidationLimit limits the amount of authorization headers will be validated + // for valid credentials. Set this value to be higher from in an environment where multiple + // basic auth headers could be received. + // Default value 1. + HeaderValidationLimit int + // Realm is a string to define realm attribute of BasicAuth. // Default value "Restricted". Realm string @@ -31,7 +37,7 @@ type BasicAuthConfig struct { // BasicAuthValidator defines a function to validate BasicAuth credentials. // The function should return a boolean indicating whether the credentials are valid, // and an error if any error occurs during the validation process. -type BasicAuthValidator func(string, string, echo.Context) (bool, error) +type BasicAuthValidator func(user string, password string, c echo.Context) (bool, error) const ( basic = "basic" @@ -40,8 +46,9 @@ const ( // DefaultBasicAuthConfig is the default BasicAuth middleware config. var DefaultBasicAuthConfig = BasicAuthConfig{ - Skipper: DefaultSkipper, - Realm: defaultRealm, + Skipper: DefaultSkipper, + Realm: defaultRealm, + HeaderValidationLimit: 1, } // BasicAuth returns an BasicAuth middleware. @@ -75,6 +82,10 @@ func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Realm != "" && config.Realm != realm { realm = strconv.Quote(config.Realm) } + maxValidationAttemptCount := 1 + if config.HeaderValidationLimit > 1 { + maxValidationAttemptCount = config.HeaderValidationLimit + } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { @@ -84,14 +95,17 @@ func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { var lastError error l := len(basic) + errCount := 0 // multiple auth headers is something that can happen in environments like - // corporate test environments that are secured application proxy servers where - // front facing proxy is configured to require own basic auth value and your application - // also requires basic auth headers from clients. + // corporate test environments that are secured by application proxy servers where + // front facing proxy is also configured to require own basic auth value and does auth checks. for _, auth := range c.Request().Header[echo.HeaderAuthorization] { if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) { continue } + if errCount >= maxValidationAttemptCount { + break + } // Invalid base64 shouldn't be treated as error // instead should be treated as invalid client input @@ -108,6 +122,7 @@ func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { } else if valid { return next(c) } + errCount++ } } diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index 6780b6b6c..d27cbf241 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -89,11 +89,15 @@ func TestBasicAuthWithConfig(t *testing.T) { expectedErrMsg: "Unauthorized", }, { - name: "validator errors out", + name: "validator errors out at 2 tries", authHeader: []string{ basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")), basic + " " + base64.StdEncoding.EncodeToString([]byte("error:secret")), }, + config: &BasicAuthConfig{ + HeaderValidationLimit: 2, + Validator: mockValidator, + }, expectedCode: http.StatusUnauthorized, expectedAuth: "", expectedErr: "validator_error", From 4ab78e4b979acd50206a2d71b7aa30b11c0a25b8 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Wed, 30 Oct 2024 20:14:38 +0200 Subject: [PATCH 3/3] Refactor basic auth middleware to support multiple auth headers --- middleware/basic_auth.go | 1 + 1 file changed, 1 insertion(+) diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index b0ad5ff07..8f23bf548 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -99,6 +99,7 @@ func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // multiple auth headers is something that can happen in environments like // corporate test environments that are secured by application proxy servers where // front facing proxy is also configured to require own basic auth value and does auth checks. + // In that case middleware can receive multiple auth headers. for _, auth := range c.Request().Header[echo.HeaderAuthorization] { if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) { continue