diff --git a/Dockerfile b/Dockerfile index 062214c..7c213ae 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,7 +10,7 @@ RUN make FROM alpine:3.12 RUN addgroup -S appgroup && adduser -S appuser -G appgroup WORKDIR /app -COPY --from=builder /app/banner.txt /app/fastauth /app/startup.sql ./ +COPY --from=builder /app/login.html /app/banner.txt /app/fastauth /app/startup.sql ./ RUN chown -R appuser:appgroup /app USER appuser ENTRYPOINT ["./fastauth"] diff --git a/db.go b/db.go index 08d52a3..2394409 100644 --- a/db.go +++ b/db.go @@ -35,14 +35,14 @@ func dbSelect(email string) (*dbRes, error) { return &res, nil } -func insertUser(salt []byte, email string, dk []byte, emailToken string, refreshToken string) error { - stmt, err := db.Prepare("INSERT INTO auth (email, password, role, salt, emailToken, refreshToken) VALUES (?, ?, 'USR', ?, ?, ?)") +func insertUser(salt []byte, email string, dk []byte, role string, emailToken string, refreshToken string) error { + stmt, err := db.Prepare("INSERT INTO auth (email, password, role, salt, emailToken, refreshToken) VALUES (?, ?, ?, ?, ?, ?)") if err != nil { return fmt.Errorf("prepare INSERT INTO auth for %v statement failed: %v", email, err) } defer stmt.Close() - res, err := stmt.Exec(email, dk, salt, emailToken, refreshToken) + res, err := stmt.Exec(email, dk, []byte(role), salt, emailToken, refreshToken) return handleErr(res, err, "INSERT INTO auth", email) } @@ -181,7 +181,7 @@ func handleErr(res sql.Result, err error, info string, email string) error { ///////// Setup -func addInitialUser(username string, password string) error { +func addInitialUserWithRole(username string, password string, role string) error { res, err := dbSelect(username) if res == nil || err != nil { salt := []byte{0} @@ -189,7 +189,7 @@ func addInitialUser(username string, password string) error { if err != nil { return err } - err = insertUser(salt, username, dk, "emailToken", "refreshToken") + err = insertUser(salt, username, dk, role, "emailToken", "refreshToken") if err != nil { return err } @@ -235,13 +235,22 @@ func setupDB() { //add user for development users := strings.Split(options.Users, ";") for _, user := range users { - userpw := strings.Split(user, ":") - if len(userpw) == 2 { - err := addInitialUser(userpw[0], userpw[1]) + userPwRole := strings.Split(user, ":") + if len(userPwRole) == 2 { + role := "USR" + err := addInitialUserWithRole(userPwRole[0], userPwRole[1], role) if err == nil { - log.Printf("insterted user %v", userpw[0]) + log.Printf("insterted user %v", userPwRole[0]) } else { - log.Printf("could not insert %v", userpw[0]) + log.Printf("could not insert %v", userPwRole[0]) + } + } else if len(userPwRole) == 3 { + role := userPwRole[2] + err := addInitialUserWithRole(userPwRole[0], userPwRole[1], role) + if err == nil { + log.Printf("insterted user %v", userPwRole[0]) + } else { + log.Printf("could not insert %v", userPwRole[0]) } } else { log.Printf("username and password need to be seperated by ':'") diff --git a/fastauth.go b/fastauth.go index 4625b00..b85b337 100644 --- a/fastauth.go +++ b/fastauth.go @@ -17,6 +17,7 @@ import ( "github.com/dimiro1/banner" "github.com/gorilla/handlers" "github.com/gorilla/mux" + "github.com/gorilla/schema" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" ldap "github.com/vjeantet/ldapserver" @@ -26,6 +27,7 @@ import ( "gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2/jwt" "hash/crc64" + "io/ioutil" "log" rnd "math/rand" "net" @@ -58,44 +60,46 @@ const ( ) type Opts struct { - Dev string - Issuer string - Port int - Ldap int - DBPath string - DBDriver string - UrlEmail string - UrlSMS string - Audience string - ExpireAccess int - ExpireRefresh int - ExpireCode int - HS256 string - EdDSA string - RS256 string - OAuthUser string - OAuthPass string - ResetRefresh bool - RefreshCookiePath string - Users string - UserEndpoints bool - OauthEndpoints bool - LdapServer bool - DetailedError bool - Limiter bool + Dev string + Issuer string + Port int + Ldap int + DBPath string + DBDriver string + UrlEmail string + UrlSMS string + Audience string + ExpireAccess int + ExpireRefresh int + ExpireCode int + HS256 string + EdDSA string + RS256 string + OAuthUser string + OAuthPass string + ResetRefresh bool + Users string + UserEndpoints bool + OauthEndpoints bool + LdapServer bool + DetailedError bool + Limiter bool + Redirects string + PasswordFlow bool + Scope string } func NewOpts() *Opts { opts := &Opts{} flag.StringVar(&opts.Dev, "dev", LookupEnv("DEV"), "Dev settings with initial secret") - flag.StringVar(&opts.Issuer, "issuer", LookupEnv("ISSUER"), "name of issuer") + flag.StringVar(&opts.Issuer, "issuer", LookupEnv("ISSUER"), "name of issuer, default in dev is my-issuer") flag.IntVar(&opts.Port, "port", LookupEnvInt("PORT"), "listening HTTP port") flag.IntVar(&opts.Ldap, "ldap", LookupEnvInt("LDAP"), "listening LDAP port") flag.StringVar(&opts.DBPath, "db-path", LookupEnv("DB_PATH"), "DB path") flag.StringVar(&opts.DBDriver, "db-driver", LookupEnv("DB_DRIVER"), "DB driver") flag.StringVar(&opts.UrlEmail, "email-url", LookupEnv("EMAIL_URL"), "Email service URL") flag.StringVar(&opts.UrlSMS, "sms-url", LookupEnv("SMS_URL"), "SMS service URL") - flag.StringVar(&opts.Audience, "audience", LookupEnv("SMS_URL"), "Audience") + flag.StringVar(&opts.Audience, "audience", LookupEnv("AUDIENCE"), "Audience, default in dev is my-audience") flag.IntVar(&opts.ExpireAccess, "expire-access", LookupEnvInt("EXPIRE_ACCESS"), "Access token expiration in seconds") flag.IntVar(&opts.ExpireRefresh, "expire-refresh", LookupEnvInt("EXPIRE_REFRESH"), "Refresh token expiration in seconds") flag.IntVar(&opts.ExpireCode, "expire-code", LookupEnvInt("EXPIRE_CODE"), "Authtoken flow expiration in seconds") @@ -103,13 +107,15 @@ func NewOpts() *Opts { flag.StringVar(&opts.RS256, "rs256", LookupEnv("RS256"), "RS256 key") flag.StringVar(&opts.EdDSA, "eddsa", LookupEnv("EDDSA"), "EdDSA key") flag.BoolVar(&opts.ResetRefresh, "reset-refresh", LookupEnv("RESET_REFRESH") != "", "Reset refresh token when setting the token") - flag.StringVar(&opts.RefreshCookiePath, "refresh-cookie-path", LookupEnv("REFRESH_COOKIE_PATH"), "Refresh cookie path, default is /refresh") flag.StringVar(&opts.Users, "users", LookupEnv("USERS"), "add these initial users. E.g, -users tom@test.ch:pw123;test@test.ch:123pw") flag.BoolVar(&opts.UserEndpoints, "user-endpoints", LookupEnv("USER_ENDPOINTS") != "", "Enable user-facing endpoints. In dev mode these are enabled by default") flag.BoolVar(&opts.OauthEndpoints, "oauth-enpoints", LookupEnv("OAUTH_ENDPOINTS") != "", "Enable oauth-facing endpoints. In dev mode these are enabled by default") flag.BoolVar(&opts.LdapServer, "ldap-server", LookupEnv("LDAP_SERVER") != "", "Enable ldap server. In dev mode these are enabled by default") flag.BoolVar(&opts.DetailedError, "details", LookupEnv("DETAILS") != "", "Enable detailed errors") flag.BoolVar(&opts.Limiter, "limiter", LookupEnv("LIMITER") != "", "Enable limiter, disabled in dev mode") + flag.StringVar(&opts.Redirects, "redir", LookupEnv("REDIR"), "add client redirects. E.g, -redir clientId1:http://blabla;clientId2:http://blublu") + flag.StringVar(&opts.Redirects, "pwflow", LookupEnv("PWFLOW"), "enable password flow, default disabled") + flag.StringVar(&opts.Scope, "scope", LookupEnv("SCOPE"), "scope, default in dev is my-scope") flag.Usage = func() { fmt.Fprintf(flag.CommandLine.Output(), "Usage of %s:\n", os.Args[0]) @@ -130,13 +136,14 @@ func defaultOpts(opts *Opts) { opts.ExpireRefresh = setDefaultInt(opts.ExpireRefresh, 7*24*60*60) //7days opts.ExpireCode = setDefaultInt(opts.ExpireCode, 60) //1min opts.ResetRefresh = false - opts.RefreshCookiePath = setDefault(opts.RefreshCookiePath, "/refresh") + opts.PasswordFlow = false if opts.Dev != "" { - opts.Issuer = setDefault(opts.Issuer, "DevIssuer") + opts.Scope = setDefault(opts.Scope, "my-scope") + opts.Audience = setDefault(opts.Audience, "my-audience") + opts.Issuer = setDefault(opts.Issuer, "my-issuer") opts.UrlEmail = setDefault(opts.UrlEmail, "http://localhost:8080/send/email/{action}/{email}/{token}") opts.UrlSMS = setDefault(opts.UrlSMS, "http://localhost:8080/send/sms/{sms}/{token}") - opts.Audience = setDefault(opts.Audience, "DevAudience") opts.HS256 = base32.StdEncoding.EncodeToString([]byte(opts.Dev)) h := crc64.MakeTable(0xC96C5795D7870F42) @@ -144,7 +151,10 @@ func defaultOpts(opts *Opts) { if err != nil { log.Fatalf("cannot generate rsa key %v", err) } - encPrivRSA := x509.MarshalPKCS1PrivateKey(rsaPrivKey) + encPrivRSA, err := x509.MarshalPKCS8PrivateKey(rsaPrivKey) + if err != nil { + log.Fatalf("cannot generate rsa key %v", err) + } opts.RS256 = base32.StdEncoding.EncodeToString(encPrivRSA) _, edPrivKey, err := ed25519.GenerateKey(rnd.New(rnd.NewSource(int64(crc64.Checksum([]byte(opts.Dev), h))))) @@ -153,14 +163,16 @@ func defaultOpts(opts *Opts) { } opts.EdDSA = base32.StdEncoding.EncodeToString(edPrivKey) - opts.OAuthUser = setDefault(opts.OAuthUser, "user") - opts.OAuthPass = setDefault(opts.OAuthPass, "pass") - opts.OauthEndpoints = true opts.UserEndpoints = true opts.LdapServer = true opts.DetailedError = true opts.Limiter = false + opts.PasswordFlow = true + + if opts.Users == "" { + opts.Users = "tom:123" + } log.Printf("DEV mode active, key is %v, hex(%v)", opts.Dev, opts.HS256) log.Printf("DEV mode active, rsa is hex(%v)", opts.RS256) @@ -182,18 +194,19 @@ func defaultOpts(opts *Opts) { } if opts.RS256 != "" { - rsa, err := base32.StdEncoding.DecodeString(opts.RS256) + rsaDec, err := base32.StdEncoding.DecodeString(opts.RS256) if err != nil { log.Fatalf("cannot decode %v", opts.RS256) } - privRSA, err = x509.ParsePKCS1PrivateKey(rsa) + i, err := x509.ParsePKCS8PrivateKey(rsaDec) + privRSA = i.(*rsa.PrivateKey) if err != nil { - log.Fatalf("cannot decode %v", rsa) + log.Fatalf("cannot decode %v", rsaDec) } k := jose.JSONWebKey{Key: privRSA.Public()} kid, err := k.Thumbprint(crypto.SHA256) if err != nil { - log.Fatalf("cannot decode %v", rsa) + log.Fatalf("cannot decode %v", rsaDec) } privRSAKid = hex.EncodeToString(kid) } @@ -248,13 +261,26 @@ func LookupEnvInt(key string) int { } type Credentials struct { - Email string `json:"email,omitempty"` - Password string `json:"password"` - TOTP string `json:"totp,omitempty"` + Email string `json:"email,omitempty" schema:"email"` + Password string `json:"password" schema:"password,required"` + TOTP string `json:"totp,omitempty" schema:"totp"` + //here comes oauth, leave empty on regular login + //If you want to use oauth, you need to configure + //client-id with a matching redirect-uri from the + //command line + ClientId string `json:"client_id,omitempty" schema:"client_id"` + ResponseType string `json:"response_type,omitempty" schema:"response_type"` + State string `json:"state,omitempty" schema:"state"` + Scope string `json:"scope" schema:"scope"` + RedirectUri string `json:"redirect_uri,omitempty" schema:"redirect_uri"` + CodeChallenge string `json:"code_challenge,omitempty" schema:"code_challenge"` + CodeCodeChallengeMethod string `json:"code_challenge_method,omitempty" schema:"code_challenge_method"` } type TokenClaims struct { - Role string `json:"role,omitempty"` + Role string `json:"role,omitempty"` + Scope string `json:"scope,omitempty"` + ClientID string `json:"client_id,omitempty"` jwt.Claims } type RefreshClaims struct { @@ -265,8 +291,8 @@ type RefreshClaims struct { type CodeClaims struct { ExpiresAt int64 `json:"exp,omitempty"` Subject string `json:"role,omitempty"` - CodeChallenge string `json:"code-challenge,omitempty"` - CodeCodeChallengeMethod string `json:"code-challenge-method,omitempty"` + CodeChallenge string `json:"code_challenge,omitempty"` + CodeCodeChallengeMethod string `json:"code_challenge_method,omitempty"` } type ProvisioningUri struct { @@ -329,34 +355,6 @@ func jwtAuth(next func(w http.ResponseWriter, r *http.Request, claims *TokenClai } } -func refresh(w http.ResponseWriter, r *http.Request) { - //https://medium.com/monstar-lab-bangladesh-engineering/jwt-auth-in-go-part-2-refresh-tokens-d334777ca8a0 - - //check if refresh token matches - c, err := r.Cookie("refresh") - if err != nil { - writeErr(w, http.StatusBadRequest, "invalid_request", "blocked", "ERR-refresh-01, cookie not found: %v", err) - return - } - accessToken, refreshToken, expiresAt, err := refresh0(c.Value) - if err != nil { - writeErr(w, http.StatusUnauthorized, "invalid_request", "blocked", "ERR-refresh-02 %v", err) - return - } - w.Header().Set("Token", accessToken) - - cookie := http.Cookie{ - Name: "refresh", - Value: refreshToken, - Path: options.RefreshCookiePath, - HttpOnly: true, - Secure: options.Dev == "", - Expires: time.Unix(expiresAt, 0), - } - w.Header().Set("Set-Cookie", cookie.String()) - w.WriteHeader(http.StatusOK) -} - func checkRefreshToken(token string) (*RefreshClaims, error) { tok, err := jwt.ParseSigned(token) if err != nil { @@ -388,45 +386,6 @@ func checkRefreshToken(token string) (*RefreshClaims, error) { return refreshClaims, nil } -func refresh0(token string) (string, string, int64, error) { - refreshClaims, err := checkRefreshToken(token) - if err != nil { - return "", "", 0, fmt.Errorf("ERR-refresh-02, could not parse claims %v", err) - } - - result, err := dbSelect(refreshClaims.Subject) - if err != nil { - return "", "", 0, fmt.Errorf("ERR-refresh-03, DB select, %v err %v", refreshClaims.Subject, err) - } - - if result.emailVerified == nil || result.emailVerified.Unix() == 0 { - return "", "", 0, fmt.Errorf("ERR-refresh-04, user %v no email verified: %v", refreshClaims.Subject, err) - } - - if result.refreshToken == nil || refreshClaims.Token != *result.refreshToken { - return "", "", 0, fmt.Errorf("ERR-refresh-05, refresh token mismatch %v != %v", refreshClaims.Token, *result.refreshToken) - } - - encodedAccessToken, err := encodeAccessToken(string(result.role), refreshClaims.Subject) - if err != nil { - return "", "", 0, fmt.Errorf("ERR-refresh-06, cannot set access token for %v, %v", refreshClaims.Subject, err) - } - - refreshToken := *result.refreshToken - if options.ResetRefresh { - refreshToken, err = resetRefreshToken(refreshToken) - if err != nil { - return "", "", 0, fmt.Errorf("ERR-refresh-07, cannot reset access token for %v, %v", refreshClaims.Subject, err) - } - } - - encodedRefreshToken, expiresAt, err := encodeRefreshToken(refreshClaims.Subject, refreshToken) - if err != nil { - return "", "", 0, fmt.Errorf("ERR-refresh-08, cannot set refresh token for %v, %v", refreshClaims.Subject, err) - } - return encodedAccessToken, encodedRefreshToken, expiresAt, nil -} - func confirmEmail(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) token := vars["token"] @@ -479,7 +438,7 @@ func signup(w http.ResponseWriter, r *http.Request) { refreshToken := base32.StdEncoding.EncodeToString(rnd[32:48]) - err = insertUser(salt, cred.Email, dk, emailToken, refreshToken) + err = insertUser(salt, cred.Email, dk, "USR", emailToken, refreshToken) if err != nil { writeErr(w, http.StatusBadRequest, "invalid_request", "blocked", "ERR-signup-06, insert user failed: %v", err) return @@ -538,10 +497,32 @@ func checkEmailPassword(email string, password string) (*dbRes, string, error) { func login(w http.ResponseWriter, r *http.Request) { var cred Credentials - err := json.NewDecoder(r.Body).Decode(&cred) + + //https://medium.com/@xoen/golang-read-from-an-io-readwriter-without-loosing-its-content-2c6911805361 + var bodyCopy []byte + var err error + if r.Body != nil { + bodyCopy, err = ioutil.ReadAll(r.Body) + if err != nil { + writeErr(w, http.StatusBadRequest, "invalid_request", "blocked", "ERR-login-01, cannot parse POST data %v", err) + return + } + } + + r.Body = ioutil.NopCloser(bytes.NewBuffer(bodyCopy)) + err = json.NewDecoder(r.Body).Decode(&cred) if err != nil { - writeErr(w, http.StatusBadRequest, "invalid_request", "blocked", "ERR-login-01, cannot parse JSON credentials %v", err) - return + r.Body = ioutil.NopCloser(bytes.NewBuffer(bodyCopy)) + err = r.ParseForm() + if err != nil { + writeErr(w, http.StatusBadRequest, "invalid_request", "blocked", "ERR-login-01, cannot parse POST data %v", err) + return + } + err = schema.NewDecoder().Decode(&cred, r.PostForm) + if err != nil { + writeErr(w, http.StatusBadRequest, "invalid_request", "blocked", "ERR-login-02, cannot populate POST data %v", err) + return + } } result, errString, err := checkEmailPassword(cred.Email, cred.Password) @@ -580,39 +561,18 @@ func login(w http.ResponseWriter, r *http.Request) { } } - encodedAccessToken, err := encodeAccessToken(string(result.role), cred.Email) + //return the code flow + encoded, _, err := encodeCodeToken(cred.Email, cred.CodeChallenge, cred.CodeCodeChallengeMethod) if err != nil { - writeErr(w, http.StatusInternalServerError, "invalid_request", "blocked", "ERR-login-11, cannot set access token for %v, %v", cred.Email, err) + writeErr(w, http.StatusInternalServerError, "invalid_request", "blocked", "ERR-login-14, cannot set refresh token for %v, %v", cred.Email, err) return } - refreshToken := *result.refreshToken - if options.ResetRefresh { - refreshToken, err = resetRefreshToken(refreshToken) - if err != nil { - writeErr(w, http.StatusInternalServerError, "invalid_request", "blocked", "ERR-login-12, cannot reset access token for %v, %v", cred.Email, err) - return - } - } + //encodedAccessToken, err := encodeAccessToken(string(result.role), "tom", options.Scope, options.Audience, options.Issuer) + //log.Printf("accesstokeen: [%v]\n", encodedAccessToken) - encodedRefreshToken, expiresAt, err := encodeRefreshToken(cred.Email, refreshToken) - if err != nil { - writeErr(w, http.StatusInternalServerError, "invalid_request", "blocked", "ERR-login-13, cannot set refresh token for %v, %v", cred.Email, err) - return - } - - w.Header().Set("Token", encodedAccessToken) - - cookie := http.Cookie{ - Name: "refresh", - Value: encodedRefreshToken, - Path: options.RefreshCookiePath, - HttpOnly: true, - Secure: options.Dev == "", - Expires: time.Unix(expiresAt, 0), - } - w.Header().Set("Set-Cookie", cookie.String()) - w.WriteHeader(http.StatusOK) + w.Header().Set("Location", cred.RedirectUri+"?code="+encoded) + w.WriteHeader(303) } func displayEmail(w http.ResponseWriter, r *http.Request) { @@ -920,17 +880,16 @@ func serverLdap() (*ldap.Server, <-chan bool) { func serverRest() (*http.Server, <-chan bool, error) { tokenExp = time.Second * time.Duration(options.ExpireAccess) refreshExp = time.Second * time.Duration(options.ExpireRefresh) - codeExp = time.Second * time.Duration(options.ExpireRefresh) + codeExp = time.Second * time.Duration(options.ExpireCode) router := mux.NewRouter() router.Use(func(next http.Handler) http.Handler { - return handlers.CombinedLoggingHandler(os.Stdout, next) + return handlers.LoggingHandler(os.Stdout, next) }) if options.UserEndpoints { router.HandleFunc("/login", login).Methods("POST") router.HandleFunc("/signup", signup).Methods("POST") - router.HandleFunc("/refresh", refresh).Methods("POST") router.HandleFunc("/reset/{email}", resetEmail).Methods("POST") router.HandleFunc("/confirm/signup/{email}/{token}", confirmEmail).Methods("GET") router.HandleFunc("/confirm/reset/{email}/{token}", confirmReset).Methods("POST") @@ -952,13 +911,20 @@ func serverRest() (*http.Server, <-chan bool, error) { } if options.OauthEndpoints { + router.HandleFunc("/oauth/login", login).Methods("POST") router.HandleFunc("/oauth/token", oauth).Methods("POST") router.HandleFunc("/oauth/revoke", revoke).Methods("POST") - router.HandleFunc("/oauth/authorize", authorize).Methods("POST") + router.HandleFunc("/oauth/authorize", authorize).Methods("GET") router.HandleFunc("/oauth/.well-known/jwks.json", jwkFunc).Methods("GET") + router.HandleFunc("/authen/logout", logout).Methods("GET") } + router.PathPrefix("/").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Printf("no route matched for: %v", r.URL) + w.WriteHeader(http.StatusNotFound) + }) + s := &http.Server{ Addr: ":" + strconv.Itoa(options.Port), Handler: limit(router), @@ -1065,12 +1031,15 @@ func newTOTP(secret string) *gotp.TOTP { return gotp.NewTOTP(secret, 6, 30, hasher) } -func encodeAccessToken(role string, subject string) (string, error) { +func encodeAccessToken(role string, subject string, scope string, audience string, issuer string) (string, error) { tokenClaims := &TokenClaims{ - Role: role, + Role: role, + Scope: scope, Claims: jwt.Claims{ - Expiry: jwt.NewNumericDate(time.Now().Add(tokenExp)), - Subject: subject, + Expiry: jwt.NewNumericDate(time.Now().Add(tokenExp)), + Subject: subject, + Audience: []string{audience}, + Issuer: issuer, }, } var sig jose.Signer diff --git a/fastauth_test.go b/fastauth_test.go index 51d83c1..fa7403e 100644 --- a/fastauth_test.go +++ b/fastauth_test.go @@ -3,12 +3,15 @@ package main import ( "bytes" "context" + "crypto/sha256" + "encoding/base64" "encoding/json" "github.com/stretchr/testify/assert" "github.com/xlzd/gotp" "io/ioutil" "log" "net/http" + "net/url" "os" "strings" "testing" @@ -102,8 +105,8 @@ func TestLogin(t *testing.T) { resp = doConfirm("tom@test.ch", token) assert.Equal(t, http.StatusOK, resp.StatusCode) - resp = doLogin("tom@test.ch", "testtest", "") - assert.Equal(t, http.StatusOK, resp.StatusCode) + resp = doLogin("tom@test.ch", "testtest", "", "") + assert.Equal(t, http.StatusSeeOther, resp.StatusCode) resp.Body.Close() shutdown() @@ -111,18 +114,18 @@ func TestLogin(t *testing.T) { func TestLoginFalse(t *testing.T) { shutdown := mainTest(&Opts{Port: testPort, DBPath: testDBPath, UrlEmail: testUrl + "/send/email/{action}/{email}/{token}", Dev: "true"}) - resp := doAll("tom@test.ch", "testtest") + resp := doAll("tom@test.ch", "testtest", "0123456789012345678901234567890123456789012") assert.Equal(t, http.StatusOK, resp.StatusCode) - resp = doLogin("tom@test.ch", "testtest", "") - assert.Equal(t, http.StatusOK, resp.StatusCode) + resp = doLogin("tom@test.ch", "testtest", "", "0123456789012345678901234567890123456789012") + assert.Equal(t, http.StatusSeeOther, resp.StatusCode) - resp = doLogin("tom@test.ch", "testtest2", "") + resp = doLogin("tom@test.ch", "testtest2", "", "0123456789012345678901234567890123456789012") bodyBytes, _ := ioutil.ReadAll(resp.Body) assert.Equal(t, http.StatusBadRequest, resp.StatusCode) assert.True(t, strings.Index(string(bodyBytes), "ERR-checkEmail-06, user tom@test.ch password mismatch") > 0) - resp = doLogin("tom@test.ch2", "testtest", "") + resp = doLogin("tom@test.ch2", "testtest", "", "0123456789012345678901234567890123456789012") bodyBytes, _ = ioutil.ReadAll(resp.Body) assert.Equal(t, http.StatusBadRequest, resp.StatusCode) assert.True(t, strings.Index(string(bodyBytes), "ERR-checkEmail-01, DB select, tom@test.ch2 err sql: no rows in result set") > 0) @@ -133,25 +136,18 @@ func TestLoginFalse(t *testing.T) { func TestRefresh(t *testing.T) { shutdown := mainTest(&Opts{Port: testPort, DBPath: testDBPath, UrlEmail: testUrl + "/send/email/{action}/{email}/{token}", Dev: "true", ExpireRefresh: 10}) - resp := doAll("tom@test.ch", "testtest") - assert.Equal(t, http.StatusOK, resp.StatusCode) - token1 := resp.Header.Get("Token") - - time.Sleep(time.Second) - cookie := resp.Cookies()[0] - resp = doRefresh(cookie.Value) + resp := doAll("tom@test.ch", "testtest", "0123456789012345678901234567890123456789012") assert.Equal(t, http.StatusOK, resp.StatusCode) - token2 := resp.Header.Get("Token") - assert.NotEqual(t, token1, token2) - - resp.Body.Close() + oauth := OAuth{} + json.NewDecoder(resp.Body).Decode(&oauth) + assert.NotEqual(t, "", oauth.AccessToken) shutdown() } func TestReset(t *testing.T) { shutdown := mainTest(&Opts{Port: testPort, DBPath: testDBPath, UrlEmail: testUrl + "/send/email/{action}/{email}/{token}", Dev: "true", ExpireRefresh: 1}) - resp := doAll("tom@test.ch", "testtest") + resp := doAll("tom@test.ch", "testtest", "0123456789012345678901234567890123456789012") assert.Equal(t, http.StatusOK, resp.StatusCode) resp = doReset("tom@test.ch") @@ -162,8 +158,8 @@ func TestReset(t *testing.T) { resp = doConfirmReset("tom@test.ch", token, "testtest2") assert.Equal(t, http.StatusOK, resp.StatusCode) - resp = doLogin("tom@test.ch", "testtest2", "") - assert.Equal(t, http.StatusOK, resp.StatusCode) + resp = doLogin("tom@test.ch", "testtest2", "", "0123456789012345678901234567890123456789012") + assert.Equal(t, http.StatusSeeOther, resp.StatusCode) resp.Body.Close() shutdown() @@ -172,7 +168,7 @@ func TestReset(t *testing.T) { func TestResetFailed(t *testing.T) { shutdown := mainTest(&Opts{Port: testPort, DBPath: testDBPath, UrlEmail: testUrl + "/send/email/{action}/{email}/{token}", Dev: "true", ExpireRefresh: 1}) - resp := doAll("tom@test.ch", "testtest") + resp := doAll("tom@test.ch", "testtest", "0123456789012345678901234567890123456789012") assert.Equal(t, http.StatusOK, resp.StatusCode) resp = doReset("tom@test.ch") @@ -183,7 +179,7 @@ func TestResetFailed(t *testing.T) { resp = doConfirmReset("tom@test.ch", token, "testtest2") assert.Equal(t, http.StatusOK, resp.StatusCode) - resp = doLogin("tom@test.ch", "testtest", "") + resp = doLogin("tom@test.ch", "testtest", "", "0123456789012345678901234567890123456789012") assert.Equal(t, http.StatusBadRequest, resp.StatusCode) resp.Body.Close() @@ -192,11 +188,12 @@ func TestResetFailed(t *testing.T) { func TestTOTP(t *testing.T) { shutdown := mainTest(&Opts{Issuer: "FFFS", Port: testPort, DBPath: testDBPath, UrlEmail: testUrl + "/send/email/{action}/{email}/{token}", Dev: "true"}) - resp := doAll("tom@test.ch", "testtest") + resp := doAll("tom@test.ch", "testtest", "0123456789012345678901234567890123456789012") assert.Equal(t, http.StatusOK, resp.StatusCode) - token := resp.Header.Get("Token") + oauth := OAuth{} + json.NewDecoder(resp.Body).Decode(&oauth) - resp = doTOTP(token) + resp = doTOTP(oauth.AccessToken) p := ProvisioningUri{} bodyBytes, _ := ioutil.ReadAll(resp.Body) json.Unmarshal(bodyBytes, &p) @@ -205,7 +202,7 @@ func TestTOTP(t *testing.T) { totp := newTOTP(secret[0]) conf := totp.Now() - resp = doTOTPConfirm(conf, token) + resp = doTOTPConfirm(conf, oauth.AccessToken) assert.Equal(t, http.StatusOK, resp.StatusCode) resp.Body.Close() @@ -214,16 +211,18 @@ func TestTOTP(t *testing.T) { func TestLoginTOTP(t *testing.T) { shutdown := mainTest(&Opts{Issuer: "FFFS", Port: testPort, DBPath: testDBPath, UrlEmail: testUrl + "/send/email/{action}/{email}/{token}", Dev: "true"}) - resp := doAll("tom@test.ch", "testtest") + resp := doAll("tom@test.ch", "testtest", "0123456789012345678901234567890123456789012") assert.Equal(t, http.StatusOK, resp.StatusCode) - token := resp.Header.Get("Token") - totp := doAllTOTP(token) + oauth := OAuth{} + json.NewDecoder(resp.Body).Decode(&oauth) - resp = doLogin("tom@test.ch", "testtest", totp.Now()) - assert.Equal(t, http.StatusOK, resp.StatusCode) + totp := doAllTOTP(oauth.AccessToken) - resp = doLogin("tom@test.ch", "testtest", "") + resp = doLogin("tom@test.ch", "testtest", totp.Now(), "0123456789012345678901234567890123456789012") + assert.Equal(t, http.StatusSeeOther, resp.StatusCode) + + resp = doLogin("tom@test.ch", "testtest", "", "0123456789012345678901234567890123456789012") assert.Equal(t, http.StatusForbidden, resp.StatusCode) resp.Body.Close() @@ -277,34 +276,59 @@ func doReset(email string) *http.Response { return resp } -func doAll(email string, pass string) *http.Response { +func doAll(email string, pass string, secret string) *http.Response { resp := doSignup(email, pass) token := token(email) resp = doConfirm(email, token) - resp = doLogin(email, pass, "") + resp = doLogin(email, pass, "", secret) + code := resp.Header.Get("Location")[6:] + resp = doCode(code, secret) return resp } -func doRefresh(cookie string) *http.Response { - req, _ := http.NewRequest("POST", testUrl+"/refresh", nil) - req.Header.Set("Content-Type", "application/json") - c := http.Cookie{Name: "refresh", Value: cookie, Path: "/refresh", Secure: false, HttpOnly: true} - req.AddCookie(&c) +func doCode(codeToken string, codeVerifier string) *http.Response { + data := url.Values{} + data.Set("grant_type", "authorization_code") + data.Set("code", codeToken) + data.Set("code_verifier", codeVerifier) + req, _ := http.NewRequest("POST", testUrl+"/oauth/token", strings.NewReader(data.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") resp, _ := http.DefaultClient.Do(req) return resp } -func doLogin(email string, pass string, totp string) *http.Response { +func doRefresh(refreshToken string) *http.Response { + data := url.Values{} + data.Set("grant_type", "refresh_token") + data.Set("refresh_token", refreshToken) + req, _ := http.NewRequest("POST", testUrl+"/oauth/token", strings.NewReader(data.Encode())) + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + resp, _ := http.DefaultClient.Do(req) + return resp +} + +func doLogin(email string, pass string, totp string, secret string) *http.Response { + h := sha256.Sum256([]byte(secret)) data := Credentials{ - Email: email, - Password: pass, - TOTP: totp, + Email: email, + Password: pass, + TOTP: totp, + CodeChallenge: base64.RawURLEncoding.EncodeToString(h[:]), + CodeCodeChallengeMethod: "S256", } + + //do not follow redirects: https://stackoverflow.com/questions/23297520/how-can-i-make-the-go-http-client-not-follow-redirects-automatically + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + payloadBytes, _ := json.Marshal(data) body := bytes.NewReader(payloadBytes) - req, _ := http.NewRequest("POST", testUrl+"/login", body) + req, _ := http.NewRequest(http.MethodPost, testUrl+"/login", body) req.Header.Set("Content-Type", "application/json") - resp, _ := http.DefaultClient.Do(req) + resp, _ := client.Do(req) return resp } @@ -359,6 +383,23 @@ func getForgotEmailToken(email string) (string, error) { return forgetEmailToken, nil } +func TestSecret(t *testing.T) { + h := sha256.Sum256([]byte("test")) + s := base64.RawURLEncoding.EncodeToString(h[:]) + assert.Equal(t, "n4bQgYhMfWWaL-qgxVrQFaO_TxsrC4Is0V1sFbDwCgg", s) +} + +func TestGetAttrDN(t *testing.T) { + + assert.Equal(t, + getAttrDN("CN=tom,OU=P_Internal,OU=P_Users,DC=test,DC=ch", "cn"), + "tom") + + assert.Equal(t, + getAttrDN("CN=tom,OU=P_Internal,OU=P_Users,DC=test,DC=ch", "cn"), + getAttrDN("cn=tom,ou=P_Internal,ou=P_Users,dc=test,dc=ch", "CN")) +} + func mainTest(opts *Opts) func() { defaultOpts(opts) options = opts diff --git a/go.mod b/go.mod index f3091d4..c25966c 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/google/go-cmp v0.5.2 // indirect github.com/gorilla/handlers v1.5.1 github.com/gorilla/mux v1.8.0 + github.com/gorilla/schema v1.2.0 github.com/kr/pretty v0.1.0 // indirect github.com/lib/pq v1.8.0 github.com/lor00x/goldap v0.0.0-20180618054307-a546dffdd1a3 diff --git a/go.sum b/go.sum index c3dddf6..09db0d0 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,8 @@ github.com/gorilla/handlers v1.5.1 h1:9lRY6j8DEeeBT10CvO9hGW0gmky0BprnvDI5vfhUHH github.com/gorilla/handlers v1.5.1/go.mod h1:t8XrUpc4KVXb7HGyJ4/cEnwQiaxrX/hz1Zv/4g96P1Q= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= +github.com/gorilla/schema v1.2.0 h1:YufUaxZYCKGFuAq3c96BOhjgd5nmXiOY9NGzF247Tsc= +github.com/gorilla/schema v1.2.0/go.mod h1:kgLaKoK1FELgZqMAVxx/5cbj0kT+57qxUrAlIO2eleU= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= diff --git a/ldap.go b/ldap.go index e881cae..6d7ed66 100644 --- a/ldap.go +++ b/ldap.go @@ -46,7 +46,7 @@ func getAttrDN(dn string, atyp string) string { for _, rdn := range parsedDN.RDNs { for _, rdnAttr := range rdn.Attributes { log.Printf("found attr %v", rdnAttr.Type) - if rdnAttr.Type == atyp { + if strings.ToLower(rdnAttr.Type) == strings.ToLower(atyp) { return rdnAttr.Value } } @@ -82,21 +82,17 @@ func handleSearch(w ldap.ResponseWriter, m *ldap.Message) { return } - _, err := dbSelect(cn) + dbRes, err := dbSelect(cn) if err != nil { res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultUnwillingToPerform) w.Write(res) return } - var e message.SearchResultEntry - if strings.Index(string(r.BaseObject()), "cn") >= 0 { - e = ldap.NewSearchResultEntry(string(r.BaseObject())) - w.Write(e) - } else { - e = ldap.NewSearchResultEntry("cn=" + cn + ", " + string(r.BaseObject())) - w.Write(e) - } + e := ldap.NewSearchResultEntry("cn=" + cn + ", " + string(r.BaseObject())) + e.AddAttribute("cn", message.AttributeValue(string(dbRes.role))) + w.Write(e) + res := ldap.NewSearchResultDoneResponse(ldap.LDAPResultSuccess) w.Write(res) } diff --git a/login.html b/login.html new file mode 100644 index 0000000..3d2061c --- /dev/null +++ b/login.html @@ -0,0 +1,136 @@ + + +
+ +