From 4fe5050404fc6ed0bd0cbe77cbef63afff429709 Mon Sep 17 00:00:00 2001 From: Hubert Grochowski Date: Fri, 24 Jan 2025 16:41:23 +0100 Subject: [PATCH 1/5] martian: cancel request context when downstream disconnects during round trip --- internal/martian/proxy_conn.go | 59 ++++++++++++++++++++++++++----- internal/martian/proxy_test.go | 63 ++++++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+), 8 deletions(-) diff --git a/internal/martian/proxy_conn.go b/internal/martian/proxy_conn.go index 6eb86fae..dfecbbde 100644 --- a/internal/martian/proxy_conn.go +++ b/internal/martian/proxy_conn.go @@ -28,6 +28,7 @@ import ( "net/http" "strconv" "strings" + "sync" "time" "github.com/saucelabs/forwarder/internal/martian/log" @@ -37,17 +38,24 @@ import ( type proxyConn struct { *Proxy - brw *bufio.ReadWriter - conn net.Conn - secure bool - cs tls.ConnectionState + brw *bufio.ReadWriter + conn net.Conn + mu sync.Mutex + ctx context.Context + cancelctx context.CancelFunc + secure bool + cs tls.ConnectionState } func newProxyConn(p *Proxy, conn net.Conn) *proxyConn { + ctx, cancel := context.WithCancel(p.BaseContext) + return &proxyConn{ - Proxy: p, - brw: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), - conn: conn, + Proxy: p, + brw: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), + conn: conn, + ctx: ctx, + cancelctx: cancel, } } @@ -82,6 +90,11 @@ func (p *proxyConn) readRequest() (*http.Request, error) { log.Errorf(context.TODO(), "can't set idle deadline: %v", deadlineErr) } + // Take lock only after idle timeout has been enabled + // as it might be blocked on p.cancelRequestOnDisconnect. + p.mu.Lock() + defer p.mu.Unlock() + // Wait for the connection to become readable before trying to // read the next request. This prevents a ReadHeaderTimeout or // ReadTimeout from starting until the first bytes of the next request @@ -114,7 +127,7 @@ func (p *proxyConn) readRequest() (*http.Request, error) { if p.secure { req.TLS = &p.cs } - req = req.WithContext(withTraceID(p.BaseContext, newTraceID(req.Header.Get(p.RequestIDHeader)))) + req = req.WithContext(withTraceID(p.ctx, newTraceID(req.Header.Get(p.RequestIDHeader)))) // Adjust the read deadline if necessary. if !hdrDeadline.Equal(wholeReqDeadline) { @@ -295,6 +308,28 @@ func (p *proxyConn) tunnel(name string, res *http.Response, crw io.ReadWriteClos return nil } +type onCloseBody struct { + io.ReadCloser + onClose func() +} + +func (b *onCloseBody) Close() error { + err := b.ReadCloser.Close() + if b.onClose != nil { + b.onClose() + } + return err +} + +func (p *proxyConn) cancelRequestOnDisconnect() { + p.mu.Lock() + defer p.mu.Unlock() + + if _, err := p.brw.Peek(1); err != nil { + p.cancelctx() + } +} + func (p *proxyConn) handle() error { req, err := p.readRequest() p.traceReadRequest(req, err) @@ -346,6 +381,14 @@ func (p *proxyConn) handle() error { req.Header.Set("Upgrade", reqUpType) } + // Wrap body to start backgroundRead after it's been consumed. + // That allows to cancel request context + // when downstream connection disconnects mid-request. + req.Body = &onCloseBody{ + ReadCloser: req.Body, + onClose: func() { go p.cancelRequestOnDisconnect() }, + } + // perform the HTTP roundtrip res, err := p.roundTrip(req) if err != nil { diff --git a/internal/martian/proxy_test.go b/internal/martian/proxy_test.go index 126feb00..3f3e6b94 100644 --- a/internal/martian/proxy_test.go +++ b/internal/martian/proxy_test.go @@ -2045,3 +2045,66 @@ func TestReadHeaderConnectionReset(t *testing.T) { t.Fatalf("conn.Read(): got %v, want io.EOF", err) } } + +func TestCancelRequestOnDisconnect(t *testing.T) { + t.Parallel() + + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to create http listener: %v", err) + } + defer l.Close() + + donec := make(chan struct{}) + + go func() { + t.Logf("Waiting for server side connection") + conn, err := l.Accept() + if err != nil { + t.Errorf("Got error while accepting connection on destination listener: %v", err) + return + } + t.Logf("Accepted server side connection") + + // Read request and hang, simulating a long computation on the server side. + buf := make([]byte, 16384) + if _, err := conn.Read(buf); err != nil { + t.Errorf("Error reading: %v", err) + return + } + <-donec + + conn.Close() + }() + + h := testHelper{ + Proxy: func(p *Proxy) { + p.ResponseModifier = ResponseModifierFunc(func(res *http.Response) error { + warning := res.Header.Get("Warning") + t.Logf("Warning: %v", warning) + if !strings.Contains(warning, "context canceled") { + t.Errorf("Context canceled warning not found in response") + } + close(donec) + return nil + }) + }, + } + conn, cancel := h.proxyConn(t) + defer cancel() + defer conn.Close() + + request := "GET / HTTP/1.1\r\n" + fmt.Sprintf("Host: %s\r\n\r\n", l.Addr()) + if _, err = conn.Write([]byte(request)); err != nil { + t.Fatalf("conn.Write(): got %v, want no error", err) + } + + // Disconnect mid-request. + conn.Close() + + select { + case <-donec: + case <-time.After(2 * time.Second): + t.Fatalf("timeout") + } +} From 3eb02a255c645f9752b37e193f068ac343d3502c Mon Sep 17 00:00:00 2001 From: Hubert Grochowski Date: Mon, 27 Jan 2025 11:05:21 +0100 Subject: [PATCH 2/5] martian: rework checking downstream conn to background read This improves performance by limiting the number of backgroundRead goroutines to 1. The read access is controlled by a read semaphore. --- internal/martian/proxy.go | 1 + internal/martian/proxy_conn.go | 62 ++++++++++++++++++++++++---------- 2 files changed, 46 insertions(+), 17 deletions(-) diff --git a/internal/martian/proxy.go b/internal/martian/proxy.go index 8fd3078e..ce16d155 100644 --- a/internal/martian/proxy.go +++ b/internal/martian/proxy.go @@ -323,6 +323,7 @@ func (p *Proxy) handleLoop(conn net.Conn) { } pc := newProxyConn(p, conn) + defer pc.Close() if err := pc.maybeHandshakeTLS(); err != nil { log.Errorf(context.TODO(), "failed to do TLS handshake: %v", err) diff --git a/internal/martian/proxy_conn.go b/internal/martian/proxy_conn.go index dfecbbde..a593181b 100644 --- a/internal/martian/proxy_conn.go +++ b/internal/martian/proxy_conn.go @@ -28,7 +28,6 @@ import ( "net/http" "strconv" "strings" - "sync" "time" "github.com/saucelabs/forwarder/internal/martian/log" @@ -40,7 +39,7 @@ type proxyConn struct { *Proxy brw *bufio.ReadWriter conn net.Conn - mu sync.Mutex + readSem chan struct{} ctx context.Context cancelctx context.CancelFunc secure bool @@ -50,13 +49,44 @@ type proxyConn struct { func newProxyConn(p *Proxy, conn net.Conn) *proxyConn { ctx, cancel := context.WithCancel(p.BaseContext) - return &proxyConn{ + pc := &proxyConn{ Proxy: p, brw: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)), conn: conn, + readSem: make(chan struct{}, 1), ctx: ctx, cancelctx: cancel, } + go pc.backgroundRead() + + return pc +} + +func (p *proxyConn) takeReadSem() bool { + select { + case p.readSem <- struct{}{}: + return true + case <-p.ctx.Done(): + return false + } +} + +func (p *proxyConn) releaseReadSem() { + <-p.readSem +} + +func (p *proxyConn) backgroundRead() { + for { + if ok := p.takeReadSem(); !ok { + return + } + if _, err := p.brw.Peek(1); err != nil { + p.cancelctx() + p.releaseReadSem() + return + } + p.releaseReadSem() + } } func (p *proxyConn) maybeHandshakeTLS() error { @@ -90,10 +120,12 @@ func (p *proxyConn) readRequest() (*http.Request, error) { log.Errorf(context.TODO(), "can't set idle deadline: %v", deadlineErr) } - // Take lock only after idle timeout has been enabled - // as it might be blocked on p.cancelRequestOnDisconnect. - p.mu.Lock() - defer p.mu.Unlock() + // Take read semaphore only after idle timeout has been enabled + // as it might be blocked on a background read. + // It'll be released on roundTrip, or never in case of a tunnel. + if ok := p.takeReadSem(); !ok { + return nil, fmt.Errorf("failed to take read semaphore") + } // Wait for the connection to become readable before trying to // read the next request. This prevents a ReadHeaderTimeout or @@ -321,15 +353,6 @@ func (b *onCloseBody) Close() error { return err } -func (p *proxyConn) cancelRequestOnDisconnect() { - p.mu.Lock() - defer p.mu.Unlock() - - if _, err := p.brw.Peek(1); err != nil { - p.cancelctx() - } -} - func (p *proxyConn) handle() error { req, err := p.readRequest() p.traceReadRequest(req, err) @@ -386,7 +409,7 @@ func (p *proxyConn) handle() error { // when downstream connection disconnects mid-request. req.Body = &onCloseBody{ ReadCloser: req.Body, - onClose: func() { go p.cancelRequestOnDisconnect() }, + onClose: p.releaseReadSem, } // perform the HTTP roundtrip @@ -576,3 +599,8 @@ func writeHeaderOnlyResponse(w io.Writer, res *http.Response) error { return nil } + +func (p *proxyConn) Close() error { + p.cancelctx() + return p.conn.Close() +} From 0c1030d28f0c3f3128d0e25407a3cd372c59898f Mon Sep 17 00:00:00 2001 From: Hubert Grochowski Date: Fri, 24 Jan 2025 17:20:47 +0100 Subject: [PATCH 3/5] http_proxy_errors: handle context cancelation --- http_proxy_errors.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/http_proxy_errors.go b/http_proxy_errors.go index 5b086bfb..20f6eab8 100644 --- a/http_proxy_errors.go +++ b/http_proxy_errors.go @@ -8,6 +8,7 @@ package forwarder import ( "bytes" + "context" "crypto/tls" "crypto/x509" "errors" @@ -50,6 +51,7 @@ func (hp *HTTPProxy) errorResponse(req *http.Request, err error) *http.Response handleMartianErrorStatus, handleAuthenticationError, handleDenyError, + handleContextCancelationError, handleStatusText, } @@ -225,6 +227,16 @@ func handleDenyError(req *http.Request, err error) (code int, msg, label string) return } +func handleContextCancelationError(_ *http.Request, err error) (code int, msg, label string) { + if errors.Is(err, context.Canceled) { + code = http.StatusInternalServerError + msg = fmt.Sprintf("request context canceled") + label = "request_ctx_canceled" + } + + return +} + // There is a difference between sending HTTP and HTTPS requests in the presence of an upstream proxy. // For HTTPS client issues a CONNECT request to the proxy and then sends the original request. // In case the proxy responds with status code 4XX or 5XX to the CONNECT request, the client interprets it as URL error. From 72335aa852a77fcf090f48e52abcacfab6580641 Mon Sep 17 00:00:00 2001 From: Hubert Grochowski Date: Mon, 27 Jan 2025 11:11:02 +0100 Subject: [PATCH 4/5] martian: cancel conn ctx only on ClosedConnError --- internal/martian/proxy_conn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/martian/proxy_conn.go b/internal/martian/proxy_conn.go index a593181b..625a762d 100644 --- a/internal/martian/proxy_conn.go +++ b/internal/martian/proxy_conn.go @@ -80,7 +80,7 @@ func (p *proxyConn) backgroundRead() { if ok := p.takeReadSem(); !ok { return } - if _, err := p.brw.Peek(1); err != nil { + if _, err := p.brw.Peek(1); isClosedConnError(err) { p.cancelctx() p.releaseReadSem() return From dd7b111a2529331157069ba766f46809ff6eab66 Mon Sep 17 00:00:00 2001 From: Hubert Grochowski Date: Mon, 27 Jan 2025 11:20:09 +0100 Subject: [PATCH 5/5] martian: fix and simplify isCloseable Notice that it used to return neterr.Timeout error as closable. This was wrong as we want to retry on any timeout error. This is why handle loop checks for maxConsecutiveErrors. --- internal/martian/errors.go | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/internal/martian/errors.go b/internal/martian/errors.go index cdcfc0cc..3c4902e0 100644 --- a/internal/martian/errors.go +++ b/internal/martian/errors.go @@ -82,18 +82,16 @@ func isClosedConnError(err error) bool { // isCloseable reports whether err is an error that indicates the client connection should be closed. func isCloseable(err error) bool { - if errors.Is(err, io.EOF) || - errors.Is(err, io.ErrUnexpectedEOF) || - errors.Is(err, io.ErrClosedPipe) { - return true + if err == nil { + return false } var neterr net.Error - if ok := errors.As(err, &neterr); ok && neterr.Timeout() { - return true - } - - return strings.Contains(err.Error(), "tls:") + return errors.Is(err, io.EOF) || + errors.Is(err, io.ErrUnexpectedEOF) || + errors.Is(err, io.ErrClosedPipe) || + (errors.As(err, &neterr) && !neterr.Timeout()) || + strings.Contains(err.Error(), "tls:") } type ErrorStatus struct { //nolint:errname // ErrorStatus is a type name not a variable.