Skip to content

Commit 5455fcc

Browse files
committed
Refactoring
1 parent a0abd92 commit 5455fcc

38 files changed

+1378
-980
lines changed

balancers/balancers.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func SingleConn() *balancerConfig.Config {
2626

2727
type filterLocalDC struct{}
2828

29-
func (filterLocalDC) Allow(info balancerConfig.Info, c conn.Conn) bool {
29+
func (filterLocalDC) Allow(info balancerConfig.Info, c conn.Info) bool {
3030
return c.Endpoint().Location() == info.SelfLocation
3131
}
3232

@@ -56,7 +56,7 @@ func PreferLocalDCWithFallBack(balancer *balancerConfig.Config) *balancerConfig.
5656

5757
type filterLocations []string
5858

59-
func (locations filterLocations) Allow(_ balancerConfig.Info, c conn.Conn) bool {
59+
func (locations filterLocations) Allow(_ balancerConfig.Info, c conn.Info) bool {
6060
location := strings.ToUpper(c.Endpoint().Location())
6161
for _, l := range locations {
6262
if location == l {
@@ -118,9 +118,9 @@ type Endpoint interface {
118118
LocalDC() bool
119119
}
120120

121-
type filterFunc func(info balancerConfig.Info, c conn.Conn) bool
121+
type filterFunc func(info balancerConfig.Info, c conn.Info) bool
122122

123-
func (p filterFunc) Allow(info balancerConfig.Info, c conn.Conn) bool {
123+
func (p filterFunc) Allow(info balancerConfig.Info, c conn.Info) bool {
124124
return p(info, c)
125125
}
126126

@@ -131,7 +131,7 @@ func (p filterFunc) String() string {
131131
// Prefer creates balancer which use endpoints by filter
132132
// Balancer "balancer" defines balancing algorithm between endpoints selected with filter
133133
func Prefer(balancer *balancerConfig.Config, filter func(endpoint Endpoint) bool) *balancerConfig.Config {
134-
balancer.Filter = filterFunc(func(_ balancerConfig.Info, c conn.Conn) bool {
134+
balancer.Filter = filterFunc(func(_ balancerConfig.Info, c conn.Info) bool {
135135
return filter(c.Endpoint())
136136
})
137137

balancers/balancers_test.go

+23-23
Original file line numberDiff line numberDiff line change
@@ -11,56 +11,56 @@ import (
1111
)
1212

1313
func TestPreferLocalDC(t *testing.T) {
14-
conns := []conn.Conn{
15-
&mock.Conn{AddrField: "1", LocationField: "1"},
16-
&mock.Conn{AddrField: "2", State: conn.Online, LocationField: "2"},
17-
&mock.Conn{AddrField: "3", State: conn.Online, LocationField: "2"},
14+
conns := []conn.Info{
15+
&mock.ConnInfo{EndpointAddrField: "1", EndpointLocationField: "1"},
16+
&mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online, EndpointLocationField: "2"},
17+
&mock.ConnInfo{EndpointAddrField: "3", ConnState: conn.Online, EndpointLocationField: "2"},
1818
}
1919
rr := PreferLocalDC(RandomChoice())
2020
require.False(t, rr.AllowFallback)
21-
require.Equal(t, []conn.Conn{conns[1], conns[2]}, applyPreferFilter(balancerConfig.Info{SelfLocation: "2"}, rr, conns))
21+
require.Equal(t, []conn.Info{conns[1], conns[2]}, applyPreferFilter(balancerConfig.Info{SelfLocation: "2"}, rr, conns))
2222
}
2323

2424
func TestPreferLocalDCWithFallBack(t *testing.T) {
25-
conns := []conn.Conn{
26-
&mock.Conn{AddrField: "1", LocationField: "1"},
27-
&mock.Conn{AddrField: "2", State: conn.Online, LocationField: "2"},
28-
&mock.Conn{AddrField: "3", State: conn.Online, LocationField: "2"},
25+
conns := []conn.Info{
26+
&mock.ConnInfo{EndpointAddrField: "1", EndpointLocationField: "1"},
27+
&mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online, EndpointLocationField: "2"},
28+
&mock.ConnInfo{EndpointAddrField: "3", ConnState: conn.Online, EndpointLocationField: "2"},
2929
}
3030
rr := PreferLocalDCWithFallBack(RandomChoice())
3131
require.True(t, rr.AllowFallback)
32-
require.Equal(t, []conn.Conn{conns[1], conns[2]}, applyPreferFilter(balancerConfig.Info{SelfLocation: "2"}, rr, conns))
32+
require.Equal(t, []conn.Info{conns[1], conns[2]}, applyPreferFilter(balancerConfig.Info{SelfLocation: "2"}, rr, conns))
3333
}
3434

3535
func TestPreferLocations(t *testing.T) {
36-
conns := []conn.Conn{
37-
&mock.Conn{AddrField: "1", LocationField: "zero", State: conn.Online},
38-
&mock.Conn{AddrField: "2", State: conn.Online, LocationField: "one"},
39-
&mock.Conn{AddrField: "3", State: conn.Online, LocationField: "two"},
36+
conns := []conn.Info{
37+
&mock.ConnInfo{EndpointAddrField: "1", EndpointLocationField: "zero", ConnState: conn.Online},
38+
&mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online, EndpointLocationField: "one"},
39+
&mock.ConnInfo{EndpointAddrField: "3", ConnState: conn.Online, EndpointLocationField: "two"},
4040
}
4141

4242
rr := PreferLocations(RandomChoice(), "zero", "two")
4343
require.False(t, rr.AllowFallback)
44-
require.Equal(t, []conn.Conn{conns[0], conns[2]}, applyPreferFilter(balancerConfig.Info{}, rr, conns))
44+
require.Equal(t, []conn.Info{conns[0], conns[2]}, applyPreferFilter(balancerConfig.Info{}, rr, conns))
4545
}
4646

4747
func TestPreferLocationsWithFallback(t *testing.T) {
48-
conns := []conn.Conn{
49-
&mock.Conn{AddrField: "1", LocationField: "zero", State: conn.Online},
50-
&mock.Conn{AddrField: "2", State: conn.Online, LocationField: "one"},
51-
&mock.Conn{AddrField: "3", State: conn.Online, LocationField: "two"},
48+
conns := []conn.Info{
49+
&mock.ConnInfo{EndpointAddrField: "1", EndpointLocationField: "zero", ConnState: conn.Online},
50+
&mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online, EndpointLocationField: "one"},
51+
&mock.ConnInfo{EndpointAddrField: "3", ConnState: conn.Online, EndpointLocationField: "two"},
5252
}
5353

5454
rr := PreferLocationsWithFallback(RandomChoice(), "zero", "two")
5555
require.True(t, rr.AllowFallback)
56-
require.Equal(t, []conn.Conn{conns[0], conns[2]}, applyPreferFilter(balancerConfig.Info{}, rr, conns))
56+
require.Equal(t, []conn.Info{conns[0], conns[2]}, applyPreferFilter(balancerConfig.Info{}, rr, conns))
5757
}
5858

59-
func applyPreferFilter(info balancerConfig.Info, b *balancerConfig.Config, conns []conn.Conn) []conn.Conn {
59+
func applyPreferFilter(info balancerConfig.Info, b *balancerConfig.Config, conns []conn.Info) []conn.Info {
6060
if b.Filter == nil {
61-
b.Filter = filterFunc(func(info balancerConfig.Info, c conn.Conn) bool { return true })
61+
b.Filter = filterFunc(func(info balancerConfig.Info, c conn.Info) bool { return true })
6262
}
63-
res := make([]conn.Conn, 0, len(conns))
63+
res := make([]conn.Info, 0, len(conns))
6464
for _, c := range conns {
6565
if b.Filter.Allow(info, c) {
6666
res = append(res, c)

balancers/config_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func TestFromConfig(t *testing.T) {
7171
}`,
7272
res: balancerConfig.Config{
7373
DetectLocalDC: true,
74-
Filter: filterFunc(func(info balancerConfig.Info, c conn.Conn) bool {
74+
Filter: filterFunc(func(info balancerConfig.Info, c conn.Info) bool {
7575
// some non nil func
7676
return false
7777
}),
@@ -95,7 +95,7 @@ func TestFromConfig(t *testing.T) {
9595
res: balancerConfig.Config{
9696
AllowFallback: true,
9797
DetectLocalDC: true,
98-
Filter: filterFunc(func(info balancerConfig.Info, c conn.Conn) bool {
98+
Filter: filterFunc(func(info balancerConfig.Info, c conn.Info) bool {
9999
// some non nil func
100100
return false
101101
}),
@@ -109,7 +109,7 @@ func TestFromConfig(t *testing.T) {
109109
"locations": ["AAA", "BBB", "CCC"]
110110
}`,
111111
res: balancerConfig.Config{
112-
Filter: filterFunc(func(info balancerConfig.Info, c conn.Conn) bool {
112+
Filter: filterFunc(func(info balancerConfig.Info, c conn.Info) bool {
113113
// some non nil func
114114
return false
115115
}),
@@ -125,7 +125,7 @@ func TestFromConfig(t *testing.T) {
125125
}`,
126126
res: balancerConfig.Config{
127127
AllowFallback: true,
128-
Filter: filterFunc(func(info balancerConfig.Info, c conn.Conn) bool {
128+
Filter: filterFunc(func(info balancerConfig.Info, c conn.Info) bool {
129129
// some non nil func
130130
return false
131131
}),

driver.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ func (d *Driver) Close(ctx context.Context) (finalErr error) {
152152
d.query.Close,
153153
d.topic.Close,
154154
d.balancer.Close,
155-
d.pool.Release,
155+
d.pool.Detach,
156156
)
157157

158158
var issues []error

internal/balancer/balancer.go

+43-23
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ type Balancer struct {
4141
localDCDetector func(ctx context.Context, endpoints []endpoint.Endpoint) (string, error)
4242

4343
mu xsync.RWMutex
44-
connectionsState *connectionsState
44+
connectionsState *connectionsState[conn.Conn]
45+
46+
closed chan struct{}
4547

4648
onApplyDiscoveredEndpoints []func(ctx context.Context, endpoints []endpoint.Info)
4749
}
@@ -133,7 +135,7 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) {
133135
return nil
134136
}
135137

136-
func endpointsDiff(newestEndpoints []endpoint.Endpoint, previousConns []conn.Conn) (
138+
func endpointsDiff(newestEndpoints []endpoint.Endpoint, previousConns []conn.Info) (
137139
nodes []trace.EndpointInfo,
138140
added []trace.EndpointInfo,
139141
dropped []trace.EndpointInfo,
@@ -178,7 +180,7 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
178180
"github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints"),
179181
b.config.DetectLocalDC,
180182
)
181-
previousConns []conn.Conn
183+
previousConns []conn.Info
182184
)
183185
defer func() {
184186
nodes, added, dropped := endpointsDiff(endpoints, previousConns)
@@ -187,7 +189,9 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
187189

188190
connections := endpointsToConnections(b.pool, endpoints)
189191
for _, c := range connections {
190-
b.pool.Allow(ctx, c)
192+
if c.State() == conn.Banned {
193+
b.pool.Unban(ctx, c)
194+
}
191195
c.Endpoint().Touch()
192196
}
193197

@@ -201,7 +205,10 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
201205

202206
b.mu.WithLock(func() {
203207
if b.connectionsState != nil {
204-
previousConns = b.connectionsState.all
208+
previousConns = make([]conn.Info, len(b.connectionsState.all))
209+
for i := range b.connectionsState.all {
210+
previousConns[i] = b.connectionsState.all[i]
211+
}
205212
}
206213
b.connectionsState = state
207214
for _, onApplyDiscoveredEndpoints := range b.onApplyDiscoveredEndpoints {
@@ -211,6 +218,8 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
211218
}
212219

213220
func (b *Balancer) Close(ctx context.Context) (err error) {
221+
close(b.closed)
222+
214223
onDone := trace.DriverOnBalancerClose(
215224
b.driverConfig.Trace(), &ctx,
216225
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).Close"),
@@ -223,6 +232,8 @@ func (b *Balancer) Close(ctx context.Context) (err error) {
223232
b.discoveryRepeater.Stop()
224233
}
225234

235+
b.applyDiscoveredEndpoints(ctx, nil, "")
236+
226237
if err = b.discoveryClient.Close(ctx); err != nil {
227238
return xerrors.WithStackTrace(err)
228239
}
@@ -258,6 +269,7 @@ func New(
258269
driverConfig: driverConfig,
259270
pool: pool,
260271
localDCDetector: detectLocalDC,
272+
closed: make(chan struct{}),
261273
}
262274
d := internalDiscovery.New(ctx, pool.Get(
263275
endpoint.New(driverConfig.Endpoint()),
@@ -300,9 +312,14 @@ func (b *Balancer) Invoke(
300312
reply interface{},
301313
opts ...grpc.CallOption,
302314
) error {
303-
return b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error {
304-
return cc.Invoke(ctx, method, args, reply, opts...)
305-
})
315+
select {
316+
case <-b.closed:
317+
return xerrors.WithStackTrace(errBalancerClosed)
318+
default:
319+
return b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error {
320+
return cc.Invoke(ctx, method, args, reply, opts...)
321+
})
322+
}
306323
}
307324

308325
func (b *Balancer) NewStream(
@@ -311,17 +328,22 @@ func (b *Balancer) NewStream(
311328
method string,
312329
opts ...grpc.CallOption,
313330
) (_ grpc.ClientStream, err error) {
314-
var client grpc.ClientStream
315-
err = b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error {
316-
client, err = cc.NewStream(ctx, desc, method, opts...)
331+
select {
332+
case <-b.closed:
333+
return nil, xerrors.WithStackTrace(errBalancerClosed)
334+
default:
335+
var client grpc.ClientStream
336+
err = b.wrapCall(ctx, func(ctx context.Context, cc conn.Conn) error {
337+
client, err = cc.NewStream(ctx, desc, method, opts...)
338+
339+
return err
340+
})
341+
if err == nil {
342+
return client, nil
343+
}
317344

318-
return err
319-
})
320-
if err == nil {
321-
return client, nil
345+
return nil, err
322346
}
323-
324-
return nil, err
325347
}
326348

327349
func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc conn.Conn) error) (err error) {
@@ -332,10 +354,8 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
332354

333355
defer func() {
334356
if err == nil {
335-
if cc.GetState() == conn.Banned {
336-
b.pool.Allow(ctx, cc)
337-
}
338-
} else if xerrors.MustPessimizeEndpoint(err, b.driverConfig.ExcludeGRPCCodesForPessimization()...) {
357+
b.pool.Unban(ctx, cc)
358+
} else if xerrors.MustBanConn(err, b.driverConfig.ExcludeGRPCCodesForPessimization()...) {
339359
b.pool.Ban(ctx, cc, err)
340360
}
341361
}()
@@ -363,7 +383,7 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
363383
return nil
364384
}
365385

366-
func (b *Balancer) connections() *connectionsState {
386+
func (b *Balancer) connections() *connectionsState[conn.Conn] {
367387
b.mu.RLock()
368388
defer b.mu.RUnlock()
369389

@@ -401,7 +421,7 @@ func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
401421
c, failedCount = state.GetConnection(ctx)
402422
if c == nil {
403423
return nil, xerrors.WithStackTrace(
404-
fmt.Errorf("%w: cannot get connection from Balancer after %d attempts", ErrNoEndpoints, failedCount),
424+
fmt.Errorf("cannot get connection from Balancer after %d attempts: %w", failedCount, ErrNoEndpoints),
405425
)
406426
}
407427

internal/balancer/balancer_test.go

+19-19
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515
func TestEndpointsDiff(t *testing.T) {
1616
for _, tt := range []struct {
1717
newestEndpoints []endpoint.Endpoint
18-
previousConns []conn.Conn
18+
previousConns []conn.Info
1919
nodes []trace.EndpointInfo
2020
added []trace.EndpointInfo
2121
dropped []trace.EndpointInfo
@@ -27,11 +27,11 @@ func TestEndpointsDiff(t *testing.T) {
2727
&mock.Endpoint{AddrField: "2"},
2828
&mock.Endpoint{AddrField: "0"},
2929
},
30-
previousConns: []conn.Conn{
31-
&mock.Conn{AddrField: "2"},
32-
&mock.Conn{AddrField: "1"},
33-
&mock.Conn{AddrField: "0"},
34-
&mock.Conn{AddrField: "3"},
30+
previousConns: []conn.Info{
31+
&mock.ConnInfo{EndpointAddrField: "2"},
32+
&mock.ConnInfo{EndpointAddrField: "1"},
33+
&mock.ConnInfo{EndpointAddrField: "0"},
34+
&mock.ConnInfo{EndpointAddrField: "3"},
3535
},
3636
nodes: []trace.EndpointInfo{
3737
&mock.Endpoint{AddrField: "0"},
@@ -49,10 +49,10 @@ func TestEndpointsDiff(t *testing.T) {
4949
&mock.Endpoint{AddrField: "2"},
5050
&mock.Endpoint{AddrField: "0"},
5151
},
52-
previousConns: []conn.Conn{
53-
&mock.Conn{AddrField: "1"},
54-
&mock.Conn{AddrField: "0"},
55-
&mock.Conn{AddrField: "3"},
52+
previousConns: []conn.Info{
53+
&mock.ConnInfo{EndpointAddrField: "1"},
54+
&mock.ConnInfo{EndpointAddrField: "0"},
55+
&mock.ConnInfo{EndpointAddrField: "3"},
5656
},
5757
nodes: []trace.EndpointInfo{
5858
&mock.Endpoint{AddrField: "0"},
@@ -71,11 +71,11 @@ func TestEndpointsDiff(t *testing.T) {
7171
&mock.Endpoint{AddrField: "3"},
7272
&mock.Endpoint{AddrField: "0"},
7373
},
74-
previousConns: []conn.Conn{
75-
&mock.Conn{AddrField: "1"},
76-
&mock.Conn{AddrField: "2"},
77-
&mock.Conn{AddrField: "0"},
78-
&mock.Conn{AddrField: "3"},
74+
previousConns: []conn.Info{
75+
&mock.ConnInfo{EndpointAddrField: "1"},
76+
&mock.ConnInfo{EndpointAddrField: "2"},
77+
&mock.ConnInfo{EndpointAddrField: "0"},
78+
&mock.ConnInfo{EndpointAddrField: "3"},
7979
},
8080
nodes: []trace.EndpointInfo{
8181
&mock.Endpoint{AddrField: "0"},
@@ -93,10 +93,10 @@ func TestEndpointsDiff(t *testing.T) {
9393
&mock.Endpoint{AddrField: "3"},
9494
&mock.Endpoint{AddrField: "0"},
9595
},
96-
previousConns: []conn.Conn{
97-
&mock.Conn{AddrField: "4"},
98-
&mock.Conn{AddrField: "7"},
99-
&mock.Conn{AddrField: "8"},
96+
previousConns: []conn.Info{
97+
&mock.ConnInfo{EndpointAddrField: "4"},
98+
&mock.ConnInfo{EndpointAddrField: "7"},
99+
&mock.ConnInfo{EndpointAddrField: "8"},
100100
},
101101
nodes: []trace.EndpointInfo{
102102
&mock.Endpoint{AddrField: "0"},

0 commit comments

Comments
 (0)