diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b0e8f65..60fee1e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,17 +2,14 @@ name: CI on: push: - branches: [main, master] + branches: [master] pull_request: - branches: [main, master] + branches: [master] jobs: test: name: Test runs-on: ubuntu-latest - strategy: - matrix: - go-version: ['1.23', '1.24'] steps: - name: Checkout code @@ -21,7 +18,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: ${{ matrix.go-version }} + go-version-file: 'go.mod' - name: Download dependencies run: go mod download @@ -31,7 +28,6 @@ jobs: - name: Upload coverage uses: codecov/codecov-action@v4 - if: matrix.go-version == '1.24' with: files: ./coverage.out fail_ci_if_error: false @@ -47,7 +43,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: '1.24' + go-version-file: 'go.mod' - name: Run golangci-lint uses: golangci/golangci-lint-action@v6 @@ -66,10 +62,10 @@ jobs: - name: Setup Go uses: actions/setup-go@v5 with: - go-version: '1.24' + go-version-file: 'go.mod' - name: Build run: go build -v ./... - name: Build example - run: go build -v ./example/... + run: go build -v -o bin/echo ./example/... diff --git a/.gitignore b/.gitignore index 43309f8..d223dd4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Build output +bin/ + # Binaries for programs and plugins *.exe *.exe~ @@ -8,6 +11,28 @@ # Test binary, build with `go test -c` *.test -# Output of the go coverage tool, specifically when used with LiteIDE +# Output of the go coverage tool *.out -.idea +coverage.out + +# Profiling +*.prof + +# Debug binary (dlv) +__debug_bin* + +# Go workspace +go.work +go.work.sum + +# IDE +.idea/ +.vscode/ +.claude/ + +# macOS +.DS_Store + +# Environment +.env +.env.local \ No newline at end of file diff --git a/.golangci.yml b/.golangci.yml index 69836c4..5e17013 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,5 +1,3 @@ -version: "2" - run: timeout: 5m modules-download-mode: readonly @@ -14,24 +12,18 @@ linters: - misspell - unconvert - unparam - exclusions: - presets: - - std-error-handling - rules: - - path: _test\.go - linters: - - unparam - - errcheck - -formatters: - enable: - gofmt - goimports - settings: - goimports: - local-prefixes: - - github.com/Zereker/socket + +issues: + exclude-rules: + - path: _test\.go + linters: + - unparam + - errcheck linters-settings: misspell: locale: US + goimports: + local-prefixes: github.com/Zereker/socket diff --git a/README.md b/README.md index b4ad89e..e179305 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ A simple, high-performance TCP server framework for Go. - **Simple API** - Easy to use with functional options pattern - **Custom Codec** - Pluggable message encoding/decoding via `io.Reader` - **Graceful Shutdown** - Context-based cancellation support -- **Heartbeat** - Automatic read/write deadline management +- **Idle Timeout** - Automatic read/write deadline management for connection health - **Error Handling** - Flexible error handling with `Disconnect` or `Continue` actions - **Structured Logging** - Built-in `slog` integration @@ -99,11 +99,13 @@ func main() { | `CustomCodecOption(codec)` | Set message codec (required) | - | | `OnMessageOption(handler)` | Set message handler (required) | - | | `OnErrorOption(handler)` | Set error handler | Disconnect on error | -| `HeartbeatOption(duration)` | Set heartbeat interval | 30s | +| `IdleTimeoutOption(duration)` | Set idle timeout for read/write deadlines | 30s | | `BufferSizeOption(size)` | Set send channel buffer size | 1 | | `MessageMaxSize(size)` | Set max message size | 1MB | | `LoggerOption(logger)` | Set custom logger | slog default | +> **Note:** The idle timeout sets TCP read/write deadlines but does not send ping/pong packets. For active connection health checking, implement heartbeat messages in your application protocol. + ## Error Handling Control how errors are handled with `OnErrorOption`: @@ -134,21 +136,40 @@ addr := conn.Addr() ## Write Methods -Three ways to send messages: +Three ways to send messages with different blocking behaviors: ```go -// Non-blocking write, returns ErrBufferFull if channel is full -conn.Write(msg) +// Non-blocking write (fire-and-forget) +// Returns ErrBufferFull immediately if channel is full +// Best for: non-critical data, custom backpressure handling +err := conn.Write(msg) +if errors.Is(err, socket.ErrBufferFull) { + // Handle backpressure: drop, retry, or use blocking write +} // Blocking write with context cancellation +// Waits for buffer space, respects context timeout/cancellation +// Best for: critical messages that must be delivered +ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) +defer cancel() conn.WriteBlocking(ctx, msg) // Write with timeout +// Waits up to the specified duration for buffer space +// Best for: simple timeout without context management conn.WriteTimeout(msg, 5*time.Second) ``` All write methods return `ErrConnectionClosed` if the connection is closed. +### Backpressure Handling + +When `ErrBufferFull` is returned, it indicates the receiver is not consuming messages fast enough. Strategies: +- **Drop**: Acceptable for metrics, heartbeats, or non-critical updates +- **Retry with backoff**: For important but delay-tolerant messages +- **Switch to blocking**: Use `WriteBlocking` when delivery is critical +- **Flow control**: Implement application-level rate limiting + ## Custom Logger Implement the `Logger` interface or use `slog`: diff --git a/conn.go b/conn.go index 6b96920..62a2de4 100644 --- a/conn.go +++ b/conn.go @@ -1,12 +1,13 @@ // Package socket provides a simple TCP server framework for Go. // It supports custom message encoding/decoding, asynchronous I/O operations, -// and connection management with heartbeat monitoring. +// and connection management with idle timeout monitoring. package socket import ( "bufio" "context" "errors" + "io" "net" "sync/atomic" "time" @@ -20,18 +21,50 @@ var ( ErrInvalidCodec = errors.New("invalid codec callback") // ErrInvalidOnMessage is returned when no message handler is provided. ErrInvalidOnMessage = errors.New("invalid on message callback") + // ErrMessageTooLarge is returned when a message exceeds the maximum allowed size. + ErrMessageTooLarge = errors.New("message too large") ) // ErrConnectionClosed is returned when operating on a closed connection. var ErrConnectionClosed = errors.New("connection closed") +// limitedReader wraps a reader and returns ErrMessageTooLarge when the limit is exceeded. +type limitedReader struct { + r io.Reader + remaining int64 +} + +func newLimitedReader(r io.Reader, limit int64) *limitedReader { + return &limitedReader{r: r, remaining: limit} +} + +func (l *limitedReader) Read(p []byte) (n int, err error) { + if l.remaining <= 0 { + return 0, ErrMessageTooLarge + } + if int64(len(p)) > l.remaining { + p = p[:l.remaining] + } + n, err = l.r.Read(p) + l.remaining -= int64(n) + return +} + +// reset resets the limit counter for reuse with a new message. +// Only remaining is reset because the underlying reader (bufio.Reader) +// maintains its own buffer state and continues reading from where it left off. +func (l *limitedReader) reset(limit int64) { + l.remaining = limit +} + // Conn represents a client connection to a TCP server. // It manages the underlying TCP connection, message encoding/decoding, // and provides read/write loops for asynchronous communication. type Conn struct { - rawConn *net.TCPConn - reader *bufio.Reader - logger Logger + rawConn *net.TCPConn + reader *bufio.Reader + limitedReader *limitedReader + logger Logger opts options @@ -79,8 +112,8 @@ func checkOptions(opts *options) error { return ErrInvalidOnMessage } - if opts.heartbeat <= 0 { - opts.heartbeat = time.Second * 30 + if opts.idleTimeout <= 0 { + opts.idleTimeout = time.Second * 30 } if opts.codec == nil { @@ -100,12 +133,14 @@ func checkOptions(opts *options) error { // newClientConnWithOptions creates a new Conn with the given options. func newClientConnWithOptions(c *net.TCPConn, opts options) *Conn { + reader := bufio.NewReaderSize(c, opts.maxReadLength) cc := &Conn{ - rawConn: c, - reader: bufio.NewReaderSize(c, opts.maxReadLength), - logger: opts.logger, - opts: opts, - sendMsg: make(chan []byte, opts.bufferSize), + rawConn: c, + reader: reader, + limitedReader: newLimitedReader(reader, int64(opts.maxReadLength)), + logger: opts.logger, + opts: opts, + sendMsg: make(chan []byte, opts.bufferSize), } return cc @@ -116,7 +151,11 @@ func newClientConnWithOptions(c *net.TCPConn, opts options) *Conn { // and blocks until an error occurs or the context is canceled. // The connection is automatically closed when Run returns. func (c *Conn) Run(ctx context.Context) error { - c.logger.Debug("connection started", "addr", c.Addr()) + c.logger.Info("connection established", "addr", c.Addr()) + c.logger.Debug("connection options", "addr", c.Addr(), + "buffer_size", c.opts.bufferSize, + "max_read_length", c.opts.maxReadLength, + "idle_timeout", c.opts.idleTimeout) ctx, c.cancel = context.WithCancel(ctx) group, child := errgroup.WithContext(ctx) @@ -132,10 +171,10 @@ func (c *Conn) Run(ctx context.Context) error { err := group.Wait() c.closeConn() - if err != nil { - c.logger.Debug("connection closed with error", "addr", c.Addr(), "error", err) + if err != nil && !errors.Is(err, context.Canceled) { + c.logger.Info("connection closed with error", "addr", c.Addr(), "error", err) } else { - c.logger.Debug("connection closed", "addr", c.Addr()) + c.logger.Info("connection closed", "addr", c.Addr()) } return err @@ -160,11 +199,28 @@ func (c *Conn) IsClosed() bool { } // ErrBufferFull is returned when the send buffer is full and cannot accept more messages. +// This error indicates backpressure - the receiver is not consuming messages fast enough. +// Recommended handling strategies: +// - Drop the message (for non-critical data like metrics) +// - Use WriteBlocking or WriteTimeout to wait for buffer space +// - Implement application-level flow control var ErrBufferFull = errors.New("send buffer full") -// Write sends a message through the connection without blocking. +// Write sends a message through the connection without blocking (fire-and-forget). // The message is encoded using the configured codec and queued for sending. -// Returns ErrBufferFull if the send buffer is full, or ErrConnectionClosed if closed. +// +// Returns: +// - nil: message was successfully queued (not yet sent) +// - ErrBufferFull: send buffer is full, message was NOT queued +// - ErrConnectionClosed: connection is closed +// - encoding error: if codec.Encode fails +// +// Use this method when: +// - You can tolerate message loss under backpressure +// - You have your own retry/backpressure logic +// - Low latency is critical and blocking is unacceptable +// +// For guaranteed delivery, use WriteBlocking or WriteTimeout instead. func (c *Conn) Write(message Message) error { if c.closed.Load() { return ErrConnectionClosed @@ -184,8 +240,19 @@ func (c *Conn) Write(message Message) error { } // WriteBlocking sends a message through the connection, blocking until the message -// is queued or the context is canceled. -// Returns ErrConnectionClosed if the connection is closed. +// is queued or the context is canceled. This is the safest write method for +// guaranteed delivery. +// +// Returns: +// - nil: message was successfully queued +// - context.Canceled or context.DeadlineExceeded: context was canceled +// - ErrConnectionClosed: connection is closed +// - encoding error: if codec.Encode fails +// +// Use this method when: +// - Message delivery is critical +// - You have proper timeout handling via context +// - Blocking is acceptable for your use case func (c *Conn) WriteBlocking(ctx context.Context, message Message) error { if c.closed.Load() { return ErrConnectionClosed @@ -205,8 +272,17 @@ func (c *Conn) WriteBlocking(ctx context.Context, message Message) error { } // WriteTimeout sends a message through the connection with a timeout. -// Returns ErrBufferFull if the message cannot be queued within the timeout, -// or ErrConnectionClosed if the connection is closed. +// This provides a middle ground between Write (non-blocking) and WriteBlocking. +// +// Returns: +// - nil: message was successfully queued +// - ErrBufferFull: timeout expired before message could be queued +// - ErrConnectionClosed: connection is closed +// - encoding error: if codec.Encode fails +// +// Use this method when: +// - You want to wait for buffer space but with a time limit +// - You don't have an existing context to pass func (c *Conn) WriteTimeout(message Message, timeout time.Duration) error { if c.closed.Load() { return ErrConnectionClosed @@ -233,15 +309,19 @@ func (c *Conn) Addr() net.Addr { // readLoop continuously reads from the connection and processes messages. // It decodes incoming data using the configured codec and calls the message handler. // Returns when the context is canceled or an unrecoverable error occurs. +// Messages exceeding maxReadLength will return ErrMessageTooLarge. func (c *Conn) readLoop(ctx context.Context) error { for { select { case <-ctx.Done(): return ctx.Err() default: - _ = c.rawConn.SetReadDeadline(time.Now().Add(c.opts.heartbeat * 2)) + _ = c.rawConn.SetReadDeadline(time.Now().Add(c.opts.idleTimeout * 2)) + + // Reset the limit for each message + c.limitedReader.reset(int64(c.opts.maxReadLength)) - message, err := c.opts.codec.Decode(c.reader) + message, err := c.opts.codec.Decode(c.limitedReader) if err != nil { c.logger.Debug("read error", "addr", c.Addr(), "error", err) if c.opts.onError(err) == Disconnect { @@ -276,7 +356,7 @@ func (c *Conn) writeLoop(ctx context.Context) error { // If an error occurs and onError returns true, the error is propagated. // Otherwise, the error is suppressed and writing continues. func (c *Conn) write(data []byte) error { - _ = c.rawConn.SetWriteDeadline(time.Now().Add(c.opts.heartbeat * 2)) + _ = c.rawConn.SetWriteDeadline(time.Now().Add(c.opts.idleTimeout * 2)) _, err := c.rawConn.Write(data) diff --git a/conn_test.go b/conn_test.go index 3925864..797b827 100644 --- a/conn_test.go +++ b/conn_test.go @@ -161,7 +161,7 @@ func TestNewConn_WithAllOptions(t *testing.T) { OnMessageOption(onMessage), OnErrorOption(onError), BufferSizeOption(10), - HeartbeatOption(time.Minute), + IdleTimeoutOption(time.Minute), MessageMaxSize(2048), ) @@ -173,8 +173,8 @@ func TestNewConn_WithAllOptions(t *testing.T) { t.Errorf("bufferSize = %d, want 10", conn.opts.bufferSize) } - if conn.opts.heartbeat != time.Minute { - t.Errorf("heartbeat = %v, want %v", conn.opts.heartbeat, time.Minute) + if conn.opts.idleTimeout != time.Minute { + t.Errorf("idleTimeout = %v, want %v", conn.opts.idleTimeout, time.Minute) } if conn.opts.maxReadLength != 2048 { @@ -204,8 +204,8 @@ func TestCheckOptions_DefaultValues(t *testing.T) { t.Errorf("maxReadLength = %d, want %d", opts.maxReadLength, defaultMaxPackageLength) } - if opts.heartbeat != time.Second*30 { - t.Errorf("heartbeat = %v, want %v", opts.heartbeat, time.Second*30) + if opts.idleTimeout != time.Second*30 { + t.Errorf("idleTimeout = %v, want %v", opts.idleTimeout, time.Second*30) } if opts.onError == nil { @@ -455,7 +455,7 @@ func TestConn_Run_ReadWrite(t *testing.T) { conn, err := NewConn(serverConn, CustomCodecOption(codec), OnMessageOption(onMessage), - HeartbeatOption(time.Second*5), + IdleTimeoutOption(time.Second*5), ) if err != nil { t.Fatalf("NewConn failed: %v", err) @@ -510,7 +510,7 @@ func TestConn_Run_DecodeError(t *testing.T) { conn, err := NewConn(serverConn, CustomCodecOption(codec), OnMessageOption(onMessage), - HeartbeatOption(time.Second*5), + IdleTimeoutOption(time.Second*5), ) if err != nil { t.Fatalf("NewConn failed: %v", err) @@ -551,7 +551,7 @@ func TestConn_Run_OnMessageError(t *testing.T) { conn, err := NewConn(serverConn, CustomCodecOption(codec), OnMessageOption(onMessage), - HeartbeatOption(time.Second*5), + IdleTimeoutOption(time.Second*5), ) if err != nil { t.Fatalf("NewConn failed: %v", err) @@ -589,7 +589,7 @@ func TestConn_Run_WriteLoop(t *testing.T) { conn, err := NewConn(serverConn, CustomCodecOption(codec), OnMessageOption(onMessage), - HeartbeatOption(time.Second*5), + IdleTimeoutOption(time.Second*5), ) if err != nil { t.Fatalf("NewConn failed: %v", err) @@ -636,7 +636,7 @@ func TestConn_Run_ReadError_OnErrorReturnsContinue(t *testing.T) { CustomCodecOption(codec), OnMessageOption(onMessage), OnErrorOption(onError), - HeartbeatOption(time.Millisecond*100), + IdleTimeoutOption(time.Millisecond*100), ) if err != nil { t.Fatalf("NewConn failed: %v", err) @@ -708,7 +708,7 @@ func TestNewClientConnWithOptions(t *testing.T) { codec: &mockCodec{}, onMessage: func(msg Message) error { return nil }, bufferSize: 5, - heartbeat: time.Minute, + idleTimeout: time.Minute, maxReadLength: 4096, logger: defaultLogger(), } @@ -719,8 +719,8 @@ func TestNewClientConnWithOptions(t *testing.T) { t.Error("rawConn not set correctly") } - if conn.opts.heartbeat != time.Minute { - t.Errorf("heartbeat = %v, want %v", conn.opts.heartbeat, time.Minute) + if conn.opts.idleTimeout != time.Minute { + t.Errorf("idleTimeout = %v, want %v", conn.opts.idleTimeout, time.Minute) } if cap(conn.sendMsg) != 5 { @@ -737,7 +737,7 @@ func TestConn_WriteLoop_WriteError(t *testing.T) { conn, err := NewConn(serverConn, CustomCodecOption(codec), OnMessageOption(onMessage), - HeartbeatOption(time.Second*5), + IdleTimeoutOption(time.Second*5), ) if err != nil { t.Fatalf("NewConn failed: %v", err) @@ -778,7 +778,7 @@ func TestConn_Write_OnErrorReturnsContinue(t *testing.T) { CustomCodecOption(codec), OnMessageOption(onMessage), OnErrorOption(onError), - HeartbeatOption(time.Millisecond*100), + IdleTimeoutOption(time.Millisecond*100), ) if err != nil { t.Fatalf("NewConn failed: %v", err) @@ -825,7 +825,7 @@ func TestConn_WriteLoop_ContextCanceled(t *testing.T) { conn, err := NewConn(serverConn, CustomCodecOption(codec), OnMessageOption(onMessage), - HeartbeatOption(time.Millisecond*100), + IdleTimeoutOption(time.Millisecond*100), ) if err != nil { t.Fatalf("NewConn failed: %v", err) @@ -867,7 +867,7 @@ func TestConn_Write_Success(t *testing.T) { conn, err := NewConn(serverConn, CustomCodecOption(codec), OnMessageOption(onMessage), - HeartbeatOption(time.Second*5), + IdleTimeoutOption(time.Second*5), ) if err != nil { t.Fatalf("NewConn failed: %v", err) @@ -1013,7 +1013,7 @@ func TestConn_write_ErrorWithOnErrorDisconnect(t *testing.T) { conn, err := NewConn(serverConn, CustomCodecOption(codec), OnMessageOption(onMessage), - HeartbeatOption(time.Millisecond*50), + IdleTimeoutOption(time.Millisecond*50), ) if err != nil { t.Fatalf("NewConn failed: %v", err) @@ -1039,7 +1039,7 @@ func TestConn_writeLoop_WriteError_Direct(t *testing.T) { conn, err := NewConn(serverConn, CustomCodecOption(codec), OnMessageOption(onMessage), - HeartbeatOption(time.Millisecond*50), + IdleTimeoutOption(time.Millisecond*50), ) if err != nil { t.Fatalf("NewConn failed: %v", err) @@ -1063,3 +1063,124 @@ func TestConn_writeLoop_WriteError_Direct(t *testing.T) { t.Error("writeLoop should return error when write fails") } } + +func TestConn_ReadLoop_MessageTooLarge(t *testing.T) { + serverConn, clientConn := createTestTCPPair(t) + defer serverConn.Close() + defer clientConn.Close() + + // Set a small max message size + maxSize := 10 + var receivedErr error + + codec := &mockCodec{ + decodeFunc: func(r io.Reader) (Message, error) { + // Try to read more than the limit + buf := make([]byte, maxSize+10) + _, err := io.ReadFull(r, buf) + if err != nil { + return nil, err + } + return mockMessage{body: buf}, nil + }, + } + + conn, err := NewConn(serverConn, + CustomCodecOption(codec), + OnMessageOption(func(msg Message) error { return nil }), + MessageMaxSize(maxSize), + OnErrorOption(func(err error) ErrorAction { + receivedErr = err + return Disconnect + }), + IdleTimeoutOption(time.Second), + ) + if err != nil { + t.Fatalf("NewConn failed: %v", err) + } + + // Send data larger than the limit + go func() { + largeData := make([]byte, maxSize+10) + clientConn.Write(largeData) + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + + runErr := conn.Run(ctx) + + // Should get ErrMessageTooLarge + if !errors.Is(receivedErr, ErrMessageTooLarge) { + t.Errorf("expected ErrMessageTooLarge, got: %v (runErr: %v)", receivedErr, runErr) + } +} + +func TestLimitedReader(t *testing.T) { + t.Run("within limit", func(t *testing.T) { + data := []byte("hello") + lr := newLimitedReader(NewBytesReader(data), 10) + + buf := make([]byte, 10) + n, err := lr.Read(buf) + if err != nil && err != io.EOF { + t.Fatalf("unexpected error: %v", err) + } + if n != 5 { + t.Errorf("expected 5 bytes, got %d", n) + } + }) + + t.Run("exceeds limit", func(t *testing.T) { + data := []byte("hello world") + lr := newLimitedReader(NewBytesReader(data), 5) + + buf := make([]byte, 10) + n, err := lr.Read(buf) + if err != nil { + t.Fatalf("first read error: %v", err) + } + if n != 5 { + t.Errorf("expected 5 bytes, got %d", n) + } + + // Second read should return ErrMessageTooLarge + _, err = lr.Read(buf) + if !errors.Is(err, ErrMessageTooLarge) { + t.Errorf("expected ErrMessageTooLarge, got: %v", err) + } + }) + + t.Run("reset", func(t *testing.T) { + data := []byte("hello") + lr := newLimitedReader(NewBytesReader(data), 3) + + buf := make([]byte, 3) + lr.Read(buf) + + // After reset, should be able to read again + lr.reset(10) + if lr.remaining != 10 { + t.Errorf("remaining should be 10, got %d", lr.remaining) + } + }) +} + +// NewBytesReader creates a simple bytes reader for testing +type bytesReader struct { + data []byte + pos int +} + +func NewBytesReader(data []byte) *bytesReader { + return &bytesReader{data: data} +} + +func (r *bytesReader) Read(p []byte) (n int, err error) { + if r.pos >= len(r.data) { + return 0, io.EOF + } + n = copy(p, r.data[r.pos:]) + r.pos += n + return n, nil +} diff --git a/option.go b/option.go index 594ceb5..acfbbf0 100644 --- a/option.go +++ b/option.go @@ -26,7 +26,7 @@ type options struct { bufferSize int // size of buffered channel maxReadLength int // maximum size of a single message - heartbeat time.Duration // heartbeat interval for read/write deadlines + idleTimeout time.Duration // idle timeout for read/write deadlines } // Option is a function that configures connection options. @@ -48,11 +48,17 @@ func BufferSizeOption(size int) Option { } } -// HeartbeatOption returns an Option that sets the heartbeat interval. -// This determines the read/write deadline timeout (heartbeat * 2). -func HeartbeatOption(heartbeat time.Duration) Option { +// IdleTimeoutOption returns an Option that sets the idle timeout for connections. +// If no data is received within the timeout period, the connection will be closed. +// The actual read/write deadline is set to idleTimeout * 2 to allow for some network latency. +// Default is 30 seconds. +// +// Note: This is NOT a heartbeat mechanism that sends ping/pong packets. +// It only sets TCP read/write deadlines. If you need active connection health checking, +// implement ping/pong messages in your application protocol. +func IdleTimeoutOption(timeout time.Duration) Option { return func(o *options) { - o.heartbeat = heartbeat + o.idleTimeout = timeout } } diff --git a/option_test.go b/option_test.go index d1506ba..43053a8 100644 --- a/option_test.go +++ b/option_test.go @@ -28,15 +28,15 @@ func TestBufferSizeOption(t *testing.T) { } } -func TestHeartbeatOption(t *testing.T) { - heartbeat := time.Minute * 5 - opt := HeartbeatOption(heartbeat) +func TestIdleTimeoutOption(t *testing.T) { + timeout := time.Minute * 5 + opt := IdleTimeoutOption(timeout) var opts options opt(&opts) - if opts.heartbeat != heartbeat { - t.Errorf("heartbeat = %v, want %v", opts.heartbeat, heartbeat) + if opts.idleTimeout != timeout { + t.Errorf("idleTimeout = %v, want %v", opts.idleTimeout, timeout) } } @@ -112,7 +112,7 @@ func TestOptions_MultipleOptions(t *testing.T) { logger := &mockLogger{} onMessage := func(msg Message) error { return nil } onError := func(err error) ErrorAction { return Continue } - heartbeat := time.Second * 45 + idleTimeout := time.Second * 45 bufferSize := 50 maxSize := 8192 @@ -121,7 +121,7 @@ func TestOptions_MultipleOptions(t *testing.T) { CustomCodecOption(codec), OnMessageOption(onMessage), OnErrorOption(onError), - HeartbeatOption(heartbeat), + IdleTimeoutOption(idleTimeout), BufferSizeOption(bufferSize), MessageMaxSize(maxSize), LoggerOption(logger), @@ -140,8 +140,8 @@ func TestOptions_MultipleOptions(t *testing.T) { if opts.onError == nil { t.Error("onError not set") } - if opts.heartbeat != heartbeat { - t.Errorf("heartbeat = %v, want %v", opts.heartbeat, heartbeat) + if opts.idleTimeout != idleTimeout { + t.Errorf("idleTimeout = %v, want %v", opts.idleTimeout, idleTimeout) } if opts.bufferSize != bufferSize { t.Errorf("bufferSize = %d, want %d", opts.bufferSize, bufferSize) diff --git a/server.go b/server.go index abe64ce..da23168 100644 --- a/server.go +++ b/server.go @@ -3,6 +3,7 @@ package socket import ( "context" "errors" + "log/slog" "net" "sync" "time" @@ -18,32 +19,85 @@ type Handler interface { // Server represents a TCP server that listens for incoming connections. type Server struct { - listener *net.TCPListener + listener *net.TCPListener + logger Logger + shutdownTimeout time.Duration - mu sync.Mutex - shutdown bool + mu sync.Mutex + shutdown bool + shutdownNow chan struct{} // signals immediate shutdown, bypassing timeout +} + +// ServerOption configures a Server. +type ServerOption func(*Server) + +// ServerLoggerOption sets the logger for the server. +func ServerLoggerOption(logger Logger) ServerOption { + return func(s *Server) { + s.logger = logger + } +} + +// ServerShutdownTimeoutOption sets the graceful shutdown timeout. +// When the context is canceled, the server will wait up to this duration +// before closing the listener. This gives existing connections time to complete. +// Default is 0 (immediate shutdown). +// +// Note: This only delays listener closure. For full graceful shutdown with +// connection draining, track connections at the application level and cancel +// them with the context passed to Conn.Run(). +func ServerShutdownTimeoutOption(timeout time.Duration) ServerOption { + return func(s *Server) { + s.shutdownTimeout = timeout + } } // New creates a new TCP server bound to the specified address. // Returns an error if the address cannot be bound. -func New(addr *net.TCPAddr) (*Server, error) { +func New(addr *net.TCPAddr, opts ...ServerOption) (*Server, error) { listener, err := net.ListenTCP(addr.Network(), addr) if err != nil { return nil, err } - return &Server{ - listener: listener, - }, nil + s := &Server{ + listener: listener, + logger: slog.Default(), + shutdownNow: make(chan struct{}), + } + + for _, opt := range opts { + opt(s) + } + + return s, nil } // Serve starts accepting connections and dispatching them to the handler. // It blocks until the context is canceled or an unrecoverable error occurs. // When the context is canceled, it stops accepting new connections gracefully. +// If ServerShutdownTimeoutOption is set, the server waits up to the specified +// duration before stopping, allowing existing handlers to complete. Call Close() +// to bypass the timeout and shut down immediately. func (s *Server) Serve(ctx context.Context, handler Handler) error { + s.logger.Info("server started", "addr", s.listener.Addr()) + // Start a goroutine to handle context cancellation go func() { <-ctx.Done() + + // Wait for shutdown timeout if configured, but allow early exit via Close() + if s.shutdownTimeout > 0 { + s.logger.Info("graceful shutdown initiated", "timeout", s.shutdownTimeout) + select { + case <-time.After(s.shutdownTimeout): + // Timeout expired, proceed with shutdown + case <-s.shutdownNow: + // Close() was called, skip remaining timeout + s.logger.Debug("shutdown timeout bypassed via Close()") + } + } + s.mu.Lock() s.shutdown = true s.mu.Unlock() @@ -59,6 +113,7 @@ func (s *Server) Serve(ctx context.Context, handler Handler) error { s.mu.Unlock() if isShutdown { + s.logger.Info("server stopped", "addr", s.listener.Addr()) return ctx.Err() } @@ -67,20 +122,31 @@ func (s *Server) Serve(ctx context.Context, handler Handler) error { if errors.As(err, &netErr) && netErr.Timeout() { continue } + s.logger.Error("accept error", "error", err) return err } + s.logger.Debug("accepted connection", "remote_addr", conn.RemoteAddr()) _ = conn.SetNoDelay(true) go handler.Handle(conn) } } // Close stops the server by closing the underlying listener. +// If a shutdown timeout is configured, Close() bypasses the remaining timeout. // Any blocked Accept calls will return with an error. func (s *Server) Close() error { s.mu.Lock() s.shutdown = true s.mu.Unlock() + + // Signal to bypass any pending shutdown timeout + select { + case s.shutdownNow <- struct{}{}: + default: + // Channel already has a signal or no one is listening + } + return s.listener.Close() } diff --git a/server_test.go b/server_test.go index 0aabcd8..9d22199 100644 --- a/server_test.go +++ b/server_test.go @@ -243,3 +243,18 @@ func TestServer_Serve_ContextCanceled(t *testing.T) { t.Fatal("timeout waiting for Serve to return") } } + +func TestServerLoggerOption(t *testing.T) { + logger := &mockLogger{} + addr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0} + + server, err := New(addr, ServerLoggerOption(logger)) + if err != nil { + t.Fatalf("New failed: %v", err) + } + defer server.Close() + + if server.logger != logger { + t.Error("logger not set correctly") + } +}