diff --git a/.gitignore b/.gitignore index dfdb708..4342e9f 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ .claude/settings.local.json /tcpulse + +coverage.out diff --git a/Makefile b/Makefile index edc2a62..d5e8961 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ OUT_DOCKER ?= tcpulse all: build build: vet staticcheck - $(GO) build -o $(OUT_BIN) + CGO_ENABLED=0 $(GO) build -o $(OUT_BIN) vet: $(GO) vet ./... diff --git a/client.go b/client.go index 97052bd..c353b59 100644 --- a/client.go +++ b/client.go @@ -41,21 +41,24 @@ func NewClient(config ClientConfig) *Client { } func waitLim(ctx context.Context, rl ratelimit.Limiter) error { + // Quick context check before any blocking operation select { case <-ctx.Done(): return ctx.Err() default: - done := make(chan struct{}) - go func() { - rl.Take() - close(done) - }() - select { - case <-done: - return nil - case <-ctx.Done(): - return ctx.Err() - } + } + + done := make(chan struct{}) + go func() { + defer close(done) + rl.Take() + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() } } @@ -97,6 +100,10 @@ func (c *Client) ConnectToAddresses(ctx context.Context, addrs []string) error { } if err := eg.Wait(); err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + slog.Warn("context canceled", "error", err) + return nil + } return fmt.Errorf("connection error: %w", err) } return nil @@ -135,12 +142,17 @@ func (c *Client) connectPersistent(ctx context.Context, addrport string) error { eg, ctx := errgroup.WithContext(ctx) for i := 0; i < int(c.config.Connections); i++ { eg.Go(func() error { - conn, err := dialer.Dial("tcp", addrport) + conn, err := dialer.DialContext(ctx, "tcp", addrport) if err != nil { return fmt.Errorf("dialing %q: %w", addrport, err) } defer conn.Close() + // Set deadlines based on context to make Read/Write operations interruptible + if deadline, ok := ctx.Deadline(); ok { + conn.SetDeadline(deadline) + } + msgsTotal := int64(c.config.Rate) * int64(c.config.Duration.Seconds()) limiter := ratelimit.New(int(c.config.Rate)) @@ -197,17 +209,25 @@ func (c *Client) connectEphemeral(ctx context.Context, addrport string) error { limiter := ratelimit.New(int(c.config.Rate)) eg, ctx := errgroup.WithContext(ctx) +ephemeralLoop: for i := int64(0); i < connTotal; i++ { + // Check for context cancellation at the start of each iteration + select { + case <-ctx.Done(): + break ephemeralLoop + default: + } + if err := waitLim(ctx, limiter); err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - break + break ephemeralLoop } continue } eg.Go(func() error { return measureTime(addrport, c.config.MergeResultsEachHost, func() error { - conn, err := dialer.Dial("tcp", addrport) + conn, err := dialer.DialContext(ctx, "tcp", addrport) if err != nil { if errors.Is(err, syscall.ETIMEDOUT) { slog.Warn("connection timeout", "addr", addrport) @@ -217,6 +237,11 @@ func (c *Client) connectEphemeral(ctx context.Context, addrport string) error { } defer conn.Close() + // Set deadlines based on context to make Read/Write operations interruptible + if deadline, ok := ctx.Deadline(); ok { + conn.SetDeadline(deadline) + } + if err := SetQuickAck(conn); err != nil { return fmt.Errorf("setting quick ack: %w", err) } @@ -267,22 +292,36 @@ func (c *Client) connectUDP(ctx context.Context, addrport string) error { } eg, ctx := errgroup.WithContext(ctx) +udpLoop: for i := int64(0); i < connTotal; i++ { + // Check for context cancellation at the start of each iteration + select { + case <-ctx.Done(): + break udpLoop + default: + } + if err := waitLim(ctx, limiter); err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - break + break udpLoop } continue } eg.Go(func() error { return measureTime(addrport, c.config.MergeResultsEachHost, func() error { - conn, err := net.Dial("udp4", addrport) + var dialer net.Dialer + conn, err := dialer.DialContext(ctx, "udp4", addrport) if err != nil { return fmt.Errorf("dialing UDP %q: %w", addrport, err) } defer conn.Close() + // Set deadlines based on context to make Read/Write operations interruptible + if deadline, ok := ctx.Deadline(); ok { + conn.SetDeadline(deadline) + } + msgPtr := bufUDPPool.Get().(*[]byte) msg := *msgPtr defer bufUDPPool.Put(msgPtr) diff --git a/client_test.go b/client_test.go index 8dd305a..9e2d452 100644 --- a/client_test.go +++ b/client_test.go @@ -1249,6 +1249,169 @@ func TestMeasureTimeWithPanic(t *testing.T) { }) } +// TestWaitLimContextCancellation tests the improved waitLim function for responsiveness to context cancellation +func TestWaitLimContextCancellation(t *testing.T) { + tests := []struct { + name string + ctxTimeout time.Duration + expectError bool + expectQuickExit bool + }{ + { + name: "immediate_cancellation", + ctxTimeout: 10 * time.Millisecond, // Slightly longer to ensure rate limiter blocks + expectError: true, + expectQuickExit: true, + }, + { + name: "normal_operation", + ctxTimeout: 100 * time.Millisecond, + expectError: false, + expectQuickExit: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), tt.ctxTimeout) + defer cancel() + + limiter := ratelimit.New(1) // Very slow rate to force blocking + if tt.expectError { + // Use up the token bucket to force waiting + limiter.Take() + } + start := time.Now() + err := waitLim(ctx, limiter) + elapsed := time.Since(start) + + if tt.expectQuickExit && elapsed > 50*time.Millisecond { + t.Errorf("Expected quick exit but took %v", elapsed) + } + + if tt.expectError { + if err == nil { + t.Error("Expected context cancellation error, got nil") + } + } else if err != nil { + t.Errorf("Expected no error, got %v", err) + } + }) + } +} + +// TestConnectEphemeralContextCancellationLoop tests that the ephemeral connection loop responds to context cancellation +func TestConnectEphemeralContextCancellationLoop(t *testing.T) { + // Use a mock server that never accepts connections to test cancellation behavior + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + defer listener.Close() + + // Don't call Accept() so connections will timeout/fail + addr := listener.Addr().String() + + client := NewClient(ClientConfig{ + Protocol: "tcp", + ConnectFlavor: flavorEphemeral, + Rate: 100, // High rate to create many goroutines + Duration: 10 * time.Second, // Long duration + MessageBytes: 32, + }) + + // Cancel context after short time to test responsiveness + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + start := time.Now() + err = client.connectEphemeral(ctx, addr) + elapsed := time.Since(start) + + // Should exit quickly due to context cancellation, not wait for full duration + if elapsed > 500*time.Millisecond { + t.Errorf("Expected quick exit due to context cancellation, but took %v", elapsed) + } + + // Should get some error (context cancellation or connection error), the key is that it exits quickly + if err == nil { + t.Error("Expected some error due to context cancellation or connection issues") + } +} + +// TestConnectUDPContextCancellationLoop tests that the UDP connection loop responds to context cancellation +func TestConnectUDPContextCancellationLoop(t *testing.T) { + // Use a non-routable address that should timeout rather than immediately fail + addr := "192.0.2.1:80" // Test network that should timeout + + client := NewClient(ClientConfig{ + Protocol: "udp", + Rate: 100, // High rate to create many goroutines + Duration: 10 * time.Second, // Long duration + MessageBytes: 32, + }) + + // Cancel context after short time to test responsiveness + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + start := time.Now() + err := client.connectUDP(ctx, addr) + elapsed := time.Since(start) + + // Should exit quickly due to context cancellation, not wait for full duration + if elapsed > 500*time.Millisecond { + t.Errorf("Expected quick exit due to context cancellation, but took %v", elapsed) + } + + // Should get context cancellation error or connection error, but exit quickly + if err == nil { + t.Error("Expected some error due to context cancellation or connection issues") + } + + // The important thing is that it exits quickly, not the specific error type + // since UDP connection behavior can vary depending on network configuration +} + +// TestConnectPersistentContextCancellationQuick tests that persistent connections respond to context cancellation quickly +func TestConnectPersistentContextCancellationQuick(t *testing.T) { + // Use a mock server that never accepts connections + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + defer listener.Close() + + addr := listener.Addr().String() + + client := NewClient(ClientConfig{ + Protocol: "tcp", + ConnectFlavor: flavorPersistent, + Connections: 5, // Multiple connections + Rate: 10, + Duration: 10 * time.Second, // Long duration + MessageBytes: 32, + }) + + // Cancel context after short time to test responsiveness + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + start := time.Now() + connErr := client.connectPersistent(ctx, addr) + elapsed := time.Since(start) + + // Should exit quickly due to context cancellation + if elapsed > 500*time.Millisecond { + t.Errorf("Expected quick exit due to context cancellation, but took %v", elapsed) + } + + // Should get context cancellation or connection error + if connErr == nil { + t.Error("Expected error due to context cancellation or connection failure") + } +} + func TestWaitLimWithSlowRateLimit(t *testing.T) { // Test with very slow rate limiter and context timeout limiter := ratelimit.New(1) // 1 per second @@ -1381,3 +1544,167 @@ func TestMetricsCleanupBetweenTests(t *testing.T) { // Clean up unregisterTimer(key, addr, false) } + +// TestDialContextWithTimeout tests that DialContext respects context timeouts +func TestDialContextWithTimeout(t *testing.T) { + // Use a non-routable address that will timeout + addr := "192.0.2.1:80" // Test network address that should timeout + + // Very short timeout to ensure quick failure + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + dialer := net.Dialer{} + start := time.Now() + _, err := dialer.DialContext(ctx, "tcp", addr) + elapsed := time.Since(start) + + if err == nil { + t.Error("Expected timeout error, got nil") + } + + if elapsed > 100*time.Millisecond { + t.Errorf("Expected quick timeout, but took %v", elapsed) + } + + if !strings.Contains(err.Error(), "context deadline exceeded") && !strings.Contains(err.Error(), "timeout") { + t.Errorf("Expected context deadline or timeout error, got %v", err) + } +} + +// TestConnectionDeadlineHandling tests that connection deadlines are properly set +func TestConnectionDeadlineHandling(t *testing.T) { + // Create a mock server that accepts connections but doesn't respond + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + defer listener.Close() + + // Accept connections but don't read/write + go func() { + for { + conn, err := listener.Accept() + if err != nil { + return + } + // Keep connection open but don't read/write + time.Sleep(1 * time.Second) + conn.Close() + } + }() + + addr := listener.Addr().String() + + // Create context with short deadline + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + dialer := net.Dialer{} + conn, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + t.Fatalf("Failed to connect: %v", err) + } + defer conn.Close() + + // Set deadline based on context + if deadline, ok := ctx.Deadline(); ok { + conn.SetDeadline(deadline) + } + + // Try to read - should timeout quickly due to deadline + start := time.Now() + buf := make([]byte, 10) + _, err = conn.Read(buf) + elapsed := time.Since(start) + + if err == nil { + t.Error("Expected timeout error on read, got nil") + } + + if elapsed > 100*time.Millisecond { + t.Errorf("Expected quick timeout on read, but took %v", elapsed) + } +} + +// TestEphemeralLoopBreakOnCancellation tests that the ephemeral loop properly breaks on context cancellation +func TestEphemeralLoopBreakOnCancellation(t *testing.T) { + // This test verifies the labeled break functionality works correctly + // Use a valid but unresponsive address + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to create listener: %v", err) + } + listener.Close() // Close immediately so connections will fail quickly + + addr := listener.Addr().String() + + client := NewClient(ClientConfig{ + Protocol: "tcp", + ConnectFlavor: flavorEphemeral, + Rate: 1000, // Very high rate + Duration: 30 * time.Second, // Long duration + MessageBytes: 32, + }) + + // Short context timeout + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + start := time.Now() + _ = client.connectEphemeral(ctx, addr) + elapsed := time.Since(start) + + // Should exit much faster than the 30-second duration due to context cancellation + if elapsed > 200*time.Millisecond { + t.Errorf("Expected loop to break quickly on context cancellation, but took %v", elapsed) + } +} + +// TestUDPLoopBreakOnCancellation tests that the UDP loop properly breaks on context cancellation +func TestUDPLoopBreakOnCancellation(t *testing.T) { + // Use port 0 which should fail connections quickly + addr := "127.0.0.1:0" + + client := NewClient(ClientConfig{ + Protocol: "udp", + Rate: 1000, // Very high rate + Duration: 30 * time.Second, // Long duration + MessageBytes: 32, + }) + + // Short context timeout + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + start := time.Now() + _ = client.connectUDP(ctx, addr) + elapsed := time.Since(start) + + // Should exit much faster than the 30-second duration due to context cancellation + if elapsed > 200*time.Millisecond { + t.Errorf("Expected UDP loop to break quickly on context cancellation, but took %v", elapsed) + } +} + +// TestWaitLimRateLimitingBehavior tests that waitLim properly rate limits while being responsive to cancellation +func TestWaitLimRateLimitingBehavior(t *testing.T) { + limiter := ratelimit.New(5) // 5 per second + ctx := context.Background() + + // Take several tokens quickly + start := time.Now() + for i := 0; i < 3; i++ { + err := waitLim(ctx, limiter) + if err != nil { + t.Errorf("Unexpected error in waitLim: %v", err) + } + } + elapsed := time.Since(start) + + // Should take at least some time due to rate limiting + expectedMinDuration := 400 * time.Millisecond // 3 tokens at 5/sec should take ~400ms + if elapsed < expectedMinDuration { + t.Errorf("Expected rate limiting to take at least %v, but took %v", expectedMinDuration, elapsed) + } +}