diff --git a/agent.go b/agent.go index 99e84c1..75704fd 100644 --- a/agent.go +++ b/agent.go @@ -1,11 +1,10 @@ package ssh import ( - "io" + "context" "net" "os" "path" - "sync" gossh "golang.org/x/crypto/ssh" ) @@ -58,26 +57,13 @@ func ForwardAgentConnections(l net.Listener, s Session) { return } go func(conn net.Conn) { - defer conn.Close() channel, reqs, err := sshConn.OpenChannel(agentChannelType, nil) if err != nil { + conn.Close() return } - defer channel.Close() go gossh.DiscardRequests(reqs) - var wg sync.WaitGroup - wg.Add(2) - go func() { - io.Copy(conn, channel) - conn.(*net.UnixConn).CloseWrite() - wg.Done() - }() - go func() { - io.Copy(channel, conn) - channel.CloseWrite() - wg.Done() - }() - wg.Wait() + bicopy(context.Background(), channel, conn) }(conn) } } diff --git a/session.go b/session.go index a8936dc..7502650 100644 --- a/session.go +++ b/session.go @@ -35,7 +35,8 @@ type Session interface { // user for this session, in the form "key=value". Environ() []string - // Exit sends an exit status and then closes the session. + // Exit sends an exit status. The caller is responsible for calling + // Close separately after any remaining I/O is complete. Exit(code int) error // Command returns a shell parsed slice of arguments that were provided by the @@ -187,7 +188,12 @@ func (sess *session) Exit(code int) error { if err != nil { return err } - return sess.Close() + // Per RFC 4254 Section 6.10, the channel needs to be closed with + // SSH_MSG_CHANNEL_CLOSE after the exit-status message. By not closing + // here, we allow the caller to complete remaining I/O (e.g. flushing + // output and sending EOF via CloseWrite) before closing the channel. + // https://datatracker.ietf.org/doc/html/rfc4254#section-6.10 + return nil } func (sess *session) User() string { @@ -272,6 +278,7 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) { go func() { sess.handler(sess) sess.Exit(0) + sess.Close() }() case "subsystem": if sess.handled { @@ -306,6 +313,7 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) { go func() { handler(sess) sess.Exit(0) + sess.Close() }() case "env": if sess.handled { diff --git a/streamlocal_test.go b/streamlocal_test.go index f912bff..4745624 100644 --- a/streamlocal_test.go +++ b/streamlocal_test.go @@ -652,6 +652,9 @@ func TestNewReverseUnixForwardingCallbackValidation(t *testing.T) { func TestNewReverseUnixForwardingCallbackBindUnlink(t *testing.T) { t.Parallel() + ctx, cancel := newContext(nil) + defer cancel() + dir := tempDirUnixSocket(t) sockPath := filepath.Join(dir, "test.sock") @@ -667,7 +670,7 @@ func TestNewReverseUnixForwardingCallbackBindUnlink(t *testing.T) { cbNoUnlink := NewReverseUnixForwardingCallback(UnixForwardingOptions{ AllowAll: true, }) - _, err = cbNoUnlink(nil, sockPath) + _, err = cbNoUnlink(ctx, sockPath) if err == nil { t.Fatal("expected listen to fail on existing socket without BindUnlink") } @@ -677,7 +680,7 @@ func TestNewReverseUnixForwardingCallbackBindUnlink(t *testing.T) { AllowAll: true, BindUnlink: true, }) - newLn, err := cbUnlink(nil, sockPath) + newLn, err := cbUnlink(ctx, sockPath) if err != nil { t.Fatalf("expected listen to succeed with BindUnlink, got: %v", err) } @@ -687,6 +690,9 @@ func TestNewReverseUnixForwardingCallbackBindUnlink(t *testing.T) { func TestNewReverseUnixForwardingCallbackBindUnlinkSkipsNonSocket(t *testing.T) { t.Parallel() + ctx, cancel := newContext(nil) + defer cancel() + dir := tempDirUnixSocket(t) filePath := filepath.Join(dir, "regular.file") @@ -700,7 +706,7 @@ func TestNewReverseUnixForwardingCallbackBindUnlinkSkipsNonSocket(t *testing.T) AllowAll: true, BindUnlink: true, }) - _, err := cb(nil, filePath) + _, err := cb(ctx, filePath) if err == nil { t.Fatal("expected listen to fail on regular file even with BindUnlink") } @@ -728,6 +734,9 @@ func TestNewReverseUnixForwardingCallbackSocketPermissions(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() + ctx, cancel := newContext(nil) + defer cancel() + // Use a short /tmp path to stay under sun_path limits // even when the test framework creates long temp paths. dir, err := os.MkdirTemp("/tmp", "ssh-perm-") @@ -741,7 +750,7 @@ func TestNewReverseUnixForwardingCallbackSocketPermissions(t *testing.T) { AllowAll: true, BindMask: tt.mask, }) - ln, err := cb(nil, sockPath) + ln, err := cb(ctx, sockPath) if err != nil { t.Fatalf("failed to listen: %v", err) } diff --git a/tcpip.go b/tcpip.go index 843704a..0d834ae 100644 --- a/tcpip.go +++ b/tcpip.go @@ -181,42 +181,69 @@ func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *go } } -// bicopy copies all of the data between the two connections and will close them -// after one or both of them are done writing. If the context is canceled, both -// of the connections will be closed. +// bicopy copies data bidirectionally between c1 and c2 until both directions +// complete or the context is canceled. When one direction finishes, it +// half-closes the write side of the destination to signal EOF to the peer +// per RFC 4254 Section 5.3, allowing the other direction to finish gracefully. +// If the context is canceled, both connections are force-closed. +// https://datatracker.ietf.org/doc/html/rfc4254#section-5.3 func bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - defer func() { - _ = c1.Close() - _ = c2.Close() - }() + defer c1.Close() + defer c2.Close() var wg sync.WaitGroup - copyFunc := func(dst io.WriteCloser, src io.Reader) { - defer func() { - wg.Done() - // If one side of the copy fails, ensure the other one exits as - // well. - cancel() - }() - _, _ = io.Copy(dst, src) - } + wg.Go(func() { + defer halfCloseWrite(c1) // done writing to destination + defer halfCloseRead(c2) // done reading from source + _, _ = io.Copy(c1, c2) + }) + wg.Go(func() { + defer halfCloseWrite(c2) // done writing to destination + defer halfCloseRead(c1) // done reading from source + _, _ = io.Copy(c2, c1) + }) - wg.Add(2) - go copyFunc(c1, c2) - go copyFunc(c2, c1) - - // Convert waitgroup to a channel so we can also wait on the context. done := make(chan struct{}) go func() { - defer close(done) wg.Wait() + close(done) }() select { - case <-ctx.Done(): case <-done: + return + case <-ctx.Done(): + c1.Close() + c2.Close() + <-done + } +} + +// halfCloseWrite signals EOF on the write side of c without fully closing +// the connection. This allows the peer to finish reading any buffered data +// and then close its side, which unblocks the other copy direction. +// All connection types used in SSH forwarding ([gossh.Channel], [net.TCPConn], +// [net.UnixConn]) support CloseWrite. For types that don't, this is a no-op +// and the deferred full Close in bicopy handles cleanup. +func halfCloseWrite(c io.ReadWriteCloser) { + type closeWriter interface { + CloseWrite() error + } + if cw, ok := c.(closeWriter); ok { + _ = cw.CloseWrite() + } +} + +// halfCloseRead closes the read side of c without fully closing the +// connection. This releases kernel resources on the source once all data +// has been consumed. [net.TCPConn] and [net.UnixConn] support CloseRead; +// [gossh.Channel] does not, so this is a no-op for SSH channels and the +// deferred full Close in bicopy handles cleanup. +func halfCloseRead(c io.ReadWriteCloser) { + type closeReader interface { + CloseRead() error + } + if cr, ok := c.(closeReader); ok { + _ = cr.CloseRead() } } diff --git a/tcpip_test.go b/tcpip_test.go index f6a9fa9..7e4e32a 100644 --- a/tcpip_test.go +++ b/tcpip_test.go @@ -7,6 +7,7 @@ import ( "net" "strconv" "strings" + "sync" "sync/atomic" "testing" "time" @@ -173,3 +174,231 @@ func TestReverseTCPForwardingRespectsCallback(t *testing.T) { t.Fatalf("Expected callback to be called once but it was called %d times", called) } } + +// newTCPConnPair creates a pair of connected TCP connections using a +// localhost listener. The returned connections support half-close via +// [net.TCPConn.CloseWrite], making them suitable for testing bicopy. +func newTCPConnPair(t *testing.T) (net.Conn, net.Conn) { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + var serverConn net.Conn + var acceptErr error + var wg sync.WaitGroup + wg.Go(func() { + serverConn, acceptErr = ln.Accept() + }) + + clientConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + wg.Wait() + if acceptErr != nil { + t.Fatal(acceptErr) + } + return clientConn, serverConn +} + +func TestBicopyNormal(t *testing.T) { + t.Parallel() + + // ext1 <--TCP--> c1 <-- bicopy --> c2 <--TCP--> ext2 + ext1, c1 := newTCPConnPair(t) + ext2, c2 := newTCPConnPair(t) + defer ext1.Close() + defer ext2.Close() + + done := make(chan struct{}) + go func() { + bicopy(context.Background(), c1, c2) + close(done) + }() + + // ext1 sends data; ext2 should receive it via bicopy + msg := []byte("hello through bicopy") + go func() { + ext1.Write(msg) + ext1.(*net.TCPConn).CloseWrite() + }() + + buf, err := io.ReadAll(ext2) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + if !bytes.Equal(buf, msg) { + t.Fatalf("got %q, want %q", buf, msg) + } + + // Close ext2's write side so bicopy's other direction finishes + ext2.(*net.TCPConn).CloseWrite() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("bicopy did not complete in time") + } +} + +func TestBicopyContextCancel(t *testing.T) { + t.Parallel() + + ext1, c1 := newTCPConnPair(t) + ext2, c2 := newTCPConnPair(t) + defer ext1.Close() + defer ext2.Close() + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + bicopy(ctx, c1, c2) + close(done) + }() + + // Cancel the context; bicopy should force-close both and return + cancel() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("bicopy did not complete after context cancellation") + } +} + +func TestBicopyHalfClosePropagation(t *testing.T) { + t.Parallel() + + // Verify that when one side finishes sending, the other side + // can still complete its transfer (no premature teardown). + // This is the key property that half-close provides over the + // old cancel-on-first-direction-complete approach. + + // ext1 <--TCP--> c1 <-- bicopy --> c2 <--TCP--> ext2 + ext1, c1 := newTCPConnPair(t) + ext2, c2 := newTCPConnPair(t) + defer ext1.Close() + defer ext2.Close() + + done := make(chan struct{}) + go func() { + bicopy(context.Background(), c1, c2) + close(done) + }() + + // ext1 sends a message and immediately closes its write side (fast direction) + fastMsg := []byte("fast side done") + ext1.Write(fastMsg) + ext1.(*net.TCPConn).CloseWrite() + + // ext2 reads the fast message + buf := make([]byte, len(fastMsg)) + if _, err := io.ReadFull(ext2, buf); err != nil { + t.Fatalf("ReadFull: %v", err) + } + if !bytes.Equal(buf, fastMsg) { + t.Fatalf("got %q, want %q", buf, fastMsg) + } + + // ext2 sends a reply after a delay (slow direction). + // With the old bicopy (cancel on first direction complete), + // this data would be lost. With half-close, it gets through. + slowMsg := []byte("slow side reply") + go func() { + time.Sleep(50 * time.Millisecond) + ext2.Write(slowMsg) + ext2.(*net.TCPConn).CloseWrite() + }() + + // ext1 should receive the slow reply + reply, err := io.ReadAll(ext1) + if err != nil { + t.Fatalf("ReadAll reply: %v", err) + } + if !bytes.Equal(reply, slowMsg) { + t.Fatalf("got reply %q, want %q", reply, slowMsg) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("bicopy did not complete in time") + } +} + +// opaqueRWC wraps an [io.ReadWriteCloser] to hide any half-close methods +// (CloseWrite, CloseRead) from type assertions. This simulates a connection +// type that does not support half-close, which is useful for testing that +// halfCloseWrite and halfCloseRead are safe no-ops rather than falling back +// to a full Close that would break the other direction. +type opaqueRWC struct { + io.ReadWriteCloser +} + +func TestHalfCloseWriteNoOpPreservesConnection(t *testing.T) { + t.Parallel() + + // Verify that halfCloseWrite is a no-op for types without CloseWrite + // support, leaving the connection open for continued reading. A + // c.Close() fallback would break this: the connection would be fully + // closed and the subsequent ReadAll would fail. + c1, c2 := net.Pipe() + defer c2.Close() + wrapped := &opaqueRWC{c1} + defer wrapped.Close() + + halfCloseWrite(wrapped) + + // Connection must still be readable after the no-op halfCloseWrite. + go func() { + c2.Write([]byte("still works")) + c2.Close() + }() + + buf, err := io.ReadAll(wrapped) + if err != nil { + t.Fatalf("ReadAll after halfCloseWrite should succeed: %v", err) + } + if string(buf) != "still works" { + t.Fatalf("got %q, want %q", string(buf), "still works") + } +} + +func TestHalfCloseReadNoOpPreservesConnection(t *testing.T) { + t.Parallel() + + // Verify that halfCloseRead is a no-op for types without CloseRead + // support, leaving the connection open for continued writing. A + // c.Close() fallback would break this: the connection would be fully + // closed and the subsequent Write would fail. + c1, c2 := net.Pipe() + defer c2.Close() + wrapped := &opaqueRWC{c1} + + halfCloseRead(wrapped) + + // Connection must still be writable after the no-op halfCloseRead. + done := make(chan []byte, 1) + go func() { + buf, _ := io.ReadAll(c2) + done <- buf + }() + + msg := []byte("still works") + if _, err := wrapped.Write(msg); err != nil { + t.Fatalf("Write after halfCloseRead should succeed: %v", err) + } + wrapped.Close() + + select { + case buf := <-done: + if string(buf) != string(msg) { + t.Fatalf("got %q, want %q", string(buf), string(msg)) + } + case <-time.After(time.Second): + t.Fatal("timed out waiting for data") + } +}