Skip to content

Commit 11bda98

Browse files
FrauElstermafredri
andauthored
fix: avoid writing messages after close and improve handshake (#476)
Co-authored-by: Mathias Fredriksson <[email protected]>
1 parent 1253b77 commit 11bda98

File tree

5 files changed

+252
-65
lines changed

5 files changed

+252
-65
lines changed

close.go

+3-9
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func CloseStatus(err error) StatusCode {
100100
func (c *Conn) Close(code StatusCode, reason string) (err error) {
101101
defer errd.Wrap(&err, "failed to close WebSocket")
102102

103-
if !c.casClosing() {
103+
if c.casClosing() {
104104
err = c.waitGoroutines()
105105
if err != nil {
106106
return err
@@ -133,7 +133,7 @@ func (c *Conn) Close(code StatusCode, reason string) (err error) {
133133
func (c *Conn) CloseNow() (err error) {
134134
defer errd.Wrap(&err, "failed to immediately close WebSocket")
135135

136-
if !c.casClosing() {
136+
if c.casClosing() {
137137
err = c.waitGoroutines()
138138
if err != nil {
139139
return err
@@ -329,13 +329,7 @@ func (ce CloseError) bytesErr() ([]byte, error) {
329329
}
330330

331331
func (c *Conn) casClosing() bool {
332-
c.closeMu.Lock()
333-
defer c.closeMu.Unlock()
334-
if !c.closing {
335-
c.closing = true
336-
return true
337-
}
338-
return false
332+
return c.closing.Swap(true)
339333
}
340334

341335
func (c *Conn) isClosed() bool {

conn.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,19 @@ type Conn struct {
6969
writeHeaderBuf [8]byte
7070
writeHeader header
7171

72+
// Close handshake state.
73+
closeStateMu sync.RWMutex
74+
closeReceivedErr error
75+
closeSentErr error
76+
77+
// CloseRead state.
7278
closeReadMu sync.Mutex
7379
closeReadCtx context.Context
7480
closeReadDone chan struct{}
7581

82+
closing atomic.Bool
83+
closeMu sync.Mutex // Protects following.
7684
closed chan struct{}
77-
closeMu sync.Mutex
78-
closing bool
7985

8086
pingCounter atomic.Int64
8187
activePingsMu sync.Mutex

conn_test.go

+148-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"errors"
99
"fmt"
1010
"io"
11+
"net"
1112
"net/http"
1213
"net/http/httptest"
1314
"os"
@@ -460,7 +461,7 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
460461
}
461462

462463
func BenchmarkConn(b *testing.B) {
463-
var benchCases = []struct {
464+
benchCases := []struct {
464465
name string
465466
mode websocket.CompressionMode
466467
}{
@@ -625,3 +626,149 @@ func TestConcurrentClosePing(t *testing.T) {
625626
}()
626627
}
627628
}
629+
630+
func TestConnClosePropagation(t *testing.T) {
631+
t.Parallel()
632+
633+
want := []byte("hello")
634+
keepWriting := func(c *websocket.Conn) <-chan error {
635+
return xsync.Go(func() error {
636+
for {
637+
err := c.Write(context.Background(), websocket.MessageText, want)
638+
if err != nil {
639+
return err
640+
}
641+
}
642+
})
643+
}
644+
keepReading := func(c *websocket.Conn) <-chan error {
645+
return xsync.Go(func() error {
646+
for {
647+
_, got, err := c.Read(context.Background())
648+
if err != nil {
649+
return err
650+
}
651+
if !bytes.Equal(want, got) {
652+
return fmt.Errorf("unexpected message: want %q, got %q", want, got)
653+
}
654+
}
655+
})
656+
}
657+
checkReadErr := func(t *testing.T, err error) {
658+
// Check read error (output depends on when read is called in relation to connection closure).
659+
var ce websocket.CloseError
660+
if errors.As(err, &ce) {
661+
assert.Equal(t, "", websocket.StatusNormalClosure, ce.Code)
662+
} else {
663+
assert.ErrorIs(t, net.ErrClosed, err)
664+
}
665+
}
666+
checkConnErrs := func(t *testing.T, conn ...*websocket.Conn) {
667+
for _, c := range conn {
668+
// Check write error.
669+
err := c.Write(context.Background(), websocket.MessageText, want)
670+
assert.ErrorIs(t, net.ErrClosed, err)
671+
672+
_, _, err = c.Read(context.Background())
673+
checkReadErr(t, err)
674+
}
675+
}
676+
677+
t.Run("CloseOtherSideDuringWrite", func(t *testing.T) {
678+
tt, this, other := newConnTest(t, nil, nil)
679+
680+
_ = this.CloseRead(tt.ctx)
681+
thisWriteErr := keepWriting(this)
682+
683+
_, got, err := other.Read(tt.ctx)
684+
assert.Success(t, err)
685+
assert.Equal(t, "msg", want, got)
686+
687+
err = other.Close(websocket.StatusNormalClosure, "")
688+
assert.Success(t, err)
689+
690+
select {
691+
case err := <-thisWriteErr:
692+
assert.ErrorIs(t, net.ErrClosed, err)
693+
case <-tt.ctx.Done():
694+
t.Fatal(tt.ctx.Err())
695+
}
696+
697+
checkConnErrs(t, this, other)
698+
})
699+
t.Run("CloseThisSideDuringWrite", func(t *testing.T) {
700+
tt, this, other := newConnTest(t, nil, nil)
701+
702+
_ = this.CloseRead(tt.ctx)
703+
thisWriteErr := keepWriting(this)
704+
otherReadErr := keepReading(other)
705+
706+
err := this.Close(websocket.StatusNormalClosure, "")
707+
assert.Success(t, err)
708+
709+
select {
710+
case err := <-thisWriteErr:
711+
assert.ErrorIs(t, net.ErrClosed, err)
712+
case <-tt.ctx.Done():
713+
t.Fatal(tt.ctx.Err())
714+
}
715+
716+
select {
717+
case err := <-otherReadErr:
718+
checkReadErr(t, err)
719+
case <-tt.ctx.Done():
720+
t.Fatal(tt.ctx.Err())
721+
}
722+
723+
checkConnErrs(t, this, other)
724+
})
725+
t.Run("CloseOtherSideDuringRead", func(t *testing.T) {
726+
tt, this, other := newConnTest(t, nil, nil)
727+
728+
_ = other.CloseRead(tt.ctx)
729+
errs := keepReading(this)
730+
731+
err := other.Write(tt.ctx, websocket.MessageText, want)
732+
assert.Success(t, err)
733+
734+
err = other.Close(websocket.StatusNormalClosure, "")
735+
assert.Success(t, err)
736+
737+
select {
738+
case err := <-errs:
739+
checkReadErr(t, err)
740+
case <-tt.ctx.Done():
741+
t.Fatal(tt.ctx.Err())
742+
}
743+
744+
checkConnErrs(t, this, other)
745+
})
746+
t.Run("CloseThisSideDuringRead", func(t *testing.T) {
747+
tt, this, other := newConnTest(t, nil, nil)
748+
749+
thisReadErr := keepReading(this)
750+
otherReadErr := keepReading(other)
751+
752+
err := other.Write(tt.ctx, websocket.MessageText, want)
753+
assert.Success(t, err)
754+
755+
err = this.Close(websocket.StatusNormalClosure, "")
756+
assert.Success(t, err)
757+
758+
select {
759+
case err := <-thisReadErr:
760+
checkReadErr(t, err)
761+
case <-tt.ctx.Done():
762+
t.Fatal(tt.ctx.Err())
763+
}
764+
765+
select {
766+
case err := <-otherReadErr:
767+
checkReadErr(t, err)
768+
case <-tt.ctx.Done():
769+
t.Fatal(tt.ctx.Err())
770+
}
771+
772+
checkConnErrs(t, this, other)
773+
})
774+
}

read.go

+59-35
Original file line numberDiff line numberDiff line change
@@ -217,57 +217,68 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) {
217217
}
218218
}
219219

220-
func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
220+
// prepareRead sets the readTimeout context and returns a done function
221+
// to be called after the read is done. It also returns an error if the
222+
// connection is closed. The reference to the error is used to assign
223+
// an error depending on if the connection closed or the context timed
224+
// out during use. Typically the referenced error is a named return
225+
// variable of the function calling this method.
226+
func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) {
221227
select {
222228
case <-c.closed:
223-
return header{}, net.ErrClosed
229+
return nil, net.ErrClosed
224230
case c.readTimeout <- ctx:
225231
}
226232

227-
h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
228-
if err != nil {
233+
done := func() {
229234
select {
230235
case <-c.closed:
231-
return header{}, net.ErrClosed
232-
case <-ctx.Done():
233-
return header{}, ctx.Err()
234-
default:
235-
return header{}, err
236+
if *err != nil {
237+
*err = net.ErrClosed
238+
}
239+
case c.readTimeout <- context.Background():
240+
}
241+
if *err != nil && ctx.Err() != nil {
242+
*err = ctx.Err()
236243
}
237244
}
238245

239-
select {
240-
case <-c.closed:
241-
return header{}, net.ErrClosed
242-
case c.readTimeout <- context.Background():
246+
c.closeStateMu.Lock()
247+
closeReceivedErr := c.closeReceivedErr
248+
c.closeStateMu.Unlock()
249+
if closeReceivedErr != nil {
250+
defer done()
251+
return nil, closeReceivedErr
243252
}
244253

245-
return h, nil
254+
return done, nil
246255
}
247256

248-
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
249-
select {
250-
case <-c.closed:
251-
return 0, net.ErrClosed
252-
case c.readTimeout <- ctx:
257+
func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) {
258+
readDone, err := c.prepareRead(ctx, &err)
259+
if err != nil {
260+
return header{}, err
253261
}
262+
defer readDone()
254263

255-
n, err := io.ReadFull(c.br, p)
264+
h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
256265
if err != nil {
257-
select {
258-
case <-c.closed:
259-
return n, net.ErrClosed
260-
case <-ctx.Done():
261-
return n, ctx.Err()
262-
default:
263-
return n, fmt.Errorf("failed to read frame payload: %w", err)
264-
}
266+
return header{}, err
265267
}
266268

267-
select {
268-
case <-c.closed:
269-
return n, net.ErrClosed
270-
case c.readTimeout <- context.Background():
269+
return h, nil
270+
}
271+
272+
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) {
273+
readDone, err := c.prepareRead(ctx, &err)
274+
if err != nil {
275+
return 0, err
276+
}
277+
defer readDone()
278+
279+
n, err := io.ReadFull(c.br, p)
280+
if err != nil {
281+
return n, fmt.Errorf("failed to read frame payload: %w", err)
271282
}
272283

273284
return n, err
@@ -325,9 +336,22 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
325336
}
326337

327338
err = fmt.Errorf("received close frame: %w", ce)
328-
c.writeClose(ce.Code, ce.Reason)
329-
c.readMu.unlock()
330-
c.close()
339+
c.closeStateMu.Lock()
340+
c.closeReceivedErr = err
341+
closeSent := c.closeSentErr != nil
342+
c.closeStateMu.Unlock()
343+
344+
// Only unlock readMu if this connection is being closed becaue
345+
// c.close will try to acquire the readMu lock. We unlock for
346+
// writeClose as well because it may also call c.close.
347+
if !closeSent {
348+
c.readMu.unlock()
349+
_ = c.writeClose(ce.Code, ce.Reason)
350+
}
351+
if !c.casClosing() {
352+
c.readMu.unlock()
353+
_ = c.close()
354+
}
331355
return err
332356
}
333357

0 commit comments

Comments
 (0)