diff --git a/internal/cmd/preflight/preflight.go b/internal/cmd/preflight/preflight.go index a9f9516ed..78ecf594b 100644 --- a/internal/cmd/preflight/preflight.go +++ b/internal/cmd/preflight/preflight.go @@ -22,6 +22,8 @@ import ( "time" "github.com/cockroachdb/field-eng-powertools/stopper" + "github.com/cockroachdb/replicator/internal/sinktest" + "github.com/cockroachdb/replicator/internal/staging/memo" "github.com/cockroachdb/replicator/internal/types" "github.com/cockroachdb/replicator/internal/util/stdpool" "github.com/pkg/errors" @@ -77,6 +79,8 @@ func testTargetConnection(ctx *stopper.Context, connString string) error { pool, err := stdpool.OpenTarget( ctx, connString, + stdpool.ProvideBackup(&memo.Memory{}, nil), + sinktest.NewBreakers(), stdpool.WithConnectionLifetime(5*time.Minute, time.Minute, 15*time.Second), stdpool.WithTransactionTimeout(time.Minute), ) diff --git a/internal/cmd/workload/config.go b/internal/cmd/workload/config.go index 26750cd07..e674d6b6f 100644 --- a/internal/cmd/workload/config.go +++ b/internal/cmd/workload/config.go @@ -59,6 +59,7 @@ type clientConfig struct { requestTimeout time.Duration resolvedInterval time.Duration retryMin, retryMax time.Duration + suffix string childTable ident.Table // Derived from targetSchema parentTable ident.Table // Derived from targetSchema @@ -85,6 +86,8 @@ func (c *clientConfig) Bind(flags *pflag.FlagSet) { "the maximum delay between HTTP retry attempts") flags.DurationVar(&c.retryMin, "retryMin", defaultRetryMin, "the minimum delay between HTTP retry attempts") + flags.StringVar(&c.suffix, "suffix", "", + "parent/child table suffix") flags.StringVar(&c.token, "token", "", "JWT bearer token if security is enabled") flags.StringVar(&c.url, "url", defaultURL, @@ -133,10 +136,10 @@ func (c *clientConfig) Preflight() error { } } if c.childTable.Empty() { - c.childTable = ident.NewTable(c.targetSchema, ident.New("child")) + c.childTable = ident.NewTable(c.targetSchema, ident.New("child"+c.suffix)) } if c.parentTable.Empty() { - c.parentTable = ident.NewTable(c.targetSchema, ident.New("parent")) + c.parentTable = ident.NewTable(c.targetSchema, ident.New("parent"+c.suffix)) } return nil } @@ -144,7 +147,9 @@ func (c *clientConfig) Preflight() error { func (c *clientConfig) createTables(ctx *stopper.Context, targetPool *types.TargetPool) error { // We need a 64-bit type. bigType := "BIGINT" - uniq := fmt.Sprintf("_%d_%d", os.Getpid(), rand.Int32N(10000)) + if c.suffix == "" { + c.suffix = fmt.Sprintf("_%d_%d", os.Getpid(), rand.Int32N(10000)) + } // Create the tables within the "current" schema specified on the // command-line. We'll use uniquely-named tables to ensure that @@ -179,8 +184,8 @@ func (c *clientConfig) createTables(ctx *stopper.Context, targetPool *types.Targ if err != nil { return err } - c.parentTable = ident.NewTable(c.targetSchema, ident.New("parent"+uniq)) - c.childTable = ident.NewTable(c.targetSchema, ident.New("child"+uniq)) + c.parentTable = ident.NewTable(c.targetSchema, ident.New("parent"+c.suffix)) + c.childTable = ident.NewTable(c.targetSchema, ident.New("child"+c.suffix)) if _, err := targetPool.ExecContext(ctx, fmt.Sprintf( `CREATE TABLE %s(parent %[2]s PRIMARY KEY, val %[2]s DEFAULT 0 NOT NULL)`, @@ -194,7 +199,7 @@ child %[4]s PRIMARY KEY, parent %[4]s NOT NULL, val %[4]s DEFAULT 0 NOT NULL, CONSTRAINT parent_fk%[2]s FOREIGN KEY(parent) REFERENCES %[3]s(parent) -)`, c.childTable, uniq, c.parentTable, bigType)); err != nil { +)`, c.childTable, c.suffix, c.parentTable, bigType)); err != nil { return errors.WithStack(err) } diff --git a/internal/cmd/workload/workload.go b/internal/cmd/workload/workload.go index 95a06658e..1e689430a 100644 --- a/internal/cmd/workload/workload.go +++ b/internal/cmd/workload/workload.go @@ -17,6 +17,8 @@ package workload import ( + "fmt" + "github.com/cockroachdb/field-eng-powertools/stopper" "github.com/cockroachdb/replicator/internal/source/cdc/server" "github.com/cockroachdb/replicator/internal/util/workload" @@ -71,6 +73,10 @@ func pcDemo() *cobra.Command { return err } + fmt.Printf("Changefeed target URL: %s\n", cfg.url) + fmt.Printf("Parent table: %s\n", cfg.parentTable) + fmt.Printf("Child table: %s\n", cfg.childTable) + if !serverCfg.HTTP.DisableAuth { if err := cfg.generateJWT(ctx, svr); err != nil { return err diff --git a/internal/cmd/workload/workload_test.go b/internal/cmd/workload/workload_test.go index ffaeeb69c..7288879f0 100644 --- a/internal/cmd/workload/workload_test.go +++ b/internal/cmd/workload/workload_test.go @@ -23,9 +23,11 @@ import ( "github.com/cockroachdb/field-eng-powertools/notify" "github.com/cockroachdb/field-eng-powertools/stopper" "github.com/cockroachdb/replicator/internal/sinkprod" + "github.com/cockroachdb/replicator/internal/sinktest" "github.com/cockroachdb/replicator/internal/sinktest/all" "github.com/cockroachdb/replicator/internal/sinktest/base" "github.com/cockroachdb/replicator/internal/source/cdc/server" + "github.com/cockroachdb/replicator/internal/types" "github.com/cockroachdb/replicator/internal/util/hlc" "github.com/cockroachdb/replicator/internal/util/stdserver" log "github.com/sirupsen/logrus" @@ -119,6 +121,166 @@ func TestWorkload(t *testing.T) { workload.CheckConsistent(ctx, t) } +func initSystem( + t *testing.T, +) ( + *runner, + *types.TableGroup, + *server.Server, + *all.Workload, + *stopper.Context, + *sinktest.Breakers, +) { + r := require.New(t) + + fixture, err := all.NewFixture(t) + r.NoError(err) + ctx := fixture.Context + + serverCfg := &server.Config{ + HTTP: stdserver.Config{ + BindAddr: "127.0.0.1:0", + GenerateSelfSigned: true, + }, + Staging: sinkprod.StagingConfig{ + CommonConfig: sinkprod.CommonConfig{ + Conn: fixture.StagingPool.ConnectionString, + }, + CreateSchema: true, + Schema: fixture.StagingDB.Schema(), + }, + Target: sinkprod.TargetConfig{ + CommonConfig: sinkprod.CommonConfig{ + Conn: fixture.TargetPool.ConnectionString, + }, + }, + } + + workload, group, err := fixture.NewWorkload(ctx, &all.WorkloadConfig{}) + r.NoError(err) + + cfg := &clientConfig{ + childTable: workload.Child.Name(), + parentTable: workload.Parent.Name(), + targetSchema: fixture.TargetSchema.Schema(), + } + + svr, err := cfg.newServer(ctx, serverCfg) + r.NoError(err) + + r.NoError(cfg.initURL(svr.GetListener())) + + r.NoError(cfg.generateJWT(ctx, svr)) + + // Create a runner, but inject the generator from above. This will + // allow us to validate the behavior later. + runner, err := cfg.newRunner(ctx, workload.GeneratorBase) + r.NoError(err) + + return runner, group, svr, workload, ctx, fixture.Breakers +} + +// This is a white-box test that creates a server and executes the +// workload against it for a few seconds, stops the workload, then +// starts it again, but this time with a failed connection to the +// target database. The target database connection then recovers +// after a period. +func TestColdStart(t *testing.T) { + t.Parallel() + const initialTime = 5 * time.Second + const targetDownTime = 5 * time.Second + const recoveryTime = 5 * time.Second + + runner, group, svr, workload, ctx, _ := initSystem(t) + + // Phase one: initial connection and run + + // Create a nested stopper, so we can run the workload generator + // for a period of time. + runnerCtx := stopper.WithContext(ctx) + runnerCtx.Go(func(runnerCtx *stopper.Context) error { + return runner.Run(runnerCtx) + }) + + r := require.New(t) + // Wait for a bit. + select { + case <-time.After(initialTime): + log.Info("waiting for runner context to finish") + runnerCtx.Stop(time.Second) + r.NoError(runnerCtx.Wait()) + case <-ctx.Stopping(): + r.Fail("test context stopping") + } + + var resolvedRange notify.Var[hlc.Range] + _, err := svr.Checkpoints.Start(ctx, group, &resolvedRange) + r.NoError(err) + for { + progress, changed := resolvedRange.Get() + if hlc.Compare(progress.Min(), runner.lastResolved) >= 0 { + break + } + log.Infof("waiting for resolved timestamp progress: %s vs %s", progress, runner.lastResolved) + select { + case <-changed: + case <-ctx.Done(): + r.NoError(ctx.Err()) + } + } + log.Infof("resolved timestamps have caught up; validating initial workload") + + workload.CheckConsistent(ctx, t) + + // Second phase: start the workload with the target down + runner, group, svr, workload, ctx, breakers := initSystem(t) + breakers.TargetConnectionFails.Store(true) + + // Create a nested stopper, so we can run the workload generator + // for a period of time. + runnerCtx = stopper.WithContext(ctx) + runnerCtx.Go(func(runnerCtx *stopper.Context) error { + return runner.Run(runnerCtx) + }) + + // Wait for a bit, re-enable target connections + select { + case <-time.After(targetDownTime): + breakers.TargetConnectionFails.Store(false) + case <-ctx.Stopping(): + r.Fail("test context stopping") + } + + // Third phase: target recovers, and we proceed + select { + case <-time.After(recoveryTime): + log.Info("waiting for runner context to finish") + runnerCtx.Stop(time.Second) + r.NoError(runnerCtx.Wait()) + case <-ctx.Stopping(): + r.Fail("test context stopping") + } + + var resolvedRangeFinal notify.Var[hlc.Range] + _, err = svr.Checkpoints.Start(ctx, group, &resolvedRangeFinal) + r.NoError(err) + for { + progress, changed := resolvedRangeFinal.Get() + if hlc.Compare(progress.Min(), runner.lastResolved) >= 0 { + break + } + log.Infof("waiting for resolved timestamp progress: %s vs %s", progress, runner.lastResolved) + select { + case <-changed: + case <-ctx.Done(): + r.NoError(ctx.Err()) + } + } + log.Infof("resolved timestamps have caught up; validating final workload") + + workload.CheckConsistent(ctx, t) +} + // This is a black-box test to ensure the demo command fires up. func TestDemoCommand(t *testing.T) { const testTime = 5 * time.Second diff --git a/internal/sinkprod/sinkprod.go b/internal/sinkprod/sinkprod.go index c8decf8e5..91d3e7331 100644 --- a/internal/sinkprod/sinkprod.go +++ b/internal/sinkprod/sinkprod.go @@ -21,6 +21,8 @@ package sinkprod import ( "time" + "github.com/cockroachdb/replicator/internal/sinktest" + "github.com/cockroachdb/replicator/internal/util/stdpool" "github.com/google/wire" ) @@ -29,7 +31,9 @@ var Set = wire.NewSet( ProvideStagingDB, ProvideStagingPool, ProvideTargetPool, + stdpool.ProvideBackup, ProvideStatementCache, + sinktest.NewBreakers, ) const ( diff --git a/internal/sinkprod/target.go b/internal/sinkprod/target.go index ead164011..55c08f854 100644 --- a/internal/sinkprod/target.go +++ b/internal/sinkprod/target.go @@ -20,6 +20,7 @@ import ( "time" "github.com/cockroachdb/field-eng-powertools/stopper" + "github.com/cockroachdb/replicator/internal/sinktest" "github.com/cockroachdb/replicator/internal/staging/version" "github.com/cockroachdb/replicator/internal/types" "github.com/cockroachdb/replicator/internal/util/diag" @@ -74,7 +75,12 @@ func (c *TargetConfig) Preflight() error { // Adding the Replicator version checker here is a bit of a hack. If // Wire supported eager dependencies, this would be an obvious use. func ProvideTargetPool( - ctx *stopper.Context, check *version.Checker, config *TargetConfig, diags *diag.Diagnostics, + ctx *stopper.Context, + check *version.Checker, + config *TargetConfig, + diags *diag.Diagnostics, + backup *stdpool.Backup, + breakers *sinktest.Breakers, ) (*types.TargetPool, error) { missing, err := check.Check(ctx) if err != nil { @@ -94,7 +100,7 @@ func ProvideTargetPool( stdpool.WithTransactionTimeout(config.ApplyTimeout), } - ret, err := stdpool.OpenTarget(ctx, config.Conn, options...) + ret, err := stdpool.OpenTarget(ctx, config.Conn, backup, breakers, options...) if err != nil { return nil, err } diff --git a/internal/sinktest/all/wire_gen.go b/internal/sinktest/all/wire_gen.go index 510921e37..62faa8354 100644 --- a/internal/sinktest/all/wire_gen.go +++ b/internal/sinktest/all/wire_gen.go @@ -7,6 +7,7 @@ package all import ( + "github.com/cockroachdb/replicator/internal/sinktest" "github.com/cockroachdb/replicator/internal/sinktest/base" "github.com/cockroachdb/replicator/internal/staging/checkpoint" "github.com/cockroachdb/replicator/internal/staging/memo" @@ -18,6 +19,7 @@ import ( "github.com/cockroachdb/replicator/internal/target/schemawatch" "github.com/cockroachdb/replicator/internal/util/applycfg" "github.com/cockroachdb/replicator/internal/util/diag" + "github.com/cockroachdb/replicator/internal/util/stdpool" "testing" ) @@ -26,6 +28,7 @@ import ( // NewFixture constructs a self-contained test fixture for all services // in the target sub-packages. func NewFixture(t testing.TB) (*Fixture, error) { + breakers := sinktest.NewBreakers() context := base.ProvideContext(t) diagnostics := diag.New(context) sourcePool, err := base.ProvideSourcePool(context, diagnostics) @@ -44,16 +47,22 @@ func NewFixture(t testing.TB) (*Fixture, error) { if err != nil { return nil, err } - targetPool, err := base.ProvideTargetPool(context, sourcePool, diagnostics) + memoMemo, err := memo.ProvideMemo(context, stagingPool, stagingSchema) + if err != nil { + return nil, err + } + backup := stdpool.ProvideBackup(memoMemo, stagingPool) + targetPool, err := base.ProvideTargetPool(context, sourcePool, backup, diagnostics, breakers) if err != nil { return nil, err } targetStatements := base.ProvideTargetStatements(context, targetPool) - targetSchema, err := base.ProvideTargetSchema(context, diagnostics, targetPool, targetStatements) + targetSchema, err := base.ProvideTargetSchema(context, diagnostics, targetPool, targetStatements, backup, breakers) if err != nil { return nil, err } fixture := &base.Fixture{ + Breakers: breakers, Context: context, SourcePool: sourcePool, SourceSchema: sourceSchema, @@ -71,12 +80,8 @@ func NewFixture(t testing.TB) (*Fixture, error) { if err != nil { return nil, err } - memoMemo, err := memo.ProvideMemo(context, stagingPool, stagingSchema) - if err != nil { - return nil, err - } - backup := schemawatch.ProvideBackup(memoMemo, stagingPool) - watchers, err := schemawatch.ProvideFactory(context, targetPool, diagnostics, backup) + schemawatchBackup := schemawatch.ProvideBackup(memoMemo, stagingPool) + watchers, err := schemawatch.ProvideFactory(context, targetPool, diagnostics, schemawatchBackup) if err != nil { return nil, err } diff --git a/internal/sinktest/base/injector.go b/internal/sinktest/base/injector.go index e21aee6af..bb710288b 100644 --- a/internal/sinktest/base/injector.go +++ b/internal/sinktest/base/injector.go @@ -27,5 +27,5 @@ import ( // NewFixture constructs a self-contained test fixture. func NewFixture(t testing.TB) (*Fixture, error) { - panic(wire.Build(TestSet)) + panic(wire.Build(TestSet, ProvideMemory)) } diff --git a/internal/sinktest/base/provider.go b/internal/sinktest/base/provider.go index ad3c78be5..d19365c5f 100644 --- a/internal/sinktest/base/provider.go +++ b/internal/sinktest/base/provider.go @@ -32,6 +32,7 @@ import ( "github.com/cockroachdb/field-eng-powertools/stopper" "github.com/cockroachdb/replicator/internal/sinktest" + "github.com/cockroachdb/replicator/internal/staging/memo" "github.com/cockroachdb/replicator/internal/types" "github.com/cockroachdb/replicator/internal/util/diag" "github.com/cockroachdb/replicator/internal/util/ident" @@ -101,15 +102,23 @@ var TestSet = wire.NewSet( ProvideTargetSchema, ProvideTargetStatements, diag.New, + sinktest.NewBreakers, + stdpool.ProvideBackup, wire.Bind(new(context.Context), new(*stopper.Context)), wire.Struct(new(Fixture), "*"), ) +// ProvideMemory returns the test implementation of types.Memo +func ProvideMemory() types.Memo { + return &memo.Memory{} +} + // Fixture can be used for tests that "just need a database", // without the other services provided by the target package. One can be // constructed by calling NewFixture. type Fixture struct { + Breakers *sinktest.Breakers // Breakers for the test Context *stopper.Context // The context for the test. SourcePool *types.SourcePool // Access to user-data tables and changefeed creation. SourceSchema sinktest.SourceSchema // A container for tables within SourcePool. @@ -186,6 +195,8 @@ func ProvideSourcePool(ctx *stopper.Context, diags *diag.Diagnostics) (*types.So tgt := *sourceConn log.Infof("source connect string: %s", tgt) ret, err := stdpool.OpenTarget(ctx, tgt, + stdpool.ProvideBackup(&memo.Memory{}, nil), + sinktest.NewBreakers(), stdpool.WithDiagnostics(diags, "source"), stdpool.WithTestControls(stdpool.TestControls{ WaitForStartup: true, @@ -277,7 +288,11 @@ func ProvideStagingPool(ctx *stopper.Context) (*types.StagingPool, error) { // ProvideTargetPool connects to the target database (which is most // often the same as the source database). func ProvideTargetPool( - ctx *stopper.Context, source *types.SourcePool, diags *diag.Diagnostics, + ctx *stopper.Context, + source *types.SourcePool, + backup *stdpool.Backup, + diags *diag.Diagnostics, + breakers *sinktest.Breakers, ) (*types.TargetPool, error) { tgt := *targetString if tgt == source.ConnectionString { @@ -285,15 +300,9 @@ func ProvideTargetPool( return (*types.TargetPool)(source), nil } log.Infof("target connect string: %s", tgt) - pool, err := stdpool.OpenTarget(ctx, *targetString, - stdpool.WithDiagnostics(diags, "target"), - stdpool.WithTestControls(stdpool.TestControls{ - WaitForStartup: true, - }), - stdpool.WithConnectionLifetime(time.Minute, 15*time.Second, 5*time.Second), - stdpool.WithPoolSize(32), - stdpool.WithTransactionTimeout(2*time.Minute), // Aligns with test case timeout. - ) + pool, err := stdpool.OpenTarget(ctx, *targetString, backup, breakers, stdpool.WithDiagnostics(diags, "target"), stdpool.WithTestControls(stdpool.TestControls{ + WaitForStartup: true, + }), stdpool.WithConnectionLifetime(time.Minute, 15*time.Second, 5*time.Second), stdpool.WithPoolSize(32), stdpool.WithTransactionTimeout(2*time.Minute)) if err != nil { return nil, err } @@ -317,6 +326,8 @@ func ProvideTargetSchema( diags *diag.Diagnostics, pool *types.TargetPool, stmts *types.TargetStatements, + backup *stdpool.Backup, + breakers *sinktest.Breakers, ) (sinktest.TargetSchema, error) { sch, err := provideSchema(ctx, pool, "tgt") ret := sinktest.TargetSchema(sch) @@ -338,7 +349,7 @@ func ProvideTargetSchema( if pool.Info().Product == types.ProductPostgreSQL { db, _ := sch.Split() conn := fmt.Sprintf("%s/%s", pool.ConnectionString, db.Raw()) - next, err := stdpool.OpenPgxAsTarget(ctx, conn, + next, err := stdpool.OpenPgxAsTarget(ctx, conn, backup, breakers, stdpool.WithDiagnostics(diags, "target_reopened")) if err != nil { return sinktest.TargetSchema{}, err @@ -358,7 +369,7 @@ func ProvideTargetSchema( u.User = url.UserPassword(sch.Raw(), DummyPassword) conn := u.String() - next, err := stdpool.OpenOracleAsTarget(ctx, conn, + next, err := stdpool.OpenOracleAsTarget(ctx, conn, backup, breakers, stdpool.WithDiagnostics(diags, "target_reopened")) if err != nil { return sinktest.TargetSchema{}, err diff --git a/internal/sinktest/base/wire_gen.go b/internal/sinktest/base/wire_gen.go index d3dde8f60..90c0fcbae 100644 --- a/internal/sinktest/base/wire_gen.go +++ b/internal/sinktest/base/wire_gen.go @@ -7,7 +7,9 @@ package base import ( + "github.com/cockroachdb/replicator/internal/sinktest" "github.com/cockroachdb/replicator/internal/util/diag" + "github.com/cockroachdb/replicator/internal/util/stdpool" "testing" ) @@ -15,6 +17,7 @@ import ( // NewFixture constructs a self-contained test fixture. func NewFixture(t testing.TB) (*Fixture, error) { + breakers := sinktest.NewBreakers() context := ProvideContext(t) diagnostics := diag.New(context) sourcePool, err := ProvideSourcePool(context, diagnostics) @@ -33,16 +36,19 @@ func NewFixture(t testing.TB) (*Fixture, error) { if err != nil { return nil, err } - targetPool, err := ProvideTargetPool(context, sourcePool, diagnostics) + memo := ProvideMemory() + backup := stdpool.ProvideBackup(memo, stagingPool) + targetPool, err := ProvideTargetPool(context, sourcePool, backup, diagnostics, breakers) if err != nil { return nil, err } targetStatements := ProvideTargetStatements(context, targetPool) - targetSchema, err := ProvideTargetSchema(context, diagnostics, targetPool, targetStatements) + targetSchema, err := ProvideTargetSchema(context, diagnostics, targetPool, targetStatements, backup, breakers) if err != nil { return nil, err } fixture := &Fixture{ + Breakers: breakers, Context: context, SourcePool: sourcePool, SourceSchema: sourceSchema, diff --git a/internal/sinktest/breakers.go b/internal/sinktest/breakers.go new file mode 100644 index 000000000..3919f4994 --- /dev/null +++ b/internal/sinktest/breakers.go @@ -0,0 +1,29 @@ +// Copyright 2024 The Cockroach Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package sinktest + +import "sync/atomic" + +// Breakers is a collection of shared-state runtime controls that cause aspects of the system to fail, for testing. +type Breakers struct { + TargetConnectionFails atomic.Bool +} + +// NewBreakers constructs a Breakers instance +func NewBreakers() *Breakers { + return &Breakers{} +} diff --git a/internal/source/cdc/server/wire_gen.go b/internal/source/cdc/server/wire_gen.go index 29fca9a17..177caae2f 100644 --- a/internal/source/cdc/server/wire_gen.go +++ b/internal/source/cdc/server/wire_gen.go @@ -20,6 +20,7 @@ import ( "github.com/cockroachdb/replicator/internal/sequencer/staging" "github.com/cockroachdb/replicator/internal/sequencer/switcher" "github.com/cockroachdb/replicator/internal/sinkprod" + "github.com/cockroachdb/replicator/internal/sinktest" "github.com/cockroachdb/replicator/internal/source/cdc" staging2 "github.com/cockroachdb/replicator/internal/staging" "github.com/cockroachdb/replicator/internal/staging/checkpoint" @@ -36,6 +37,7 @@ import ( "github.com/cockroachdb/replicator/internal/util/applycfg" "github.com/cockroachdb/replicator/internal/util/diag" "github.com/cockroachdb/replicator/internal/util/ident" + "github.com/cockroachdb/replicator/internal/util/stdpool" "github.com/cockroachdb/replicator/internal/util/stdserver" "github.com/google/wire" "net" @@ -82,7 +84,9 @@ func NewServer(ctx *stopper.Context, config *Config) (*Server, error) { return nil, err } checker := version.ProvideChecker(stagingPool, memoMemo) - targetPool, err := sinkprod.ProvideTargetPool(ctx, checker, targetConfig, diagnostics) + backup := stdpool.ProvideBackup(memoMemo, stagingPool) + breakers := sinktest.NewBreakers() + targetPool, err := sinkprod.ProvideTargetPool(ctx, checker, targetConfig, diagnostics, backup, breakers) if err != nil { return nil, err } @@ -91,8 +95,8 @@ func NewServer(ctx *stopper.Context, config *Config) (*Server, error) { return nil, err } dlqConfig := cdc.ProvideDLQConfig(cdcConfig) - backup := schemawatch.ProvideBackup(memoMemo, stagingPool) - watchers, err := schemawatch.ProvideFactory(ctx, targetPool, diagnostics, backup) + schemawatchBackup := schemawatch.ProvideBackup(memoMemo, stagingPool) + watchers, err := schemawatch.ProvideFactory(ctx, targetPool, diagnostics, schemawatchBackup) if err != nil { return nil, err } @@ -184,7 +188,9 @@ func newTestFixture(context *stopper.Context, config *Config) (*testFixture, fun } cdcConfig := &config.CDC checker := version.ProvideChecker(stagingPool, memoMemo) - targetPool, err := sinkprod.ProvideTargetPool(context, checker, targetConfig, diagnostics) + backup := stdpool.ProvideBackup(memoMemo, stagingPool) + breakers := sinktest.NewBreakers() + targetPool, err := sinkprod.ProvideTargetPool(context, checker, targetConfig, diagnostics, backup, breakers) if err != nil { return nil, nil, err } @@ -197,8 +203,8 @@ func newTestFixture(context *stopper.Context, config *Config) (*testFixture, fun return nil, nil, err } dlqConfig := cdc.ProvideDLQConfig(cdcConfig) - backup := schemawatch.ProvideBackup(memoMemo, stagingPool) - watchers, err := schemawatch.ProvideFactory(context, targetPool, diagnostics, backup) + schemawatchBackup := schemawatch.ProvideBackup(memoMemo, stagingPool) + watchers, err := schemawatch.ProvideFactory(context, targetPool, diagnostics, schemawatchBackup) if err != nil { return nil, nil, err } diff --git a/internal/source/kafka/wire_gen.go b/internal/source/kafka/wire_gen.go index e26544baf..6f6fbeec0 100644 --- a/internal/source/kafka/wire_gen.go +++ b/internal/source/kafka/wire_gen.go @@ -20,6 +20,7 @@ import ( "github.com/cockroachdb/replicator/internal/sequencer/staging" "github.com/cockroachdb/replicator/internal/sequencer/switcher" "github.com/cockroachdb/replicator/internal/sinkprod" + "github.com/cockroachdb/replicator/internal/sinktest" "github.com/cockroachdb/replicator/internal/staging/checkpoint" "github.com/cockroachdb/replicator/internal/staging/leases" "github.com/cockroachdb/replicator/internal/staging/memo" @@ -31,6 +32,7 @@ import ( "github.com/cockroachdb/replicator/internal/target/schemawatch" "github.com/cockroachdb/replicator/internal/util/applycfg" "github.com/cockroachdb/replicator/internal/util/diag" + "github.com/cockroachdb/replicator/internal/util/stdpool" ) // Injectors from injector.go: @@ -64,7 +66,9 @@ func Start(ctx *stopper.Context, config *Config) (*Kafka, error) { return nil, err } checker := version.ProvideChecker(stagingPool, memoMemo) - targetPool, err := sinkprod.ProvideTargetPool(ctx, checker, targetConfig, diagnostics) + backup := stdpool.ProvideBackup(memoMemo, stagingPool) + breakers := sinktest.NewBreakers() + targetPool, err := sinkprod.ProvideTargetPool(ctx, checker, targetConfig, diagnostics, backup, breakers) if err != nil { return nil, err } @@ -73,8 +77,8 @@ func Start(ctx *stopper.Context, config *Config) (*Kafka, error) { return nil, err } dlqConfig := &eagerConfig.DLQ - backup := schemawatch.ProvideBackup(memoMemo, stagingPool) - watchers, err := schemawatch.ProvideFactory(ctx, targetPool, diagnostics, backup) + schemawatchBackup := schemawatch.ProvideBackup(memoMemo, stagingPool) + watchers, err := schemawatch.ProvideFactory(ctx, targetPool, diagnostics, schemawatchBackup) if err != nil { return nil, err } diff --git a/internal/source/mylogical/wire_gen.go b/internal/source/mylogical/wire_gen.go index 9ccc16c94..5cdb759f0 100644 --- a/internal/source/mylogical/wire_gen.go +++ b/internal/source/mylogical/wire_gen.go @@ -14,6 +14,7 @@ import ( "github.com/cockroachdb/replicator/internal/sequencer/immediate" script2 "github.com/cockroachdb/replicator/internal/sequencer/script" "github.com/cockroachdb/replicator/internal/sinkprod" + "github.com/cockroachdb/replicator/internal/sinktest" "github.com/cockroachdb/replicator/internal/staging/memo" "github.com/cockroachdb/replicator/internal/staging/stage" "github.com/cockroachdb/replicator/internal/staging/version" @@ -23,6 +24,7 @@ import ( "github.com/cockroachdb/replicator/internal/target/schemawatch" "github.com/cockroachdb/replicator/internal/util/applycfg" "github.com/cockroachdb/replicator/internal/util/diag" + "github.com/cockroachdb/replicator/internal/util/stdpool" ) // Injectors from injector.go: @@ -59,7 +61,9 @@ func Start(ctx *stopper.Context, config *Config) (*MYLogical, error) { return nil, err } checker := version.ProvideChecker(stagingPool, memoMemo) - targetPool, err := sinkprod.ProvideTargetPool(ctx, checker, targetConfig, diagnostics) + backup := stdpool.ProvideBackup(memoMemo, stagingPool) + breakers := sinktest.NewBreakers() + targetPool, err := sinkprod.ProvideTargetPool(ctx, checker, targetConfig, diagnostics, backup, breakers) if err != nil { return nil, err } @@ -68,8 +72,8 @@ func Start(ctx *stopper.Context, config *Config) (*MYLogical, error) { return nil, err } dlqConfig := &eagerConfig.DLQ - backup := schemawatch.ProvideBackup(memoMemo, stagingPool) - watchers, err := schemawatch.ProvideFactory(ctx, targetPool, diagnostics, backup) + schemawatchBackup := schemawatch.ProvideBackup(memoMemo, stagingPool) + watchers, err := schemawatch.ProvideFactory(ctx, targetPool, diagnostics, schemawatchBackup) if err != nil { return nil, err } diff --git a/internal/source/objstore/wire_gen.go b/internal/source/objstore/wire_gen.go index 0c2214b1b..e9c99a60d 100644 --- a/internal/source/objstore/wire_gen.go +++ b/internal/source/objstore/wire_gen.go @@ -20,6 +20,7 @@ import ( "github.com/cockroachdb/replicator/internal/sequencer/staging" "github.com/cockroachdb/replicator/internal/sequencer/switcher" "github.com/cockroachdb/replicator/internal/sinkprod" + "github.com/cockroachdb/replicator/internal/sinktest" "github.com/cockroachdb/replicator/internal/staging/checkpoint" "github.com/cockroachdb/replicator/internal/staging/leases" "github.com/cockroachdb/replicator/internal/staging/memo" @@ -31,6 +32,7 @@ import ( "github.com/cockroachdb/replicator/internal/target/schemawatch" "github.com/cockroachdb/replicator/internal/util/applycfg" "github.com/cockroachdb/replicator/internal/util/diag" + "github.com/cockroachdb/replicator/internal/util/stdpool" ) // Injectors from injector.go: @@ -63,7 +65,9 @@ func Start(ctx *stopper.Context, config *Config) (*Objstore, error) { return nil, err } checker := version.ProvideChecker(stagingPool, memoMemo) - targetPool, err := sinkprod.ProvideTargetPool(ctx, checker, targetConfig, diagnostics) + backup := stdpool.ProvideBackup(memoMemo, stagingPool) + breakers := sinktest.NewBreakers() + targetPool, err := sinkprod.ProvideTargetPool(ctx, checker, targetConfig, diagnostics, backup, breakers) if err != nil { return nil, err } @@ -72,8 +76,8 @@ func Start(ctx *stopper.Context, config *Config) (*Objstore, error) { return nil, err } dlqConfig := &eagerConfig.DLQ - backup := schemawatch.ProvideBackup(memoMemo, stagingPool) - watchers, err := schemawatch.ProvideFactory(ctx, targetPool, diagnostics, backup) + schemawatchBackup := schemawatch.ProvideBackup(memoMemo, stagingPool) + watchers, err := schemawatch.ProvideFactory(ctx, targetPool, diagnostics, schemawatchBackup) if err != nil { return nil, err } diff --git a/internal/source/pglogical/wire_gen.go b/internal/source/pglogical/wire_gen.go index ee31ae9cb..fe0f5403f 100644 --- a/internal/source/pglogical/wire_gen.go +++ b/internal/source/pglogical/wire_gen.go @@ -14,6 +14,7 @@ import ( "github.com/cockroachdb/replicator/internal/sequencer/immediate" script2 "github.com/cockroachdb/replicator/internal/sequencer/script" "github.com/cockroachdb/replicator/internal/sinkprod" + "github.com/cockroachdb/replicator/internal/sinktest" "github.com/cockroachdb/replicator/internal/staging/memo" "github.com/cockroachdb/replicator/internal/staging/stage" "github.com/cockroachdb/replicator/internal/staging/version" @@ -23,6 +24,7 @@ import ( "github.com/cockroachdb/replicator/internal/target/schemawatch" "github.com/cockroachdb/replicator/internal/util/applycfg" "github.com/cockroachdb/replicator/internal/util/diag" + "github.com/cockroachdb/replicator/internal/util/stdpool" ) // Injectors from injector.go: @@ -59,7 +61,9 @@ func Start(context *stopper.Context, config *Config) (*PGLogical, error) { return nil, err } checker := version.ProvideChecker(stagingPool, memoMemo) - targetPool, err := sinkprod.ProvideTargetPool(context, checker, targetConfig, diagnostics) + backup := stdpool.ProvideBackup(memoMemo, stagingPool) + breakers := sinktest.NewBreakers() + targetPool, err := sinkprod.ProvideTargetPool(context, checker, targetConfig, diagnostics, backup, breakers) if err != nil { return nil, err } @@ -68,8 +72,8 @@ func Start(context *stopper.Context, config *Config) (*PGLogical, error) { return nil, err } dlqConfig := &eagerConfig.DLQ - backup := schemawatch.ProvideBackup(memoMemo, stagingPool) - watchers, err := schemawatch.ProvideFactory(context, targetPool, diagnostics, backup) + schemawatchBackup := schemawatch.ProvideBackup(memoMemo, stagingPool) + watchers, err := schemawatch.ProvideFactory(context, targetPool, diagnostics, schemawatchBackup) if err != nil { return nil, err } diff --git a/internal/target/schemawatch/watcher.go b/internal/target/schemawatch/watcher.go index ba37aa0a6..3023df87b 100644 --- a/internal/target/schemawatch/watcher.go +++ b/internal/target/schemawatch/watcher.go @@ -64,16 +64,16 @@ func newWatcher( var data *types.SchemaData if err := tx.Ping(); err != nil { - log.WithError(err).Warn("failed to ping target database; trying to restore backup") + log.WithError(err).Warn("failed to ping target database; trying to restore schema backup") // On start up, when the target database is down, fall back to // the staging memo about the table schema data, err = b.restore(ctx, schema) if err != nil { - return nil, err + return nil, errors.Wrapf(err, "failed to restore schema backup for %s", schema) } if data == nil { // Restore saw no errors, but we also didn't have a value - return nil, errors.Wrapf(err, "no backup schema data for %s", schema) + return nil, fmt.Errorf("no backup schema data for %s", schema) } } else { // Initial data load to sanity-check and make ready. diff --git a/internal/util/stdpool/backup.go b/internal/util/stdpool/backup.go new file mode 100644 index 000000000..ea6286159 --- /dev/null +++ b/internal/util/stdpool/backup.go @@ -0,0 +1,71 @@ +// Copyright 2024 The Cockroach Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package stdpool + +import ( + "fmt" + "net/url" + + "github.com/cockroachdb/field-eng-powertools/stopper" + "github.com/cockroachdb/replicator/internal/types" + "github.com/pkg/errors" +) + +// ProvideBackup provides the version backup service +func ProvideBackup(memo types.Memo, stagingPool *types.StagingPool) *Backup { + return &Backup{ + memo: memo, + stagingPool: stagingPool, + } +} + +// Backup backs up the version string for the target database +type Backup struct { + memo types.Memo + stagingPool *types.StagingPool +} + +// Store the value +func (b *Backup) Store(ctx *stopper.Context, connectString, ver string) error { + key, err := b.targetVersionMemoKey(connectString) + if err != nil { + return err + } + return b.memo.Put(ctx, b.stagingPool, key, []byte(ver)) +} + +// Load the value +func (b *Backup) Load(ctx *stopper.Context, connectString string) (string, error) { + key, err := b.targetVersionMemoKey(connectString) + if err != nil { + return "", err + } + bs, err := b.memo.Get(ctx, b.stagingPool, key) + if err != nil { + return "", err + } + return string(bs), nil +} + +func (b *Backup) targetVersionMemoKey(connectString string) (string, error) { + u, err := url.Parse(connectString) + if err != nil { + return "", errors.Wrap(err, "could not parse connection string") + } + key := fmt.Sprintf("%s://%s:%s/dbVersion", u.Scheme, u.Hostname(), u.Port()) + return key, nil +} diff --git a/internal/util/stdpool/my.go b/internal/util/stdpool/my.go index 533a6aa35..1f602eb1e 100644 --- a/internal/util/stdpool/my.go +++ b/internal/util/stdpool/my.go @@ -18,9 +18,11 @@ package stdpool import ( + "context" "crypto/tls" "database/sql" sqldriver "database/sql/driver" + errors2 "errors" "fmt" "net/url" "strconv" @@ -29,6 +31,7 @@ import ( "time" "github.com/cockroachdb/field-eng-powertools/stopper" + "github.com/cockroachdb/replicator/internal/sinktest" "github.com/cockroachdb/replicator/internal/types" "github.com/cockroachdb/replicator/internal/util/secure" "github.com/go-sql-driver/mysql" @@ -87,25 +90,126 @@ func (o *onomastic) newName(prefix string) string { // OpenMySQLAsTarget opens a database connection, returning it as // a single connection. func OpenMySQLAsTarget( - ctx *stopper.Context, connectString string, url *url.URL, options ...Option, + ctx *stopper.Context, + connectString string, + backup *Backup, + breakers *sinktest.Breakers, + options ...Option, ) (*types.TargetPool, error) { var tc TestControls if err := attachOptions(ctx, &tc, options); err != nil { return nil, err } - // Use a unique name for each call of OpenMySQLAsTarget. - tlsConfigName := tlsConfigNames.newName("mysql_driver") - tlsConfigs, err := secure.ParseTLSOptions(url) + + connector, err := newTLSFallbackConnector(connectString, &breakers.TargetConnectionFails) if err != nil { + return nil, errors.WithStack(err) + } + + ret := &types.TargetPool{ + DB: sql.OpenDB(connector), + PoolInfo: types.PoolInfo{ + ConnectionString: connectString, + Product: types.ProductMySQL, + + ErrCode: myErrCode, + IsDeferrable: myErrDeferrable, + ShouldRetry: myErrRetryable, + }, + } + + if tc.WaitForStartup { + if err := awaitMySQLReady(ctx, ret); err != nil { + return nil, err + } + if ctx.IsStopping() { + return nil, ctx.Err() + } + } + + // Testing that connection is usable. + if err := ret.QueryRow("SELECT VERSION();").Scan(&ret.Version); err != nil { + queryErr := errors.Wrap(err, "could not query version") + ver, err := backup.Load(ctx, connectString) + if err != nil { + return nil, errors2.Join(queryErr, errors.Wrap(err, "could not load version from staging")) + } + if ver == "" { + return nil, fmt.Errorf("empty version loaded from staging") + } + ret.Version = ver + } else if err := backup.Store(ctx, connectString, ret.Version); err != nil { + return nil, errors.Wrap(err, "could not store version to staging") + } + log.Infof("Version %s.", ret.Version) + if strings.Contains(ret.Version, "MariaDB") { + ret.PoolInfo.Product = types.ProductMariaDB + } + if err := setTableHint(ret.Info()); err != nil { return nil, err } - var ret *types.TargetPool - var transportError error + // If debug is enabled we print sql mode and ssl info. + if log.IsLevelEnabled(log.DebugLevel) { + var mode string + if err := ret.QueryRow("SELECT @@sql_mode").Scan(&mode); err != nil { + log.Errorf("could not query sql mode %s", err.Error()) + } + var varName, cipher string + if err := ret.QueryRow("SHOW STATUS LIKE 'Ssl_cipher';").Scan(&varName, &cipher); err != nil { + log.Errorf("could not query ssl info %s", err.Error()) + } + log.Debugf("Mode %s. %s %s", mode, varName, cipher) + ret.Version = fmt.Sprintf("%s cipher[%s]", ret.Version, cipher) + } + if err := attachOptions(ctx, ret.DB, options); err != nil { + return nil, err + } + if err := attachOptions(ctx, &ret.PoolInfo, options); err != nil { + return nil, err + } + + return ret, nil +} + +type mySQLDriver struct{} + +func (m *mySQLDriver) Open(name string) (sqldriver.Conn, error) { + u, err := url.Parse(name) + if err != nil { + return nil, err + } + tlsConfigs, err := secure.ParseTLSOptions(u) + if err != nil { + return nil, err + } + + ctor := tlsFallbackConnector{ + tlsConfigs: tlsConfigs, + } + return ctor.Connect(context.Background()) +} + +type tlsFallbackConnector struct { + tlsConfigs []*tls.Config + connectionURL *url.URL + delegate sqldriver.Connector + failConnections *atomic.Bool +} + +func (t *tlsFallbackConnector) findDelegate() (sqldriver.Connector, error) { + if t.delegate != nil { + return t.delegate, nil + } + + // Use a unique name for each call of OpenMySQLAsTarget. + tlsConfigName := tlsConfigNames.newName("mysql_driver") + + var lastErr error // Try all possible transport options. // The first one that works is the one we will use. - for _, tlsConfig := range tlsConfigs { + for _, tlsConfig := range t.tlsConfigs { mysql.DeregisterTLSConfig(tlsConfigName) - mySQLString, err := getConnString(url, tlsConfigName, tlsConfig) + mySQLString, err := getConnString(t.connectionURL, tlsConfigName, tlsConfig) if err != nil { return nil, errors.WithStack(err) } @@ -117,6 +221,7 @@ func OpenMySQLAsTarget( if err != nil { return nil, errors.WithStack(err) } + // This impacts the use of db.Query() and friends, but not // prepared statements. We set this as a workaround for MariaDB // queries where the parameter types in prepared statements @@ -124,79 +229,82 @@ func OpenMySQLAsTarget( // the ? markers with the literal values, rather than doing a // prepare, bind, exec. cfg.InterpolateParams = true - connector, err := mysql.NewConnector(cfg) + mySqlConnector, err := mysql.NewConnector(cfg) if err != nil { - log.WithError(err).Trace("failed to connect to database server") - transportError = err - // Try a different option. - continue - } - ret = &types.TargetPool{ - DB: sql.OpenDB(connector), - PoolInfo: types.PoolInfo{ - ConnectionString: connectString, - Product: types.ProductMySQL, - - ErrCode: myErrCode, - IsDeferrable: myErrDeferrable, - ShouldRetry: myErrRetryable, - }, - } - ctx.Defer(func() { _ = ret.Close() }) - - ping: - if err := ret.Ping(); err != nil { - // For some errors, we retry. - if tc.WaitForStartup && isMySQLStartupError(err) { - log.WithError(err).Info("waiting for database to become ready") - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(10 * time.Second): - goto ping - } - } - transportError = err - _ = ret.Close() - // Try a different option. - continue - } - // Testing that connection is usable. - if err := ret.QueryRow("SELECT VERSION();").Scan(&ret.Version); err != nil { - return nil, errors.Wrap(err, "could not query version") + return nil, errors.WithStack(err) } - log.Infof("Version %s.", ret.Version) - if strings.Contains(ret.Version, "MariaDB") { - ret.PoolInfo.Product = types.ProductMariaDB + + myDB := sql.OpenDB(mySqlConnector) + defer myDB.Close() + if err := myDB.Ping(); err != nil { + lastErr = err + } else { + return mySqlConnector, nil } - if err := setTableHint(ret.Info()); err != nil { + } + + // Nothing worked; return the last error + return nil, lastErr +} + +func (t *tlsFallbackConnector) Connect(ctx context.Context) (sqldriver.Conn, error) { + if t.failConnections.Load() { + return nil, fmt.Errorf("testing connection failure") + } + if t.delegate == nil { + delegate, err := t.findDelegate() + if err != nil { return nil, err } - // If debug is enabled we print sql mode and ssl info. - if log.IsLevelEnabled(log.DebugLevel) { - var mode string - if err := ret.QueryRow("SELECT @@sql_mode").Scan(&mode); err != nil { - log.Errorf("could not query sql mode %s", err.Error()) - } - var varName, cipher string - if err := ret.QueryRow("SHOW STATUS LIKE 'Ssl_cipher';").Scan(&varName, &cipher); err != nil { - log.Errorf("could not query ssl info %s", err.Error()) - } - log.Debugf("Mode %s. %s %s", mode, varName, cipher) - ret.Version = fmt.Sprintf("%s cipher[%s]", ret.Version, cipher) + if delegate == nil { + // This shouldn't happen; we should either find a delegate, or error + return nil, fmt.Errorf("could not find MySQL connector delegate") } - if err := attachOptions(ctx, ret.DB, options); err != nil { - return nil, err + t.delegate = delegate + } + return t.delegate.Connect(ctx) +} + +func (t *tlsFallbackConnector) Driver() sqldriver.Driver { + return &mySQLDriver{} +} + +func newTLSFallbackConnector( + connStr string, failConnections *atomic.Bool, +) (sqldriver.Connector, error) { + u, err := url.Parse(connStr) + if err != nil { + return nil, err + } + + tlsConfigs, err := secure.ParseTLSOptions(u) + if err != nil { + return nil, err + } + + return &tlsFallbackConnector{ + tlsConfigs: tlsConfigs, + connectionURL: u, + failConnections: failConnections, + }, nil +} + +func awaitMySQLReady(ctx *stopper.Context, db *types.TargetPool) error { + for { + err := db.Ping() + if err == nil || !isMySQLStartupError(err) { + return err } - if err := attachOptions(ctx, &ret.PoolInfo, options); err != nil { - return nil, err + + // We have a startup error + log.WithError(err).Info("waiting for database to become ready") + + select { + case <-ctx.Stopping(): + return nil + case <-time.After(10 * time.Second): } - // The connection meets the client/server requirements, - // no need to try other transport options. - return ret, nil } - // All the options have been exhausted, returning the last error. - return nil, transportError } // TODO (silvano): verify error codes. diff --git a/internal/util/stdpool/ora.go b/internal/util/stdpool/ora.go index 824f9226a..ebf271066 100644 --- a/internal/util/stdpool/ora.go +++ b/internal/util/stdpool/ora.go @@ -19,11 +19,17 @@ package stdpool import ( + "context" "database/sql" + "database/sql/driver" + errors2 "errors" + "fmt" "strconv" + "sync/atomic" "time" "github.com/cockroachdb/field-eng-powertools/stopper" + "github.com/cockroachdb/replicator/internal/sinktest" "github.com/cockroachdb/replicator/internal/types" "github.com/godror/godror" "github.com/pkg/errors" @@ -78,7 +84,11 @@ func oraErrorRetryable(err error) bool { // OpenOracleAsTarget opens a connection to an Oracle database endpoint and // return it as a [types.TargetPool]. func OpenOracleAsTarget( - ctx *stopper.Context, connectString string, options ...Option, + ctx *stopper.Context, + connectString string, + backup *Backup, + breakers *sinktest.Breakers, + options ...Option, ) (*types.TargetPool, error) { var tc TestControls if err := attachOptions(ctx, &tc, options); err != nil { @@ -101,8 +111,13 @@ func OpenOracleAsTarget( } connector := godror.NewConnector(params) + proxy := &oraConnectorProxy{ + flag: &breakers.TargetConnectionFails, + delegate: connector, + } + ret := &types.TargetPool{ - DB: sql.OpenDB(connector), + DB: sql.OpenDB(proxy), PoolInfo: types.PoolInfo{ ConnectionString: connectString, Product: types.ProductOracle, @@ -114,23 +129,21 @@ func OpenOracleAsTarget( } ctx.Defer(func() { _ = ret.Close() }) -ping: - if err := ret.Ping(); err != nil { - if tc.WaitForStartup && isOracleStartupError(err) { - log.WithError(err).Info("waiting for database to become ready") - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(10 * time.Second): - goto ping - } + if err := ret.QueryRow("SELECT banner FROM V$VERSION").Scan(&ret.Version); err != nil { + queryErr := errors.Wrap(err, "could not query version") + log.WithError(queryErr).Warn("could not query database version; trying to load from staging backup") + ver, err := backup.Load(ctx, connectString) + if err != nil { + return nil, errors2.Join(queryErr, errors.Wrap(err, "could not load version from staging")) + } + if ver == "" { + return nil, fmt.Errorf("empty version loaded from staging") } - return nil, errors.Wrap(err, "could not ping the database") + ret.Version = ver + } else if err := backup.Store(ctx, connectString, ret.Version); err != nil { + return nil, errors.Wrap(err, "could not store version to staging") } - if err := ret.QueryRow("SELECT banner FROM V$VERSION").Scan(&ret.Version); err != nil { - return nil, errors.Wrap(err, "could not query version") - } if err := setTableHint(ret.Info()); err != nil { return nil, err } @@ -159,3 +172,19 @@ func isOracleStartupError(err error) bool { } return oracleStartupErrors[code] } + +type oraConnectorProxy struct { + flag *atomic.Bool + delegate driver.Connector +} + +func (t *oraConnectorProxy) Connect(ctx context.Context) (driver.Conn, error) { + if t.flag.Load() { + return nil, errors.New("testing connection failure") + } + return t.delegate.Connect(ctx) +} + +func (t *oraConnectorProxy) Driver() driver.Driver { + return t.delegate.Driver() +} diff --git a/internal/util/stdpool/ora_unsupported.go b/internal/util/stdpool/ora_unsupported.go index e95799c17..aa7d91e60 100644 --- a/internal/util/stdpool/ora_unsupported.go +++ b/internal/util/stdpool/ora_unsupported.go @@ -20,13 +20,18 @@ package stdpool import ( "github.com/cockroachdb/field-eng-powertools/stopper" + "github.com/cockroachdb/replicator/internal/sinktest" "github.com/cockroachdb/replicator/internal/types" "github.com/pkg/errors" ) // OpenOracleAsTarget returns an unsupported error. func OpenOracleAsTarget( - ctx *stopper.Context, connectString string, options ...Option, + ctx *stopper.Context, + connectString string, + backup *Backup, + breakers *sinktest.Breakers, + options ...Option, ) (*types.TargetPool, error) { return nil, errors.New("this build does not support Oracle Database") } diff --git a/internal/util/stdpool/pgx.go b/internal/util/stdpool/pgx.go index 80e33ab45..6df79395b 100644 --- a/internal/util/stdpool/pgx.go +++ b/internal/util/stdpool/pgx.go @@ -20,9 +20,13 @@ package stdpool import ( "context" "database/sql" + errors2 "errors" + "fmt" "strings" + "sync/atomic" "github.com/cockroachdb/field-eng-powertools/stopper" + "github.com/cockroachdb/replicator/internal/sinktest" "github.com/cockroachdb/replicator/internal/types" "github.com/cockroachdb/replicator/internal/util/retry" "github.com/jackc/pgx/v5" @@ -132,11 +136,15 @@ func OpenPgxAsStaging( // OpenPgxAsTarget uses pgx to open a database connection, returning it as a // stdlib pool. func OpenPgxAsTarget( - ctx *stopper.Context, connectString string, options ...Option, + ctx *stopper.Context, + connectString string, + backup *Backup, + breakers *sinktest.Breakers, + options ...Option, ) (*types.TargetPool, error) { db, err := openPgx(ctx, connectString, options, func(ctx *stopper.Context, cfg *pgxpool.Config) (*sql.DB, error) { - impl := stdlib.OpenDB(*cfg.ConnConfig) + impl := stdlib.OpenDB(*cfg.ConnConfig, pgxBeforeConnectBreaker(&breakers.TargetConnectionFails)) ctx.Defer(func() { _ = impl.Close() }) return impl, nil }) @@ -155,10 +163,21 @@ func OpenPgxAsTarget( } if err := retry.Retry(ctx, ret, func(ctx context.Context) error { - return ret.QueryRowContext(ctx, "SELECT version()").Scan(&ret.Version) + return ret.QueryRow("SELECT VERSION();").Scan(&ret.Version) }); err != nil { - return nil, errors.Wrap(err, "could not determine cluster version") + queryErr := errors.Wrap(err, "could not query version") + ver, err := backup.Load(ctx, connectString) + if err != nil { + return nil, errors2.Join(queryErr, errors.Wrap(err, "could not load version from staging")) + } + if ver == "" { + return nil, fmt.Errorf("empty version loaded from staging") + } + ret.Version = ver + } else if err := backup.Store(ctx, connectString, ret.Version); err != nil { + return nil, errors.Wrap(err, "could not store version to staging") } + if err := setTableHint(ret.Info()); err != nil { return nil, err } @@ -205,3 +224,12 @@ func openPgx[P attachable]( return ret, attachOptions(ctx, ret, options) } + +func pgxBeforeConnectBreaker(flag *atomic.Bool) stdlib.OptionOpenDB { + return stdlib.OptionBeforeConnect(func(ctx context.Context, config *pgx.ConnConfig) error { + if flag.Load() { + return errors.New("testing connection failure") + } + return nil + }) +} diff --git a/internal/util/stdpool/target.go b/internal/util/stdpool/target.go index bad2c9ec9..35c58fb7e 100644 --- a/internal/util/stdpool/target.go +++ b/internal/util/stdpool/target.go @@ -21,6 +21,7 @@ import ( "strings" "github.com/cockroachdb/field-eng-powertools/stopper" + "github.com/cockroachdb/replicator/internal/sinktest" "github.com/cockroachdb/replicator/internal/types" "github.com/pkg/errors" ) @@ -28,7 +29,11 @@ import ( // OpenTarget selects from target connector implementations based on the // URL scheme contained in the connection string. func OpenTarget( - ctx *stopper.Context, connectString string, options ...Option, + ctx *stopper.Context, + connectString string, + backup *Backup, + breakers *sinktest.Breakers, + options ...Option, ) (*types.TargetPool, error) { u, err := url.Parse(connectString) if err != nil { @@ -37,11 +42,11 @@ func OpenTarget( switch strings.ToLower(u.Scheme) { case "mysql": - return OpenMySQLAsTarget(ctx, connectString, u, options...) + return OpenMySQLAsTarget(ctx, connectString, backup, breakers, options...) case "pg", "pgx", "postgres", "postgresql": - return OpenPgxAsTarget(ctx, connectString, options...) + return OpenPgxAsTarget(ctx, connectString, backup, breakers, options...) case "ora", "oracle": - return OpenOracleAsTarget(ctx, connectString, options...) + return OpenOracleAsTarget(ctx, connectString, backup, breakers, options...) default: return nil, errors.Errorf("unknown URL scheme: %s", u.Scheme) } diff --git a/qol/fmt.sh b/qol/fmt.sh new file mode 100755 index 000000000..8c13bf54a --- /dev/null +++ b/qol/fmt.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +# Ensure copyright in new files +go run github.com/google/addlicense -c "The Cockroach Authors" -l apache -s -v -check -ignore '**/testdata/**/*.sql' -ignore '**/thirdparty/**' -ignore '*.ddt' -ignore '*.md' -ignore '.idea/*' . + +# Standardize formatting +go run github.com/cockroachdb/crlfmt -w -ignore '_gen.go|plsql_parser.go' . + +# Lints +go run golang.org/x/lint/golint -set_exit_status $(go list ./... | grep -v "/thirdparty" | grep -v "/oracleparser") +go run honnef.co/go/tools/cmd/staticcheck -checks all $(go list ./... | grep -v "/thirdparty" | grep -v "/oracleparser")