From b8a2f9f4ef2eff88173971b30eecd55b73077d89 Mon Sep 17 00:00:00 2001 From: Bob Vawter Date: Tue, 31 Oct 2023 19:03:26 -0400 Subject: [PATCH] wire: Make stopper.Context globally available This change cleans up some longstanding lifecycle issues around graceful draining. In the past, we relied on the ability to return a cancel function to Wire to shut down background goroutines that might be started by some service or another. This works, but it would be nice to standardize on the stopper.Context utility type for running goroutines that can benefit from a soft-exit condition. The crux of this change is to make `*stopper.Context` available from all injectors. The `context.Context` type remains injectable, but it becomes an alias for the `*stopper.Context`. The top-level injectors now require a `*stopper.Context`, which gives the callers the ability to trigger a soft shutdown of background processes before tearing down the entire stack. This should reduce the amount of log-spam during shutdown. This change will also make it far more straightforward to implement metrics that tick (e.g. #560) since a service or object that needs to produce a ticking metric has the option to use the stack-global `*stopper.Context.Go()`. Increase use of our `Go()` method also provides a future opportunity to improve observabiity around background processes within a cdc-sink binary. --- internal/cmd/fslogical/fslogical.go | 4 +- internal/cmd/mylogical/mylogical.go | 4 +- internal/cmd/pglogical/pglogical.go | 4 +- internal/cmd/start/start.go | 4 +- internal/script/injector.go | 9 +++- internal/script/wire_gen.go | 8 +-- internal/sinktest/all/wire_gen.go | 7 +-- internal/sinktest/base/provider.go | 20 ++++++-- internal/sinktest/base/wire_gen.go | 5 +- internal/source/cdc/provider.go | 15 +++--- internal/source/cdc/resolver.go | 29 ++--------- internal/source/cdc/test_fixture.go | 4 ++ internal/source/cdc/wire_gen.go | 3 +- internal/source/fslogical/injector.go | 5 +- internal/source/fslogical/provider.go | 39 +++++++-------- internal/source/fslogical/wire_gen.go | 46 ++++++++--------- internal/source/logical/factory.go | 57 +++++++++++----------- internal/source/logical/logical_test.go | 9 ++-- internal/source/logical/loop.go | 26 +++++----- internal/source/mylogical/injector.go | 4 +- internal/source/mylogical/provider.go | 7 +-- internal/source/mylogical/wire_gen.go | 7 ++- internal/source/pglogical/injector.go | 4 +- internal/source/pglogical/provider.go | 7 +-- internal/source/pglogical/wire_gen.go | 7 ++- internal/source/server/injector.go | 4 +- internal/source/server/integration_test.go | 3 +- internal/source/server/test_fixture.go | 4 +- internal/source/server/wire_gen.go | 36 ++++++-------- internal/staging/stage/factory.go | 10 ++-- internal/staging/stage/provider.go | 6 ++- internal/staging/stage/stage.go | 3 +- main.go | 27 ++++++++-- 33 files changed, 230 insertions(+), 197 deletions(-) diff --git a/internal/cmd/fslogical/fslogical.go b/internal/cmd/fslogical/fslogical.go index f52b3cc2d..1d80adbf0 100644 --- a/internal/cmd/fslogical/fslogical.go +++ b/internal/cmd/fslogical/fslogical.go @@ -21,6 +21,7 @@ package fslogical import ( "github.com/cockroachdb/cdc-sink/internal/source/fslogical" "github.com/cockroachdb/cdc-sink/internal/util/stdlogical" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/spf13/cobra" ) @@ -31,7 +32,8 @@ func Command() *cobra.Command { Bind: cfg.Bind, Short: "start a Google Cloud Firestore logical replication feed", Start: func(cmd *cobra.Command) (any, func(), error) { - return fslogical.Start(cmd.Context(), cfg) + // main.go provides this stopper. + return fslogical.Start(stopper.From(cmd.Context()), cfg) }, Use: "fslogical", }) diff --git a/internal/cmd/mylogical/mylogical.go b/internal/cmd/mylogical/mylogical.go index ee17cbafd..68304ca6b 100644 --- a/internal/cmd/mylogical/mylogical.go +++ b/internal/cmd/mylogical/mylogical.go @@ -21,6 +21,7 @@ package mylogical import ( "github.com/cockroachdb/cdc-sink/internal/source/mylogical" "github.com/cockroachdb/cdc-sink/internal/util/stdlogical" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/spf13/cobra" ) @@ -31,7 +32,8 @@ func Command() *cobra.Command { Bind: cfg.Bind, Short: "start a mySQL replication feed", Start: func(cmd *cobra.Command) (any, func(), error) { - return mylogical.Start(cmd.Context(), cfg) + // main.go provides a stopper. + return mylogical.Start(stopper.From(cmd.Context()), cfg) }, Use: "mylogical", }) diff --git a/internal/cmd/pglogical/pglogical.go b/internal/cmd/pglogical/pglogical.go index 814e8ea30..4de5d83da 100644 --- a/internal/cmd/pglogical/pglogical.go +++ b/internal/cmd/pglogical/pglogical.go @@ -21,6 +21,7 @@ package pglogical import ( "github.com/cockroachdb/cdc-sink/internal/source/pglogical" "github.com/cockroachdb/cdc-sink/internal/util/stdlogical" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/spf13/cobra" ) @@ -31,7 +32,8 @@ func Command() *cobra.Command { Bind: cfg.Bind, Short: "start a pg logical replication feed", Start: func(cmd *cobra.Command) (any, func(), error) { - return pglogical.Start(cmd.Context(), cfg) + // main.go provides a stopper. + return pglogical.Start(stopper.From(cmd.Context()), cfg) }, Use: "pglogical", }) diff --git a/internal/cmd/start/start.go b/internal/cmd/start/start.go index cb17ca0b3..88cc6fefb 100644 --- a/internal/cmd/start/start.go +++ b/internal/cmd/start/start.go @@ -20,6 +20,7 @@ package start import ( "github.com/cockroachdb/cdc-sink/internal/source/server" "github.com/cockroachdb/cdc-sink/internal/util/stdlogical" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/spf13/cobra" ) @@ -30,7 +31,8 @@ func Command() *cobra.Command { Bind: cfg.Bind, Short: "start the server", Start: func(cmd *cobra.Command) (any, func(), error) { - return server.NewServer(cmd.Context(), &cfg) + // main.go gives us a stopper, just unwrap it. + return server.NewServer(stopper.From(cmd.Context()), &cfg) }, Use: "start", }) diff --git a/internal/script/injector.go b/internal/script/injector.go index bdeb1c2f0..248d8eceb 100644 --- a/internal/script/injector.go +++ b/internal/script/injector.go @@ -27,24 +27,29 @@ import ( "github.com/cockroachdb/cdc-sink/internal/types" "github.com/cockroachdb/cdc-sink/internal/util/applycfg" "github.com/cockroachdb/cdc-sink/internal/util/diag" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" ) // Evaluate the loaded script. func Evaluate( - ctx context.Context, + ctx *stopper.Context, loader *Loader, configs *applycfg.Configs, diags *diag.Diagnostics, targetSchema TargetSchema, watchers types.Watchers, ) (*UserScript, error) { - panic(wire.Build(ProvideUserScript)) + panic(wire.Build( + ProvideUserScript, + wire.Bind(new(context.Context), new(*stopper.Context)), + )) } func newScriptFromFixture(*all.Fixture, *Config, TargetSchema) (*UserScript, error) { panic(wire.Build( Set, + wire.Bind(new(context.Context), new(*stopper.Context)), wire.FieldsOf(new(*all.Fixture), "Diagnostics", "Fixture", "Configs", "Watchers"), wire.FieldsOf(new(*base.Fixture), "Context"), )) diff --git a/internal/script/wire_gen.go b/internal/script/wire_gen.go index 052cf1f8a..9039dea35 100644 --- a/internal/script/wire_gen.go +++ b/internal/script/wire_gen.go @@ -7,11 +7,11 @@ package script import ( - "context" "github.com/cockroachdb/cdc-sink/internal/sinktest/all" "github.com/cockroachdb/cdc-sink/internal/types" "github.com/cockroachdb/cdc-sink/internal/util/applycfg" "github.com/cockroachdb/cdc-sink/internal/util/diag" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" ) import ( @@ -21,7 +21,7 @@ import ( // Injectors from injector.go: // Evaluate the loaded script. -func Evaluate(ctx context.Context, loader *Loader, configs *applycfg.Configs, diags *diag.Diagnostics, targetSchema TargetSchema, watchers types.Watchers) (*UserScript, error) { +func Evaluate(ctx *stopper.Context, loader *Loader, configs *applycfg.Configs, diags *diag.Diagnostics, targetSchema TargetSchema, watchers types.Watchers) (*UserScript, error) { userScript, err := ProvideUserScript(ctx, configs, loader, diags, targetSchema, watchers) if err != nil { return nil, err @@ -31,7 +31,7 @@ func Evaluate(ctx context.Context, loader *Loader, configs *applycfg.Configs, di func newScriptFromFixture(fixture *all.Fixture, config *Config, targetSchema TargetSchema) (*UserScript, error) { baseFixture := fixture.Fixture - contextContext := baseFixture.Context + context := baseFixture.Context configs := fixture.Configs loader, err := ProvideLoader(config) if err != nil { @@ -39,7 +39,7 @@ func newScriptFromFixture(fixture *all.Fixture, config *Config, targetSchema Tar } diagnostics := fixture.Diagnostics watchers := fixture.Watchers - userScript, err := ProvideUserScript(contextContext, configs, loader, diagnostics, targetSchema, watchers) + userScript, err := ProvideUserScript(context, configs, loader, diagnostics, targetSchema, watchers) if err != nil { return nil, err } diff --git a/internal/sinktest/all/wire_gen.go b/internal/sinktest/all/wire_gen.go index 7a2dce40c..45687f596 100644 --- a/internal/sinktest/all/wire_gen.go +++ b/internal/sinktest/all/wire_gen.go @@ -23,10 +23,7 @@ import ( // NewFixture constructs a self-contained test fixture for all services // in the target sub-packages. func NewFixture() (*Fixture, func(), error) { - context, cleanup, err := base.ProvideContext() - if err != nil { - return nil, nil, err - } + context, cleanup := base.ProvideContext() diagnostics, cleanup2 := diag.New(context) sourcePool, cleanup3, err := base.ProvideSourcePool(context, diagnostics) if err != nil { @@ -160,7 +157,7 @@ func NewFixture() (*Fixture, func(), error) { cleanup() return nil, nil, err } - stagers := stage.ProvideFactory(stagingPool, stagingSchema) + stagers := stage.ProvideFactory(stagingPool, stagingSchema, context) checker := version.ProvideChecker(stagingPool, memoMemo) watcher, err := ProvideWatcher(context, targetSchema, watchers) if err != nil { diff --git a/internal/sinktest/base/provider.go b/internal/sinktest/base/provider.go index c257a0de8..4a60589de 100644 --- a/internal/sinktest/base/provider.go +++ b/internal/sinktest/base/provider.go @@ -34,6 +34,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/util/retry" "github.com/cockroachdb/cdc-sink/internal/util/stdpool" "github.com/cockroachdb/cdc-sink/internal/util/stmtcache" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" "github.com/pkg/errors" log "github.com/sirupsen/logrus" @@ -95,6 +96,7 @@ var TestSet = wire.NewSet( ProvideTargetStatements, diag.New, + wire.Bind(new(context.Context), new(*stopper.Context)), wire.Struct(new(Fixture), "*"), ) @@ -102,7 +104,7 @@ var TestSet = wire.NewSet( // without the other services provided by the target package. One can be // constructed by calling NewFixture. type Fixture struct { - Context context.Context // The context for the test. + Context *stopper.Context // The context for the test. SourcePool *types.SourcePool // Access to user-data tables and changefeed creation. SourceSchema sinktest.SourceSchema // A container for tables within SourcePool. StagingPool *types.StagingPool // Access to __cdc_sink database. @@ -138,9 +140,19 @@ var caseTimout = flag.Duration( // ProvideContext returns an execution context that is associated with a // singleton connection to a CockroachDB cluster. -func ProvideContext() (context.Context, func(), error) { - ctx, cancel := context.WithTimeout(context.Background(), *caseTimout) - return ctx, cancel, nil +func ProvideContext() (*stopper.Context, func()) { + ctx := stopper.WithContext(context.Background()) + ctx.Go(func() error { + select { + case <-ctx.Stopping(): + // Clean shutdown, do nothing. + case <-time.After(*caseTimout): + // Just cancel immediately. + ctx.Stop(0) + } + return nil + }) + return ctx, func() { ctx.Stop(100 * time.Millisecond) } } // ProvideSourcePool connects to the source database. If the source is a diff --git a/internal/sinktest/base/wire_gen.go b/internal/sinktest/base/wire_gen.go index dccfcc054..4e92ae2b8 100644 --- a/internal/sinktest/base/wire_gen.go +++ b/internal/sinktest/base/wire_gen.go @@ -14,10 +14,7 @@ import ( // NewFixture constructs a self-contained test fixture. func NewFixture() (*Fixture, func(), error) { - context, cleanup, err := ProvideContext() - if err != nil { - return nil, nil, err - } + context, cleanup := ProvideContext() diagnostics, cleanup2 := diag.New(context) sourcePool, cleanup3, err := ProvideSourcePool(context, diagnostics) if err != nil { diff --git a/internal/source/cdc/provider.go b/internal/source/cdc/provider.go index ec4bde9d0..6dc0173ac 100644 --- a/internal/source/cdc/provider.go +++ b/internal/source/cdc/provider.go @@ -17,12 +17,12 @@ package cdc import ( - "context" "fmt" "github.com/cockroachdb/cdc-sink/internal/source/logical" "github.com/cockroachdb/cdc-sink/internal/types" "github.com/cockroachdb/cdc-sink/internal/util/ident" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" "github.com/pkg/errors" ) @@ -55,7 +55,7 @@ func ProvideMetaTable(cfg *Config) MetaTable { // ProvideResolvers is called by Wire. func ProvideResolvers( - ctx context.Context, + ctx *stopper.Context, cfg *Config, leases types.Leases, loops *logical.Factory, @@ -63,9 +63,9 @@ func ProvideResolvers( pool *types.StagingPool, stagers types.Stagers, watchers types.Watchers, -) (*Resolvers, func(), error) { +) (*Resolvers, error) { if _, err := pool.Exec(ctx, fmt.Sprintf(schema, metaTable.Table())); err != nil { - return nil, nil, errors.WithStack(err) + return nil, errors.WithStack(err) } ret := &Resolvers{ @@ -75,6 +75,7 @@ func ProvideResolvers( metaTable: metaTable.Table(), pool: pool, stagers: stagers, + stop: ctx, watchers: watchers, } ret.mu.instances = &ident.SchemaMap[*logical.Loop]{} @@ -82,13 +83,13 @@ func ProvideResolvers( // Resume from previous state. schemas, err := ScanForTargetSchemas(ctx, pool, ret.metaTable) if err != nil { - return nil, nil, err + return nil, err } for _, schema := range schemas { if _, _, err := ret.get(ctx, schema); err != nil { - return nil, nil, errors.Wrapf(err, "could not bootstrap resolver for schema %s", schema) + return nil, errors.Wrapf(err, "could not bootstrap resolver for schema %s", schema) } } - return ret, ret.close, nil + return ret, nil } diff --git a/internal/source/cdc/resolver.go b/internal/source/cdc/resolver.go index 142b2a636..2597bb027 100644 --- a/internal/source/cdc/resolver.go +++ b/internal/source/cdc/resolver.go @@ -595,33 +595,15 @@ type Resolvers struct { metaTable ident.Table pool *types.StagingPool stagers types.Stagers + stop *stopper.Context // Manage lifecycle of background processes. watchers types.Watchers mu struct { sync.Mutex - cleanups []func() instances *ident.SchemaMap[*logical.Loop] } } -// close will drain any running resolver loops. -func (r *Resolvers) close() { - r.mu.Lock() - defer r.mu.Unlock() - - // Cancel each loop. - for _, cancel := range r.mu.cleanups { - cancel() - } - // Wait for shutdown. - _ = r.mu.instances.Range(func(_ ident.Schema, l *logical.Loop) error { - <-l.Stopped() - return nil - }) - r.mu.cleanups = nil - r.mu.instances = nil -} - // get creates or returns the [logical.Loop] and the enclosed resolver. func (r *Resolvers) get( ctx context.Context, target ident.Schema, @@ -643,7 +625,7 @@ func (r *Resolvers) get( return nil, ret, nil } - loop, cleanup, err := r.loops.Start(&logical.LoopConfig{ + loop, err := r.loops.Start(r.stop, &logical.LoopConfig{ Dialect: ret, LoopName: "changefeed-" + target.Raw(), TargetSchema: target, @@ -655,13 +637,8 @@ func (r *Resolvers) get( r.mu.instances.Put(target, loop) // Start a goroutine to retire old data. - stop := stopper.WithContext(context.Background()) - ret.retireLoop(stop) + ret.retireLoop(r.stop) - r.mu.cleanups = append(r.mu.cleanups, - cleanup, - func() { stop.Stop(time.Second) }, - ) return loop, ret, nil } diff --git a/internal/source/cdc/test_fixture.go b/internal/source/cdc/test_fixture.go index 21ba36ca5..410c86fa3 100644 --- a/internal/source/cdc/test_fixture.go +++ b/internal/source/cdc/test_fixture.go @@ -20,6 +20,8 @@ package cdc import ( + "context" + "github.com/cockroachdb/cdc-sink/internal/script" "github.com/cockroachdb/cdc-sink/internal/sinktest/all" "github.com/cockroachdb/cdc-sink/internal/sinktest/base" @@ -28,6 +30,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/staging/leases" "github.com/cockroachdb/cdc-sink/internal/target" "github.com/cockroachdb/cdc-sink/internal/util/diag" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" ) @@ -50,6 +53,7 @@ func newTestFixture(*all.Fixture, *Config) (*testFixture, func(), error) { target.Set, trust.New, // Is valid to use as a provider. wire.Struct(new(testFixture), "*"), + wire.Bind(new(context.Context), new(*stopper.Context)), wire.Bind(new(logical.Config), new(*Config)), )) } diff --git a/internal/source/cdc/wire_gen.go b/internal/source/cdc/wire_gen.go index 3c3225f6e..18b5e166e 100644 --- a/internal/source/cdc/wire_gen.go +++ b/internal/source/cdc/wire_gen.go @@ -121,7 +121,7 @@ func newTestFixture(fixture *all.Fixture, config *Config) (*testFixture, func(), } metaTable := ProvideMetaTable(config) stagers := fixture.Stagers - resolvers, cleanup8, err := ProvideResolvers(context, config, typesLeases, factory, metaTable, stagingPool, stagers, watchers) + resolvers, err := ProvideResolvers(context, config, typesLeases, factory, metaTable, stagingPool, stagers, watchers) if err != nil { cleanup7() cleanup6() @@ -147,7 +147,6 @@ func newTestFixture(fixture *all.Fixture, config *Config) (*testFixture, func(), Resolvers: resolvers, } return cdcTestFixture, func() { - cleanup8() cleanup7() cleanup6() cleanup5() diff --git a/internal/source/fslogical/injector.go b/internal/source/fslogical/injector.go index 9e1c166d0..d530bd303 100644 --- a/internal/source/fslogical/injector.go +++ b/internal/source/fslogical/injector.go @@ -29,13 +29,15 @@ import ( "github.com/cockroachdb/cdc-sink/internal/staging" "github.com/cockroachdb/cdc-sink/internal/target" "github.com/cockroachdb/cdc-sink/internal/util/diag" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" ) // Start creates a PostgreSQL logical replication loop using the // provided configuration. -func Start(context.Context, *Config) (*FSLogical, func(), error) { +func Start(*stopper.Context, *Config) (*FSLogical, func(), error) { panic(wire.Build( + wire.Bind(new(context.Context), new(*stopper.Context)), wire.Bind(new(logical.Config), new(*Config)), wire.Struct(new(FSLogical), "*"), ProvideFirestoreClient, @@ -53,6 +55,7 @@ func Start(context.Context, *Config) (*FSLogical, func(), error) { // Build remaining testable components from a common fixture. func startLoopsFromFixture(*all.Fixture, *Config) ([]*logical.Loop, func(), error) { panic(wire.Build( + wire.Bind(new(context.Context), new(*stopper.Context)), wire.Bind(new(logical.Config), new(*Config)), wire.FieldsOf(new(*base.Fixture), "Context"), wire.FieldsOf(new(*all.Fixture), diff --git a/internal/source/fslogical/provider.go b/internal/source/fslogical/provider.go index c0b4cd6fc..0fe8c7def 100644 --- a/internal/source/fslogical/provider.go +++ b/internal/source/fslogical/provider.go @@ -29,6 +29,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/source/logical" "github.com/cockroachdb/cdc-sink/internal/types" "github.com/cockroachdb/cdc-sink/internal/util/ident" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/golang/groupcache/lru" log "github.com/sirupsen/logrus" "google.golang.org/api/option" @@ -51,7 +52,7 @@ var enableWipe bool // ProvideLoops is called by wire to construct a logical-replication // loop for each configured collection/table pair. func ProvideLoops( - ctx context.Context, + ctx *stopper.Context, cfg *Config, fs *firestore.Client, loops *logical.Factory, @@ -59,22 +60,15 @@ func ProvideLoops( pool *types.StagingPool, st *Tombstones, userscript *script.UserScript, -) ([]*logical.Loop, func(), error) { +) ([]*logical.Loop, error) { if err := cfg.Preflight(); err != nil { - return nil, nil, err + return nil, err } idx := 0 ret := make([]*logical.Loop, userscript.Sources.Len()) recurseFilter := &ident.Map[struct{}]{} - cancels := make([]func(), userscript.Sources.Len()) - cancel := func() { - for _, fn := range cancels { - fn() - } - } - err := userscript.Sources.Range(func(sourceName ident.Ident, source *script.Source) error { var isGroup bool var sourcePath string @@ -112,7 +106,7 @@ func ProvideLoops( loopCfg.LoopName = sourceName.Raw() var err error - ret[idx], cancels[idx], err = loops.Start(loopCfg) + ret[idx], err = loops.Start(ctx, loopCfg) if err != nil { return err } @@ -121,11 +115,10 @@ func ProvideLoops( return nil }) if err != nil { - cancel() - return nil, nil, err + return nil, err } - return ret, cancel, nil + return ret, nil } // ProvideFirestoreClient is called by wire. If a local emulator is in @@ -174,12 +167,16 @@ func ProvideScriptTarget(cfg *Config) script.TargetSchema { // ProvideTombstones is called by wire to construct a helper that // manages document tombstones. func ProvideTombstones( - cfg *Config, fs *firestore.Client, loops *logical.Factory, userscript *script.UserScript, -) (*Tombstones, func(), error) { + ctx *stopper.Context, + cfg *Config, + fs *firestore.Client, + loops *logical.Factory, + userscript *script.UserScript, +) (*Tombstones, error) { ret := &Tombstones{cfg: cfg} if cfg.TombstoneCollection == "" { log.Trace("no tombstone collection was configured") - return ret, nil, nil + return ret, nil } ret.coll = fs.Collection(cfg.TombstoneCollection) @@ -188,7 +185,7 @@ func ProvideTombstones( ret.deletesTo.Put(source, dest.DeletesTo) return nil }); err != nil { - return nil, nil, err + return nil, err } ret.source = ident.New(cfg.TombstoneCollection) ret.mu.cache = &lru.Cache{MaxEntries: 1_000_000} @@ -196,11 +193,11 @@ func ProvideTombstones( loopConfig := cfg.LoopConfig.Copy() loopConfig.Dialect = ret loopConfig.LoopName = cfg.TombstoneCollection - _, cancel, err := loops.Start(loopConfig) + _, err := loops.Start(ctx, loopConfig) if err != nil { - return nil, nil, err + return nil, err } - return ret, cancel, nil + return ret, nil } // Wipe any leftover documents from testing. diff --git a/internal/source/fslogical/wire_gen.go b/internal/source/fslogical/wire_gen.go index 2e0610e88..9c8dbfac1 100644 --- a/internal/source/fslogical/wire_gen.go +++ b/internal/source/fslogical/wire_gen.go @@ -7,7 +7,6 @@ package fslogical import ( - "context" "github.com/cockroachdb/cdc-sink/internal/script" "github.com/cockroachdb/cdc-sink/internal/sinktest/all" "github.com/cockroachdb/cdc-sink/internal/source/logical" @@ -18,14 +17,15 @@ import ( "github.com/cockroachdb/cdc-sink/internal/target/schemawatch" "github.com/cockroachdb/cdc-sink/internal/util/applycfg" "github.com/cockroachdb/cdc-sink/internal/util/diag" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" ) // Injectors from injector.go: // Start creates a PostgreSQL logical replication loop using the // provided configuration. -func Start(contextContext context.Context, config *Config) (*FSLogical, func(), error) { - diagnostics, cleanup := diag.New(contextContext) +func Start(context *stopper.Context, config *Config) (*FSLogical, func(), error) { + diagnostics, cleanup := diag.New(context) configs, err := applycfg.ProvideConfigs(diagnostics) if err != nil { cleanup() @@ -47,7 +47,7 @@ func Start(contextContext context.Context, config *Config) (*FSLogical, func(), cleanup() return nil, nil, err } - targetPool, cleanup2, err := logical.ProvideTargetPool(contextContext, baseConfig, diagnostics) + targetPool, cleanup2, err := logical.ProvideTargetPool(context, baseConfig, diagnostics) if err != nil { cleanup() return nil, nil, err @@ -58,14 +58,14 @@ func Start(contextContext context.Context, config *Config) (*FSLogical, func(), cleanup() return nil, nil, err } - userScript, err := script.ProvideUserScript(contextContext, configs, loader, diagnostics, targetSchema, watchers) + userScript, err := script.ProvideUserScript(context, configs, loader, diagnostics, targetSchema, watchers) if err != nil { cleanup3() cleanup2() cleanup() return nil, nil, err } - client, cleanup4, err := ProvideFirestoreClient(contextContext, config, userScript) + client, cleanup4, err := ProvideFirestoreClient(context, config, userScript) if err != nil { cleanup3() cleanup2() @@ -91,7 +91,7 @@ func Start(contextContext context.Context, config *Config) (*FSLogical, func(), cleanup() return nil, nil, err } - stagingPool, cleanup7, err := logical.ProvideStagingPool(contextContext, baseConfig, diagnostics) + stagingPool, cleanup7, err := logical.ProvideStagingPool(context, baseConfig, diagnostics) if err != nil { cleanup6() cleanup5() @@ -112,7 +112,7 @@ func Start(contextContext context.Context, config *Config) (*FSLogical, func(), cleanup() return nil, nil, err } - memoMemo, err := memo.ProvideMemo(contextContext, stagingPool, stagingSchema) + memoMemo, err := memo.ProvideMemo(context, stagingPool, stagingSchema) if err != nil { cleanup7() cleanup6() @@ -124,7 +124,7 @@ func Start(contextContext context.Context, config *Config) (*FSLogical, func(), return nil, nil, err } checker := version.ProvideChecker(stagingPool, memoMemo) - factory, err := logical.ProvideFactory(contextContext, appliers, configs, baseConfig, diagnostics, memoMemo, loader, stagingPool, targetPool, watchers, checker) + factory, err := logical.ProvideFactory(context, appliers, configs, baseConfig, diagnostics, memoMemo, loader, stagingPool, targetPool, watchers, checker) if err != nil { cleanup7() cleanup6() @@ -135,7 +135,7 @@ func Start(contextContext context.Context, config *Config) (*FSLogical, func(), cleanup() return nil, nil, err } - tombstones, cleanup8, err := ProvideTombstones(config, client, factory, userScript) + tombstones, err := ProvideTombstones(context, config, client, factory, userScript) if err != nil { cleanup7() cleanup6() @@ -146,9 +146,8 @@ func Start(contextContext context.Context, config *Config) (*FSLogical, func(), cleanup() return nil, nil, err } - v, cleanup9, err := ProvideLoops(contextContext, config, client, factory, memoMemo, stagingPool, tombstones, userScript) + v, err := ProvideLoops(context, config, client, factory, memoMemo, stagingPool, tombstones, userScript) if err != nil { - cleanup8() cleanup7() cleanup6() cleanup5() @@ -163,8 +162,6 @@ func Start(contextContext context.Context, config *Config) (*FSLogical, func(), Loops: v, } return fsLogical, func() { - cleanup9() - cleanup8() cleanup7() cleanup6() cleanup5() @@ -178,7 +175,7 @@ func Start(contextContext context.Context, config *Config) (*FSLogical, func(), // Build remaining testable components from a common fixture. func startLoopsFromFixture(fixture *all.Fixture, config *Config) ([]*logical.Loop, func(), error) { baseFixture := fixture.Fixture - contextContext := baseFixture.Context + context := baseFixture.Context configs := fixture.Configs scriptConfig, err := logical.ProvideUserScriptConfig(config) if err != nil { @@ -188,14 +185,14 @@ func startLoopsFromFixture(fixture *all.Fixture, config *Config) ([]*logical.Loo if err != nil { return nil, nil, err } - diagnostics, cleanup := diag.New(contextContext) + diagnostics, cleanup := diag.New(context) targetSchema := ProvideScriptTarget(config) baseConfig, err := logical.ProvideBaseConfig(config, loader) if err != nil { cleanup() return nil, nil, err } - targetPool, cleanup2, err := logical.ProvideTargetPool(contextContext, baseConfig, diagnostics) + targetPool, cleanup2, err := logical.ProvideTargetPool(context, baseConfig, diagnostics) if err != nil { cleanup() return nil, nil, err @@ -206,14 +203,14 @@ func startLoopsFromFixture(fixture *all.Fixture, config *Config) ([]*logical.Loo cleanup() return nil, nil, err } - userScript, err := script.ProvideUserScript(contextContext, configs, loader, diagnostics, targetSchema, watchers) + userScript, err := script.ProvideUserScript(context, configs, loader, diagnostics, targetSchema, watchers) if err != nil { cleanup3() cleanup2() cleanup() return nil, nil, err } - client, cleanup4, err := ProvideFirestoreClient(contextContext, config, userScript) + client, cleanup4, err := ProvideFirestoreClient(context, config, userScript) if err != nil { cleanup3() cleanup2() @@ -240,7 +237,7 @@ func startLoopsFromFixture(fixture *all.Fixture, config *Config) ([]*logical.Loo return nil, nil, err } typesMemo := fixture.Memo - stagingPool, cleanup7, err := logical.ProvideStagingPool(contextContext, baseConfig, diagnostics) + stagingPool, cleanup7, err := logical.ProvideStagingPool(context, baseConfig, diagnostics) if err != nil { cleanup6() cleanup5() @@ -251,7 +248,7 @@ func startLoopsFromFixture(fixture *all.Fixture, config *Config) ([]*logical.Loo return nil, nil, err } checker := fixture.VersionChecker - factory, err := logical.ProvideFactory(contextContext, appliers, configs, baseConfig, diagnostics, typesMemo, loader, stagingPool, targetPool, watchers, checker) + factory, err := logical.ProvideFactory(context, appliers, configs, baseConfig, diagnostics, typesMemo, loader, stagingPool, targetPool, watchers, checker) if err != nil { cleanup7() cleanup6() @@ -262,7 +259,7 @@ func startLoopsFromFixture(fixture *all.Fixture, config *Config) ([]*logical.Loo cleanup() return nil, nil, err } - tombstones, cleanup8, err := ProvideTombstones(config, client, factory, userScript) + tombstones, err := ProvideTombstones(context, config, client, factory, userScript) if err != nil { cleanup7() cleanup6() @@ -273,9 +270,8 @@ func startLoopsFromFixture(fixture *all.Fixture, config *Config) ([]*logical.Loo cleanup() return nil, nil, err } - v, cleanup9, err := ProvideLoops(contextContext, config, client, factory, typesMemo, stagingPool, tombstones, userScript) + v, err := ProvideLoops(context, config, client, factory, typesMemo, stagingPool, tombstones, userScript) if err != nil { - cleanup8() cleanup7() cleanup6() cleanup5() @@ -286,8 +282,6 @@ func startLoopsFromFixture(fixture *all.Fixture, config *Config) ([]*logical.Loo return nil, nil, err } return v, func() { - cleanup9() - cleanup8() cleanup7() cleanup6() cleanup5() diff --git a/internal/source/logical/factory.go b/internal/source/logical/factory.go index ab0554056..f5c9d4315 100644 --- a/internal/source/logical/factory.go +++ b/internal/source/logical/factory.go @@ -60,7 +60,8 @@ func (f *Factory) Immediate(ctx context.Context, target ident.Schema) (Batcher, // no-ops and wait to exit. The implementation of Process would make // the Events / Batcher available externally. This has the downside // of actually needing to start the loop goroutines. - fake, cancel, err := f.newLoop(stopper.From(ctx), &LoopConfig{ + stop := stopper.WithContext(ctx) + fake, err := f.newLoop(stop, &LoopConfig{ Dialect: &fakeDialect{}, LoopName: fmt.Sprintf("immediate-%s", target.Raw()), TargetSchema: target, @@ -68,6 +69,10 @@ func (f *Factory) Immediate(ctx context.Context, target ident.Schema) (Batcher, if err != nil { return nil, nil, err } + cancel := func() { + stop.Stop(f.baseConfig.ApplyTimeout) + <-stop.Done() + } if f.baseConfig.Immediate { return fake.loop.events.fan, cancel, nil @@ -76,32 +81,27 @@ func (f *Factory) Immediate(ctx context.Context, target ident.Schema) (Batcher, } // Start constructs a new replication Loop. -func (f *Factory) Start(config *LoopConfig) (*Loop, func(), error) { +func (f *Factory) Start(ctx *stopper.Context, config *LoopConfig) (*Loop, error) { var err error // Ensure the configuration is set up and validated. config, err = f.expandConfig(config) if err != nil { - return nil, nil, err + return nil, err } // Construct the new loop and start it. - stop := stopper.WithContext(context.Background()) - loop, cleanup, err := f.newLoop(stop, config) + ctx = stopper.WithContext(ctx) + loop, err := f.newLoop(ctx, config) if err != nil { - return nil, nil, err - } - go loop.loop.run() - - // Perform a graceful shutdown and wait for the loop to exit. - grace := f.baseConfig.ApplyTimeout - cancel := func() { - stop.Stop(grace) - <-loop.Stopped() - cleanup() + return nil, err } + ctx.Go(func() error { + loop.loop.run() + return nil + }) - return loop, cancel, nil + return loop, nil } // expandConfig returns a preflighted copy of the configuration. @@ -121,10 +121,10 @@ func (f *Factory) expandConfig(config *LoopConfig) (*LoopConfig, error) { } // newLoop constructs a loop, but does not start it. -func (f *Factory) newLoop(ctx *stopper.Context, config *LoopConfig) (*Loop, func(), error) { +func (f *Factory) newLoop(ctx *stopper.Context, config *LoopConfig) (*Loop, error) { watcher, err := f.watchers.Get(ctx, config.TargetSchema) if err != nil { - return nil, nil, err + return nil, err } config = config.Copy() config.Dialect = WithChaos(config.Dialect, f.baseConfig.ChaosProb) @@ -135,7 +135,7 @@ func (f *Factory) newLoop(ctx *stopper.Context, config *LoopConfig) (*Loop, func } initialPoint, err := loop.loadConsistentPoint(ctx) if err != nil { - return nil, nil, err + return nil, err } loop.consistentPoint.Set(initialPoint) @@ -161,7 +161,7 @@ func (f *Factory) newLoop(ctx *stopper.Context, config *LoopConfig) (*Loop, func } else { // Sanity-check that there are no FKs defined. if len(watcher.Get().Order) > 1 { - return nil, nil, errors.New("the destination database has tables with foreign keys, " + + return nil, errors.New("the destination database has tables with foreign keys, " + "but support for FKs is not enabled") } } @@ -169,11 +169,14 @@ func (f *Factory) newLoop(ctx *stopper.Context, config *LoopConfig) (*Loop, func // Create a branch in the diagnostics reporting for the loop. loopDiags, err := f.diags.Wrap(config.LoopName) if err != nil { - return nil, nil, err + return nil, err } - cancel := func() { + // Unregister the loop on shutdown. + ctx.Go(func() error { + <-ctx.Stopping() f.diags.Unregister(config.LoopName) - } + return nil + }) userscript, err := script.Evaluate( ctx, @@ -184,8 +187,7 @@ func (f *Factory) newLoop(ctx *stopper.Context, config *LoopConfig) (*Loop, func f.watchers, ) if err != nil { - cancel() - return nil, nil, errors.Wrapf(err, "could not initialize userscript for %s", config.LoopName) + return nil, errors.Wrapf(err, "could not initialize userscript for %s", config.LoopName) } // Apply logic and configurations defined by the user-script. @@ -206,11 +208,10 @@ func (f *Factory) newLoop(ctx *stopper.Context, config *LoopConfig) (*Loop, func loop.metrics.backfillStatus = backfillStatus.WithLabelValues(config.LoopName) if err := loopDiags.Register("loop", loop); err != nil { - cancel() - return nil, nil, err + return nil, err } - return &Loop{loop, initialPoint}, cancel, nil + return &Loop{loop, initialPoint}, nil } // singletonChannel returns a channel that emits a single value and is diff --git a/internal/source/logical/logical_test.go b/internal/source/logical/logical_test.go index f476eeb3a..0d1873e6a 100644 --- a/internal/source/logical/logical_test.go +++ b/internal/source/logical/logical_test.go @@ -30,6 +30,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/util/batches" "github.com/cockroachdb/cdc-sink/internal/util/ident" "github.com/cockroachdb/cdc-sink/internal/util/stamp" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -196,7 +197,8 @@ func testLogicalSmoke(t *testing.T, mode *logicalTestMode) { } defer cancelFactory() - loop, cancelLoop, err := factory.Start(&logical.LoopConfig{ + loopStopper := stopper.WithContext(fixture.Context) + loop, err := factory.Start(loopStopper, &logical.LoopConfig{ Dialect: gen, LoopName: "generator", TargetSchema: dbName, @@ -242,7 +244,7 @@ func testLogicalSmoke(t *testing.T, mode *logicalTestMode) { } // Wait for the loop to shut down, or a timeout. - cancelLoop() + loopStopper.Stop(100 * time.Millisecond) gen.emit(0) // Kick the simplistic ReadInto loop so that it exits. select { case <-loop.Stopped(): @@ -351,13 +353,12 @@ api.configureTable("t_2", { r.NoError(err) defer cancelFactory() - _, cancelLoop, err := factory.Start(&logical.LoopConfig{ + _, err = factory.Start(fixture.Context, &logical.LoopConfig{ Dialect: gen, LoopName: "generator", TargetSchema: dbName, }) r.NoError(err) - defer cancelLoop() // Wait for replication. for idx, tgt := range tgts { diff --git a/internal/source/logical/loop.go b/internal/source/logical/loop.go index a048d7533..3afc55d3e 100644 --- a/internal/source/logical/loop.go +++ b/internal/source/logical/loop.go @@ -176,7 +176,7 @@ func (l *loop) run() { defer log.Debugf("replication loop %q shut down", l.loopConfig.LoopName) for { - err := l.runOnce(l.running) + err := l.runOnce() // Otherwise, log any error, and sleep for a bit. if err != nil { @@ -199,13 +199,14 @@ func (l *loop) run() { // runOnce is called by run. If the Dialect implements a leasing // behavior, a lease will be obtained before any further action is // taken. -func (l *loop) runOnce(ctx context.Context) error { +func (l *loop) runOnce() error { + var stop *stopper.Context if lessor, ok := l.loopConfig.Dialect.(Lessor); ok { // Loop until we can acquire a lease. var lease types.Lease for { var err error - lease, err = lessor.Acquire(ctx) + lease, err = lessor.Acquire(l.running) // Lease acquired. if err == nil { log.Tracef("lease %s acquired", l.loopConfig.LoopName) @@ -223,8 +224,8 @@ func (l *loop) runOnce(ctx context.Context) error { select { case <-time.After(duration): continue - case <-ctx.Done(): - return ctx.Err() + case <-l.running.Stopping(): + return nil } } // General err, defer to the loop's retry delay. @@ -232,11 +233,13 @@ func (l *loop) runOnce(ctx context.Context) error { } defer lease.Release() // Ensure that all work is bound to the lifetime of the lease. - ctx = lease.Context() + stop = stopper.WithContext(lease.Context()) + } else { + stop = l.running } // Ensure our in-memory consistent point matches the database. - point, err := l.loadConsistentPoint(ctx) + point, err := l.loadConsistentPoint(stop) if err != nil { return err } @@ -245,7 +248,7 @@ func (l *loop) runOnce(ctx context.Context) error { // Determine how to perform the filling. source, events, isBackfilling := l.chooseFillStrategy() - return l.runOnceUsing(stopper.From(ctx), source, events, isBackfilling) + return l.runOnceUsing(stop, source, events, isBackfilling) } // runOnceUsing is called from runOnce or doBackfill. @@ -429,13 +432,14 @@ func (l *loop) doBackfill(ctx context.Context, loopName string, backfiller Backf cfg.Dialect = backfiller cfg.LoopName = loopName - // The incoming context should already be a stopper. + // Create a (most likely nested) stopper. stop := stopper.WithContext(ctx) - filler, cleanup, err := l.factory.newLoop(stop, cfg) + filler, err := l.factory.newLoop(stop, cfg) if err != nil { return err } - defer cleanup() + // We don't need any grace time since the sub-loop has exited. + defer func() { stop.Stop(0) }() return filler.loop.runOnceUsing( stop, diff --git a/internal/source/mylogical/injector.go b/internal/source/mylogical/injector.go index 6501921e7..c7bf67460 100644 --- a/internal/source/mylogical/injector.go +++ b/internal/source/mylogical/injector.go @@ -27,13 +27,15 @@ import ( "github.com/cockroachdb/cdc-sink/internal/staging" "github.com/cockroachdb/cdc-sink/internal/target" "github.com/cockroachdb/cdc-sink/internal/util/diag" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" ) // Start creates a MySQL/MariaDB logical replication loop using the // provided configuration. -func Start(ctx context.Context, config *Config) (*MYLogical, func(), error) { +func Start(ctx *stopper.Context, config *Config) (*MYLogical, func(), error) { panic(wire.Build( + wire.Bind(new(context.Context), new(*stopper.Context)), wire.Bind(new(logical.Config), new(*Config)), wire.Struct(new(MYLogical), "*"), Set, diff --git a/internal/source/mylogical/provider.go b/internal/source/mylogical/provider.go index 656997152..599451f6b 100644 --- a/internal/source/mylogical/provider.go +++ b/internal/source/mylogical/provider.go @@ -21,6 +21,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/source/logical" "github.com/cockroachdb/cdc-sink/internal/types" "github.com/cockroachdb/cdc-sink/internal/util/ident" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/go-mysql-org/go-mysql/replication" "github.com/google/wire" ) @@ -64,8 +65,8 @@ func ProvideDialect(config *Config, _ *script.Loader) (logical.Dialect, error) { // ProvideLoop is called by Wire to construct the sole logical loop used // in the mylogical mode. func ProvideLoop( - cfg *Config, dialect logical.Dialect, loops *logical.Factory, -) (*logical.Loop, func(), error) { + ctx *stopper.Context, cfg *Config, dialect logical.Dialect, loops *logical.Factory, +) (*logical.Loop, error) { cfg.Dialect = dialect - return loops.Start(&cfg.LoopConfig) + return loops.Start(ctx, &cfg.LoopConfig) } diff --git a/internal/source/mylogical/wire_gen.go b/internal/source/mylogical/wire_gen.go index 45faca347..9c1ea4ffd 100644 --- a/internal/source/mylogical/wire_gen.go +++ b/internal/source/mylogical/wire_gen.go @@ -7,7 +7,6 @@ package mylogical import ( - "context" "github.com/cockroachdb/cdc-sink/internal/script" "github.com/cockroachdb/cdc-sink/internal/source/logical" "github.com/cockroachdb/cdc-sink/internal/staging/memo" @@ -17,13 +16,14 @@ import ( "github.com/cockroachdb/cdc-sink/internal/target/schemawatch" "github.com/cockroachdb/cdc-sink/internal/util/applycfg" "github.com/cockroachdb/cdc-sink/internal/util/diag" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" ) // Injectors from injector.go: // Start creates a MySQL/MariaDB logical replication loop using the // provided configuration. -func Start(ctx context.Context, config *Config) (*MYLogical, func(), error) { +func Start(ctx *stopper.Context, config *Config) (*MYLogical, func(), error) { diagnostics, cleanup := diag.New(ctx) scriptConfig, err := logical.ProvideUserScriptConfig(config) if err != nil { @@ -120,7 +120,7 @@ func Start(ctx context.Context, config *Config) (*MYLogical, func(), error) { cleanup() return nil, nil, err } - loop, cleanup7, err := ProvideLoop(config, dialect, factory) + loop, err := ProvideLoop(ctx, config, dialect, factory) if err != nil { cleanup6() cleanup5() @@ -135,7 +135,6 @@ func Start(ctx context.Context, config *Config) (*MYLogical, func(), error) { Loop: loop, } return myLogical, func() { - cleanup7() cleanup6() cleanup5() cleanup4() diff --git a/internal/source/pglogical/injector.go b/internal/source/pglogical/injector.go index 8f519aef5..4e1769b24 100644 --- a/internal/source/pglogical/injector.go +++ b/internal/source/pglogical/injector.go @@ -27,13 +27,15 @@ import ( "github.com/cockroachdb/cdc-sink/internal/staging" "github.com/cockroachdb/cdc-sink/internal/target" "github.com/cockroachdb/cdc-sink/internal/util/diag" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" ) // Start creates a PostgreSQL logical replication loop using the // provided configuration. -func Start(ctx context.Context, config *Config) (*PGLogical, func(), error) { +func Start(ctx *stopper.Context, config *Config) (*PGLogical, func(), error) { panic(wire.Build( + wire.Bind(new(context.Context), new(*stopper.Context)), wire.Bind(new(logical.Config), new(*Config)), wire.Struct(new(PGLogical), "*"), Set, diff --git a/internal/source/pglogical/provider.go b/internal/source/pglogical/provider.go index 58a95c437..83688ecf4 100644 --- a/internal/source/pglogical/provider.go +++ b/internal/source/pglogical/provider.go @@ -24,6 +24,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/types" "github.com/cockroachdb/cdc-sink/internal/util/ident" "github.com/cockroachdb/cdc-sink/internal/util/stdpool" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" "github.com/pkg/errors" log "github.com/sirupsen/logrus" @@ -103,8 +104,8 @@ func ProvideDialect( // ProvideLoop is called by Wire to construct the sole logical loop used // in the pglogical mode. func ProvideLoop( - cfg *Config, dialect logical.Dialect, loops *logical.Factory, -) (*logical.Loop, func(), error) { + ctx *stopper.Context, cfg *Config, dialect logical.Dialect, loops *logical.Factory, +) (*logical.Loop, error) { cfg.Dialect = dialect - return loops.Start(&cfg.LoopConfig) + return loops.Start(ctx, &cfg.LoopConfig) } diff --git a/internal/source/pglogical/wire_gen.go b/internal/source/pglogical/wire_gen.go index 1633130dd..7185329e6 100644 --- a/internal/source/pglogical/wire_gen.go +++ b/internal/source/pglogical/wire_gen.go @@ -7,7 +7,6 @@ package pglogical import ( - "context" "github.com/cockroachdb/cdc-sink/internal/script" "github.com/cockroachdb/cdc-sink/internal/source/logical" "github.com/cockroachdb/cdc-sink/internal/staging/memo" @@ -17,13 +16,14 @@ import ( "github.com/cockroachdb/cdc-sink/internal/target/schemawatch" "github.com/cockroachdb/cdc-sink/internal/util/applycfg" "github.com/cockroachdb/cdc-sink/internal/util/diag" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" ) // Injectors from injector.go: // Start creates a PostgreSQL logical replication loop using the // provided configuration. -func Start(ctx context.Context, config *Config) (*PGLogical, func(), error) { +func Start(ctx *stopper.Context, config *Config) (*PGLogical, func(), error) { diagnostics, cleanup := diag.New(ctx) scriptConfig, err := logical.ProvideUserScriptConfig(config) if err != nil { @@ -120,7 +120,7 @@ func Start(ctx context.Context, config *Config) (*PGLogical, func(), error) { cleanup() return nil, nil, err } - loop, cleanup7, err := ProvideLoop(config, dialect, factory) + loop, err := ProvideLoop(ctx, config, dialect, factory) if err != nil { cleanup6() cleanup5() @@ -135,7 +135,6 @@ func Start(ctx context.Context, config *Config) (*PGLogical, func(), error) { Loop: loop, } return pgLogical, func() { - cleanup7() cleanup6() cleanup5() cleanup4() diff --git a/internal/source/server/injector.go b/internal/source/server/injector.go index da1edf4c9..57104dec1 100644 --- a/internal/source/server/injector.go +++ b/internal/source/server/injector.go @@ -28,10 +28,11 @@ import ( "github.com/cockroachdb/cdc-sink/internal/staging" "github.com/cockroachdb/cdc-sink/internal/target" "github.com/cockroachdb/cdc-sink/internal/util/diag" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" ) -func NewServer(ctx context.Context, config *Config) (*Server, func(), error) { +func NewServer(ctx *stopper.Context, config *Config) (*Server, func(), error) { panic(wire.Build( Set, cdc.Set, @@ -40,6 +41,7 @@ func NewServer(ctx context.Context, config *Config) (*Server, func(), error) { script.Set, staging.Set, target.Set, + wire.Bind(new(context.Context), new(*stopper.Context)), wire.Bind(new(logical.Config), new(*Config)), wire.FieldsOf(new(*Config), "CDC"), )) diff --git a/internal/source/server/integration_test.go b/internal/source/server/integration_test.go index 6414c5a3f..488738748 100644 --- a/internal/source/server/integration_test.go +++ b/internal/source/server/integration_test.go @@ -35,6 +35,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/util/hlc" "github.com/cockroachdb/cdc-sink/internal/util/ident" "github.com/cockroachdb/cdc-sink/internal/util/stdlogical" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" joonix "github.com/joonix/log" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" @@ -153,7 +154,7 @@ func testIntegration(t *testing.T, cfg testConfig) { targetPool := destFixture.TargetPool // The target fixture contains the cdc-sink server. - targetFixture, cancel, err := newTestFixture(ctx, &Config{ + targetFixture, cancel, err := newTestFixture(stopper.WithContext(ctx), &Config{ CDC: cdc.Config{ BaseConfig: logical.BaseConfig{ Immediate: cfg.immediate, diff --git a/internal/source/server/test_fixture.go b/internal/source/server/test_fixture.go index 4cb004bc4..42fb698c2 100644 --- a/internal/source/server/test_fixture.go +++ b/internal/source/server/test_fixture.go @@ -31,6 +31,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/types" "github.com/cockroachdb/cdc-sink/internal/util/diag" "github.com/cockroachdb/cdc-sink/internal/util/ident" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" ) @@ -48,7 +49,7 @@ type testFixture struct { // We want this to be as close as possible to Start, it just exposes // additional plumbing details via the returned testFixture pointer. -func newTestFixture(context.Context, *Config) (*testFixture, func(), error) { +func newTestFixture(*stopper.Context, *Config) (*testFixture, func(), error) { panic(wire.Build( Set, cdc.Set, @@ -57,6 +58,7 @@ func newTestFixture(context.Context, *Config) (*testFixture, func(), error) { script.Set, staging.Set, target.Set, + wire.Bind(new(context.Context), new(*stopper.Context)), wire.Bind(new(logical.Config), new(*Config)), wire.FieldsOf(new(*Config), "CDC"), wire.Struct(new(testFixture), "*"), diff --git a/internal/source/server/wire_gen.go b/internal/source/server/wire_gen.go index 878be13bc..f388ddf33 100644 --- a/internal/source/server/wire_gen.go +++ b/internal/source/server/wire_gen.go @@ -7,7 +7,6 @@ package server import ( - "context" "github.com/cockroachdb/cdc-sink/internal/script" "github.com/cockroachdb/cdc-sink/internal/source/cdc" "github.com/cockroachdb/cdc-sink/internal/source/logical" @@ -22,12 +21,13 @@ import ( "github.com/cockroachdb/cdc-sink/internal/util/applycfg" "github.com/cockroachdb/cdc-sink/internal/util/diag" "github.com/cockroachdb/cdc-sink/internal/util/ident" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "net" ) // Injectors from injector.go: -func NewServer(ctx context.Context, config *Config) (*Server, func(), error) { +func NewServer(ctx *stopper.Context, config *Config) (*Server, func(), error) { diagnostics, cleanup := diag.New(ctx) scriptConfig, err := logical.ProvideUserScriptConfig(config) if err != nil { @@ -170,8 +170,8 @@ func NewServer(ctx context.Context, config *Config) (*Server, func(), error) { return nil, nil, err } metaTable := cdc.ProvideMetaTable(cdcConfig) - stagers := stage.ProvideFactory(stagingPool, stagingSchema) - resolvers, cleanup10, err := cdc.ProvideResolvers(ctx, cdcConfig, typesLeases, factory, metaTable, stagingPool, stagers, watchers) + stagers := stage.ProvideFactory(stagingPool, stagingSchema, ctx) + resolvers, err := cdc.ProvideResolvers(ctx, cdcConfig, typesLeases, factory, metaTable, stagingPool, stagers, watchers) if err != nil { cleanup9() cleanup8() @@ -196,7 +196,6 @@ func NewServer(ctx context.Context, config *Config) (*Server, func(), error) { serveMux := ProvideMux(handler, stagingPool, targetPool) tlsConfig, err := ProvideTLSConfig(config) if err != nil { - cleanup10() cleanup9() cleanup8() cleanup7() @@ -208,9 +207,8 @@ func NewServer(ctx context.Context, config *Config) (*Server, func(), error) { cleanup() return nil, nil, err } - server, cleanup11 := ProvideServer(authenticator, diagnostics, listener, serveMux, tlsConfig) + server, cleanup10 := ProvideServer(authenticator, diagnostics, listener, serveMux, tlsConfig) return server, func() { - cleanup11() cleanup10() cleanup9() cleanup8() @@ -228,8 +226,8 @@ func NewServer(ctx context.Context, config *Config) (*Server, func(), error) { // We want this to be as close as possible to Start, it just exposes // additional plumbing details via the returned testFixture pointer. -func newTestFixture(contextContext context.Context, config *Config) (*testFixture, func(), error) { - diagnostics, cleanup := diag.New(contextContext) +func newTestFixture(context *stopper.Context, config *Config) (*testFixture, func(), error) { + diagnostics, cleanup := diag.New(context) scriptConfig, err := logical.ProvideUserScriptConfig(config) if err != nil { cleanup() @@ -245,7 +243,7 @@ func newTestFixture(contextContext context.Context, config *Config) (*testFixtur cleanup() return nil, nil, err } - stagingPool, cleanup2, err := logical.ProvideStagingPool(contextContext, baseConfig, diagnostics) + stagingPool, cleanup2, err := logical.ProvideStagingPool(context, baseConfig, diagnostics) if err != nil { cleanup() return nil, nil, err @@ -256,7 +254,7 @@ func newTestFixture(contextContext context.Context, config *Config) (*testFixtur cleanup() return nil, nil, err } - authenticator, cleanup3, err := ProvideAuthenticator(contextContext, diagnostics, config, stagingPool, stagingSchema) + authenticator, cleanup3, err := ProvideAuthenticator(context, diagnostics, config, stagingPool, stagingSchema) if err != nil { cleanup2() cleanup() @@ -270,7 +268,7 @@ func newTestFixture(contextContext context.Context, config *Config) (*testFixtur return nil, nil, err } cdcConfig := &config.CDC - targetPool, cleanup5, err := logical.ProvideTargetPool(contextContext, baseConfig, diagnostics) + targetPool, cleanup5, err := logical.ProvideTargetPool(context, baseConfig, diagnostics) if err != nil { cleanup4() cleanup3() @@ -320,7 +318,7 @@ func newTestFixture(contextContext context.Context, config *Config) (*testFixtur cleanup() return nil, nil, err } - memoMemo, err := memo.ProvideMemo(contextContext, stagingPool, stagingSchema) + memoMemo, err := memo.ProvideMemo(context, stagingPool, stagingSchema) if err != nil { cleanup8() cleanup7() @@ -333,7 +331,7 @@ func newTestFixture(contextContext context.Context, config *Config) (*testFixtur return nil, nil, err } checker := version.ProvideChecker(stagingPool, memoMemo) - factory, err := logical.ProvideFactory(contextContext, appliers, configs, baseConfig, diagnostics, memoMemo, loader, stagingPool, targetPool, watchers, checker) + factory, err := logical.ProvideFactory(context, appliers, configs, baseConfig, diagnostics, memoMemo, loader, stagingPool, targetPool, watchers, checker) if err != nil { cleanup8() cleanup7() @@ -357,7 +355,7 @@ func newTestFixture(contextContext context.Context, config *Config) (*testFixtur cleanup() return nil, nil, err } - typesLeases, err := leases.ProvideLeases(contextContext, stagingPool, stagingSchema) + typesLeases, err := leases.ProvideLeases(context, stagingPool, stagingSchema) if err != nil { cleanup9() cleanup8() @@ -371,8 +369,8 @@ func newTestFixture(contextContext context.Context, config *Config) (*testFixtur return nil, nil, err } metaTable := cdc.ProvideMetaTable(cdcConfig) - stagers := stage.ProvideFactory(stagingPool, stagingSchema) - resolvers, cleanup10, err := cdc.ProvideResolvers(contextContext, cdcConfig, typesLeases, factory, metaTable, stagingPool, stagers, watchers) + stagers := stage.ProvideFactory(stagingPool, stagingSchema, context) + resolvers, err := cdc.ProvideResolvers(context, cdcConfig, typesLeases, factory, metaTable, stagingPool, stagers, watchers) if err != nil { cleanup9() cleanup8() @@ -397,7 +395,6 @@ func newTestFixture(contextContext context.Context, config *Config) (*testFixtur serveMux := ProvideMux(handler, stagingPool, targetPool) tlsConfig, err := ProvideTLSConfig(config) if err != nil { - cleanup10() cleanup9() cleanup8() cleanup7() @@ -409,7 +406,7 @@ func newTestFixture(contextContext context.Context, config *Config) (*testFixtur cleanup() return nil, nil, err } - server, cleanup11 := ProvideServer(authenticator, diagnostics, listener, serveMux, tlsConfig) + server, cleanup10 := ProvideServer(authenticator, diagnostics, listener, serveMux, tlsConfig) serverTestFixture := &testFixture{ Authenticator: authenticator, Config: config, @@ -422,7 +419,6 @@ func newTestFixture(contextContext context.Context, config *Config) (*testFixtur Watcher: watchers, } return serverTestFixture, func() { - cleanup11() cleanup10() cleanup9() cleanup8() diff --git a/internal/staging/stage/factory.go b/internal/staging/stage/factory.go index 0a0455b24..f93de0815 100644 --- a/internal/staging/stage/factory.go +++ b/internal/staging/stage/factory.go @@ -26,6 +26,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/types" "github.com/cockroachdb/cdc-sink/internal/util/hlc" "github.com/cockroachdb/cdc-sink/internal/util/ident" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/jackc/pgx/v5" "github.com/pkg/errors" log "github.com/sirupsen/logrus" @@ -34,6 +35,7 @@ import ( type factory struct { db *types.StagingPool stagingDB ident.Schema + stop *stopper.Context mu struct { sync.RWMutex @@ -44,14 +46,14 @@ type factory struct { var _ types.Stagers = (*factory)(nil) // Get returns a memoized instance of a stage for the given table. -func (f *factory) Get(ctx context.Context, target ident.Table) (types.Stager, error) { +func (f *factory) Get(_ context.Context, target ident.Table) (types.Stager, error) { if ret := f.getUnlocked(target); ret != nil { return ret, nil } - return f.createUnlocked(ctx, target) + return f.createUnlocked(target) } -func (f *factory) createUnlocked(ctx context.Context, table ident.Table) (*stage, error) { +func (f *factory) createUnlocked(table ident.Table) (*stage, error) { f.mu.Lock() defer f.mu.Unlock() @@ -59,7 +61,7 @@ func (f *factory) createUnlocked(ctx context.Context, table ident.Table) (*stage return ret, nil } - ret, err := newStore(ctx, f.db, f.stagingDB, table) + ret, err := newStore(f.stop, f.db, f.stagingDB, table) if err == nil { f.mu.instances.Put(table, ret) } diff --git a/internal/staging/stage/provider.go b/internal/staging/stage/provider.go index a19d86850..cb7d7cfef 100644 --- a/internal/staging/stage/provider.go +++ b/internal/staging/stage/provider.go @@ -19,6 +19,7 @@ package stage import ( "github.com/cockroachdb/cdc-sink/internal/types" "github.com/cockroachdb/cdc-sink/internal/util/ident" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" ) @@ -28,10 +29,13 @@ var Set = wire.NewSet( ) // ProvideFactory is called by Wire to construct the Stagers factory. -func ProvideFactory(db *types.StagingPool, stagingDB ident.StagingSchema) types.Stagers { +func ProvideFactory( + db *types.StagingPool, stagingDB ident.StagingSchema, stop *stopper.Context, +) types.Stagers { f := &factory{ db: db, stagingDB: stagingDB.Schema(), + stop: stop, } f.mu.instances = &ident.TableMap[*stage]{} return f diff --git a/internal/staging/stage/stage.go b/internal/staging/stage/stage.go index 226b39165..f3592a9ba 100644 --- a/internal/staging/stage/stage.go +++ b/internal/staging/stage/stage.go @@ -32,6 +32,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/util/ident" "github.com/cockroachdb/cdc-sink/internal/util/metrics" "github.com/cockroachdb/cdc-sink/internal/util/retry" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/pkg/errors" @@ -79,7 +80,7 @@ var _ types.Stager = (*stage)(nil) // newStore constructs a new mutation stage that will track pending // mutations to be applied to the given target table. func newStore( - ctx context.Context, db *types.StagingPool, stagingDB ident.Schema, target ident.Table, + ctx *stopper.Context, db *types.StagingPool, stagingDB ident.Schema, target ident.Table, ) (*stage, error) { table := stagingTable(stagingDB, target) diff --git a/main.go b/main.go index c26c2bf09..865f290b0 100644 --- a/main.go +++ b/main.go @@ -39,6 +39,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/cmd/version" "github.com/cockroachdb/cdc-sink/internal/script" "github.com/cockroachdb/cdc-sink/internal/util/logfmt" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" joonix "github.com/joonix/log" "github.com/pkg/errors" log "github.com/sirupsen/logrus" @@ -46,6 +47,7 @@ import ( ) func main() { + var gracePeriod time.Duration var logFormat, logDestination string var verbosity int root := &cobra.Command{ @@ -96,6 +98,7 @@ func main() { }, } f := root.PersistentFlags() + f.DurationVar(&gracePeriod, "gracePeriod", 30*time.Second, "allow background processes to exit") f.StringVar(&logFormat, "logFormat", "text", "choose log output format [ fluent, text ]") f.StringVar(&logDestination, "logDestination", "", "write logs to a file, instead of stdout") f.CountVarP(&verbosity, "verbose", "v", "increase logging verbosity to debug; repeat for trace") @@ -114,10 +117,28 @@ func main() { version.Command(), ) - ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) - log.DeferExitHandler(cancel) + stop := stopper.WithContext(context.Background()) + // Stop cleanly on interrupt. + stop.Go(func() error { + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) + defer cancel() + select { + case <-ctx.Done(): + log.Info("Interrupted") + stop.Stop(gracePeriod) + case <-stop.Stopping(): + // Nothing to do. + } + return nil + }) + // Allow log.Exit() or log.Fatal() to trigger shutdown. + log.DeferExitHandler(func() { + stop.Stop(gracePeriod) + <-stop.Done() + }) - if err := root.ExecuteContext(ctx); err != nil { + // Commands can unwrap the stopper as needed. + if err := root.ExecuteContext(stop); err != nil { log.WithError(err).Error("exited") log.Exit(1) }