@@ -49,6 +49,11 @@ type Conn struct {
49
49
// Effectively meaning whoever holds it gets to write to bw.
50
50
writeFrameLock chan struct {}
51
51
writeHeaderBuf []byte
52
+ writeHeader * header
53
+
54
+ // messageWriter state.
55
+ writeMsgOpcode opcode
56
+ writeMsgCtx context.Context
52
57
53
58
// Used to ensure the previous reader is read till EOF before allowing
54
59
// a new one.
@@ -58,6 +63,12 @@ type Conn struct {
58
63
readHeaderBuf []byte
59
64
controlPayloadBuf []byte
60
65
66
+ // messageReader state
67
+ readMsgCtx context.Context
68
+ readMsgHeader header
69
+ readFrameEOF bool
70
+ readMaskPos int
71
+
61
72
setReadTimeout chan context.Context
62
73
setWriteTimeout chan context.Context
63
74
@@ -81,6 +92,7 @@ func (c *Conn) init() {
81
92
c .activePings = make (map [string ]chan <- struct {})
82
93
83
94
c .writeHeaderBuf = makeWriteHeaderBuf ()
95
+ c .writeHeader = & header {}
84
96
c .readHeaderBuf = makeReadHeaderBuf ()
85
97
c .controlPayloadBuf = make ([]byte , maxControlFramePayload )
86
98
@@ -315,15 +327,11 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
315
327
if err != nil {
316
328
return 0 , nil , xerrors .Errorf ("failed to get reader: %w" , err )
317
329
}
318
- return typ , & limitedReader {
319
- c : c ,
320
- r : r ,
321
- left : c .msgReadLimit ,
322
- }, nil
330
+ return typ , r , nil
323
331
}
324
332
325
333
func (c * Conn ) reader (ctx context.Context ) (MessageType , io.Reader , error ) {
326
- if c .previousReader != nil && c . previousReader . h != nil {
334
+ if c .previousReader != nil && ! c . readFrameEOF {
327
335
// The only way we know for sure the previous reader is not yet complete is
328
336
// if there is an active frame not yet fully read.
329
337
// Otherwise, a user may have read the last byte but not the EOF if the EOF
@@ -336,7 +344,7 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
336
344
return 0 , nil , err
337
345
}
338
346
339
- if c .previousReader != nil && ! c .previousReader .done {
347
+ if c .previousReader != nil && ! c .previousReader .eof {
340
348
if h .opcode != opContinuation {
341
349
err := xerrors .Errorf ("received new data message without finishing the previous message" )
342
350
c .Close (StatusProtocolError , err .Error ())
@@ -347,33 +355,36 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
347
355
return 0 , nil , xerrors .Errorf ("previous message not read to completion" )
348
356
}
349
357
350
- c .previousReader .done = true
358
+ c .previousReader .eof = true
351
359
352
- return c .reader (ctx )
360
+ h , err = c .readTillMsg (ctx )
361
+ if err != nil {
362
+ return 0 , nil , err
363
+ }
353
364
} else if h .opcode == opContinuation {
354
365
err := xerrors .Errorf ("received continuation frame not after data or text frame" )
355
366
c .Close (StatusProtocolError , err .Error ())
356
367
return 0 , nil , err
357
368
}
358
369
359
- r := & messageReader {
360
- ctx : ctx ,
361
- c : c ,
370
+ c .readMsgCtx = ctx
371
+ c .readMsgHeader = h
372
+ c .readFrameEOF = false
373
+ c .readMaskPos = 0
362
374
363
- h : & h ,
375
+ r := & messageReader {
376
+ c : c ,
377
+ left : c .msgReadLimit ,
364
378
}
365
379
c .previousReader = r
366
380
return MessageType (h .opcode ), r , nil
367
381
}
368
382
369
383
// messageReader enables reading a data frame from the WebSocket connection.
370
384
type messageReader struct {
371
- ctx context.Context
372
- c * Conn
373
-
374
- h * header
375
- maskPos int
376
- done bool
385
+ c * Conn
386
+ left int64
387
+ eof bool
377
388
}
378
389
379
390
// Read reads as many bytes as possible into p.
@@ -391,12 +402,22 @@ func (r *messageReader) Read(p []byte) (int, error) {
391
402
}
392
403
393
404
func (r * messageReader ) read (p []byte ) (int , error ) {
394
- if r .done {
405
+ if r .eof {
395
406
return 0 , xerrors .Errorf ("cannot use EOFed reader" )
396
407
}
397
408
398
- if r .h == nil {
399
- h , err := r .c .readTillMsg (r .ctx )
409
+ if r .left <= 0 {
410
+ err := xerrors .Errorf ("read limited at %v bytes" , r .c .msgReadLimit )
411
+ r .c .Close (StatusMessageTooBig , err .Error ())
412
+ return 0 , err
413
+ }
414
+
415
+ if int64 (len (p )) > r .left {
416
+ p = p [:r .left ]
417
+ }
418
+
419
+ if r .c .readFrameEOF {
420
+ h , err := r .c .readTillMsg (r .c .readMsgCtx )
400
421
if err != nil {
401
422
return 0 , err
402
423
}
@@ -406,38 +427,37 @@ func (r *messageReader) read(p []byte) (int, error) {
406
427
r .c .Close (StatusProtocolError , err .Error ())
407
428
return 0 , err
408
429
}
409
- r .h = & h
430
+
431
+ r .c .readMsgHeader = h
432
+ r .c .readFrameEOF = false
433
+ r .c .readMaskPos = 0
410
434
}
411
435
412
- if int64 (len (p )) > r .h .payloadLength {
413
- p = p [:r .h .payloadLength ]
436
+ h := r .c .readMsgHeader
437
+ if int64 (len (p )) > h .payloadLength {
438
+ p = p [:h .payloadLength ]
414
439
}
415
440
416
- n , err := r .c .readFramePayload (r .ctx , p )
441
+ n , err := r .c .readFramePayload (r .c . readMsgCtx , p )
417
442
418
- r .h .payloadLength -= int64 (n )
419
- if r .h .masked {
420
- r .maskPos = fastXOR (r .h .maskKey , r .maskPos , p )
443
+ h .payloadLength -= int64 (n )
444
+ r .left -= int64 (n )
445
+ if h .masked {
446
+ r .c .readMaskPos = fastXOR (h .maskKey , r .c .readMaskPos , p )
421
447
}
448
+ r .c .readMsgHeader = h
422
449
423
450
if err != nil {
424
451
return n , err
425
452
}
426
453
427
- if r .h .payloadLength == 0 {
428
- fin := r .h .fin
429
-
430
- // Need to nil this as Reader uses it to check
431
- // whether there is active data on the previous reader and
432
- // now there isn't.
433
- r .h = nil
454
+ if h .payloadLength == 0 {
455
+ r .c .readFrameEOF = true
434
456
435
- if fin {
436
- r .done = true
457
+ if h . fin {
458
+ r .eof = true
437
459
return n , io .EOF
438
460
}
439
-
440
- r .maskPos = 0
441
461
}
442
462
443
463
return n , nil
@@ -524,10 +544,10 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
524
544
if err != nil {
525
545
return nil , err
526
546
}
547
+ c .writeMsgCtx = ctx
548
+ c .writeMsgOpcode = opcode (typ )
527
549
return & messageWriter {
528
- ctx : ctx ,
529
- opcode : opcode (typ ),
530
- c : c ,
550
+ c : c ,
531
551
}, nil
532
552
}
533
553
@@ -556,8 +576,6 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
556
576
557
577
// messageWriter enables writing to a WebSocket connection.
558
578
type messageWriter struct {
559
- ctx context.Context
560
- opcode opcode
561
579
c * Conn
562
580
closed bool
563
581
}
@@ -575,11 +593,11 @@ func (w *messageWriter) write(p []byte) (int, error) {
575
593
if w .closed {
576
594
return 0 , xerrors .Errorf ("cannot use closed writer" )
577
595
}
578
- n , err := w .c .writeFrame (w .ctx , false , w .opcode , p )
596
+ n , err := w .c .writeFrame (w .c . writeMsgCtx , false , w .c . writeMsgOpcode , p )
579
597
if err != nil {
580
598
return n , xerrors .Errorf ("failed to write data frame: %w" , err )
581
599
}
582
- w .opcode = opContinuation
600
+ w .c . writeMsgOpcode = opContinuation
583
601
return n , nil
584
602
}
585
603
@@ -599,7 +617,7 @@ func (w *messageWriter) close() error {
599
617
}
600
618
w .closed = true
601
619
602
- _ , err := w .c .writeFrame (w .ctx , true , w .opcode , nil )
620
+ _ , err := w .c .writeFrame (w .c . writeMsgCtx , true , w .c . writeMsgOpcode , nil )
603
621
if err != nil {
604
622
return xerrors .Errorf ("failed to write fin frame: %w" , err )
605
623
}
@@ -618,20 +636,6 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
618
636
619
637
// writeFrame handles all writes to the connection.
620
638
func (c * Conn ) writeFrame (ctx context.Context , fin bool , opcode opcode , p []byte ) (int , error ) {
621
- h := header {
622
- fin : fin ,
623
- opcode : opcode ,
624
- masked : c .client ,
625
- payloadLength : int64 (len (p )),
626
- }
627
-
628
- if c .client {
629
- _ , err := io .ReadFull (cryptorand .Reader , h .maskKey [:])
630
- if err != nil {
631
- return 0 , xerrors .Errorf ("failed to generate masking key: %w" , err )
632
- }
633
- }
634
-
635
639
err := c .acquireLock (ctx , c .writeFrameLock )
636
640
if err != nil {
637
641
return 0 , err
@@ -644,7 +648,19 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
644
648
case c .setWriteTimeout <- ctx :
645
649
}
646
650
647
- n , err := c .realWriteFrame (ctx , h , p )
651
+ c .writeHeader .fin = fin
652
+ c .writeHeader .opcode = opcode
653
+ c .writeHeader .masked = c .client
654
+ c .writeHeader .payloadLength = int64 (len (p ))
655
+
656
+ if c .client {
657
+ _ , err := io .ReadFull (cryptorand .Reader , c .writeHeader .maskKey [:])
658
+ if err != nil {
659
+ return 0 , xerrors .Errorf ("failed to generate masking key: %w" , err )
660
+ }
661
+ }
662
+
663
+ n , err := c .realWriteFrame (ctx , * c .writeHeader , p )
648
664
if err != nil {
649
665
return n , err
650
666
}
0 commit comments