diff --git a/internal/codegen/golang/driver.go b/internal/codegen/golang/driver.go index 7ef723b55e..db30a3b0e2 100644 --- a/internal/codegen/golang/driver.go +++ b/internal/codegen/golang/driver.go @@ -1,5 +1,7 @@ package golang +import "github.com/sqlc-dev/sqlc/internal/config" + type SQLDriver string const ( @@ -15,12 +17,24 @@ const ( SQLDriverGoSQLDriverMySQL = "github.com/go-sql-driver/mysql" ) -func parseDriver(sqlPackage string) SQLDriver { +func parseDriver(sqlPackage string, engine string) SQLDriver { switch sqlPackage { case SQLPackagePGXV4: return SQLDriverPGXV4 case SQLPackagePGXV5: return SQLDriverPGXV5 + default: + driver := driverFromEngine(engine) + return driver + } +} + +func driverFromEngine(engine string) SQLDriver { + switch engine { + case string(config.EnginePostgreSQL): + return SQLDriverLibPQ + case string(config.EngineMySQL): + return SQLDriverGoSQLDriverMySQL default: return SQLDriverLibPQ } diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index 7cd0a8dccd..3ef76073e3 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -123,7 +123,6 @@ func generate(req *plugin.CodeGenRequest, enums []Enum, structs []Struct, querie Enums: enums, Structs: structs, } - golang := req.Settings.Go tctx := tmplCtx{ EmitInterface: golang.EmitInterface, @@ -137,7 +136,7 @@ func generate(req *plugin.CodeGenRequest, enums []Enum, structs []Struct, querie EmitAllEnumValues: golang.EmitAllEnumValues, UsesCopyFrom: usesCopyFrom(queries), UsesBatch: usesBatch(queries), - SQLDriver: parseDriver(golang.SqlPackage), + SQLDriver: parseDriver(golang.SqlPackage, req.Settings.Engine), Q: "`", Package: golang.Package, Enums: enums, @@ -145,15 +144,14 @@ func generate(req *plugin.CodeGenRequest, enums []Enum, structs []Struct, querie SqlcVersion: req.SqlcVersion, } - if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && golang.SqlDriver != SQLDriverGoSQLDriverMySQL { + if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && tctx.SQLDriver != SQLDriverGoSQLDriverMySQL { return nil, errors.New(":copyfrom is only supported by pgx and github.com/go-sql-driver/mysql") } - if tctx.UsesCopyFrom && golang.SqlDriver == SQLDriverGoSQLDriverMySQL { + if tctx.UsesCopyFrom && tctx.SQLDriver == SQLDriverGoSQLDriverMySQL { if err := checkNoTimesForMySQLCopyFrom(queries); err != nil { return nil, err } - tctx.SQLDriver = SQLDriverGoSQLDriverMySQL } if tctx.UsesBatch && !tctx.SQLDriver.IsPGX() { diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index 71146ba643..ed30e0071c 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -116,7 +116,7 @@ func (i *importer) dbImports() fileImports { {Path: "context"}, } - sqlpkg := parseDriver(i.Settings.Go.SqlPackage) + sqlpkg := parseDriver(i.Settings.Go.SqlPackage, i.Settings.Engine) switch sqlpkg { case SQLDriverPGXV4: pkg = append(pkg, ImportSpec{Path: "github.com/jackc/pgconn"}) @@ -160,7 +160,7 @@ func buildImports(settings *plugin.Settings, queries []Query, uses func(string) std["database/sql"] = struct{}{} } - sqlpkg := parseDriver(settings.Go.SqlPackage) + sqlpkg := parseDriver(settings.Go.SqlPackage, settings.Engine) for _, q := range queries { if q.Cmd == metadata.CmdExecResult { switch sqlpkg { @@ -374,7 +374,7 @@ func (i *importer) queryImports(filename string) fileImports { std["context"] = struct{}{} } - sqlpkg := parseDriver(i.Settings.Go.SqlPackage) + sqlpkg := parseDriver(i.Settings.Go.SqlPackage, i.Settings.Engine) if sqlcSliceScan() { std["strings"] = struct{}{} } @@ -459,7 +459,7 @@ func (i *importer) batchImports() fileImports { std["context"] = struct{}{} std["errors"] = struct{}{} - sqlpkg := parseDriver(i.Settings.Go.SqlPackage) + sqlpkg := parseDriver(i.Settings.Go.SqlPackage, i.Settings.Engine) switch sqlpkg { case SQLDriverPGXV4: pkg[ImportSpec{Path: "github.com/jackc/pgx/v4"}] = struct{}{} diff --git a/internal/codegen/golang/postgresql_type.go b/internal/codegen/golang/postgresql_type.go index 1936de1f38..1ac6a94bb8 100644 --- a/internal/codegen/golang/postgresql_type.go +++ b/internal/codegen/golang/postgresql_type.go @@ -36,7 +36,7 @@ func parseIdentifierString(name string) (*plugin.Identifier, error) { func postgresType(req *plugin.CodeGenRequest, col *plugin.Column) string { columnType := sdk.DataType(col.Type) notNull := col.NotNull || col.IsArray - driver := parseDriver(req.Settings.Go.SqlPackage) + driver := parseDriver(req.Settings.Go.SqlPackage, req.Settings.Engine) emitPointersForNull := driver.IsPGX() && req.Settings.Go.EmitPointersForNullTypes switch columnType { diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index 8e1c2714f7..9941b25166 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -208,7 +208,7 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) Comments: query.Comments, Table: query.InsertIntoTable, } - sqlpkg := parseDriver(req.Settings.Go.SqlPackage) + sqlpkg := parseDriver(req.Settings.Go.SqlPackage, req.Settings.Engine) qpl := int(*req.Settings.Go.QueryParameterLimit)