Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions agent.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package ssh

import (
"io"
"context"
"net"
"os"
"path"
"sync"

gossh "golang.org/x/crypto/ssh"
)
Expand Down Expand Up @@ -58,26 +57,13 @@ func ForwardAgentConnections(l net.Listener, s Session) {
return
}
go func(conn net.Conn) {
defer conn.Close()
channel, reqs, err := sshConn.OpenChannel(agentChannelType, nil)
if err != nil {
conn.Close()
return
}
defer channel.Close()
go gossh.DiscardRequests(reqs)
var wg sync.WaitGroup
wg.Add(2)
go func() {
io.Copy(conn, channel)
conn.(*net.UnixConn).CloseWrite()
wg.Done()
}()
go func() {
io.Copy(channel, conn)
channel.CloseWrite()
wg.Done()
}()
wg.Wait()
bicopy(context.Background(), channel, conn)
}(conn)
}
}
12 changes: 10 additions & 2 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ type Session interface {
// user for this session, in the form "key=value".
Environ() []string

// Exit sends an exit status and then closes the session.
// Exit sends an exit status. The caller is responsible for calling
// Close separately after any remaining I/O is complete.
Exit(code int) error

// Command returns a shell parsed slice of arguments that were provided by the
Expand Down Expand Up @@ -187,7 +188,12 @@ func (sess *session) Exit(code int) error {
if err != nil {
return err
}
return sess.Close()
// Per RFC 4254 Section 6.10, the channel needs to be closed with
// SSH_MSG_CHANNEL_CLOSE after the exit-status message. By not closing
// here, we allow the caller to complete remaining I/O (e.g. flushing
// output and sending EOF via CloseWrite) before closing the channel.
// https://datatracker.ietf.org/doc/html/rfc4254#section-6.10
return nil
}

func (sess *session) User() string {
Expand Down Expand Up @@ -272,6 +278,7 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
go func() {
sess.handler(sess)
sess.Exit(0)
sess.Close()
}()
case "subsystem":
if sess.handled {
Expand Down Expand Up @@ -306,6 +313,7 @@ func (sess *session) handleRequests(reqs <-chan *gossh.Request) {
go func() {
handler(sess)
sess.Exit(0)
sess.Close()
}()
case "env":
if sess.handled {
Expand Down
17 changes: 13 additions & 4 deletions streamlocal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,9 @@ func TestNewReverseUnixForwardingCallbackValidation(t *testing.T) {
func TestNewReverseUnixForwardingCallbackBindUnlink(t *testing.T) {
t.Parallel()

ctx, cancel := newContext(nil)
defer cancel()

dir := tempDirUnixSocket(t)
sockPath := filepath.Join(dir, "test.sock")

Expand All @@ -667,7 +670,7 @@ func TestNewReverseUnixForwardingCallbackBindUnlink(t *testing.T) {
cbNoUnlink := NewReverseUnixForwardingCallback(UnixForwardingOptions{
AllowAll: true,
})
_, err = cbNoUnlink(nil, sockPath)
_, err = cbNoUnlink(ctx, sockPath)
if err == nil {
t.Fatal("expected listen to fail on existing socket without BindUnlink")
}
Expand All @@ -677,7 +680,7 @@ func TestNewReverseUnixForwardingCallbackBindUnlink(t *testing.T) {
AllowAll: true,
BindUnlink: true,
})
newLn, err := cbUnlink(nil, sockPath)
newLn, err := cbUnlink(ctx, sockPath)
if err != nil {
t.Fatalf("expected listen to succeed with BindUnlink, got: %v", err)
}
Expand All @@ -687,6 +690,9 @@ func TestNewReverseUnixForwardingCallbackBindUnlink(t *testing.T) {
func TestNewReverseUnixForwardingCallbackBindUnlinkSkipsNonSocket(t *testing.T) {
t.Parallel()

ctx, cancel := newContext(nil)
defer cancel()

dir := tempDirUnixSocket(t)
filePath := filepath.Join(dir, "regular.file")

Expand All @@ -700,7 +706,7 @@ func TestNewReverseUnixForwardingCallbackBindUnlinkSkipsNonSocket(t *testing.T)
AllowAll: true,
BindUnlink: true,
})
_, err := cb(nil, filePath)
_, err := cb(ctx, filePath)
if err == nil {
t.Fatal("expected listen to fail on regular file even with BindUnlink")
}
Expand Down Expand Up @@ -728,6 +734,9 @@ func TestNewReverseUnixForwardingCallbackSocketPermissions(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

ctx, cancel := newContext(nil)
defer cancel()

// Use a short /tmp path to stay under sun_path limits
// even when the test framework creates long temp paths.
dir, err := os.MkdirTemp("/tmp", "ssh-perm-")
Expand All @@ -741,7 +750,7 @@ func TestNewReverseUnixForwardingCallbackSocketPermissions(t *testing.T) {
AllowAll: true,
BindMask: tt.mask,
})
ln, err := cb(nil, sockPath)
ln, err := cb(ctx, sockPath)
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
Expand Down
79 changes: 53 additions & 26 deletions tcpip.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,42 +181,69 @@ func (h *ForwardedTCPHandler) HandleSSHRequest(ctx Context, srv *Server, req *go
}
}

// bicopy copies all of the data between the two connections and will close them
// after one or both of them are done writing. If the context is canceled, both
// of the connections will be closed.
// bicopy copies data bidirectionally between c1 and c2 until both directions
// complete or the context is canceled. When one direction finishes, it
// half-closes the write side of the destination to signal EOF to the peer
// per RFC 4254 Section 5.3, allowing the other direction to finish gracefully.
// If the context is canceled, both connections are force-closed.
// https://datatracker.ietf.org/doc/html/rfc4254#section-5.3
func bicopy(ctx context.Context, c1, c2 io.ReadWriteCloser) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

defer func() {
_ = c1.Close()
_ = c2.Close()
}()
defer c1.Close()
defer c2.Close()

var wg sync.WaitGroup
copyFunc := func(dst io.WriteCloser, src io.Reader) {
defer func() {
wg.Done()
// If one side of the copy fails, ensure the other one exits as
// well.
cancel()
}()
_, _ = io.Copy(dst, src)
}
wg.Go(func() {
defer halfCloseWrite(c1) // done writing to destination
defer halfCloseRead(c2) // done reading from source
_, _ = io.Copy(c1, c2)
})
wg.Go(func() {
defer halfCloseWrite(c2) // done writing to destination
defer halfCloseRead(c1) // done reading from source
_, _ = io.Copy(c2, c1)
})

wg.Add(2)
go copyFunc(c1, c2)
go copyFunc(c2, c1)

// Convert waitgroup to a channel so we can also wait on the context.
done := make(chan struct{})
go func() {
defer close(done)
wg.Wait()
close(done)
}()

select {
case <-ctx.Done():
case <-done:
return
case <-ctx.Done():
c1.Close()
c2.Close()
<-done
}
}

// halfCloseWrite signals EOF on the write side of c without fully closing
// the connection. This allows the peer to finish reading any buffered data
// and then close its side, which unblocks the other copy direction.
// All connection types used in SSH forwarding ([gossh.Channel], [net.TCPConn],
// [net.UnixConn]) support CloseWrite. For types that don't, this is a no-op
// and the deferred full Close in bicopy handles cleanup.
func halfCloseWrite(c io.ReadWriteCloser) {
type closeWriter interface {
CloseWrite() error
}
if cw, ok := c.(closeWriter); ok {
_ = cw.CloseWrite()
}
}

// halfCloseRead closes the read side of c without fully closing the
// connection. This releases kernel resources on the source once all data
// has been consumed. [net.TCPConn] and [net.UnixConn] support CloseRead;
// [gossh.Channel] does not, so this is a no-op for SSH channels and the
// deferred full Close in bicopy handles cleanup.
func halfCloseRead(c io.ReadWriteCloser) {
type closeReader interface {
CloseRead() error
}
if cr, ok := c.(closeReader); ok {
_ = cr.CloseRead()
}
}
Loading