diff --git a/internal/balancer/balancer.go b/internal/balancer/balancer.go index c9ce8d9b8..76aaf75f8 100644 --- a/internal/balancer/balancer.go +++ b/internal/balancer/balancer.go @@ -27,7 +27,10 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/trace" ) -var ErrNoEndpoints = xerrors.Wrap(fmt.Errorf("no endpoints")) +var ( + ErrNoEndpoints = xerrors.Wrap(fmt.Errorf("no endpoints")) + errBalancerClosed = xerrors.Wrap(fmt.Errorf("internal ydb sdk balancer closed")) +) type Balancer struct { driverConfig *config.Config @@ -40,6 +43,7 @@ type Balancer struct { localDCDetector func(ctx context.Context, endpoints []endpoint.Endpoint) (string, error) connectionsState atomic.Pointer[connectionsState] + closed atomic.Bool } func (b *Balancer) clusterDiscovery(ctx context.Context) (err error) { @@ -152,6 +156,8 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, newest []endpoi } func (b *Balancer) Close(ctx context.Context) (err error) { + b.closed.Store(true) + onDone := trace.DriverOnBalancerClose( b.driverConfig.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer.(*Balancer).Close"), @@ -281,6 +287,10 @@ func (b *Balancer) Invoke( reply interface{}, opts ...grpc.CallOption, ) error { + if b.closed.Load() { + return xerrors.WithStackTrace(errBalancerClosed) + } + return b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error { return cc.Invoke(ctx, method, args, reply, opts...) }) @@ -292,6 +302,10 @@ func (b *Balancer) NewStream( method string, opts ...grpc.CallOption, ) (_ grpc.ClientStream, err error) { + if b.closed.Load() { + return nil, xerrors.WithStackTrace(errBalancerClosed) + } + var client grpc.ClientStream err = b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error { client, err = cc.NewStream(ctx, desc, method, opts...) @@ -345,6 +359,10 @@ func (b *Balancer) connections() *connectionsState { } func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) { + if b.closed.Load() { + return nil, xerrors.WithStackTrace(errBalancerClosed) + } + onDone := trace.DriverOnBalancerChooseEndpoint( b.driverConfig.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer.(*Balancer).getConn"),