Skip to content

Commit da11dc8

Browse files
authored
NET-1134:move oauth from CE build block to pro (#2919)
* move oauth from CE build block to pro * move oauth code and api handler under pro * move common func back to auth from pro/auth * change log level to Info for information logs * fix import issue
1 parent 7eb1cf4 commit da11dc8

File tree

16 files changed

+336
-314
lines changed

16 files changed

+336
-314
lines changed

auth/auth.go

Lines changed: 34 additions & 273 deletions
Original file line numberDiff line numberDiff line change
@@ -3,156 +3,25 @@ package auth
33
import (
44
"encoding/base64"
55
"encoding/json"
6-
"errors"
76
"fmt"
8-
"net/http"
9-
"strings"
10-
"time"
117

12-
"golang.org/x/crypto/bcrypt"
13-
"golang.org/x/exp/slog"
14-
"golang.org/x/oauth2"
15-
16-
"github.com/gorilla/websocket"
178
"github.com/gravitl/netmaker/logger"
189
"github.com/gravitl/netmaker/logic"
19-
"github.com/gravitl/netmaker/logic/pro/netcache"
2010
"github.com/gravitl/netmaker/models"
21-
"github.com/gravitl/netmaker/servercfg"
11+
"golang.org/x/crypto/bcrypt"
12+
"golang.org/x/exp/slog"
13+
"golang.org/x/oauth2"
2214
)
2315

2416
// == consts ==
2517
const (
26-
init_provider = "initprovider"
27-
get_user_info = "getuserinfo"
28-
handle_callback = "handlecallback"
29-
handle_login = "handlelogin"
30-
google_provider_name = "google"
31-
azure_ad_provider_name = "azure-ad"
32-
github_provider_name = "github"
33-
oidc_provider_name = "oidc"
34-
verify_user = "verifyuser"
35-
user_signin_length = 16
36-
node_signin_length = 64
37-
headless_signin_length = 32
18+
node_signin_length = 64
3819
)
3920

40-
// OAuthUser - generic OAuth strategy user
41-
type OAuthUser struct {
42-
Name string `json:"name" bson:"name"`
43-
Email string `json:"email" bson:"email"`
44-
Login string `json:"login" bson:"login"`
45-
UserPrincipalName string `json:"userPrincipalName" bson:"userPrincipalName"`
46-
AccessToken string `json:"accesstoken" bson:"accesstoken"`
47-
}
48-
4921
var (
5022
auth_provider *oauth2.Config
51-
upgrader = websocket.Upgrader{}
5223
)
5324

54-
func getCurrentAuthFunctions() map[string]interface{} {
55-
var authInfo = servercfg.GetAuthProviderInfo()
56-
var authProvider = authInfo[0]
57-
switch authProvider {
58-
case google_provider_name:
59-
return google_functions
60-
case azure_ad_provider_name:
61-
return azure_ad_functions
62-
case github_provider_name:
63-
return github_functions
64-
case oidc_provider_name:
65-
return oidc_functions
66-
default:
67-
return nil
68-
}
69-
}
70-
71-
// InitializeAuthProvider - initializes the auth provider if any is present
72-
func InitializeAuthProvider() string {
73-
var functions = getCurrentAuthFunctions()
74-
if functions == nil {
75-
return ""
76-
}
77-
logger.Log(0, "setting oauth secret")
78-
var err = logic.SetAuthSecret(logic.RandomString(64))
79-
if err != nil {
80-
logger.FatalLog("failed to set auth_secret", err.Error())
81-
}
82-
var authInfo = servercfg.GetAuthProviderInfo()
83-
var serverConn = servercfg.GetAPIHost()
84-
if strings.Contains(serverConn, "localhost") || strings.Contains(serverConn, "127.0.0.1") {
85-
serverConn = "http://" + serverConn
86-
logger.Log(1, "localhost OAuth detected, proceeding with insecure http redirect: (", serverConn, ")")
87-
} else {
88-
serverConn = "https://" + serverConn
89-
logger.Log(1, "external OAuth detected, proceeding with https redirect: ("+serverConn+")")
90-
}
91-
92-
if authInfo[0] == "oidc" {
93-
functions[init_provider].(func(string, string, string, string))(serverConn+"/api/oauth/callback", authInfo[1], authInfo[2], authInfo[3])
94-
return authInfo[0]
95-
}
96-
97-
functions[init_provider].(func(string, string, string))(serverConn+"/api/oauth/callback", authInfo[1], authInfo[2])
98-
return authInfo[0]
99-
}
100-
101-
// HandleAuthCallback - handles oauth callback
102-
// Note: not included in API reference as part of the OAuth process itself.
103-
func HandleAuthCallback(w http.ResponseWriter, r *http.Request) {
104-
if auth_provider == nil {
105-
handleOauthNotConfigured(w)
106-
return
107-
}
108-
var functions = getCurrentAuthFunctions()
109-
if functions == nil {
110-
return
111-
}
112-
state, _ := getStateAndCode(r)
113-
_, err := netcache.Get(state) // if in netcache proceeed with node registration login
114-
if err == nil || errors.Is(err, netcache.ErrExpired) {
115-
switch len(state) {
116-
case node_signin_length:
117-
logger.Log(1, "proceeding with host SSO callback")
118-
HandleHostSSOCallback(w, r)
119-
case headless_signin_length:
120-
logger.Log(1, "proceeding with headless SSO callback")
121-
HandleHeadlessSSOCallback(w, r)
122-
default:
123-
logger.Log(1, "invalid state length: ", fmt.Sprintf("%d", len(state)))
124-
}
125-
} else { // handle normal login
126-
functions[handle_callback].(func(http.ResponseWriter, *http.Request))(w, r)
127-
}
128-
}
129-
130-
// swagger:route GET /api/oauth/login nodes HandleAuthLogin
131-
//
132-
// Handles OAuth login.
133-
//
134-
// Schemes: https
135-
//
136-
// Security:
137-
// oauth
138-
// Responses:
139-
// 200: okResponse
140-
func HandleAuthLogin(w http.ResponseWriter, r *http.Request) {
141-
if auth_provider == nil {
142-
handleOauthNotConfigured(w)
143-
return
144-
}
145-
var functions = getCurrentAuthFunctions()
146-
if functions == nil {
147-
return
148-
}
149-
if servercfg.GetFrontendURL() == "" {
150-
handleOauthNotConfigured(w)
151-
return
152-
}
153-
functions[handle_login].(func(http.ResponseWriter, *http.Request))(w, r)
154-
}
155-
15625
// IsOauthUser - returns
15726
func IsOauthUser(user *models.User) error {
15827
var currentValue, err = FetchPassValue("")
@@ -163,81 +32,30 @@ func IsOauthUser(user *models.User) error {
16332
return bCryptErr
16433
}
16534

166-
// HandleHeadlessSSO - handles the OAuth login flow for headless interfaces such as Netmaker CLI via websocket
167-
func HandleHeadlessSSO(w http.ResponseWriter, r *http.Request) {
168-
conn, err := upgrader.Upgrade(w, r, nil)
169-
if err != nil {
170-
logger.Log(0, "error during connection upgrade for headless sign-in:", err.Error())
171-
return
172-
}
173-
if conn == nil {
174-
logger.Log(0, "failed to establish web-socket connection during headless sign-in")
175-
return
176-
}
177-
defer conn.Close()
35+
func FetchPassValue(newValue string) (string, error) {
17836

179-
req := &netcache.CValue{User: "", Pass: ""}
180-
stateStr := logic.RandomString(headless_signin_length)
181-
if err = netcache.Set(stateStr, req); err != nil {
182-
logger.Log(0, "Failed to process sso request -", err.Error())
183-
return
37+
type valueHolder struct {
38+
Value string `json:"value" bson:"value"`
18439
}
185-
186-
timeout := make(chan bool, 1)
187-
answer := make(chan string, 1)
188-
defer close(answer)
189-
defer close(timeout)
190-
191-
if auth_provider == nil {
192-
if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
193-
logger.Log(0, "error during message writing:", err.Error())
194-
}
195-
return
40+
newValueHolder := valueHolder{}
41+
var currentValue, err = logic.FetchAuthSecret()
42+
if err != nil {
43+
return "", err
19644
}
197-
redirectUrl = fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr)
198-
if err = conn.WriteMessage(websocket.TextMessage, []byte(redirectUrl)); err != nil {
199-
logger.Log(0, "error during message writing:", err.Error())
45+
var unmarshErr = json.Unmarshal([]byte(currentValue), &newValueHolder)
46+
if unmarshErr != nil {
47+
return "", unmarshErr
20048
}
20149

202-
go func() {
203-
for {
204-
cachedReq, err := netcache.Get(stateStr)
205-
if err != nil {
206-
if strings.Contains(err.Error(), "expired") {
207-
logger.Log(0, "timeout occurred while waiting for SSO")
208-
timeout <- true
209-
break
210-
}
211-
continue
212-
} else if cachedReq.Pass != "" {
213-
logger.Log(0, "SSO process completed for user ", cachedReq.User)
214-
answer <- cachedReq.Pass
215-
break
216-
}
217-
time.Sleep(500) // try it 2 times per second to see if auth is completed
218-
}
219-
}()
220-
221-
select {
222-
case result := <-answer:
223-
if err = conn.WriteMessage(websocket.TextMessage, []byte(result)); err != nil {
224-
logger.Log(0, "Error during message writing:", err.Error())
225-
}
226-
case <-timeout:
227-
logger.Log(0, "Authentication server time out for headless SSO login")
228-
if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
229-
logger.Log(0, "Error during message writing:", err.Error())
230-
}
231-
}
232-
if err = netcache.Del(stateStr); err != nil {
233-
logger.Log(0, "failed to remove SSO cache entry", err.Error())
234-
}
235-
if err = conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
236-
logger.Log(0, "write close:", err.Error())
50+
var b64CurrentValue, b64Err = base64.StdEncoding.DecodeString(newValueHolder.Value)
51+
if b64Err != nil {
52+
logger.Log(0, "could not decode pass")
53+
return "", nil
23754
}
55+
return string(b64CurrentValue), nil
23856
}
23957

240-
// == private methods ==
58+
// == private ==
24159

24260
func addUser(email string) error {
24361
var hasSuperAdmin, err = logic.HasSuperAdmin()
@@ -247,7 +65,7 @@ func addUser(email string) error {
24765
} // generate random password to adapt to current model
24866
var newPass, fetchErr = FetchPassValue("")
24967
if fetchErr != nil {
250-
slog.Error("failed to get password", "error", err.Error())
68+
slog.Error("failed to get password", "error", fetchErr.Error())
25169
return fetchErr
25270
}
25371
var newUser = models.User{
@@ -273,77 +91,20 @@ func addUser(email string) error {
27391
return nil
27492
}
27593

276-
func FetchPassValue(newValue string) (string, error) {
277-
278-
type valueHolder struct {
279-
Value string `json:"value" bson:"value"`
280-
}
281-
newValueHolder := valueHolder{}
282-
var currentValue, err = logic.FetchAuthSecret()
283-
if err != nil {
284-
return "", err
285-
}
286-
var unmarshErr = json.Unmarshal([]byte(currentValue), &newValueHolder)
287-
if unmarshErr != nil {
288-
return "", unmarshErr
289-
}
290-
291-
var b64CurrentValue, b64Err = base64.StdEncoding.DecodeString(newValueHolder.Value)
292-
if b64Err != nil {
293-
logger.Log(0, "could not decode pass")
294-
return "", nil
295-
}
296-
return string(b64CurrentValue), nil
297-
}
298-
299-
func getStateAndCode(r *http.Request) (string, string) {
300-
var state, code string
301-
if r.FormValue("state") != "" && r.FormValue("code") != "" {
302-
state = r.FormValue("state")
303-
code = r.FormValue("code")
304-
} else if r.URL.Query().Get("state") != "" && r.URL.Query().Get("code") != "" {
305-
state = r.URL.Query().Get("state")
306-
code = r.URL.Query().Get("code")
307-
}
308-
309-
return state, code
310-
}
311-
312-
func (user *OAuthUser) getUserName() string {
313-
var userName string
314-
if user.Email != "" {
315-
userName = user.Email
316-
} else if user.Login != "" {
317-
userName = user.Login
318-
} else if user.UserPrincipalName != "" {
319-
userName = user.UserPrincipalName
320-
} else if user.Name != "" {
321-
userName = user.Name
322-
}
323-
return userName
324-
}
94+
func isUserIsAllowed(username, network string, shouldAddUser bool) (*models.User, error) {
32595

326-
func isStateCached(state string) bool {
327-
_, err := netcache.Get(state)
328-
return err == nil || strings.Contains(err.Error(), "expired")
329-
}
330-
331-
// isEmailAllowed - checks if email is allowed to signup
332-
func isEmailAllowed(email string) bool {
333-
allowedDomains := servercfg.GetAllowedEmailDomains()
334-
domains := strings.Split(allowedDomains, ",")
335-
if len(domains) == 1 && domains[0] == "*" {
336-
return true
337-
}
338-
emailParts := strings.Split(email, "@")
339-
if len(emailParts) < 2 {
340-
return false
341-
}
342-
baseDomainOfEmail := emailParts[1]
343-
for _, domain := range domains {
344-
if domain == baseDomainOfEmail {
345-
return true
96+
user, err := logic.GetUser(username)
97+
if err != nil && shouldAddUser { // user must not exist, so try to make one
98+
if err = addUser(username); err != nil {
99+
logger.Log(0, "failed to add user", username, "during a node SSO network join on network", network)
100+
// response := returnErrTemplate(user.UserName, "failed to add user", state, reqKeyIf)
101+
// w.WriteHeader(http.StatusInternalServerError)
102+
// w.Write(response)
103+
return nil, fmt.Errorf("failed to add user to system")
346104
}
105+
logger.Log(0, "user", username, "was added during a node SSO network join on network", network)
106+
user, _ = logic.GetUser(username)
347107
}
348-
return false
108+
109+
return user, nil
349110
}

auth/host_session.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func SessionHandler(conn *websocket.Conn) {
121121
return
122122
}
123123
logger.Log(0, "user registration attempted with host:", registerMessage.RegisterHost.Name, "via SSO")
124-
redirectUrl = fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr)
124+
redirectUrl := fmt.Sprintf("https://%s/api/oauth/register/%s", servercfg.GetAPIConnString(), stateStr)
125125
err = conn.WriteMessage(messageType, []byte(redirectUrl))
126126
if err != nil {
127127
logger.Log(0, "error during message writing:", err.Error())

controllers/node.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ func Authorize(hostAllowed, networkCheck bool, authNetwork string, next http.Han
202202
}
203203

204204
isnetadmin := issuperadmin || isadmin
205-
if errN == nil && (issuperadmin || isadmin) {
205+
if issuperadmin || isadmin {
206206
nodeID = "mastermac"
207207
isAuthorized = true
208208
r.Header.Set("ismasterkey", "yes")

controllers/user.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,6 @@ func userHandlers(r *mux.Router) {
3232
r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, http.HandlerFunc(deleteUser))).Methods(http.MethodDelete)
3333
r.HandleFunc("/api/users/{username}", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUser)))).Methods(http.MethodGet)
3434
r.HandleFunc("/api/users", logic.SecurityCheck(true, http.HandlerFunc(getUsers))).Methods(http.MethodGet)
35-
r.HandleFunc("/api/oauth/login", auth.HandleAuthLogin).Methods(http.MethodGet)
36-
r.HandleFunc("/api/oauth/callback", auth.HandleAuthCallback).Methods(http.MethodGet)
37-
r.HandleFunc("/api/oauth/headless", auth.HandleHeadlessSSO)
38-
r.HandleFunc("/api/oauth/register/{regKey}", auth.RegisterHostSSO).Methods(http.MethodGet)
3935
r.HandleFunc("/api/users_pending", logic.SecurityCheck(true, http.HandlerFunc(getPendingUsers))).Methods(http.MethodGet)
4036
r.HandleFunc("/api/users_pending", logic.SecurityCheck(true, http.HandlerFunc(deleteAllPendingUsers))).Methods(http.MethodDelete)
4137
r.HandleFunc("/api/users_pending/user/{username}", logic.SecurityCheck(true, http.HandlerFunc(deletePendingUser))).Methods(http.MethodDelete)
@@ -119,7 +115,7 @@ func authenticateUser(response http.ResponseWriter, request *http.Request) {
119115
successJSONResponse, jsonError := json.Marshal(successResponse)
120116
if jsonError != nil {
121117
logger.Log(0, username,
122-
"error marshalling resp: ", err.Error())
118+
"error marshalling resp: ", jsonError.Error())
123119
logic.ReturnErrorResponse(response, request, errorResponse)
124120
return
125121
}

0 commit comments

Comments
 (0)