Skip to content

Commit 08353de

Browse files
committed
Implement post migration step callbacks
1 parent 854ae05 commit 08353de

File tree

2 files changed

+170
-1
lines changed

2 files changed

+170
-1
lines changed

migrate.go

+59-1
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,61 @@ func (e ErrDirty) Error() string {
5555
return fmt.Sprintf("Dirty database version %v. Fix and force version.", e.Version)
5656
}
5757

58+
// PostStepCallback is a callback function type that can be used to execute a
59+
// Golang based migration step after a SQL based migration step has been
60+
// executed. The callback function receives the migration and the database
61+
// driver as arguments.
62+
type PostStepCallback func(migr *Migration, driver database.Driver) error
63+
5864
// options is a set of optional options that can be set when a Migrate instance
5965
// is created.
6066
type options struct {
67+
// postStepCallbacks is a map of PostStepCallback functions that can be
68+
// used to execute a Golang based migration step after a SQL based
69+
// migration step has been executed. The key is the migration version
70+
// and the value is the callback function that should be run _after_ the
71+
// step was executed (but within the same database transaction).
72+
postStepCallbacks map[uint]PostStepCallback
6173
}
6274

6375
// defaultOptions returns a new options struct with default values.
6476
func defaultOptions() options {
65-
return options{}
77+
return options{
78+
postStepCallbacks: make(map[uint]PostStepCallback),
79+
}
6680
}
6781

6882
// Option is a function that can be used to set options on a Migrate instance.
6983
type Option func(*options)
7084

