Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow running Golang based post migration steps #1253

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 96 additions & 9 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,71 @@ func (e ErrDirty) Error() string {
return fmt.Sprintf("Dirty database version %v. Fix and force version.", e.Version)
}

// PostStepCallback is a callback function type that can be used to execute a
// Golang based migration step after a SQL based migration step has been
// executed. The callback function receives the migration and the database
// driver as arguments.
type PostStepCallback func(migr *Migration, driver database.Driver) error

// options is a set of optional options that can be set when a Migrate instance
// is created.
type options struct {
// postStepCallbacks is a map of PostStepCallback functions that can be
// used to execute a Golang based migration step after a SQL based
// migration step has been executed. The key is the migration version
// and the value is the callback function that should be run _after_ the
// step was executed (but within the same database transaction).
postStepCallbacks map[uint]PostStepCallback
}

// defaultOptions returns a new options struct with default values.
func defaultOptions() options {
return options{
postStepCallbacks: make(map[uint]PostStepCallback),
}
}

// Option is a function that can be used to set options on a Migrate instance.
type Option func(*options)

// WithPostStepCallbacks is an option that can be used to set a map of
// PostStepCallback functions that can be used to execute a Golang based
// migration step after a SQL based migration step has been executed. The key is
// the migration version and the value is the callback function that should be
// run _after_ the step was executed (but before the version is marked as
// cleanly executed). An error returned from the callback will cause the
// migration to fail and the step to be marked as dirty.
func WithPostStepCallbacks(
postStepCallbacks map[uint]PostStepCallback) Option {

return func(o *options) {
o.postStepCallbacks = postStepCallbacks
}
}

// WithPostStepCallback is an option that can be used to set a PostStepCallback
// function that can be used to execute a Golang based migration step after the
// SQL based migration step with the given version number has been executed. The
// callback is the function that should be run _after_ the step was executed
// (but before the version is marked as cleanly executed). An error returned
// from the callback will cause the migration to fail and the step to be marked
// as dirty.
func WithPostStepCallback(version uint, callback PostStepCallback) Option {
return func(o *options) {
o.postStepCallbacks[version] = callback
}
}

