@@ -41,21 +41,24 @@ func NewClient(config ClientConfig) *Client {
4141}
4242
4343func waitLim (ctx context.Context , rl ratelimit.Limiter ) error {
44+ // Quick context check before any blocking operation
4445 select {
4546 case <- ctx .Done ():
4647 return ctx .Err ()
4748 default :
48- done := make (chan struct {})
49- go func () {
50- rl .Take ()
51- close (done )
52- }()
53- select {
54- case <- done :
55- return nil
56- case <- ctx .Done ():
57- return ctx .Err ()
58- }
49+ }
50+
51+ done := make (chan struct {})
52+ go func () {
53+ defer close (done )
54+ rl .Take ()
55+ }()
56+
57+ select {
58+ case <- done :
59+ return nil
60+ case <- ctx .Done ():
61+ return ctx .Err ()
5962 }
6063}
6164
@@ -97,6 +100,10 @@ func (c *Client) ConnectToAddresses(ctx context.Context, addrs []string) error {
97100 }
98101
99102 if err := eg .Wait (); err != nil {
103+ if errors .Is (err , context .Canceled ) || errors .Is (err , context .DeadlineExceeded ) {
104+ slog .Warn ("context canceled" , "error" , err )
105+ return nil
106+ }
100107 return fmt .Errorf ("connection error: %w" , err )
101108 }
102109 return nil
@@ -135,12 +142,17 @@ func (c *Client) connectPersistent(ctx context.Context, addrport string) error {
135142 eg , ctx := errgroup .WithContext (ctx )
136143 for i := 0 ; i < int (c .config .Connections ); i ++ {
137144 eg .Go (func () error {
138- conn , err := dialer .Dial ( "tcp" , addrport )
145+ conn , err := dialer .DialContext ( ctx , "tcp" , addrport )
139146 if err != nil {
140147 return fmt .Errorf ("dialing %q: %w" , addrport , err )
141148 }
142149 defer conn .Close ()
143150
151+ // Set deadlines based on context to make Read/Write operations interruptible
152+ if deadline , ok := ctx .Deadline (); ok {
153+ conn .SetDeadline (deadline )
154+ }
155+
144156 msgsTotal := int64 (c .config .Rate ) * int64 (c .config .Duration .Seconds ())
145157 limiter := ratelimit .New (int (c .config .Rate ))
146158
@@ -197,17 +209,25 @@ func (c *Client) connectEphemeral(ctx context.Context, addrport string) error {
197209 limiter := ratelimit .New (int (c .config .Rate ))
198210
199211 eg , ctx := errgroup .WithContext (ctx )
212+ ephemeralLoop:
200213 for i := int64 (0 ); i < connTotal ; i ++ {
214+ // Check for context cancellation at the start of each iteration
215+ select {
216+ case <- ctx .Done ():
217+ break ephemeralLoop
218+ default :
219+ }
220+
201221 if err := waitLim (ctx , limiter ); err != nil {
202222 if errors .Is (err , context .Canceled ) || errors .Is (err , context .DeadlineExceeded ) {
203- break
223+ break ephemeralLoop
204224 }
205225 continue
206226 }
207227
208228 eg .Go (func () error {
209229 return measureTime (addrport , c .config .MergeResultsEachHost , func () error {
210- conn , err := dialer .Dial ( "tcp" , addrport )
230+ conn , err := dialer .DialContext ( ctx , "tcp" , addrport )
211231 if err != nil {
212232 if errors .Is (err , syscall .ETIMEDOUT ) {
213233 slog .Warn ("connection timeout" , "addr" , addrport )
@@ -217,6 +237,11 @@ func (c *Client) connectEphemeral(ctx context.Context, addrport string) error {
217237 }
218238 defer conn .Close ()
219239
240+ // Set deadlines based on context to make Read/Write operations interruptible
241+ if deadline , ok := ctx .Deadline (); ok {
242+ conn .SetDeadline (deadline )
243+ }
244+
220245 if err := SetQuickAck (conn ); err != nil {
221246 return fmt .Errorf ("setting quick ack: %w" , err )
222247 }
@@ -267,22 +292,36 @@ func (c *Client) connectUDP(ctx context.Context, addrport string) error {
267292 }
268293
269294 eg , ctx := errgroup .WithContext (ctx )
295+ udpLoop:
270296 for i := int64 (0 ); i < connTotal ; i ++ {
297+ // Check for context cancellation at the start of each iteration
298+ select {
299+ case <- ctx .Done ():
300+ break udpLoop
301+ default :
302+ }
303+
271304 if err := waitLim (ctx , limiter ); err != nil {
272305 if errors .Is (err , context .Canceled ) || errors .Is (err , context .DeadlineExceeded ) {
273- break
306+ break udpLoop
274307 }
275308 continue
276309 }
277310
278311 eg .Go (func () error {
279312 return measureTime (addrport , c .config .MergeResultsEachHost , func () error {
280- conn , err := net .Dial ("udp4" , addrport )
313+ var dialer net.Dialer
314+ conn , err := dialer .DialContext (ctx , "udp4" , addrport )
281315 if err != nil {
282316 return fmt .Errorf ("dialing UDP %q: %w" , addrport , err )
283317 }
284318 defer conn .Close ()
285319
320+ // Set deadlines based on context to make Read/Write operations interruptible
321+ if deadline , ok := ctx .Deadline (); ok {
322+ conn .SetDeadline (deadline )
323+ }
324+
286325 msgPtr := bufUDPPool .Get ().(* []byte )
287326 msg := * msgPtr
288327 defer bufUDPPool .Put (msgPtr )
0 commit comments