From 9bc07048f372998be030de92fbf9cebbd81a3ebc Mon Sep 17 00:00:00 2001 From: Kian Parvin Date: Wed, 16 Apr 2025 12:51:25 +0200 Subject: [PATCH] fix: handle calls to Close() before Serve() This change prevents race conditions in programs that use the ssh server if they call Close() before Serve(), taking from improvements to the Go stdlib http server. --- server.go | 63 ++++++++++++++++++++------------------------------ server_test.go | 48 +++++++++++++++++++++++++++++++------- 2 files changed, 65 insertions(+), 46 deletions(-) diff --git a/server.go b/server.go index 7dbaa0f..70492b1 100644 --- a/server.go +++ b/server.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "sync" + "sync/atomic" "time" gossh "golang.org/x/crypto/ssh" @@ -70,12 +71,12 @@ type Server struct { // handlers, but handle named subsystems. SubsystemHandlers map[string]SubsystemHandler + inShutdown atomic.Bool // true when server is in shutdown listenerWg sync.WaitGroup mu sync.RWMutex listeners map[net.Listener]struct{} conns map[*gossh.ServerConn]struct{} connWg sync.WaitGroup - doneChan chan struct{} } func (srv *Server) ensureHostSigner() error { @@ -191,11 +192,20 @@ func (srv *Server) Handle(fn Handler) { // Close returns any error returned from closing the Server's // underlying Listener(s). func (srv *Server) Close() error { + srv.inShutdown.Store(true) srv.mu.Lock() defer srv.mu.Unlock() - srv.closeDoneChanLocked() err := srv.closeListenersLocked() + + // Unlock srv.mu while waiting for listenerWg. + // The group Add and Done calls are made with srv.mu held, + // to avoid adding a new listener in the window between + // us setting inShutdown above and waiting here. + srv.mu.Unlock() + srv.listenerWg.Wait() + srv.mu.Lock() + for c := range srv.conns { c.Close() delete(srv.conns, c) @@ -209,9 +219,9 @@ func (srv *Server) Close() error { // If the provided context expires before the shutdown is complete, // then the context's error is returned. func (srv *Server) Shutdown(ctx context.Context) error { + srv.inShutdown.Store(true) srv.mu.Lock() lnerr := srv.closeListenersLocked() - srv.closeDoneChanLocked() srv.mu.Unlock() finished := make(chan struct{}, 1) @@ -229,6 +239,10 @@ func (srv *Server) Shutdown(ctx context.Context) error { } } +func (s *Server) shuttingDown() bool { + return s.inShutdown.Load() +} + // Serve accepts incoming connections on the Listener l, creating a new // connection goroutine for each. The connection goroutines read requests and then // calls srv.Handler to handle sessions. @@ -245,15 +259,15 @@ func (srv *Server) Serve(l net.Listener) error { } var tempDelay time.Duration - srv.trackListener(l, true) + if !srv.trackListener(l, true) { + return ErrServerClosed + } defer srv.trackListener(l, false) for { conn, e := l.Accept() if e != nil { - select { - case <-srv.getDoneChan(): + if srv.shuttingDown() { return ErrServerClosed - default: } if ne, ok := e.(net.Error); ok && ne.Temporary() { if tempDelay == 0 { @@ -393,32 +407,6 @@ func (srv *Server) SetOption(option Option) error { return option(srv) } -func (srv *Server) getDoneChan() <-chan struct{} { - srv.mu.Lock() - defer srv.mu.Unlock() - - return srv.getDoneChanLocked() -} - -func (srv *Server) getDoneChanLocked() chan struct{} { - if srv.doneChan == nil { - srv.doneChan = make(chan struct{}) - } - return srv.doneChan -} - -func (srv *Server) closeDoneChanLocked() { - ch := srv.getDoneChanLocked() - select { - case <-ch: - // Already closed. Don't close again. - default: - // Safe to close here. We're the only closer, guarded - // by srv.mu. - close(ch) - } -} - func (srv *Server) closeListenersLocked() error { var err error for ln := range srv.listeners { @@ -430,7 +418,7 @@ func (srv *Server) closeListenersLocked() error { return err } -func (srv *Server) trackListener(ln net.Listener, add bool) { +func (srv *Server) trackListener(ln net.Listener, add bool) bool { srv.mu.Lock() defer srv.mu.Unlock() @@ -438,10 +426,8 @@ func (srv *Server) trackListener(ln net.Listener, add bool) { srv.listeners = make(map[net.Listener]struct{}) } if add { - // If the *Server is being reused after a previous - // Close or Shutdown, reset its doneChan: - if len(srv.listeners) == 0 && len(srv.conns) == 0 { - srv.doneChan = nil + if srv.shuttingDown() { + return false } srv.listeners[ln] = struct{}{} srv.listenerWg.Add(1) @@ -449,6 +435,7 @@ func (srv *Server) trackListener(ln net.Listener, add bool) { delete(srv.listeners, ln) srv.listenerWg.Done() } + return true } func (srv *Server) trackConn(c *gossh.ServerConn, add bool) { diff --git a/server_test.go b/server_test.go index 63fe694..7b0fd49 100644 --- a/server_test.go +++ b/server_test.go @@ -41,7 +41,7 @@ func TestServerShutdown(t *testing.T) { go func() { err := s.Serve(l) if err != nil && err != ErrServerClosed { - t.Fatal(err) + t.Error(err) } }() sessDone := make(chan struct{}) @@ -52,10 +52,10 @@ func TestServerShutdown(t *testing.T) { var stdout bytes.Buffer sess.Stdout = &stdout if err := sess.Run(""); err != nil { - t.Fatal(err) + t.Error(err) } if !bytes.Equal(stdout.Bytes(), testBytes) { - t.Fatalf("expected = %s; got %s", testBytes, stdout.Bytes()) + t.Errorf("expected = %s; got %s", testBytes, stdout.Bytes()) } }() @@ -64,7 +64,7 @@ func TestServerShutdown(t *testing.T) { defer close(srvDone) err := s.Shutdown(context.Background()) if err != nil { - t.Fatal(err) + t.Error(err) } }() @@ -90,7 +90,7 @@ func TestServerClose(t *testing.T) { go func() { err := s.Serve(l) if err != nil && err != ErrServerClosed { - t.Fatal(err) + t.Error(err) } }() @@ -103,14 +103,14 @@ func TestServerClose(t *testing.T) { defer close(clientDoneChan) <-closeDoneChan if err := sess.Run(""); err != nil && err != io.EOF { - t.Fatal(err) + t.Error(err) } }() go func() { err := s.Close() if err != nil { - t.Fatal(err) + t.Error(err) } close(closeDoneChan) }() @@ -120,12 +120,44 @@ func TestServerClose(t *testing.T) { case <-timeout: t.Error("timeout") return - case <-s.getDoneChan(): + case <-closeDoneChan: <-clientDoneChan return } } +func TestServerCloseBeforeServe(t *testing.T) { + l := newLocalListener() + s := &Server{} + + serveDoneChan := make(chan struct{}) + closeDoneChan := make(chan struct{}) + + go func() { + <-closeDoneChan + err := s.Serve(l) + if err != nil && err != ErrServerClosed { + t.Error(err) + } + close(serveDoneChan) + }() + + err := s.Close() + if err != nil { + t.Error(err) + } + close(closeDoneChan) + + timeout := time.After(1 * time.Second) + select { + case <-timeout: + t.Error("timeout") + return + case <-serveDoneChan: + return + } +} + func TestServerHandshakeTimeout(t *testing.T) { l := newLocalListener()