Skip to content

Commit 5024792

Browse files
committed
Fix NetConn read bug
See #100 (comment)
1 parent 97f63d0 commit 5024792

File tree

2 files changed

+48
-10
lines changed

2 files changed

+48
-10
lines changed

Diff for: netconn.go

+13
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import (
2222
// reading/writing goroutines are interrupted but the connection is kept alive.
2323
//
2424
// The Addr methods will return a mock net.Addr.
25+
//
26+
// A received StatusNormalClosure close frame will be translated to EOF when reading.
2527
func NetConn(c *Conn) net.Conn {
2628
nc := &netConn{
2729
c: c,
@@ -47,6 +49,7 @@ type netConn struct {
4749

4850
readTimer *time.Timer
4951
readContext context.Context
52+
eofed bool
5053

5154
reader io.Reader
5255
}
@@ -66,9 +69,18 @@ func (c *netConn) Write(p []byte) (int, error) {
6669
}
6770

6871
func (c *netConn) Read(p []byte) (int, error) {
72+
if c.eofed {
73+
return 0, io.EOF
74+
}
75+
6976
if c.reader == nil {
7077
typ, r, err := c.c.Reader(c.readContext)
7178
if err != nil {
79+
var ce CloseError
80+
if xerrors.As(err, &ce) && (ce.Code == StatusNormalClosure) {
81+
c.eofed = true
82+
return 0, io.EOF
83+
}
7284
return 0, err
7385
}
7486
if typ != MessageBinary {
@@ -81,6 +93,7 @@ func (c *netConn) Read(p []byte) (int, error) {
8193
n, err := c.reader.Read(p)
8294
if err == io.EOF {
8395
c.reader = nil
96+
err = nil
8497
}
8598
return n, err
8699
}

Diff for: websocket_test.go

+35-10
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,13 @@ func TestHandshake(t *testing.T) {
130130
nc := websocket.NetConn(c)
131131
defer nc.Close()
132132

133-
nc.SetWriteDeadline(time.Now().Add(time.Second * 10))
133+
nc.SetWriteDeadline(time.Now().Add(time.Second * 15))
134134

135-
_, err = nc.Write([]byte("hello"))
136-
if err != nil {
137-
return err
135+
for i := 0; i < 3; i++ {
136+
_, err = nc.Write([]byte("hello"))
137+
if err != nil {
138+
return err
139+
}
138140
}
139141

140142
return nil
@@ -151,16 +153,39 @@ func TestHandshake(t *testing.T) {
151153
nc := websocket.NetConn(c)
152154
defer nc.Close()
153155

154-
nc.SetReadDeadline(time.Now().Add(time.Second * 10))
156+
nc.SetReadDeadline(time.Now().Add(time.Second * 15))
155157

156-
p := make([]byte, len("hello"))
157-
_, err = io.ReadFull(nc, p)
158-
if err != nil {
158+
read := func() error {
159+
p := make([]byte, len("hello"))
160+
// We do not use io.ReadFull here as it masks EOFs.
161+
// See https://github.com/nhooyr/websocket/issues/100#issuecomment-508148024
162+
_, err = nc.Read(p)
163+
if err != nil {
164+
return err
165+
}
166+
167+
if string(p) != "hello" {
168+
return xerrors.Errorf("unexpected payload %q received", string(p))
169+
}
170+
return nil
171+
}
172+
173+
for i := 0; i < 3; i++ {
174+
err = read()
175+
if err != nil {
176+
return err
177+
}
178+
}
179+
180+
// Ensure the close frame is converted to an EOF and multiple read's after all return EOF.
181+
err = read()
182+
if err != io.EOF {
159183
return err
160184
}
161185

162-
if string(p) != "hello" {
163-
return xerrors.Errorf("unexpected payload %q received", string(p))
186+
err = read()
187+
if err != io.EOF {
188+
return err
164189
}
165190

166191
return nil

0 commit comments

Comments
 (0)