Skip to content

Commit ac8ddfc

Browse files
mvdanbuger
authored andcommitted
all: use helper funcs for ctx's SessionData (TykTechnologies#734)
Also use a pointer type for more arguments and return values, as we need a pointer for the context value. This is because we need to be able to return nil (non-existing) in ctxGetSession. Besides, sticking SessionState into an interface{} will need a dereference anyway. Updates TykTechnologies#683.
1 parent 922b4e4 commit ac8ddfc

33 files changed

+346
-335
lines changed

api.go

+37-24
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func checkAndApplyTrialPeriod(keyName, apiId string, newSession *SessionState) {
9797
}
9898
}
9999

100-
func doAddOrUpdate(keyName string, newSession SessionState, dontReset bool) error {
100+
func doAddOrUpdate(keyName string, newSession *SessionState, dontReset bool) error {
101101
newSession.LastUpdated = strconv.Itoa(int(time.Now().Unix()))
102102

103103
if len(newSession.AccessRights) > 0 {
@@ -117,7 +117,7 @@ func doAddOrUpdate(keyName string, newSession SessionState, dontReset bool) erro
117117
}).Error("Could not add key for this API ID, API doesn't exist.")
118118
return errors.New("API must be active to add keys")
119119
}
120-
checkAndApplyTrialPeriod(keyName, apiId, &newSession)
120+
checkAndApplyTrialPeriod(keyName, apiId, newSession)
121121

122122
// Lets reset keys if they are edited by admin
123123
if !apiSpec.DontSetQuotasOnCreate {
@@ -127,7 +127,7 @@ func doAddOrUpdate(keyName string, newSession SessionState, dontReset bool) erro
127127
newSession.QuotaRenews = time.Now().Unix() + newSession.QuotaRenewalRate
128128
}
129129

