diff --git a/internal/httputil/httputil.go b/internal/httputil/httputil.go index db0dd5f127..761ad14df5 100644 --- a/internal/httputil/httputil.go +++ b/internal/httputil/httputil.go @@ -10,9 +10,17 @@ import ( "net/http" ) -// DefaultHTTPClient is the default HTTP client used across the driver. -var DefaultHTTPClient = &http.Client{ - Transport: http.DefaultTransport.(*http.Transport).Clone(), +var DefaultHTTPClient = &http.Client{} + +// NewHTTPClient will return the globally-defined DefaultHTTPClient, updating +// the transport if it differs from the http package DefaultTransport. +func NewHTTPClient() *http.Client { + client := DefaultHTTPClient + if _, ok := http.DefaultTransport.(*http.Transport); !ok { + client.Transport = http.DefaultTransport + } + + return client } // CloseIdleHTTPConnections closes any connections which were previously diff --git a/internal/httputil/httputil_test.go b/internal/httputil/httputil_test.go new file mode 100644 index 0000000000..124a704904 --- /dev/null +++ b/internal/httputil/httputil_test.go @@ -0,0 +1,41 @@ +// Copyright (C) MongoDB, Inc. 2022-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package httputil + +import ( + "net/http" + "testing" + + "go.mongodb.org/mongo-driver/v2/internal/assert" +) + +type nonDefaultTransport struct{} + +func (*nonDefaultTransport) RoundTrip(*http.Request) (*http.Response, error) { return nil, nil } + +func TestDefaultHTTPClientTransport(t *testing.T) { + t.Run("default", func(t *testing.T) { + client := NewHTTPClient() + + val := assert.ObjectsAreEqual(http.DefaultClient, client) + + assert.True(t, val) + assert.Equal(t, DefaultHTTPClient, client) + }) + + t.Run("non-default global transport", func(t *testing.T) { + http.DefaultTransport = &nonDefaultTransport{} + + client := NewHTTPClient() + + val := assert.ObjectsAreEqual(&nonDefaultTransport{}, client.Transport) + + assert.True(t, val) + assert.Equal(t, DefaultHTTPClient, client) + assert.NotEqual(t, http.DefaultClient, client) // Sanity Check + }) +} diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index 925325c6c5..5e2f186cfb 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -302,7 +302,7 @@ type ClientOptions struct { // Client creates a new ClientOptions instance. func Client() *ClientOptions { opts := &ClientOptions{} - opts = opts.SetHTTPClient(httputil.DefaultHTTPClient) + opts = opts.SetHTTPClient(httputil.NewHTTPClient()) return opts } diff --git a/mongo/options/clientoptions_test.go b/mongo/options/clientoptions_test.go index 907584d5f0..2fe91a4399 100644 --- a/mongo/options/clientoptions_test.go +++ b/mongo/options/clientoptions_test.go @@ -525,6 +525,22 @@ func TestClientOptions(t *testing.T) { }) } +type nonDefaultTransport struct{} + +func (*nonDefaultTransport) RoundTrip(*http.Request) (*http.Response, error) { return nil, nil } + +func TestClientHTTPTransport(t *testing.T) { + t.Run("Default client", func(t *testing.T) { + got := Client().HTTPClient + assert.Equal(t, http.DefaultClient, got) + }) + t.Run("Non-default global transport", func(t *testing.T) { + http.DefaultTransport = &nonDefaultTransport{} + got := Client().HTTPClient.Transport + assert.Equal(t, &nonDefaultTransport{}, got) + }) +} + func createCertPool(t *testing.T, paths ...string) *x509.CertPool { t.Helper()