From 9b9dbd662a17e30170c518725179e969eb0bd868 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 2 Jun 2026 16:05:56 +0900 Subject: [PATCH] feat(go): add option to get the last query ID Implements #84. --- go/connection.go | 53 +++++++--- go/go.mod | 46 ++++----- go/go.sum | 96 +++++++++--------- go/pkg/driver.go | 138 +++++++++++++++++--------- go/statement.go | 58 +++++++++++ go/trino.go | 26 ++--- go/trino_test.go | 18 ++++ go/validation/tests/test_statement.py | 99 ++++++++++++++++++ go/validation/tests/trino.py | 4 +- 9 files changed, 385 insertions(+), 153 deletions(-) create mode 100644 go/statement.go diff --git a/go/connection.go b/go/connection.go index fb5ae13..9085d7d 100644 --- a/go/connection.go +++ b/go/connection.go @@ -33,6 +33,22 @@ const ( TrinoMaxQuerySizeBytes = 1_000_000 ) +// trinoConnectionImpl extends sqlwrapper connection with DbObjectsEnumerator +type trinoConnectionImpl struct { + *sqlwrapper.ConnectionImplBase // Embed sqlwrapper connection for all standard functionality + + version string +} + +// implements BulkIngester interface +var _ sqlwrapper.BulkIngester = (*trinoConnectionImpl)(nil) + +// implements DbObjectsEnumerator interface +var _ driverbase.DbObjectsEnumerator = (*trinoConnectionImpl)(nil) + +// implements CurrentNameSpacer interface +var _ driverbase.CurrentNamespacer = (*trinoConnectionImpl)(nil) + // GetCurrentCatalog implements driverbase.CurrentNamespacer. func (c *trinoConnectionImpl) GetCurrentCatalog(ctx context.Context) (string, error) { var catalog string @@ -186,15 +202,18 @@ func (c *trinoConnectionImpl) GetPlaceholder(field *arrow.Field, index int) stri var _ sqlwrapper.BulkIngester = (*trinoConnectionImpl)(nil) // ExecuteBulkIngest performs Trino bulk ingest using batched INSERT statements. -func (c *trinoConnectionImpl) ExecuteBulkIngest(ctx context.Context, conn *sqlwrapper.LoggingConn, options *driverbase.BulkIngestOptions, stream array.RecordReader) (rowCount int64, err error) { +func (c *trinoConnectionImpl) ExecuteBulkIngest(ctx context.Context, stmt sqlwrapper.StatementImpl, conn *sqlwrapper.LoggingConn, options *driverbase.BulkIngestOptions, stream array.RecordReader) (rowCount int64, err error) { + // inject the query ID capture through everything + params := stmt.(*trinoStatement).GetAdditionalExecParams() + schema := stream.Schema() - if err := c.createTableIfNeeded(ctx, conn, options.TableName, schema, options); err != nil { + if err := c.createTableIfNeeded(ctx, conn, options.TableName, schema, options, params); err != nil { return -1, c.ErrorHelper.WrapIO(err, "failed to create table") } if options.IngestBatchSize > 0 { return sqlwrapper.ExecuteBatchedBulkIngest( - ctx, conn, options, stream, + ctx, stmt, conn, options, stream, c.TypeConverter, c, &c.Base().ErrorHelper, ) } @@ -204,7 +223,7 @@ func (c *trinoConnectionImpl) ExecuteBulkIngest(ctx context.Context, conn *sqlwr } // Use Trino-specific batching with accurate serialized size measurement - return c.executeDynamicBatchedIngest(ctx, conn, options, stream) + return c.executeDynamicBatchedIngest(ctx, conn, options, stream, params) } // executeDynamicBatchedIngest performs batched INSERT with incremental query building. @@ -220,6 +239,7 @@ func (c *trinoConnectionImpl) executeDynamicBatchedIngest( conn *sqlwrapper.LoggingConn, options *driverbase.BulkIngestOptions, stream array.RecordReader, + params []any, ) (int64, error) { var totalRowsInserted int64 schema := stream.Schema() @@ -269,7 +289,7 @@ func (c *trinoConnectionImpl) executeDynamicBatchedIngest( if queryBuilder.Len()+additionalLength > options.MaxQuerySizeBytes && currentBatchRows > 0 { // Execute current batch - rowsInserted, err := c.executeBatch(ctx, conn, queryBuilder.String()) + rowsInserted, err := c.executeBatch(ctx, conn, queryBuilder.String(), params) if err != nil { return totalRowsInserted, c.ErrorHelper.WrapIO(err, "failed to insert batch at rows %d-%d", startRowIdx, startRowIdx+currentBatchRows-1) @@ -294,7 +314,7 @@ func (c *trinoConnectionImpl) executeDynamicBatchedIngest( // Execute final batch for this record batch if currentBatchRows > 0 { - rowsInserted, err := c.executeBatch(ctx, conn, queryBuilder.String()) + rowsInserted, err := c.executeBatch(ctx, conn, queryBuilder.String(), params) if err != nil { return totalRowsInserted, c.ErrorHelper.WrapIO(err, "failed to insert final batch at rows %d-%d", startRowIdx, startRowIdx+currentBatchRows-1) @@ -355,12 +375,13 @@ func (c *trinoConnectionImpl) executeBatch( ctx context.Context, conn *sqlwrapper.LoggingConn, querySQL string, + params []any, ) (int64, error) { if querySQL == "" { return 0, nil } - result, err := conn.ExecContext(ctx, querySQL) + result, err := conn.ExecContext(ctx, querySQL, params...) if err != nil { return 0, err } @@ -374,20 +395,20 @@ func (c *trinoConnectionImpl) executeBatch( } // createTableIfNeeded creates the table based on the ingest mode -func (c *trinoConnectionImpl) createTableIfNeeded(ctx context.Context, conn *sqlwrapper.LoggingConn, tableName string, schema *arrow.Schema, options *driverbase.BulkIngestOptions) error { +func (c *trinoConnectionImpl) createTableIfNeeded(ctx context.Context, conn *sqlwrapper.LoggingConn, tableName string, schema *arrow.Schema, options *driverbase.BulkIngestOptions, params []any) error { switch options.Mode { case adbc.OptionValueIngestModeCreate: // Create the table (fail if exists) - return c.createTable(ctx, conn, tableName, schema, false) + return c.createTable(ctx, conn, tableName, schema, false, params) case adbc.OptionValueIngestModeCreateAppend: // Create the table if it doesn't exist - return c.createTable(ctx, conn, tableName, schema, true) + return c.createTable(ctx, conn, tableName, schema, true, params) case adbc.OptionValueIngestModeReplace: // Drop and recreate the table - if err := c.dropTable(ctx, conn, tableName); err != nil { + if err := c.dropTable(ctx, conn, tableName, params); err != nil { return err } - return c.createTable(ctx, conn, tableName, schema, false) + return c.createTable(ctx, conn, tableName, schema, false, params) case adbc.OptionValueIngestModeAppend: // Table should already exist, do nothing return nil @@ -397,7 +418,7 @@ func (c *trinoConnectionImpl) createTableIfNeeded(ctx context.Context, conn *sql } // createTable creates a Trino table from Arrow schema -func (c *trinoConnectionImpl) createTable(ctx context.Context, conn *sqlwrapper.LoggingConn, tableName string, schema *arrow.Schema, ifNotExists bool) error { +func (c *trinoConnectionImpl) createTable(ctx context.Context, conn *sqlwrapper.LoggingConn, tableName string, schema *arrow.Schema, ifNotExists bool, params []any) error { var queryBuilder strings.Builder queryBuilder.WriteString("CREATE TABLE ") if ifNotExists { @@ -424,14 +445,14 @@ func (c *trinoConnectionImpl) createTable(ctx context.Context, conn *sqlwrapper. queryBuilder.WriteString(")") - _, err := conn.ExecContext(ctx, queryBuilder.String()) + _, err := conn.ExecContext(ctx, queryBuilder.String(), params...) return err } // dropTable drops a Trino table -func (c *trinoConnectionImpl) dropTable(ctx context.Context, conn *sqlwrapper.LoggingConn, tableName string) error { +func (c *trinoConnectionImpl) dropTable(ctx context.Context, conn *sqlwrapper.LoggingConn, tableName string, params []any) error { dropSQL := fmt.Sprintf("DROP TABLE IF EXISTS %s", quoteIdentifier(tableName)) - _, err := conn.ExecContext(ctx, dropSQL) + _, err := conn.ExecContext(ctx, dropSQL, params...) return err } diff --git a/go/go.mod b/go/go.mod index b9623c5..24d9e5e 100644 --- a/go/go.mod +++ b/go/go.mod @@ -17,8 +17,8 @@ module github.com/adbc-drivers/trino go 1.26.0 require ( - github.com/adbc-drivers/driverbase-go/driverbase v0.0.0-20260423045143-148150eea03b - github.com/adbc-drivers/driverbase-go/sqlwrapper v0.0.0-20260423045143-148150eea03b + github.com/adbc-drivers/driverbase-go/driverbase v0.0.0-20260608005845-f218ccb883e8 + github.com/adbc-drivers/driverbase-go/sqlwrapper v0.0.0-20260608005845-f218ccb883e8 github.com/adbc-drivers/driverbase-go/testutil v0.0.0-20260423045143-148150eea03b github.com/adbc-drivers/driverbase-go/validation v0.0.0-20260423045143-148150eea03b github.com/apache/arrow-adbc/go/adbc v1.11.0 @@ -30,7 +30,7 @@ require ( require ( github.com/andybalholm/brotli v1.2.1 // indirect - github.com/apache/thrift v0.22.0 // indirect + github.com/apache/thrift v0.23.0 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect @@ -46,33 +46,33 @@ require ( github.com/jcmturner/goidentity/v6 v6.0.1 // indirect github.com/jcmturner/gokrb5/v8 v8.4.4 // indirect github.com/jcmturner/rpc/v2 v2.0.3 // indirect - github.com/klauspost/compress v1.18.5 // indirect + github.com/klauspost/compress v1.18.6 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/pierrec/lz4 v2.6.1+incompatible // indirect - github.com/pierrec/lz4/v4 v4.1.26 // indirect + github.com/pierrec/lz4/v4 v4.1.27 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/zeebo/xxh3 v1.1.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect - go.opentelemetry.io/otel v1.43.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0 // indirect - go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0 // indirect - go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.43.0 // indirect - go.opentelemetry.io/otel/metric v1.43.0 // indirect - go.opentelemetry.io/otel/sdk v1.43.0 // indirect - go.opentelemetry.io/otel/trace v1.43.0 // indirect + go.opentelemetry.io/otel v1.44.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.44.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.44.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.44.0 // indirect + go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.44.0 // indirect + go.opentelemetry.io/otel/metric v1.44.0 // indirect + go.opentelemetry.io/otel/sdk v1.44.0 // indirect + go.opentelemetry.io/otel/trace v1.44.0 // indirect go.opentelemetry.io/proto/otlp v1.10.0 // indirect - golang.org/x/crypto v0.50.0 // indirect - golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f // indirect - golang.org/x/mod v0.35.0 // indirect - golang.org/x/net v0.53.0 // indirect + golang.org/x/crypto v0.52.0 // indirect + golang.org/x/exp v0.0.0-20260603202125-055de637280b // indirect + golang.org/x/mod v0.36.0 // indirect + golang.org/x/net v0.55.0 // indirect golang.org/x/sync v0.20.0 // indirect - golang.org/x/sys v0.43.0 // indirect - golang.org/x/text v0.36.0 // indirect - golang.org/x/tools v0.44.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20260420184626-e10c466a9529 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20260420184626-e10c466a9529 // indirect - google.golang.org/grpc v1.80.0 // indirect + golang.org/x/sys v0.45.0 // indirect + golang.org/x/text v0.37.0 // indirect + golang.org/x/tools v0.45.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260526163538-3dc84a4a5aaa // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260526163538-3dc84a4a5aaa // indirect + google.golang.org/grpc v1.81.1 // indirect google.golang.org/protobuf v1.36.11 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go/go.sum b/go/go.sum index ce63496..f4e7598 100644 --- a/go/go.sum +++ b/go/go.sum @@ -6,10 +6,10 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk= -github.com/adbc-drivers/driverbase-go/driverbase v0.0.0-20260423045143-148150eea03b h1:0qcsKFrFGtzvdBeUjkCS4Y9AVwhsfhJL+JT01aYzLvs= -github.com/adbc-drivers/driverbase-go/driverbase v0.0.0-20260423045143-148150eea03b/go.mod h1:RhEM8H8KNN5+UE/3g6vJM4eJsnOd9IBKYEq4vCOCie0= -github.com/adbc-drivers/driverbase-go/sqlwrapper v0.0.0-20260423045143-148150eea03b h1:AgGvSTRiyEIV/gXYpY1w3+8PJ1JIkRqtcdDJ5kLVVAo= -github.com/adbc-drivers/driverbase-go/sqlwrapper v0.0.0-20260423045143-148150eea03b/go.mod h1:0cNj4jqbY7zgcTYaAHhoM+InKuWnBXgy8jhyr9dUPkQ= +github.com/adbc-drivers/driverbase-go/driverbase v0.0.0-20260608005845-f218ccb883e8 h1:BG9ZygPtzRQWGUpdzu38ITrk669ao13EN797NQSij2E= +github.com/adbc-drivers/driverbase-go/driverbase v0.0.0-20260608005845-f218ccb883e8/go.mod h1:rCDugBvob6aEFZu7zJ9SqPSyyys6ipNWcsfvzF/Ns74= +github.com/adbc-drivers/driverbase-go/sqlwrapper v0.0.0-20260608005845-f218ccb883e8 h1:+EZnfQX2uHe+s/tlQfcS0rXx1y6dTQKg0A1EqL3RNko= +github.com/adbc-drivers/driverbase-go/sqlwrapper v0.0.0-20260608005845-f218ccb883e8/go.mod h1:H5NH2RgPyzQjwJFA7wYNbJNdxQXQcBcpN3+xQkTgP2s= github.com/adbc-drivers/driverbase-go/testutil v0.0.0-20260423045143-148150eea03b h1:lCFRgYUVoy3cFTnUoKi/px3ZiV2cmIHujy3OVEm93P4= github.com/adbc-drivers/driverbase-go/testutil v0.0.0-20260423045143-148150eea03b/go.mod h1:uKkYVA+iUXij7mPFproEi6sOIWFVNOT87UcqucqqnBg= github.com/adbc-drivers/driverbase-go/validation v0.0.0-20260423045143-148150eea03b h1:hvd/bQfV1cdkYN0wAgDZVNC9Bwuq4IGVHLpG5gFBWU8= @@ -22,8 +22,8 @@ github.com/apache/arrow-adbc/go/adbc v1.11.0 h1:zcSLtV8CQ27chkYWZmySvd4+pkkDtWhR github.com/apache/arrow-adbc/go/adbc v1.11.0/go.mod h1:7BIIq4XHzttQVG293LRurIhF/UH+IWlgGRR6hUqPLG8= github.com/apache/arrow-go/v18 v18.6.0 h1:GX/Jyd3R7mCLiECAwY9FWbbaYblie2WXBSz4Sw8fNpM= github.com/apache/arrow-go/v18 v18.6.0/go.mod h1:gm3MiPpY82fLYK5VKPB3WoJbsiLVDfT7flD5/vHReKw= -github.com/apache/thrift v0.22.0 h1:r7mTJdj51TMDe6RtcmNdQxgn9XcyfGDOzegMDRg47uc= -github.com/apache/thrift v0.22.0/go.mod h1:1e7J/O1Ae6ZQMTYdy9xa3w9k+XHWPfRvdPyJeynQ+/g= +github.com/apache/thrift v0.23.0 h1:wKR6YnefQSEnxpEfmgTPuJibNG4bF0p2TK34tHLWi3s= +github.com/apache/thrift v0.23.0/go.mod h1:zPt6WxgvTOM6hF92y8C+MkEM5LMxZuk4JcQOiU4Esvs= github.com/aws/aws-sdk-go v1.55.8 h1:JRmEUbU52aJQZ2AjX4q4Wu7t4uZjOu71uyNmaWlUkJQ= github.com/aws/aws-sdk-go v1.55.8/go.mod h1:ZkViS9AqA6otK+JBBNH2++sx1sgxrPKcSzPPvQkUtXk= github.com/aws/aws-sdk-go-v2 v1.39.0 h1:xm5WV/2L4emMRmMjHFykqiA4M/ra0DJVSWUkDyBjbg4= @@ -126,8 +126,8 @@ github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh6 github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= -github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= -github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= +github.com/klauspost/compress v1.18.6 h1:2jupLlAwFm95+YDR+NwD2MEfFO9d4z4Prjl1XXDjuao= +github.com/klauspost/compress v1.18.6/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -150,8 +150,8 @@ github.com/ory/dockertest/v3 v3.12.0 h1:3oV9d0sDzlSQfHtIaB5k6ghUCVMVLpAY8hwrqoCy github.com/ory/dockertest/v3 v3.12.0/go.mod h1:aKNDTva3cp8dwOWwb9cWuX84aH5akkxXRvO7KCwWVjE= github.com/pierrec/lz4 v2.6.1+incompatible h1:9UY3+iC23yxF0UfGaYrGplQ+79Rg+h/q9FV9ix19jjM= github.com/pierrec/lz4 v2.6.1+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= -github.com/pierrec/lz4/v4 v4.1.26 h1:GrpZw1gZttORinvzBdXPUXATeqlJjqUG/D87TKMnhjY= -github.com/pierrec/lz4/v4 v4.1.26/go.mod h1:EoQMVJgeeEOMsCqCzqFm2O0cJvljX2nGZjcRIPL34O4= +github.com/pierrec/lz4/v4 v4.1.27 h1:+PhzhWDrjRj89TH2sw43nE3+4+W8lSxIuQadEHZyjUk= +github.com/pierrec/lz4/v4 v4.1.27/go.mod h1:EoQMVJgeeEOMsCqCzqFm2O0cJvljX2nGZjcRIPL34O4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -189,24 +189,24 @@ github.com/zeebo/xxh3 v1.1.0 h1:s7DLGDK45Dyfg7++yxI0khrfwq9661w9EN78eP/UZVs= github.com/zeebo/xxh3 v1.1.0/go.mod h1:IisAie1LELR4xhVinxWS5+zf1lA4p0MW4T+w+W07F5s= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= -go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= -go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 h1:88Y4s2C8oTui1LGM6bTWkw0ICGcOLCAI5l6zsD1j20k= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0/go.mod h1:Vl1/iaggsuRlrHf/hfPJPvVag77kKyvrLeD10kpMl+A= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0 h1:RAE+JPfvEmvy+0LzyUA25/SGawPwIUbZ6u0Wug54sLc= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0/go.mod h1:AGmbycVGEsRx9mXMZ75CsOyhSP6MFIcj/6dnG+vhVjk= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0 h1:3iZJKlCZufyRzPzlQhUIWVmfltrXuGyfjREgGP3UUjc= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0/go.mod h1:/G+nUPfhq2e+qiXMGxMwumDrP5jtzU+mWN7/sjT2rak= -go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.43.0 h1:mS47AX77OtFfKG4vtp+84kuGSFZHTyxtXIN269vChY0= -go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.43.0/go.mod h1:PJnsC41lAGncJlPUniSwM81gc80GkgWJWr3cu2nKEtU= -go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= -go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= -go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= -go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= -go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= -go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= -go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= -go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= +go.opentelemetry.io/otel v1.44.0 h1:JjwHmHpA4iZ3wBxluu2fbbE7j4kqlE8jXyAyPXH7HqU= +go.opentelemetry.io/otel v1.44.0/go.mod h1:BMgjTHL9WPRlRjL2oZCBTL4whCGtXch2H4BhOPIAyYc= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.44.0 h1:4YsVu3B8+3qtWYYrsUYgn0OG78pN0rnNPRGX4SbokQI= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.44.0/go.mod h1:+wnlSn0mD1ADVMe3v9Z/WIaiz6q6gL2J/ejaAmdmv80= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.44.0 h1:qazEJlUOQzhCpzQpFETGby7EdqjI1wsd0W+6Gg1SCTU= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.44.0/go.mod h1:fOD2Yefuxixkx3ahVNf0O/PERb6r4OlbxfATVnYvzCo= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.44.0 h1:lgh3PiVrRUWMLOVSkQicxzZll5NjF1r+AtsX1XRIHw0= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.44.0/go.mod h1:5Cnhth3m/AgOeTgE3ex12pPmiu/gGtZit03kSzx9X7s= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.44.0 h1:bl2S7Ubua0Nms+D/gAmznQTd4dxxMA93aKbcpKqiTCs= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.44.0/go.mod h1:L0hRV50XdVIODHUfWEqGRCXQvj2rV82STVo12FMFBU0= +go.opentelemetry.io/otel/metric v1.44.0 h1:1w0gILTcHdr3YI+ixLyjemwrVnsMURbTZFrSYCdDdmc= +go.opentelemetry.io/otel/metric v1.44.0/go.mod h1:8O7hanEPBNgEMmybD3s2VBKcgWOCsA6tzHBPODAiquo= +go.opentelemetry.io/otel/sdk v1.44.0 h1:nHYwb9lK+fJPU/dnT6s7W7Z8itMWyqrnVfbheVYrZ58= +go.opentelemetry.io/otel/sdk v1.44.0/go.mod h1:Osuydd3Se74nqjAKxid74N5eC+jfEqfTegHRnq58oK0= +go.opentelemetry.io/otel/sdk/metric v1.44.0 h1:3LlKgI+VjbVsjNRFZJZAJ30WjXC5VkNRks6si09iEfI= +go.opentelemetry.io/otel/sdk/metric v1.44.0/go.mod h1:5B5pMARnXxKhltooO4xUuCBorl65a4EpnTalObqOigA= +go.opentelemetry.io/otel/trace v1.44.0 h1:jxF5CsGYCe74MCRx2X4g7WsY/VBKRqqpNvXlX/6gtIk= +go.opentelemetry.io/otel/trace v1.44.0/go.mod h1:oLl1jrMQAVo6v3GAggN+1VH9VIz9iUSvW53sW1Q8PIE= go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g= go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= @@ -214,21 +214,21 @@ go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= -golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= -golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= -golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f h1:W3F4c+6OLc6H2lb//N1q4WpJkhzJCK5J6kUi1NTVXfM= -golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f/go.mod h1:J1xhfL/vlindoeF/aINzNzt2Bket5bjo9sdOYzOsU80= +golang.org/x/crypto v0.52.0 h1:RMs7fP2rXdep0CftQlK8Uf+kibLm7qkCcradZWYz988= +golang.org/x/crypto v0.52.0/go.mod h1:1QgfPxDqh0T2M/elOJtp9RvuR95kVjir0e6/BvEmGbc= +golang.org/x/exp v0.0.0-20260603202125-055de637280b h1:v1uXiEBHo8QA0LiGCo7UgHMzHT4Kdfpl2zmtH5vaP1Q= +golang.org/x/exp v0.0.0-20260603202125-055de637280b/go.mod h1:d2fgXJLVs4dYDHUk5lwMIfzRzSrWCfGZb0ZqeLa/Vcw= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM= -golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU= +golang.org/x/mod v0.36.0 h1:JJjpVx6myfUsUdAzZuOSTTmRE0PfZeNWzzvKrP7amb4= +golang.org/x/mod v0.36.0/go.mod h1:moc6ELqsWcOw5Ef3xVprK5ul/MvtVvkIXLziUOICjUQ= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= -golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= +golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8= +golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= @@ -239,8 +239,8 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= -golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY= +golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -248,22 +248,22 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= -golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= +golang.org/x/text v0.37.0 h1:Cqjiwd9eSg8e0QAkyCaQTNHFIIzWtidPahFWR83rTrc= +golang.org/x/text v0.37.0/go.mod h1:a5sjxXGs9hsn/AJVwuElvCAo9v8QYLzvavO5z2PiM38= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.44.0 h1:UP4ajHPIcuMjT1GqzDWRlalUEoY+uzoZKnhOjbIPD2c= -golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI= +golang.org/x/tools v0.45.0 h1:18qN3FAooORvApf5XjCXgsuayZOEtXf6JK18I3+ONa8= +golang.org/x/tools v0.45.0/go.mod h1:LuUGqqaXcXMEFEruIVJVm5mgDD8vww/z/SR1gQ4uE/0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= -google.golang.org/genproto/googleapis/api v0.0.0-20260420184626-e10c466a9529 h1:zUWMZsvo/IJcD1t6MNCPO/azZTwz0TvwCBqr5aifoVY= -google.golang.org/genproto/googleapis/api v0.0.0-20260420184626-e10c466a9529/go.mod h1:a5OGAgyRr4lqco7AG9hQM9Fwh0N2ZV4grR0eXFEsXQg= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260420184626-e10c466a9529 h1:XF8+t6QQiS0o9ArVan/HW8Q7cycNPGsJf6GA2nXxYAg= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260420184626-e10c466a9529/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= -google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= -google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= +google.golang.org/genproto/googleapis/api v0.0.0-20260526163538-3dc84a4a5aaa h1:Kjn0N0tCrDgiAFW+lGO4JZ3ck44CehvJQMAwj9QF0G8= +google.golang.org/genproto/googleapis/api v0.0.0-20260526163538-3dc84a4a5aaa/go.mod h1:q4lMZS6kskjT5HvCPrnnypcDPVJqT/f4nfxmkE7gryY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260526163538-3dc84a4a5aaa h1:mZHHdPZl0dbGHCflZgAq/Q468DWVFcU2whhB2KAo8fk= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260526163538-3dc84a4a5aaa/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +google.golang.org/grpc v1.81.1 h1:VnnIIZ88UzOOKLukQi+ImGz8O1Wdp8nAGGnvOfEIWQQ= +google.golang.org/grpc v1.81.1/go.mod h1:xGH9GfzOyMTGIOXBJmXt+BX/V0kcdQbdcuwQ/zNw42I= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/go/pkg/driver.go b/go/pkg/driver.go index e2c0950..5bbacae 100644 --- a/go/pkg/driver.go +++ b/go/pkg/driver.go @@ -248,22 +248,22 @@ func initLoggingFromEnv(db adbc.DatabaseLogging) { db.SetLogger(logger) } -// Allocate a new cgo.Handle and store its address in a heap-allocated -// uintptr_t. Experimentally, this was found to be necessary, else -// something (the Go runtime?) would corrupt (garbage-collect?) the -// handle. +// cgo.Handle is a uintptr integer (not a pointer). Packing it directly into +// a void* field is safe: the CGO checker only rejects Go heap pointers, and +// handle values (small non-zero integers from a global counter) never alias +// Go-allocated memory. The GC does not scan C-managed memory, so it will +// never misinterpret the stored integer as a live pointer. No C allocation +// is needed — the handle value itself fits in the pointer-sized field. func createHandle(hndl cgo.Handle) unsafe.Pointer { - // uintptr_t* hptr = malloc(sizeof(uintptr_t)); - hptr := (*C.uintptr_t)(C.calloc(C.sizeof_uintptr_t, C.size_t(1))) - // *hptr = (uintptr)hndl; - *hptr = C.uintptr_t(uintptr(hndl)) - return unsafe.Pointer(hptr) + return unsafe.Pointer(uintptr(hndl)) +} + +func handleFromPtr(ptr unsafe.Pointer) cgo.Handle { + return cgo.Handle(uintptr(ptr)) } func getFromHandle[T any](ptr unsafe.Pointer) *T { - // uintptr_t* hptr = (uintptr_t*)ptr; - hptr := (*C.uintptr_t)(ptr) - return cgo.Handle((uintptr)(*hptr)).Value().(*T) + return handleFromPtr(ptr).Value().(*T) } func exportStringOption(val string, out *C.char, length *C.size_t) C.AdbcStatusCode { @@ -271,7 +271,7 @@ func exportStringOption(val string, out *C.char, length *C.size_t) C.AdbcStatusC if lenWithTerminator <= *length { sink := fromCArr[byte]((*byte)(unsafe.Pointer(out)), int(*length)) copy(sink, val) - sink[lenWithTerminator] = 0 + sink[len(val)] = 0 } *length = lenWithTerminator return C.ADBC_STATUS_OK @@ -435,17 +435,16 @@ func TrinoArrayStreamRelease(stream *C.struct_ArrowArrayStream) { if stream == nil || stream.release != (*[0]byte)(C.TrinoArrayStreamRelease) || stream.private_data == nil { return } - h := (*(*cgo.Handle)(stream.private_data)) + h := handleFromPtr(stream.private_data) + stream.private_data = nil cStream := h.Value().(*cArrayStream) + h.Delete() cStream.rdr.Release() if cStream.adbcErr != nil { C.TrinoerrRelease(cStream.adbcErr) C.free(unsafe.Pointer(cStream.adbcErr)) } - C.free(unsafe.Pointer(stream.private_data)) - stream.private_data = nil - h.Delete() runtime.GC() } @@ -472,10 +471,17 @@ func exportRecordReader(rdr array.RecordReader, stream *C.struct_ArrowArrayStrea rdr.Retain() } +type unappliedOpt struct { + stringVal *string + int64Val *int64 + byteVal []byte + doubleVal *float64 +} + type cDatabase struct { cancellableContext - opts map[string]string + opts map[string]unappliedOpt db driverbase.Database } @@ -580,12 +586,35 @@ func TrinoDatabaseInit(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code return C.ADBC_STATUS_INVALID_STATE } - adb, aerr := drv.NewDatabaseWithContext(cdb.newContext(), cdb.opts) + stringOpts := map[string]string{} + for k, v := range cdb.opts { + if v.stringVal != nil { + stringOpts[k] = *v.stringVal + } + } + ctx := cdb.newContext() + adb, aerr := drv.NewDatabaseWithContext(ctx, stringOpts) if aerr != nil { return C.AdbcStatusCode(errToAdbcErr(err, aerr)) } cdb.db = adb.(driverbase.Database) + for k, v := range cdb.opts { + switch { + case v.stringVal != nil: + continue + case v.int64Val != nil: + aerr = cdb.db.SetOptionInt(ctx, k, *v.int64Val) + case v.byteVal != nil: + aerr = cdb.db.SetOptionBytes(ctx, k, v.byteVal) + case v.doubleVal != nil: + aerr = cdb.db.SetOptionDouble(ctx, k, *v.doubleVal) + } + if aerr != nil { + return C.AdbcStatusCode(errToAdbcErr(err, aerr)) + } + } + initLoggingFromEnv(cdb.db) return C.ADBC_STATUS_OK } @@ -605,7 +634,7 @@ func TrinoDatabaseNew(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (code setErr(err, "AdbcDatabaseNew: database already allocated") return C.ADBC_STATUS_INVALID_STATE } - dbobj := &cDatabase{opts: make(map[string]string)} + dbobj := &cDatabase{opts: make(map[string]unappliedOpt)} hndl := cgo.NewHandle(dbobj) db.private_data = createHandle(hndl) return C.ADBC_STATUS_OK @@ -621,19 +650,17 @@ func TrinoDatabaseRelease(db *C.struct_AdbcDatabase, err *C.struct_AdbcError) (c if !checkDBAlloc(db, err, "AdbcDatabaseRelease") { return C.ADBC_STATUS_INVALID_STATE } - h := (*(*cgo.Handle)(db.private_data)) + h := handleFromPtr(db.private_data) + db.private_data = nil cdb := h.Value().(*cDatabase) + h.Delete() if cdb.db != nil { cdb.db.Close(cdb.newContext()) cdb.db = nil } cdb.opts = nil - if db.private_data != nil { - C.free(unsafe.Pointer(db.private_data)) - db.private_data = nil - } - h.Delete() + // manually trigger GC for two reasons: // 1. ASAN expects the release callback to be called before // the process ends, but GC is not deterministic. So by manually @@ -661,7 +688,7 @@ func TrinoDatabaseSetOption(db *C.struct_AdbcDatabase, key, value *C.cchar_t, er e := cdb.db.SetOption(cdb.newContext(), k, v) return C.AdbcStatusCode(errToAdbcErr(err, e)) } else { - cdb.opts[k] = v + cdb.opts[k] = unappliedOpt{stringVal: new(v)} } return C.ADBC_STATUS_OK @@ -674,13 +701,19 @@ func TrinoDatabaseSetOptionBytes(db *C.struct_AdbcDatabase, key *C.cchar_t, valu code = poison(err, "AdbcDatabaseSetOptionBytes", e) } }() - cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionBytes") - if cdb == nil { + if !checkDBAlloc(db, err, "AdbcDatabaseSetOptionBytes") { return C.ADBC_STATUS_INVALID_STATE } + cdb := getFromHandle[cDatabase](db.private_data) + k := C.GoString(key) + v := fromCArr[byte](value, int(length)) - e := cdb.db.SetOptionBytes(cdb.newContext(), C.GoString(key), fromCArr[byte](value, int(length))) - return C.AdbcStatusCode(errToAdbcErr(err, e)) + if cdb.db != nil { + e := cdb.db.SetOptionBytes(cdb.newContext(), k, v) + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + cdb.opts[k] = unappliedOpt{byteVal: v} + return C.ADBC_STATUS_OK } //export TrinoDatabaseSetOptionDouble @@ -690,13 +723,19 @@ func TrinoDatabaseSetOptionDouble(db *C.struct_AdbcDatabase, key *C.cchar_t, val code = poison(err, "AdbcDatabaseSetOptionDouble", e) } }() - cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionDouble") - if cdb == nil { + if !checkDBAlloc(db, err, "AdbcDatabaseSetOptionDouble") { return C.ADBC_STATUS_INVALID_STATE } + cdb := getFromHandle[cDatabase](db.private_data) + k := C.GoString(key) + v := float64(value) - e := cdb.db.SetOptionDouble(cdb.newContext(), C.GoString(key), float64(value)) - return C.AdbcStatusCode(errToAdbcErr(err, e)) + if cdb.db != nil { + e := cdb.db.SetOptionDouble(cdb.newContext(), k, v) + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + cdb.opts[k] = unappliedOpt{doubleVal: new(v)} + return C.ADBC_STATUS_OK } //export TrinoDatabaseSetOptionInt @@ -706,13 +745,19 @@ func TrinoDatabaseSetOptionInt(db *C.struct_AdbcDatabase, key *C.cchar_t, value code = poison(err, "AdbcDatabaseSetOptionInt", e) } }() - cdb := checkDBInit(db, err, "AdbcDatabaseSetOptionInt") - if cdb == nil { + if !checkDBAlloc(db, err, "AdbcDatabaseSetOptionInt") { return C.ADBC_STATUS_INVALID_STATE } + cdb := getFromHandle[cDatabase](db.private_data) + k := C.GoString(key) + v := int64(value) - e := cdb.db.SetOptionInt(cdb.newContext(), C.GoString(key), int64(value)) - return C.AdbcStatusCode(errToAdbcErr(err, e)) + if cdb.db != nil { + e := cdb.db.SetOptionInt(cdb.newContext(), k, v) + return C.AdbcStatusCode(errToAdbcErr(err, e)) + } + cdb.opts[k] = unappliedOpt{int64Val: new(v)} + return C.ADBC_STATUS_OK } type cConn struct { @@ -969,15 +1014,15 @@ func TrinoConnectionRelease(cnxn *C.struct_AdbcConnection, err *C.struct_AdbcErr if !checkConnAlloc(cnxn, err, "AdbcConnectionRelease") { return C.ADBC_STATUS_INVALID_STATE } - h := (*(*cgo.Handle)(cnxn.private_data)) + h := handleFromPtr(cnxn.private_data) + cnxn.private_data = nil conn := h.Value().(*cConn) + h.Delete() defer func() { conn.cancelContext() conn.cnxn = nil - C.free(cnxn.private_data) - cnxn.private_data = nil - h.Delete() + // manually trigger GC for two reasons: // 1. ASAN expects the release callback to be called before // the process ends, but GC is not deterministic. So by manually @@ -1418,15 +1463,14 @@ func TrinoStatementRelease(stmt *C.struct_AdbcStatement, err *C.struct_AdbcError if !checkStmtAlloc(stmt, err, "AdbcStatementRelease") { return C.ADBC_STATUS_INVALID_STATE } - h := (*(*cgo.Handle)(stmt.private_data)) + h := handleFromPtr(stmt.private_data) + stmt.private_data = nil st := h.Value().(*cStmt) + h.Delete() defer func() { st.cancelContext() st.stmt = nil - C.free(stmt.private_data) - stmt.private_data = nil - h.Delete() // manually trigger GC for two reasons: // 1. ASAN expects the release callback to be called before // the process ends, but GC is not deterministic. So by manually diff --git a/go/statement.go b/go/statement.go new file mode 100644 index 0000000..39ac350 --- /dev/null +++ b/go/statement.go @@ -0,0 +1,58 @@ +// Copyright (c) 2026 ADBC Drivers Contributors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package trino + +import ( + "context" + "database/sql" + "sync/atomic" + "time" + + sqlwrapper "github.com/adbc-drivers/driverbase-go/sqlwrapper" + "github.com/trinodb/trino-go-client/trino" +) + +type trinoStatement struct { + *sqlwrapper.StatementImplBase + + lastQueryId atomic.Pointer[string] +} + +func (st *trinoStatement) GetOption(ctx context.Context, key string) (string, error) { + switch key { + case "trino.statement.last_query_id": + val := st.lastQueryId.Load() + if val == nil { + return "", nil + } + return *val, nil + } + return st.StatementImplBase.GetOption(ctx, key) +} + +func (st *trinoStatement) GetAdditionalExecParams() []any { + return []any{ + sql.Named("X-Trino-Progress-Callback", trinoProgressCallback{st}), + sql.Named("X-Trino-Progress-Callback-Period", 5*time.Second), + } +} + +type trinoProgressCallback struct { + st *trinoStatement +} + +func (cb trinoProgressCallback) Update(info trino.QueryProgressInfo) { + cb.st.lastQueryId.Store(&info.QueryId) +} diff --git a/go/trino.go b/go/trino.go index fd1acc1..fc2b4b5 100644 --- a/go/trino.go +++ b/go/trino.go @@ -517,22 +517,6 @@ func convertDecimalToTrinoNumericFromInt(value *big.Int, scale int32) trino.Nume return trino.Numeric(rat.FloatString(int(scale))) } -// trinoConnectionImpl extends sqlwrapper connection with DbObjectsEnumerator -type trinoConnectionImpl struct { - *sqlwrapper.ConnectionImplBase // Embed sqlwrapper connection for all standard functionality - - version string -} - -// implements BulkIngester interface -var _ sqlwrapper.BulkIngester = (*trinoConnectionImpl)(nil) - -// implements DbObjectsEnumerator interface -var _ driverbase.DbObjectsEnumerator = (*trinoConnectionImpl)(nil) - -// implements CurrentNameSpacer interface -var _ driverbase.CurrentNamespacer = (*trinoConnectionImpl)(nil) - // trinoConnectionFactory creates Trino connections type trinoConnectionFactory struct{} @@ -547,6 +531,12 @@ func (f *trinoConnectionFactory) CreateConnection( }, nil } +func (f *trinoConnectionFactory) CreateStatement(stmt *sqlwrapper.StatementImplBase) (sqlwrapper.StatementImpl, error) { + return &trinoStatement{ + StatementImplBase: stmt, + }, nil +} + // NewDriver constructs the ADBC Driver for "trino". func NewDriver(alloc memory.Allocator) driverbase.DriverWithContext { vendorName := "Trino" @@ -554,8 +544,10 @@ func NewDriver(alloc memory.Allocator) driverbase.DriverWithContext { DefaultTypeConverter: sqlwrapper.DefaultTypeConverter{VendorName: vendorName}, } + factory := &trinoConnectionFactory{} driver := sqlwrapper.NewDriver(alloc, "trino", vendorName, NewTrinoDBFactory(), typeConverter). - WithConnectionFactory(&trinoConnectionFactory{}). + WithConnectionFactory(factory). + WithStatementFactory(factory). WithErrorInspector(TrinoErrorInspector{}) driver.DriverInfo.MustRegister(map[adbc.InfoCode]any{ adbc.InfoDriverName: "ADBC Driver Foundry Driver for Trino", diff --git a/go/trino_test.go b/go/trino_test.go index 36ba719..a5c6fe4 100644 --- a/go/trino_test.go +++ b/go/trino_test.go @@ -489,6 +489,24 @@ type selectCase struct { expected string } +func (s *TrinoTests) TestIngestQueryId() { + schema := arrow.NewSchema([]arrow.Field{ + { + Name: "ints", + Type: arrow.PrimitiveTypes.Int64, + Nullable: true, + }, + }, nil) + batch := testutil.RecordFromJSON(s.T(), s.Quirks.Alloc(), schema, `[{"ints": 1}, {"ints": 2}, {"ints": 3}]`) + defer batch.Release() + s.Require().NoError(s.stmt.Bind(s.ctx, batch)) + s.Require().NoError(s.stmt.SetOption(s.ctx, adbc.OptionKeyIngestTargetTable, "foobar")) + s.Require().NoError(s.stmt.SetOption(s.ctx, adbc.OptionKeyIngestMode, adbc.OptionValueIngestModeReplace)) + + _, err := s.stmt.ExecuteUpdate(s.ctx) + s.Require().NoError(err) +} + func (s *TrinoTests) TestSelect() { // Drop table if it exists first, then create test table with basic Trino types s.NoError(s.stmt.SetSqlQuery(s.ctx, `DROP TABLE IF EXISTS memory.default.test_types`)) diff --git a/go/validation/tests/test_statement.py b/go/validation/tests/test_statement.py index 03fd1b0..252e441 100644 --- a/go/validation/tests/test_statement.py +++ b/go/validation/tests/test_statement.py @@ -13,6 +13,7 @@ # limitations under the License. import adbc_drivers_validation.tests.statement as statement_tests +import pyarrow from . import trino @@ -52,3 +53,101 @@ def test_rows_affected(self, driver, conn) -> None: driver.drop_table(table_name="test_rows_affected") ) cursor.adbc_statement.execute_update() + + +def test_query_id_query(driver, conn) -> None: + with conn.cursor() as cursor: + query_id = cursor.adbc_statement.get_option("trino.statement.last_query_id") + assert query_id == "" + + cursor.adbc_statement.set_sql_query("SELECT 1") + stream, _ = cursor.adbc_statement.execute_query() + with pyarrow.RecordBatchReader._import_from_c(stream.address) as reader: + reader.read_all() + + query_id = cursor.adbc_statement.get_option("trino.statement.last_query_id") + assert query_id is not None + assert query_id != "" + + +def test_query_id_query_bind(driver, conn) -> None: + with conn.cursor() as cursor: + query_id = cursor.adbc_statement.get_option("trino.statement.last_query_id") + assert query_id == "" + + cursor.adbc_statement.set_options(**{"adbc.statement.batch_size": 1}) + cursor.adbc_statement.set_sql_query("SELECT ? + 1") + cursor.adbc_statement.bind_stream(pyarrow.table({"col1": [1, 2, 3]})) + stream, _ = cursor.adbc_statement.execute_query() + + # each execution gets its own query ID + seen = set() + with pyarrow.RecordBatchReader._import_from_c(stream.address) as reader: + for batch in reader: + seen.add( + cursor.adbc_statement.get_option("trino.statement.last_query_id") + ) + + assert len(seen) == 3 + for query_id in seen: + assert query_id is not None + assert query_id != "" + + +def test_query_id_update(driver, conn) -> None: + with conn.cursor() as cursor: + query_id = cursor.adbc_statement.get_option("trino.statement.last_query_id") + assert query_id == "" + + cursor.adbc_statement.set_sql_query("DROP TABLE IF EXISTS foobar") + cursor.adbc_statement.execute_update() + + query_id = cursor.adbc_statement.get_option("trino.statement.last_query_id") + assert query_id is not None + assert query_id != "" + + +def test_query_id_update_bind(driver, conn) -> None: + with conn.cursor() as cursor: + query_id = cursor.adbc_statement.get_option("trino.statement.last_query_id") + assert query_id == "" + + cursor.adbc_statement.set_sql_query("SELECT ? + 1") + cursor.adbc_statement.bind_stream(pyarrow.table({"col1": [1, 2, 3]})) + cursor.adbc_statement.execute_update() + query_id = cursor.adbc_statement.get_option("trino.statement.last_query_id") + assert query_id is not None + assert query_id != "" + + +def test_query_id_ingest(driver, conn) -> None: + with conn.cursor() as cursor: + query_id = cursor.adbc_statement.get_option("trino.statement.last_query_id") + assert query_id == "" + + rb = pyarrow.record_batch({"col1": range(4096)}) + table = pyarrow.Table.from_batches([rb] * 16) + cursor.adbc_ingest("foobar", table, mode="replace") + + query_id = cursor.adbc_statement.get_option("trino.statement.last_query_id") + assert query_id is not None + assert query_id != "" + + +def test_query_id_ingest_batch_size(driver, conn) -> None: + # driver takes a different route here + with conn.cursor() as cursor: + query_id = cursor.adbc_statement.get_option("trino.statement.last_query_id") + assert query_id == "" + + cursor.adbc_statement.set_options( + **{"adbc.statement.ingest.batch_size": "2000"} + ) + + rb = pyarrow.record_batch({"col1": range(4096)}) + table = pyarrow.Table.from_batches([rb] * 16) + cursor.adbc_ingest("foobar", table, mode="replace") + + query_id = cursor.adbc_statement.get_option("trino.statement.last_query_id") + assert query_id is not None + assert query_id != "" diff --git a/go/validation/tests/trino.py b/go/validation/tests/trino.py index cf7f8b8..313b1a8 100644 --- a/go/validation/tests/trino.py +++ b/go/validation/tests/trino.py @@ -23,8 +23,8 @@ class TrinoQuirks(model.DriverQuirks): driver = "adbc_driver_trino" driver_name = "ADBC Driver Foundry Driver for Trino" vendor_name = "Trino" - vendor_version = re.compile(r"Trino 480") - short_version = "480" + vendor_version = re.compile(r"Trino 481") + short_version = "481" features = model.DriverFeatures( connection_get_table_schema=True, connection_transactions=False,