diff --git a/database/sqlserver/README.md b/database/sqlserver/README.md index 8ecd87723..aae7c28ab 100644 --- a/database/sqlserver/README.md +++ b/database/sqlserver/README.md @@ -17,6 +17,7 @@ | `encrypt` | | `disable` - Data send between client and server is not encrypted. `false` - Data sent between client and server is not encrypted beyond the login packet (Default). `true` - Data sent between client and server is encrypted. | | `app+name` || The application name (default is go-mssqldb). | | `useMsi` | | `true` - Use Azure MSI Authentication for connecting to Sql Server. Must be running from an Azure VM/an instance with MSI enabled. `false` - Use password authentication (Default). See [here for Azure MSI Auth details](https://docs.microsoft.com/en-us/azure/app-service/app-service-web-tutorial-connect-msi). NOTE: Since this cannot be tested locally, this is not officially supported. +| `x-batch` | | Enable batch statements (default: false) | See https://github.com/microsoft/go-mssqldb for full parameter list. diff --git a/database/sqlserver/sqlserver.go b/database/sqlserver/sqlserver.go index 92834d1ad..961257e5a 100644 --- a/database/sqlserver/sqlserver.go +++ b/database/sqlserver/sqlserver.go @@ -16,6 +16,7 @@ import ( "github.com/golang-migrate/migrate/v4/database" "github.com/hashicorp/go-multierror" mssql "github.com/microsoft/go-mssqldb" // mssql support + "github.com/microsoft/go-mssqldb/batch" ) func init() { @@ -30,7 +31,7 @@ var ( ErrNoDatabaseName = fmt.Errorf("no database name") ErrNoSchema = fmt.Errorf("no schema") ErrDatabaseDirty = fmt.Errorf("database is dirty") - ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed.") + ErrMultipleAuthOptionsPassed = fmt.Errorf("both password and useMsi=true were passed") ) var lockErrorMap = map[int]string{ @@ -42,9 +43,10 @@ var lockErrorMap = map[int]string{ // Config for database type Config struct { - MigrationsTable string - DatabaseName string - SchemaName string + MigrationsTable string + DatabaseName string + SchemaName string + BatchStatementEnabled bool } // SQL Server connection @@ -168,9 +170,18 @@ func (ss *SQLServer) Open(url string) (database.Driver, error) { migrationsTable := purl.Query().Get("x-migrations-table") + batchStatementEnabled := false + if s := purl.Query().Get("x-batch"); len(s) > 0 { + batchStatementEnabled, err = strconv.ParseBool(s) + if err != nil { + return nil, fmt.Errorf("unable to parse option x-batch: %w", err) + } + } + px, err := WithInstance(db, &Config{ - DatabaseName: purl.Path, - MigrationsTable: migrationsTable, + DatabaseName: purl.Path, + MigrationsTable: migrationsTable, + BatchStatementEnabled: batchStatementEnabled, }) if err != nil { @@ -247,15 +258,23 @@ func (ss *SQLServer) Run(migration io.Reader) error { // run migration query := string(migr[:]) - if _, err := ss.conn.ExecContext(context.Background(), query); err != nil { - if msErr, ok := err.(mssql.Error); ok { - message := fmt.Sprintf("migration failed: %s", msErr.Message) - if msErr.ProcName != "" { - message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName) + scripts := []string{query} + + if ss.config.BatchStatementEnabled { + scripts = batch.Split(query, "go") + } + + for _, script := range scripts { + if _, err := ss.conn.ExecContext(context.Background(), script); err != nil { + if msErr, ok := err.(mssql.Error); ok { + message := fmt.Sprintf("migration failed: %s", msErr.Message) + if msErr.ProcName != "" { + message = fmt.Sprintf("%s (proc name %s)", msErr.Message, msErr.ProcName) + } + return database.Error{OrigErr: err, Err: message, Query: []byte(script), Line: uint(msErr.LineNo)} } - return database.Error{OrigErr: err, Err: message, Query: migr, Line: uint(msErr.LineNo)} + return database.Error{OrigErr: err, Err: "migration failed", Query: []byte(script)} } - return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } return nil diff --git a/database/sqlserver/sqlserver_test.go b/database/sqlserver/sqlserver_test.go index 402f4480f..0f5034f60 100644 --- a/database/sqlserver/sqlserver_test.go +++ b/database/sqlserver/sqlserver_test.go @@ -37,8 +37,9 @@ var ( } ) -func msConnectionString(host, port string) string { - return fmt.Sprintf("sqlserver://sa:%v@%v:%v?database=master", saPassword, host, port) +func msConnectionString(host, port string, options ...string) string { + options = append(options, "database=master") + return fmt.Sprintf("sqlserver://sa:%v@%v:%v?%s", saPassword, host, port, strings.Join(options, "&")) } func msConnectionStringMsiWithPassword(host, port string, useMsi bool) string { @@ -87,6 +88,7 @@ func Test(t *testing.T) { t.Run("test", test) t.Run("testMigrate", testMigrate) t.Run("testMultiStatement", testMultiStatement) + t.Run("testBatchedStatement", testBatchedStatement) t.Run("testErrorParsing", testErrorParsing) t.Run("testLockWorks", testLockWorks) t.Run("testMsiTrue", testMsiTrue) @@ -191,6 +193,49 @@ func testMultiStatement(t *testing.T) { }) } +func testBatchedStatement(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.Port(defaultPort) + if err != nil { + t.Fatal(err) + } + + addr := msConnectionString(ip, port, "x-batch=true") + ms := &SQLServer{} + d, err := ms.Open(addr) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := d.Close(); err != nil { + t.Error(err) + } + }() + if err := d.Run(strings.NewReader(`CREATE PROCEDURE uspA +AS +BEGIN + SELECT 1; +END; +GO +CREATE PROCEDURE uspB +AS +BEGIN + SELECT 2; +END`)); err != nil { + t.Fatalf("expected err to be nil, got %v", err) + } + + // make sure second proc exists + var exists int + if err := d.(*SQLServer).conn.QueryRowContext(context.Background(), "Select COUNT(1) from sysobjects where type = 'P' and category = 0 and [NAME] = 'uspB'").Scan(&exists); err != nil { + t.Fatal(err) + } + if exists != 1 { + t.Fatalf("expected proc uspB to exist") + } + }) +} + func testErrorParsing(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { SkipIfUnsupportedArch(t, c)