Skip to content

Commit

Permalink
fix(nats): wait for reconnects on setup (#439)
Browse files Browse the repository at this point in the history
If the initial connect fails, nats will spawn reconnect async
handlers. Thus, we need to wait for all reconnects to be attempted
before returning to the caller, otherwise, we won't be making
use of reconnections

* fix(app): init sig chan as buffered

* fix(etcd): prevent shutdown from crashing app

If the etcd module shuts down before all connections are set up,
it will crash trying to access sd.cli where it's still nil. Thus
adding a check on shutdown
  • Loading branch information
hspedro authored Jan 27, 2025
1 parent 2bbd948 commit 85f7615
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 15 deletions.
4 changes: 2 additions & 2 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,8 @@ func (app *App) Start() {
app.running = false
}()

sg := make(chan os.Signal)
signal.Notify(sg, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGKILL, syscall.SIGTERM)
sg := make(chan os.Signal, 1)
signal.Notify(sg, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGTERM)

maxSessionCount := func() int64 {
count := app.sessionPool.GetSessionCount()
Expand Down
22 changes: 11 additions & 11 deletions cluster/etcd_service_discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,18 +614,18 @@ func (sd *etcdServiceDiscovery) revoke() error {
close(sd.stopLeaseChan)
c := make(chan error, 1)
go func() {
defer close(c)
logger.Log.Debug("waiting for etcd revoke")
ctx, cancel := context.WithTimeout(context.Background(), sd.revokeTimeout)
_, err := sd.cli.Revoke(ctx, sd.leaseID)
cancel()
c <- err
logger.Log.Debug("finished waiting for etcd revoke")
if sd.cli != nil {
defer close(c)
logger.Log.Debug("waiting for etcd revoke")
ctx, cancel := context.WithTimeout(context.Background(), sd.revokeTimeout)
_, err := sd.cli.Revoke(ctx, sd.leaseID)
cancel()
c <- err
logger.Log.Debug("finished waiting for etcd revoke")
}
}()
select {
case err := <-c:
return err // completed normally
}
err := <-c
return err // completed normally
}

func (sd *etcdServiceDiscovery) addServer(sv *Server) {
Expand Down
1 change: 1 addition & 0 deletions cluster/grpc_rpc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func TestGRPCServerInit(t *testing.T) {

sv := getServer()
gs, err := NewGRPCServer(c, sv, []metrics.Reporter{})
assert.NoError(t, err)
gs.SetPitayaServer(mockPitayaServer)
err = gs.Init()
assert.NoError(t, err)
Expand Down
44 changes: 42 additions & 2 deletions cluster/nats_rpc_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ package cluster

import (
"fmt"
"os"
"syscall"
"time"

nats "github.com/nats-io/nats.go"
"github.com/topfreegames/pitaya/v2/logger"
Expand All @@ -32,6 +35,8 @@ func getChannel(serverType, serverID string) string {
}

func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.Option) (*nats.Conn, error) {
connectedCh := make(chan bool)
initialConnectErrorCh := make(chan error)
natsOptions := append(
options,
nats.DisconnectErrHandler(func(_ *nats.Conn, err error) {
Expand All @@ -49,7 +54,19 @@ func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.O

logger.Log.Errorf("nats connection closed. reason: %q", nc.LastError())
if appDieChan != nil {
appDieChan <- true
select {
case appDieChan <- true:
return
case initialConnectErrorCh <- nc.LastError():
logger.Log.Warnf("appDieChan not ready, sending error in initialConnectCh")
default:
logger.Log.Warnf("no termination channel available, sending SIGTERM to app")
err := syscall.Kill(os.Getpid(), syscall.SIGTERM)
if err != nil {
logger.Log.Errorf("could not kill the application via SIGTERM, exiting", err)
os.Exit(1)
}
}
}
}),
nats.ErrorHandler(func(nc *nats.Conn, sub *nats.Subscription, err error) {
Expand All @@ -61,11 +78,34 @@ func setupNatsConn(connectString string, appDieChan chan bool, options ...nats.O
logger.Log.Errorf(err.Error())
}
}),
nats.ConnectHandler(func(*nats.Conn) {
connectedCh <- true
}),
)

nc, err := nats.Connect(connectString, natsOptions...)
if err != nil {
return nil, err
}
return nc, nil
maxConnTimeout := nc.Opts.Timeout
if nc.Opts.RetryOnFailedConnect {
// This is non-deterministic becase jitter TLS is different and we need to simplify
// the calculations. What we want to do is simply not block forever the call while
// we don't set a timeout so low that hinders our own reconnect config:
// maxReconnectTimeout = reconnectWait + reconnectJitter + reconnectTimeout
// connectionTimeout + (maxReconnectionAttemps * maxReconnectTimeout)
// Thus, the time.After considers 2 times this value
maxReconnectionTimeout := nc.Opts.ReconnectWait + nc.Opts.ReconnectJitter + nc.Opts.Timeout
maxConnTimeout += time.Duration(nc.Opts.MaxReconnect) * maxReconnectionTimeout
}

logger.Log.Debugf("attempting nats connection for a max of %v", maxConnTimeout)
select {
case <-connectedCh:
return nc, nil
case err := <-initialConnectErrorCh:
return nil, err
case <-time.After(maxConnTimeout * 2):
return nil, fmt.Errorf("timeout setting up nats connection")
}
}
134 changes: 134 additions & 0 deletions cluster/nats_rpc_common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"testing"
"time"

"github.com/nats-io/nats-server/v2/test"
nats "github.com/nats-io/nats.go"
"github.com/stretchr/testify/assert"
"github.com/topfreegames/pitaya/v2/helpers"
Expand Down Expand Up @@ -77,3 +78,136 @@ func TestNatsRPCCommonCloseHandler(t *testing.T) {
assert.True(t, ok)
assert.True(t, value)
}

func TestSetupNatsConnReconnection(t *testing.T) {
t.Run("waits for reconnection on initial failure", func(t *testing.T) {
// Use an invalid address first to force initial connection failure
invalidAddr := "nats://invalid:4222"
validAddr := "nats://localhost:4222"

urls := fmt.Sprintf("%s,%s", invalidAddr, validAddr)

go func() {
time.Sleep(50 * time.Millisecond)
ts := test.RunDefaultServer()
defer ts.Shutdown()
<-time.After(200 * time.Millisecond)
}()

// Setup connection with retry enabled
appDieCh := make(chan bool)
conn, err := setupNatsConn(
urls,
appDieCh,
nats.ReconnectWait(10*time.Millisecond),
nats.MaxReconnects(5),
nats.RetryOnFailedConnect(true),
)

assert.NoError(t, err)
assert.NotNil(t, conn)
assert.True(t, conn.IsConnected())

conn.Close()
})

t.Run("does not block indefinitely if all connect attempts fail", func(t *testing.T) {
invalidAddr := "nats://invalid:4222"

appDieCh := make(chan bool)
done := make(chan any)

ts := test.RunDefaultServer()
defer ts.Shutdown()

go func() {
conn, err := setupNatsConn(
invalidAddr,
appDieCh,
nats.ReconnectWait(10*time.Millisecond),
nats.MaxReconnects(2),
nats.RetryOnFailedConnect(true),
)
assert.Error(t, err)
assert.Nil(t, conn)
close(done)
close(appDieCh)
}()

select {
case <-appDieCh:
case <-done:
case <-time.After(250 * time.Millisecond):
t.Fail()
}
})

t.Run("if it fails to connect, exit with error even if appDieChan is not ready to listen", func(t *testing.T) {
invalidAddr := "nats://invalid:4222"

appDieCh := make(chan bool)
done := make(chan any)

ts := test.RunDefaultServer()
defer ts.Shutdown()

go func() {
conn, err := setupNatsConn(invalidAddr, appDieCh)
assert.Error(t, err)
assert.Nil(t, conn)
close(done)
close(appDieCh)
}()

select {
case <-done:
case <-time.After(50 * time.Millisecond):
t.Fail()
}
})

t.Run("if connection takes too long, exit with error after waiting maxReconnTimeout", func(t *testing.T) {
invalidAddr := "nats://invalid:4222"

appDieCh := make(chan bool)
done := make(chan any)

initialConnectionTimeout := time.Nanosecond
maxReconnectionAtetmpts := 1
reconnectWait := time.Nanosecond
reconnectJitter := time.Nanosecond
maxReconnectionTimeout := reconnectWait + reconnectJitter + initialConnectionTimeout
maxReconnTimeout := initialConnectionTimeout + (time.Duration(maxReconnectionAtetmpts) * maxReconnectionTimeout)

maxTestTimeout := 100 * time.Millisecond

// Assert that if it fails because of connection timeout the test will capture
assert.Greater(t, maxTestTimeout, maxReconnTimeout)

ts := test.RunDefaultServer()
defer ts.Shutdown()

go func() {
conn, err := setupNatsConn(
invalidAddr,
appDieCh,
nats.Timeout(initialConnectionTimeout),
nats.ReconnectWait(reconnectWait),
nats.MaxReconnects(maxReconnectionAtetmpts),
nats.ReconnectJitter(reconnectJitter, reconnectJitter),
nats.RetryOnFailedConnect(true),
)
assert.Error(t, err)
assert.ErrorContains(t, err, "timeout setting up nats connection")
assert.Nil(t, conn)
close(done)
close(appDieCh)
}()

select {
case <-done:
case <-time.After(maxTestTimeout):
t.Fail()
}
})
}

0 comments on commit 85f7615

Please sign in to comment.