-
Notifications
You must be signed in to change notification settings - Fork 1.8k
feat: add ALLOW_SIGNUP + ALLOWED_EMAIL_* for self-hosted instances #1098
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
b049a3b
47449f6
c448dfc
bf6e5ce
c7fa597
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ package handler | |
|
|
||
| import ( | ||
| "context" | ||
| "errors" | ||
| "crypto/rand" | ||
| "crypto/subtle" | ||
| "encoding/binary" | ||
|
|
@@ -22,6 +23,18 @@ import ( | |
| db "github.com/multica-ai/multica/server/pkg/db/generated" | ||
| ) | ||
|
|
||
| // SignupError represents signup restriction errors | ||
| type SignupError struct { | ||
| Message string | ||
| } | ||
|
|
||
| func (e SignupError) Error() string { | ||
| return e.Message | ||
| } | ||
|
|
||
| var ErrSignupProhibited = SignupError{Message: "user registration is disabled on this self-hosted instance"} | ||
| var ErrEmailNotAllowed = SignupError{Message: "email address or domain not allowed on this instance"} | ||
|
|
||
| type UserResponse struct { | ||
| ID string `json:"id"` | ||
| Name string `json:"name"` | ||
|
|
@@ -78,23 +91,68 @@ func (h *Handler) issueJWT(user db.User) (string, error) { | |
|
|
||
| func (h *Handler) findOrCreateUser(ctx context.Context, email string) (db.User, error) { | ||
| user, err := h.Queries.GetUserByEmail(ctx, email) | ||
| if err != nil { | ||
| if !isNotFound(err) { | ||
| return db.User{}, err | ||
| } | ||
| name := email | ||
| if at := strings.Index(email, "@"); at > 0 { | ||
| name = email[:at] | ||
| if err == nil { | ||
| return user, nil | ||
| } | ||
| if !isNotFound(err) { | ||
| return db.User{}, err | ||
| } | ||
|
|
||
| // New user creation path. Check if signups are allowed. | ||
| if err := h.checkSignupAllowed(email); err != nil { | ||
| return db.User{}, err | ||
| } | ||
|
Comment on lines
+101
to
+104
|
||
|
|
||
| name := email | ||
| if at := strings.Index(email, "@"); at > 0 { | ||
| name = email[:at] | ||
| } | ||
| return h.Queries.CreateUser(ctx, db.CreateUserParams{ | ||
| Name: name, | ||
| Email: email, | ||
| }) | ||
| } | ||
|
|
||
| func (h *Handler) checkSignupAllowed(email string) error { | ||
| if os.Getenv("ALLOW_SIGNUP") == "false" { | ||
| return ErrSignupProhibited | ||
| } | ||
|
|
||
| allowedDomainsStr := os.Getenv("ALLOWED_EMAIL_DOMAINS") | ||
| allowedEmailsStr := os.Getenv("ALLOWED_EMAILS") | ||
|
|
||
|
azaanaliraza marked this conversation as resolved.
|
||
| if allowedDomainsStr == "" && allowedEmailsStr == "" { | ||
| return nil | ||
| } | ||
|
|
||
| allowed := false | ||
| emailLower := strings.ToLower(email) | ||
|
|
||
| if allowedDomainsStr != "" { | ||
| for _, domain := range strings.Split(allowedDomainsStr, ",") { | ||
| domain = strings.TrimSpace(domain) | ||
| if domain != "" && strings.HasSuffix(emailLower, "@"+strings.ToLower(domain)) { | ||
| allowed = true | ||
| break | ||
| } | ||
| } | ||
| user, err = h.Queries.CreateUser(ctx, db.CreateUserParams{ | ||
| Name: name, | ||
| Email: email, | ||
| }) | ||
| if err != nil { | ||
| return db.User{}, err | ||
| } | ||
|
|
||
| if !allowed && allowedEmailsStr != "" { | ||
| for _, allowedEmail := range strings.Split(allowedEmailsStr, ",") { | ||
| allowedEmail = strings.TrimSpace(allowedEmail) | ||
| if allowedEmail != "" && strings.EqualFold(emailLower, allowedEmail) { | ||
| allowed = true | ||
| break | ||
| } | ||
| } | ||
|
azaanaliraza marked this conversation as resolved.
|
||
| } | ||
|
azaanaliraza marked this conversation as resolved.
|
||
| return user, nil | ||
|
|
||
| if !allowed { | ||
| return ErrEmailNotAllowed | ||
| } | ||
|
|
||
| return nil | ||
| } | ||
|
|
||
| func (h *Handler) SendCode(w http.ResponseWriter, r *http.Request) { | ||
|
|
@@ -110,6 +168,24 @@ func (h *Handler) SendCode(w http.ResponseWriter, r *http.Request) { | |
| return | ||
| } | ||
|
|
||
| // Short-circuit: If not already a user, check if signups are allowed for this email | ||
| _, err := h.Queries.GetUserByEmail(r.Context(), email) | ||
| if err != nil { | ||
| if !isNotFound(err) { | ||
| writeError(w, http.StatusInternalServerError, "failed to look up user") | ||
| return | ||
| } | ||
| if err := h.checkSignupAllowed(email); err != nil { | ||
| msg := "user registration is disabled" | ||
| var signupErr SignupError | ||
| if errors.As(err, &signupErr) { | ||
| msg = signupErr.Message | ||
| } | ||
| writeError(w, http.StatusForbidden, msg) | ||
| return | ||
| } | ||
| } | ||
|
|
||
| // Rate limit: max 1 code per 60 seconds per email | ||
| latest, err := h.Queries.GetLatestCodeByEmail(r.Context(), email) | ||
| if err == nil && time.Since(latest.CreatedAt.Time) < 60*time.Second { | ||
|
|
@@ -180,6 +256,11 @@ func (h *Handler) VerifyCode(w http.ResponseWriter, r *http.Request) { | |
|
|
||
| user, err := h.findOrCreateUser(r.Context(), email) | ||
| if err != nil { | ||
| var signupErr SignupError | ||
| if errors.As(err, &signupErr) { | ||
| writeError(w, http.StatusForbidden, signupErr.Error()) | ||
| return | ||
| } | ||
| writeError(w, http.StatusInternalServerError, "failed to create user") | ||
| return | ||
| } | ||
|
|
@@ -336,6 +417,11 @@ func (h *Handler) GoogleLogin(w http.ResponseWriter, r *http.Request) { | |
|
|
||
| user, err := h.findOrCreateUser(r.Context(), email) | ||
| if err != nil { | ||
| var signupErr SignupError | ||
| if errors.As(err, &signupErr) { | ||
| writeError(w, http.StatusForbidden, signupErr.Error()) | ||
| return | ||
| } | ||
| writeError(w, http.StatusInternalServerError, "failed to create user") | ||
| return | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.