Skip to content

Commit

Permalink
merging
Browse files Browse the repository at this point in the history
  • Loading branch information
paganotoni committed Jun 21, 2024
2 parents 2163387 + a67d6fc commit 42cbeea
Show file tree
Hide file tree
Showing 17 changed files with 177 additions and 73 deletions.
9 changes: 8 additions & 1 deletion cmd/database/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
_ "github.com/leapkit/core/tools/envload"
// sqlite3 driver
_ "github.com/mattn/go-sqlite3"
// postgres driver
_ "github.com/lib/pq"
)

func main() {
Expand All @@ -33,7 +35,12 @@ func main() {

switch os.Args[1] {
case "migrate":
err := db.RunMigrationsDir(filepath.Join("internal", "migrations"), url)
conn, err := db.Connect(url)
if err != nil {
fmt.Println(err)
}

err = db.RunMigrationsDir(filepath.Join("internal", "migrations"), conn)
if err != nil {
fmt.Println(err)

Expand Down
23 changes: 17 additions & 6 deletions db/connection.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package db

import (
"database/sql"
"strings"
"sync"

"github.com/jmoiron/sqlx"
)

var (
conn *sqlx.DB
conn *sql.DB
cmux sync.Mutex

//DriverName defaults to postgres
Expand All @@ -17,7 +17,7 @@ var (
// ConnFn is the database connection builder function that
// will be used by the application based on the driver and
// connection string.
type ConnFn func() (*sqlx.DB, error)
type ConnFn func() (*sql.DB, error)

// connectionOptions for the database
type connectionOption func()
Expand All @@ -27,7 +27,7 @@ type connectionOption func()
// connection string. It opens the connection only once
// and return the same connection on subsequent calls.
func ConnectionFn(url string, opts ...connectionOption) ConnFn {
return func() (cx *sqlx.DB, err error) {
return func() (cx *sql.DB, err error) {
cmux.Lock()
defer cmux.Unlock()

Expand All @@ -40,7 +40,7 @@ func ConnectionFn(url string, opts ...connectionOption) ConnFn {
v()
}

conn, err = sqlx.Connect(driverName, url)
conn, err = sql.Open(driverName, url)
if err != nil {
return nil, err
}
Expand All @@ -56,3 +56,14 @@ func WithDriver(name string) connectionOption {
driverName = name
}
}

// connect to the database based on the driver and connection string.
func Connect(url string) (*sql.DB, error) {
// Based on DSN
driver := "sqlite3"
if strings.HasPrefix(url, "postgres") {
driver = "postgres"
}

return sql.Open(driver, url)
}
31 changes: 11 additions & 20 deletions db/migrations.go
Original file line number Diff line number Diff line change
@@ -1,31 +1,28 @@
package db

import (
"database/sql"
"embed"
"fmt"
"html/template"
"io/fs"
"os"
"path/filepath"
"regexp"

"github.com/jmoiron/sqlx"
"github.com/leapkit/core/db/migrations"
"github.com/leapkit/core/db/postgres"
"github.com/leapkit/core/db/sqlite"
"github.com/mattn/go-sqlite3"
)

// migratorFor the adapter for the passed SQL connection
// based on the driver name.
func migratorFor(conn *sqlx.DB) any {
// Migrator for the passed SQL connection.
switch conn.DriverName() {
case "postgres":
return postgres.New(conn)
case "sqlite", "sqlite3":
func migratorFor(conn *sql.DB) any {
switch conn.Driver().(type) {
case *sqlite3.SQLiteDriver:
return sqlite.New(conn)
default:
return nil
return postgres.New(conn)
}
}

Expand Down Expand Up @@ -58,20 +55,15 @@ func GenerateMigration(name string, options ...migrations.Option) error {

// RunMigrationsDir receives a folder and a database URL
// to apply the migrations to the database.
func RunMigrationsDir(dir, url string) error {
conn, err := sqlx.Open("sqlite3", url)
if err != nil {
return fmt.Errorf("error opening database connection: %w", err)
}

func RunMigrationsDir(dir string, conn *sql.DB) error {
migrator := migratorFor(conn).(migrations.Migrator)
err = migrator.Setup()
err := migrator.Setup()
if err != nil {
return fmt.Errorf("error setting up migrations: %w", err)
}

exp := regexp.MustCompile("(\\d{14})_(.*).sql")
return os.Walk(dir, func(path string, info os.FileInfo, err error) error {
return filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if info.IsDir() {
return nil
}
Expand All @@ -94,13 +86,12 @@ func RunMigrationsDir(dir, url string) error {

return nil
})

}

// RunMigrations by checking in the migrations database
// table, each of the adapters take care of this.
func RunMigrations(folder embed.FS, conn *sqlx.DB) error {
dir, err := folder.ReadDir(".")
func RunMigrations(fs embed.FS, conn *sql.DB) error {
dir, err := fs.ReadDir(".")
if err != nil {
return fmt.Errorf("error reading migrations directory: %w", err)
}
Expand Down
7 changes: 3 additions & 4 deletions db/postgres/adapter.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package postgres

import (
"database/sql"
"regexp"

"github.com/jmoiron/sqlx"
)

var (
Expand All @@ -15,11 +14,11 @@ var (
// adapter for the sqlite database it includes the connection
// to perform the framework operations.
type adapter struct {
conn *sqlx.DB
conn *sql.DB
}

// New sqlite adapter with the passed connection.
func New(conn *sqlx.DB) *adapter {
func New(conn *sql.DB) *adapter {
return &adapter{
conn: conn,
}
Expand Down
14 changes: 8 additions & 6 deletions db/postgres/manager.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package postgres

import (
"database/sql"
"fmt"

"github.com/jmoiron/sqlx"
)

type manager struct {
Expand All @@ -23,13 +22,15 @@ func (m *manager) Create() error {
return fmt.Errorf("invalid database url: %s", m.url)
}

db, err := sqlx.Connect("postgres", fmt.Sprintf("postgres://%s:%s@%s:%s/postgres?sslmode=disable", matches[1], matches[2], matches[3], matches[4]))
db, err := sql.Open("postgres", fmt.Sprintf("postgres://%s:%s@%s:%s/postgres?sslmode=disable", matches[1], matches[2], matches[3], matches[4]))
if err != nil {
return fmt.Errorf("error connecting to database: %w", err)
}

var exists int
err = db.Get(&exists, "SELECT COUNT(datname) FROM pg_database WHERE datname ilike $1", matches[5])
row := db.QueryRow("SELECT COUNT(datname) FROM pg_database WHERE datname ilike $1", matches[5])
err = row.Scan(&exists)

if err != nil {
return err
}
Expand All @@ -53,13 +54,14 @@ func (m *manager) Drop() error {
return fmt.Errorf("invalid database url: %s", m.url)
}

db, err := sqlx.Connect("postgres", fmt.Sprintf("postgres://%s:%s@%s:%s/postgres?sslmode=disable", matches[1], matches[2], matches[3], matches[4]))
db, err := sql.Open("postgres", fmt.Sprintf("postgres://%s:%s@%s:%s/postgres?sslmode=disable", matches[1], matches[2], matches[3], matches[4]))
if err != nil {
return fmt.Errorf("error connecting to database: %w", err)
}

var dbexists int
err = db.Get(&dbexists, "SELECT COUNT(datname) FROM pg_database WHERE datname ilike $1", matches[5])
row := db.QueryRow("SELECT COUNT(datname) FROM pg_database WHERE datname ilike $1", matches[5])
err = row.Scan(&dbexists)
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion db/postgres/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ func (a *adapter) Setup() error {
// on the migrations table.
func (a *adapter) Run(timestamp, sql string) error {
var exists bool
err := a.conn.Get(&exists, "SELECT EXISTS (SELECT 1 FROM schema_migrations WHERE timestamp = $1)", timestamp)
row := a.conn.QueryRow("SELECT EXISTS (SELECT 1 FROM schema_migrations WHERE timestamp = $1)", timestamp)
err := row.Scan(&exists)
if err != nil {
return fmt.Errorf("error running migration: %w", err)
}
Expand Down
8 changes: 5 additions & 3 deletions db/sqlite/adapter.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package sqlite

import "github.com/jmoiron/sqlx"
import (
"database/sql"
)

// adapter for the sqlite database it includes the connection
// to perform the framework operations.
type adapter struct {
conn *sqlx.DB
conn *sql.DB
}

// New sqlite adapter with the passed connection.
func New(conn *sqlx.DB) *adapter {
func New(conn *sql.DB) *adapter {
return &adapter{
conn: conn,
}
Expand Down
3 changes: 2 additions & 1 deletion db/sqlite/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ func (a *adapter) Setup() error {
// on the migrations table.
func (a *adapter) Run(timestamp, sql string) error {
var exists bool
err := a.conn.Get(&exists, "SELECT EXISTS (SELECT 1 FROM schema_migrations WHERE timestamp = $1)", timestamp)
row := a.conn.QueryRow("SELECT EXISTS (SELECT 1 FROM schema_migrations WHERE timestamp = $1)", timestamp)
err := row.Scan(&exists)
if err != nil {
return fmt.Errorf("error running migration: %w", err)
}
Expand Down
33 changes: 22 additions & 11 deletions db/sqlite/migrations_test.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package sqlite_test

import (
"database/sql"
"path/filepath"
"testing"

"github.com/jmoiron/sqlx"
"github.com/leapkit/core/db/sqlite"
_ "github.com/mattn/go-sqlite3"
)

func TestSetup(t *testing.T) {
td := t.TempDir()
conn, err := sqlx.Connect("sqlite3", filepath.Join(td, "database.db"))
conn, err := sql.Open("sqlite3", filepath.Join(td, "database.db"))
if err != nil {
t.Fatal(err)
}
Expand All @@ -22,21 +22,30 @@ func TestSetup(t *testing.T) {
t.Fatal(err)
}

result := struct{ Name string }{}
err = conn.Get(&result, "SELECT name FROM sqlite_master WHERE type='table' AND name='schema_migrations';")
var name string
rows, err := conn.Query("SELECT name FROM sqlite_master WHERE type='table' AND name='schema_migrations';")
if err != nil {
t.Fatal("schema_migrations table not found")
}

if result.Name != "schema_migrations" {
if !rows.Next() {
t.Fatal("schema_migrations table not found")
}

err = rows.Scan(&name)
if err != nil {
t.Fatal(err)
}

if name != "schema_migrations" {
t.Fatal("schema_migrations table not found")
}
}

func TestRun(t *testing.T) {
t.Run("migration not found", func(t *testing.T) {
td := t.TempDir()
conn, err := sqlx.Connect("sqlite3", filepath.Join(td, "database.db"))
conn, err := sql.Open("sqlite3", filepath.Join(td, "database.db"))
if err != nil {
t.Fatal(err)
}
Expand All @@ -52,16 +61,17 @@ func TestRun(t *testing.T) {
t.Fatal(err)
}

result := struct{ Name string }{}
err = conn.Get(&result, "SELECT name FROM sqlite_master WHERE type='table' AND name='users';")
var name string
row := conn.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='users';")
err = row.Scan(&name)
if err != nil {
t.Fatal("users table not found")
}
})

t.Run("migration found", func(t *testing.T) {
td := t.TempDir()
conn, err := sqlx.Connect("sqlite3", filepath.Join(td, "database.db"))
conn, err := sql.Open("sqlite3", filepath.Join(td, "database.db"))
if err != nil {
t.Fatal(err)
}
Expand All @@ -82,8 +92,9 @@ func TestRun(t *testing.T) {
t.Fatal(err)
}

result := struct{ Name string }{}
err = conn.Get(&result, "SELECT name FROM sqlite_master WHERE type='table' AND name='users';")
var name string
row := conn.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name='users';")
err = row.Scan(&name)
if err != nil {
t.Fatal("users table not found")
}
Expand Down
2 changes: 1 addition & 1 deletion docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ To get started with LeapKit the [template](https://github.com/leapkit/template)
To generate your project you can use `gonew` to copy the template and generate the necessary files for your project. The following command will generate a new project called `superapp` in the current directory.

```
go run rsc.io/tmp/gonew@latest github.com/leapkit/[email protected].5 superapp
go run rsc.io/tmp/gonew@latest github.com/leapkit/[email protected].8 superapp
```

### Setup
Expand Down
Loading

0 comments on commit 42cbeea

Please sign in to comment.