Skip to content

Commit cae12e7

Browse files
committed
WIP
1 parent abe3bf3 commit cae12e7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+33737
-510
lines changed

balancers/balancers_test.go

+11-10
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"testing"
55

66
"github.com/stretchr/testify/require"
7+
"google.golang.org/grpc/connectivity"
78

89
balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config"
910
"github.com/ydb-platform/ydb-go-sdk/v3/internal/conn"
@@ -13,8 +14,8 @@ import (
1314
func TestPreferLocalDC(t *testing.T) {
1415
conns := []conn.Info{
1516
&mock.ConnInfo{EndpointAddrField: "1", EndpointLocationField: "1"},
16-
&mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online, EndpointLocationField: "2"},
17-
&mock.ConnInfo{EndpointAddrField: "3", ConnState: conn.Online, EndpointLocationField: "2"},
17+
&mock.ConnInfo{EndpointAddrField: "2", ConnState: connectivity.Ready, EndpointLocationField: "2"},
18+
&mock.ConnInfo{EndpointAddrField: "3", ConnState: connectivity.Ready, EndpointLocationField: "2"},
1819
}
1920
rr := PreferLocalDC(RandomChoice())
2021
require.False(t, rr.AllowFallback)
@@ -24,8 +25,8 @@ func TestPreferLocalDC(t *testing.T) {
2425
func TestPreferLocalDCWithFallBack(t *testing.T) {
2526
conns := []conn.Info{
2627
&mock.ConnInfo{EndpointAddrField: "1", EndpointLocationField: "1"},
27-
&mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online, EndpointLocationField: "2"},
28-
&mock.ConnInfo{EndpointAddrField: "3", ConnState: conn.Online, EndpointLocationField: "2"},
28+
&mock.ConnInfo{EndpointAddrField: "2", ConnState: connectivity.Ready, EndpointLocationField: "2"},
29+
&mock.ConnInfo{EndpointAddrField: "3", ConnState: connectivity.Ready, EndpointLocationField: "2"},
2930
}
3031
rr := PreferLocalDCWithFallBack(RandomChoice())
3132
require.True(t, rr.AllowFallback)
@@ -34,9 +35,9 @@ func TestPreferLocalDCWithFallBack(t *testing.T) {
3435

3536
func TestPreferLocations(t *testing.T) {
3637
conns := []conn.Info{
37-
&mock.ConnInfo{EndpointAddrField: "1", EndpointLocationField: "zero", ConnState: conn.Online},
38-
&mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online, EndpointLocationField: "one"},
39-
&mock.ConnInfo{EndpointAddrField: "3", ConnState: conn.Online, EndpointLocationField: "two"},
38+
&mock.ConnInfo{EndpointAddrField: "1", EndpointLocationField: "zero", ConnState: connectivity.Ready},
39+
&mock.ConnInfo{EndpointAddrField: "2", ConnState: connectivity.Ready, EndpointLocationField: "one"},
40+
&mock.ConnInfo{EndpointAddrField: "3", ConnState: connectivity.Ready, EndpointLocationField: "two"},
4041
}
4142

4243
rr := PreferLocations(RandomChoice(), "zero", "two")
@@ -46,9 +47,9 @@ func TestPreferLocations(t *testing.T) {
4647

4748
func TestPreferLocationsWithFallback(t *testing.T) {
4849
conns := []conn.Info{
49-
&mock.ConnInfo{EndpointAddrField: "1", EndpointLocationField: "zero", ConnState: conn.Online},
50-
&mock.ConnInfo{EndpointAddrField: "2", ConnState: conn.Online, EndpointLocationField: "one"},
51-
&mock.ConnInfo{EndpointAddrField: "3", ConnState: conn.Online, EndpointLocationField: "two"},
50+
&mock.ConnInfo{EndpointAddrField: "1", EndpointLocationField: "zero", ConnState: connectivity.Ready},
51+
&mock.ConnInfo{EndpointAddrField: "2", ConnState: connectivity.Ready, EndpointLocationField: "one"},
52+
&mock.ConnInfo{EndpointAddrField: "3", ConnState: connectivity.Ready, EndpointLocationField: "two"},
5253
}
5354

5455
rr := PreferLocationsWithFallback(RandomChoice(), "zero", "two")

internal/balancer/balancer.go

+78-81
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ import (
44
"context"
55
"fmt"
66
"sort"
7+
"sync/atomic"
78

89
"google.golang.org/grpc"
10+
grpcCodes "google.golang.org/grpc/codes"
911

1012
"github.com/ydb-platform/ydb-go-sdk/v3/config"
1113
balancerConfig "github.com/ydb-platform/ydb-go-sdk/v3/internal/balancer/config"
@@ -19,7 +21,6 @@ import (
1921
"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack"
2022
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext"
2123
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
22-
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync"
2324
"github.com/ydb-platform/ydb-go-sdk/v3/retry"
2425
"github.com/ydb-platform/ydb-go-sdk/v3/trace"
2526
)
@@ -40,33 +41,13 @@ type Balancer struct {
4041
discoveryRepeater repeater.Repeater
4142
localDCDetector func(ctx context.Context, endpoints []endpoint.Endpoint) (string, error)
4243

43-
mu xsync.RWMutex
44-
connectionsState *connectionsState[conn.Conn]
44+
connections atomic.Pointer[connections[conn.Conn]]
4545

4646
closed chan struct{}
4747

4848
onApplyDiscoveredEndpoints []func(ctx context.Context, endpoints []endpoint.Info)
4949
}
5050

51-
func (b *Balancer) HasNode(id uint32) bool {
52-
if b.config.SingleConn {
53-
return true
54-
}
55-
b.mu.RLock()
56-
defer b.mu.RUnlock()
57-
if _, has := b.connectionsState.connByNodeID[id]; has {
58-
return true
59-
}
60-
61-
return false
62-
}
63-
64-
func (b *Balancer) OnUpdate(onApplyDiscoveredEndpoints func(ctx context.Context, endpoints []endpoint.Info)) {
65-
b.mu.WithLock(func() {
66-
b.onApplyDiscoveredEndpoints = append(b.onApplyDiscoveredEndpoints, onApplyDiscoveredEndpoints)
67-
})
68-
}
69-
7051
func (b *Balancer) clusterDiscovery(ctx context.Context) (err error) {
7152
return retry.Retry(
7253
repeater.WithEvent(ctx, repeater.EventInit),
@@ -135,37 +116,37 @@ func (b *Balancer) clusterDiscoveryAttempt(ctx context.Context) (err error) {
135116
return nil
136117
}
137118

138-
func endpointsDiff(newestEndpoints []endpoint.Endpoint, previousConns []conn.Info) (
119+
func endpointsDiff(newestEndpoints []trace.EndpointInfo, previousEndpoints []trace.EndpointInfo) (
139120
nodes []trace.EndpointInfo,
140121
added []trace.EndpointInfo,
141122
dropped []trace.EndpointInfo,
142123
) {
143124
nodes = make([]trace.EndpointInfo, 0, len(newestEndpoints))
144-
added = make([]trace.EndpointInfo, 0, len(previousConns))
145-
dropped = make([]trace.EndpointInfo, 0, len(previousConns))
125+
added = make([]trace.EndpointInfo, 0, len(previousEndpoints))
126+
dropped = make([]trace.EndpointInfo, 0, len(previousEndpoints))
146127
var (
147128
newestMap = make(map[string]struct{}, len(newestEndpoints))
148-
previousMap = make(map[string]struct{}, len(previousConns))
129+
previousMap = make(map[string]struct{}, len(previousEndpoints))
149130
)
150131
sort.Slice(newestEndpoints, func(i, j int) bool {
151132
return newestEndpoints[i].Address() < newestEndpoints[j].Address()
152133
})
153-
sort.Slice(previousConns, func(i, j int) bool {
154-
return previousConns[i].Endpoint().Address() < previousConns[j].Endpoint().Address()
134+
sort.Slice(previousEndpoints, func(i, j int) bool {
135+
return previousEndpoints[i].Address() < previousEndpoints[j].Address()
155136
})
156-
for _, e := range previousConns {
157-
previousMap[e.Endpoint().Address()] = struct{}{}
137+
for _, e := range previousEndpoints {
138+
previousMap[e.Address()] = struct{}{}
158139
}
159-
for _, e := range newestEndpoints {
160-
nodes = append(nodes, e.Copy())
161-
newestMap[e.Address()] = struct{}{}
162-
if _, has := previousMap[e.Address()]; !has {
163-
added = append(added, e.Copy())
140+
for _, info := range newestEndpoints {
141+
nodes = append(nodes, info)
142+
newestMap[info.Address()] = struct{}{}
143+
if _, has := previousMap[info.Address()]; !has {
144+
added = append(added, info)
164145
}
165146
}
166-
for _, c := range previousConns {
167-
if _, has := newestMap[c.Endpoint().Address()]; !has {
168-
dropped = append(dropped, c.Endpoint().Copy())
147+
for _, info := range previousEndpoints {
148+
if _, has := newestMap[info.Address()]; !has {
149+
dropped = append(dropped, info)
169150
}
170151
}
171152

@@ -180,41 +161,28 @@ func (b *Balancer) applyDiscoveredEndpoints(ctx context.Context, endpoints []end
180161
"github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).applyDiscoveredEndpoints"),
181162
b.config.DetectLocalDC,
182163
)
183-
previousConns []conn.Info
184164
)
185-
defer func() {
186-
nodes, added, dropped := endpointsDiff(endpoints, previousConns)
187-
onDone(nodes, added, dropped, localDC)
188-
}()
189165

190166
connections := endpointsToConnections(b.pool, endpoints)
191167
for _, c := range connections {
192-
if c.State() == conn.Banned {
193-
b.pool.Unban(ctx, c)
194-
}
195168
c.Endpoint().Touch()
196169
}
197170

198171
info := balancerConfig.Info{SelfLocation: localDC}
199-
state := newConnectionsState(connections, b.config.Filter, info, b.config.AllowFallback)
200-
201-
endpointsInfo := make([]endpoint.Info, len(endpoints))
202-
for i, e := range endpoints {
203-
endpointsInfo[i] = e
204-
}
205-
206-
b.mu.WithLock(func() {
207-
if b.connectionsState != nil {
208-
previousConns = make([]conn.Info, len(b.connectionsState.all))
209-
for i := range b.connectionsState.all {
210-
previousConns[i] = b.connectionsState.all[i]
211-
}
212-
}
213-
b.connectionsState = state
214-
for _, onApplyDiscoveredEndpoints := range b.onApplyDiscoveredEndpoints {
215-
onApplyDiscoveredEndpoints(ctx, endpointsInfo)
172+
newestConnections := newConns(connections, b.config.Filter, info, b.config.AllowFallback)
173+
previousConnections := b.connections.Swap(newestConnections)
174+
defer func() {
175+
if previousConnections != nil {
176+
nodes, added, dropped := endpointsDiff(newestConnections.all.ToTraceEndpointInfo(), previousConnections.all.ToTraceEndpointInfo())
177+
onDone(nodes, added, dropped, localDC)
178+
} else {
179+
nodes, added, dropped := endpointsDiff(newestConnections.all.ToTraceEndpointInfo(), nil)
180+
onDone(nodes, added, dropped, localDC)
216181
}
217-
})
182+
}()
183+
for _, onApplyDiscoveredEndpoints := range b.onApplyDiscoveredEndpoints {
184+
onApplyDiscoveredEndpoints(ctx, newestConnections.all.ToEndpointInfo())
185+
}
218186
}
219187

220188
func (b *Balancer) Close(ctx context.Context) (err error) {
@@ -241,6 +209,44 @@ func (b *Balancer) Close(ctx context.Context) (err error) {
241209
return nil
242210
}
243211

212+
func (b *Balancer) markConnAsBad(ctx context.Context, cc conn.Conn, cause error) {
213+
onDone := trace.DriverOnBalancerMarkConnAsBad(
214+
b.driverConfig.Trace(), &ctx,
215+
stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/3/internal/balancer.(*Balancer).markConnAsBad"),
216+
cc.Endpoint(), cause,
217+
)
218+
219+
if !xerrors.IsTransportError(cause,
220+
grpcCodes.ResourceExhausted,
221+
grpcCodes.Unavailable,
222+
// grpcCodes.OK,
223+
// grpcCodes.Canceled,
224+
// grpcCodes.Unknown,
225+
// grpcCodes.InvalidArgument,
226+
// grpcCodes.DeadlineExceeded,
227+
// grpcCodes.NotFound,
228+
// grpcCodes.AlreadyExists,
229+
// grpcCodes.PermissionDenied,
230+
// grpcCodes.FailedPrecondition,
231+
// grpcCodes.Aborted,
232+
// grpcCodes.OutOfRange,
233+
// grpcCodes.Unimplemented,
234+
// grpcCodes.Internal,
235+
// grpcCodes.DataLoss,
236+
// grpcCodes.Unauthenticated,
237+
) {
238+
return
239+
}
240+
241+
newestConns, changed := b.connections.Load().withBadConn(cc)
242+
243+
if changed {
244+
b.connections.Store(newestConns)
245+
}
246+
247+
onDone(newestConns.prefer.ToTraceEndpointInfo(), newestConns.fallback.ToTraceEndpointInfo())
248+
}
249+
244250
func New(
245251
ctx context.Context,
246252
driverConfig *config.Config,
@@ -353,10 +359,8 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
353359
}
354360

355361
defer func() {
356-
if err == nil {
357-
b.pool.Unban(ctx, cc)
358-
} else if xerrors.MustBanConn(err, b.driverConfig.ExcludeGRPCCodesForPessimization()...) {
359-
b.pool.Ban(ctx, cc, err)
362+
if err != nil {
363+
b.markConnAsBad(ctx, cc, err)
360364
}
361365
}()
362366

@@ -383,13 +387,6 @@ func (b *Balancer) wrapCall(ctx context.Context, f func(ctx context.Context, cc
383387
return nil
384388
}
385389

386-
func (b *Balancer) connections() *connectionsState[conn.Conn] {
387-
b.mu.RLock()
388-
defer b.mu.RUnlock()
389-
390-
return b.connectionsState
391-
}
392-
393390
func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
394391
onDone := trace.DriverOnBalancerChooseEndpoint(
395392
b.driverConfig.Trace(), &ctx,
@@ -408,17 +405,17 @@ func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
408405
}
409406

410407
var (
411-
state = b.connections()
408+
connections = b.connections.Load()
412409
failedCount int
413410
)
414411

415412
defer func() {
416-
if failedCount*2 > state.PreferredCount() && b.discoveryRepeater != nil {
413+
if failedCount*2 > connections.PreferredCount() && b.discoveryRepeater != nil {
417414
b.discoveryRepeater.Force()
418415
}
419416
}()
420417

421-
c, failedCount = state.GetConnection(ctx)
418+
c, failedCount = connections.GetConn(ctx)
422419
if c == nil {
423420
return nil, xerrors.WithStackTrace(
424421
fmt.Errorf("cannot get connection from Balancer after %d attempts: %w", failedCount, ErrNoEndpoints),
@@ -429,10 +426,10 @@ func (b *Balancer) getConn(ctx context.Context) (c conn.Conn, err error) {
429426
}
430427

431428
func endpointsToConnections(p *conn.Pool, endpoints []endpoint.Endpoint) []conn.Conn {
432-
conns := make([]conn.Conn, 0, len(endpoints))
429+
connections := make([]conn.Conn, 0, len(endpoints))
433430
for _, e := range endpoints {
434-
conns = append(conns, p.Get(e))
431+
connections = append(connections, p.Get(e))
435432
}
436433

437-
return conns
434+
return connections
438435
}

0 commit comments

Comments
 (0)