Skip to content

Commit ebced03

Browse files
authored
Merge pull request #41 from yuuki/fix/sigint-handling-ephemeral-mode
Fix SIGINT handling in high-load ephemeral TCP/UDP mode
2 parents 7069cbe + f4c389d commit ebced03

File tree

4 files changed

+385
-17
lines changed

4 files changed

+385
-17
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@
22
.claude/settings.local.json
33

44
/tcpulse
5+
6+
coverage.out

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ OUT_DOCKER ?= tcpulse
88
all: build
99

1010
build: vet staticcheck
11-
$(GO) build -o $(OUT_BIN)
11+
CGO_ENABLED=0 $(GO) build -o $(OUT_BIN)
1212

1313
vet:
1414
$(GO) vet ./...

client.go

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,24 @@ func NewClient(config ClientConfig) *Client {
4141
}
4242

4343
func waitLim(ctx context.Context, rl ratelimit.Limiter) error {
44+
// Quick context check before any blocking operation
4445
select {
4546
case <-ctx.Done():
4647
return ctx.Err()
4748
default:
48-
done := make(chan struct{})
49-
go func() {
50-
rl.Take()
51-
close(done)
52-
}()
53-
select {
54-
case <-done:
55-
return nil
56-
case <-ctx.Done():
57-
return ctx.Err()
58-
}
49+
}
50+
51+
done := make(chan struct{})
52+
go func() {
53+
defer close(done)
54+
rl.Take()
55+
}()
56+
57+
select {
58+
case <-done:
59+
return nil
60+
case <-ctx.Done():
61+
return ctx.Err()
5962
}
6063
}
6164

@@ -97,6 +100,10 @@ func (c *Client) ConnectToAddresses(ctx context.Context, addrs []string) error {
97100
}
98101

99102
if err := eg.Wait(); err != nil {
103+
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
104+
slog.Warn("context canceled", "error", err)
105+
return nil
106+
}
100107
return fmt.Errorf("connection error: %w", err)
101108
}
102109
return nil
@@ -135,12 +142,17 @@ func (c *Client) connectPersistent(ctx context.Context, addrport string) error {
135142
eg, ctx := errgroup.WithContext(ctx)
136143
for i := 0; i < int(c.config.Connections); i++ {
137144
eg.Go(func() error {
138-
conn, err := dialer.Dial("tcp", addrport)
145+
conn, err := dialer.DialContext(ctx, "tcp", addrport)
139146
if err != nil {
140147
return fmt.Errorf("dialing %q: %w", addrport, err)
141148
}
142149
defer conn.Close()
143150

151+
// Set deadlines based on context to make Read/Write operations interruptible
152+
if deadline, ok := ctx.Deadline(); ok {
153+
conn.SetDeadline(deadline)
154+
}
155+
144156
msgsTotal := int64(c.config.Rate) * int64(c.config.Duration.Seconds())
145157
limiter := ratelimit.New(int(c.config.Rate))
146158

@@ -197,17 +209,25 @@ func (c *Client) connectEphemeral(ctx context.Context, addrport string) error {
197209
limiter := ratelimit.New(int(c.config.Rate))
198210

199211
eg, ctx := errgroup.WithContext(ctx)
212+
ephemeralLoop:
200213
for i := int64(0); i < connTotal; i++ {
214+
// Check for context cancellation at the start of each iteration
215+
select {
216+
case <-ctx.Done():
217+
break ephemeralLoop
218+
default:
219+
}
220+
201221
if err := waitLim(ctx, limiter); err != nil {
202222
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
203-
break
223+
break ephemeralLoop
204224
}
205225
continue
206226
}
207227

208228
eg.Go(func() error {
209229
return measureTime(addrport, c.config.MergeResultsEachHost, func() error {
210-
conn, err := dialer.Dial("tcp", addrport)
230+
conn, err := dialer.DialContext(ctx, "tcp", addrport)
211231
if err != nil {
212232
if errors.Is(err, syscall.ETIMEDOUT) {
213233
slog.Warn("connection timeout", "addr", addrport)
@@ -217,6 +237,11 @@ func (c *Client) connectEphemeral(ctx context.Context, addrport string) error {
217237
}
218238
defer conn.Close()
219239

240+
// Set deadlines based on context to make Read/Write operations interruptible
241+
if deadline, ok := ctx.Deadline(); ok {
242+
conn.SetDeadline(deadline)
243+
}
244+
220245
if err := SetQuickAck(conn); err != nil {
221246
return fmt.Errorf("setting quick ack: %w", err)
222247
}
@@ -267,22 +292,36 @@ func (c *Client) connectUDP(ctx context.Context, addrport string) error {
267292
}
268293

269294
eg, ctx := errgroup.WithContext(ctx)
295+
udpLoop:
270296
for i := int64(0); i < connTotal; i++ {
297+
// Check for context cancellation at the start of each iteration
298+
select {
299+
case <-ctx.Done():
300+
break udpLoop
301+
default:
302+
}
303+
271304
if err := waitLim(ctx, limiter); err != nil {
272305
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
273-
break
306+
break udpLoop
274307
}
275308
continue
276309
}
277310

278311
eg.Go(func() error {
279312
return measureTime(addrport, c.config.MergeResultsEachHost, func() error {
280-
conn, err := net.Dial("udp4", addrport)
313+
var dialer net.Dialer
314+
conn, err := dialer.DialContext(ctx, "udp4", addrport)
281315
if err != nil {
282316
return fmt.Errorf("dialing UDP %q: %w", addrport, err)
283317
}
284318
defer conn.Close()
285319

320+
// Set deadlines based on context to make Read/Write operations interruptible
321+
if deadline, ok := ctx.Deadline(); ok {
322+
conn.SetDeadline(deadline)
323+
}
324+
286325
msgPtr := bufUDPPool.Get().(*[]byte)
287326
msg := *msgPtr
288327
defer bufUDPPool.Put(msgPtr)

0 commit comments

Comments
 (0)