Skip to content

Commit

Permalink
fix: Make sure that we can't connect after we stop the bindings liste…
Browse files Browse the repository at this point in the history
…ner (#469)
  • Loading branch information
jonstacks authored Oct 25, 2024
1 parent 304480e commit 06944f8
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
16 changes: 8 additions & 8 deletions pkg/bindingsdriver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,11 @@ type ConnectionHandler func(net.Conn) error

type bindingsListener struct {
listener net.Listener
stop chan struct{}
cnxnHandler ConnectionHandler
log logr.Logger

stopOnce sync.Once
stop chan struct{}
}

func newBindingsListener(address string, cnxnHandler ConnectionHandler) (*bindingsListener, error) {
Expand All @@ -72,8 +74,8 @@ func newBindingsListener(address string, cnxnHandler ConnectionHandler) (*bindin

bl := &bindingsListener{
listener: l,
stop: make(chan struct{}),
cnxnHandler: cnxnHandler,
stop: make(chan struct{}),
}

go bl.run()
Expand All @@ -83,18 +85,16 @@ func newBindingsListener(address string, cnxnHandler ConnectionHandler) (*bindin

// Stop stops the listener. It is safe to call stop multiple times.
func (b *bindingsListener) Stop() {
select {
case b.stop <- struct{}{}:
close(b.stop)
default:
}
b.stopOnce.Do(func() {
b.listener.Close()
b.stop <- struct{}{}
})
}

func (b *bindingsListener) run() {
for {
select {
case <-b.stop:
b.listener.Close()
return
default:
}
Expand Down
10 changes: 9 additions & 1 deletion pkg/bindingsdriver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"math/rand/v2"
"net"
"testing"
"time"

"github.com/stretchr/testify/assert"
)
Expand Down Expand Up @@ -35,7 +36,7 @@ func TestBindingsListener(t *testing.T) {
assert.NotNil(t, bl)

// test that we can connect to the listener
conn, err := net.Dial("tcp", loopbackAddr(port))
conn, err := net.DialTimeout("tcp", loopbackAddr(port), 10*time.Millisecond)
assert.NoError(t, err)

out, err := io.ReadAll(conn)
Expand All @@ -44,6 +45,13 @@ func TestBindingsListener(t *testing.T) {
assert.Equal(t, "hello world", string(out))

assert.NotPanics(t, func() { bl.Stop() })

// test that we can't connect to the listener after it's stopped
conn, err = net.DialTimeout("tcp", loopbackAddr(port), 10*time.Millisecond)
assert.Error(t, err)
assert.Nil(t, conn)

// test that we can stop the listener multiple times
assert.NotPanics(t, func() { bl.Stop() })
}

Expand Down

0 comments on commit 06944f8

Please sign in to comment.