From a770ce14a48a6ec427e862fa61353f3216f6e3dc Mon Sep 17 00:00:00 2001
From: Gustavo Chain <me@qustavo.cc>
Date: Tue, 21 Sep 2021 14:17:42 +0200
Subject: [PATCH] test: fix nil onError handler on testHooks

Closes #38
---
 benchmark_test.go         |  2 +-
 sqlhooks_mysql_test.go    |  2 +-
 sqlhooks_postgres_test.go |  2 +-
 sqlhooks_sqlite3_test.go  |  2 +-
 sqlhooks_test.go          | 19 +++++++++++++++----
 5 files changed, 19 insertions(+), 8 deletions(-)

diff --git a/benchmark_test.go b/benchmark_test.go
index c2b3010..bbd5aa1 100644
--- a/benchmark_test.go
+++ b/benchmark_test.go
@@ -13,7 +13,7 @@ import (
 
 func init() {
 	hooks := &testHooks{}
-	hooks.noop()
+	hooks.reset()
 
 	sql.Register("sqlite3-benchmark", Wrap(&sqlite3.SQLiteDriver{}, hooks))
 	sql.Register("mysql-benchmark", Wrap(&mysql.MySQLDriver{}, hooks))
diff --git a/sqlhooks_mysql_test.go b/sqlhooks_mysql_test.go
index 204392a..273c5b4 100644
--- a/sqlhooks_mysql_test.go
+++ b/sqlhooks_mysql_test.go
@@ -36,7 +36,7 @@ func TestMySQL(t *testing.T) {
 	s.TestErrHookHook(t, "SELECT * FROM users WHERE id = $2", "INVALID_ARGS")
 
 	t.Run("DBWorks", func(t *testing.T) {
-		s.hooks.noop()
+		s.hooks.reset()
 		if _, err := s.db.Exec("DELETE FROM users"); err != nil {
 			t.Fatal(err)
 		}
diff --git a/sqlhooks_postgres_test.go b/sqlhooks_postgres_test.go
index 6a6560b..4c69235 100644
--- a/sqlhooks_postgres_test.go
+++ b/sqlhooks_postgres_test.go
@@ -36,7 +36,7 @@ func TestPostgres(t *testing.T) {
 	s.TestErrHookHook(t, "SELECT * FROM users WHERE id = $2", "INVALID_ARGS")
 
 	t.Run("DBWorks", func(t *testing.T) {
-		s.hooks.noop()
+		s.hooks.reset()
 		if _, err := s.db.Exec("DELETE FROM users"); err != nil {
 			t.Fatal(err)
 		}
diff --git a/sqlhooks_sqlite3_test.go b/sqlhooks_sqlite3_test.go
index 1342997..f9e785f 100644
--- a/sqlhooks_sqlite3_test.go
+++ b/sqlhooks_sqlite3_test.go
@@ -34,7 +34,7 @@ func TestSQLite3(t *testing.T) {
 	s.TestErrHookHook(t, "SELECT * FROM users WHERE id = $2", "INVALID_ARGS")
 
 	t.Run("DBWorks", func(t *testing.T) {
-		s.hooks.noop()
+		s.hooks.reset()
 		if _, err := s.db.Exec("DELETE FROM users"); err != nil {
 			t.Fatal(err)
 		}
diff --git a/sqlhooks_test.go b/sqlhooks_test.go
index 904a920..ea07b7e 100644
--- a/sqlhooks_test.go
+++ b/sqlhooks_test.go
@@ -19,12 +19,22 @@ type testHooks struct {
 	onError ErrorHook
 }
 
-func (h *testHooks) noop() {
-	noop := func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
+func newTestHooks() *testHooks {
+	th := &testHooks{}
+	th.reset()
+	return th
+}
+
+func (h *testHooks) reset() {
+	noop := func(ctx context.Context, _ string, _ ...interface{}) (context.Context, error) {
 		return ctx, nil
 	}
 
-	h.before, h.after = noop, noop
+	noopErr := func(_ context.Context, err error, _ string, _ ...interface{}) error {
+		return err
+	}
+
+	h.before, h.after, h.onError = noop, noop, noopErr
 }
 
 func (h *testHooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
@@ -45,7 +55,8 @@ type suite struct {
 }
 
 func newSuite(t *testing.T, driver driver.Driver, dsn string) *suite {
-	hooks := &testHooks{}
+	hooks := newTestHooks()
+
 	driverName := fmt.Sprintf("sqlhooks-%s", time.Now().String())
 	sql.Register(driverName, Wrap(driver, hooks))