85+
// WithPostStepCallbacks is an option that can be used to set a map of
86+
// PostStepCallback functions that can be used to execute a Golang based
87+
// migration step after a SQL based migration step has been executed. The key is
88+
// the migration version and the value is the callback function that should be
89+
// run _after_ the step was executed (but before the version is marked as
90+
// cleanly executed). An error returned from the callback will cause the
91+
// migration to fail and the step to be marked as dirty.
92+
func WithPostStepCallbacks(
93+
postStepCallbacks map[uint]PostStepCallback) Option {
94+
95+
return func(o *options) {
96+
o.postStepCallbacks = postStepCallbacks
97+
}
98+
}
99+
100+
// WithPostStepCallback is an option that can be used to set a PostStepCallback
101+
// function that can be used to execute a Golang based migration step after the
102+
// SQL based migration step with the given version number has been executed. The
103+
// callback is the function that should be run _after_ the step was executed
104+
// (but before the version is marked as cleanly executed). An error returned
105+
// from the callback will cause the migration to fail and the step to be marked
106+
// as dirty.
107+
func WithPostStepCallback(version uint, callback PostStepCallback) Option {
108+
return func(o *options) {
109+
o.postStepCallbacks[version] = callback
110+
}
111+
}
112+
71113
type Migrate struct {
72114
sourceName string
73115
sourceDrv source.Driver
@@ -775,6 +817,22 @@ func (m *Migrate) runMigrations(ret <-chan interface{}) error {
775817
if err := m.databaseDrv.Run(migr.BufferedBody); err != nil {
776818
return err
777819
}
820+
821+
// If there is a post execution function for
822+
// this migration, run it now.
823+
cb, ok := m.opts.postStepCallbacks[migr.Version]
824+
if ok {
825+
m.logVerbosePrintf("Running post step "+
826+
"callback for %v\n", migr.LogString())
827+
828+
err := cb(migr, m.databaseDrv)
829+
if err != nil {
830+
return err
831+
}
832+
833+
m.logVerbosePrintf("Post step callback "+
834+
"finished for %v\n", migr.LogString())
835+
}
778836
}
779837

780838
// set clean state

migrate_test.go

+111
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"strings"
1111
"testing"
1212

13+
"github.com/golang-migrate/migrate/v4/database"
1314
dStub "github.com/golang-migrate/migrate/v4/database/stub"
1415
"github.com/golang-migrate/migrate/v4/source"
1516
sStub "github.com/golang-migrate/migrate/v4/source/stub"
@@ -878,6 +879,116 @@ func TestUpAndDown(t *testing.T) {
878879
equalDbSeq(t, 1, expectedSequence, dbDrv)
879880
}
880881

882+
func TestPostStepCallback(t *testing.T) {
883+
m, _ := New("stub://", "stub://", WithPostStepCallbacks(
884+
map[uint]PostStepCallback{
885+
1: func(m *Migration, driver database.Driver) error {
886+
return driver.Run(
887+
strings.NewReader("CALLBACK 1"),
888+
)
889+
},
890+
7: func(m *Migration, driver database.Driver) error {
891+
return driver.Run(
892+
strings.NewReader("CALLBACK 7"),
893+
)
894+
},
895+
},
896+
))
897+
m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations
898+
dbDrv := m.databaseDrv.(*dStub.Stub)
899+
900+
// go Up first
901+
if err := m.Up(); err != nil {
902+
t.Fatal(err)
903+
}
904+
expectedSequence := migrationSequence{
905+
mr("CREATE 1"),
906+
mr("CALLBACK 1"),
907+
mr("CREATE 3"),
908+
mr("CREATE 4"),
909+
mr("CREATE 7"),
910+
mr("CALLBACK 7"),
911+
}
912+
equalDbSeq(t, 0, expectedSequence, dbDrv)
913+
914+
if !bytes.Equal(dbDrv.LastRunMigration, []byte("CALLBACK 7")) {
915+
t.Fatalf("expected database last migration to be callback 7, "+
916+
"got %s", dbDrv.LastRunMigration)
917+
}
918+
919+
// go Down
920+
if err := m.Down(); err != nil {
921+
t.Fatal(err)
922+
}
923+
expectedSequence = migrationSequence{
924+
mr("CREATE 1"),
925+
mr("CALLBACK 1"),
926+
mr("CREATE 3"),
927+
mr("CREATE 4"),
928+
mr("CREATE 7"),
929+
mr("CALLBACK 7"),
930+
mr("DROP 7"),
931+
mr("CALLBACK 7"),
932+
mr("DROP 5"),
933+
mr("DROP 4"),
934+
mr("DROP 1"),
935+
mr("CALLBACK 1"),
936+
}
937+
equalDbSeq(t, 1, expectedSequence, dbDrv)
938+
939+
if !bytes.Equal(dbDrv.LastRunMigration, []byte("CALLBACK 1")) {
940+
t.Fatalf("expected database last migration to be callback 1, "+
941+
"got %s", dbDrv.LastRunMigration)
942+
}
943+
944+
// go 1 Up and then all the way Up
945+
if err := m.Steps(1); err != nil {
946+
t.Fatal(err)
947+
}
948+
expectedSequence = migrationSequence{
949+
mr("CREATE 1"),
950+
mr("CALLBACK 1"),
951+
mr("CREATE 3"),
952+
mr("CREATE 4"),
953+
mr("CREATE 7"),
954+
mr("CALLBACK 7"),
955+
mr("DROP 7"),
956+
mr("CALLBACK 7"),
957+
mr("DROP 5"),
958+
mr("DROP 4"),
959+
mr("DROP 1"),
960+
mr("CALLBACK 1"),
961+
mr("CREATE 1"),
962+
mr("CALLBACK 1"),
963+
}
964+
equalDbSeq(t, 2, expectedSequence, dbDrv)
965+
966+
if err := m.Up(); err != nil {
967+
t.Fatal(err)
968+
}
969+
expectedSequence = migrationSequence{
970+
mr("CREATE 1"),
971+
mr("CALLBACK 1"),
972+
mr("CREATE 3"),
973+
mr("CREATE 4"),
974+
mr("CREATE 7"),
975+
mr("CALLBACK 7"),
976+
mr("DROP 7"),
977+
mr("CALLBACK 7"),
978+
mr("DROP 5"),
979+
mr("DROP 4"),
980+
mr("DROP 1"),
981+
mr("CALLBACK 1"),
982+
mr("CREATE 1"),
983+
mr("CALLBACK 1"),
984+
mr("CREATE 3"),
985+
mr("CREATE 4"),
986+
mr("CREATE 7"),
987+
mr("CALLBACK 7"),
988+
}
989+
equalDbSeq(t, 3, expectedSequence, dbDrv)
990+
}
991+
881992
func TestUpDirty(t *testing.T) {
882993
m, _ := New("stub://", "stub://")
883994
dbDrv := m.databaseDrv.(*dStub.Stub)

0 commit comments

Comments
 (0)