Skip to content

Commit 0b8b974

Browse files
committed
Reduce allocation overhea to absolute minimum
Can't go any lower than this afaict. 16 bytes per Writer and 24 bytes per Reader. go tool pprof agrees with me on bytes per op but says the allocs per op are 3 instead of 4 and thinks echoLoop is allocating. I don't know. Lots of cleanup can be performed. Closes #95
1 parent 029e412 commit 0b8b974

File tree

4 files changed

+86
-97
lines changed

4 files changed

+86
-97
lines changed

limitedreader.go

-33
This file was deleted.

websocket.go

+79-63
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ type Conn struct {
4949
// Effectively meaning whoever holds it gets to write to bw.
5050
writeFrameLock chan struct{}
5151
writeHeaderBuf []byte
52+
writeHeader *header
53+
54+
// messageWriter state.
55+
writeMsgOpcode opcode
56+
writeMsgCtx context.Context
5257

5358
// Used to ensure the previous reader is read till EOF before allowing
5459
// a new one.
@@ -58,6 +63,12 @@ type Conn struct {
5863
readHeaderBuf []byte
5964
controlPayloadBuf []byte
6065

66+
// messageReader state
67+
readMsgCtx context.Context
68+
readMsgHeader header
69+
readFrameEOF bool
70+
readMaskPos int
71+
6172
setReadTimeout chan context.Context
6273
setWriteTimeout chan context.Context
6374

@@ -81,6 +92,7 @@ func (c *Conn) init() {
8192
c.activePings = make(map[string]chan<- struct{})
8293

8394
c.writeHeaderBuf = makeWriteHeaderBuf()
95+
c.writeHeader = &header{}
8496
c.readHeaderBuf = makeReadHeaderBuf()
8597
c.controlPayloadBuf = make([]byte, maxControlFramePayload)
8698

@@ -315,15 +327,11 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
315327
if err != nil {
316328
return 0, nil, xerrors.Errorf("failed to get reader: %w", err)
317329
}
318-
return typ, &limitedReader{
319-
c: c,
320-
r: r,
321-
left: c.msgReadLimit,
322-
}, nil
330+
return typ, r, nil
323331
}
324332

325333
func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
326-
if c.previousReader != nil && c.previousReader.h != nil {
334+
if c.previousReader != nil && !c.readFrameEOF {
327335
// The only way we know for sure the previous reader is not yet complete is
328336
// if there is an active frame not yet fully read.
329337
// Otherwise, a user may have read the last byte but not the EOF if the EOF
@@ -336,7 +344,7 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
336344
return 0, nil, err
337345
}
338346

339-
if c.previousReader != nil && !c.previousReader.done {
347+
if c.previousReader != nil && !c.previousReader.eof {
340348
if h.opcode != opContinuation {
341349
err := xerrors.Errorf("received new data message without finishing the previous message")
342350
c.Close(StatusProtocolError, err.Error())
@@ -347,33 +355,36 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
347355
return 0, nil, xerrors.Errorf("previous message not read to completion")
348356
}
349357

350-
c.previousReader.done = true
358+
c.previousReader.eof = true
351359

352-
return c.reader(ctx)
360+
h, err = c.readTillMsg(ctx)
361+
if err != nil {
362+
return 0, nil, err
363+
}
353364
} else if h.opcode == opContinuation {
354365
err := xerrors.Errorf("received continuation frame not after data or text frame")
355366
c.Close(StatusProtocolError, err.Error())
356367
return 0, nil, err
357368
}
358369

359-
r := &messageReader{
360-
ctx: ctx,
361-
c: c,
370+
c.readMsgCtx = ctx
371+
c.readMsgHeader = h
372+
c.readFrameEOF = false
373+
c.readMaskPos = 0
362374

363-
h: &h,
375+
r := &messageReader{
376+
c: c,
377+
left: c.msgReadLimit,
364378
}
365379
c.previousReader = r
366380
return MessageType(h.opcode), r, nil
367381
}
368382

369383
// messageReader enables reading a data frame from the WebSocket connection.
370384
type messageReader struct {
371-
ctx context.Context
372-
c *Conn
373-
374-
h *header
375-
maskPos int
376-
done bool
385+
c *Conn
386+
left int64
387+
eof bool
377388
}
378389

379390
// Read reads as many bytes as possible into p.
@@ -391,12 +402,22 @@ func (r *messageReader) Read(p []byte) (int, error) {
391402
}
392403

393404
func (r *messageReader) read(p []byte) (int, error) {
394-
if r.done {
405+
if r.eof {
395406
return 0, xerrors.Errorf("cannot use EOFed reader")
396407
}
397408

398-
if r.h == nil {
399-
h, err := r.c.readTillMsg(r.ctx)
409+
if r.left <= 0 {
410+
err := xerrors.Errorf("read limited at %v bytes", r.c.msgReadLimit)
411+
r.c.Close(StatusMessageTooBig, err.Error())
412+
return 0, err
413+
}
414+
415+
if int64(len(p)) > r.left {
416+
p = p[:r.left]
417+
}
418+
419+
if r.c.readFrameEOF {
420+
h, err := r.c.readTillMsg(r.c.readMsgCtx)
400421
if err != nil {
401422
return 0, err
402423
}
@@ -406,38 +427,37 @@ func (r *messageReader) read(p []byte) (int, error) {
406427
r.c.Close(StatusProtocolError, err.Error())
407428
return 0, err
408429
}
409-
r.h = &h
430+
431+
r.c.readMsgHeader = h
432+
r.c.readFrameEOF = false
433+
r.c.readMaskPos = 0
410434
}
411435

412-
if int64(len(p)) > r.h.payloadLength {
413-
p = p[:r.h.payloadLength]
436+
h := r.c.readMsgHeader
437+
if int64(len(p)) > h.payloadLength {
438+
p = p[:h.payloadLength]
414439
}
415440

416-
n, err := r.c.readFramePayload(r.ctx, p)
441+
n, err := r.c.readFramePayload(r.c.readMsgCtx, p)
417442

418-
r.h.payloadLength -= int64(n)
419-
if r.h.masked {
420-
r.maskPos = fastXOR(r.h.maskKey, r.maskPos, p)
443+
h.payloadLength -= int64(n)
444+
r.left -= int64(n)
445+
if h.masked {
446+
r.c.readMaskPos = fastXOR(h.maskKey, r.c.readMaskPos, p)
421447
}
448+
r.c.readMsgHeader = h
422449

423450
if err != nil {
424451
return n, err
425452
}
426453

427-
if r.h.payloadLength == 0 {
428-
fin := r.h.fin
429-
430-
// Need to nil this as Reader uses it to check
431-
// whether there is active data on the previous reader and
432-
// now there isn't.
433-
r.h = nil
454+
if h.payloadLength == 0 {
455+
r.c.readFrameEOF = true
434456

435-
if fin {
436-
r.done = true
457+
if h.fin {
458+
r.eof = true
437459
return n, io.EOF
438460
}
439-
440-
r.maskPos = 0
441461
}
442462

443463
return n, nil
@@ -524,10 +544,10 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
524544
if err != nil {
525545
return nil, err
526546
}
547+
c.writeMsgCtx = ctx
548+
c.writeMsgOpcode = opcode(typ)
527549
return &messageWriter{
528-
ctx: ctx,
529-
opcode: opcode(typ),
530-
c: c,
550+
c: c,
531551
}, nil
532552
}
533553

@@ -556,8 +576,6 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
556576

557577
// messageWriter enables writing to a WebSocket connection.
558578
type messageWriter struct {
559-
ctx context.Context
560-
opcode opcode
561579
c *Conn
562580
closed bool
563581
}
@@ -575,11 +593,11 @@ func (w *messageWriter) write(p []byte) (int, error) {
575593
if w.closed {
576594
return 0, xerrors.Errorf("cannot use closed writer")
577595
}
578-
n, err := w.c.writeFrame(w.ctx, false, w.opcode, p)
596+
n, err := w.c.writeFrame(w.c.writeMsgCtx, false, w.c.writeMsgOpcode, p)
579597
if err != nil {
580598
return n, xerrors.Errorf("failed to write data frame: %w", err)
581599
}
582-
w.opcode = opContinuation
600+
w.c.writeMsgOpcode = opContinuation
583601
return n, nil
584602
}
585603

@@ -599,7 +617,7 @@ func (w *messageWriter) close() error {
599617
}
600618
w.closed = true
601619

602-
_, err := w.c.writeFrame(w.ctx, true, w.opcode, nil)
620+
_, err := w.c.writeFrame(w.c.writeMsgCtx, true, w.c.writeMsgOpcode, nil)
603621
if err != nil {
604622
return xerrors.Errorf("failed to write fin frame: %w", err)
605623
}
@@ -618,20 +636,6 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
618636

619637
// writeFrame handles all writes to the connection.
620638
func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) {
621-
h := header{
622-
fin: fin,
623-
opcode: opcode,
624-
masked: c.client,
625-
payloadLength: int64(len(p)),
626-
}
627-
628-
if c.client {
629-
_, err := io.ReadFull(cryptorand.Reader, h.maskKey[:])
630-
if err != nil {
631-
return 0, xerrors.Errorf("failed to generate masking key: %w", err)
632-
}
633-
}
634-
635639
err := c.acquireLock(ctx, c.writeFrameLock)
636640
if err != nil {
637641
return 0, err
@@ -644,7 +648,19 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
644648
case c.setWriteTimeout <- ctx:
645649
}
646650

647-
n, err := c.realWriteFrame(ctx, h, p)
651+
c.writeHeader.fin = fin
652+
c.writeHeader.opcode = opcode
653+
c.writeHeader.masked = c.client
654+
c.writeHeader.payloadLength = int64(len(p))
655+
656+
if c.client {
657+
_, err := io.ReadFull(cryptorand.Reader, c.writeHeader.maskKey[:])
658+
if err != nil {
659+
return 0, xerrors.Errorf("failed to generate masking key: %w", err)
660+
}
661+
}
662+
663+
n, err := c.realWriteFrame(ctx, *c.writeHeader, p)
648664
if err != nil {
649665
return n, err
650666
}

websocket_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -879,7 +879,7 @@ func BenchmarkConn(b *testing.B) {
879879
b.Run("echo", func(b *testing.B) {
880880
for _, size := range sizes {
881881
b.Run(strconv.Itoa(size), func(b *testing.B) {
882-
benchConn(b, true, true, size)
882+
benchConn(b, true, false, size)
883883
})
884884
}
885885
})

xor_test.go

+6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"crypto/rand"
55
"strconv"
66
"testing"
7+
"unsafe"
78

89
"github.com/google/go-cmp/cmp"
910
)
@@ -80,3 +81,8 @@ func BenchmarkXOR(b *testing.B) {
8081
})
8182
}
8283
}
84+
85+
func TestFoo(t *testing.T) {
86+
t.Log(unsafe.Sizeof(messageWriter{}))
87+
t.Log(unsafe.Sizeof(messageReader{}))
88+
}

0 commit comments

Comments
 (0)