Skip to content

Commit d1e338e

Browse files
committed
TUN-7545: Add support for full bidirectionally streaming with close signal propagation
1 parent b243602 commit d1e338e

File tree

2 files changed

+188
-6
lines changed

2 files changed

+188
-6
lines changed

stream/stream.go

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,36 @@ import (
1515
"github.com/cloudflare/cloudflared/cfio"
1616
)
1717

18+
type Stream interface {
19+
Reader
20+
WriterCloser
21+
}
22+
23+
type Reader interface {
24+
io.Reader
25+
}
26+
27+
type WriterCloser interface {
28+
io.Writer
29+
WriteCloser
30+
}
31+
32+
type WriteCloser interface {
33+
CloseWrite() error
34+
}
35+
36+
type nopCloseWriterAdapter struct {
37+
io.ReadWriter
38+
}
39+
40+
func NopCloseWriterAdapter(stream io.ReadWriter) *nopCloseWriterAdapter {
41+
return &nopCloseWriterAdapter{stream}
42+
}
43+
44+
func (n *nopCloseWriterAdapter) CloseWrite() error {
45+
return nil
46+
}
47+
1848
type bidirectionalStreamStatus struct {
1949
doneChan chan struct{}
2050
anyDone uint32
@@ -32,25 +62,53 @@ func (s *bidirectionalStreamStatus) markUniStreamDone() {
3262
s.doneChan <- struct{}{}
3363
}
3464

35-
func (s *bidirectionalStreamStatus) waitAnyDone() {
65+
func (s *bidirectionalStreamStatus) wait(maxWaitForSecondStream time.Duration) error {
3666
<-s.doneChan
67+
68+
// Only wait for second stream to finish if maxWait is greater than zero
69+
if maxWaitForSecondStream > 0 {
70+
71+
timer := time.NewTimer(maxWaitForSecondStream)
72+
defer timer.Stop()
73+
74+
select {
75+
case <-timer.C:
76+
return fmt.Errorf("timeout waiting for second stream to finish")
77+
case <-s.doneChan:
78+
return nil
79+
}
80+
}
81+
82+
return nil
3783
}
3884
func (s *bidirectionalStreamStatus) isAnyDone() bool {
3985
return atomic.LoadUint32(&s.anyDone) > 0
4086
}
4187

4288
// Pipe copies copy data to & from provided io.ReadWriters.
4389
func Pipe(tunnelConn, originConn io.ReadWriter, log *zerolog.Logger) {
90+
PipeBidirectional(NopCloseWriterAdapter(tunnelConn), NopCloseWriterAdapter(originConn), 0, log)
91+
}
92+
93+
// PipeBidirectional copies data two BidirectionStreams. It is a special case of Pipe where it receives a concept that allows for Read and Write side to be closed independently.
94+
// The main difference is that when piping data from a reader to a writer, if EOF is read, then this implementation propagates the EOF signal to the destination/writer by closing the write side of the
95+
// Bidirectional Stream.
96+
// Finally, depending on once EOF is ready from one of the provided streams, the other direction of streaming data will have a configured time period to also finish, otherwise,
97+
// the method will return immediately with a timeout error. It is however, the responsability of the caller to close the associated streams in both ends in order to free all the resources/go-routines.
98+
func PipeBidirectional(downstream, upstream Stream, maxWaitForSecondStream time.Duration, log *zerolog.Logger) error {
4499
status := newBiStreamStatus()
45100

46-
go unidirectionalStream(tunnelConn, originConn, "origin->tunnel", status, log)
47-
go unidirectionalStream(originConn, tunnelConn, "tunnel->origin", status, log)
101+
go unidirectionalStream(downstream, upstream, "upstream->downstream", status, log)
102+
go unidirectionalStream(upstream, downstream, "downstream->upstream", status, log)
48103

49-
// If one side is done, we are done.
50-
status.waitAnyDone()
104+
if err := status.wait(maxWaitForSecondStream); err != nil {
105+
return errors.Wrap(err, "unable to wait for both streams while proxying")
106+
}
107+
108+
return nil
51109
}
52110

53-
func unidirectionalStream(dst io.Writer, src io.Reader, dir string, status *bidirectionalStreamStatus, log *zerolog.Logger) {
111+
func unidirectionalStream(dst WriterCloser, src Reader, dir string, status *bidirectionalStreamStatus, log *zerolog.Logger) {
54112
defer func() {
55113
// The bidirectional streaming spawns 2 goroutines to stream each direction.
56114
// If any ends, the callstack returns, meaning the Tunnel request/stream (depending on http2 vs quic) will
@@ -71,6 +129,8 @@ func unidirectionalStream(dst io.Writer, src io.Reader, dir string, status *bidi
71129
}
72130
}()
73131

132+
defer dst.CloseWrite()
133+
74134
_, err := copyData(dst, src, dir)
75135
if err != nil {
76136
log.Debug().Msgf("%s copy: %v", dir, err)

stream/stream_test.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
package stream
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"sync"
7+
"testing"
8+
"time"
9+
10+
"github.com/rs/zerolog"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestPipeBidirectionalFinishBothSides(t *testing.T) {
15+
fun := func(upstream, downstream *mockedStream) {
16+
downstream.closeReader()
17+
upstream.closeReader()
18+
}
19+
20+
testPipeBidirectionalUnblocking(t, fun, time.Millisecond*200, false)
21+
}
22+
23+
func TestPipeBidirectionalFinishOneSideTimeout(t *testing.T) {
24+
fun := func(upstream, downstream *mockedStream) {
25+
downstream.closeReader()
26+
}
27+
28+
testPipeBidirectionalUnblocking(t, fun, time.Millisecond*200, true)
29+
}
30+
31+
func TestPipeBidirectionalClosingWriteBothSidesAlsoExists(t *testing.T) {
32+
fun := func(upstream, downstream *mockedStream) {
33+
downstream.CloseWrite()
34+
upstream.CloseWrite()
35+
36+
downstream.writeToReader("abc")
37+
upstream.writeToReader("abc")
38+
}
39+
40+
testPipeBidirectionalUnblocking(t, fun, time.Millisecond*200, false)
41+
}
42+
43+
func TestPipeBidirectionalClosingWriteSingleSideAlsoExists(t *testing.T) {
44+
fun := func(upstream, downstream *mockedStream) {
45+
downstream.CloseWrite()
46+
47+
downstream.writeToReader("abc")
48+
upstream.writeToReader("abc")
49+
}
50+
51+
testPipeBidirectionalUnblocking(t, fun, time.Millisecond*200, true)
52+
}
53+
54+
func testPipeBidirectionalUnblocking(t *testing.T, afterFun func(*mockedStream, *mockedStream), timeout time.Duration, expectTimeout bool) {
55+
logger := zerolog.Nop()
56+
57+
downstream := newMockedStream()
58+
upstream := newMockedStream()
59+
60+
resultCh := make(chan error)
61+
go func() {
62+
resultCh <- PipeBidirectional(downstream, upstream, timeout, &logger)
63+
}()
64+
65+
afterFun(upstream, downstream)
66+
67+
select {
68+
case err := <-resultCh:
69+
if expectTimeout {
70+
require.NotNil(t, err)
71+
} else {
72+
require.Nil(t, err)
73+
}
74+
75+
case <-time.After(timeout * 2):
76+
require.Fail(t, "test timeout")
77+
}
78+
}
79+
80+
func newMockedStream() *mockedStream {
81+
return &mockedStream{
82+
readCh: make(chan *string),
83+
writeCh: make(chan struct{}),
84+
}
85+
}
86+
87+
type mockedStream struct {
88+
readCh chan *string
89+
writeCh chan struct{}
90+
91+
writeCloseOnce sync.Once
92+
}
93+
94+
func (m *mockedStream) Read(p []byte) (n int, err error) {
95+
result := <-m.readCh
96+
if result == nil {
97+
return 0, io.EOF
98+
}
99+
100+
return len(*result), nil
101+
}
102+
103+
func (m *mockedStream) Write(p []byte) (n int, err error) {
104+
<-m.writeCh
105+
106+
return 0, fmt.Errorf("closed")
107+
}
108+
109+
func (m *mockedStream) CloseWrite() error {
110+
m.writeCloseOnce.Do(func() {
111+
close(m.writeCh)
112+
})
113+
114+
return nil
115+
}
116+
117+
func (m *mockedStream) closeReader() {
118+
close(m.readCh)
119+
}
120+
func (m *mockedStream) writeToReader(content string) {
121+
m.readCh <- &content
122+
}

0 commit comments

Comments
 (0)