Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

martian: cancel request context on downstream disconnection #1004

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions http_proxy_errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package forwarder

import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"errors"
Expand Down Expand Up @@ -50,6 +51,7 @@ func (hp *HTTPProxy) errorResponse(req *http.Request, err error) *http.Response
handleMartianErrorStatus,
handleAuthenticationError,
handleDenyError,
handleContextCancelationError,
handleStatusText,
}

Expand Down Expand Up @@ -225,6 +227,16 @@ func handleDenyError(req *http.Request, err error) (code int, msg, label string)
return
}

func handleContextCancelationError(_ *http.Request, err error) (code int, msg, label string) {
if errors.Is(err, context.Canceled) {
code = http.StatusInternalServerError
msg = fmt.Sprintf("request context canceled")
label = "request_ctx_canceled"
}

return
}

// There is a difference between sending HTTP and HTTPS requests in the presence of an upstream proxy.
// For HTTPS client issues a CONNECT request to the proxy and then sends the original request.
// In case the proxy responds with status code 4XX or 5XX to the CONNECT request, the client interprets it as URL error.
Expand Down
16 changes: 7 additions & 9 deletions internal/martian/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,18 +82,16 @@ func isClosedConnError(err error) bool {

// isCloseable reports whether err is an error that indicates the client connection should be closed.
func isCloseable(err error) bool {
if errors.Is(err, io.EOF) ||
errors.Is(err, io.ErrUnexpectedEOF) ||
errors.Is(err, io.ErrClosedPipe) {
return true
if err == nil {
return false
}

var neterr net.Error
if ok := errors.As(err, &neterr); ok && neterr.Timeout() {
return true
}

return strings.Contains(err.Error(), "tls:")
return errors.Is(err, io.EOF) ||
errors.Is(err, io.ErrUnexpectedEOF) ||
errors.Is(err, io.ErrClosedPipe) ||
(errors.As(err, &neterr) && !neterr.Timeout()) ||
strings.Contains(err.Error(), "tls:")
}

type ErrorStatus struct { //nolint:errname // ErrorStatus is a type name not a variable.
Expand Down
1 change: 1 addition & 0 deletions internal/martian/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ func (p *Proxy) handleLoop(conn net.Conn) {
}

pc := newProxyConn(p, conn)
defer pc.Close()

if err := pc.maybeHandshakeTLS(); err != nil {
log.Errorf(context.TODO(), "failed to do TLS handshake: %v", err)
Expand Down
89 changes: 80 additions & 9 deletions internal/martian/proxy_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,55 @@ import (

type proxyConn struct {
*Proxy
brw *bufio.ReadWriter
conn net.Conn
secure bool
cs tls.ConnectionState
brw *bufio.ReadWriter
conn net.Conn
readSem chan struct{}
ctx context.Context
cancelctx context.CancelFunc
secure bool
cs tls.ConnectionState
}

func newProxyConn(p *Proxy, conn net.Conn) *proxyConn {
return &proxyConn{
Proxy: p,
brw: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
conn: conn,
ctx, cancel := context.WithCancel(p.BaseContext)

pc := &proxyConn{
Proxy: p,
brw: bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)),
conn: conn,
readSem: make(chan struct{}, 1),
ctx: ctx,
cancelctx: cancel,
}
go pc.backgroundRead()

return pc
}

func (p *proxyConn) takeReadSem() bool {
select {
case p.readSem <- struct{}{}:
return true
case <-p.ctx.Done():
return false
}
}

func (p *proxyConn) releaseReadSem() {
<-p.readSem
}

func (p *proxyConn) backgroundRead() {
for {
if ok := p.takeReadSem(); !ok {
return
}
if _, err := p.brw.Peek(1); isClosedConnError(err) {
p.cancelctx()
p.releaseReadSem()
return
}
p.releaseReadSem()
}
}

Expand Down Expand Up @@ -82,6 +120,13 @@ func (p *proxyConn) readRequest() (*http.Request, error) {
log.Errorf(context.TODO(), "can't set idle deadline: %v", deadlineErr)
}

// Take read semaphore only after idle timeout has been enabled
// as it might be blocked on a background read.
// It'll be released on roundTrip, or never in case of a tunnel.
if ok := p.takeReadSem(); !ok {
return nil, fmt.Errorf("failed to take read semaphore")
}

// Wait for the connection to become readable before trying to
// read the next request. This prevents a ReadHeaderTimeout or
// ReadTimeout from starting until the first bytes of the next request
Expand Down Expand Up @@ -114,7 +159,7 @@ func (p *proxyConn) readRequest() (*http.Request, error) {
if p.secure {
req.TLS = &p.cs
}
req = req.WithContext(withTraceID(p.BaseContext, newTraceID(req.Header.Get(p.RequestIDHeader))))
req = req.WithContext(withTraceID(p.ctx, newTraceID(req.Header.Get(p.RequestIDHeader))))

// Adjust the read deadline if necessary.
if !hdrDeadline.Equal(wholeReqDeadline) {
Expand Down Expand Up @@ -295,6 +340,19 @@ func (p *proxyConn) tunnel(name string, res *http.Response, crw io.ReadWriteClos
return nil
}

type onCloseBody struct {
io.ReadCloser
onClose func()
}

func (b *onCloseBody) Close() error {
err := b.ReadCloser.Close()
if b.onClose != nil {
b.onClose()
}
return err
}

func (p *proxyConn) handle() error {
req, err := p.readRequest()
p.traceReadRequest(req, err)
Expand Down Expand Up @@ -346,6 +404,14 @@ func (p *proxyConn) handle() error {
req.Header.Set("Upgrade", reqUpType)
}

// Wrap body to start backgroundRead after it's been consumed.
// That allows to cancel request context
// when downstream connection disconnects mid-request.
req.Body = &onCloseBody{
ReadCloser: req.Body,
onClose: p.releaseReadSem,
}

// perform the HTTP roundtrip
res, err := p.roundTrip(req)
if err != nil {
Expand Down Expand Up @@ -533,3 +599,8 @@ func writeHeaderOnlyResponse(w io.Writer, res *http.Response) error {

return nil
}

func (p *proxyConn) Close() error {
p.cancelctx()
return p.conn.Close()
}
63 changes: 63 additions & 0 deletions internal/martian/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2045,3 +2045,66 @@ func TestReadHeaderConnectionReset(t *testing.T) {
t.Fatalf("conn.Read(): got %v, want io.EOF", err)
}
}

func TestCancelRequestOnDisconnect(t *testing.T) {
t.Parallel()

l, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("Failed to create http listener: %v", err)
}
defer l.Close()

donec := make(chan struct{})

go func() {
t.Logf("Waiting for server side connection")
conn, err := l.Accept()
if err != nil {
t.Errorf("Got error while accepting connection on destination listener: %v", err)
return
}
t.Logf("Accepted server side connection")

// Read request and hang, simulating a long computation on the server side.
buf := make([]byte, 16384)
if _, err := conn.Read(buf); err != nil {
t.Errorf("Error reading: %v", err)
return
}
<-donec

conn.Close()
}()

h := testHelper{
Proxy: func(p *Proxy) {
p.ResponseModifier = ResponseModifierFunc(func(res *http.Response) error {
warning := res.Header.Get("Warning")
t.Logf("Warning: %v", warning)
if !strings.Contains(warning, "context canceled") {
t.Errorf("Context canceled warning not found in response")
}
close(donec)
return nil
})
},
}
conn, cancel := h.proxyConn(t)
defer cancel()
defer conn.Close()

request := "GET / HTTP/1.1\r\n" + fmt.Sprintf("Host: %s\r\n\r\n", l.Addr())
if _, err = conn.Write([]byte(request)); err != nil {
t.Fatalf("conn.Write(): got %v, want no error", err)
}

// Disconnect mid-request.
conn.Close()

select {
case <-donec:
case <-time.After(2 * time.Second):
t.Fatalf("timeout")
}
}
Loading