Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor basic auth middleware to support multiple auth headers #2695

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 59 additions & 31 deletions middleware/basic_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
package middleware

import (
"bytes"
"encoding/base64"
"errors"
"net/http"
"strconv"
"strings"
Expand All @@ -21,6 +23,12 @@ type BasicAuthConfig struct {
// Required.
Validator BasicAuthValidator

// HeaderValidationLimit limits the amount of authorization headers will be validated
// for valid credentials. Set this value to be higher from in an environment where multiple
// basic auth headers could be received.
// Default value 1.
HeaderValidationLimit int

// Realm is a string to define realm attribute of BasicAuth.
// Default value "Restricted".
Realm string
Expand All @@ -29,7 +37,7 @@ type BasicAuthConfig struct {
// BasicAuthValidator defines a function to validate BasicAuth credentials.
// The function should return a boolean indicating whether the credentials are valid,
// and an error if any error occurs during the validation process.
type BasicAuthValidator func(string, string, echo.Context) (bool, error)
type BasicAuthValidator func(user string, password string, c echo.Context) (bool, error)

const (
basic = "basic"
Expand All @@ -38,8 +46,9 @@ const (

// DefaultBasicAuthConfig is the default BasicAuth middleware config.
var DefaultBasicAuthConfig = BasicAuthConfig{
Skipper: DefaultSkipper,
Realm: defaultRealm,
Skipper: DefaultSkipper,
Realm: defaultRealm,
HeaderValidationLimit: 1,
}

// BasicAuth returns an BasicAuth middleware.
Expand All @@ -52,18 +61,30 @@ func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc {
return BasicAuthWithConfig(c)
}

// BasicAuthWithConfig returns an BasicAuth middleware with config.
// See `BasicAuth()`.
// BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config.
func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
// Defaults
mw, err := config.ToMiddleware()
if err != nil {
panic(err)
}
return mw
}

// ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration
func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
if config.Validator == nil {
panic("echo: basic-auth middleware requires a validator function")
return nil, errors.New("echo basic-auth middleware requires a validator function")
}
if config.Skipper == nil {
config.Skipper = DefaultBasicAuthConfig.Skipper
config.Skipper = DefaultSkipper
}
realm := defaultRealm
if config.Realm != "" && config.Realm != realm {
realm = strconv.Quote(config.Realm)
}
if config.Realm == "" {
config.Realm = defaultRealm
maxValidationAttemptCount := 1
if config.HeaderValidationLimit > 1 {
maxValidationAttemptCount = config.HeaderValidationLimit
}

return func(next echo.HandlerFunc) echo.HandlerFunc {
Expand All @@ -72,40 +93,47 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
return next(c)
}

auth := c.Request().Header.Get(echo.HeaderAuthorization)
var lastError error
l := len(basic)
errCount := 0
// multiple auth headers is something that can happen in environments like
// corporate test environments that are secured by application proxy servers where
// front facing proxy is also configured to require own basic auth value and does auth checks.
// In that case middleware can receive multiple auth headers.
for _, auth := range c.Request().Header[echo.HeaderAuthorization] {
if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) {
continue
}
if errCount >= maxValidationAttemptCount {
break
}

if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) {
// Invalid base64 shouldn't be treated as error
// instead should be treated as invalid client input
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest).SetInternal(err)
b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:])
if errDecode != nil {
lastError = echo.NewHTTPError(http.StatusBadRequest).WithInternal(errDecode)
continue
}

cred := string(b)
for i := 0; i < len(cred); i++ {
if cred[i] == ':' {
// Verify credentials
valid, err := config.Validator(cred[:i], cred[i+1:], c)
if err != nil {
return err
} else if valid {
return next(c)
}
break
idx := bytes.IndexByte(b, ':')
if idx >= 0 {
valid, errValidate := config.Validator(string(b[:idx]), string(b[idx+1:]), c)
if errValidate != nil {
lastError = errValidate
} else if valid {
return next(c)
}
errCount++
}
}

realm := defaultRealm
if config.Realm != defaultRealm {
realm = strconv.Quote(config.Realm)
if lastError != nil {
return lastError
}

// Need to return `401` for browsers to pop-up login box.
c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm)
return echo.ErrUnauthorized
}
}
}, nil
}
116 changes: 89 additions & 27 deletions middleware/basic_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,45 @@ import (
)

func TestBasicAuth(t *testing.T) {

validator := func(u, p string, c echo.Context) (bool, error) {
if u == "joe" && p == "secret" {
return true, nil
}
return false, nil
}

req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()

userPassB64 := base64.StdEncoding.EncodeToString([]byte("joe:secret"))
req.Header.Set(echo.HeaderAuthorization, basic+" "+userPassB64)

e := echo.New()
c := e.NewContext(req, res)

h := BasicAuth(validator)(func(c echo.Context) error {
return c.String(http.StatusIMUsed, "test")
})
err := h(c)
assert.NoError(t, err)
assert.Equal(t, http.StatusIMUsed, res.Code)
}

