Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
.claude/settings.local.json

/tcpulse

coverage.out
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ OUT_DOCKER ?= tcpulse
all: build

build: vet staticcheck
$(GO) build -o $(OUT_BIN)
CGO_ENABLED=0 $(GO) build -o $(OUT_BIN)

vet:
$(GO) vet ./...
Expand Down
71 changes: 55 additions & 16 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,24 @@ func NewClient(config ClientConfig) *Client {
}

func waitLim(ctx context.Context, rl ratelimit.Limiter) error {
// Quick context check before any blocking operation
select {
case <-ctx.Done():
return ctx.Err()
default:
done := make(chan struct{})
go func() {
rl.Take()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}

done := make(chan struct{})
go func() {
defer close(done)
rl.Take()
}()

select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}

Expand Down Expand Up @@ -97,6 +100,10 @@ func (c *Client) ConnectToAddresses(ctx context.Context, addrs []string) error {
}

if err := eg.Wait(); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
slog.Warn("context canceled", "error", err)
return nil
}
return fmt.Errorf("connection error: %w", err)
}
return nil
Expand Down Expand Up @@ -135,12 +142,17 @@ func (c *Client) connectPersistent(ctx context.Context, addrport string) error {
eg, ctx := errgroup.WithContext(ctx)
for i := 0; i < int(c.config.Connections); i++ {
eg.Go(func() error {
conn, err := dialer.Dial("tcp", addrport)
conn, err := dialer.DialContext(ctx, "tcp", addrport)
if err != nil {
return fmt.Errorf("dialing %q: %w", addrport, err)
}
defer conn.Close()

// Set deadlines based on context to make Read/Write operations interruptible
if deadline, ok := ctx.Deadline(); ok {
conn.SetDeadline(deadline)
}

msgsTotal := int64(c.config.Rate) * int64(c.config.Duration.Seconds())
limiter := ratelimit.New(int(c.config.Rate))

Expand Down Expand Up @@ -197,17 +209,25 @@ func (c *Client) connectEphemeral(ctx context.Context, addrport string) error {
limiter := ratelimit.New(int(c.config.Rate))

eg, ctx := errgroup.WithContext(ctx)
ephemeralLoop:
for i := int64(0); i < connTotal; i++ {
// Check for context cancellation at the start of each iteration
select {
case <-ctx.Done():
break ephemeralLoop
default:
}

if err := waitLim(ctx, limiter); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
break
break ephemeralLoop
}
continue
}

eg.Go(func() error {
return measureTime(addrport, c.config.MergeResultsEachHost, func() error {
conn, err := dialer.Dial("tcp", addrport)
conn, err := dialer.DialContext(ctx, "tcp", addrport)
if err != nil {
if errors.Is(err, syscall.ETIMEDOUT) {
slog.Warn("connection timeout", "addr", addrport)
Expand All @@ -217,6 +237,11 @@ func (c *Client) connectEphemeral(ctx context.Context, addrport string) error {
}
defer conn.Close()

// Set deadlines based on context to make Read/Write operations interruptible
if deadline, ok := ctx.Deadline(); ok {
conn.SetDeadline(deadline)
}

if err := SetQuickAck(conn); err != nil {
return fmt.Errorf("setting quick ack: %w", err)
}
Expand Down Expand Up @@ -267,22 +292,36 @@ func (c *Client) connectUDP(ctx context.Context, addrport string) error {
}

eg, ctx := errgroup.WithContext(ctx)
udpLoop:
for i := int64(0); i < connTotal; i++ {
// Check for context cancellation at the start of each iteration
select {
case <-ctx.Done():
break udpLoop
default:
}

if err := waitLim(ctx, limiter); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
break
break udpLoop
}
continue
}

eg.Go(func() error {
return measureTime(addrport, c.config.MergeResultsEachHost, func() error {
conn, err := net.Dial("udp4", addrport)
var dialer net.Dialer
conn, err := dialer.DialContext(ctx, "udp4", addrport)
if err != nil {
return fmt.Errorf("dialing UDP %q: %w", addrport, err)
}
defer conn.Close()

// Set deadlines based on context to make Read/Write operations interruptible
if deadline, ok := ctx.Deadline(); ok {
conn.SetDeadline(deadline)
}

msgPtr := bufUDPPool.Get().(*[]byte)
msg := *msgPtr
defer bufUDPPool.Put(msgPtr)
Expand Down
Loading