type Migrate struct {
sourceName string
sourceDrv source.Driver
databaseName string
databaseDrv database.Driver

// opts is a set of options that can be used to modify the behavior
// of the Migrate instance.
opts options

// Log accepts a Logger interface
Log Logger

Expand All @@ -84,8 +143,8 @@ type Migrate struct {

// New returns a new Migrate instance from a source URL and a database URL.
// The URL scheme is defined by each driver.
func New(sourceURL, databaseURL string) (*Migrate, error) {
m := newCommon()
func New(sourceURL, databaseURL string, opts ...Option) (*Migrate, error) {
m := newMigrateWithOptions(opts)

sourceName, err := iurl.SchemeFromURL(sourceURL)
if err != nil {
Expand Down Expand Up @@ -118,8 +177,10 @@ func New(sourceURL, databaseURL string) (*Migrate, error) {
// and an existing database instance. The source URL scheme is defined by each driver.
// Use any string that can serve as an identifier during logging as databaseName.
// You are responsible for closing the underlying database client if necessary.
func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInstance database.Driver) (*Migrate, error) {
m := newCommon()
func NewWithDatabaseInstance(sourceURL string, databaseName string,
databaseInstance database.Driver, opts ...Option) (*Migrate, error) {

m := newMigrateWithOptions(opts)

sourceName, err := iurl.SchemeFromURL(sourceURL)
if err != nil {
Expand All @@ -144,8 +205,10 @@ func NewWithDatabaseInstance(sourceURL string, databaseName string, databaseInst
// and a database URL. The database URL scheme is defined by each driver.
// Use any string that can serve as an identifier during logging as sourceName.
// You are responsible for closing the underlying source client if necessary.
func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, databaseURL string) (*Migrate, error) {
m := newCommon()
func NewWithSourceInstance(sourceName string, sourceInstance source.Driver,
databaseURL string, opts ...Option) (*Migrate, error) {

m := newMigrateWithOptions(opts)

databaseName, err := iurl.SchemeFromURL(databaseURL)
if err != nil {
Expand All @@ -170,8 +233,11 @@ func NewWithSourceInstance(sourceName string, sourceInstance source.Driver, data
// database instance. Use any string that can serve as an identifier during logging
// as sourceName and databaseName. You are responsible for closing down
// the underlying source and database client if necessary.
func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseName string, databaseInstance database.Driver) (*Migrate, error) {
m := newCommon()
func NewWithInstance(sourceName string, sourceInstance source.Driver,
databaseName string, databaseInstance database.Driver,
opts ...Option) (*Migrate, error) {

m := newMigrateWithOptions(opts)

m.sourceName = sourceName
m.databaseName = databaseName
Expand All @@ -182,8 +248,13 @@ func NewWithInstance(sourceName string, sourceInstance source.Driver, databaseNa
return m, nil
}

func newCommon() *Migrate {
func newMigrateWithOptions(optFunctions []Option) *Migrate {
opts := defaultOptions()
for _, opt := range optFunctions {
opt(&opts)
}
return &Migrate{
opts: opts,
GracefulStop: make(chan bool, 1),
PrefetchMigrations: DefaultPrefetchMigrations,
LockTimeout: DefaultLockTimeout,
Expand Down Expand Up @@ -746,6 +817,22 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
if err := m.databaseDrv.Run(migr.BufferedBody); err != nil {
return err
}

// If there is a post execution function for
// this migration, run it now.
cb, ok := m.opts.postStepCallbacks[migr.Version]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC if the post step callback fails, then the version remains dirty. This could end up being a problem requiring down migration from the user. I wonder if it'd be better to just run the migration, the callback and a final SetVersion in a transaction?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree. But the thing is: The version is also dirty if the SQL based migration fails, as we also error out then.
And there doesn't seem to be the concept of DB transactions in the migration tool, my guess is because not every supported database backend can do transactions...

So not really sure what to do differently here.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC if the post step callback fails, then the version remains dirty.

This appears to be aa quirk related to the way the library works:

migrate/FAQ.md

Lines 62 to 65 in 604248c

#### What does "dirty" database mean?
Before a migration runs, each database sets a dirty flag. Execution stops if a migration fails and the dirty state persists,
which prevents attempts to run more migrations on top of a failed migration. You need to manually fix the error
and then "force" the expected version.
.

In the context of our desired usage we typically catch situations like this via unit tests of the migration itself.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reading a bit more of the codebase, I think it's possible to run the callback function in the same db transaction as the migration:

if m.config.NoTxWrap {
return m.executeQueryNoTx(query)
}
return m.executeQuery(query)
.

With the way the interfaces work, if we modify those (using something other than a func opt), then we'd need to update every single driver in the codebase.

Perhaps the slimmest change would be to add the functional opt to the Run method in the main Driver interface?

if ok {
m.logVerbosePrintf("Running post step "+
"callback for %v\n", migr.LogString())

err := cb(migr, m.databaseDrv)
if err != nil {
return err
}
Comment on lines +828 to +831
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should wrap err in the return statement with a contextual message using fmt.Errorf for better clarity.

return fmt.Errorf("failed to execute post migration callback: %w", err)


m.logVerbosePrintf("Post step callback "+
"finished for %v\n", migr.LogString())
}
}

// set clean state
Expand Down
111 changes: 111 additions & 0 deletions migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"strings"
"testing"

"github.com/golang-migrate/migrate/v4/database"
dStub "github.com/golang-migrate/migrate/v4/database/stub"
"github.com/golang-migrate/migrate/v4/source"
sStub "github.com/golang-migrate/migrate/v4/source/stub"
Expand Down Expand Up @@ -878,6 +879,116 @@ func TestUpAndDown(t *testing.T) {
equalDbSeq(t, 1, expectedSequence, dbDrv)
}

func TestPostStepCallback(t *testing.T) {
m, _ := New("stub://", "stub://", WithPostStepCallbacks(
map[uint]PostStepCallback{
1: func(m *Migration, driver database.Driver) error {
return driver.Run(
strings.NewReader("CALLBACK 1"),
)
},
7: func(m *Migration, driver database.Driver) error {
return driver.Run(
strings.NewReader("CALLBACK 7"),
)
},
},
))
m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations
dbDrv := m.databaseDrv.(*dStub.Stub)

// go Up first
if err := m.Up(); err != nil {
t.Fatal(err)
}
expectedSequence := migrationSequence{
mr("CREATE 1"),
mr("CALLBACK 1"),
mr("CREATE 3"),
mr("CREATE 4"),
mr("CREATE 7"),
mr("CALLBACK 7"),
}
equalDbSeq(t, 0, expectedSequence, dbDrv)

if !bytes.Equal(dbDrv.LastRunMigration, []byte("CALLBACK 7")) {
t.Fatalf("expected database last migration to be callback 7, "+
"got %s", dbDrv.LastRunMigration)
}

// go Down
if err := m.Down(); err != nil {
t.Fatal(err)
}
expectedSequence = migrationSequence{
mr("CREATE 1"),
mr("CALLBACK 1"),
mr("CREATE 3"),
mr("CREATE 4"),
mr("CREATE 7"),
mr("CALLBACK 7"),
mr("DROP 7"),
mr("CALLBACK 7"),
mr("DROP 5"),
mr("DROP 4"),
mr("DROP 1"),
mr("CALLBACK 1"),
}
equalDbSeq(t, 1, expectedSequence, dbDrv)

if !bytes.Equal(dbDrv.LastRunMigration, []byte("CALLBACK 1")) {
t.Fatalf("expected database last migration to be callback 1, "+
"got %s", dbDrv.LastRunMigration)
}

// go 1 Up and then all the way Up
if err := m.Steps(1); err != nil {
t.Fatal(err)
}
expectedSequence = migrationSequence{
mr("CREATE 1"),
mr("CALLBACK 1"),
mr("CREATE 3"),
mr("CREATE 4"),
mr("CREATE 7"),
mr("CALLBACK 7"),
mr("DROP 7"),
mr("CALLBACK 7"),
mr("DROP 5"),
mr("DROP 4"),
mr("DROP 1"),
mr("CALLBACK 1"),
mr("CREATE 1"),
mr("CALLBACK 1"),
}
equalDbSeq(t, 2, expectedSequence, dbDrv)

if err := m.Up(); err != nil {
t.Fatal(err)
}
expectedSequence = migrationSequence{
mr("CREATE 1"),
mr("CALLBACK 1"),
mr("CREATE 3"),
mr("CREATE 4"),
mr("CREATE 7"),
mr("CALLBACK 7"),
mr("DROP 7"),
mr("CALLBACK 7"),
mr("DROP 5"),
mr("DROP 4"),
mr("DROP 1"),
mr("CALLBACK 1"),
mr("CREATE 1"),
mr("CALLBACK 1"),
mr("CREATE 3"),
mr("CREATE 4"),
mr("CREATE 7"),
mr("CALLBACK 7"),
}
equalDbSeq(t, 3, expectedSequence, dbDrv)
}

func TestUpDirty(t *testing.T) {
m, _ := New("stub://", "stub://")
dbDrv := m.databaseDrv.(*dStub.Stub)
Expand Down