@@ -4,8 +4,10 @@ import (
4
4
"context"
5
5
"fmt"
6
6
"sort"
7
+ "sync/atomic"
7
8
8
9
"google.golang.org/grpc"
10
+ grpcCodes "google.golang.org/grpc/codes"
9
11
10
12
"github.com/ydb-platform/ydb-go-sdk/v3/config"
11
13
balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config"
@@ -19,7 +21,6 @@ import (
19
21
"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack"
20
22
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
21
23
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
22
- "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync"
23
24
"github.com/ydb-platform/ydb-go-sdk/v3/retry"
24
25
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
25
26
)
@@ -40,33 +41,13 @@ type Balancer struct {
40
41
discoveryRepeater repeater.Repeater
41
42
localDCDetector func (ctx context.Context , endpoints []endpoint.Endpoint ) (string , error )
42
43
43
- mu xsync.RWMutex
44
- connectionsState * connectionsState [conn.Conn ]
44
+ connections atomic.Pointer [connections [conn.Conn ]]
45
45
46
46
closed chan struct {}
47
47
48
48
onApplyDiscoveredEndpoints []func (ctx context.Context , endpoints []endpoint.Info )
49
49
}
50
50
51
- func (b * Balancer ) HasNode (id uint32 ) bool {
52
- if b .config .SingleConn {
53
- return true
54
- }
55
- b .mu .RLock ()
56
- defer b .mu .RUnlock ()
57
- if _ , has := b .connectionsState .connByNodeID [id ]; has {
58
- return true
59
- }
60
-
61
- return false
62
- }
63
-
64
- func (b * Balancer ) OnUpdate (onApplyDiscoveredEndpoints func (ctx context.Context , endpoints []endpoint.Info )) {
65
- b .mu .WithLock (func () {
66
- b .onApplyDiscoveredEndpoints = append (b .onApplyDiscoveredEndpoints , onApplyDiscoveredEndpoints )
67
- })
68
- }
69
-
70
51
func (b * Balancer ) clusterDiscovery (ctx context.Context ) (err error ) {
71
52
return retry .Retry (
72
53
repeater .WithEvent (ctx , repeater .EventInit ),
@@ -135,37 +116,37 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) {
135
116
return nil
136
117
}
137
118
138
- func endpointsDiff (newestEndpoints []endpoint. Endpoint , previousConns []conn. Info ) (
119
+ func endpointsDiff (newestEndpoints []trace. EndpointInfo , previousEndpoints []trace. EndpointInfo ) (
139
120
nodes []trace.EndpointInfo ,
140
121
added []trace.EndpointInfo ,
141
122
dropped []trace.EndpointInfo ,
142
123
) {
143
124
nodes = make ([]trace.EndpointInfo , 0 , len (newestEndpoints ))
144
- added = make ([]trace.EndpointInfo , 0 , len (previousConns ))
145
- dropped = make ([]trace.EndpointInfo , 0 , len (previousConns ))
125
+ added = make ([]trace.EndpointInfo , 0 , len (previousEndpoints ))
126
+ dropped = make ([]trace.EndpointInfo , 0 , len (previousEndpoints ))
146
127
var (
147
128
newestMap = make (map [string ]struct {}, len (newestEndpoints ))
148
- previousMap = make (map [string ]struct {}, len (previousConns ))
129
+ previousMap = make (map [string ]struct {}, len (previousEndpoints ))
149
130
)
150
131
sort .Slice (newestEndpoints , func (i , j int ) bool {
151
132
return newestEndpoints [i ].Address () < newestEndpoints [j ].Address ()
152
133
})
153
- sort .Slice (previousConns , func (i , j int ) bool {
154
- return previousConns [i ].Endpoint (). Address () < previousConns [j ]. Endpoint () .Address ()
134
+ sort .Slice (previousEndpoints , func (i , j int ) bool {
135
+ return previousEndpoints [i ].Address () < previousEndpoints [j ].Address ()
155
136
})
156
- for _ , e := range previousConns {
157
- previousMap [e .Endpoint (). Address ()] = struct {}{}
137
+ for _ , e := range previousEndpoints {
138
+ previousMap [e .Address ()] = struct {}{}
158
139
}
159
- for _ , e := range newestEndpoints {
160
- nodes = append (nodes , e . Copy () )
161
- newestMap [e .Address ()] = struct {}{}
162
- if _ , has := previousMap [e .Address ()]; ! has {
163
- added = append (added , e . Copy () )
140
+ for _ , info := range newestEndpoints {
141
+ nodes = append (nodes , info )
142
+ newestMap [info .Address ()] = struct {}{}
143
+ if _ , has := previousMap [info .Address ()]; ! has {
144
+ added = append (added , info )
164
145
}
165
146
}
166
- for _ , c := range previousConns {
167
- if _ , has := newestMap [c . Endpoint () .Address ()]; ! has {
168
- dropped = append (dropped , c . Endpoint (). Copy () )
147
+ for _ , info := range previousEndpoints {
148
+ if _ , has := newestMap [info .Address ()]; ! has {
149
+ dropped = append (dropped , info )
169
150
}
170
151
}
171
152
@@ -180,41 +161,28 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
180
161
"github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints" ),
181
162
b .config .DetectLocalDC ,
182
163
)
183
- previousConns []conn.Info
184
164
)
185
- defer func () {
186
- nodes , added , dropped := endpointsDiff (endpoints , previousConns )
187
- onDone (nodes , added , dropped , localDC )
188
- }()
189
165
190
166
connections := endpointsToConnections (b .pool , endpoints )
191
167
for _ , c := range connections {
192
- if c .State () == conn .Banned {
193
- b .pool .Unban (ctx , c )
194
- }
195
168
c .Endpoint ().Touch ()
196
169
}
197
170
198
171
info := balancerConfig.Info {SelfLocation : localDC }
199
- state := newConnectionsState (connections , b .config .Filter , info , b .config .AllowFallback )
200
-
201
- endpointsInfo := make ([]endpoint.Info , len (endpoints ))
202
- for i , e := range endpoints {
203
- endpointsInfo [i ] = e
204
- }
205
-
206
- b .mu .WithLock (func () {
207
- if b .connectionsState != nil {
208
- previousConns = make ([]conn.Info , len (b .connectionsState .all ))
209
- for i := range b .connectionsState .all {
210
- previousConns [i ] = b .connectionsState .all [i ]
211
- }
212
- }
213
- b .connectionsState = state
214
- for _ , onApplyDiscoveredEndpoints := range b .onApplyDiscoveredEndpoints {
215
- onApplyDiscoveredEndpoints (ctx , endpointsInfo )
172
+ newestConnections := newConns (connections , b .config .Filter , info , b .config .AllowFallback )
173
+ previousConnections := b .connections .Swap (newestConnections )
174
+ defer func () {
175
+ if previousConnections != nil {
176
+ nodes , added , dropped := endpointsDiff (newestConnections .all .ToTraceEndpointInfo (), previousConnections .all .ToTraceEndpointInfo ())
177
+ onDone (nodes , added , dropped , localDC )
178
+ } else {
179
+ nodes , added , dropped := endpointsDiff (newestConnections .all .ToTraceEndpointInfo (), nil )
180
+ onDone (nodes , added , dropped , localDC )
216
181
}
217
- })
182
+ }()
183
+ for _ , onApplyDiscoveredEndpoints := range b .onApplyDiscoveredEndpoints {
184
+ onApplyDiscoveredEndpoints (ctx , newestConnections .all .ToEndpointInfo ())
185
+ }
218
186
}
219
187
220
188
func (b * Balancer ) Close (ctx context.Context ) (err error ) {
@@ -241,6 +209,44 @@ func (b *Balancer) Close(ctx context.Context) (err error) {
241
209
return nil
242
210
}
243
211
212
+ func (b * Balancer ) markConnAsBad (ctx context.Context , cc conn.Conn , cause error ) {
213
+ onDone := trace .DriverOnBalancerMarkConnAsBad (
214
+ b .driverConfig .Trace (), & ctx ,
215
+ stack .FunctionID ("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).markConnAsBad" ),
216
+ cc .Endpoint (), cause ,
217
+ )
218
+
219
+ if ! xerrors .IsTransportError (cause ,
220
+ grpcCodes .ResourceExhausted ,
221
+ grpcCodes .Unavailable ,
222
+ // grpcCodes.OK,
223
+ // grpcCodes.Canceled,
224
+ // grpcCodes.Unknown,
225
+ // grpcCodes.InvalidArgument,
226
+ // grpcCodes.DeadlineExceeded,
227
+ // grpcCodes.NotFound,
228
+ // grpcCodes.AlreadyExists,
229
+ // grpcCodes.PermissionDenied,
230
+ // grpcCodes.FailedPrecondition,
231
+ // grpcCodes.Aborted,
232
+ // grpcCodes.OutOfRange,
233
+ // grpcCodes.Unimplemented,
234
+ // grpcCodes.Internal,
235
+ // grpcCodes.DataLoss,
236
+ // grpcCodes.Unauthenticated,
237
+ ) {
238
+ return
239
+ }
240
+
241
+ newestConns , changed := b .connections .Load ().withBadConn (cc )
242
+
243
+ if changed {
244
+ b .connections .Store (newestConns )
245
+ }
246
+
247
+ onDone (newestConns .prefer .ToTraceEndpointInfo (), newestConns .fallback .ToTraceEndpointInfo ())
248
+ }
249
+
244
250
func New (
245
251
ctx context.Context ,
246
252
driverConfig * config.Config ,
@@ -353,10 +359,8 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
353
359
}
354
360
355
361
defer func () {
356
- if err == nil {
357
- b .pool .Unban (ctx , cc )
358
- } else if xerrors .MustBanConn (err , b .driverConfig .ExcludeGRPCCodesForPessimization ()... ) {
359
- b .pool .Ban (ctx , cc , err )
362
+ if err != nil {
363
+ b .markConnAsBad (ctx , cc , err )
360
364
}
361
365
}()
362
366
@@ -383,13 +387,6 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
383
387
return nil
384
388
}
385
389
386
- func (b * Balancer ) connections () * connectionsState [conn.Conn ] {
387
- b .mu .RLock ()
388
- defer b .mu .RUnlock ()
389
-
390
- return b .connectionsState
391
- }
392
-
393
390
func (b * Balancer ) getConn (ctx context.Context ) (c conn.Conn , err error ) {
394
391
onDone := trace .DriverOnBalancerChooseEndpoint (
395
392
b .driverConfig .Trace (), & ctx ,
@@ -408,17 +405,17 @@ func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
408
405
}
409
406
410
407
var (
411
- state = b .connections ()
408
+ connections = b .connections . Load ()
412
409
failedCount int
413
410
)
414
411
415
412
defer func () {
416
- if failedCount * 2 > state .PreferredCount () && b .discoveryRepeater != nil {
413
+ if failedCount * 2 > connections .PreferredCount () && b .discoveryRepeater != nil {
417
414
b .discoveryRepeater .Force ()
418
415
}
419
416
}()
420
417
421
- c , failedCount = state . GetConnection (ctx )
418
+ c , failedCount = connections . GetConn (ctx )
422
419
if c == nil {
423
420
return nil , xerrors .WithStackTrace (
424
421
fmt .Errorf ("cannot get connection from Balancer after %d attempts: %w" , failedCount , ErrNoEndpoints ),
@@ -429,10 +426,10 @@ func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
429
426
}
430
427
431
428
func endpointsToConnections (p * conn.Pool , endpoints []endpoint.Endpoint ) []conn.Conn {
432
- conns := make ([]conn.Conn , 0 , len (endpoints ))
429
+ connections := make ([]conn.Conn , 0 , len (endpoints ))
433
430
for _ , e := range endpoints {
434
- conns = append (conns , p .Get (e ))
431
+ connections = append (connections , p .Get (e ))
435
432
}
436
433
437
- return conns
434
+ return connections
438
435
}
0 commit comments