Skip to content

Commit 703784f

Browse files
authored
feat: add ping and pong received callbacks (#509)
This change adds two optional callbacks to both `DialOptions` and `AcceptOptions`. These callbacks are invoked synchronously when a ping or pong frame is received, allowing advanced users to log or inspect payloads for metrics or debugging. If the callback needs to perform more complex work or reuse the payload outside the callback, it is recommended to perform processing in a separate goroutine. The boolean return value of `OnPingReceived` is used to determine if the subsequent pong frame should be sent. If `false` is returned, the pong frame is not sent. Fixes #246
1 parent aec630d commit 703784f

File tree

5 files changed

+135
-5
lines changed

5 files changed

+135
-5
lines changed

accept.go

+19
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package websocket
55

66
import (
77
"bytes"
8+
"context"
89
"crypto/sha1"
910
"encoding/base64"
1011
"errors"
@@ -62,6 +63,22 @@ type AcceptOptions struct {
6263
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
6364
// for CompressionContextTakeover.
6465
CompressionThreshold int
66+
67+
// OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
68+
//
69+
// The payload contains the application data of the ping frame.
70+
// If the callback returns false, the subsequent pong frame will not be sent.
71+
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
72+
OnPingReceived func(ctx context.Context, payload []byte) bool
73+
74+
// OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
75+
//
76+
// The payload contains the application data of the pong frame.
77+
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
78+
//
79+
// Unlike OnPingReceived, this callback does not return a value because a pong frame
80+
// is a response to a ping and does not trigger any further frame transmission.
81+
OnPongReceived func(ctx context.Context, payload []byte)
6582
}
6683

6784
func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
@@ -156,6 +173,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
156173
client: false,
157174
copts: copts,
158175
flateThreshold: opts.CompressionThreshold,
176+
onPingReceived: opts.OnPingReceived,
177+
onPongReceived: opts.OnPongReceived,
159178

160179
br: brw.Reader,
161180
bw: brw.Writer,

conn.go

+11-5
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,11 @@ type Conn struct {
8383
closeMu sync.Mutex // Protects following.
8484
closed chan struct{}
8585

86-
pingCounter atomic.Int64
87-
activePingsMu sync.Mutex
88-
activePings map[string]chan<- struct{}
86+
pingCounter atomic.Int64
87+
activePingsMu sync.Mutex
88+
activePings map[string]chan<- struct{}
89+
onPingReceived func(context.Context, []byte) bool
90+
onPongReceived func(context.Context, []byte)
8991
}
9092

9193
type connConfig struct {
@@ -94,6 +96,8 @@ type connConfig struct {
9496
client bool
9597
copts *compressionOptions
9698
flateThreshold int
99+
onPingReceived func(context.Context, []byte) bool
100+
onPongReceived func(context.Context, []byte)
97101

98102
br *bufio.Reader
99103
bw *bufio.Writer
@@ -114,8 +118,10 @@ func newConn(cfg connConfig) *Conn {
114118
writeTimeout: make(chan context.Context),
115119
timeoutLoopDone: make(chan struct{}),
116120

117-
closed: make(chan struct{}),
118-
activePings: make(map[string]chan<- struct{}),
121+
closed: make(chan struct{}),
122+
activePings: make(map[string]chan<- struct{}),
123+
onPingReceived: cfg.onPingReceived,
124+
onPongReceived: cfg.onPongReceived,
119125
}
120126

121127
c.readMu = newMu(c)

conn_test.go

+79
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,85 @@ func TestConn(t *testing.T) {
9797
assert.Contains(t, err, "failed to wait for pong")
9898
})
9999

100+
t.Run("pingReceivedPongReceived", func(t *testing.T) {
101+
var pingReceived1, pongReceived1 bool
102+
var pingReceived2, pongReceived2 bool
103+
tt, c1, c2 := newConnTest(t,
104+
&websocket.DialOptions{
105+
OnPingReceived: func(ctx context.Context, payload []byte) bool {
106+
pingReceived1 = true
107+
return true
108+
},
109+
OnPongReceived: func(ctx context.Context, payload []byte) {
110+
pongReceived1 = true
111+
},
112+
}, &websocket.AcceptOptions{
113+
OnPingReceived: func(ctx context.Context, payload []byte) bool {
114+
pingReceived2 = true
115+
return true
116+
},
117+
OnPongReceived: func(ctx context.Context, payload []byte) {
118+
pongReceived2 = true
119+
},
120+
},
121+
)
122+
123+
c1.CloseRead(tt.ctx)
124+
c2.CloseRead(tt.ctx)
125+
126+
ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
127+
defer cancel()
128+
129+
err := c1.Ping(ctx)
130+
assert.Success(t, err)
131+
132+
c1.CloseNow()
133+
c2.CloseNow()
134+
135+
assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2)
136+
assert.Equal(t, "only one side receives the pong", false, pongReceived1 && pongReceived2)
137+
assert.Equal(t, "ping and pong received", true, (pingReceived1 && pongReceived2) || (pingReceived2 && pongReceived1))
138+
})
139+
140+
t.Run("pingReceivedPongNotReceived", func(t *testing.T) {
141+
var pingReceived1, pongReceived1 bool
142+
var pingReceived2, pongReceived2 bool
143+
tt, c1, c2 := newConnTest(t,
144+
&websocket.DialOptions{
145+
OnPingReceived: func(ctx context.Context, payload []byte) bool {
146+
pingReceived1 = true
147+
return false
148+
},
149+
OnPongReceived: func(ctx context.Context, payload []byte) {
150+
pongReceived1 = true
151+
},
152+
}, &websocket.AcceptOptions{
153+
OnPingReceived: func(ctx context.Context, payload []byte) bool {
154+
pingReceived2 = true
155+
return false
156+
},
157+
OnPongReceived: func(ctx context.Context, payload []byte) {
158+
pongReceived2 = true
159+
},
160+
},
161+
)
162+
163+
c1.CloseRead(tt.ctx)
164+
c2.CloseRead(tt.ctx)
165+
166+
ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
167+
defer cancel()
168+
169+
err := c1.Ping(ctx)
170+
assert.Contains(t, err, "failed to wait for pong")
171+
172+
c1.CloseNow()
173+
c2.CloseNow()
174+
175+
assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2)
176+
assert.Equal(t, "ping received and pong not received", true, (pingReceived1 && !pongReceived2) || (pingReceived2 && !pongReceived1))
177+
})
178+
100179
t.Run("concurrentWrite", func(t *testing.T) {
101180
tt, c1, c2 := newConnTest(t, nil, nil)
102181

dial.go

+18
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,22 @@ type DialOptions struct {
4848
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
4949
// for CompressionContextTakeover.
5050
CompressionThreshold int
51+
52+
// OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
53+
//
54+
// The payload contains the application data of the ping frame.
55+
// If the callback returns false, the subsequent pong frame will not be sent.
56+
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
57+
OnPingReceived func(ctx context.Context, payload []byte) bool
58+
59+
// OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
60+
//
61+
// The payload contains the application data of the pong frame.
62+
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
63+
//
64+
// Unlike OnPingReceived, this callback does not return a value because a pong frame
65+
// is a response to a ping and does not trigger any further frame transmission.
66+
OnPongReceived func(ctx context.Context, payload []byte)
5167
}
5268

5369
func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) {
@@ -163,6 +179,8 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
163179
client: true,
164180
copts: copts,
165181
flateThreshold: opts.CompressionThreshold,
182+
onPingReceived: opts.OnPingReceived,
183+
onPongReceived: opts.OnPongReceived,
166184
br: getBufioReader(rwc),
167185
bw: getBufioWriter(rwc),
168186
}), resp, nil

read.go

+8
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,16 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
312312

313313
switch h.opcode {
314314
case opPing:
315+
if c.onPingReceived != nil {
316+
if !c.onPingReceived(ctx, b) {
317+
return nil
318+
}
319+
}
315320
return c.writeControl(ctx, opPong, b)
316321
case opPong:
322+
if c.onPongReceived != nil {
323+
c.onPongReceived(ctx, b)
324+
}
317325
c.activePingsMu.Lock()
318326
pong, ok := c.activePings[string(b)]
319327
c.activePingsMu.Unlock()

0 commit comments

Comments
 (0)