func TestBasicAuthPanic(t *testing.T) {
assert.PanicsWithError(t, "echo basic-auth middleware requires a validator function", func() {
BasicAuth(nil)
})
}

func TestBasicAuthWithConfig(t *testing.T) {
e := echo.New()

exampleSecret := base64.StdEncoding.EncodeToString([]byte("joe:secret"))
mockValidator := func(u, p string, c echo.Context) (bool, error) {
if u == "error" {
return false, errors.New("validator_error")
}
if u == "joe" && p == "secret" {
return true, nil
}
Expand All @@ -27,56 +63,83 @@ func TestBasicAuth(t *testing.T) {

tests := []struct {
name string
authHeader string
authHeader []string
config *BasicAuthConfig
expectedCode int
expectedAuth string
skipperResult bool
expectedErr bool
expectedErr string
expectedErrMsg string
}{
{
name: "Valid credentials",
authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
authHeader: []string{basic + " " + exampleSecret},
expectedCode: http.StatusOK,
},
{
name: "Case-insensitive header scheme",
authHeader: strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")),
authHeader: []string{strings.ToUpper(basic) + " " + exampleSecret},
expectedCode: http.StatusOK,
},
{
name: "Invalid credentials",
authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")),
authHeader: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password"))},
expectedCode: http.StatusUnauthorized,
expectedAuth: basic + ` realm="someRealm"`,
expectedErr: true,
expectedErr: "code=401, message=Unauthorized",
expectedErrMsg: "Unauthorized",
},
{
name: "validator errors out at 2 tries",
authHeader: []string{
basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")),
basic + " " + base64.StdEncoding.EncodeToString([]byte("error:secret")),
},
config: &BasicAuthConfig{
HeaderValidationLimit: 2,
Validator: mockValidator,
},
expectedCode: http.StatusUnauthorized,
expectedAuth: "",
expectedErr: "validator_error",
expectedErrMsg: "Unauthorized",
},
{
name: "Invalid credentials, default realm",
authHeader: []string{basic + " " + exampleSecret},
expectedCode: http.StatusOK,
expectedAuth: basic + ` realm="Restricted"`,
},
{
name: "Invalid base64 string",
authHeader: basic + " invalidString",
authHeader: []string{basic + " invalidString"},
expectedCode: http.StatusBadRequest,
expectedErr: true,
expectedErr: "code=400, message=Bad Request, internal=illegal base64 data at input byte 12",
expectedErrMsg: "Bad Request",
},
{
name: "Missing Authorization header",
expectedCode: http.StatusUnauthorized,
expectedErr: true,
expectedErr: "code=401, message=Unauthorized",
expectedErrMsg: "Unauthorized",
},
{
name: "Invalid Authorization header",
authHeader: base64.StdEncoding.EncodeToString([]byte("invalid")),
authHeader: []string{base64.StdEncoding.EncodeToString([]byte("invalid"))},
expectedCode: http.StatusUnauthorized,
expectedErr: true,
expectedErr: "code=401, message=Unauthorized",
expectedErrMsg: "Unauthorized",
},
{
name: "Skipped Request",
authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip")),
expectedCode: http.StatusOK,
skipperResult: true,
name: "Skipped Request",
authHeader: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip"))},
expectedCode: http.StatusOK,
config: &BasicAuthConfig{
Validator: mockValidator,
Realm: "someRealm",
Skipper: func(c echo.Context) bool {
return true
},
},
},
}

Expand All @@ -87,26 +150,25 @@ func TestBasicAuth(t *testing.T) {
res := httptest.NewRecorder()
c := e.NewContext(req, res)

if tt.authHeader != "" {
req.Header.Set(echo.HeaderAuthorization, tt.authHeader)
for _, h := range tt.authHeader {
req.Header.Add(echo.HeaderAuthorization, h)
}

h := BasicAuthWithConfig(BasicAuthConfig{
config := BasicAuthConfig{
Validator: mockValidator,
Realm: "someRealm",
Skipper: func(c echo.Context) bool {
return tt.skipperResult
},
})(func(c echo.Context) error {
}
if tt.config != nil {
config = *tt.config
}
h := BasicAuthWithConfig(config)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})

err := h(c)

if tt.expectedErr {
var he *echo.HTTPError
errors.As(err, &he)
assert.Equal(t, tt.expectedCode, he.Code)
if tt.expectedErr != "" {
assert.EqualError(t, err, tt.expectedErr)
if tt.expectedAuth != "" {
assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate))
}
Expand Down
Loading