From c94edc49221cf0035d5df8623e18282ad4aaf83d Mon Sep 17 00:00:00 2001 From: David Syers Date: Mon, 19 May 2025 13:25:18 +0100 Subject: [PATCH 01/11] Initial --- .golangci.yml | 47 ++++--- database/cassandra/cassandra.go | 18 +++ database/clickhouse/clickhouse.go | 18 +++ database/cockroachdb/cockroachdb.go | 18 +++ database/driver.go | 15 +++ database/driver_test.go | 6 + database/firebird/firebird.go | 18 +++ database/mongodb/mongodb.go | 18 +++ database/mysql/mysql.go | 55 ++++++++ database/neo4j/neo4j.go | 18 +++ database/pgx/pgx.go | 18 +++ database/pgx/v5/pgx.go | 18 +++ database/postgres/postgres.go | 18 +++ database/ql/ql.go | 19 +++ database/redshift/redshift.go | 18 +++ database/rqlite/rqlite.go | 18 +++ database/snowflake/snowflake.go | 18 +++ database/spanner/spanner.go | 18 +++ database/sqlcipher/sqlcipher.go | 18 +++ database/sqlite/sqlite.go | 18 +++ database/sqlite3/sqlite3.go | 18 +++ database/sqlserver/sqlserver.go | 18 +++ database/stub/stub.go | 20 ++- database/yugabytedb/yugabytedb.go | 18 +++ migrate.go | 192 ++++++++++++++++++---------- 25 files changed, 589 insertions(+), 89 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 68a8e953b..913bddb87 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,30 +1,37 @@ -run: - # timeout for analysis, e.g. 30s, 5m, default is 1m - timeout: 5m +version: "2" linters: enable: - #- golint - #- interfacer - - unconvert - #- dupl - goconst - - gofmt - misspell - - unparam - nakedret - prealloc - revive - #- gosec -linters-settings: - misspell: - locale: US - revive: + - unconvert + - unparam + settings: + misspell: + locale: US + revive: + rules: + - name: redundant-build-tag + exclusions: + generated: lax rules: - - name: redundant-build-tag + - path: (.+)\.go$ + text: G104 + paths: + - third_party$ + - builtin$ + - examples$ issues: - max-same-issues: 0 max-issues-per-linter: 0 - exclude-use-default: false - exclude: - # gosec: Duplicated errcheck checks - - G104 + max-same-issues: 0 +formatters: + enable: + - gofmt + exclusions: + generated: lax + paths: + - third_party$ + - builtin$ + - examples$ diff --git a/database/cassandra/cassandra.go b/database/cassandra/cassandra.go index 74eecc98e..a231a2eae 100644 --- a/database/cassandra/cassandra.go +++ b/database/cassandra/cassandra.go @@ -42,6 +42,8 @@ type Config struct { KeyspaceName string MultiStatementEnabled bool MultiStatementMaxSize int + + Triggers map[string]func(d database.Driver, detail interface{}) error } type Cassandra struct { @@ -198,6 +200,22 @@ func (c *Cassandra) Close() error { return nil } +func (c *Cassandra) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + c.config.Triggers = t +} + +func (c *Cassandra) Trigger(name string, detail interface{}) error { + if c.config.Triggers == nil { + return nil + } + + if trigger, ok := c.config.Triggers[name]; ok { + return trigger(c, detail) + } + + return nil +} + func (c *Cassandra) Lock() error { if !c.isLocked.CAS(false, true) { return database.ErrLocked diff --git a/database/clickhouse/clickhouse.go b/database/clickhouse/clickhouse.go index d2b65c0ce..8382110af 100644 --- a/database/clickhouse/clickhouse.go +++ b/database/clickhouse/clickhouse.go @@ -34,6 +34,8 @@ type Config struct { MigrationsTableEngine string MultiStatementEnabled bool MultiStatementMaxSize int + + Triggers map[string]func(d database.Driver, detail interface{}) error } func init() { @@ -306,6 +308,22 @@ func (ch *ClickHouse) Unlock() error { } func (ch *ClickHouse) Close() error { return ch.conn.Close() } +func (ch *ClickHouse) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + ch.config.Triggers = t +} + +func (ch *ClickHouse) Trigger(name string, detail interface{}) error { + if ch.config.Triggers == nil { + return nil + } + + if trigger, ok := ch.config.Triggers[name]; ok { + return trigger(ch, detail) + } + + return nil +} + // Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611 func quoteIdentifier(name string) string { end := strings.IndexRune(name, 0) diff --git a/database/cockroachdb/cockroachdb.go b/database/cockroachdb/cockroachdb.go index 699b3facd..0cb04612e 100644 --- a/database/cockroachdb/cockroachdb.go +++ b/database/cockroachdb/cockroachdb.go @@ -37,6 +37,8 @@ type Config struct { LockTable string ForceLock bool DatabaseName string + + Triggers map[string]func(d database.Driver, detail interface{}) error } type CockroachDb struct { @@ -144,6 +146,22 @@ func (c *CockroachDb) Close() error { return c.db.Close() } +func (c *CockroachDb) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + c.config.Triggers = t +} + +func (c *CockroachDb) Trigger(name string, detail interface{}) error { + if c.config.Triggers == nil { + return nil + } + + if trigger, ok := c.config.Triggers[name]; ok { + return trigger(c, detail) + } + + return nil +} + // Locking is done manually with a separate lock table. Implementing advisory locks in CRDB is being discussed // See: https://github.com/cockroachdb/cockroach/issues/13546 func (c *CockroachDb) Lock() error { diff --git a/database/driver.go b/database/driver.go index 11268e6b9..0536e5cb0 100644 --- a/database/driver.go +++ b/database/driver.go @@ -19,6 +19,14 @@ var ( const NilVersion int = -1 +const TrigRunPre string = "RunPre" +const TrigRunPost string = "RunPost" +const TrigSetVersionPre string = "SetVersionPre" +const TrigSetVersionPost string = "SetVersionPost" +const TrigVersionTableExists string = "VersionTableExists" +const TrigVersionTablePre string = "VersionTablePre" +const TrigVersionTablePost string = "VersionTablePost" + var driversMu sync.RWMutex var drivers = make(map[string]Driver) @@ -52,6 +60,13 @@ type Driver interface { // Migrate will call this function only once per instance. Close() error + // AddTriggers adds triggers to the database. The map key is the trigger name + AddTriggers(t map[string]func(m Driver, detail interface{}) error) + + // Trigger is called when a trigger is fired. The name is the trigger name + // and detail is the trigger detail. + Trigger(name string, detail interface{}) error + // Lock should acquire a database lock so that only one migration process // can run at a time. Migrate will call this function before Run is called. // If the implementation can't provide this functionality, return nil. diff --git a/database/driver_test.go b/database/driver_test.go index 7880f3208..65fed957d 100644 --- a/database/driver_test.go +++ b/database/driver_test.go @@ -28,6 +28,12 @@ func (m *mockDriver) Close() error { return nil } +func (m *mockDriver) AddTriggers(t map[string]func(m Driver, detail interface{}) error) {} + +func (m *mockDriver) Trigger(name string, detail interface{}) error { + return nil +} + func (m *mockDriver) Lock() error { return nil } diff --git a/database/firebird/firebird.go b/database/firebird/firebird.go index e15ea96b8..3b7191cc3 100644 --- a/database/firebird/firebird.go +++ b/database/firebird/firebird.go @@ -31,6 +31,8 @@ var ( type Config struct { DatabaseName string MigrationsTable string + + Triggers map[string]func(d database.Driver, detail interface{}) error } type Firebird struct { @@ -106,6 +108,22 @@ func (f *Firebird) Close() error { return nil } +func (f *Firebird) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + f.config.Triggers = t +} + +func (f *Firebird) Trigger(name string, detail interface{}) error { + if f.config.Triggers == nil { + return nil + } + + if trigger, ok := f.config.Triggers[name]; ok { + return trigger(f, detail) + } + + return nil +} + func (f *Firebird) Lock() error { if !f.isLocked.CAS(false, true) { return database.ErrLocked diff --git a/database/mongodb/mongodb.go b/database/mongodb/mongodb.go index 3a9a6be9e..4dcbd4f64 100644 --- a/database/mongodb/mongodb.go +++ b/database/mongodb/mongodb.go @@ -59,6 +59,8 @@ type Config struct { MigrationsCollection string TransactionMode bool Locking Locking + + Triggers map[string]func(d database.Driver, detail interface{}) error } type versionInfo struct { Version int `bson:"version"` @@ -297,6 +299,22 @@ func (m *Mongo) Close() error { return m.client.Disconnect(context.TODO()) } +func (m *Mongo) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + m.config.Triggers = t +} + +func (m *Mongo) Trigger(name string, detail interface{}) error { + if m.config.Triggers == nil { + return nil + } + + if trigger, ok := m.config.Triggers[name]; ok { + return trigger(m, detail) + } + + return nil +} + func (m *Mongo) Drop() error { return m.db.Drop(context.TODO()) } diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index 711ba5187..028855e14 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -43,6 +43,8 @@ type Config struct { DatabaseName string NoLock bool StatementTimeout time.Duration + + Triggers map[string]func(d database.Driver, detail interface{}) error } type Mysql struct { @@ -283,6 +285,22 @@ func (m *Mysql) Close() error { return nil } +func (m *Mysql) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + m.config.Triggers = t +} + +func (m *Mysql) Trigger(name string, detail interface{}) error { + if m.config.Triggers == nil { + return nil + } + + if trigger, ok := m.config.Triggers[name]; ok { + return trigger(m, detail) + } + + return nil +} + func (m *Mysql) Lock() error { return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error { if m.config.NoLock { @@ -347,9 +365,15 @@ func (m *Mysql) Run(migration io.Reader) error { } query := string(migr[:]) + if err := m.Trigger(database.TrigRunPre, struct{ Query string }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } if _, err := m.conn.ExecContext(ctx, query); err != nil { return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } + if err := m.Trigger(database.TrigRunPost, struct{ Query string }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return nil } @@ -360,6 +384,16 @@ func (m *Mysql) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := m.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := "DELETE FROM `" + m.config.MigrationsTable + "` LIMIT 1" if _, err := tx.ExecContext(context.Background(), query); err != nil { if errRollback := tx.Rollback(); errRollback != nil { @@ -381,6 +415,16 @@ func (m *Mysql) SetVersion(version int, dirty bool) error { } } + if err := m.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } @@ -486,14 +530,25 @@ func (m *Mysql) ensureVersionTable() (err error) { return &database.Error{OrigErr: err, Query: []byte(query)} } } else { + if err := m.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := m.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + // if not, create the empty migration table query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)" if _, err := m.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + + if err := m.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } return nil } diff --git a/database/neo4j/neo4j.go b/database/neo4j/neo4j.go index 179e0da60..1cde8f97b 100644 --- a/database/neo4j/neo4j.go +++ b/database/neo4j/neo4j.go @@ -34,6 +34,8 @@ type Config struct { MigrationsLabel string MultiStatement bool MultiStatementMaxSize int + + Triggers map[string]func(d database.Driver, detail interface{}) error } type Neo4j struct { @@ -118,6 +120,22 @@ func (n *Neo4j) Close() error { return n.driver.Close() } +func (n *Neo4j) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + n.config.Triggers = t +} + +func (n *Neo4j) Trigger(name string, detail interface{}) error { + if n.config.Triggers == nil { + return nil + } + + if trigger, ok := n.config.Triggers[name]; ok { + return trigger(n, detail) + } + + return nil +} + // local locking in order to pass tests, Neo doesn't support database locking func (n *Neo4j) Lock() error { if !atomic.CompareAndSwapUint32(&n.lock, 0, 1) { diff --git a/database/pgx/pgx.go b/database/pgx/pgx.go index efe8bea80..9b6021f55 100644 --- a/database/pgx/pgx.go +++ b/database/pgx/pgx.go @@ -64,6 +64,8 @@ type Config struct { MigrationsTableQuoted bool MultiStatementEnabled bool MultiStatementMaxSize int + + Triggers map[string]func(d database.Driver, detail interface{}) error } type Postgres struct { @@ -247,6 +249,22 @@ func (p *Postgres) Close() error { return nil } +func (p *Postgres) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + p.config.Triggers = t +} + +func (p *Postgres) Trigger(name string, detail interface{}) error { + if p.config.Triggers == nil { + return nil + } + + if trigger, ok := p.config.Triggers[name]; ok { + return trigger(p, detail) + } + + return nil +} + func (p *Postgres) Lock() error { return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { switch p.config.LockStrategy { diff --git a/database/pgx/v5/pgx.go b/database/pgx/v5/pgx.go index 303174495..fe626b3c8 100644 --- a/database/pgx/v5/pgx.go +++ b/database/pgx/v5/pgx.go @@ -52,6 +52,8 @@ type Config struct { MigrationsTableQuoted bool MultiStatementEnabled bool MultiStatementMaxSize int + + Triggers map[string]func(d database.Driver, detail interface{}) error } type Postgres struct { @@ -218,6 +220,22 @@ func (p *Postgres) Close() error { return nil } +func (p *Postgres) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + p.config.Triggers = t +} + +func (p *Postgres) Trigger(name string, detail interface{}) error { + if p.config.Triggers == nil { + return nil + } + + if trigger, ok := p.config.Triggers[name]; ok { + return trigger(p, detail) + } + + return nil +} + // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS func (p *Postgres) Lock() error { return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index 5e4519115..dc25dd377 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -52,6 +52,8 @@ type Config struct { migrationsTableName string StatementTimeout time.Duration MultiStatementMaxSize int + + Triggers map[string]func(d database.Driver, detail interface{}) error } type Postgres struct { @@ -230,6 +232,22 @@ func (p *Postgres) Close() error { return nil } +func (p *Postgres) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + p.config.Triggers = t +} + +func (p *Postgres) Trigger(name string, detail interface{}) error { + if p.config.Triggers == nil { + return nil + } + + if trigger, ok := p.config.Triggers[name]; ok { + return trigger(p, detail) + } + + return nil +} + // https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS func (p *Postgres) Lock() error { return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error { diff --git a/database/ql/ql.go b/database/ql/ql.go index 37c062455..5f7936a69 100644 --- a/database/ql/ql.go +++ b/database/ql/ql.go @@ -30,6 +30,8 @@ var ( type Config struct { MigrationsTable string DatabaseName string + + Triggers map[string]func(d database.Driver, detail interface{}) error } type Ql struct { @@ -125,6 +127,23 @@ func (m *Ql) Open(url string) (database.Driver, error) { func (m *Ql) Close() error { return m.db.Close() } + +func (m *Ql) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + m.config.Triggers = t +} + +func (m *Ql) Trigger(name string, detail interface{}) error { + if m.config.Triggers == nil { + return nil + } + + if trigger, ok := m.config.Triggers[name]; ok { + return trigger(m, detail) + } + + return nil +} + func (m *Ql) Drop() (err error) { query := `SELECT Name FROM __Table` tables, err := m.db.Query(query) diff --git a/database/redshift/redshift.go b/database/redshift/redshift.go index 7687b9d9a..87c07d58e 100644 --- a/database/redshift/redshift.go +++ b/database/redshift/redshift.go @@ -34,6 +34,8 @@ var ( type Config struct { MigrationsTable string DatabaseName string + + Triggers map[string]func(d database.Driver, detail interface{}) error } type Redshift struct { @@ -125,6 +127,22 @@ func (p *Redshift) Close() error { return nil } +func (p *Redshift) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + p.config.Triggers = t +} + +func (p *Redshift) Trigger(name string, detail interface{}) error { + if p.config.Triggers == nil { + return nil + } + + if trigger, ok := p.config.Triggers[name]; ok { + return trigger(p, detail) + } + + return nil +} + // Redshift does not support advisory lock functions: https://docs.aws.amazon.com/redshift/latest/dg/c_unsupported-postgresql-functions.html func (p *Redshift) Lock() error { if !p.isLocked.CAS(false, true) { diff --git a/database/rqlite/rqlite.go b/database/rqlite/rqlite.go index af0d53007..e6227a7f1 100644 --- a/database/rqlite/rqlite.go +++ b/database/rqlite/rqlite.go @@ -39,6 +39,8 @@ type Config struct { ConnectInsecure bool // MigrationsTable configures the migrations table name MigrationsTable string + + Triggers map[string]func(d database.Driver, detail interface{}) error } type Rqlite struct { @@ -138,6 +140,22 @@ func (r *Rqlite) Close() error { return nil } +func (r *Rqlite) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + r.config.Triggers = t +} + +func (r *Rqlite) Trigger(name string, detail interface{}) error { + if r.config.Triggers == nil { + return nil + } + + if trigger, ok := r.config.Triggers[name]; ok { + return trigger(r, detail) + } + + return nil +} + // Lock should acquire a database lock so that only one migration process // can run at a time. Migrate will call this function before Run is called. // If the implementation can't provide this functionality, return nil. diff --git a/database/snowflake/snowflake.go b/database/snowflake/snowflake.go index 46ce30200..d59f0cab8 100644 --- a/database/snowflake/snowflake.go +++ b/database/snowflake/snowflake.go @@ -35,6 +35,8 @@ var ( type Config struct { MigrationsTable string DatabaseName string + + Triggers map[string]func(d database.Driver, detail interface{}) error } type Snowflake struct { @@ -158,6 +160,22 @@ func (p *Snowflake) Close() error { return nil } +func (p *Snowflake) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + p.config.Triggers = t +} + +func (p *Snowflake) Trigger(name string, detail interface{}) error { + if p.config.Triggers == nil { + return nil + } + + if trigger, ok := p.config.Triggers[name]; ok { + return trigger(p, detail) + } + + return nil +} + func (p *Snowflake) Lock() error { if !p.isLocked.CAS(false, true) { return database.ErrLocked diff --git a/database/spanner/spanner.go b/database/spanner/spanner.go index b733302d5..b4b5fa711 100644 --- a/database/spanner/spanner.go +++ b/database/spanner/spanner.go @@ -56,6 +56,8 @@ type Config struct { // Parsing outputs clean DDL statements such as reformatted // and void of comments. CleanStatements bool + + Triggers map[string]func(d database.Driver, detail interface{}) error } // Spanner implements database.Driver for Google Cloud Spanner @@ -150,6 +152,22 @@ func (s *Spanner) Close() error { return s.db.admin.Close() } +func (s *Spanner) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + s.config.Triggers = t +} + +func (s *Spanner) Trigger(name string, detail interface{}) error { + if s.config.Triggers == nil { + return nil + } + + if trigger, ok := s.config.Triggers[name]; ok { + return trigger(s, detail) + } + + return nil +} + // Lock implements database.Driver but doesn't do anything because Spanner only // enqueues the UpdateDatabaseDdlRequest. func (s *Spanner) Lock() error { diff --git a/database/sqlcipher/sqlcipher.go b/database/sqlcipher/sqlcipher.go index f98fb3a21..42bb52692 100644 --- a/database/sqlcipher/sqlcipher.go +++ b/database/sqlcipher/sqlcipher.go @@ -31,6 +31,8 @@ type Config struct { MigrationsTable string DatabaseName string NoTxWrap bool + + Triggers map[string]func(d database.Driver, detail interface{}) error } type Sqlite struct { @@ -133,6 +135,22 @@ func (m *Sqlite) Close() error { return m.db.Close() } +func (m *Sqlite) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + m.config.Triggers = t +} + +func (m *Sqlite) Trigger(name string, detail interface{}) error { + if m.config.Triggers == nil { + return nil + } + + if trigger, ok := m.config.Triggers[name]; ok { + return trigger(m, detail) + } + + return nil +} + func (m *Sqlite) Drop() (err error) { query := `SELECT name FROM sqlite_master WHERE type = 'table';` tables, err := m.db.Query(query) diff --git a/database/sqlite/sqlite.go b/database/sqlite/sqlite.go index ce449dfa0..d6ae9bc43 100644 --- a/database/sqlite/sqlite.go +++ b/database/sqlite/sqlite.go @@ -31,6 +31,8 @@ type Config struct { MigrationsTable string DatabaseName string NoTxWrap bool + + Triggers map[string]func(d database.Driver, detail interface{}) error } type Sqlite struct { @@ -133,6 +135,22 @@ func (m *Sqlite) Close() error { return m.db.Close() } +func (m *Sqlite) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + m.config.Triggers = t +} + +func (m *Sqlite) Trigger(name string, detail interface{}) error { + if m.config.Triggers == nil { + return nil + } + + if trigger, ok := m.config.Triggers[name]; ok { + return trigger(m, detail) + } + + return nil +} + func (m *Sqlite) Drop() (err error) { query := `SELECT name FROM sqlite_master WHERE type = 'table';` tables, err := m.db.Query(query) diff --git a/database/sqlite3/sqlite3.go b/database/sqlite3/sqlite3.go index 56bb23338..4d1a9235b 100644 --- a/database/sqlite3/sqlite3.go +++ b/database/sqlite3/sqlite3.go @@ -31,6 +31,8 @@ type Config struct { MigrationsTable string DatabaseName string NoTxWrap bool + + Triggers map[string]func(d database.Driver, detail interface{}) error } type Sqlite struct { @@ -133,6 +135,22 @@ func (m *Sqlite) Close() error { return m.db.Close() } +func (m *Sqlite) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + m.config.Triggers = t +} + +func (m *Sqlite) Trigger(name string, detail interface{}) error { + if m.config.Triggers == nil { + return nil + } + + if trigger, ok := m.config.Triggers[name]; ok { + return trigger(m, detail) + } + + return nil +} + func (m *Sqlite) Drop() (err error) { query := `SELECT name FROM sqlite_master WHERE type = 'table';` tables, err := m.db.Query(query) diff --git a/database/sqlserver/sqlserver.go b/database/sqlserver/sqlserver.go index 92834d1ad..75b9915f9 100644 --- a/database/sqlserver/sqlserver.go +++ b/database/sqlserver/sqlserver.go @@ -45,6 +45,8 @@ type Config struct { MigrationsTable string DatabaseName string SchemaName string + + Triggers map[string]func(d database.Driver, detail interface{}) error } // SQL Server connection @@ -190,6 +192,22 @@ func (ss *SQLServer) Close() error { return nil } +func (ss *SQLServer) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + ss.config.Triggers = t +} + +func (ss *SQLServer) Trigger(name string, detail interface{}) error { + if ss.config.Triggers == nil { + return nil + } + + if trigger, ok := ss.config.Triggers[name]; ok { + return trigger(ss, detail) + } + + return nil +} + // Lock creates an advisory local on the database to prevent multiple migrations from running at the same time. func (ss *SQLServer) Lock() error { return database.CasRestoreOnErr(&ss.isLocked, false, true, database.ErrLocked, func() error { diff --git a/database/stub/stub.go b/database/stub/stub.go index ae502650b..7b4363e08 100644 --- a/database/stub/stub.go +++ b/database/stub/stub.go @@ -34,7 +34,9 @@ func (s *Stub) Open(url string) (database.Driver, error) { }, nil } -type Config struct{} +type Config struct { + Triggers map[string]func(d database.Driver, detail interface{}) error +} func WithInstance(instance interface{}, config *Config) (database.Driver, error) { return &Stub{ @@ -49,6 +51,22 @@ func (s *Stub) Close() error { return nil } +func (s *Stub) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + s.Config.Triggers = t +} + +func (s *Stub) Trigger(name string, detail interface{}) error { + if s.Config.Triggers == nil { + return nil + } + + if trigger, ok := s.Config.Triggers[name]; ok { + return trigger(s, detail) + } + + return nil +} + func (s *Stub) Lock() error { if !s.isLocked.CAS(false, true) { return database.ErrLocked diff --git a/database/yugabytedb/yugabytedb.go b/database/yugabytedb/yugabytedb.go index 764d23c02..c173377e1 100644 --- a/database/yugabytedb/yugabytedb.go +++ b/database/yugabytedb/yugabytedb.go @@ -49,6 +49,8 @@ type Config struct { MaxRetryInterval time.Duration MaxRetryElapsedTime time.Duration MaxRetries int + + Triggers map[string]func(d database.Driver, detail interface{}) error } type YugabyteDB struct { @@ -189,6 +191,22 @@ func (c *YugabyteDB) Close() error { return c.db.Close() } +func (c *YugabyteDB) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { + c.config.Triggers = t +} + +func (c *YugabyteDB) Trigger(name string, detail interface{}) error { + if c.config.Triggers == nil { + return nil + } + + if trigger, ok := c.config.Triggers[name]; ok { + return trigger(c, detail) + } + + return nil +} + // Locking is done manually with a separate lock table. Implementing advisory locks in YugabyteDB is being discussed // See: https://github.com/yugabyte/yugabyte-db/issues/3642 func (c *YugabyteDB) Lock() error { diff --git a/migrate.go b/migrate.go index 266cc04eb..c17dcf9be 100644 --- a/migrate.go +++ b/migrate.go @@ -36,6 +36,11 @@ var ( ErrLockTimeout = errors.New("timeout: can't acquire database lock") ) +const TrigRunMigrationPre = "RunMigrationPre" +const TrigRunMigrationPost = "RunMigrationPost" +const TrigRunMigrationVersionPre = "RunMigrationVersionPre" +const TrigRunMigrationVersionPost = "RunMigrationVersionPost" + // ErrShortLimit is an error returned when not enough migrations // can be returned by a source for a given limit. type ErrShortLimit struct { @@ -80,64 +85,103 @@ type Migrate struct { // LockTimeout defaults to DefaultLockTimeout, // but can be set per Migrate instance. LockTimeout time.Duration + + Triggers map[string]func(m *Migrate, detail interface{}) error } -// New returns a new Migrate instance from a source URL and a database URL. -// The URL scheme is defined by each driver. -func New(sourceURL, databaseURL string) (*Migrate, error) { +type Options struct { + // Source from URL + // The URL scheme is defined by each driver. + SourceURL string + + // Source from Instance + // Use any string that can serve as an identifier during logging as sourceName. + // You are responsible for closing down the underlying source if necessary. + SourceName string + SourceInstance source.Driver + + // Database from URL + // The URL scheme is defined by each driver. + DatabaseURL string + + // Database from Instance + // Use any string that can serve as an identifier during logging as databaseName. + // You are responsible for closing the underlying database client if necessary. + // You will also need to setup your own triggers if needed. + DatabaseName string + DatabaseInstance database.Driver + + // Triggers - these can be used to execute arbitrary code to meet any additional + // requirements that may be needed (i.e. some people need a history of migrations) + MigrateTriggers map[string]func(m *Migrate, detail interface{}) error + DatabaseTriggers map[string]func(d database.Driver, detail interface{}) error +} + +// NewFromOptions returns a new Migrate instance from the options provided. +func NewFromOptions(o Options) (*Migrate, error) { m := newCommon() - sourceName, err := iurl.SchemeFromURL(sourceURL) - if err != nil { - return nil, fmt.Errorf("failed to parse scheme from source URL: %w", err) - } - m.sourceName = sourceName + if o.SourceURL != "" { + sourceName, err := iurl.SchemeFromURL(o.SourceURL) + if err != nil { + return nil, fmt.Errorf("failed to parse scheme from source URL: %w", err) + } + m.sourceName = sourceName - databaseName, err := iurl.SchemeFromURL(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to parse scheme from database URL: %w", err) + sourceDrv, err := source.Open(o.SourceURL) + if err != nil { + return nil, fmt.Errorf("failed to open source, %q: %w", o.SourceURL, err) + } + m.sourceDrv = sourceDrv + } else if o.SourceName != "" && o.SourceInstance != nil { + m.sourceName = o.SourceName + m.sourceDrv = o.SourceInstance + } else { + return nil, fmt.Errorf("must specify either SourceURL or SourceName and SourceInstance") } - m.databaseName = databaseName - sourceDrv, err := source.Open(sourceURL) - if err != nil { - return nil, fmt.Errorf("failed to open source, %q: %w", sourceURL, err) - } - m.sourceDrv = sourceDrv + if o.DatabaseURL != "" { + databaseName, err := iurl.SchemeFromURL(o.DatabaseURL) + if err != nil { + return nil, fmt.Errorf("failed to parse scheme from database URL: %w", err) + } + m.databaseName = databaseName - databaseDrv, err := database.Open(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to open database: %w", err) + databaseDrv, err := database.Open(o.DatabaseURL) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + m.databaseDrv = databaseDrv + m.databaseDrv.AddTriggers(o.DatabaseTriggers) + } else if o.DatabaseName != "" && o.DatabaseInstance != nil { + m.databaseName = o.DatabaseName + m.databaseDrv = o.DatabaseInstance } - m.databaseDrv = databaseDrv + + m.Triggers = o.MigrateTriggers return m, nil } +// New returns a new Migrate instance from a source URL and a database URL. +// The URL scheme is defined by each driver. +func New(sourceURL, databaseURL string) (*Migrate, error) { + return NewFromOptions(Options{ + SourceURL: sourceURL, + DatabaseURL: databaseURL, + }) +} + // NewWithDatabaseInstance returns a new Migrate instance from a source URL // and an existing database instance. The source URL scheme is defined by each driver. // Use any string that can serve as an identifier during logging as databaseName. // You are responsible for closing the underlying database client if necessary. func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) { - m := newCommon() - - sourceName, err := iurl.SchemeFromURL(sourceURL) - if err != nil { - return nil, err - } - m.sourceName = sourceName - - m.databaseName = databaseName - - sourceDrv, err := source.Open(sourceURL) - if err != nil { - return nil, fmt.Errorf("failed to open source, %q: %w", sourceURL, err) - } - m.sourceDrv = sourceDrv - - m.databaseDrv = databaseInstance - - return m, nil + return NewFromOptions(Options{ + SourceURL: sourceURL, + DatabaseName: databaseName, + DatabaseInstance: databaseInstance, + }) } // NewWithSourceInstance returns a new Migrate instance from an existing source instance @@ -145,25 +189,11 @@ func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInst // Use any string that can serve as an identifier during logging as sourceName. // You are responsible for closing the underlying source client if necessary. func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) { - m := newCommon() - - databaseName, err := iurl.SchemeFromURL(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to parse scheme from database URL: %w", err) - } - m.databaseName = databaseName - - m.sourceName = sourceName - - databaseDrv, err := database.Open(databaseURL) - if err != nil { - return nil, fmt.Errorf("failed to open database: %w", err) - } - m.databaseDrv = databaseDrv - - m.sourceDrv = sourceInstance - - return m, nil + return NewFromOptions(Options{ + SourceName: sourceName, + SourceInstance: sourceInstance, + DatabaseURL: databaseURL, + }) } // NewWithInstance returns a new Migrate instance from an existing source and @@ -171,15 +201,12 @@ func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, data // as sourceName and databaseName. You are responsible for closing down // the underlying source and database client if necessary. func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseName string, databaseInstance database.Driver) (*Migrate, error) { - m := newCommon() - - m.sourceName = sourceName - m.databaseName = databaseName - - m.sourceDrv = sourceInstance - m.databaseDrv = databaseInstance - - return m, nil + return NewFromOptions(Options{ + SourceName: sourceName, + SourceInstance: sourceInstance, + DatabaseName: databaseName, + DatabaseInstance: databaseInstance, + }) } func newCommon() *Migrate { @@ -191,6 +218,18 @@ func newCommon() *Migrate { } } +func (m *Migrate) Trigger(name string, detail interface{}) error { + if m.Triggers == nil { + return nil + } + + if trigger, ok := m.Triggers[name]; ok { + return trigger(m, detail) + } + + return nil +} + // Close closes the source and the database. func (m *Migrate) Close() (source error, database error) { databaseSrvClose := make(chan error) @@ -723,6 +762,10 @@ func (m *Migrate) readDown(from int, limit int, ret chan<- interface{}) { // to stop execution because it might have received a stop signal on the // GracefulStop channel. func (m *Migrate) runMigrations(ret <-chan interface{}) error { + if err := m.Trigger(TrigRunMigrationPre, nil); err != nil { + return err + } + for r := range ret { if m.stop() { @@ -742,10 +785,18 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error { } if migr.Body != nil { + if err := m.Trigger(TrigRunMigrationVersionPre, struct{ Migration *Migration }{migr}); err != nil { + return err + } + m.logVerbosePrintf("Read and execute %v\n", migr.LogString()) if err := m.databaseDrv.Run(migr.BufferedBody); err != nil { return err } + + if err := m.Trigger(TrigRunMigrationVersionPost, struct{ Migration *Migration }{migr}); err != nil { + return err + } } // set clean state @@ -770,6 +821,11 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error { return fmt.Errorf("unknown type: %T with value: %+v", r, r) } } + + if err := m.Trigger(TrigRunMigrationPost, nil); err != nil { + return err + } + return nil } From 73a2435b9473400fa3e3887d995d878fac2a3462 Mon Sep 17 00:00:00 2001 From: David Syers Date: Mon, 19 May 2025 14:31:12 +0100 Subject: [PATCH 02/11] temp --- go.mod | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 3c20151f2..3de2f3325 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,5 @@ -module github.com/golang-migrate/migrate/v4 +//module github.com/golang-migrate/migrate/v4 +module github.com/dsyers/migrate/v4 go 1.23.0 From 5a5d7a488276c7ce03c6aa785b1847ac771f3e12 Mon Sep 17 00:00:00 2001 From: David Syers Date: Tue, 20 May 2025 12:33:54 +0100 Subject: [PATCH 03/11] Tidy / example --- database/driver.go | 2 +- database/mysql/examples/triggers/main.go | 120 ++++++++++++++++++ .../migrations/20250101_tusers.down.sql | 1 + .../migrations/20250101_tusers.up.sql | 6 + .../migrations/20250102_tusers_alter.down.sql | 4 + .../migrations/20250102_tusers_alter.up.sql | 4 + database/mysql/mysql.go | 27 +++- go.mod | 3 +- migrate.go | 16 ++- 9 files changed, 171 insertions(+), 12 deletions(-) create mode 100644 database/mysql/examples/triggers/main.go create mode 100644 database/mysql/examples/triggers/migrations/20250101_tusers.down.sql create mode 100644 database/mysql/examples/triggers/migrations/20250101_tusers.up.sql create mode 100644 database/mysql/examples/triggers/migrations/20250102_tusers_alter.down.sql create mode 100644 database/mysql/examples/triggers/migrations/20250102_tusers_alter.up.sql diff --git a/database/driver.go b/database/driver.go index 0536e5cb0..8d756aef6 100644 --- a/database/driver.go +++ b/database/driver.go @@ -61,7 +61,7 @@ type Driver interface { Close() error // AddTriggers adds triggers to the database. The map key is the trigger name - AddTriggers(t map[string]func(m Driver, detail interface{}) error) + AddTriggers(t map[string]func(response interface{}) error) // Trigger is called when a trigger is fired. The name is the trigger name // and detail is the trigger detail. diff --git a/database/mysql/examples/triggers/main.go b/database/mysql/examples/triggers/main.go new file mode 100644 index 000000000..7c69ff42c --- /dev/null +++ b/database/mysql/examples/triggers/main.go @@ -0,0 +1,120 @@ +package main + +import ( + "context" + "database/sql" + "fmt" + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database" + "github.com/golang-migrate/migrate/v4/database/mysql" + _ "github.com/golang-migrate/migrate/v4/source/file" + "os" +) + +type App struct { + Connection *sql.Conn + MigrationTable string + HistoryID *int64 +} + +func main() { + db, err := sql.Open("mysql", "root:root@tcp(localhost:3306)/db?multiStatements=true") + if err != nil { + fmt.Printf("%v\n", err) + os.Exit(1) + } + ctx := context.Background() + conn, err := db.Conn(ctx) + if err != nil { + fmt.Printf("%v\n", err) + os.Exit(1) + } + defer conn.Close() + defer db.Close() + + app := &App{ + Connection: conn, + MigrationTable: mysql.DefaultMigrationsTable, + } + + databaseDrv, err := mysql.WithConnection(ctx, conn, &mysql.Config{ + DatabaseName: "db", + Triggers: map[string]func(response interface{}) error{ + database.TrigVersionTableExists: app.MigrationHistoryTable, + database.TrigVersionTablePost: app.MigrationHistoryTable, + database.TrigRunPost: app.DatabaseRunPost, + }, + }) + if err != nil { + fmt.Printf("%v\n", err) + os.Exit(1) + } + m, err := migrate.NewFromOptions(migrate.Options{ + DatabaseInstance: databaseDrv, + DatabaseName: "db", + SourceURL: "file://migrations", + MigrateTriggers: map[string]func(response migrate.TriggerResponse) error{ + migrate.TrigRunMigrationVersionPre: app.RunMigrationVersionPre, + migrate.TrigRunMigrationVersionPost: app.RunMigrationVersionPost, + }, + }) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + err = m.Up() + //err = m.Down() + if err != nil { + fmt.Printf("%v\n", err) + return + } +} + +func (app *App) MigrationHistoryTable(response interface{}) error { + r, _ := response.(mysql.TriggerResponse) + fmt.Printf("Executing database trigger %s\n, %v\n", r.Trigger, r) + query := "CREATE TABLE IF NOT EXISTS `" + app.MigrationTable + "_history` (`id` bigint not null primary key auto_increment, `version` bigint not null, `target` bigint, identifier varchar(255), `dirty` tinyint not null, migration text, `timestamp` datetime not null)" + _, err := app.Connection.ExecContext(context.TODO(), query) + return err +} + +func (app *App) DatabaseRunPost(response interface{}) error { + r, _ := response.(mysql.TriggerResponse) + detail := r.Detail.(struct{ Query string }) + fmt.Printf("Executing database trigger %s\n, %v\n", r.Trigger, r.Detail) + query := "UPDATE `" + app.MigrationTable + "_history` SET `migration` = ? WHERE `id` = ?" + _, err := app.Connection.ExecContext(context.TODO(), query, detail.Query, app.HistoryID) + if err != nil { + fmt.Printf("Error updating migration history: %v\n", err) + return err + } + + return nil +} + +func (app *App) RunMigrationVersionPre(r migrate.TriggerResponse) error { + detail := r.Detail.(struct{ Migration *migrate.Migration }) + fmt.Printf("Executing migration trigger %s\n, %v\n", r.Trigger, r.Detail) + query := "INSERT INTO `" + app.MigrationTable + "_history` (`version`, `identifier`, `target`, `dirty`, `timestamp`) VALUES (?, ?, ?, 1, NOW())" + _, err := app.Connection.ExecContext(context.TODO(), query, detail.Migration.Version, detail.Migration.Identifier, detail.Migration.TargetVersion) + if err != nil { + fmt.Printf("Error inserting migration history: %v\n", err) + return err + } + query = "SELECT LAST_INSERT_ID()" + row := app.Connection.QueryRowContext(context.TODO(), query) + return row.Scan(&app.HistoryID) +} + +func (app *App) RunMigrationVersionPost(r migrate.TriggerResponse) error { + fmt.Printf("Executing migration trigger %s\n, %v\n", r.Trigger, r.Detail) + query := "UPDATE `" + app.MigrationTable + "_history` SET `dirty` = 0, `timestamp` = NOW() WHERE `id` = ?" + _, err := app.Connection.ExecContext(context.TODO(), query, app.HistoryID) + if err != nil { + fmt.Printf("Error updating migration history: %v\n", err) + return err + } + app.HistoryID = nil + + return nil +} diff --git a/database/mysql/examples/triggers/migrations/20250101_tusers.down.sql b/database/mysql/examples/triggers/migrations/20250101_tusers.down.sql new file mode 100644 index 000000000..3ad31301a --- /dev/null +++ b/database/mysql/examples/triggers/migrations/20250101_tusers.down.sql @@ -0,0 +1 @@ +DROP TABLE tusers; \ No newline at end of file diff --git a/database/mysql/examples/triggers/migrations/20250101_tusers.up.sql b/database/mysql/examples/triggers/migrations/20250101_tusers.up.sql new file mode 100644 index 000000000..0f9c910b4 --- /dev/null +++ b/database/mysql/examples/triggers/migrations/20250101_tusers.up.sql @@ -0,0 +1,6 @@ +CREATE TABLE tusers ( + id SERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL, + email VARCHAR(100) UNIQUE NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); \ No newline at end of file diff --git a/database/mysql/examples/triggers/migrations/20250102_tusers_alter.down.sql b/database/mysql/examples/triggers/migrations/20250102_tusers_alter.down.sql new file mode 100644 index 000000000..3058138c8 --- /dev/null +++ b/database/mysql/examples/triggers/migrations/20250102_tusers_alter.down.sql @@ -0,0 +1,4 @@ +ALTER TABLE tusers + DROP COLUMN last_login, + DROP COLUMN status, + DROP COLUMN profile_picture; \ No newline at end of file diff --git a/database/mysql/examples/triggers/migrations/20250102_tusers_alter.up.sql b/database/mysql/examples/triggers/migrations/20250102_tusers_alter.up.sql new file mode 100644 index 000000000..f7d9e5709 --- /dev/null +++ b/database/mysql/examples/triggers/migrations/20250102_tusers_alter.up.sql @@ -0,0 +1,4 @@ +ALTER TABLE tusers + ADD COLUMN last_login TIMESTAMP, + ADD COLUMN status VARCHAR(20) DEFAULT 'active', + ADD COLUMN profile_picture VARCHAR(255); \ No newline at end of file diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index 028855e14..8dd6c246a 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -44,7 +44,15 @@ type Config struct { NoLock bool StatementTimeout time.Duration - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error +} + +type TriggerResponse struct { + Driver *Mysql + Config *Config + Connection *sql.Conn + Trigger string + Detail interface{} } type Mysql struct { @@ -285,7 +293,7 @@ func (m *Mysql) Close() error { return nil } -func (m *Mysql) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (m *Mysql) AddTriggers(t map[string]func(response interface{}) error) { m.config.Triggers = t } @@ -295,7 +303,12 @@ func (m *Mysql) Trigger(name string, detail interface{}) error { } if trigger, ok := m.config.Triggers[name]; ok { - return trigger(m, detail) + return trigger(TriggerResponse{ + Driver: m, + Config: m.config, + Trigger: name, + Detail: detail, + }) } return nil @@ -365,13 +378,17 @@ func (m *Mysql) Run(migration io.Reader) error { } query := string(migr[:]) - if err := m.Trigger(database.TrigRunPre, struct{ Query string }{Query: query}); err != nil { + if err := m.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} } if _, err := m.conn.ExecContext(ctx, query); err != nil { return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } - if err := m.Trigger(database.TrigRunPost, struct{ Query string }{Query: query}); err != nil { + if err := m.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} } diff --git a/go.mod b/go.mod index 3de2f3325..3c20151f2 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,4 @@ -//module github.com/golang-migrate/migrate/v4 -module github.com/dsyers/migrate/v4 +module github.com/golang-migrate/migrate/v4 go 1.23.0 diff --git a/migrate.go b/migrate.go index c17dcf9be..495603013 100644 --- a/migrate.go +++ b/migrate.go @@ -86,7 +86,7 @@ type Migrate struct { // but can be set per Migrate instance. LockTimeout time.Duration - Triggers map[string]func(m *Migrate, detail interface{}) error + Triggers map[string]func(response TriggerResponse) error } type Options struct { @@ -113,8 +113,13 @@ type Options struct { // Triggers - these can be used to execute arbitrary code to meet any additional // requirements that may be needed (i.e. some people need a history of migrations) - MigrateTriggers map[string]func(m *Migrate, detail interface{}) error - DatabaseTriggers map[string]func(d database.Driver, detail interface{}) error + MigrateTriggers map[string]func(response TriggerResponse) error + DatabaseTriggers map[string]func(response interface{}) error +} + +type TriggerResponse struct { + Trigger string + Detail interface{} } // NewFromOptions returns a new Migrate instance from the options provided. @@ -224,7 +229,10 @@ func (m *Migrate) Trigger(name string, detail interface{}) error { } if trigger, ok := m.Triggers[name]; ok { - return trigger(m, detail) + return trigger(TriggerResponse{ + Trigger: name, + Detail: detail, + }) } return nil From 9e1ba3c199925c1b858758edcc096b4794ad70e7 Mon Sep 17 00:00:00 2001 From: David Syers Date: Tue, 20 May 2025 19:32:03 +0100 Subject: [PATCH 04/11] triggers --- database/cassandra/cassandra.go | 63 ++++++++++++++++++++++-- database/clickhouse/clickhouse.go | 72 ++++++++++++++++++++++++++-- database/cockroachdb/cockroachdb.go | 54 +++++++++++++++++++-- database/firebird/firebird.go | 50 +++++++++++++++++-- database/mongodb/mongodb.go | 49 +++++++++++++++++-- database/mysql/mysql.go | 9 ++-- database/neo4j/neo4j.go | 69 +++++++++++++++++++++++++-- database/pgx/pgx.go | 74 +++++++++++++++++++++++++++-- database/pgx/v5/pgx.go | 74 +++++++++++++++++++++++++++-- database/postgres/postgres.go | 74 +++++++++++++++++++++++++++-- database/ql/ql.go | 67 ++++++++++++++++++++++++-- database/redshift/redshift.go | 60 +++++++++++++++++++++-- database/rqlite/rqlite.go | 50 +++++++++++++++++-- database/snowflake/snowflake.go | 59 +++++++++++++++++++++-- database/spanner/spanner.go | 55 +++++++++++++++++++-- database/sqlcipher/sqlcipher.go | 70 +++++++++++++++++++++++++-- database/sqlite/sqlite.go | 66 +++++++++++++++++++++++-- database/sqlite3/sqlite3.go | 60 +++++++++++++++++++++-- database/sqlserver/sqlserver.go | 56 ++++++++++++++++++++-- database/stub/stub.go | 18 +++++-- database/yugabytedb/yugabytedb.go | 53 +++++++++++++++++++-- 21 files changed, 1123 insertions(+), 79 deletions(-) diff --git a/database/cassandra/cassandra.go b/database/cassandra/cassandra.go index a231a2eae..5b51e154a 100644 --- a/database/cassandra/cassandra.go +++ b/database/cassandra/cassandra.go @@ -43,7 +43,14 @@ type Config struct { MultiStatementEnabled bool MultiStatementMaxSize int - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error +} + +type TriggerResponse struct { + Driver *Cassandra + Config *Config + Trigger string + Detail interface{} } type Cassandra struct { @@ -200,7 +207,7 @@ func (c *Cassandra) Close() error { return nil } -func (c *Cassandra) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (c *Cassandra) AddTriggers(t map[string]func(response interface{}) error) { c.config.Triggers = t } @@ -210,7 +217,12 @@ func (c *Cassandra) Trigger(name string, detail interface{}) error { } if trigger, ok := c.config.Triggers[name]; ok { - return trigger(c, detail) + return trigger(TriggerResponse{ + Driver: c, + Config: c.config, + Trigger: name, + Detail: detail, + }) } return nil @@ -238,10 +250,22 @@ func (c *Cassandra) Run(migration io.Reader) error { if tq == "" { return true } + if e := c.Trigger(database.TrigRunPre, struct { + Query string + }{Query: tq}); e != nil { + err = database.Error{OrigErr: e, Err: "failed to trigger RunPre"} + return false + } if e := c.session.Query(tq).Exec(); e != nil { err = database.Error{OrigErr: e, Err: "migration failed", Query: m} return false } + if e := c.Trigger(database.TrigRunPost, struct { + Query string + }{Query: tq}); e != nil { + err = database.Error{OrigErr: e, Err: "failed to trigger RunPost"} + return false + } return true }); e != nil { return e @@ -253,15 +277,32 @@ func (c *Cassandra) Run(migration io.Reader) error { if err != nil { return err } + if err := c.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(migr)}); err != nil { + return database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } // run migration if err := c.session.Query(string(migr)).Exec(); err != nil { // TODO: cast to Cassandra error and get line number return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } + if err := c.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(migr)}); err != nil { + return database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return nil } func (c *Cassandra) SetVersion(version int, dirty bool) error { + if err := c.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + // DELETE instead of TRUNCATE because AWS Keyspaces does not support it // see: https://docs.aws.amazon.com/keyspaces/latest/devguide/cassandra-apis.html squery := `SELECT version FROM "` + c.config.MigrationsTable + `"` @@ -287,6 +328,13 @@ func (c *Cassandra) SetVersion(version int, dirty bool) error { } } + if err := c.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + return nil } @@ -342,6 +390,10 @@ func (c *Cassandra) ensureVersionTable() (err error) { } }() + if err := c.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + err = c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec() if err != nil { return err @@ -349,6 +401,11 @@ func (c *Cassandra) ensureVersionTable() (err error) { if _, _, err = c.Version(); err != nil { return err } + + if err := c.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } diff --git a/database/clickhouse/clickhouse.go b/database/clickhouse/clickhouse.go index 8382110af..58bf8bf70 100644 --- a/database/clickhouse/clickhouse.go +++ b/database/clickhouse/clickhouse.go @@ -35,7 +35,7 @@ type Config struct { MultiStatementEnabled bool MultiStatementMaxSize int - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error } func init() { @@ -69,6 +69,13 @@ type ClickHouse struct { isLocked atomic.Bool } +type TriggerResponse struct { + Driver *ClickHouse + Config *Config + Trigger string + Detail interface{} +} + func (ch *ClickHouse) Open(dsn string) (database.Driver, error) { purl, err := url.Parse(dsn) if err != nil { @@ -143,10 +150,22 @@ func (ch *ClickHouse) Run(r io.Reader) error { if tq == "" { return true } + if e := ch.Trigger(database.TrigRunPre, struct { + Query string + }{Query: tq}); e != nil { + err = database.Error{OrigErr: e, Err: "failed to trigger RunPre"} + return false + } if _, e := ch.conn.Exec(string(m)); e != nil { err = database.Error{OrigErr: e, Err: "migration failed", Query: m} return false } + if e := ch.Trigger(database.TrigRunPost, struct { + Query string + }{Query: tq}); e != nil { + err = database.Error{OrigErr: e, Err: "failed to trigger RunPost"} + return false + } return true }); e != nil { return e @@ -159,10 +178,22 @@ func (ch *ClickHouse) Run(r io.Reader) error { return err } + if err := ch.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(migration)}); err != nil { + return database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + if _, err := ch.conn.Exec(string(migration)); err != nil { return database.Error{OrigErr: err, Err: "migration failed", Query: migration} } + if err := ch.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(migration)}); err != nil { + return database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } + return nil } func (ch *ClickHouse) Version() (int, bool, error) { @@ -182,7 +213,7 @@ func (ch *ClickHouse) Version() (int, bool, error) { func (ch *ClickHouse) SetVersion(version int, dirty bool) error { var ( - bool = func(v bool) uint8 { + booln = func(v bool) uint8 { if v { return 1 } @@ -194,11 +225,25 @@ func (ch *ClickHouse) SetVersion(version int, dirty bool) error { return err } + if err := ch.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := "INSERT INTO " + ch.config.MigrationsTable + " (version, dirty, sequence) VALUES (?, ?, ?)" - if _, err := tx.Exec(query, version, bool(dirty), time.Now().UnixNano()); err != nil { + if _, err := tx.Exec(query, version, booln(dirty), time.Now().UnixNano()); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + if err := ch.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + return tx.Commit() } @@ -230,9 +275,16 @@ func (ch *ClickHouse) ensureVersionTable() (err error) { return &database.Error{OrigErr: err, Query: []byte(query)} } } else { + if err := ch.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := ch.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + // if not, create the empty migration table if len(ch.config.ClusterName) > 0 { query = fmt.Sprintf(` @@ -257,6 +309,11 @@ func (ch *ClickHouse) ensureVersionTable() (err error) { if _, err := ch.conn.Exec(query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + + if err := ch.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } @@ -308,7 +365,7 @@ func (ch *ClickHouse) Unlock() error { } func (ch *ClickHouse) Close() error { return ch.conn.Close() } -func (ch *ClickHouse) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (ch *ClickHouse) AddTriggers(t map[string]func(response interface{}) error) { ch.config.Triggers = t } @@ -318,7 +375,12 @@ func (ch *ClickHouse) Trigger(name string, detail interface{}) error { } if trigger, ok := ch.config.Triggers[name]; ok { - return trigger(ch, detail) + return trigger(TriggerResponse{ + Driver: ch, + Config: ch.config, + Trigger: name, + Detail: detail, + }) } return nil diff --git a/database/cockroachdb/cockroachdb.go b/database/cockroachdb/cockroachdb.go index 0cb04612e..19b9e4490 100644 --- a/database/cockroachdb/cockroachdb.go +++ b/database/cockroachdb/cockroachdb.go @@ -38,7 +38,7 @@ type Config struct { ForceLock bool DatabaseName string - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error } type CockroachDb struct { @@ -49,6 +49,13 @@ type CockroachDb struct { config *Config } +type TriggerResponse struct { + Driver *CockroachDb + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -146,7 +153,7 @@ func (c *CockroachDb) Close() error { return c.db.Close() } -func (c *CockroachDb) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (c *CockroachDb) AddTriggers(t map[string]func(response interface{}) error) { c.config.Triggers = t } @@ -156,7 +163,12 @@ func (c *CockroachDb) Trigger(name string, detail interface{}) error { } if trigger, ok := c.config.Triggers[name]; ok { - return trigger(c, detail) + return trigger(TriggerResponse{ + Driver: c, + Config: c.config, + Trigger: name, + Detail: detail, + }) } return nil @@ -236,15 +248,32 @@ func (c *CockroachDb) Run(migration io.Reader) error { // run migration query := string(migr[:]) + if err := c.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } if _, err := c.db.Exec(query); err != nil { return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } + if err := c.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return nil } func (c *CockroachDb) SetVersion(version int, dirty bool) error { return crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) error { + if err := c.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return err + } + if _, err := tx.Exec(`DELETE FROM "` + c.config.MigrationsTable + `"`); err != nil { return err } @@ -258,6 +287,13 @@ func (c *CockroachDb) SetVersion(version int, dirty bool) error { } } + if err := c.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return err + } + return nil }) } @@ -351,14 +387,26 @@ func (c *CockroachDb) ensureVersionTable() (err error) { return &database.Error{OrigErr: err, Query: []byte(query)} } if count == 1 { + if err := c.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := c.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + // if not, create the empty migration table query = `CREATE TABLE "` + c.config.MigrationsTable + `" (version INT NOT NULL PRIMARY KEY, dirty BOOL NOT NULL)` if _, err := c.db.Exec(query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + + if err := c.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } diff --git a/database/firebird/firebird.go b/database/firebird/firebird.go index 3b7191cc3..78973a4a2 100644 --- a/database/firebird/firebird.go +++ b/database/firebird/firebird.go @@ -32,7 +32,7 @@ type Config struct { DatabaseName string MigrationsTable string - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error } type Firebird struct { @@ -45,6 +45,13 @@ type Firebird struct { config *Config } +type TriggerResponse struct { + Driver *Firebird + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -108,7 +115,7 @@ func (f *Firebird) Close() error { return nil } -func (f *Firebird) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (f *Firebird) AddTriggers(t map[string]func(response interface{}) error) { f.config.Triggers = t } @@ -118,7 +125,12 @@ func (f *Firebird) Trigger(name string, detail interface{}) error { } if trigger, ok := f.config.Triggers[name]; ok { - return trigger(f, detail) + return trigger(TriggerResponse{ + Driver: f, + Config: f.config, + Trigger: name, + Detail: detail, + }) } return nil @@ -146,9 +158,19 @@ func (f *Firebird) Run(migration io.Reader) error { // run migration query := string(migr[:]) + if err := f.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } if _, err := f.conn.ExecContext(context.Background(), query); err != nil { return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } + if err := f.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return nil } @@ -158,6 +180,13 @@ func (f *Firebird) SetVersion(version int, dirty bool) error { // for failed down migration on the first migration // See: https://github.com/golang-migrate/migrate/issues/330 + if err := f.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + // TODO: parameterize this SQL statement // https://firebirdsql.org/refdocs/langrefupd20-execblock.html // VALUES (?, ?) doesn't work @@ -171,6 +200,13 @@ func (f *Firebird) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Query: []byte(query)} } + if err := f.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + return nil } @@ -249,6 +285,10 @@ func (f *Firebird) ensureVersionTable() (err error) { } }() + if err := f.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then execute statement 'create table "%v" (version bigint not null primary key, dirty smallint not null)'; @@ -259,6 +299,10 @@ func (f *Firebird) ensureVersionTable() (err error) { return &database.Error{OrigErr: err, Query: []byte(query)} } + if err := f.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } diff --git a/database/mongodb/mongodb.go b/database/mongodb/mongodb.go index 4dcbd4f64..ede3b3c6f 100644 --- a/database/mongodb/mongodb.go +++ b/database/mongodb/mongodb.go @@ -60,7 +60,7 @@ type Config struct { TransactionMode bool Locking Locking - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error } type versionInfo struct { Version int `bson:"version"` @@ -77,6 +77,13 @@ type findFilter struct { Key int `bson:"locking_key"` } +type TriggerResponse struct { + Driver *Mongo + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -218,6 +225,13 @@ func parseInt(urlParam string, defaultValue int) (int, error) { return defaultValue, nil } func (m *Mongo) SetVersion(version int, dirty bool) error { + if err := m.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + migrationsCollection := m.db.Collection(m.config.MigrationsCollection) if err := migrationsCollection.Drop(context.TODO()); err != nil { return &database.Error{OrigErr: err, Err: "drop migrations collection failed"} @@ -226,6 +240,14 @@ func (m *Mongo) SetVersion(version int, dirty bool) error { if err != nil { return &database.Error{OrigErr: err, Err: "save version failed"} } + + if err := m.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + return nil } @@ -252,6 +274,11 @@ func (m *Mongo) Run(migration io.Reader) error { if err != nil { return fmt.Errorf("unmarshaling json error: %s", err) } + if err := m.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(migr)}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } if m.config.TransactionMode { if err := m.executeCommandsWithTransaction(context.TODO(), cmds); err != nil { return err @@ -261,6 +288,11 @@ func (m *Mongo) Run(migration io.Reader) error { return err } } + if err := m.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(migr)}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return nil } @@ -299,7 +331,7 @@ func (m *Mongo) Close() error { return m.client.Disconnect(context.TODO()) } -func (m *Mongo) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (m *Mongo) AddTriggers(t map[string]func(response interface{}) error) { m.config.Triggers = t } @@ -309,7 +341,12 @@ func (m *Mongo) Trigger(name string, detail interface{}) error { } if trigger, ok := m.config.Triggers[name]; ok { - return trigger(m, detail) + return trigger(TriggerResponse{ + Driver: m, + Config: m.config, + Trigger: name, + Detail: detail, + }) } return nil @@ -354,9 +391,15 @@ func (m *Mongo) ensureVersionTable() (err error) { if err != nil { return err } + if err := m.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } if _, _, err = m.Version(); err != nil { return err } + if err := m.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } return nil } diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go index 8dd6c246a..b857a3ce6 100644 --- a/database/mysql/mysql.go +++ b/database/mysql/mysql.go @@ -48,11 +48,10 @@ type Config struct { } type TriggerResponse struct { - Driver *Mysql - Config *Config - Connection *sql.Conn - Trigger string - Detail interface{} + Driver *Mysql + Config *Config + Trigger string + Detail interface{} } type Mysql struct { diff --git a/database/neo4j/neo4j.go b/database/neo4j/neo4j.go index 1cde8f97b..3dff54853 100644 --- a/database/neo4j/neo4j.go +++ b/database/neo4j/neo4j.go @@ -35,7 +35,7 @@ type Config struct { MultiStatement bool MultiStatementMaxSize int - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error } type Neo4j struct { @@ -46,6 +46,13 @@ type Neo4j struct { config *Config } +type TriggerResponse struct { + Driver *Neo4j + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(driver neo4j.Driver, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -120,7 +127,7 @@ func (n *Neo4j) Close() error { return n.driver.Close() } -func (n *Neo4j) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (n *Neo4j) AddTriggers(t map[string]func(response interface{}) error) { n.config.Triggers = t } @@ -130,7 +137,12 @@ func (n *Neo4j) Trigger(name string, detail interface{}) error { } if trigger, ok := n.config.Triggers[name]; ok { - return trigger(n, detail) + return trigger(TriggerResponse{ + Driver: n, + Config: n.config, + Trigger: name, + Detail: detail, + }) } return nil @@ -176,11 +188,26 @@ func (n *Neo4j) Run(migration io.Reader) (err error) { return true } + if err = n.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(trimStmt)}); err != nil { + stmtRunErr = err + return false + } + result, err := transaction.Run(string(trimStmt), nil) if _, err := neo4j.Collect(result, err); err != nil { stmtRunErr = err return false } + + if err = n.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(trimStmt)}); err != nil { + stmtRunErr = err + return false + } + return true }); err != nil { return nil, err @@ -194,8 +221,19 @@ func (n *Neo4j) Run(migration io.Reader) (err error) { if err != nil { return err } - - _, err = neo4j.Collect(session.Run(string(body[:]), nil)) + if err = n.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(body[:])}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + if _, err = neo4j.Collect(session.Run(string(body[:]), nil)); err != nil { + return err + } + if err = n.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(body[:])}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return err } @@ -210,12 +248,24 @@ func (n *Neo4j) SetVersion(version int, dirty bool) (err error) { } }() + if err := n.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } query := fmt.Sprintf("MERGE (sm:%s {version: $version}) SET sm.dirty = $dirty, sm.ts = datetime()", n.config.MigrationsLabel) _, err = neo4j.Collect(session.Run(query, map[string]interface{}{"version": version, "dirty": dirty})) if err != nil { return err } + if err := n.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } return nil } @@ -310,12 +360,21 @@ func (n *Neo4j) ensureVersionConstraint() (err error) { return err } if len(res) == 1 { + if err := n.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := n.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } query := fmt.Sprintf("CREATE CONSTRAINT ON (a:%s) ASSERT a.version IS UNIQUE", n.config.MigrationsLabel) if _, err := neo4j.Collect(session.Run(query, nil)); err != nil { return err } + if err := n.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } return nil } diff --git a/database/pgx/pgx.go b/database/pgx/pgx.go index 9b6021f55..ce1fde69b 100644 --- a/database/pgx/pgx.go +++ b/database/pgx/pgx.go @@ -65,7 +65,7 @@ type Config struct { MultiStatementEnabled bool MultiStatementMaxSize int - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error } type Postgres struct { @@ -78,6 +78,13 @@ type Postgres struct { config *Config } +type TriggerResponse struct { + Driver *Postgres + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -249,7 +256,7 @@ func (p *Postgres) Close() error { return nil } -func (p *Postgres) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (p *Postgres) AddTriggers(t map[string]func(response interface{}) error) { p.config.Triggers = t } @@ -259,7 +266,12 @@ func (p *Postgres) Trigger(name string, detail interface{}) error { } if trigger, ok := p.config.Triggers[name]; ok { - return trigger(p, detail) + return trigger(TriggerResponse{ + Driver: p, + Config: p.config, + Trigger: name, + Detail: detail, + }) } return nil @@ -381,9 +393,19 @@ func (p *Postgres) Run(migration io.Reader) error { if p.config.MultiStatementEnabled { var err error if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { + if e := p.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(m)}); e != nil { + return false + } if err = p.runStatement(m); err != nil { return false } + if e := p.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(m)}); e != nil { + return false + } return true }); e != nil { return e @@ -394,7 +416,20 @@ func (p *Postgres) Run(migration io.Reader) error { if err != nil { return err } - return p.runStatement(migr) + if err = p.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(migr)}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + if err = p.runStatement(migr); err != nil { + return err + } + if err = p.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(migr)}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } + return nil } func (p *Postgres) runStatement(statement []byte) error { @@ -470,6 +505,16 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := p.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) if _, err := tx.Exec(query); err != nil { if errRollback := tx.Rollback(); errRollback != nil { @@ -491,6 +536,16 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { } } + if err := p.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } @@ -597,14 +652,25 @@ func (p *Postgres) ensureVersionTable() (err error) { } if count == 1 { + if err := p.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := p.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)` if _, err = p.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + if err := p.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } diff --git a/database/pgx/v5/pgx.go b/database/pgx/v5/pgx.go index fe626b3c8..eb0e81864 100644 --- a/database/pgx/v5/pgx.go +++ b/database/pgx/v5/pgx.go @@ -53,7 +53,7 @@ type Config struct { MultiStatementEnabled bool MultiStatementMaxSize int - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error } type Postgres struct { @@ -66,6 +66,13 @@ type Postgres struct { config *Config } +type TriggerResponse struct { + Driver *Postgres + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -220,7 +227,7 @@ func (p *Postgres) Close() error { return nil } -func (p *Postgres) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (p *Postgres) AddTriggers(t map[string]func(response interface{}) error) { p.config.Triggers = t } @@ -230,7 +237,12 @@ func (p *Postgres) Trigger(name string, detail interface{}) error { } if trigger, ok := p.config.Triggers[name]; ok { - return trigger(p, detail) + return trigger(TriggerResponse{ + Driver: p, + Config: p.config, + Trigger: name, + Detail: detail, + }) } return nil @@ -272,9 +284,19 @@ func (p *Postgres) Run(migration io.Reader) error { if p.config.MultiStatementEnabled { var err error if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { + if e := p.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(m)}); e != nil { + return false + } if err = p.runStatement(m); err != nil { return false } + if e := p.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(m)}); e != nil { + return false + } return true }); e != nil { return e @@ -285,7 +307,20 @@ func (p *Postgres) Run(migration io.Reader) error { if err != nil { return err } - return p.runStatement(migr) + if err = p.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(migr)}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + if err = p.runStatement(migr); err != nil { + return err + } + if err = p.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(migr)}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } + return nil } func (p *Postgres) runStatement(statement []byte) error { @@ -361,6 +396,16 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := p.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) if _, err := tx.Exec(query); err != nil { if errRollback := tx.Rollback(); errRollback != nil { @@ -382,6 +427,16 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { } } + if err := p.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } @@ -482,14 +537,25 @@ func (p *Postgres) ensureVersionTable() (err error) { } if count == 1 { + if err := p.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := p.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)` if _, err = p.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + if err := p.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index dc25dd377..453c39fbd 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -53,7 +53,7 @@ type Config struct { StatementTimeout time.Duration MultiStatementMaxSize int - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error } type Postgres struct { @@ -66,6 +66,13 @@ type Postgres struct { config *Config } +type TriggerResponse struct { + Driver *Postgres + Config *Config + Trigger string + Detail interface{} +} + func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Postgres, error) { if config == nil { return nil, ErrNilConfig @@ -232,7 +239,7 @@ func (p *Postgres) Close() error { return nil } -func (p *Postgres) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (p *Postgres) AddTriggers(t map[string]func(response interface{}) error) { p.config.Triggers = t } @@ -242,7 +249,12 @@ func (p *Postgres) Trigger(name string, detail interface{}) error { } if trigger, ok := p.config.Triggers[name]; ok { - return trigger(p, detail) + return trigger(TriggerResponse{ + Driver: p, + Config: p.config, + Trigger: name, + Detail: detail, + }) } return nil @@ -285,9 +297,19 @@ func (p *Postgres) Run(migration io.Reader) error { if p.config.MultiStatementEnabled { var err error if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool { + if e := p.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(m)}); e != nil { + return false + } if err = p.runStatement(m); err != nil { return false } + if e := p.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(m)}); e != nil { + return false + } return true }); e != nil { return e @@ -298,7 +320,20 @@ func (p *Postgres) Run(migration io.Reader) error { if err != nil { return err } - return p.runStatement(migr) + if err = p.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(migr)}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + if err = p.runStatement(migr); err != nil { + return err + } + if err = p.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(migr)}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } + return nil } func (p *Postgres) runStatement(statement []byte) error { @@ -377,6 +412,16 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := p.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) if _, err := tx.Exec(query); err != nil { if errRollback := tx.Rollback(); errRollback != nil { @@ -398,6 +443,16 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { } } + if err := p.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } @@ -498,13 +553,24 @@ func (p *Postgres) ensureVersionTable() (err error) { } if count == 1 { + if err := p.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := p.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + query = `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.migrationsSchemaName) + `.` + pq.QuoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)` if _, err = p.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + if err := p.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } diff --git a/database/ql/ql.go b/database/ql/ql.go index 5f7936a69..6c37651db 100644 --- a/database/ql/ql.go +++ b/database/ql/ql.go @@ -31,7 +31,7 @@ type Config struct { MigrationsTable string DatabaseName string - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error } type Ql struct { @@ -41,6 +41,13 @@ type Ql struct { config *Config } +type TriggerResponse struct { + Driver *Ql + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -86,6 +93,12 @@ func (m *Ql) ensureVersionTable() (err error) { if err != nil { return err } + if err := m.Trigger(database.TrigVersionTablePre, nil); err != nil { + if err := tx.Rollback(); err != nil { + return err + } + return err + } if _, err := tx.Exec(fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s (version uint64, dirty bool); CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version); @@ -95,6 +108,12 @@ func (m *Ql) ensureVersionTable() (err error) { } return err } + if err := m.Trigger(database.TrigVersionTablePost, nil); err != nil { + if err := tx.Rollback(); err != nil { + return err + } + return err + } if err := tx.Commit(); err != nil { return err } @@ -128,7 +147,7 @@ func (m *Ql) Close() error { return m.db.Close() } -func (m *Ql) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (m *Ql) AddTriggers(t map[string]func(response interface{}) error) { m.config.Triggers = t } @@ -138,7 +157,12 @@ func (m *Ql) Trigger(name string, detail interface{}) error { } if trigger, ok := m.config.Triggers[name]; ok { - return trigger(m, detail) + return trigger(TriggerResponse{ + Driver: m, + Config: m.config, + Trigger: name, + Detail: detail, + }) } return nil @@ -203,7 +227,22 @@ func (m *Ql) Run(migration io.Reader) error { } query := string(migr[:]) - return m.executeQuery(query) + if err := m.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + + if err = m.executeQuery(query); err != nil { + return err + } + + if err := m.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } + return nil } func (m *Ql) executeQuery(query string) error { tx, err := m.db.Begin() @@ -227,6 +266,16 @@ func (m *Ql) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := m.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := "TRUNCATE TABLE " + m.config.MigrationsTable if _, err := tx.Exec(query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} @@ -246,6 +295,16 @@ func (m *Ql) SetVersion(version int, dirty bool) error { } } + if err := m.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } diff --git a/database/redshift/redshift.go b/database/redshift/redshift.go index 87c07d58e..cff6973e1 100644 --- a/database/redshift/redshift.go +++ b/database/redshift/redshift.go @@ -35,7 +35,7 @@ type Config struct { MigrationsTable string DatabaseName string - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error } type Redshift struct { @@ -47,6 +47,13 @@ type Redshift struct { config *Config } +type TriggerResponse struct { + Driver *Redshift + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -127,7 +134,7 @@ func (p *Redshift) Close() error { return nil } -func (p *Redshift) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (p *Redshift) AddTriggers(t map[string]func(response interface{}) error) { p.config.Triggers = t } @@ -137,7 +144,12 @@ func (p *Redshift) Trigger(name string, detail interface{}) error { } if trigger, ok := p.config.Triggers[name]; ok { - return trigger(p, detail) + return trigger(TriggerResponse{ + Driver: p, + Config: p.config, + Trigger: name, + Detail: detail, + }) } return nil @@ -166,6 +178,11 @@ func (p *Redshift) Run(migration io.Reader) error { // run migration query := string(migr[:]) + if err := p.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } if _, err := p.conn.ExecContext(context.Background(), query); err != nil { if pgErr, ok := err.(*pq.Error); ok { var line uint @@ -187,6 +204,11 @@ func (p *Redshift) Run(migration io.Reader) error { } return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } + if err := p.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return nil } @@ -232,6 +254,16 @@ func (p *Redshift) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := p.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := `DELETE FROM "` + p.config.MigrationsTable + `"` if _, err := tx.Exec(query); err != nil { if errRollback := tx.Rollback(); errRollback != nil { @@ -253,6 +285,16 @@ func (p *Redshift) SetVersion(version int, dirty bool) error { } } + if err := p.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } @@ -346,13 +388,25 @@ func (p *Redshift) ensureVersionTable() (err error) { return &database.Error{OrigErr: err, Query: []byte(query)} } if count == 1 { + if err := p.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := p.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + // if not, create the empty migration table query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)` if _, err := p.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + + if err := p.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } diff --git a/database/rqlite/rqlite.go b/database/rqlite/rqlite.go index e6227a7f1..fd3f633fa 100644 --- a/database/rqlite/rqlite.go +++ b/database/rqlite/rqlite.go @@ -40,7 +40,7 @@ type Config struct { // MigrationsTable configures the migrations table name MigrationsTable string - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error } type Rqlite struct { @@ -50,6 +50,13 @@ type Rqlite struct { config *Config } +type TriggerResponse struct { + Driver *Rqlite + Config *Config + Trigger string + Detail interface{} +} + // WithInstance creates a rqlite database driver with an existing gorqlite database connection // and a Config struct func WithInstance(instance *gorqlite.Connection, config *Config) (database.Driver, error) { @@ -99,6 +106,10 @@ func (r *Rqlite) ensureVersionTable() (err error) { } }() + if err := r.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + stmts := []string{ fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (version uint64, dirty bool)`, r.config.MigrationsTable), fmt.Sprintf(`CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version)`, r.config.MigrationsTable), @@ -108,6 +119,10 @@ func (r *Rqlite) ensureVersionTable() (err error) { return err } + if err := r.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } @@ -140,7 +155,7 @@ func (r *Rqlite) Close() error { return nil } -func (r *Rqlite) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (r *Rqlite) AddTriggers(t map[string]func(response interface{}) error) { r.config.Triggers = t } @@ -150,7 +165,12 @@ func (r *Rqlite) Trigger(name string, detail interface{}) error { } if trigger, ok := r.config.Triggers[name]; ok { - return trigger(r, detail) + return trigger(TriggerResponse{ + Driver: r, + Config: r.config, + Trigger: name, + Detail: detail, + }) } return nil @@ -184,9 +204,19 @@ func (r *Rqlite) Run(migration io.Reader) error { } query := string(migr[:]) + if err := r.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } if _, err := r.db.WriteOne(query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + if err := r.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return nil } @@ -195,6 +225,13 @@ func (r *Rqlite) Run(migration io.Reader) error { // Migrate will call this function before and after each call to Run. // version must be >= -1. -1 means NilVersion. func (r *Rqlite) SetVersion(version int, dirty bool) error { + if err := r.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + deleteQuery := fmt.Sprintf(`DELETE FROM %s`, r.config.MigrationsTable) statements := []gorqlite.ParameterizedStatement{ { @@ -228,6 +265,13 @@ func (r *Rqlite) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Query: []byte(deleteQuery + "\n" + insertQuery)} } + if err := r.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + return nil } diff --git a/database/snowflake/snowflake.go b/database/snowflake/snowflake.go index d59f0cab8..9ea573da1 100644 --- a/database/snowflake/snowflake.go +++ b/database/snowflake/snowflake.go @@ -36,7 +36,7 @@ type Config struct { MigrationsTable string DatabaseName string - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error } type Snowflake struct { @@ -48,6 +48,13 @@ type Snowflake struct { config *Config } +type TriggerResponse struct { + Driver *Snowflake + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -160,7 +167,7 @@ func (p *Snowflake) Close() error { return nil } -func (p *Snowflake) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (p *Snowflake) AddTriggers(t map[string]func(response interface{}) error) { p.config.Triggers = t } @@ -170,7 +177,12 @@ func (p *Snowflake) Trigger(name string, detail interface{}) error { } if trigger, ok := p.config.Triggers[name]; ok { - return trigger(p, detail) + return trigger(TriggerResponse{ + Driver: p, + Config: p.config, + Trigger: name, + Detail: detail, + }) } return nil @@ -198,6 +210,11 @@ func (p *Snowflake) Run(migration io.Reader) error { // run migration query := string(migr[:]) + if err := p.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } if _, err := p.conn.ExecContext(context.Background(), query); err != nil { if pgErr, ok := err.(*pq.Error); ok { var line uint @@ -219,6 +236,11 @@ func (p *Snowflake) Run(migration io.Reader) error { } return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } + if err := p.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return nil } @@ -264,6 +286,16 @@ func (p *Snowflake) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := p.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := `DELETE FROM "` + p.config.MigrationsTable + `"` if _, err := tx.Exec(query); err != nil { if errRollback := tx.Rollback(); errRollback != nil { @@ -287,6 +319,16 @@ func (p *Snowflake) SetVersion(version int, dirty bool) error { } } + if err := p.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } @@ -380,9 +422,16 @@ func (p *Snowflake) ensureVersionTable() (err error) { return &database.Error{OrigErr: err, Query: []byte(query)} } if count == 1 { + if err := p.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := p.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + // if not, create the empty migration table query = `CREATE TABLE if not exists "` + p.config.MigrationsTable + `" ( version bigint not null primary key, dirty boolean not null)` @@ -390,5 +439,9 @@ func (p *Snowflake) ensureVersionTable() (err error) { return &database.Error{OrigErr: err, Query: []byte(query)} } + if err := p.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } diff --git a/database/spanner/spanner.go b/database/spanner/spanner.go index b4b5fa711..b4071b98b 100644 --- a/database/spanner/spanner.go +++ b/database/spanner/spanner.go @@ -57,7 +57,7 @@ type Config struct { // and void of comments. CleanStatements bool - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error } // Spanner implements database.Driver for Google Cloud Spanner @@ -74,6 +74,13 @@ type DB struct { data *spanner.Client } +type TriggerResponse struct { + Driver *Spanner + Config *Config + Trigger string + Detail interface{} +} + func NewDB(admin sdb.DatabaseAdminClient, data spanner.Client) *DB { return &DB{ admin: &admin, @@ -152,7 +159,7 @@ func (s *Spanner) Close() error { return s.db.admin.Close() } -func (s *Spanner) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (s *Spanner) AddTriggers(t map[string]func(response interface{}) error) { s.config.Triggers = t } @@ -162,7 +169,12 @@ func (s *Spanner) Trigger(name string, detail interface{}) error { } if trigger, ok := s.config.Triggers[name]; ok { - return trigger(s, detail) + return trigger(TriggerResponse{ + Driver: s, + Config: s.config, + Trigger: name, + Detail: detail, + }) } return nil @@ -192,6 +204,12 @@ func (s *Spanner) Run(migration io.Reader) error { return err } + if err := s.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(migr)}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + stmts := []string{string(migr)} if s.config.CleanStatements { stmts, err = cleanStatements(migr) @@ -214,6 +232,12 @@ func (s *Spanner) Run(migration io.Reader) error { return &database.Error{OrigErr: err, Err: "migration failed", Query: migr} } + if err := s.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(migr)}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } + return nil } @@ -221,6 +245,13 @@ func (s *Spanner) Run(migration io.Reader) error { func (s *Spanner) SetVersion(version int, dirty bool) error { ctx := context.Background() + if err := s.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + _, err := s.db.data.ReadWriteTransaction(ctx, func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { m := []*spanner.Mutation{ @@ -235,6 +266,13 @@ func (s *Spanner) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err} } + if err := s.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + return nil } @@ -335,9 +373,16 @@ func (s *Spanner) ensureVersionTable() (err error) { tbl := s.config.MigrationsTable iter := s.db.data.Single().Read(ctx, tbl, spanner.AllKeys(), []string{"Version"}) if err := iter.Do(func(r *spanner.Row) error { return nil }); err == nil { + if err := s.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := s.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + stmt := fmt.Sprintf(`CREATE TABLE %s ( Version INT64 NOT NULL, Dirty BOOL NOT NULL @@ -355,6 +400,10 @@ func (s *Spanner) ensureVersionTable() (err error) { return &database.Error{OrigErr: err, Query: []byte(stmt)} } + if err := s.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } diff --git a/database/sqlcipher/sqlcipher.go b/database/sqlcipher/sqlcipher.go index 42bb52692..70c4f1aea 100644 --- a/database/sqlcipher/sqlcipher.go +++ b/database/sqlcipher/sqlcipher.go @@ -32,7 +32,7 @@ type Config struct { DatabaseName string NoTxWrap bool - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error } type Sqlite struct { @@ -42,6 +42,13 @@ type Sqlite struct { config *Config } +type TriggerResponse struct { + Driver *Sqlite + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -83,6 +90,10 @@ func (m *Sqlite) ensureVersionTable() (err error) { } }() + if err := m.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + query := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool); CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version); @@ -91,6 +102,11 @@ func (m *Sqlite) ensureVersionTable() (err error) { if _, err := m.db.Exec(query); err != nil { return err } + + if err := m.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } @@ -135,7 +151,7 @@ func (m *Sqlite) Close() error { return m.db.Close() } -func (m *Sqlite) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (m *Sqlite) AddTriggers(t map[string]func(response interface{}) error) { m.config.Triggers = t } @@ -145,7 +161,12 @@ func (m *Sqlite) Trigger(name string, detail interface{}) error { } if trigger, ok := m.config.Triggers[name]; ok { - return trigger(m, detail) + return trigger(TriggerResponse{ + Driver: m, + Config: m.config, + Trigger: name, + Detail: detail, + }) } return nil @@ -216,10 +237,29 @@ func (m *Sqlite) Run(migration io.Reader) error { } query := string(migr[:]) + if err := m.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + if m.config.NoTxWrap { - return m.executeQueryNoTx(query) + if err = m.executeQueryNoTx(query); err != nil { + return err + } + } else { + if err = m.executeQuery(query); err != nil { + return err + } } - return m.executeQuery(query) + + if err := m.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } + + return nil } func (m *Sqlite) executeQuery(query string) error { @@ -252,6 +292,16 @@ func (m *Sqlite) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := m.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := "DELETE FROM " + m.config.MigrationsTable if _, err := tx.Exec(query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} @@ -270,6 +320,16 @@ func (m *Sqlite) SetVersion(version int, dirty bool) error { } } + if err := m.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } diff --git a/database/sqlite/sqlite.go b/database/sqlite/sqlite.go index d6ae9bc43..58b2b1966 100644 --- a/database/sqlite/sqlite.go +++ b/database/sqlite/sqlite.go @@ -32,7 +32,7 @@ type Config struct { DatabaseName string NoTxWrap bool - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error } type Sqlite struct { @@ -42,6 +42,13 @@ type Sqlite struct { config *Config } +type TriggerResponse struct { + Driver *Sqlite + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -83,6 +90,10 @@ func (m *Sqlite) ensureVersionTable() (err error) { } }() + if err := m.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + query := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool); CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version); @@ -91,6 +102,10 @@ func (m *Sqlite) ensureVersionTable() (err error) { if _, err := m.db.Exec(query); err != nil { return err } + + if err := m.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } return nil } @@ -135,7 +150,7 @@ func (m *Sqlite) Close() error { return m.db.Close() } -func (m *Sqlite) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (m *Sqlite) AddTriggers(t map[string]func(response interface{}) error) { m.config.Triggers = t } @@ -145,7 +160,12 @@ func (m *Sqlite) Trigger(name string, detail interface{}) error { } if trigger, ok := m.config.Triggers[name]; ok { - return trigger(m, detail) + return trigger(TriggerResponse{ + Driver: m, + Config: m.config, + Trigger: name, + Detail: detail, + }) } return nil @@ -216,10 +236,26 @@ func (m *Sqlite) Run(migration io.Reader) error { } query := string(migr[:]) + if err := m.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + if m.config.NoTxWrap { - return m.executeQueryNoTx(query) + if err = m.executeQueryNoTx(query); err != nil { + return err + } + } else if err = m.executeQuery(query); err != nil { + return err + } + + if err := m.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} } - return m.executeQuery(query) + return nil } func (m *Sqlite) executeQuery(query string) error { @@ -252,6 +288,16 @@ func (m *Sqlite) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := m.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := "DELETE FROM " + m.config.MigrationsTable if _, err := tx.Exec(query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} @@ -270,6 +316,16 @@ func (m *Sqlite) SetVersion(version int, dirty bool) error { } } + if err := m.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } diff --git a/database/sqlite3/sqlite3.go b/database/sqlite3/sqlite3.go index 4d1a9235b..62a032654 100644 --- a/database/sqlite3/sqlite3.go +++ b/database/sqlite3/sqlite3.go @@ -32,7 +32,7 @@ type Config struct { DatabaseName string NoTxWrap bool - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error } type Sqlite struct { @@ -42,6 +42,13 @@ type Sqlite struct { config *Config } +type TriggerResponse struct { + Driver *Sqlite + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -83,6 +90,10 @@ func (m *Sqlite) ensureVersionTable() (err error) { } }() + if err := m.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + query := fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s (version uint64,dirty bool); CREATE UNIQUE INDEX IF NOT EXISTS version_unique ON %s (version); @@ -91,6 +102,10 @@ func (m *Sqlite) ensureVersionTable() (err error) { if _, err := m.db.Exec(query); err != nil { return err } + + if err := m.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } return nil } @@ -135,7 +150,7 @@ func (m *Sqlite) Close() error { return m.db.Close() } -func (m *Sqlite) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (m *Sqlite) AddTriggers(t map[string]func(response interface{}) error) { m.config.Triggers = t } @@ -145,7 +160,12 @@ func (m *Sqlite) Trigger(name string, detail interface{}) error { } if trigger, ok := m.config.Triggers[name]; ok { - return trigger(m, detail) + return trigger(TriggerResponse{ + Driver: m, + Config: m.config, + Trigger: name, + Detail: detail, + }) } return nil @@ -217,9 +237,19 @@ func (m *Sqlite) Run(migration io.Reader) error { query := string(migr[:]) if m.config.NoTxWrap { - return m.executeQueryNoTx(query) + if err = m.executeQueryNoTx(query); err != nil { + return err + } + } else if err = m.executeQuery(query); err != nil { + return err + } + + if err := m.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} } - return m.executeQuery(query) + return nil } func (m *Sqlite) executeQuery(query string) error { @@ -252,6 +282,16 @@ func (m *Sqlite) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := m.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := "DELETE FROM " + m.config.MigrationsTable if _, err := tx.Exec(query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} @@ -270,6 +310,16 @@ func (m *Sqlite) SetVersion(version int, dirty bool) error { } } + if err := m.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } diff --git a/database/sqlserver/sqlserver.go b/database/sqlserver/sqlserver.go index 75b9915f9..9544c86e4 100644 --- a/database/sqlserver/sqlserver.go +++ b/database/sqlserver/sqlserver.go @@ -46,7 +46,7 @@ type Config struct { DatabaseName string SchemaName string - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error } // SQL Server connection @@ -60,6 +60,13 @@ type SQLServer struct { config *Config } +type TriggerResponse struct { + Driver *SQLServer + Config *Config + Trigger string + Detail interface{} +} + // WithInstance returns a database instance from an already created database connection. // // Note that the deprecated `mssql` driver is not supported. Please use the newer `sqlserver` driver. @@ -192,7 +199,7 @@ func (ss *SQLServer) Close() error { return nil } -func (ss *SQLServer) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (ss *SQLServer) AddTriggers(t map[string]func(response interface{}) error) { ss.config.Triggers = t } @@ -202,7 +209,12 @@ func (ss *SQLServer) Trigger(name string, detail interface{}) error { } if trigger, ok := ss.config.Triggers[name]; ok { - return trigger(ss, detail) + return trigger(TriggerResponse{ + Driver: ss, + Config: ss.config, + Trigger: name, + Detail: detail, + }) } return nil @@ -265,6 +277,11 @@ func (ss *SQLServer) Run(migration io.Reader) error { // run migration query := string(migr[:]) + if err := ss.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { if msErr, ok := err.(mssql.Error); ok { message := fmt.Sprintf("migration failed: %s", msErr.Message) @@ -275,6 +292,11 @@ func (ss *SQLServer) Run(migration io.Reader) error { } return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } + if err := ss.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return nil } @@ -287,6 +309,16 @@ func (ss *SQLServer) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } + if err := ss.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + query := `TRUNCATE TABLE ` + ss.getMigrationTable() if _, err := tx.Exec(query); err != nil { if errRollback := tx.Rollback(); errRollback != nil { @@ -312,6 +344,16 @@ func (ss *SQLServer) SetVersion(version int, dirty bool) error { } } + if err := ss.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + if errRollback := tx.Rollback(); errRollback != nil { + err = multierror.Append(err, errRollback) + } + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + if err := tx.Commit(); err != nil { return &database.Error{OrigErr: err, Err: "transaction commit failed"} } @@ -386,6 +428,10 @@ func (ss *SQLServer) ensureVersionTable() (err error) { } }() + if err := ss.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + query := `IF NOT EXISTS (SELECT * FROM sysobjects @@ -398,6 +444,10 @@ func (ss *SQLServer) ensureVersionTable() (err error) { return &database.Error{OrigErr: err, Query: []byte(query)} } + if err := ss.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } + return nil } diff --git a/database/stub/stub.go b/database/stub/stub.go index 7b4363e08..f059b8003 100644 --- a/database/stub/stub.go +++ b/database/stub/stub.go @@ -35,7 +35,14 @@ func (s *Stub) Open(url string) (database.Driver, error) { } type Config struct { - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error +} + +type TriggerResponse struct { + Driver *Stub + Config *Config + Trigger string + Detail interface{} } func WithInstance(instance interface{}, config *Config) (database.Driver, error) { @@ -51,7 +58,7 @@ func (s *Stub) Close() error { return nil } -func (s *Stub) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (s *Stub) AddTriggers(t map[string]func(response interface{}) error) { s.Config.Triggers = t } @@ -61,7 +68,12 @@ func (s *Stub) Trigger(name string, detail interface{}) error { } if trigger, ok := s.Config.Triggers[name]; ok { - return trigger(s, detail) + return trigger(TriggerResponse{ + Driver: s, + Config: s.Config, + Trigger: name, + Detail: detail, + }) } return nil diff --git a/database/yugabytedb/yugabytedb.go b/database/yugabytedb/yugabytedb.go index c173377e1..3a507292c 100644 --- a/database/yugabytedb/yugabytedb.go +++ b/database/yugabytedb/yugabytedb.go @@ -50,7 +50,7 @@ type Config struct { MaxRetryElapsedTime time.Duration MaxRetries int - Triggers map[string]func(d database.Driver, detail interface{}) error + Triggers map[string]func(response interface{}) error } type YugabyteDB struct { @@ -61,6 +61,13 @@ type YugabyteDB struct { config *Config } +type TriggerResponse struct { + Driver *YugabyteDB + Config *Config + Trigger string + Detail interface{} +} + func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig @@ -191,7 +198,7 @@ func (c *YugabyteDB) Close() error { return c.db.Close() } -func (c *YugabyteDB) AddTriggers(t map[string]func(d database.Driver, detail interface{}) error) { +func (c *YugabyteDB) AddTriggers(t map[string]func(response interface{}) error) { c.config.Triggers = t } @@ -201,7 +208,12 @@ func (c *YugabyteDB) Trigger(name string, detail interface{}) error { } if trigger, ok := c.config.Triggers[name]; ok { - return trigger(c, detail) + return trigger(TriggerResponse{ + Driver: c, + Config: c.config, + Trigger: name, + Detail: detail, + }) } return nil @@ -281,15 +293,32 @@ func (c *YugabyteDB) Run(migration io.Reader) error { // run migration query := string(migr[:]) + if err := c.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } if _, err := c.db.Exec(query); err != nil { return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } + if err := c.Trigger(database.TrigRunPost, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } return nil } func (c *YugabyteDB) SetVersion(version int, dirty bool) error { return c.doTxWithRetry(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable}, func(tx *sql.Tx) error { + if err := c.Trigger(database.TrigSetVersionPre, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPre"} + } + if _, err := tx.Exec(`DELETE FROM "` + c.config.MigrationsTable + `"`); err != nil { return err } @@ -303,6 +332,13 @@ func (c *YugabyteDB) SetVersion(version int, dirty bool) error { } } + if err := c.Trigger(database.TrigSetVersionPost, struct { + Version int + Dirty bool + }{Version: version, Dirty: dirty}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger SetVersionPost"} + } + return nil }) } @@ -393,14 +429,25 @@ func (c *YugabyteDB) ensureVersionTable() (err error) { return &database.Error{OrigErr: err, Query: []byte(query)} } if count == 1 { + if err := c.Trigger(database.TrigVersionTableExists, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTableExists"} + } return nil } + if err := c.Trigger(database.TrigVersionTablePre, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePre"} + } + // if not, create the empty migration table query = `CREATE TABLE "` + c.config.MigrationsTable + `" (version INT NOT NULL PRIMARY KEY, dirty BOOL NOT NULL)` if _, err := c.db.Exec(query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } + + if err := c.Trigger(database.TrigVersionTablePost, nil); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger VersionTablePost"} + } return nil } From 67a598bd236a56c8c1a6422f2aa787d037024871 Mon Sep 17 00:00:00 2001 From: David Syers Date: Tue, 20 May 2025 19:59:34 +0100 Subject: [PATCH 05/11] Mock --- .golangci.yml | 47 ++++++++++++++++++----------------------- database/driver_test.go | 2 +- 2 files changed, 21 insertions(+), 28 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 913bddb87..68a8e953b 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,37 +1,30 @@ -version: "2" +run: + # timeout for analysis, e.g. 30s, 5m, default is 1m + timeout: 5m linters: enable: + #- golint + #- interfacer + - unconvert + #- dupl - goconst + - gofmt - misspell + - unparam - nakedret - prealloc - revive - - unconvert - - unparam - settings: - misspell: - locale: US - revive: - rules: - - name: redundant-build-tag - exclusions: - generated: lax + #- gosec +linters-settings: + misspell: + locale: US + revive: rules: - - path: (.+)\.go$ - text: G104 - paths: - - third_party$ - - builtin$ - - examples$ + - name: redundant-build-tag issues: - max-issues-per-linter: 0 max-same-issues: 0 -formatters: - enable: - - gofmt - exclusions: - generated: lax - paths: - - third_party$ - - builtin$ - - examples$ + max-issues-per-linter: 0 + exclude-use-default: false + exclude: + # gosec: Duplicated errcheck checks + - G104 diff --git a/database/driver_test.go b/database/driver_test.go index 65fed957d..1b157cde4 100644 --- a/database/driver_test.go +++ b/database/driver_test.go @@ -28,7 +28,7 @@ func (m *mockDriver) Close() error { return nil } -func (m *mockDriver) AddTriggers(t map[string]func(m Driver, detail interface{}) error) {} +func (m *mockDriver) AddTriggers(t map[string]func(detail interface{}) error) {} func (m *mockDriver) Trigger(name string, detail interface{}) error { return nil From 1c61b86631fab5651aa93d987fab018839ff2ba9 Mon Sep 17 00:00:00 2001 From: David Syers Date: Fri, 11 Jul 2025 16:22:01 +0100 Subject: [PATCH 06/11] fix --- migrate.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/migrate.go b/migrate.go index 495603013..f61a855fc 100644 --- a/migrate.go +++ b/migrate.go @@ -158,7 +158,7 @@ func NewFromOptions(o Options) (*Migrate, error) { } m.databaseDrv = databaseDrv m.databaseDrv.AddTriggers(o.DatabaseTriggers) - } else if o.DatabaseName != "" && o.DatabaseInstance != nil { + } else if o.DatabaseInstance != nil { m.databaseName = o.DatabaseName m.databaseDrv = o.DatabaseInstance } From 7842b9260869c5e048c1e16c45423795edf77d47 Mon Sep 17 00:00:00 2001 From: David Syers Date: Sat, 12 Jul 2025 20:02:39 +0100 Subject: [PATCH 07/11] Tests --- Makefile | 1 - database/cassandra/cassandra.go | 2 +- database/cockroachdb/cockroachdb.go | 1 + database/testing/migrate_testing.go | 69 +++++++++++++++++++++++++++++ migrate.go | 4 ++ 5 files changed, 75 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 8e23a43c7..b7d48f803 100644 --- a/Makefile +++ b/Makefile @@ -38,7 +38,6 @@ test: @mkdir $(COVERAGE_DIR) make test-with-flags TEST_FLAGS='-v -race -covermode atomic -coverprofile $$(COVERAGE_DIR)/combined.txt -bench=. -benchmem -timeout 20m' - test-with-flags: @echo SOURCE: $(SOURCE) @echo DATABASE_TEST: $(DATABASE_TEST) diff --git a/database/cassandra/cassandra.go b/database/cassandra/cassandra.go index 5b51e154a..50f141a33 100644 --- a/database/cassandra/cassandra.go +++ b/database/cassandra/cassandra.go @@ -287,7 +287,7 @@ func (c *Cassandra) Run(migration io.Reader) error { // TODO: cast to Cassandra error and get line number return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } - if err := c.Trigger(database.TrigRunPre, struct { + if err := c.Trigger(database.TrigRunPost, struct { Query string }{Query: string(migr)}); err != nil { return database.Error{OrigErr: err, Err: "failed to trigger RunPost"} diff --git a/database/cockroachdb/cockroachdb.go b/database/cockroachdb/cockroachdb.go index 19b9e4490..0d0d680da 100644 --- a/database/cockroachdb/cockroachdb.go +++ b/database/cockroachdb/cockroachdb.go @@ -411,6 +411,7 @@ func (c *CockroachDb) ensureVersionTable() (err error) { } func (c *CockroachDb) ensureLockTable() error { + // check if lock table exists var count int query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` diff --git a/database/testing/migrate_testing.go b/database/testing/migrate_testing.go index be8ed195f..43d0cd14e 100644 --- a/database/testing/migrate_testing.go +++ b/database/testing/migrate_testing.go @@ -4,6 +4,8 @@ package testing import ( + "github.com/golang-migrate/migrate/v4/database" + "reflect" "testing" ) @@ -28,7 +30,74 @@ func TestMigrateDrop(t *testing.T, m *migrate.Migrate) { func TestMigrateUp(t *testing.T, m *migrate.Migrate) { t.Log("UP") + + tt := &triggerTest{ + t: t, + m: m, + triggered: map[string]bool{ + migrate.TrigRunMigrationPre: false, + migrate.TrigRunMigrationPost: false, + migrate.TrigRunMigrationVersionPre: false, + migrate.TrigRunMigrationVersionPost: false, + database.TrigRunPre: false, + database.TrigRunPost: false, + }, + } + + m.Triggers = map[string]func(r migrate.TriggerResponse) error{ + migrate.TrigRunMigrationPre: tt.trigMigrationCheck, + migrate.TrigRunMigrationPost: tt.trigMigrationCheck, + migrate.TrigRunMigrationVersionPre: tt.trigMigrationCheck, + migrate.TrigRunMigrationVersionPost: tt.trigMigrationCheck, + } + + m.AddDatabaseTriggers(map[string]func(response interface{}) error{ + database.TrigRunPre: tt.trigDatabaseMigrationCheck, + database.TrigRunPost: tt.trigDatabaseMigrationCheck, + }) + if err := m.Up(); err != nil { t.Fatal(err) } + + if !tt.triggered[migrate.TrigRunMigrationPre] { + t.Fatalf("expected trigger %s to be called, but it was not", migrate.TrigRunMigrationPre) + } + if !tt.triggered[migrate.TrigRunMigrationPost] { + t.Fatalf("expected trigger %s to be called, but it was not", migrate.TrigRunMigrationPost) + } + if !tt.triggered[migrate.TrigRunMigrationVersionPre] { + t.Fatalf("expected trigger %s to be called, but it was not", migrate.TrigRunMigrationVersionPre) + } + if !tt.triggered[migrate.TrigRunMigrationVersionPost] { + t.Fatalf("expected trigger %s to be called, but it was not", migrate.TrigRunMigrationVersionPost) + } + if !tt.triggered[database.TrigRunPre] { + t.Fatalf("expected database trigger %s to be called, but it was not", database.TrigRunPre) + } + if !tt.triggered[database.TrigRunPost] { + t.Fatalf("expected database trigger %s to be called, but it was not", database.TrigRunPost) + } +} + +type triggerTest struct { + t *testing.T + m *migrate.Migrate + triggered map[string]bool +} + +func (tt *triggerTest) trigMigrationCheck(r migrate.TriggerResponse) error { + tt.triggered[r.Trigger] = true + return nil +} + +func (tt *triggerTest) trigDatabaseMigrationCheck(response interface{}) error { + val := reflect.ValueOf(response) + field := val.FieldByName("Trigger") + if !field.IsValid() { + tt.t.Fatalf("expected response to have a Trigger field, got %T", response) + } + + tt.triggered[field.String()] = true + return nil } diff --git a/migrate.go b/migrate.go index f61a855fc..948e775c3 100644 --- a/migrate.go +++ b/migrate.go @@ -238,6 +238,10 @@ func (m *Migrate) Trigger(name string, detail interface{}) error { return nil } +func (m *Migrate) AddDatabaseTriggers(t map[string]func(response interface{}) error) { + m.databaseDrv.AddTriggers(t) +} + // Close closes the source and the database. func (m *Migrate) Close() (source error, database error) { databaseSrvClose := make(chan error) From 1e71c35bdb9d6cb90f9a7bef96c62fc20a13bdd3 Mon Sep 17 00:00:00 2001 From: David Syers Date: Sat, 12 Jul 2025 20:03:57 +0100 Subject: [PATCH 08/11] space --- Makefile | 1 + 1 file changed, 1 insertion(+) diff --git a/Makefile b/Makefile index b7d48f803..8e23a43c7 100644 --- a/Makefile +++ b/Makefile @@ -38,6 +38,7 @@ test: @mkdir $(COVERAGE_DIR) make test-with-flags TEST_FLAGS='-v -race -covermode atomic -coverprofile $$(COVERAGE_DIR)/combined.txt -bench=. -benchmem -timeout 20m' + test-with-flags: @echo SOURCE: $(SOURCE) @echo DATABASE_TEST: $(DATABASE_TEST) From 6c6bedded8c87fa5635e4599568ede763715ef8e Mon Sep 17 00:00:00 2001 From: David Syers Date: Mon, 14 Jul 2025 22:19:38 +0100 Subject: [PATCH 09/11] Missed one --- database/sqlite3/sqlite3.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/database/sqlite3/sqlite3.go b/database/sqlite3/sqlite3.go index 62a032654..ab449555e 100644 --- a/database/sqlite3/sqlite3.go +++ b/database/sqlite3/sqlite3.go @@ -236,6 +236,12 @@ func (m *Sqlite) Run(migration io.Reader) error { } query := string(migr[:]) + if err := m.Trigger(database.TrigRunPre, struct { + Query string + }{Query: query}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + if m.config.NoTxWrap { if err = m.executeQueryNoTx(query); err != nil { return err From b226f737e2e7957e626cb6ffb21492f397166136 Mon Sep 17 00:00:00 2001 From: David Syers Date: Tue, 15 Jul 2025 14:53:18 +0100 Subject: [PATCH 10/11] stub --- database/stub/stub.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/database/stub/stub.go b/database/stub/stub.go index f059b8003..73f788eeb 100644 --- a/database/stub/stub.go +++ b/database/stub/stub.go @@ -98,8 +98,22 @@ func (s *Stub) Run(migration io.Reader) error { if err != nil { return err } + + if err := s.Trigger(database.TrigRunPre, struct { + Query string + }{Query: string(m[:])}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPre"} + } + s.LastRunMigration = m s.MigrationSequence = append(s.MigrationSequence, string(m[:])) + + if err := s.Trigger(database.TrigRunPost, struct { + Query string + }{Query: string(m[:])}); err != nil { + return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} + } + return nil } From 167724507290243006d0cae7af46cd12695721a3 Mon Sep 17 00:00:00 2001 From: David Syers Date: Tue, 15 Jul 2025 14:55:07 +0100 Subject: [PATCH 11/11] lint --- database/stub/stub.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/database/stub/stub.go b/database/stub/stub.go index 73f788eeb..b28017fb0 100644 --- a/database/stub/stub.go +++ b/database/stub/stub.go @@ -113,7 +113,7 @@ func (s *Stub) Run(migration io.Reader) error { }{Query: string(m[:])}); err != nil { return &database.Error{OrigErr: err, Err: "failed to trigger RunPost"} } - + return nil }