diff --git a/README.md b/README.md index 65fd7f70..c8a5089f 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,8 @@ production: table: migrations ``` +To process only the specified directory, set `dir: migrations/sqlite3`, to process the specified directory and all of its subdirectories recursively, set `dir: migrations/sqlite3/*`. + (See more examples for different set ups [here](test-integration/dbconfig.yml)) Also one can obtain env variables in datasource field via `os.ExpandEnv` embedded call for the field. diff --git a/migrate.go b/migrate.go index d08e22b4..7d98ab97 100644 --- a/migrate.go +++ b/migrate.go @@ -8,9 +8,11 @@ import ( "errors" "fmt" "io" + "io/fs" "net/http" "os" "path" + "path/filepath" "regexp" "sort" "strconv" @@ -300,6 +302,80 @@ func migrationFromFile(dir http.FileSystem, root string, info os.FileInfo) (*Mig return migration, nil } +type RecursiveFileMigrationSource struct { + Dir string +} + +var _ MigrationSource = (*RecursiveFileMigrationSource)(nil) + +func (r RecursiveFileMigrationSource) FindMigrations() ([]*Migration, error) { + dirfs := os.DirFS(r.Dir) + migrationIDs := make(map[string]string) + var migrations []*Migration + if err := fs.WalkDir(dirfs, ".", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return fmt.Errorf("Error while walkdir %s: %w", path, err) + } + if d.IsDir() { + return nil + } + + migrationID := d.Name() + if existPath, exist := migrationIDs[migrationID]; exist { + return fmt.Errorf("duplicate ID '%s' found in both path %s and path %s", migrationID, existPath, path) + } + migrationIDs[migrationID] = path + + migration, err := migrationFromFS(dirfs, migrationID, path) + if err != nil { + return err + } + migrations = append(migrations, migration) + return nil + }); err != nil { + return nil, err + } + + // Make sure migrations are sorted + sort.Sort(byId(migrations)) + + return migrations, nil +} + +func migrationFromFS(dirfs fs.FS, migrationID, path string) (*Migration, error) { + file, err := dirfs.Open(path) + if err != nil { + return nil, fmt.Errorf("Error while opening %s: %w", path, err) + } + defer func() { _ = file.Close() }() + + rs, ok := file.(io.ReadSeeker) + if !ok { + data, err := io.ReadAll(file) + if err != nil { + return nil, fmt.Errorf("Error while read file %s: %w", path, err) + } + rs = io.NewSectionReader(bytes.NewReader(data), 0, int64(len(data))) + } + migration, err := ParseMigration(migrationID, rs) + if err != nil { + return nil, fmt.Errorf("Error while parsing %s: %w", path, err) + } + + return migration, nil +} + +func MakeFileMigrationSource(dir string) MigrationSource { + if _, last := filepath.Split(dir); last == "*" { + return RecursiveFileMigrationSource{ + Dir: filepath.Dir(dir), + } + } + return FileMigrationSource{ + Dir: dir, + } +} + // Migrations from a bindata asset set. type AssetMigrationSource struct { // Asset should return content of file in path if exists diff --git a/migrate_test.go b/migrate_test.go index 3d6e80e4..87df7499 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -5,6 +5,7 @@ import ( "database/sql" "embed" "net/http" + "path/filepath" "time" "github.com/go-gorp/gorp/v3" @@ -130,6 +131,39 @@ func (s *SqliteMigrateSuite) TestFileMigrate(c *C) { c.Assert(id, Equals, int64(1)) } +func (s *SqliteMigrateSuite) TestRecursiveFileMigrate(c *C) { + migrations := &RecursiveFileMigrationSource{ + Dir: "test-migrations", + } + + // Executes two migrations + n, err := Exec(s.Db, "sqlite3", migrations, Up) + c.Assert(err, IsNil) + c.Assert(n, Equals, 4) + + // Has data + id, err := s.DbMap.SelectInt("SELECT id FROM people") + c.Assert(err, IsNil) + c.Assert(id, Equals, int64(1)) + + name, err := s.DbMap.SelectStr("SELECT name FROM people") + c.Assert(err, IsNil) + c.Assert(name, Equals, "test") +} + +func (s *SqliteMigrateSuite) TestMakeFileMigrationSource(c *C) { + { + dir := filepath.Join("aaa", "bbb", "ccc") + got := MakeFileMigrationSource(dir) + c.Assert(got, Equals, FileMigrationSource{Dir: dir}) + } + { + dir := filepath.Join("aaa", "bbb", "*") + got := MakeFileMigrationSource(dir) + c.Assert(got, Equals, RecursiveFileMigrationSource{Dir: filepath.Join("aaa", "bbb")}) + } +} + func (s *SqliteMigrateSuite) TestHttpFileSystemMigrate(c *C) { migrations := &HttpFileSystemMigrationSource{ FileSystem: http.Dir("test-migrations"), diff --git a/sql-migrate/command_common.go b/sql-migrate/command_common.go index e517c24b..5bfb853e 100644 --- a/sql-migrate/command_common.go +++ b/sql-migrate/command_common.go @@ -18,9 +18,7 @@ func ApplyMigrations(dir migrate.MigrationDirection, dryrun bool, limit int, ver } defer db.Close() - source := migrate.FileMigrationSource{ - Dir: env.Dir, - } + source := migrate.MakeFileMigrationSource(env.Dir) if dryrun { var migrations []*migrate.PlannedMigration diff --git a/test-migrations/dir1/3_add_column.sql b/test-migrations/dir1/3_add_column.sql new file mode 100644 index 00000000..373d4f31 --- /dev/null +++ b/test-migrations/dir1/3_add_column.sql @@ -0,0 +1,5 @@ +-- +migrate Up +ALTER TABLE people ADD COLUMN name varchar(255); + +-- +migrate Down +ALTER TABLE people DROP COLUMN name varchar(255); diff --git a/test-migrations/dir1/dir2/4_set_name.sql b/test-migrations/dir1/dir2/4_set_name.sql new file mode 100644 index 00000000..d6620949 --- /dev/null +++ b/test-migrations/dir1/dir2/4_set_name.sql @@ -0,0 +1,5 @@ +-- +migrate Up +UPDATE people SET name = 'test' WHERE id = 1; + +-- +migrate Down +UPDATE people SET name = NULL WHERE id = 1;