Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,21 @@ socket.OnErrorOption(func(err error) socket.ErrorAction {
})
```

## Connection Management

```go
// Gracefully close the connection
conn.Close()

// Check if connection is closed
if conn.IsClosed() {
// Handle closed connection
}

// Get remote address
addr := conn.Addr()
```

## Write Methods

Three ways to send messages:
Expand All @@ -132,6 +147,8 @@ conn.WriteBlocking(ctx, msg)
conn.WriteTimeout(msg, 5*time.Second)
```

All write methods return `ErrConnectionClosed` if the connection is closed.

## Custom Logger

Implement the `Logger` interface or use `slog`:
Expand Down
50 changes: 45 additions & 5 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"errors"
"net"
"sync/atomic"
"time"

"golang.org/x/sync/errgroup"
Expand All @@ -21,6 +22,9 @@ var (
ErrInvalidOnMessage = errors.New("invalid on message callback")
)

// ErrConnectionClosed is returned when operating on a closed connection.
var ErrConnectionClosed = errors.New("connection closed")

// Conn represents a client connection to a TCP server.
// It manages the underlying TCP connection, message encoding/decoding,
// and provides read/write loops for asynchronous communication.
Expand All @@ -32,6 +36,8 @@ type Conn struct {
opts options

sendMsg chan []byte
closed atomic.Bool
cancel context.CancelFunc
}

// Default configuration values.
Expand Down Expand Up @@ -112,6 +118,7 @@ func newClientConnWithOptions(c *net.TCPConn, opts options) *Conn {
func (c *Conn) Run(ctx context.Context) error {
c.logger.Debug("connection started", "addr", c.Addr())

ctx, c.cancel = context.WithCancel(ctx)
group, child := errgroup.WithContext(ctx)

group.Go(func() error {
Expand All @@ -123,7 +130,7 @@ func (c *Conn) Run(ctx context.Context) error {
})

err := group.Wait()
c.close()
c.closeConn()

if err != nil {
c.logger.Debug("connection closed with error", "addr", c.Addr(), "error", err)
Expand All @@ -134,13 +141,35 @@ func (c *Conn) Run(ctx context.Context) error {
return err
}

// Close gracefully closes the connection.
// It cancels the context and closes the underlying TCP connection.
// Safe to call multiple times.
func (c *Conn) Close() error {
if c.closed.Swap(true) {
return nil // already closed
}
if c.cancel != nil {
c.cancel()
}
return c.rawConn.Close()
}

// IsClosed returns true if the connection has been closed.
func (c *Conn) IsClosed() bool {
return c.closed.Load()
}

// ErrBufferFull is returned when the send buffer is full and cannot accept more messages.
var ErrBufferFull = errors.New("send buffer full")

// Write sends a message through the connection without blocking.
// The message is encoded using the configured codec and queued for sending.
// Returns ErrBufferFull if the send buffer is full.
// Returns ErrBufferFull if the send buffer is full, or ErrConnectionClosed if closed.
func (c *Conn) Write(message Message) error {
if c.closed.Load() {
return ErrConnectionClosed
}

bytes, err := c.opts.codec.Encode(message)
if err != nil {
return err
Expand All @@ -156,7 +185,12 @@ func (c *Conn) Write(message Message) error {

// WriteBlocking sends a message through the connection, blocking until the message
// is queued or the context is canceled.
// Returns ErrConnectionClosed if the connection is closed.
func (c *Conn) WriteBlocking(ctx context.Context, message Message) error {
if c.closed.Load() {
return ErrConnectionClosed
}

bytes, err := c.opts.codec.Encode(message)
if err != nil {
return err
Expand All @@ -171,8 +205,13 @@ func (c *Conn) WriteBlocking(ctx context.Context, message Message) error {
}

// WriteTimeout sends a message through the connection with a timeout.
// Returns ErrBufferFull if the message cannot be queued within the timeout.
// Returns ErrBufferFull if the message cannot be queued within the timeout,
// or ErrConnectionClosed if the connection is closed.
func (c *Conn) WriteTimeout(message Message, timeout time.Duration) error {
if c.closed.Load() {
return ErrConnectionClosed
}

bytes, err := c.opts.codec.Encode(message)
if err != nil {
return err
Expand Down Expand Up @@ -251,7 +290,8 @@ func (c *Conn) write(data []byte) error {
return nil
}

// close closes the underlying TCP connection.
func (c *Conn) close() {
// closeConn marks the connection as closed and closes the underlying TCP connection.
func (c *Conn) closeConn() {
c.closed.Store(true)
c.rawConn.Close()
}
9 changes: 8 additions & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,14 @@ func TestConn_close(t *testing.T) {
t.Fatalf("NewConn failed: %v", err)
}

conn.close()
if err := conn.Close(); err != nil {
t.Errorf("Close failed: %v", err)
}

// Verify IsClosed returns true
if !conn.IsClosed() {
t.Error("expected IsClosed to return true after Close")
}

// Verify connection is closed by trying to write
_, err = serverConn.Write([]byte("test"))
Expand Down
Loading