diff --git a/README.md b/README.md index 089ebb1..b4ad89e 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,21 @@ socket.OnErrorOption(func(err error) socket.ErrorAction { }) ``` +## Connection Management + +```go +// Gracefully close the connection +conn.Close() + +// Check if connection is closed +if conn.IsClosed() { + // Handle closed connection +} + +// Get remote address +addr := conn.Addr() +``` + ## Write Methods Three ways to send messages: @@ -132,6 +147,8 @@ conn.WriteBlocking(ctx, msg) conn.WriteTimeout(msg, 5*time.Second) ``` +All write methods return `ErrConnectionClosed` if the connection is closed. + ## Custom Logger Implement the `Logger` interface or use `slog`: diff --git a/conn.go b/conn.go index f484572..6b96920 100644 --- a/conn.go +++ b/conn.go @@ -8,6 +8,7 @@ import ( "context" "errors" "net" + "sync/atomic" "time" "golang.org/x/sync/errgroup" @@ -21,6 +22,9 @@ var ( ErrInvalidOnMessage = errors.New("invalid on message callback") ) +// ErrConnectionClosed is returned when operating on a closed connection. +var ErrConnectionClosed = errors.New("connection closed") + // Conn represents a client connection to a TCP server. // It manages the underlying TCP connection, message encoding/decoding, // and provides read/write loops for asynchronous communication. @@ -32,6 +36,8 @@ type Conn struct { opts options sendMsg chan []byte + closed atomic.Bool + cancel context.CancelFunc } // Default configuration values. @@ -112,6 +118,7 @@ func newClientConnWithOptions(c *net.TCPConn, opts options) *Conn { func (c *Conn) Run(ctx context.Context) error { c.logger.Debug("connection started", "addr", c.Addr()) + ctx, c.cancel = context.WithCancel(ctx) group, child := errgroup.WithContext(ctx) group.Go(func() error { @@ -123,7 +130,7 @@ func (c *Conn) Run(ctx context.Context) error { }) err := group.Wait() - c.close() + c.closeConn() if err != nil { c.logger.Debug("connection closed with error", "addr", c.Addr(), "error", err) @@ -134,13 +141,35 @@ func (c *Conn) Run(ctx context.Context) error { return err } +// Close gracefully closes the connection. +// It cancels the context and closes the underlying TCP connection. +// Safe to call multiple times. +func (c *Conn) Close() error { + if c.closed.Swap(true) { + return nil // already closed + } + if c.cancel != nil { + c.cancel() + } + return c.rawConn.Close() +} + +// IsClosed returns true if the connection has been closed. +func (c *Conn) IsClosed() bool { + return c.closed.Load() +} + // ErrBufferFull is returned when the send buffer is full and cannot accept more messages. var ErrBufferFull = errors.New("send buffer full") // Write sends a message through the connection without blocking. // The message is encoded using the configured codec and queued for sending. -// Returns ErrBufferFull if the send buffer is full. +// Returns ErrBufferFull if the send buffer is full, or ErrConnectionClosed if closed. func (c *Conn) Write(message Message) error { + if c.closed.Load() { + return ErrConnectionClosed + } + bytes, err := c.opts.codec.Encode(message) if err != nil { return err @@ -156,7 +185,12 @@ func (c *Conn) Write(message Message) error { // WriteBlocking sends a message through the connection, blocking until the message // is queued or the context is canceled. +// Returns ErrConnectionClosed if the connection is closed. func (c *Conn) WriteBlocking(ctx context.Context, message Message) error { + if c.closed.Load() { + return ErrConnectionClosed + } + bytes, err := c.opts.codec.Encode(message) if err != nil { return err @@ -171,8 +205,13 @@ func (c *Conn) WriteBlocking(ctx context.Context, message Message) error { } // WriteTimeout sends a message through the connection with a timeout. -// Returns ErrBufferFull if the message cannot be queued within the timeout. +// Returns ErrBufferFull if the message cannot be queued within the timeout, +// or ErrConnectionClosed if the connection is closed. func (c *Conn) WriteTimeout(message Message, timeout time.Duration) error { + if c.closed.Load() { + return ErrConnectionClosed + } + bytes, err := c.opts.codec.Encode(message) if err != nil { return err @@ -251,7 +290,8 @@ func (c *Conn) write(data []byte) error { return nil } -// close closes the underlying TCP connection. -func (c *Conn) close() { +// closeConn marks the connection as closed and closes the underlying TCP connection. +func (c *Conn) closeConn() { + c.closed.Store(true) c.rawConn.Close() } diff --git a/conn_test.go b/conn_test.go index 6765418..3925864 100644 --- a/conn_test.go +++ b/conn_test.go @@ -683,7 +683,14 @@ func TestConn_close(t *testing.T) { t.Fatalf("NewConn failed: %v", err) } - conn.close() + if err := conn.Close(); err != nil { + t.Errorf("Close failed: %v", err) + } + + // Verify IsClosed returns true + if !conn.IsClosed() { + t.Error("expected IsClosed to return true after Close") + } // Verify connection is closed by trying to write _, err = serverConn.Write([]byte("test"))