diff --git a/const.go b/const.go index 193e839..4f1589f 100644 --- a/const.go +++ b/const.go @@ -74,3 +74,18 @@ func (ccm CodeChallengeMethod) Validate(cc, ver string) bool { return false } } + +// AuthorizeRequestMethod the type of authorization request method +type AuthorizeRequestMethod string + +const ( + AuthorizeRequestGet AuthorizeRequestMethod = "GET" + AuthorizeRequestPost AuthorizeRequestMethod = "POST" +) + +func (ar AuthorizeRequestMethod) String() string { + if ar == AuthorizeRequestGet || ar == AuthorizeRequestPost { + return string(ar) + } + return "" +} diff --git a/errors/response.go b/errors/response.go index c8d5902..d53ac19 100644 --- a/errors/response.go +++ b/errors/response.go @@ -35,16 +35,19 @@ func (r *Response) SetHeader(key, value string) { // https://tools.ietf.org/html/rfc6749#section-5.2 var ( ErrInvalidRequest = errors.New("invalid_request") + ErrMissingClientID = errors.New("missing_client_id") + ErrInvalidRequestMethod = errors.New("invalid_request_method") ErrUnauthorizedClient = errors.New("unauthorized_client") ErrAccessDenied = errors.New("access_denied") ErrUnsupportedResponseType = errors.New("unsupported_response_type") + ErrMissingResponseType = errors.New("missing_response_type") ErrInvalidScope = errors.New("invalid_scope") ErrServerError = errors.New("server_error") ErrTemporarilyUnavailable = errors.New("temporarily_unavailable") ErrInvalidClient = errors.New("invalid_client") ErrInvalidGrant = errors.New("invalid_grant") ErrUnsupportedGrantType = errors.New("unsupported_grant_type") - ErrCodeChallengeRquired = errors.New("invalid_request") + ErrCodeChallengeRequired = errors.New("invalid_request") ErrUnsupportedCodeChallengeMethod = errors.New("invalid_request") ErrInvalidCodeChallengeLen = errors.New("invalid_request") ) @@ -52,33 +55,39 @@ var ( // Descriptions error description var Descriptions = map[error]string{ ErrInvalidRequest: "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed", + ErrMissingClientID: "The request is missing client_id", + ErrInvalidRequestMethod: "The request method is invalid, unknown, or malformed", ErrUnauthorizedClient: "The client is not authorized to request an authorization code using this method", ErrAccessDenied: "The resource owner or authorization server denied the request", ErrUnsupportedResponseType: "The authorization server does not support obtaining an authorization code using this method", + ErrMissingResponseType: "The requested response type is empty", ErrInvalidScope: "The requested scope is invalid, unknown, or malformed", ErrServerError: "The authorization server encountered an unexpected condition that prevented it from fulfilling the request", ErrTemporarilyUnavailable: "The authorization server is currently unable to handle the request due to a temporary overloading or maintenance of the server", ErrInvalidClient: "Client authentication failed", ErrInvalidGrant: "The provided authorization grant (e.g., authorization code, resource owner credentials) or refresh token is invalid, expired, revoked, does not match the redirection URI used in the authorization request, or was issued to another client", ErrUnsupportedGrantType: "The authorization grant type is not supported by the authorization server", - ErrCodeChallengeRquired: "PKCE is required. code_challenge is missing", + ErrCodeChallengeRequired: "PKCE is required. code_challenge is missing", ErrUnsupportedCodeChallengeMethod: "Selected code_challenge_method not supported", - ErrInvalidCodeChallengeLen: "Code challenge length must be between 43 and 128 charachters long", + ErrInvalidCodeChallengeLen: "Code challenge length must be between 43 and 128 characters long", } // StatusCodes response error HTTP status code var StatusCodes = map[error]int{ ErrInvalidRequest: 400, + ErrMissingClientID: 400, + ErrInvalidRequestMethod: 400, ErrUnauthorizedClient: 401, ErrAccessDenied: 403, ErrUnsupportedResponseType: 401, + ErrMissingResponseType: 400, ErrInvalidScope: 400, ErrServerError: 500, ErrTemporarilyUnavailable: 503, ErrInvalidClient: 401, ErrInvalidGrant: 401, ErrUnsupportedGrantType: 401, - ErrCodeChallengeRquired: 400, + ErrCodeChallengeRequired: 400, ErrUnsupportedCodeChallengeMethod: 400, ErrInvalidCodeChallengeLen: 400, } diff --git a/server/config.go b/server/config.go index 3bbb884..e9bad11 100644 --- a/server/config.go +++ b/server/config.go @@ -9,12 +9,13 @@ import ( // Config configuration parameters type Config struct { - TokenType string // token type - AllowGetAccessRequest bool // to allow GET requests for the token - AllowedResponseTypes []oauth2.ResponseType // allow the authorization type - AllowedGrantTypes []oauth2.GrantType // allow the grant type - AllowedCodeChallengeMethods []oauth2.CodeChallengeMethod - ForcePKCE bool + TokenType string // token type + AllowGetAccessRequest bool // to allow GET requests for the token + AllowedResponseTypes []oauth2.ResponseType // allow the authorization type + AllowedGrantTypes []oauth2.GrantType // allow the grant type + AllowedCodeChallengeMethods []oauth2.CodeChallengeMethod + AllowedAuthorizeRequestMethods []oauth2.AuthorizeRequestMethod //allowed `authorize request methods` + ForcePKCE bool } // NewConfig create to configuration instance @@ -32,6 +33,10 @@ func NewConfig() *Config { oauth2.CodeChallengePlain, oauth2.CodeChallengeS256, }, + AllowedAuthorizeRequestMethods: []oauth2.AuthorizeRequestMethod{ + oauth2.AuthorizeRequestGet, + oauth2.AuthorizeRequestPost, + }, } } diff --git a/server/server.go b/server/server.go index df19d1f..5bddddf 100755 --- a/server/server.go +++ b/server/server.go @@ -143,14 +143,19 @@ func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface return u.String(), nil } -// CheckResponseType check allows response type -func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { - for _, art := range s.Config.AllowedResponseTypes { - if art == rt { - return true +// CheckResponseType checks for an allowed response type +func (s *Server) CheckResponseType(responseType oauth2.ResponseType) error { + if responseType.String() == "" { + return errors.ErrMissingResponseType + } + + for _, rType := range s.Config.AllowedResponseTypes { + if rType == responseType { + return nil } } - return false + + return errors.ErrUnsupportedResponseType } // CheckCodeChallengeMethod checks for allowed code challenge method @@ -163,50 +168,73 @@ func (s *Server) CheckCodeChallengeMethod(ccm oauth2.CodeChallengeMethod) bool { return false } +// CheckAuthorizeRequestMethod checks for allowed code challenge method +func (s *Server) CheckAuthorizeRequestMethod(requestMethod oauth2.AuthorizeRequestMethod) bool { + for _, method := range s.Config.AllowedAuthorizeRequestMethods { + if method == requestMethod { + return true + } + } + return false +} + +// CheckCodeChallenge checks if the Code Challenge is valid +func (s *Server) CheckCodeChallenge(codeChallenge string, isForcePKCE bool) error { + if isForcePKCE && codeChallenge == "" { + return errors.ErrCodeChallengeRequired + } + if len(codeChallenge) > 0 && len(codeChallenge) < 43 || len(codeChallenge) > 128 { + return errors.ErrInvalidCodeChallengeLen + } + return nil +} + // ValidationAuthorizeRequest the authorization request validation func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { + if r == nil { + return nil, errors.ErrInvalidRequest + } + redirectURI := r.FormValue("redirect_uri") + clientID := r.FormValue("client_id") - if !(r.Method == "GET" || r.Method == "POST") || - clientID == "" { - return nil, errors.ErrInvalidRequest + if clientID == "" { + return nil, errors.ErrMissingClientID } - resType := oauth2.ResponseType(r.FormValue("response_type")) - if resType.String() == "" { - return nil, errors.ErrUnsupportedResponseType - } else if allowed := s.CheckResponseType(resType); !allowed { - return nil, errors.ErrUnauthorizedClient + if isMethodAllowed := s.CheckAuthorizeRequestMethod(oauth2.AuthorizeRequestMethod(r.Method)); !isMethodAllowed { + return nil, errors.ErrInvalidRequestMethod } - cc := r.FormValue("code_challenge") - if cc == "" && s.Config.ForcePKCE { - return nil, errors.ErrCodeChallengeRquired + responseType := oauth2.ResponseType(r.FormValue("response_type")) + if err := s.CheckResponseType(responseType); err != nil { + return nil, err } - if cc != "" && (len(cc) < 43 || len(cc) > 128) { - return nil, errors.ErrInvalidCodeChallengeLen + + codeChallenge := r.FormValue("code_challenge") + if err := s.CheckCodeChallenge(codeChallenge, s.Config.ForcePKCE); err != nil { + return nil, err } - ccm := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method")) - // set default - if ccm == "" { - ccm = oauth2.CodeChallengePlain + codeChallengeMethod := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method")) + // Default to plain method if not specified + if codeChallengeMethod == "" { + codeChallengeMethod = oauth2.CodeChallengePlain } - if ccm != "" && !s.CheckCodeChallengeMethod(ccm) { + if !s.CheckCodeChallengeMethod(codeChallengeMethod) { return nil, errors.ErrUnsupportedCodeChallengeMethod } - req := &AuthorizeRequest{ + return &AuthorizeRequest{ RedirectURI: redirectURI, - ResponseType: resType, + ResponseType: responseType, ClientID: clientID, State: r.FormValue("state"), Scope: r.FormValue("scope"), Request: r, - CodeChallenge: cc, - CodeChallengeMethod: ccm, - } - return req, nil + CodeChallenge: codeChallenge, + CodeChallengeMethod: codeChallengeMethod, + }, nil } // GetAuthorizeToken get authorization token(code)