diff --git a/internal/cmd/fslogical/fslogical.go b/internal/cmd/fslogical/fslogical.go index f52b3cc2d..1d80adbf0 100644 --- a/internal/cmd/fslogical/fslogical.go +++ b/internal/cmd/fslogical/fslogical.go @@ -21,6 +21,7 @@ package fslogical import ( "github.com/cockroachdb/cdc-sink/internal/source/fslogical" "github.com/cockroachdb/cdc-sink/internal/util/stdlogical" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/spf13/cobra" ) @@ -31,7 +32,8 @@ func Command() *cobra.Command { Bind: cfg.Bind, Short: "start a Google Cloud Firestore logical replication feed", Start: func(cmd *cobra.Command) (any, func(), error) { - return fslogical.Start(cmd.Context(), cfg) + // main.go provides this stopper. + return fslogical.Start(stopper.From(cmd.Context()), cfg) }, Use: "fslogical", }) diff --git a/internal/cmd/mylogical/mylogical.go b/internal/cmd/mylogical/mylogical.go index ee17cbafd..68304ca6b 100644 --- a/internal/cmd/mylogical/mylogical.go +++ b/internal/cmd/mylogical/mylogical.go @@ -21,6 +21,7 @@ package mylogical import ( "github.com/cockroachdb/cdc-sink/internal/source/mylogical" "github.com/cockroachdb/cdc-sink/internal/util/stdlogical" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/spf13/cobra" ) @@ -31,7 +32,8 @@ func Command() *cobra.Command { Bind: cfg.Bind, Short: "start a mySQL replication feed", Start: func(cmd *cobra.Command) (any, func(), error) { - return mylogical.Start(cmd.Context(), cfg) + // main.go provides a stopper. + return mylogical.Start(stopper.From(cmd.Context()), cfg) }, Use: "mylogical", }) diff --git a/internal/cmd/pglogical/pglogical.go b/internal/cmd/pglogical/pglogical.go index 814e8ea30..4de5d83da 100644 --- a/internal/cmd/pglogical/pglogical.go +++ b/internal/cmd/pglogical/pglogical.go @@ -21,6 +21,7 @@ package pglogical import ( "github.com/cockroachdb/cdc-sink/internal/source/pglogical" "github.com/cockroachdb/cdc-sink/internal/util/stdlogical" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/spf13/cobra" ) @@ -31,7 +32,8 @@ func Command() *cobra.Command { Bind: cfg.Bind, Short: "start a pg logical replication feed", Start: func(cmd *cobra.Command) (any, func(), error) { - return pglogical.Start(cmd.Context(), cfg) + // main.go provides a stopper. + return pglogical.Start(stopper.From(cmd.Context()), cfg) }, Use: "pglogical", }) diff --git a/internal/cmd/start/start.go b/internal/cmd/start/start.go index cb17ca0b3..88cc6fefb 100644 --- a/internal/cmd/start/start.go +++ b/internal/cmd/start/start.go @@ -20,6 +20,7 @@ package start import ( "github.com/cockroachdb/cdc-sink/internal/source/server" "github.com/cockroachdb/cdc-sink/internal/util/stdlogical" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/spf13/cobra" ) @@ -30,7 +31,8 @@ func Command() *cobra.Command { Bind: cfg.Bind, Short: "start the server", Start: func(cmd *cobra.Command) (any, func(), error) { - return server.NewServer(cmd.Context(), &cfg) + // main.go gives us a stopper, just unwrap it. + return server.NewServer(stopper.From(cmd.Context()), &cfg) }, Use: "start", }) diff --git a/internal/script/injector.go b/internal/script/injector.go index bdeb1c2f0..248d8eceb 100644 --- a/internal/script/injector.go +++ b/internal/script/injector.go @@ -27,24 +27,29 @@ import ( "github.com/cockroachdb/cdc-sink/internal/types" "github.com/cockroachdb/cdc-sink/internal/util/applycfg" "github.com/cockroachdb/cdc-sink/internal/util/diag" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" ) // Evaluate the loaded script. func Evaluate( - ctx context.Context, + ctx *stopper.Context, loader *Loader, configs *applycfg.Configs, diags *diag.Diagnostics, targetSchema TargetSchema, watchers types.Watchers, ) (*UserScript, error) { - panic(wire.Build(ProvideUserScript)) + panic(wire.Build( + ProvideUserScript, + wire.Bind(new(context.Context), new(*stopper.Context)), + )) } func newScriptFromFixture(*all.Fixture, *Config, TargetSchema) (*UserScript, error) { panic(wire.Build( Set, + wire.Bind(new(context.Context), new(*stopper.Context)), wire.FieldsOf(new(*all.Fixture), "Diagnostics", "Fixture", "Configs", "Watchers"), wire.FieldsOf(new(*base.Fixture), "Context"), )) diff --git a/internal/script/wire_gen.go b/internal/script/wire_gen.go index 052cf1f8a..9039dea35 100644 --- a/internal/script/wire_gen.go +++ b/internal/script/wire_gen.go @@ -7,11 +7,11 @@ package script import ( - "context" "github.com/cockroachdb/cdc-sink/internal/sinktest/all" "github.com/cockroachdb/cdc-sink/internal/types" "github.com/cockroachdb/cdc-sink/internal/util/applycfg" "github.com/cockroachdb/cdc-sink/internal/util/diag" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" ) import ( @@ -21,7 +21,7 @@ import ( // Injectors from injector.go: // Evaluate the loaded script. -func Evaluate(ctx context.Context, loader *Loader, configs *applycfg.Configs, diags *diag.Diagnostics, targetSchema TargetSchema, watchers types.Watchers) (*UserScript, error) { +func Evaluate(ctx *stopper.Context, loader *Loader, configs *applycfg.Configs, diags *diag.Diagnostics, targetSchema TargetSchema, watchers types.Watchers) (*UserScript, error) { userScript, err := ProvideUserScript(ctx, configs, loader, diags, targetSchema, watchers) if err != nil { return nil, err @@ -31,7 +31,7 @@ func Evaluate(ctx context.Context, loader *Loader, configs *applycfg.Configs, di func newScriptFromFixture(fixture *all.Fixture, config *Config, targetSchema TargetSchema) (*UserScript, error) { baseFixture := fixture.Fixture - contextContext := baseFixture.Context + context := baseFixture.Context configs := fixture.Configs loader, err := ProvideLoader(config) if err != nil { @@ -39,7 +39,7 @@ func newScriptFromFixture(fixture *all.Fixture, config *Config, targetSchema Tar } diagnostics := fixture.Diagnostics watchers := fixture.Watchers - userScript, err := ProvideUserScript(contextContext, configs, loader, diagnostics, targetSchema, watchers) + userScript, err := ProvideUserScript(context, configs, loader, diagnostics, targetSchema, watchers) if err != nil { return nil, err } diff --git a/internal/sinktest/all/wire_gen.go b/internal/sinktest/all/wire_gen.go index 7a2dce40c..45687f596 100644 --- a/internal/sinktest/all/wire_gen.go +++ b/internal/sinktest/all/wire_gen.go @@ -23,10 +23,7 @@ import ( // NewFixture constructs a self-contained test fixture for all services // in the target sub-packages. func NewFixture() (*Fixture, func(), error) { - context, cleanup, err := base.ProvideContext() - if err != nil { - return nil, nil, err - } + context, cleanup := base.ProvideContext() diagnostics, cleanup2 := diag.New(context) sourcePool, cleanup3, err := base.ProvideSourcePool(context, diagnostics) if err != nil { @@ -160,7 +157,7 @@ func NewFixture() (*Fixture, func(), error) { cleanup() return nil, nil, err } - stagers := stage.ProvideFactory(stagingPool, stagingSchema) + stagers := stage.ProvideFactory(stagingPool, stagingSchema, context) checker := version.ProvideChecker(stagingPool, memoMemo) watcher, err := ProvideWatcher(context, targetSchema, watchers) if err != nil { diff --git a/internal/sinktest/base/provider.go b/internal/sinktest/base/provider.go index c257a0de8..4a60589de 100644 --- a/internal/sinktest/base/provider.go +++ b/internal/sinktest/base/provider.go @@ -34,6 +34,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/util/retry" "github.com/cockroachdb/cdc-sink/internal/util/stdpool" "github.com/cockroachdb/cdc-sink/internal/util/stmtcache" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" "github.com/pkg/errors" log "github.com/sirupsen/logrus" @@ -95,6 +96,7 @@ var TestSet = wire.NewSet( ProvideTargetStatements, diag.New, + wire.Bind(new(context.Context), new(*stopper.Context)), wire.Struct(new(Fixture), "*"), ) @@ -102,7 +104,7 @@ var TestSet = wire.NewSet( // without the other services provided by the target package. One can be // constructed by calling NewFixture. type Fixture struct { - Context context.Context // The context for the test. + Context *stopper.Context // The context for the test. SourcePool *types.SourcePool // Access to user-data tables and changefeed creation. SourceSchema sinktest.SourceSchema // A container for tables within SourcePool. StagingPool *types.StagingPool // Access to __cdc_sink database. @@ -138,9 +140,19 @@ var caseTimout = flag.Duration( // ProvideContext returns an execution context that is associated with a // singleton connection to a CockroachDB cluster. -func ProvideContext() (context.Context, func(), error) { - ctx, cancel := context.WithTimeout(context.Background(), *caseTimout) - return ctx, cancel, nil +func ProvideContext() (*stopper.Context, func()) { + ctx := stopper.WithContext(context.Background()) + ctx.Go(func() error { + select { + case <-ctx.Stopping(): + // Clean shutdown, do nothing. + case <-time.After(*caseTimout): + // Just cancel immediately. + ctx.Stop(0) + } + return nil + }) + return ctx, func() { ctx.Stop(100 * time.Millisecond) } } // ProvideSourcePool connects to the source database. If the source is a diff --git a/internal/sinktest/base/wire_gen.go b/internal/sinktest/base/wire_gen.go index dccfcc054..4e92ae2b8 100644 --- a/internal/sinktest/base/wire_gen.go +++ b/internal/sinktest/base/wire_gen.go @@ -14,10 +14,7 @@ import ( // NewFixture constructs a self-contained test fixture. func NewFixture() (*Fixture, func(), error) { - context, cleanup, err := ProvideContext() - if err != nil { - return nil, nil, err - } + context, cleanup := ProvideContext() diagnostics, cleanup2 := diag.New(context) sourcePool, cleanup3, err := ProvideSourcePool(context, diagnostics) if err != nil { diff --git a/internal/source/cdc/provider.go b/internal/source/cdc/provider.go index ec4bde9d0..6dc0173ac 100644 --- a/internal/source/cdc/provider.go +++ b/internal/source/cdc/provider.go @@ -17,12 +17,12 @@ package cdc import ( - "context" "fmt" "github.com/cockroachdb/cdc-sink/internal/source/logical" "github.com/cockroachdb/cdc-sink/internal/types" "github.com/cockroachdb/cdc-sink/internal/util/ident" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" "github.com/pkg/errors" ) @@ -55,7 +55,7 @@ func ProvideMetaTable(cfg *Config) MetaTable { // ProvideResolvers is called by Wire. func ProvideResolvers( - ctx context.Context, + ctx *stopper.Context, cfg *Config, leases types.Leases, loops *logical.Factory, @@ -63,9 +63,9 @@ func ProvideResolvers( pool *types.StagingPool, stagers types.Stagers, watchers types.Watchers, -) (*Resolvers, func(), error) { +) (*Resolvers, error) { if _, err := pool.Exec(ctx, fmt.Sprintf(schema, metaTable.Table())); err != nil { - return nil, nil, errors.WithStack(err) + return nil, errors.WithStack(err) } ret := &Resolvers{ @@ -75,6 +75,7 @@ func ProvideResolvers( metaTable: metaTable.Table(), pool: pool, stagers: stagers, + stop: ctx, watchers: watchers, } ret.mu.instances = &ident.SchemaMap[*logical.Loop]{} @@ -82,13 +83,13 @@ func ProvideResolvers( // Resume from previous state. schemas, err := ScanForTargetSchemas(ctx, pool, ret.metaTable) if err != nil { - return nil, nil, err + return nil, err } for _, schema := range schemas { if _, _, err := ret.get(ctx, schema); err != nil { - return nil, nil, errors.Wrapf(err, "could not bootstrap resolver for schema %s", schema) + return nil, errors.Wrapf(err, "could not bootstrap resolver for schema %s", schema) } } - return ret, ret.close, nil + return ret, nil } diff --git a/internal/source/cdc/resolver.go b/internal/source/cdc/resolver.go index 142b2a636..2597bb027 100644 --- a/internal/source/cdc/resolver.go +++ b/internal/source/cdc/resolver.go @@ -595,33 +595,15 @@ type Resolvers struct { metaTable ident.Table pool *types.StagingPool stagers types.Stagers + stop *stopper.Context // Manage lifecycle of background processes. watchers types.Watchers mu struct { sync.Mutex - cleanups []func() instances *ident.SchemaMap[*logical.Loop] } } -// close will drain any running resolver loops. -func (r *Resolvers) close() { - r.mu.Lock() - defer r.mu.Unlock() - - // Cancel each loop. - for _, cancel := range r.mu.cleanups { - cancel() - } - // Wait for shutdown. - _ = r.mu.instances.Range(func(_ ident.Schema, l *logical.Loop) error { - <-l.Stopped() - return nil - }) - r.mu.cleanups = nil - r.mu.instances = nil -} - // get creates or returns the [logical.Loop] and the enclosed resolver. func (r *Resolvers) get( ctx context.Context, target ident.Schema, @@ -643,7 +625,7 @@ func (r *Resolvers) get( return nil, ret, nil } - loop, cleanup, err := r.loops.Start(&logical.LoopConfig{ + loop, err := r.loops.Start(r.stop, &logical.LoopConfig{ Dialect: ret, LoopName: "changefeed-" + target.Raw(), TargetSchema: target, @@ -655,13 +637,8 @@ func (r *Resolvers) get( r.mu.instances.Put(target, loop) // Start a goroutine to retire old data. - stop := stopper.WithContext(context.Background()) - ret.retireLoop(stop) + ret.retireLoop(r.stop) - r.mu.cleanups = append(r.mu.cleanups, - cleanup, - func() { stop.Stop(time.Second) }, - ) return loop, ret, nil } diff --git a/internal/source/cdc/test_fixture.go b/internal/source/cdc/test_fixture.go index 21ba36ca5..410c86fa3 100644 --- a/internal/source/cdc/test_fixture.go +++ b/internal/source/cdc/test_fixture.go @@ -20,6 +20,8 @@ package cdc import ( + "context" + "github.com/cockroachdb/cdc-sink/internal/script" "github.com/cockroachdb/cdc-sink/internal/sinktest/all" "github.com/cockroachdb/cdc-sink/internal/sinktest/base" @@ -28,6 +30,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/staging/leases" "github.com/cockroachdb/cdc-sink/internal/target" "github.com/cockroachdb/cdc-sink/internal/util/diag" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" ) @@ -50,6 +53,7 @@ func newTestFixture(*all.Fixture, *Config) (*testFixture, func(), error) { target.Set, trust.New, // Is valid to use as a provider. wire.Struct(new(testFixture), "*"), + wire.Bind(new(context.Context), new(*stopper.Context)), wire.Bind(new(logical.Config), new(*Config)), )) } diff --git a/internal/source/cdc/wire_gen.go b/internal/source/cdc/wire_gen.go index 3c3225f6e..18b5e166e 100644 --- a/internal/source/cdc/wire_gen.go +++ b/internal/source/cdc/wire_gen.go @@ -121,7 +121,7 @@ func newTestFixture(fixture *all.Fixture, config *Config) (*testFixture, func(), } metaTable := ProvideMetaTable(config) stagers := fixture.Stagers - resolvers, cleanup8, err := ProvideResolvers(context, config, typesLeases, factory, metaTable, stagingPool, stagers, watchers) + resolvers, err := ProvideResolvers(context, config, typesLeases, factory, metaTable, stagingPool, stagers, watchers) if err != nil { cleanup7() cleanup6() @@ -147,7 +147,6 @@ func newTestFixture(fixture *all.Fixture, config *Config) (*testFixture, func(), Resolvers: resolvers, } return cdcTestFixture, func() { - cleanup8() cleanup7() cleanup6() cleanup5() diff --git a/internal/source/fslogical/injector.go b/internal/source/fslogical/injector.go index 9e1c166d0..d530bd303 100644 --- a/internal/source/fslogical/injector.go +++ b/internal/source/fslogical/injector.go @@ -29,13 +29,15 @@ import ( "github.com/cockroachdb/cdc-sink/internal/staging" "github.com/cockroachdb/cdc-sink/internal/target" "github.com/cockroachdb/cdc-sink/internal/util/diag" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" ) // Start creates a PostgreSQL logical replication loop using the // provided configuration. -func Start(context.Context, *Config) (*FSLogical, func(), error) { +func Start(*stopper.Context, *Config) (*FSLogical, func(), error) { panic(wire.Build( + wire.Bind(new(context.Context), new(*stopper.Context)), wire.Bind(new(logical.Config), new(*Config)), wire.Struct(new(FSLogical), "*"), ProvideFirestoreClient, @@ -53,6 +55,7 @@ func Start(context.Context, *Config) (*FSLogical, func(), error) { // Build remaining testable components from a common fixture. func startLoopsFromFixture(*all.Fixture, *Config) ([]*logical.Loop, func(), error) { panic(wire.Build( + wire.Bind(new(context.Context), new(*stopper.Context)), wire.Bind(new(logical.Config), new(*Config)), wire.FieldsOf(new(*base.Fixture), "Context"), wire.FieldsOf(new(*all.Fixture), diff --git a/internal/source/fslogical/provider.go b/internal/source/fslogical/provider.go index c0b4cd6fc..0fe8c7def 100644 --- a/internal/source/fslogical/provider.go +++ b/internal/source/fslogical/provider.go @@ -29,6 +29,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/source/logical" "github.com/cockroachdb/cdc-sink/internal/types" "github.com/cockroachdb/cdc-sink/internal/util/ident" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/golang/groupcache/lru" log "github.com/sirupsen/logrus" "google.golang.org/api/option" @@ -51,7 +52,7 @@ var enableWipe bool // ProvideLoops is called by wire to construct a logical-replication // loop for each configured collection/table pair. func ProvideLoops( - ctx context.Context, + ctx *stopper.Context, cfg *Config, fs *firestore.Client, loops *logical.Factory, @@ -59,22 +60,15 @@ func ProvideLoops( pool *types.StagingPool, st *Tombstones, userscript *script.UserScript, -) ([]*logical.Loop, func(), error) { +) ([]*logical.Loop, error) { if err := cfg.Preflight(); err != nil { - return nil, nil, err + return nil, err } idx := 0 ret := make([]*logical.Loop, userscript.Sources.Len()) recurseFilter := &ident.Map[struct{}]{} - cancels := make([]func(), userscript.Sources.Len()) - cancel := func() { - for _, fn := range cancels { - fn() - } - } - err := userscript.Sources.Range(func(sourceName ident.Ident, source *script.Source) error { var isGroup bool var sourcePath string @@ -112,7 +106,7 @@ func ProvideLoops( loopCfg.LoopName = sourceName.Raw() var err error - ret[idx], cancels[idx], err = loops.Start(loopCfg) + ret[idx], err = loops.Start(ctx, loopCfg) if err != nil { return err } @@ -121,11 +115,10 @@ func ProvideLoops( return nil }) if err != nil { - cancel() - return nil, nil, err + return nil, err } - return ret, cancel, nil + return ret, nil } // ProvideFirestoreClient is called by wire. If a local emulator is in @@ -174,12 +167,16 @@ func ProvideScriptTarget(cfg *Config) script.TargetSchema { // ProvideTombstones is called by wire to construct a helper that // manages document tombstones. func ProvideTombstones( - cfg *Config, fs *firestore.Client, loops *logical.Factory, userscript *script.UserScript, -) (*Tombstones, func(), error) { + ctx *stopper.Context, + cfg *Config, + fs *firestore.Client, + loops *logical.Factory, + userscript *script.UserScript, +) (*Tombstones, error) { ret := &Tombstones{cfg: cfg} if cfg.TombstoneCollection == "" { log.Trace("no tombstone collection was configured") - return ret, nil, nil + return ret, nil } ret.coll = fs.Collection(cfg.TombstoneCollection) @@ -188,7 +185,7 @@ func ProvideTombstones( ret.deletesTo.Put(source, dest.DeletesTo) return nil }); err != nil { - return nil, nil, err + return nil, err } ret.source = ident.New(cfg.TombstoneCollection) ret.mu.cache = &lru.Cache{MaxEntries: 1_000_000} @@ -196,11 +193,11 @@ func ProvideTombstones( loopConfig := cfg.LoopConfig.Copy() loopConfig.Dialect = ret loopConfig.LoopName = cfg.TombstoneCollection - _, cancel, err := loops.Start(loopConfig) + _, err := loops.Start(ctx, loopConfig) if err != nil { - return nil, nil, err + return nil, err } - return ret, cancel, nil + return ret, nil } // Wipe any leftover documents from testing. diff --git a/internal/source/fslogical/wire_gen.go b/internal/source/fslogical/wire_gen.go index 2e0610e88..9c8dbfac1 100644 --- a/internal/source/fslogical/wire_gen.go +++ b/internal/source/fslogical/wire_gen.go @@ -7,7 +7,6 @@ package fslogical import ( - "context" "github.com/cockroachdb/cdc-sink/internal/script" "github.com/cockroachdb/cdc-sink/internal/sinktest/all" "github.com/cockroachdb/cdc-sink/internal/source/logical" @@ -18,14 +17,15 @@ import ( "github.com/cockroachdb/cdc-sink/internal/target/schemawatch" "github.com/cockroachdb/cdc-sink/internal/util/applycfg" "github.com/cockroachdb/cdc-sink/internal/util/diag" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" ) // Injectors from injector.go: // Start creates a PostgreSQL logical replication loop using the // provided configuration. -func Start(contextContext context.Context, config *Config) (*FSLogical, func(), error) { - diagnostics, cleanup := diag.New(contextContext) +func Start(context *stopper.Context, config *Config) (*FSLogical, func(), error) { + diagnostics, cleanup := diag.New(context) configs, err := applycfg.ProvideConfigs(diagnostics) if err != nil { cleanup() @@ -47,7 +47,7 @@ func Start(contextContext context.Context, config *Config) (*FSLogical, func(), cleanup() return nil, nil, err } - targetPool, cleanup2, err := logical.ProvideTargetPool(contextContext, baseConfig, diagnostics) + targetPool, cleanup2, err := logical.ProvideTargetPool(context, baseConfig, diagnostics) if err != nil { cleanup() return nil, nil, err @@ -58,14 +58,14 @@ func Start(contextContext context.Context, config *Config) (*FSLogical, func(), cleanup() return nil, nil, err } - userScript, err := script.ProvideUserScript(contextContext, configs, loader, diagnostics, targetSchema, watchers) + userScript, err := script.ProvideUserScript(context, configs, loader, diagnostics, targetSchema, watchers) if err != nil { cleanup3() cleanup2() cleanup() return nil, nil, err } - client, cleanup4, err := ProvideFirestoreClient(contextContext, config, userScript) + client, cleanup4, err := ProvideFirestoreClient(context, config, userScript) if err != nil { cleanup3() cleanup2() @@ -91,7 +91,7 @@ func Start(contextContext context.Context, config *Config) (*FSLogical, func(), cleanup() return nil, nil, err } - stagingPool, cleanup7, err := logical.ProvideStagingPool(contextContext, baseConfig, diagnostics) + stagingPool, cleanup7, err := logical.ProvideStagingPool(context, baseConfig, diagnostics) if err != nil { cleanup6() cleanup5() @@ -112,7 +112,7 @@ func Start(contextContext context.Context, config *Config) (*FSLogical, func(), cleanup() return nil, nil, err } - memoMemo, err := memo.ProvideMemo(contextContext, stagingPool, stagingSchema) + memoMemo, err := memo.ProvideMemo(context, stagingPool, stagingSchema) if err != nil { cleanup7() cleanup6() @@ -124,7 +124,7 @@ func Start(contextContext context.Context, config *Config) (*FSLogical, func(), return nil, nil, err } checker := version.ProvideChecker(stagingPool, memoMemo) - factory, err := logical.ProvideFactory(contextContext, appliers, configs, baseConfig, diagnostics, memoMemo, loader, stagingPool, targetPool, watchers, checker) + factory, err := logical.ProvideFactory(context, appliers, configs, baseConfig, diagnostics, memoMemo, loader, stagingPool, targetPool, watchers, checker) if err != nil { cleanup7() cleanup6() @@ -135,7 +135,7 @@ func Start(contextContext context.Context, config *Config) (*FSLogical, func(), cleanup() return nil, nil, err } - tombstones, cleanup8, err := ProvideTombstones(config, client, factory, userScript) + tombstones, err := ProvideTombstones(context, config, client, factory, userScript) if err != nil { cleanup7() cleanup6() @@ -146,9 +146,8 @@ func Start(contextContext context.Context, config *Config) (*FSLogical, func(), cleanup() return nil, nil, err } - v, cleanup9, err := ProvideLoops(contextContext, config, client, factory, memoMemo, stagingPool, tombstones, userScript) + v, err := ProvideLoops(context, config, client, factory, memoMemo, stagingPool, tombstones, userScript) if err != nil { - cleanup8() cleanup7() cleanup6() cleanup5() @@ -163,8 +162,6 @@ func Start(contextContext context.Context, config *Config) (*FSLogical, func(), Loops: v, } return fsLogical, func() { - cleanup9() - cleanup8() cleanup7() cleanup6() cleanup5() @@ -178,7 +175,7 @@ func Start(contextContext context.Context, config *Config) (*FSLogical, func(), // Build remaining testable components from a common fixture. func startLoopsFromFixture(fixture *all.Fixture, config *Config) ([]*logical.Loop, func(), error) { baseFixture := fixture.Fixture - contextContext := baseFixture.Context + context := baseFixture.Context configs := fixture.Configs scriptConfig, err := logical.ProvideUserScriptConfig(config) if err != nil { @@ -188,14 +185,14 @@ func startLoopsFromFixture(fixture *all.Fixture, config *Config) ([]*logical.Loo if err != nil { return nil, nil, err } - diagnostics, cleanup := diag.New(contextContext) + diagnostics, cleanup := diag.New(context) targetSchema := ProvideScriptTarget(config) baseConfig, err := logical.ProvideBaseConfig(config, loader) if err != nil { cleanup() return nil, nil, err } - targetPool, cleanup2, err := logical.ProvideTargetPool(contextContext, baseConfig, diagnostics) + targetPool, cleanup2, err := logical.ProvideTargetPool(context, baseConfig, diagnostics) if err != nil { cleanup() return nil, nil, err @@ -206,14 +203,14 @@ func startLoopsFromFixture(fixture *all.Fixture, config *Config) ([]*logical.Loo cleanup() return nil, nil, err } - userScript, err := script.ProvideUserScript(contextContext, configs, loader, diagnostics, targetSchema, watchers) + userScript, err := script.ProvideUserScript(context, configs, loader, diagnostics, targetSchema, watchers) if err != nil { cleanup3() cleanup2() cleanup() return nil, nil, err } - client, cleanup4, err := ProvideFirestoreClient(contextContext, config, userScript) + client, cleanup4, err := ProvideFirestoreClient(context, config, userScript) if err != nil { cleanup3() cleanup2() @@ -240,7 +237,7 @@ func startLoopsFromFixture(fixture *all.Fixture, config *Config) ([]*logical.Loo return nil, nil, err } typesMemo := fixture.Memo - stagingPool, cleanup7, err := logical.ProvideStagingPool(contextContext, baseConfig, diagnostics) + stagingPool, cleanup7, err := logical.ProvideStagingPool(context, baseConfig, diagnostics) if err != nil { cleanup6() cleanup5() @@ -251,7 +248,7 @@ func startLoopsFromFixture(fixture *all.Fixture, config *Config) ([]*logical.Loo return nil, nil, err } checker := fixture.VersionChecker - factory, err := logical.ProvideFactory(contextContext, appliers, configs, baseConfig, diagnostics, typesMemo, loader, stagingPool, targetPool, watchers, checker) + factory, err := logical.ProvideFactory(context, appliers, configs, baseConfig, diagnostics, typesMemo, loader, stagingPool, targetPool, watchers, checker) if err != nil { cleanup7() cleanup6() @@ -262,7 +259,7 @@ func startLoopsFromFixture(fixture *all.Fixture, config *Config) ([]*logical.Loo cleanup() return nil, nil, err } - tombstones, cleanup8, err := ProvideTombstones(config, client, factory, userScript) + tombstones, err := ProvideTombstones(context, config, client, factory, userScript) if err != nil { cleanup7() cleanup6() @@ -273,9 +270,8 @@ func startLoopsFromFixture(fixture *all.Fixture, config *Config) ([]*logical.Loo cleanup() return nil, nil, err } - v, cleanup9, err := ProvideLoops(contextContext, config, client, factory, typesMemo, stagingPool, tombstones, userScript) + v, err := ProvideLoops(context, config, client, factory, typesMemo, stagingPool, tombstones, userScript) if err != nil { - cleanup8() cleanup7() cleanup6() cleanup5() @@ -286,8 +282,6 @@ func startLoopsFromFixture(fixture *all.Fixture, config *Config) ([]*logical.Loo return nil, nil, err } return v, func() { - cleanup9() - cleanup8() cleanup7() cleanup6() cleanup5() diff --git a/internal/source/logical/factory.go b/internal/source/logical/factory.go index ab0554056..f5c9d4315 100644 --- a/internal/source/logical/factory.go +++ b/internal/source/logical/factory.go @@ -60,7 +60,8 @@ func (f *Factory) Immediate(ctx context.Context, target ident.Schema) (Batcher, // no-ops and wait to exit. The implementation of Process would make // the Events / Batcher available externally. This has the downside // of actually needing to start the loop goroutines. - fake, cancel, err := f.newLoop(stopper.From(ctx), &LoopConfig{ + stop := stopper.WithContext(ctx) + fake, err := f.newLoop(stop, &LoopConfig{ Dialect: &fakeDialect{}, LoopName: fmt.Sprintf("immediate-%s", target.Raw()), TargetSchema: target, @@ -68,6 +69,10 @@ func (f *Factory) Immediate(ctx context.Context, target ident.Schema) (Batcher, if err != nil { return nil, nil, err } + cancel := func() { + stop.Stop(f.baseConfig.ApplyTimeout) + <-stop.Done() + } if f.baseConfig.Immediate { return fake.loop.events.fan, cancel, nil @@ -76,32 +81,27 @@ func (f *Factory) Immediate(ctx context.Context, target ident.Schema) (Batcher, } // Start constructs a new replication Loop. -func (f *Factory) Start(config *LoopConfig) (*Loop, func(), error) { +func (f *Factory) Start(ctx *stopper.Context, config *LoopConfig) (*Loop, error) { var err error // Ensure the configuration is set up and validated. config, err = f.expandConfig(config) if err != nil { - return nil, nil, err + return nil, err } // Construct the new loop and start it. - stop := stopper.WithContext(context.Background()) - loop, cleanup, err := f.newLoop(stop, config) + ctx = stopper.WithContext(ctx) + loop, err := f.newLoop(ctx, config) if err != nil { - return nil, nil, err - } - go loop.loop.run() - - // Perform a graceful shutdown and wait for the loop to exit. - grace := f.baseConfig.ApplyTimeout - cancel := func() { - stop.Stop(grace) - <-loop.Stopped() - cleanup() + return nil, err } + ctx.Go(func() error { + loop.loop.run() + return nil + }) - return loop, cancel, nil + return loop, nil } // expandConfig returns a preflighted copy of the configuration. @@ -121,10 +121,10 @@ func (f *Factory) expandConfig(config *LoopConfig) (*LoopConfig, error) { } // newLoop constructs a loop, but does not start it. -func (f *Factory) newLoop(ctx *stopper.Context, config *LoopConfig) (*Loop, func(), error) { +func (f *Factory) newLoop(ctx *stopper.Context, config *LoopConfig) (*Loop, error) { watcher, err := f.watchers.Get(ctx, config.TargetSchema) if err != nil { - return nil, nil, err + return nil, err } config = config.Copy() config.Dialect = WithChaos(config.Dialect, f.baseConfig.ChaosProb) @@ -135,7 +135,7 @@ func (f *Factory) newLoop(ctx *stopper.Context, config *LoopConfig) (*Loop, func } initialPoint, err := loop.loadConsistentPoint(ctx) if err != nil { - return nil, nil, err + return nil, err } loop.consistentPoint.Set(initialPoint) @@ -161,7 +161,7 @@ func (f *Factory) newLoop(ctx *stopper.Context, config *LoopConfig) (*Loop, func } else { // Sanity-check that there are no FKs defined. if len(watcher.Get().Order) > 1 { - return nil, nil, errors.New("the destination database has tables with foreign keys, " + + return nil, errors.New("the destination database has tables with foreign keys, " + "but support for FKs is not enabled") } } @@ -169,11 +169,14 @@ func (f *Factory) newLoop(ctx *stopper.Context, config *LoopConfig) (*Loop, func // Create a branch in the diagnostics reporting for the loop. loopDiags, err := f.diags.Wrap(config.LoopName) if err != nil { - return nil, nil, err + return nil, err } - cancel := func() { + // Unregister the loop on shutdown. + ctx.Go(func() error { + <-ctx.Stopping() f.diags.Unregister(config.LoopName) - } + return nil + }) userscript, err := script.Evaluate( ctx, @@ -184,8 +187,7 @@ func (f *Factory) newLoop(ctx *stopper.Context, config *LoopConfig) (*Loop, func f.watchers, ) if err != nil { - cancel() - return nil, nil, errors.Wrapf(err, "could not initialize userscript for %s", config.LoopName) + return nil, errors.Wrapf(err, "could not initialize userscript for %s", config.LoopName) } // Apply logic and configurations defined by the user-script. @@ -206,11 +208,10 @@ func (f *Factory) newLoop(ctx *stopper.Context, config *LoopConfig) (*Loop, func loop.metrics.backfillStatus = backfillStatus.WithLabelValues(config.LoopName) if err := loopDiags.Register("loop", loop); err != nil { - cancel() - return nil, nil, err + return nil, err } - return &Loop{loop, initialPoint}, cancel, nil + return &Loop{loop, initialPoint}, nil } // singletonChannel returns a channel that emits a single value and is diff --git a/internal/source/logical/logical_test.go b/internal/source/logical/logical_test.go index f476eeb3a..0d1873e6a 100644 --- a/internal/source/logical/logical_test.go +++ b/internal/source/logical/logical_test.go @@ -30,6 +30,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/util/batches" "github.com/cockroachdb/cdc-sink/internal/util/ident" "github.com/cockroachdb/cdc-sink/internal/util/stamp" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -196,7 +197,8 @@ func testLogicalSmoke(t *testing.T, mode *logicalTestMode) { } defer cancelFactory() - loop, cancelLoop, err := factory.Start(&logical.LoopConfig{ + loopStopper := stopper.WithContext(fixture.Context) + loop, err := factory.Start(loopStopper, &logical.LoopConfig{ Dialect: gen, LoopName: "generator", TargetSchema: dbName, @@ -242,7 +244,7 @@ func testLogicalSmoke(t *testing.T, mode *logicalTestMode) { } // Wait for the loop to shut down, or a timeout. - cancelLoop() + loopStopper.Stop(100 * time.Millisecond) gen.emit(0) // Kick the simplistic ReadInto loop so that it exits. select { case <-loop.Stopped(): @@ -351,13 +353,12 @@ api.configureTable("t_2", { r.NoError(err) defer cancelFactory() - _, cancelLoop, err := factory.Start(&logical.LoopConfig{ + _, err = factory.Start(fixture.Context, &logical.LoopConfig{ Dialect: gen, LoopName: "generator", TargetSchema: dbName, }) r.NoError(err) - defer cancelLoop() // Wait for replication. for idx, tgt := range tgts { diff --git a/internal/source/logical/loop.go b/internal/source/logical/loop.go index a048d7533..3afc55d3e 100644 --- a/internal/source/logical/loop.go +++ b/internal/source/logical/loop.go @@ -176,7 +176,7 @@ func (l *loop) run() { defer log.Debugf("replication loop %q shut down", l.loopConfig.LoopName) for { - err := l.runOnce(l.running) + err := l.runOnce() // Otherwise, log any error, and sleep for a bit. if err != nil { @@ -199,13 +199,14 @@ func (l *loop) run() { // runOnce is called by run. If the Dialect implements a leasing // behavior, a lease will be obtained before any further action is // taken. -func (l *loop) runOnce(ctx context.Context) error { +func (l *loop) runOnce() error { + var stop *stopper.Context if lessor, ok := l.loopConfig.Dialect.(Lessor); ok { // Loop until we can acquire a lease. var lease types.Lease for { var err error - lease, err = lessor.Acquire(ctx) + lease, err = lessor.Acquire(l.running) // Lease acquired. if err == nil { log.Tracef("lease %s acquired", l.loopConfig.LoopName) @@ -223,8 +224,8 @@ func (l *loop) runOnce(ctx context.Context) error { select { case <-time.After(duration): continue - case <-ctx.Done(): - return ctx.Err() + case <-l.running.Stopping(): + return nil } } // General err, defer to the loop's retry delay. @@ -232,11 +233,13 @@ func (l *loop) runOnce(ctx context.Context) error { } defer lease.Release() // Ensure that all work is bound to the lifetime of the lease. - ctx = lease.Context() + stop = stopper.WithContext(lease.Context()) + } else { + stop = l.running } // Ensure our in-memory consistent point matches the database. - point, err := l.loadConsistentPoint(ctx) + point, err := l.loadConsistentPoint(stop) if err != nil { return err } @@ -245,7 +248,7 @@ func (l *loop) runOnce(ctx context.Context) error { // Determine how to perform the filling. source, events, isBackfilling := l.chooseFillStrategy() - return l.runOnceUsing(stopper.From(ctx), source, events, isBackfilling) + return l.runOnceUsing(stop, source, events, isBackfilling) } // runOnceUsing is called from runOnce or doBackfill. @@ -429,13 +432,14 @@ func (l *loop) doBackfill(ctx context.Context, loopName string, backfiller Backf cfg.Dialect = backfiller cfg.LoopName = loopName - // The incoming context should already be a stopper. + // Create a (most likely nested) stopper. stop := stopper.WithContext(ctx) - filler, cleanup, err := l.factory.newLoop(stop, cfg) + filler, err := l.factory.newLoop(stop, cfg) if err != nil { return err } - defer cleanup() + // We don't need any grace time since the sub-loop has exited. + defer func() { stop.Stop(0) }() return filler.loop.runOnceUsing( stop, diff --git a/internal/source/mylogical/injector.go b/internal/source/mylogical/injector.go index 6501921e7..c7bf67460 100644 --- a/internal/source/mylogical/injector.go +++ b/internal/source/mylogical/injector.go @@ -27,13 +27,15 @@ import ( "github.com/cockroachdb/cdc-sink/internal/staging" "github.com/cockroachdb/cdc-sink/internal/target" "github.com/cockroachdb/cdc-sink/internal/util/diag" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" ) // Start creates a MySQL/MariaDB logical replication loop using the // provided configuration. -func Start(ctx context.Context, config *Config) (*MYLogical, func(), error) { +func Start(ctx *stopper.Context, config *Config) (*MYLogical, func(), error) { panic(wire.Build( + wire.Bind(new(context.Context), new(*stopper.Context)), wire.Bind(new(logical.Config), new(*Config)), wire.Struct(new(MYLogical), "*"), Set, diff --git a/internal/source/mylogical/provider.go b/internal/source/mylogical/provider.go index 656997152..599451f6b 100644 --- a/internal/source/mylogical/provider.go +++ b/internal/source/mylogical/provider.go @@ -21,6 +21,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/source/logical" "github.com/cockroachdb/cdc-sink/internal/types" "github.com/cockroachdb/cdc-sink/internal/util/ident" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/go-mysql-org/go-mysql/replication" "github.com/google/wire" ) @@ -64,8 +65,8 @@ func ProvideDialect(config *Config, _ *script.Loader) (logical.Dialect, error) { // ProvideLoop is called by Wire to construct the sole logical loop used // in the mylogical mode. func ProvideLoop( - cfg *Config, dialect logical.Dialect, loops *logical.Factory, -) (*logical.Loop, func(), error) { + ctx *stopper.Context, cfg *Config, dialect logical.Dialect, loops *logical.Factory, +) (*logical.Loop, error) { cfg.Dialect = dialect - return loops.Start(&cfg.LoopConfig) + return loops.Start(ctx, &cfg.LoopConfig) } diff --git a/internal/source/mylogical/wire_gen.go b/internal/source/mylogical/wire_gen.go index 45faca347..9c1ea4ffd 100644 --- a/internal/source/mylogical/wire_gen.go +++ b/internal/source/mylogical/wire_gen.go @@ -7,7 +7,6 @@ package mylogical import ( - "context" "github.com/cockroachdb/cdc-sink/internal/script" "github.com/cockroachdb/cdc-sink/internal/source/logical" "github.com/cockroachdb/cdc-sink/internal/staging/memo" @@ -17,13 +16,14 @@ import ( "github.com/cockroachdb/cdc-sink/internal/target/schemawatch" "github.com/cockroachdb/cdc-sink/internal/util/applycfg" "github.com/cockroachdb/cdc-sink/internal/util/diag" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" ) // Injectors from injector.go: // Start creates a MySQL/MariaDB logical replication loop using the // provided configuration. -func Start(ctx context.Context, config *Config) (*MYLogical, func(), error) { +func Start(ctx *stopper.Context, config *Config) (*MYLogical, func(), error) { diagnostics, cleanup := diag.New(ctx) scriptConfig, err := logical.ProvideUserScriptConfig(config) if err != nil { @@ -120,7 +120,7 @@ func Start(ctx context.Context, config *Config) (*MYLogical, func(), error) { cleanup() return nil, nil, err } - loop, cleanup7, err := ProvideLoop(config, dialect, factory) + loop, err := ProvideLoop(ctx, config, dialect, factory) if err != nil { cleanup6() cleanup5() @@ -135,7 +135,6 @@ func Start(ctx context.Context, config *Config) (*MYLogical, func(), error) { Loop: loop, } return myLogical, func() { - cleanup7() cleanup6() cleanup5() cleanup4() diff --git a/internal/source/pglogical/injector.go b/internal/source/pglogical/injector.go index 8f519aef5..4e1769b24 100644 --- a/internal/source/pglogical/injector.go +++ b/internal/source/pglogical/injector.go @@ -27,13 +27,15 @@ import ( "github.com/cockroachdb/cdc-sink/internal/staging" "github.com/cockroachdb/cdc-sink/internal/target" "github.com/cockroachdb/cdc-sink/internal/util/diag" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" ) // Start creates a PostgreSQL logical replication loop using the // provided configuration. -func Start(ctx context.Context, config *Config) (*PGLogical, func(), error) { +func Start(ctx *stopper.Context, config *Config) (*PGLogical, func(), error) { panic(wire.Build( + wire.Bind(new(context.Context), new(*stopper.Context)), wire.Bind(new(logical.Config), new(*Config)), wire.Struct(new(PGLogical), "*"), Set, diff --git a/internal/source/pglogical/provider.go b/internal/source/pglogical/provider.go index 58a95c437..83688ecf4 100644 --- a/internal/source/pglogical/provider.go +++ b/internal/source/pglogical/provider.go @@ -24,6 +24,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/types" "github.com/cockroachdb/cdc-sink/internal/util/ident" "github.com/cockroachdb/cdc-sink/internal/util/stdpool" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" "github.com/pkg/errors" log "github.com/sirupsen/logrus" @@ -103,8 +104,8 @@ func ProvideDialect( // ProvideLoop is called by Wire to construct the sole logical loop used // in the pglogical mode. func ProvideLoop( - cfg *Config, dialect logical.Dialect, loops *logical.Factory, -) (*logical.Loop, func(), error) { + ctx *stopper.Context, cfg *Config, dialect logical.Dialect, loops *logical.Factory, +) (*logical.Loop, error) { cfg.Dialect = dialect - return loops.Start(&cfg.LoopConfig) + return loops.Start(ctx, &cfg.LoopConfig) } diff --git a/internal/source/pglogical/wire_gen.go b/internal/source/pglogical/wire_gen.go index 1633130dd..7185329e6 100644 --- a/internal/source/pglogical/wire_gen.go +++ b/internal/source/pglogical/wire_gen.go @@ -7,7 +7,6 @@ package pglogical import ( - "context" "github.com/cockroachdb/cdc-sink/internal/script" "github.com/cockroachdb/cdc-sink/internal/source/logical" "github.com/cockroachdb/cdc-sink/internal/staging/memo" @@ -17,13 +16,14 @@ import ( "github.com/cockroachdb/cdc-sink/internal/target/schemawatch" "github.com/cockroachdb/cdc-sink/internal/util/applycfg" "github.com/cockroachdb/cdc-sink/internal/util/diag" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" ) // Injectors from injector.go: // Start creates a PostgreSQL logical replication loop using the // provided configuration. -func Start(ctx context.Context, config *Config) (*PGLogical, func(), error) { +func Start(ctx *stopper.Context, config *Config) (*PGLogical, func(), error) { diagnostics, cleanup := diag.New(ctx) scriptConfig, err := logical.ProvideUserScriptConfig(config) if err != nil { @@ -120,7 +120,7 @@ func Start(ctx context.Context, config *Config) (*PGLogical, func(), error) { cleanup() return nil, nil, err } - loop, cleanup7, err := ProvideLoop(config, dialect, factory) + loop, err := ProvideLoop(ctx, config, dialect, factory) if err != nil { cleanup6() cleanup5() @@ -135,7 +135,6 @@ func Start(ctx context.Context, config *Config) (*PGLogical, func(), error) { Loop: loop, } return pgLogical, func() { - cleanup7() cleanup6() cleanup5() cleanup4() diff --git a/internal/source/server/injector.go b/internal/source/server/injector.go index da1edf4c9..57104dec1 100644 --- a/internal/source/server/injector.go +++ b/internal/source/server/injector.go @@ -28,10 +28,11 @@ import ( "github.com/cockroachdb/cdc-sink/internal/staging" "github.com/cockroachdb/cdc-sink/internal/target" "github.com/cockroachdb/cdc-sink/internal/util/diag" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" ) -func NewServer(ctx context.Context, config *Config) (*Server, func(), error) { +func NewServer(ctx *stopper.Context, config *Config) (*Server, func(), error) { panic(wire.Build( Set, cdc.Set, @@ -40,6 +41,7 @@ func NewServer(ctx context.Context, config *Config) (*Server, func(), error) { script.Set, staging.Set, target.Set, + wire.Bind(new(context.Context), new(*stopper.Context)), wire.Bind(new(logical.Config), new(*Config)), wire.FieldsOf(new(*Config), "CDC"), )) diff --git a/internal/source/server/integration_test.go b/internal/source/server/integration_test.go index 6414c5a3f..488738748 100644 --- a/internal/source/server/integration_test.go +++ b/internal/source/server/integration_test.go @@ -35,6 +35,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/util/hlc" "github.com/cockroachdb/cdc-sink/internal/util/ident" "github.com/cockroachdb/cdc-sink/internal/util/stdlogical" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" joonix "github.com/joonix/log" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" @@ -153,7 +154,7 @@ func testIntegration(t *testing.T, cfg testConfig) { targetPool := destFixture.TargetPool // The target fixture contains the cdc-sink server. - targetFixture, cancel, err := newTestFixture(ctx, &Config{ + targetFixture, cancel, err := newTestFixture(stopper.WithContext(ctx), &Config{ CDC: cdc.Config{ BaseConfig: logical.BaseConfig{ Immediate: cfg.immediate, diff --git a/internal/source/server/test_fixture.go b/internal/source/server/test_fixture.go index 4cb004bc4..42fb698c2 100644 --- a/internal/source/server/test_fixture.go +++ b/internal/source/server/test_fixture.go @@ -31,6 +31,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/types" "github.com/cockroachdb/cdc-sink/internal/util/diag" "github.com/cockroachdb/cdc-sink/internal/util/ident" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" ) @@ -48,7 +49,7 @@ type testFixture struct { // We want this to be as close as possible to Start, it just exposes // additional plumbing details via the returned testFixture pointer. -func newTestFixture(context.Context, *Config) (*testFixture, func(), error) { +func newTestFixture(*stopper.Context, *Config) (*testFixture, func(), error) { panic(wire.Build( Set, cdc.Set, @@ -57,6 +58,7 @@ func newTestFixture(context.Context, *Config) (*testFixture, func(), error) { script.Set, staging.Set, target.Set, + wire.Bind(new(context.Context), new(*stopper.Context)), wire.Bind(new(logical.Config), new(*Config)), wire.FieldsOf(new(*Config), "CDC"), wire.Struct(new(testFixture), "*"), diff --git a/internal/source/server/wire_gen.go b/internal/source/server/wire_gen.go index 878be13bc..f388ddf33 100644 --- a/internal/source/server/wire_gen.go +++ b/internal/source/server/wire_gen.go @@ -7,7 +7,6 @@ package server import ( - "context" "github.com/cockroachdb/cdc-sink/internal/script" "github.com/cockroachdb/cdc-sink/internal/source/cdc" "github.com/cockroachdb/cdc-sink/internal/source/logical" @@ -22,12 +21,13 @@ import ( "github.com/cockroachdb/cdc-sink/internal/util/applycfg" "github.com/cockroachdb/cdc-sink/internal/util/diag" "github.com/cockroachdb/cdc-sink/internal/util/ident" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "net" ) // Injectors from injector.go: -func NewServer(ctx context.Context, config *Config) (*Server, func(), error) { +func NewServer(ctx *stopper.Context, config *Config) (*Server, func(), error) { diagnostics, cleanup := diag.New(ctx) scriptConfig, err := logical.ProvideUserScriptConfig(config) if err != nil { @@ -170,8 +170,8 @@ func NewServer(ctx context.Context, config *Config) (*Server, func(), error) { return nil, nil, err } metaTable := cdc.ProvideMetaTable(cdcConfig) - stagers := stage.ProvideFactory(stagingPool, stagingSchema) - resolvers, cleanup10, err := cdc.ProvideResolvers(ctx, cdcConfig, typesLeases, factory, metaTable, stagingPool, stagers, watchers) + stagers := stage.ProvideFactory(stagingPool, stagingSchema, ctx) + resolvers, err := cdc.ProvideResolvers(ctx, cdcConfig, typesLeases, factory, metaTable, stagingPool, stagers, watchers) if err != nil { cleanup9() cleanup8() @@ -196,7 +196,6 @@ func NewServer(ctx context.Context, config *Config) (*Server, func(), error) { serveMux := ProvideMux(handler, stagingPool, targetPool) tlsConfig, err := ProvideTLSConfig(config) if err != nil { - cleanup10() cleanup9() cleanup8() cleanup7() @@ -208,9 +207,8 @@ func NewServer(ctx context.Context, config *Config) (*Server, func(), error) { cleanup() return nil, nil, err } - server, cleanup11 := ProvideServer(authenticator, diagnostics, listener, serveMux, tlsConfig) + server, cleanup10 := ProvideServer(authenticator, diagnostics, listener, serveMux, tlsConfig) return server, func() { - cleanup11() cleanup10() cleanup9() cleanup8() @@ -228,8 +226,8 @@ func NewServer(ctx context.Context, config *Config) (*Server, func(), error) { // We want this to be as close as possible to Start, it just exposes // additional plumbing details via the returned testFixture pointer. -func newTestFixture(contextContext context.Context, config *Config) (*testFixture, func(), error) { - diagnostics, cleanup := diag.New(contextContext) +func newTestFixture(context *stopper.Context, config *Config) (*testFixture, func(), error) { + diagnostics, cleanup := diag.New(context) scriptConfig, err := logical.ProvideUserScriptConfig(config) if err != nil { cleanup() @@ -245,7 +243,7 @@ func newTestFixture(contextContext context.Context, config *Config) (*testFixtur cleanup() return nil, nil, err } - stagingPool, cleanup2, err := logical.ProvideStagingPool(contextContext, baseConfig, diagnostics) + stagingPool, cleanup2, err := logical.ProvideStagingPool(context, baseConfig, diagnostics) if err != nil { cleanup() return nil, nil, err @@ -256,7 +254,7 @@ func newTestFixture(contextContext context.Context, config *Config) (*testFixtur cleanup() return nil, nil, err } - authenticator, cleanup3, err := ProvideAuthenticator(contextContext, diagnostics, config, stagingPool, stagingSchema) + authenticator, cleanup3, err := ProvideAuthenticator(context, diagnostics, config, stagingPool, stagingSchema) if err != nil { cleanup2() cleanup() @@ -270,7 +268,7 @@ func newTestFixture(contextContext context.Context, config *Config) (*testFixtur return nil, nil, err } cdcConfig := &config.CDC - targetPool, cleanup5, err := logical.ProvideTargetPool(contextContext, baseConfig, diagnostics) + targetPool, cleanup5, err := logical.ProvideTargetPool(context, baseConfig, diagnostics) if err != nil { cleanup4() cleanup3() @@ -320,7 +318,7 @@ func newTestFixture(contextContext context.Context, config *Config) (*testFixtur cleanup() return nil, nil, err } - memoMemo, err := memo.ProvideMemo(contextContext, stagingPool, stagingSchema) + memoMemo, err := memo.ProvideMemo(context, stagingPool, stagingSchema) if err != nil { cleanup8() cleanup7() @@ -333,7 +331,7 @@ func newTestFixture(contextContext context.Context, config *Config) (*testFixtur return nil, nil, err } checker := version.ProvideChecker(stagingPool, memoMemo) - factory, err := logical.ProvideFactory(contextContext, appliers, configs, baseConfig, diagnostics, memoMemo, loader, stagingPool, targetPool, watchers, checker) + factory, err := logical.ProvideFactory(context, appliers, configs, baseConfig, diagnostics, memoMemo, loader, stagingPool, targetPool, watchers, checker) if err != nil { cleanup8() cleanup7() @@ -357,7 +355,7 @@ func newTestFixture(contextContext context.Context, config *Config) (*testFixtur cleanup() return nil, nil, err } - typesLeases, err := leases.ProvideLeases(contextContext, stagingPool, stagingSchema) + typesLeases, err := leases.ProvideLeases(context, stagingPool, stagingSchema) if err != nil { cleanup9() cleanup8() @@ -371,8 +369,8 @@ func newTestFixture(contextContext context.Context, config *Config) (*testFixtur return nil, nil, err } metaTable := cdc.ProvideMetaTable(cdcConfig) - stagers := stage.ProvideFactory(stagingPool, stagingSchema) - resolvers, cleanup10, err := cdc.ProvideResolvers(contextContext, cdcConfig, typesLeases, factory, metaTable, stagingPool, stagers, watchers) + stagers := stage.ProvideFactory(stagingPool, stagingSchema, context) + resolvers, err := cdc.ProvideResolvers(context, cdcConfig, typesLeases, factory, metaTable, stagingPool, stagers, watchers) if err != nil { cleanup9() cleanup8() @@ -397,7 +395,6 @@ func newTestFixture(contextContext context.Context, config *Config) (*testFixtur serveMux := ProvideMux(handler, stagingPool, targetPool) tlsConfig, err := ProvideTLSConfig(config) if err != nil { - cleanup10() cleanup9() cleanup8() cleanup7() @@ -409,7 +406,7 @@ func newTestFixture(contextContext context.Context, config *Config) (*testFixtur cleanup() return nil, nil, err } - server, cleanup11 := ProvideServer(authenticator, diagnostics, listener, serveMux, tlsConfig) + server, cleanup10 := ProvideServer(authenticator, diagnostics, listener, serveMux, tlsConfig) serverTestFixture := &testFixture{ Authenticator: authenticator, Config: config, @@ -422,7 +419,6 @@ func newTestFixture(contextContext context.Context, config *Config) (*testFixtur Watcher: watchers, } return serverTestFixture, func() { - cleanup11() cleanup10() cleanup9() cleanup8() diff --git a/internal/staging/stage/factory.go b/internal/staging/stage/factory.go index 0a0455b24..f93de0815 100644 --- a/internal/staging/stage/factory.go +++ b/internal/staging/stage/factory.go @@ -26,6 +26,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/types" "github.com/cockroachdb/cdc-sink/internal/util/hlc" "github.com/cockroachdb/cdc-sink/internal/util/ident" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/jackc/pgx/v5" "github.com/pkg/errors" log "github.com/sirupsen/logrus" @@ -34,6 +35,7 @@ import ( type factory struct { db *types.StagingPool stagingDB ident.Schema + stop *stopper.Context mu struct { sync.RWMutex @@ -44,14 +46,14 @@ type factory struct { var _ types.Stagers = (*factory)(nil) // Get returns a memoized instance of a stage for the given table. -func (f *factory) Get(ctx context.Context, target ident.Table) (types.Stager, error) { +func (f *factory) Get(_ context.Context, target ident.Table) (types.Stager, error) { if ret := f.getUnlocked(target); ret != nil { return ret, nil } - return f.createUnlocked(ctx, target) + return f.createUnlocked(target) } -func (f *factory) createUnlocked(ctx context.Context, table ident.Table) (*stage, error) { +func (f *factory) createUnlocked(table ident.Table) (*stage, error) { f.mu.Lock() defer f.mu.Unlock() @@ -59,7 +61,7 @@ func (f *factory) createUnlocked(ctx context.Context, table ident.Table) (*stage return ret, nil } - ret, err := newStore(ctx, f.db, f.stagingDB, table) + ret, err := newStore(f.stop, f.db, f.stagingDB, table) if err == nil { f.mu.instances.Put(table, ret) } diff --git a/internal/staging/stage/provider.go b/internal/staging/stage/provider.go index a19d86850..cb7d7cfef 100644 --- a/internal/staging/stage/provider.go +++ b/internal/staging/stage/provider.go @@ -19,6 +19,7 @@ package stage import ( "github.com/cockroachdb/cdc-sink/internal/types" "github.com/cockroachdb/cdc-sink/internal/util/ident" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/google/wire" ) @@ -28,10 +29,13 @@ var Set = wire.NewSet( ) // ProvideFactory is called by Wire to construct the Stagers factory. -func ProvideFactory(db *types.StagingPool, stagingDB ident.StagingSchema) types.Stagers { +func ProvideFactory( + db *types.StagingPool, stagingDB ident.StagingSchema, stop *stopper.Context, +) types.Stagers { f := &factory{ db: db, stagingDB: stagingDB.Schema(), + stop: stop, } f.mu.instances = &ident.TableMap[*stage]{} return f diff --git a/internal/staging/stage/stage.go b/internal/staging/stage/stage.go index 226b39165..f3592a9ba 100644 --- a/internal/staging/stage/stage.go +++ b/internal/staging/stage/stage.go @@ -32,6 +32,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/util/ident" "github.com/cockroachdb/cdc-sink/internal/util/metrics" "github.com/cockroachdb/cdc-sink/internal/util/retry" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/pkg/errors" @@ -79,7 +80,7 @@ var _ types.Stager = (*stage)(nil) // newStore constructs a new mutation stage that will track pending // mutations to be applied to the given target table. func newStore( - ctx context.Context, db *types.StagingPool, stagingDB ident.Schema, target ident.Table, + ctx *stopper.Context, db *types.StagingPool, stagingDB ident.Schema, target ident.Table, ) (*stage, error) { table := stagingTable(stagingDB, target) diff --git a/main.go b/main.go index c26c2bf09..865f290b0 100644 --- a/main.go +++ b/main.go @@ -39,6 +39,7 @@ import ( "github.com/cockroachdb/cdc-sink/internal/cmd/version" "github.com/cockroachdb/cdc-sink/internal/script" "github.com/cockroachdb/cdc-sink/internal/util/logfmt" + "github.com/cockroachdb/cdc-sink/internal/util/stopper" joonix "github.com/joonix/log" "github.com/pkg/errors" log "github.com/sirupsen/logrus" @@ -46,6 +47,7 @@ import ( ) func main() { + var gracePeriod time.Duration var logFormat, logDestination string var verbosity int root := &cobra.Command{ @@ -96,6 +98,7 @@ func main() { }, } f := root.PersistentFlags() + f.DurationVar(&gracePeriod, "gracePeriod", 30*time.Second, "allow background processes to exit") f.StringVar(&logFormat, "logFormat", "text", "choose log output format [ fluent, text ]") f.StringVar(&logDestination, "logDestination", "", "write logs to a file, instead of stdout") f.CountVarP(&verbosity, "verbose", "v", "increase logging verbosity to debug; repeat for trace") @@ -114,10 +117,28 @@ func main() { version.Command(), ) - ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) - log.DeferExitHandler(cancel) + stop := stopper.WithContext(context.Background()) + // Stop cleanly on interrupt. + stop.Go(func() error { + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) + defer cancel() + select { + case <-ctx.Done(): + log.Info("Interrupted") + stop.Stop(gracePeriod) + case <-stop.Stopping(): + // Nothing to do. + } + return nil + }) + // Allow log.Exit() or log.Fatal() to trigger shutdown. + log.DeferExitHandler(func() { + stop.Stop(gracePeriod) + <-stop.Done() + }) - if err := root.ExecuteContext(ctx); err != nil { + // Commands can unwrap the stopper as needed. + if err := root.ExecuteContext(stop); err != nil { log.WithError(err).Error("exited") log.Exit(1) }