130-
err := apiSpec.SessionManager.UpdateSession(keyName, newSession, getLifetime(apiSpec, &newSession))
130+
err := apiSpec.SessionManager.UpdateSession(keyName, newSession, getLifetime(apiSpec, newSession))
131131
if err != nil {
132132
return err
133133
}
@@ -145,8 +145,8 @@ func doAddOrUpdate(keyName string, newSession SessionState, dontReset bool) erro
145145
spec.SessionManager.ResetQuota(keyName, newSession)
146146
newSession.QuotaRenews = time.Now().Unix() + newSession.QuotaRenewalRate
147147
}
148-
checkAndApplyTrialPeriod(keyName, spec.APIID, &newSession)
149-
err := spec.SessionManager.UpdateSession(keyName, newSession, getLifetime(spec, &newSession))
148+
checkAndApplyTrialPeriod(keyName, spec.APIID, newSession)
149+
err := spec.SessionManager.UpdateSession(keyName, newSession, getLifetime(spec, newSession))
150150
if err != nil {
151151
return err
152152
}
@@ -248,7 +248,7 @@ func handleAddOrUpdate(keyName string, r *http.Request) ([]byte, int) {
248248

249249
}
250250
suppressReset := r.FormValue("suppress_reset") == "1"
251-
if err := doAddOrUpdate(keyName, newSession, suppressReset); err != nil {
251+
if err := doAddOrUpdate(keyName, &newSession, suppressReset); err != nil {
252252
return createError("Failed to create key, ensure security settings are correct."), 500
253253
}
254254

@@ -398,7 +398,7 @@ func handleDeleteKey(keyName, apiID string) ([]byte, int) {
398398
// Go through ALL managed API's and delete the key
399399
for _, spec := range ApiSpecRegister {
400400
spec.SessionManager.RemoveSession(keyName)
401-
spec.SessionManager.ResetQuota(keyName, SessionState{})
401+
spec.SessionManager.ResetQuota(keyName, &SessionState{})
402402
}
403403

404404
log.WithFields(logrus.Fields{
@@ -421,7 +421,7 @@ func handleDeleteKey(keyName, apiID string) ([]byte, int) {
421421
}
422422

423423
sessionManager.RemoveSession(keyName)
424-
sessionManager.ResetQuota(keyName, SessionState{})
424+
sessionManager.ResetQuota(keyName, &SessionState{})
425425

426426
statusObj := APIModifyKeySuccess{keyName, "ok", "deleted"}
427427
responseMessage, err = json.Marshal(&statusObj)
@@ -920,9 +920,9 @@ func orgHandler(w http.ResponseWriter, r *http.Request) {
920920
}
921921

922922
func handleOrgAddOrUpdate(keyName string, r *http.Request) ([]byte, int) {
923-
var newSession SessionState
923+
newSession := new(SessionState)
924924

925-
if err := json.NewDecoder(r.Body).Decode(&newSession); err != nil {
925+
if err := json.NewDecoder(r.Body).Decode(newSession); err != nil {
926926
log.Error("Couldn't decode new session object: ", err)
927927
return createError("Request malformed"), 400
928928
}
@@ -1146,8 +1146,8 @@ func createKeyHandler(w http.ResponseWriter, r *http.Request) {
11461146
return
11471147
}
11481148

1149-
var newSession SessionState
1150-
if err := json.NewDecoder(r.Body).Decode(&newSession); err != nil {
1149+
newSession := new(SessionState)
1150+
if err := json.NewDecoder(r.Body).Decode(newSession); err != nil {
11511151
log.WithFields(logrus.Fields{
11521152
"prefix": "api",
11531153
"status": "fail",
@@ -1168,14 +1168,14 @@ func createKeyHandler(w http.ResponseWriter, r *http.Request) {
11681168
for apiID := range newSession.AccessRights {
11691169
apiSpec := GetSpecForApi(apiID)
11701170
if apiSpec != nil {
1171-
checkAndApplyTrialPeriod(newKey, apiID, &newSession)
1171+
checkAndApplyTrialPeriod(newKey, apiID, newSession)
11721172
// If we have enabled HMAC checking for keys, we need to generate a secret for the client to use
11731173
if !apiSpec.DontSetQuotasOnCreate {
11741174
// Reset quota by default
11751175
apiSpec.SessionManager.ResetQuota(newKey, newSession)
11761176
newSession.QuotaRenews = time.Now().Unix() + newSession.QuotaRenewalRate
11771177
}
1178-
err := apiSpec.SessionManager.UpdateSession(newKey, newSession, getLifetime(apiSpec, &newSession))
1178+
err := apiSpec.SessionManager.UpdateSession(newKey, newSession, getLifetime(apiSpec, newSession))
11791179
if err != nil {
11801180
responseMessage := createError("Failed to create key - " + err.Error())
11811181
doJSONWrite(w, 403, responseMessage)
@@ -1209,13 +1209,13 @@ func createKeyHandler(w http.ResponseWriter, r *http.Request) {
12091209
}).Warning("No API Access Rights set on key session, adding key to all APIs.")
12101210

12111211
for _, spec := range ApiSpecRegister {
1212-
checkAndApplyTrialPeriod(newKey, spec.APIID, &newSession)
1212+
checkAndApplyTrialPeriod(newKey, spec.APIID, newSession)
12131213
if !spec.DontSetQuotasOnCreate {
12141214
// Reset quote by default
12151215
spec.SessionManager.ResetQuota(newKey, newSession)
12161216
newSession.QuotaRenews = time.Now().Unix() + newSession.QuotaRenewalRate
12171217
}
1218-
err := spec.SessionManager.UpdateSession(newKey, newSession, getLifetime(spec, &newSession))
1218+
err := spec.SessionManager.UpdateSession(newKey, newSession, getLifetime(spec, newSession))
12191219
if err != nil {
12201220
responseMessage := createError("Failed to create key - " + err.Error())
12211221
doJSONWrite(w, 403, responseMessage)
@@ -1740,20 +1740,19 @@ func healthCheckhandler(w http.ResponseWriter, r *http.Request) {
17401740

17411741
func UserRatesCheck() http.HandlerFunc {
17421742
return func(w http.ResponseWriter, r *http.Request) {
1743-
sessionState := context.Get(r, SessionData)
1744-
if sessionState == nil {
1743+
session := ctxGetSession(r)
1744+
if session == nil {
17451745
responseMessage := createError("Health checks are not enabled for this node")
17461746
doJSONWrite(w, 405, responseMessage)
17471747
return
17481748
}
17491749

1750-
userSession := sessionState.(SessionState)
17511750
returnSession := PublicSessionState{}
1752-
returnSession.Quota.QuotaRenews = userSession.QuotaRenews
1753-
returnSession.Quota.QuotaRemaining = userSession.QuotaRemaining
1754-
returnSession.Quota.QuotaMax = userSession.QuotaMax
1755-
returnSession.RateLimit.Rate = userSession.Rate
1756-
returnSession.RateLimit.Per = userSession.Per
1751+
returnSession.Quota.QuotaRenews = session.QuotaRenews
1752+
returnSession.Quota.QuotaRemaining = session.QuotaRemaining
1753+
returnSession.Quota.QuotaMax = session.QuotaMax
1754+
returnSession.RateLimit.Rate = session.Rate
1755+
returnSession.RateLimit.Per = session.Per
17571756

17581757
responseMessage, err := json.Marshal(returnSession)
17591758
if err != nil {
@@ -1836,3 +1835,17 @@ func ctxSetData(r *http.Request, m map[string]interface{}) {
18361835
}
18371836
context.Set(r, ContextData, m)
18381837
}
1838+
1839+
func ctxGetSession(r *http.Request) *SessionState {
1840+
if v := context.Get(r, SessionData); v != nil {
1841+
return v.(*SessionState)
1842+
}
1843+
return nil
1844+
}
1845+
1846+
func ctxSetSession(r *http.Request, s *SessionState) {
1847+
if s == nil {
1848+
panic("setting a nil context SessionData")
1849+
}
1850+
context.Set(r, SessionData, s)
1851+
}

api_test.go

+25-8
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ func TestHealthCheckEndpoint(t *testing.T) {
9393
}
9494
}
9595

96-
func createSampleSession() SessionState {
97-
return SessionState{
96+
func createSampleSession() *SessionState {
97+
return &SessionState{
9898
Rate: 5.0,
9999
Allowance: 5.0,
100100
LastCheck: time.Now().Unix(),
@@ -119,7 +119,7 @@ func TestApiHandler(t *testing.T) {
119119

120120
for _, uri := range uris {
121121
sampleKey := createSampleSession()
122-
body, _ := json.Marshal(&sampleKey)
122+
body, _ := json.Marshal(sampleKey)
123123

124124
recorder := httptest.NewRecorder()
125125

@@ -154,7 +154,7 @@ func TestApiHandler(t *testing.T) {
154154
func TestApiHandlerGetSingle(t *testing.T) {
155155
uri := "/tyk/apis/1"
156156
sampleKey := createSampleSession()
157-
body, _ := json.Marshal(&sampleKey)
157+
body, _ := json.Marshal(sampleKey)
158158

159159
recorder := httptest.NewRecorder()
160160

@@ -272,7 +272,7 @@ func TestKeyHandlerNewKey(t *testing.T) {
272272
for _, api_id := range []string{"1", "none", ""} {
273273
uri := "/tyk/keys/1234"
274274
sampleKey := createSampleSession()
275-
body, _ := json.Marshal(&sampleKey)
275+
body, _ := json.Marshal(sampleKey)
276276

277277
recorder := httptest.NewRecorder()
278278
param := make(url.Values)
@@ -309,7 +309,7 @@ func TestKeyHandlerUpdateKey(t *testing.T) {
309309
for _, api_id := range []string{"1", "none", ""} {
310310
uri := "/tyk/keys/1234"
311311
sampleKey := createSampleSession()
312-
body, _ := json.Marshal(&sampleKey)
312+
body, _ := json.Marshal(sampleKey)
313313

314314
recorder := httptest.NewRecorder()
315315
param := make(url.Values)
@@ -378,7 +378,7 @@ func TestKeyHandlerGetKey(t *testing.T) {
378378
func createKey() {
379379
uri := "/tyk/keys/1234"
380380
sampleKey := createSampleSession()
381-
body, _ := json.Marshal(&sampleKey)
381+
body, _ := json.Marshal(sampleKey)
382382

383383
recorder := httptest.NewRecorder()
384384
req, _ := http.NewRequest("POST", uri, bytes.NewReader(body))
@@ -429,7 +429,7 @@ func TestCreateKeyHandlerCreateNewKey(t *testing.T) {
429429
uri := "/tyk/keys/create"
430430

431431
sampleKey := createSampleSession()
432-
body, _ := json.Marshal(&sampleKey)
432+
body, _ := json.Marshal(sampleKey)
433433

434434
recorder := httptest.NewRecorder()
435435
param := make(url.Values)
@@ -715,3 +715,20 @@ func TestContextData(t *testing.T) {
715715
}()
716716
ctxSetData(r, nil)
717717
}
718+
719+
func TestContextSession(t *testing.T) {
720+
r := new(http.Request)
721+
if ctxGetSession(r) != nil {
722+
t.Fatal("expected ctxGetSession to return nil")
723+
}
724+
ctxSetSession(r, &SessionState{})
725+
if ctxGetSession(r) == nil {
726+
t.Fatal("expected ctxGetSession to return non-nil")
727+
}
728+
defer func() {
729+
if r := recover(); r == nil {
730+
t.Fatal("expected ctxSetSession of zero val to panic")
731+
}
732+
}()
733+
ctxSetSession(r, nil)
734+
}

auth_manager.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ type AuthorisationHandler interface {
2525
// SessionState objects, not identity
2626
type SessionHandler interface {
2727
Init(store StorageHandler)
28-
UpdateSession(keyName string, session SessionState, resetTTLTo int64) error
28+
UpdateSession(keyName string, session *SessionState, resetTTLTo int64) error
2929
RemoveSession(keyName string)
3030
GetSessionDetail(keyName string) (SessionState, bool)
3131
GetSessions(filter string) []string
3232
GetStore() StorageHandler
33-
ResetQuota(string, SessionState)
33+
ResetQuota(string, *SessionState)
3434
}
3535

3636
// DefaultAuthorisationManager implements AuthorisationHandler,
@@ -87,7 +87,7 @@ func (b *DefaultSessionManager) GetStore() StorageHandler {
8787
return b.Store
8888
}
8989

90-
func (b *DefaultSessionManager) ResetQuota(keyName string, session SessionState) {
90+
func (b *DefaultSessionManager) ResetQuota(keyName string, session *SessionState) {
9191

9292
rawKey := QuotaKeyPrefix + publicHash(keyName)
9393
log.WithFields(logrus.Fields{
@@ -105,7 +105,7 @@ func (b *DefaultSessionManager) ResetQuota(keyName string, session SessionState)
105105
}
106106

107107
// UpdateSession updates the session state in the storage engine
108-
func (b *DefaultSessionManager) UpdateSession(keyName string, session SessionState, resetTTLTo int64) error {
108+
func (b *DefaultSessionManager) UpdateSession(keyName string, session *SessionState, resetTTLTo int64) error {
109109
if !session.HasChanged() {
110110
log.Debug("Session has not changed, not updating")
111111
return nil

coprocess.go

+7-8
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,9 @@ func (c *CoProcessor) GetObjectFromRequest(r *http.Request) *coprocess.Object {
112112

113113
// Encode the session object (if not a pre-process & not a custom key check):
114114
if c.HookType != coprocess.HookType_Pre && c.HookType != coprocess.HookType_CustomKeyCheck {
115-
session := context.Get(r, SessionData)
115+
session := ctxGetSession(r)
116116
if session != nil {
117-
sessionState := session.(SessionState)
118-
object.Session = ProtoSessionState(sessionState)
117+
object.Session = ProtoSessionState(session)
119118
}
120119
}
121120

@@ -294,17 +293,17 @@ func (m *CoProcessMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Requ
294293
return errors.New("Key not authorised"), 403
295294
}
296295

297-
returnedSessionState := TykSessionState(returnObject.Session)
296+
returnedSession := TykSessionState(returnObject.Session)
298297

299298
if extractor == nil {
300-
sessionLifetime := getLifetime(m.Spec, &returnedSessionState)
299+
sessionLifetime := getLifetime(m.Spec, returnedSession)
301300
// This API is not using the ID extractor, but we've got a session:
302-
m.Spec.SessionManager.UpdateSession(authHeaderValue, returnedSessionState, sessionLifetime)
303-
context.Set(r, SessionData, returnedSessionState)
301+
m.Spec.SessionManager.UpdateSession(authHeaderValue, returnedSession, sessionLifetime)
302+
ctxSetSession(r, returnedSession)
304303
context.Set(r, AuthHeaderValue, authHeaderValue)
305304
} else {
306305
// The CP middleware did setup a session, we should pass it to the ID extractor (caching):
307-
extractor.PostProcess(r, returnedSessionState, sessionID)
306+
extractor.PostProcess(r, returnedSession, sessionID)
308307
}
309308
}
310309

0 commit comments

Comments
 (0)