From 1dbe6705f23ef219148b513dda4645b03be12f91 Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Sun, 2 Feb 2025 13:01:25 +0100 Subject: [PATCH] fix: return `return_to` code if already authenticated --- selfservice/strategy/oidc/strategy.go | 18 +++++---- selfservice/strategy/oidc/strategy_test.go | 46 ++++++++++++++++------ 2 files changed, 43 insertions(+), 21 deletions(-) diff --git a/selfservice/strategy/oidc/strategy.go b/selfservice/strategy/oidc/strategy.go index f04f06d35899..74098b87e904 100644 --- a/selfservice/strategy/oidc/strategy.go +++ b/selfservice/strategy/oidc/strategy.go @@ -365,14 +365,6 @@ func (s *Strategy) alreadyAuthenticated(ctx context.Context, w http.ResponseWrit if _, ok := f.(*settings.Flow); ok { // ignore this if it's a settings flow } else if !isForced(f) { - if flowID, ok := registrationOrLoginFlowID(f); ok { - if _, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(ctx, flowID); hasCode { - err := s.d.SessionTokenExchangePersister().UpdateSessionOnExchanger(ctx, flowID, sess.ID) - if err != nil { - return false, err - } - } - } returnTo := s.d.Config().SelfServiceBrowserDefaultReturnTo(ctx) if redirecter, ok := f.(flow.FlowWithRedirect); ok { r, err := x.SecureRedirectTo(r, returnTo, redirecter.SecureRedirectToOpts(ctx, s.d)...) @@ -380,6 +372,16 @@ func (s *Strategy) alreadyAuthenticated(ctx context.Context, w http.ResponseWrit returnTo = r } } + if flowID, ok := registrationOrLoginFlowID(f); ok { + if codes, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(ctx, flowID); hasCode { + if err := s.d.SessionTokenExchangePersister().UpdateSessionOnExchanger(ctx, flowID, sess.ID); err != nil { + return false, err + } + q := returnTo.Query() + q.Set("code", codes.ReturnToCode) + returnTo.RawQuery = q.Encode() + } + } http.Redirect(w, r, returnTo.String(), http.StatusSeeOther) return true, nil } diff --git a/selfservice/strategy/oidc/strategy_test.go b/selfservice/strategy/oidc/strategy_test.go index 927d0c27457b..3c7a2d3a1e91 100644 --- a/selfservice/strategy/oidc/strategy_test.go +++ b/selfservice/strategy/oidc/strategy_test.go @@ -202,8 +202,8 @@ func TestStrategy(t *testing.T) { return res, body } - makeAPICodeFlowRequest := func(t *testing.T, provider, action string) (returnToURL *url.URL) { - res, err := testhelpers.NewDebugClient(t).Post(action, "application/json", strings.NewReader(fmt.Sprintf(`{ + makeAPICodeFlowRequest := func(t *testing.T, provider, action string, cookieJar *cookiejar.Jar) (returnToURL *url.URL) { + res, err := http.Post(action, "application/json", strings.NewReader(fmt.Sprintf(`{ "method": "oidc", "provider": %q }`, provider))) @@ -212,7 +212,7 @@ func TestStrategy(t *testing.T) { var changeLocation flow.BrowserLocationChangeRequiredError require.NoError(t, json.NewDecoder(res.Body).Decode(&changeLocation)) - res, err = testhelpers.NewClientWithCookieJar(t, nil, nil).Get(changeLocation.RedirectBrowserTo) + res, err = testhelpers.NewClientWithCookieJar(t, cookieJar, nil).Get(changeLocation.RedirectBrowserTo) require.NoError(t, err) returnToURL = res.Request.URL @@ -839,12 +839,12 @@ func TestStrategy(t *testing.T) { t.Run("suite=API with session token exchange code", func(t *testing.T) { scope = []string{"openid"} - loginOrRegister := func(t *testing.T, flowID uuid.UUID, code string) { + loginOrRegister := func(t *testing.T, flowID uuid.UUID, code string, cookieJar *cookiejar.Jar) { _, err := exchangeCodeForToken(t, sessiontokenexchange.Codes{InitCode: code}) require.Error(t, err) action := assertFormValues(t, flowID, "valid") - returnToURL := makeAPICodeFlowRequest(t, "valid", action) + returnToURL := makeAPICodeFlowRequest(t, "valid", action, cookieJar) returnToCode := returnToURL.Query().Get("code") assert.NotEmpty(t, code, "code query param was empty in the return_to URL") @@ -857,18 +857,18 @@ func TestStrategy(t *testing.T) { assert.NotEmpty(t, codeResponse.Token) assert.Equal(t, subject, gjson.GetBytes(codeResponse.Session.Identity.Traits, "subject").String()) } - performRegistration := func(t *testing.T) { + performRegistration := func(t *testing.T, cookieJar *cookiejar.Jar) { f := newAPIRegistrationFlow(t, returnTS.URL+"?return_session_token_exchange_code=true&return_to=/app_code", 1*time.Minute) - loginOrRegister(t, f.ID, f.SessionTokenExchangeCode) + loginOrRegister(t, f.ID, f.SessionTokenExchangeCode, cookieJar) } - performLogin := func(t *testing.T) { + performLogin := func(t *testing.T, cookieJar *cookiejar.Jar) { f := newAPILoginFlow(t, returnTS.URL+"?return_session_token_exchange_code=true&return_to=/app_code", 1*time.Minute) - loginOrRegister(t, f.ID, f.SessionTokenExchangeCode) + loginOrRegister(t, f.ID, f.SessionTokenExchangeCode, cookieJar) } for _, tc := range []struct { name string - first, then func(*testing.T) + first, then func(*testing.T, *cookiejar.Jar) }{{ name: "login-twice", first: performLogin, then: performLogin, @@ -884,10 +884,30 @@ func TestStrategy(t *testing.T) { }} { t.Run("case="+tc.name, func(t *testing.T) { subject = tc.name + "-api-code-testing@ory.sh" - tc.first(t) - tc.then(t) + tc.first(t, nil) + tc.then(t, nil) }) } + + t.Run("case=should return exchange code even if already authenticated", func(t *testing.T) { + subject = "existing-session-api-code-testing@ory.sh" + jar := x.Must(cookiejar.New(nil)) + + t.Run("step=register and create a session", func(t *testing.T) { + returnTo := "/foo" + r := newBrowserLoginFlow(t, fmt.Sprintf("%s?return_to=%s", returnTS.URL, returnTo), time.Minute) + action := assertFormValues(t, r.ID, "valid") + + res, body := makeRequestWithCookieJar(t, "valid", action, url.Values{}, jar, nil) + assert.True(t, strings.HasSuffix(res.Request.URL.String(), returnTo)) + assertIdentity(t, res, body) + }) + + t.Run("step=perform login and get exchange code", func(t *testing.T) { + performLogin(t, jar) + }) + }) + t.Run("case=should use redirect_to URL on failure", func(t *testing.T) { ctx := context.Background() subject = "existing-subject-api-code-testing@ory.sh" @@ -905,7 +925,7 @@ func TestStrategy(t *testing.T) { require.Error(t, err) action := assertFormValues(t, f.ID, "valid") - returnToURL := makeAPICodeFlowRequest(t, "valid", action) + returnToURL := makeAPICodeFlowRequest(t, "valid", action, nil) returnedFlow := returnToURL.Query().Get("flow") require.NotEmpty(t, returnedFlow, "flow query param was empty in the return_to URL")