diff --git a/client.go b/client.go index 998c3a87..744b9682 100644 --- a/client.go +++ b/client.go @@ -595,6 +595,25 @@ func (c *Client) SetCommonHeader(key, value string) *Client { return c } +// SetCommonHeaderNonCanonical set a header for all requests which key is a +// non-canonical key (keep case unchanged), only valid for HTTP/1.1. +func (c *Client) SetCommonHeaderNonCanonical(key, value string) *Client { + if c.Headers == nil { + c.Headers = make(http.Header) + } + c.Headers[key] = append(c.Headers[key], value) + return c +} + +// SetCommonHeadersNonCanonical set headers for all requests which key is a +// non-canonical key (keep case unchanged), only valid for HTTP/1.1. +func (c *Client) SetCommonHeadersNonCanonical(hdrs map[string]string) *Client { + for k, v := range hdrs { + c.SetCommonHeaderNonCanonical(k, v) + } + return c +} + // SetCommonContentType set the `Content-Type` header for all requests. func (c *Client) SetCommonContentType(ct string) *Client { c.SetCommonHeader(hdrContentTypeKey, ct) diff --git a/client_test.go b/client_test.go index 83123093..e4bce5b8 100644 --- a/client_test.go +++ b/client_test.go @@ -144,6 +144,11 @@ func TestSetCommonHeader(t *testing.T) { assertEqual(t, "my-value", c.Headers.Get("my-header")) } +func TestSetCommonHeaderNonCanonical(t *testing.T) { + c := tc().SetCommonHeaderNonCanonical("my-Header", "my-value") + assertEqual(t, "my-value", c.Headers["my-Header"][0]) +} + func TestSetCommonHeaders(t *testing.T) { c := tc().SetCommonHeaders(map[string]string{ "header1": "value1", @@ -153,6 +158,13 @@ func TestSetCommonHeaders(t *testing.T) { assertEqual(t, "value2", c.Headers.Get("header2")) } +func TestSetCommonHeadersNonCanonical(t *testing.T) { + c := tc().SetCommonHeadersNonCanonical(map[string]string{ + "my-Header": "my-value", + }) + assertEqual(t, "my-value", c.Headers["my-Header"][0]) +} + func TestSetCommonBasicAuth(t *testing.T) { c := tc().SetCommonBasicAuth("imroc", "123456") assertEqual(t, "Basic aW1yb2M6MTIzNDU2", c.Headers.Get("Authorization")) diff --git a/middleware.go b/middleware.go index 4434e5da..1bce9f6a 100644 --- a/middleware.go +++ b/middleware.go @@ -387,9 +387,11 @@ func parseRequestHeader(c *Client, r *Request) error { if r.Headers == nil { r.Headers = make(http.Header) } - for k := range c.Headers { - if r.Headers.Get(k) == "" { - r.Headers.Add(k, c.Headers.Get(k)) + for k, vs := range c.Headers { + for _, v := range vs { + if len(r.Headers[k]) == 0 { + r.Headers[k] = append(r.Headers[k], v) + } } } return nil diff --git a/request_test.go b/request_test.go index b0ca95aa..f88532d1 100644 --- a/request_test.go +++ b/request_test.go @@ -528,6 +528,13 @@ func TestSetHeaderNonCanonical(t *testing.T) { Get("/header") assertSuccess(t, resp, err) assertEqual(t, true, strings.Contains(resp.Dump(), key)) + + c.SetCommonHeaderNonCanonical(key, "test") + resp, err = c.R(). + EnableDumpWithoutResponse(). + Get("/header") + assertSuccess(t, resp, err) + assertEqual(t, true, strings.Contains(resp.Dump(), key)) } func TestQueryParam(t *testing.T) {