diff --git a/api/client.go b/api/client.go index ddbfea099..0e647b675 100644 --- a/api/client.go +++ b/api/client.go @@ -18,7 +18,6 @@ import ( "bytes" "context" "errors" - "io" "net" "net/http" "net/url" @@ -132,36 +131,26 @@ func (c *httpClient) Do(ctx context.Context, req *http.Request) (*http.Response, req = req.WithContext(ctx) } resp, err := c.client.Do(req) - defer func() { - if resp != nil { - _, _ = io.Copy(io.Discard, resp.Body) - _ = resp.Body.Close() - } - }() - if err != nil { return nil, nil, err } var body []byte - done := make(chan struct{}) + done := make(chan error, 1) go func() { var buf bytes.Buffer - // TODO(bwplotka): Add LimitReader for too long err messages (e.g. limit by 1KB) - _, err = buf.ReadFrom(resp.Body) + _, err := buf.ReadFrom(resp.Body) body = buf.Bytes() - close(done) + done <- err }() select { case <-ctx.Done(): + resp.Body.Close() <-done - err = resp.Body.Close() - if err == nil { - err = ctx.Err() - } - case <-done: + return resp, nil, ctx.Err() + case err = <-done: + resp.Body.Close() + return resp, body, err } - - return resp, body, err } diff --git a/api/client_test.go b/api/client_test.go index 874387868..720084aa0 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -16,11 +16,13 @@ package api import ( "bytes" "context" + "errors" "fmt" "net/http" "net/http/httptest" "net/url" "testing" + "time" ) func TestConfig(t *testing.T) { @@ -116,6 +118,52 @@ func TestClientURL(t *testing.T) { } } +func TestDoContextCancellation(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("partial")) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + + <-r.Context().Done() + })) + + defer ts.Close() + + client, err := NewClient(Config{ + Address: ts.URL, + }) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + + req, err := http.NewRequest(http.MethodGet, ts.URL, nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + start := time.Now() + resp, body, err := client.Do(ctx, req) + elapsed := time.Since(start) + + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected error %v, got: %v", context.DeadlineExceeded, err) + } + if body != nil { + t.Errorf("expected no body due to cancellation, got: %q", string(body)) + } + if elapsed > 200*time.Millisecond { + t.Errorf("Do did not return promptly on cancellation: took %v", elapsed) + } + + if resp != nil && resp.Body != nil { + resp.Body.Close() + } +} + // Serve any http request with a response of N KB of spaces. type serveSpaces struct { sizeKB int