diff --git a/CHANGELOG.md b/CHANGELOG.md index 43f7842ea..2d57b0fa6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +* Fixed goroutine leak on closing `database/sql` driver * "No endpoints" is retriable error now ## v3.99.3 diff --git a/driver.go b/driver.go index a3ebf304d..a60717315 100644 --- a/driver.go +++ b/driver.go @@ -164,8 +164,8 @@ func (d *Driver) Close(ctx context.Context) (finalErr error) { d.ctxCancel() defer func() { - for _, f := range d.onClose { - f(d) + for _, onClose := range d.onClose { + onClose(d) } }() diff --git a/internal/repeater/repeater.go b/internal/repeater/repeater.go index d4d6c6bb1..01f9a6ec8 100644 --- a/internal/repeater/repeater.go +++ b/internal/repeater/repeater.go @@ -163,8 +163,10 @@ func (r *repeater) wakeUp(e Event) (err error) { } func (r *repeater) worker(ctx context.Context, tick clockwork.Ticker) { - defer close(r.stopped) - defer tick.Stop() + defer func() { + close(r.stopped) + tick.Stop() + }() // force returns backoff with delays [500ms...32s] force := backoff.New( diff --git a/internal/xsql/connector.go b/internal/xsql/connector.go index 95c9cc9c9..5936c321e 100644 --- a/internal/xsql/connector.go +++ b/internal/xsql/connector.go @@ -43,7 +43,7 @@ type ( LegacyOpts []legacy.Option Options []propose.Option disableServerBalancer bool - onCLose []func(*Connector) + onClose []func(*Connector) clock clockwork.Clock idleThreshold time.Duration @@ -204,7 +204,7 @@ func (c *Connector) Close() error { default: close(c.done) - for _, onClose := range c.onCLose { + for _, onClose := range c.onClose { onClose(c) } diff --git a/internal/xsql/options.go b/internal/xsql/options.go index a1293a60b..59319da9c 100644 --- a/internal/xsql/options.go +++ b/internal/xsql/options.go @@ -76,7 +76,7 @@ func (opt traceRetryOption) Apply(c *Connector) error { } func (onClose onCloseOption) Apply(c *Connector) error { - c.onCLose = append(c.onCLose, onClose) + c.onClose = append(c.onClose, onClose) return nil } diff --git a/sql.go b/sql.go index 3fcbbdfb7..26b91a97c 100644 --- a/sql.go +++ b/sql.go @@ -11,7 +11,6 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/legacy" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql/propose" - "github.com/ydb-platform/ydb-go-sdk/v3/internal/xsync" "github.com/ydb-platform/ydb-go-sdk/v3/table" "github.com/ydb-platform/ydb-go-sdk/v3/table/options" "github.com/ydb-platform/ydb-go-sdk/v3/trace" @@ -32,31 +31,13 @@ func withConnectorOptions(opts ...ConnectorOption) Option { } } -type sqlDriver struct { - connectors xsync.Map[*xsql.Connector, *Driver] -} +type sqlDriver struct{} var ( _ driver.Driver = &sqlDriver{} _ driver.DriverContext = &sqlDriver{} ) -func (d *sqlDriver) Close() error { - var errs []error - d.connectors.Range(func(c *xsql.Connector, _ *Driver) bool { - if err := c.Close(); err != nil { - errs = append(errs, err) - } - - return true - }) - if len(errs) > 0 { - return xerrors.NewWithIssues("ydb legacy driver close failed", errs...) - } - - return nil -} - // Open returns a new Driver to the ydb. func (d *sqlDriver) Open(string) (driver.Conn, error) { return nil, xsql.ErrUnsupported @@ -68,15 +49,16 @@ func (d *sqlDriver) OpenConnector(dataSourceName string) (driver.Connector, erro return nil, xerrors.WithStackTrace(fmt.Errorf("failed to connect by data source name '%s': %w", dataSourceName, err)) } - return Connector(db, db.databaseSQLOptions...) -} - -func (d *sqlDriver) attach(c *xsql.Connector, parent *Driver) { - d.connectors.Set(c, parent) -} + c, err := Connector(db, append(db.databaseSQLOptions, + xsql.WithOnClose(func(connector *xsql.Connector) { + _ = db.Close(context.Background()) + }), + )...) + if err != nil { + return nil, xerrors.WithStackTrace(fmt.Errorf("failed to create connector: %w", err)) + } -func (d *sqlDriver) detach(c *xsql.Connector) { - d.connectors.Delete(c) + return c, nil } type QueryMode int @@ -242,7 +224,6 @@ func Connector(parent *Driver, opts ...ConnectorOption) (SQLConnector, error) { parent.databaseSQLOptions, opts..., ), - xsql.WithOnClose(d.detach), xsql.WithTraceRetry(parent.config.TraceRetry()), xsql.WithRetryBudget(parent.config.RetryBudget()), )..., @@ -250,7 +231,6 @@ func Connector(parent *Driver, opts ...ConnectorOption) (SQLConnector, error) { if err != nil { return nil, xerrors.WithStackTrace(err) } - d.attach(c, parent) return c, nil } diff --git a/tests/integration/basic_example_database_sql_bindings_test.go b/tests/integration/basic_example_database_sql_bindings_test.go index b59df9461..9c191263f 100644 --- a/tests/integration/basic_example_database_sql_bindings_test.go +++ b/tests/integration/basic_example_database_sql_bindings_test.go @@ -26,6 +26,8 @@ import ( ) func TestBasicExampleDatabaseSqlBindings(t *testing.T) { + defer simpleDetectGoroutineLeak(t) + folder := t.Name() ctx, cancel := context.WithTimeout(xtest.Context(t), 42*time.Second) diff --git a/tests/integration/basic_example_database_sql_test.go b/tests/integration/basic_example_database_sql_test.go index 06af67bf9..2a51e4dfe 100644 --- a/tests/integration/basic_example_database_sql_test.go +++ b/tests/integration/basic_example_database_sql_test.go @@ -26,6 +26,8 @@ import ( ) func TestBasicExampleDatabaseSql(t *testing.T) { + defer simpleDetectGoroutineLeak(t) + folder := t.Name() ctx, cancel := context.WithTimeout(xtest.Context(t), 42*time.Second) diff --git a/tests/integration/helpers_test.go b/tests/integration/helpers_test.go index 65b92eb7b..f8c5b903e 100644 --- a/tests/integration/helpers_test.go +++ b/tests/integration/helpers_test.go @@ -10,6 +10,7 @@ import ( "fmt" "os" "path" + "runtime" "strings" "testing" "text/template" @@ -468,3 +469,16 @@ func driverEngine(db *sql.DB) (engine xsql.Engine) { return engine } + +func simpleDetectGoroutineLeak(t *testing.T) { + // 1) testing.go => main.main() + // 2) current test + const expectedGoroutinesCount = 2 + if num := runtime.NumGoroutine(); num > expectedGoroutinesCount { + bb := make([]byte, 2<<32) + if n := runtime.Stack(bb, true); n < len(bb) { + bb = bb[:n] + } + t.Error(fmt.Sprintf("unexpected goroutines:\n%s\n", string(bb[runtime.Stack(bb, false)+1:]))) + } +}