From 4ea9dff73e238173bb073acd51ebc72efd44f8f6 Mon Sep 17 00:00:00 2001 From: Conrad Ludgate Date: Thu, 23 Sep 2021 12:00:03 +0100 Subject: [PATCH] feat: support driver.Connector interface --- go.mod | 4 ++-- go.sum | 11 +++++------ sqlhooks.go | 48 +++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 54 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index eb0c75a..181f536 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module github.com/qustavo/sqlhooks/v2 go 1.13 require ( - github.com/go-sql-driver/mysql v1.4.1 - github.com/lib/pq v1.2.0 + github.com/go-sql-driver/mysql v1.6.0 + github.com/lib/pq v1.10.3 github.com/mattn/go-sqlite3 v1.10.0 github.com/opentracing/opentracing-go v1.1.0 github.com/stretchr/testify v1.4.0 diff --git a/go.sum b/go.sum index f913b42..bd5b20c 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,11 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= -github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= -github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= -github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/lib/pq v1.10.3 h1:v9QZf2Sn6AmjXtQeFpdoq/eaNtYP6IN+7lcrygsIAtg= +github.com/lib/pq v1.10.3/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= -github.com/mattn/go-sqlite3 v1.11.0 h1:LDdKkqtYlom37fkvqs8rMPFKAMe8+SgjbwZ6ex1/A/Q= -github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/opentracing/opentracing-go v1.1.0 h1:pWlfV3Bxv7k65HYwkikxat0+s3pV4bsqf19k25Ur8rU= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -15,6 +13,7 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/sqlhooks.go b/sqlhooks.go index 3d52576..ceca574 100644 --- a/sqlhooks.go +++ b/sqlhooks.go @@ -48,13 +48,16 @@ func (drv *Driver) Open(name string) (driver.Conn, error) { if err != nil { return conn, err } + return wrapConn(conn, drv.hooks) +} +func wrapConn(conn driver.Conn, hooks Hooks) (driver.Conn, error) { // Drivers that don't implement driver.ConnBeginTx are not supported. if _, ok := conn.(driver.ConnBeginTx); !ok { return nil, errors.New("driver must implement driver.ConnBeginTx") } - wrapped := &Conn{conn, drv.hooks} + wrapped := &Conn{conn, hooks} if isExecer(conn) && isQueryer(conn) && isSessionResetter(conn) { return &ExecerQueryerContextWithSessionResetter{wrapped, &ExecerContext{wrapped}, &QueryerContext{wrapped}, @@ -74,6 +77,49 @@ func (drv *Driver) Open(name string) (driver.Conn, error) { return wrapped, nil } +// borrowed from https://github.com/golang/go/blob/d0dd26a88c019d54f22463daae81e785f5867565/src/database/sql/sql.go#L755-L766 +type dsnConnector struct { + dsn string + driver driver.Driver +} + +func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) { + return t.driver.Open(t.dsn) +} + +func (t dsnConnector) Driver() driver.Driver { + return t.driver +} + +func (drv *Driver) OpenConnector(name string) (driver.Connector, error) { + if driverCtx, ok := drv.Driver.(driver.DriverContext); ok { + connector, err := driverCtx.OpenConnector(name) + if err != nil { + return nil, err + } + return &Connector{connector, drv.hooks}, nil + } + return &Connector{dsnConnector{dsn: name, driver: drv.Driver}, drv.hooks}, nil +} + +// Driver implements a database/sql/driver.Driver +type Connector struct { + driver.Connector + hooks Hooks +} + +func (connector Connector) Connect(ctx context.Context) (driver.Conn, error) { + conn, err := connector.Connector.Connect(ctx) + if err != nil { + return conn, err + } + return wrapConn(conn, connector.hooks) +} + +func (connector Connector) Driver() driver.Driver { + return &Driver{connector.Connector.Driver(), connector.hooks} +} + // Conn implements a database/sql.driver.Conn type Conn struct { Conn driver.Conn