diff --git a/benchmark_test.go b/benchmark_test.go index c2b3010..bbd5aa1 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -13,7 +13,7 @@ import ( func init() { hooks := &testHooks{} - hooks.noop() + hooks.reset() sql.Register("sqlite3-benchmark", Wrap(&sqlite3.SQLiteDriver{}, hooks)) sql.Register("mysql-benchmark", Wrap(&mysql.MySQLDriver{}, hooks)) diff --git a/sqlhooks_mysql_test.go b/sqlhooks_mysql_test.go index 204392a..273c5b4 100644 --- a/sqlhooks_mysql_test.go +++ b/sqlhooks_mysql_test.go @@ -36,7 +36,7 @@ func TestMySQL(t *testing.T) { s.TestErrHookHook(t, "SELECT * FROM users WHERE id = $2", "INVALID_ARGS") t.Run("DBWorks", func(t *testing.T) { - s.hooks.noop() + s.hooks.reset() if _, err := s.db.Exec("DELETE FROM users"); err != nil { t.Fatal(err) } diff --git a/sqlhooks_postgres_test.go b/sqlhooks_postgres_test.go index 6a6560b..4c69235 100644 --- a/sqlhooks_postgres_test.go +++ b/sqlhooks_postgres_test.go @@ -36,7 +36,7 @@ func TestPostgres(t *testing.T) { s.TestErrHookHook(t, "SELECT * FROM users WHERE id = $2", "INVALID_ARGS") t.Run("DBWorks", func(t *testing.T) { - s.hooks.noop() + s.hooks.reset() if _, err := s.db.Exec("DELETE FROM users"); err != nil { t.Fatal(err) } diff --git a/sqlhooks_sqlite3_test.go b/sqlhooks_sqlite3_test.go index 1342997..f9e785f 100644 --- a/sqlhooks_sqlite3_test.go +++ b/sqlhooks_sqlite3_test.go @@ -34,7 +34,7 @@ func TestSQLite3(t *testing.T) { s.TestErrHookHook(t, "SELECT * FROM users WHERE id = $2", "INVALID_ARGS") t.Run("DBWorks", func(t *testing.T) { - s.hooks.noop() + s.hooks.reset() if _, err := s.db.Exec("DELETE FROM users"); err != nil { t.Fatal(err) } diff --git a/sqlhooks_test.go b/sqlhooks_test.go index 904a920..ea07b7e 100644 --- a/sqlhooks_test.go +++ b/sqlhooks_test.go @@ -19,12 +19,22 @@ type testHooks struct { onError ErrorHook } -func (h *testHooks) noop() { - noop := func(ctx context.Context, query string, args ...interface{}) (context.Context, error) { +func newTestHooks() *testHooks { + th := &testHooks{} + th.reset() + return th +} + +func (h *testHooks) reset() { + noop := func(ctx context.Context, _ string, _ ...interface{}) (context.Context, error) { return ctx, nil } - h.before, h.after = noop, noop + noopErr := func(_ context.Context, err error, _ string, _ ...interface{}) error { + return err + } + + h.before, h.after, h.onError = noop, noop, noopErr } func (h *testHooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { @@ -45,7 +55,8 @@ type suite struct { } func newSuite(t *testing.T, driver driver.Driver, dsn string) *suite { - hooks := &testHooks{} + hooks := newTestHooks() + driverName := fmt.Sprintf("sqlhooks-%s", time.Now().String()) sql.Register(driverName, Wrap(driver, hooks))