Skip to content

Commit

Permalink
Add DefaultRedirectPolicy
Browse files Browse the repository at this point in the history
  • Loading branch information
rosahaj committed Dec 1, 2024
1 parent 24b0c84 commit 87ad469
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 13 deletions.
16 changes: 3 additions & 13 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,20 +321,10 @@ func (c *Client) GetTLSClientConfig() *tls.Config {
return c.TLSClientConfig
}

func (c *Client) defaultCheckRedirect(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return errors.New("stopped after 10 redirects")
}
if c.DebugLog {
c.log.Debugf("<redirect> %s %s", req.Method, req.URL.String())
}
return nil
}

// SetRedirectPolicy set the RedirectPolicy which controls the behavior of receiving redirect
// responses (usually responses with 301 and 302 status code), see the predefined
// AllowedDomainRedirectPolicy, AllowedHostRedirectPolicy, MaxRedirectPolicy, NoRedirectPolicy,
// SameDomainRedirectPolicy and SameHostRedirectPolicy.
// AllowedDomainRedirectPolicy, AllowedHostRedirectPolicy, DefaultRedirectPolicy, MaxRedirectPolicy,
// NoRedirectPolicy, SameDomainRedirectPolicy and SameHostRedirectPolicy.
func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client {
if len(policies) == 0 {
return c
Expand Down Expand Up @@ -1565,7 +1555,7 @@ func C() *Client {
xmlUnmarshal: xml.Unmarshal,
cookiejarFactory: memoryCookieJarFactory,
}
httpClient.CheckRedirect = c.defaultCheckRedirect
c.SetRedirectPolicy(DefaultRedirectPolicy())
c.initCookieJar()

c.initTransport()
Expand Down
4 changes: 4 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,10 @@ func TestRedirect(t *testing.T) {
tests.AssertNotNil(t, err)
tests.AssertContains(t, err.Error(), "stopped after 3 redirects", true)

_, err = tc().SetRedirectPolicy(MaxRedirectPolicy(20)).SetRedirectPolicy(DefaultRedirectPolicy()).R().Get("/unlimited-redirect")
tests.AssertNotNil(t, err)
tests.AssertContains(t, err.Error(), "stopped after 10 redirects", true)

_, err = tc().SetRedirectPolicy(SameDomainRedirectPolicy()).R().Get("/redirect-to-other")
tests.AssertNotNil(t, err)
tests.AssertContains(t, err.Error(), "different domain name is not allowed", true)
Expand Down
5 changes: 5 additions & 0 deletions redirect.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ func MaxRedirectPolicy(noOfRedirect int) RedirectPolicy {
}
}

// DefaultRedirectPolicy allows up to 10 redirects
func DefaultRedirectPolicy() RedirectPolicy {
return MaxRedirectPolicy(10)
}

// NoRedirectPolicy disable redirect behaviour
func NoRedirectPolicy() RedirectPolicy {
return func(req *http.Request, via []*http.Request) error {
Expand Down

0 comments on commit 87ad469

Please sign in to comment.