@@ -41,7 +41,9 @@ type Balancer struct {
41
41
localDCDetector func (ctx context.Context , endpoints []endpoint.Endpoint ) (string , error )
42
42
43
43
mu xsync.RWMutex
44
- connectionsState * connectionsState
44
+ connectionsState * connectionsState [conn.Conn ]
45
+
46
+ closed chan struct {}
45
47
46
48
onApplyDiscoveredEndpoints []func (ctx context.Context , endpoints []endpoint.Info )
47
49
}
@@ -133,7 +135,7 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) {
133
135
return nil
134
136
}
135
137
136
- func endpointsDiff (newestEndpoints []endpoint.Endpoint , previousConns []conn.Conn ) (
138
+ func endpointsDiff (newestEndpoints []endpoint.Endpoint , previousConns []conn.Info ) (
137
139
nodes []trace.EndpointInfo ,
138
140
added []trace.EndpointInfo ,
139
141
dropped []trace.EndpointInfo ,
@@ -178,7 +180,7 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
178
180
"github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints" ),
179
181
b .config .DetectLocalDC ,
180
182
)
181
- previousConns []conn.Conn
183
+ previousConns []conn.Info
182
184
)
183
185
defer func () {
184
186
nodes , added , dropped := endpointsDiff (endpoints , previousConns )
@@ -187,7 +189,9 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
187
189
188
190
connections := endpointsToConnections (b .pool , endpoints )
189
191
for _ , c := range connections {
190
- b .pool .Allow (ctx , c )
192
+ if c .State () == conn .Banned {
193
+ b .pool .Unban (ctx , c )
194
+ }
191
195
c .Endpoint ().Touch ()
192
196
}
193
197
@@ -201,7 +205,10 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
201
205
202
206
b .mu .WithLock (func () {
203
207
if b .connectionsState != nil {
204
- previousConns = b .connectionsState .all
208
+ previousConns = make ([]conn.Info , len (b .connectionsState .all ))
209
+ for i := range b .connectionsState .all {
210
+ previousConns [i ] = b .connectionsState .all [i ]
211
+ }
205
212
}
206
213
b .connectionsState = state
207
214
for _ , onApplyDiscoveredEndpoints := range b .onApplyDiscoveredEndpoints {
@@ -211,6 +218,8 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
211
218
}
212
219
213
220
func (b * Balancer ) Close (ctx context.Context ) (err error ) {
221
+ close (b .closed )
222
+
214
223
onDone := trace .DriverOnBalancerClose (
215
224
b .driverConfig .Trace (), & ctx ,
216
225
stack .FunctionID ("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).Close" ),
@@ -223,6 +232,8 @@ func (b *Balancer) Close(ctx context.Context) (err error) {
223
232
b .discoveryRepeater .Stop ()
224
233
}
225
234
235
+ b .applyDiscoveredEndpoints (ctx , nil , "" )
236
+
226
237
if err = b .discoveryClient .Close (ctx ); err != nil {
227
238
return xerrors .WithStackTrace (err )
228
239
}
@@ -258,6 +269,7 @@ func New(
258
269
driverConfig : driverConfig ,
259
270
pool : pool ,
260
271
localDCDetector : detectLocalDC ,
272
+ closed : make (chan struct {}),
261
273
}
262
274
d := internalDiscovery .New (ctx , pool .Get (
263
275
endpoint .New (driverConfig .Endpoint ()),
@@ -300,9 +312,14 @@ func (b *Balancer) Invoke(
300
312
reply interface {},
301
313
opts ... grpc.CallOption ,
302
314
) error {
303
- return b .wrapCall (ctx , func (ctx context.Context , cc conn.Conn ) error {
304
- return cc .Invoke (ctx , method , args , reply , opts ... )
305
- })
315
+ select {
316
+ case <- b .closed :
317
+ return xerrors .WithStackTrace (errBalancerClosed )
318
+ default :
319
+ return b .wrapCall (ctx , func (ctx context.Context , cc conn.Conn ) error {
320
+ return cc .Invoke (ctx , method , args , reply , opts ... )
321
+ })
322
+ }
306
323
}
307
324
308
325
func (b * Balancer ) NewStream (
@@ -311,17 +328,22 @@ func (b *Balancer) NewStream(
311
328
method string ,
312
329
opts ... grpc.CallOption ,
313
330
) (_ grpc.ClientStream , err error ) {
314
- var client grpc.ClientStream
315
- err = b .wrapCall (ctx , func (ctx context.Context , cc conn.Conn ) error {
316
- client , err = cc .NewStream (ctx , desc , method , opts ... )
331
+ select {
332
+ case <- b .closed :
333
+ return nil , xerrors .WithStackTrace (errBalancerClosed )
334
+ default :
335
+ var client grpc.ClientStream
336
+ err = b .wrapCall (ctx , func (ctx context.Context , cc conn.Conn ) error {
337
+ client , err = cc .NewStream (ctx , desc , method , opts ... )
338
+
339
+ return err
340
+ })
341
+ if err == nil {
342
+ return client , nil
343
+ }
317
344
318
- return err
319
- })
320
- if err == nil {
321
- return client , nil
345
+ return nil , err
322
346
}
323
-
324
- return nil , err
325
347
}
326
348
327
349
func (b * Balancer ) wrapCall (ctx context.Context , f func (ctx context.Context , cc conn.Conn ) error ) (err error ) {
@@ -332,10 +354,8 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
332
354
333
355
defer func () {
334
356
if err == nil {
335
- if cc .GetState () == conn .Banned {
336
- b .pool .Allow (ctx , cc )
337
- }
338
- } else if xerrors .MustPessimizeEndpoint (err , b .driverConfig .ExcludeGRPCCodesForPessimization ()... ) {
357
+ b .pool .Unban (ctx , cc )
358
+ } else if xerrors .MustBanConn (err , b .driverConfig .ExcludeGRPCCodesForPessimization ()... ) {
339
359
b .pool .Ban (ctx , cc , err )
340
360
}
341
361
}()
@@ -363,7 +383,7 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
363
383
return nil
364
384
}
365
385
366
- func (b * Balancer ) connections () * connectionsState {
386
+ func (b * Balancer ) connections () * connectionsState [conn. Conn ] {
367
387
b .mu .RLock ()
368
388
defer b .mu .RUnlock ()
369
389
@@ -401,7 +421,7 @@ func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
401
421
c , failedCount = state .GetConnection (ctx )
402
422
if c == nil {
403
423
return nil , xerrors .WithStackTrace (
404
- fmt .Errorf ("%w: cannot get connection from Balancer after %d attempts" , ErrNoEndpoints , failedCount ),
424
+ fmt .Errorf ("cannot get connection from Balancer after %d attempts: %w " , failedCount , ErrNoEndpoints ),
405
425
)
406
426
}
407
427
0 commit comments