Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions go/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,22 @@ const (
TrinoMaxQuerySizeBytes = 1_000_000
)

// trinoConnectionImpl extends sqlwrapper connection with DbObjectsEnumerator
type trinoConnectionImpl struct {
*sqlwrapper.ConnectionImplBase // Embed sqlwrapper connection for all standard functionality

version string
}

// implements BulkIngester interface
var _ sqlwrapper.BulkIngester = (*trinoConnectionImpl)(nil)

// implements DbObjectsEnumerator interface
var _ driverbase.DbObjectsEnumerator = (*trinoConnectionImpl)(nil)

// implements CurrentNameSpacer interface
var _ driverbase.CurrentNamespacer = (*trinoConnectionImpl)(nil)

// GetCurrentCatalog implements driverbase.CurrentNamespacer.
func (c *trinoConnectionImpl) GetCurrentCatalog(ctx context.Context) (string, error) {
var catalog string
Expand Down Expand Up @@ -186,15 +202,18 @@ func (c *trinoConnectionImpl) GetPlaceholder(field *arrow.Field, index int) stri
var _ sqlwrapper.BulkIngester = (*trinoConnectionImpl)(nil)

// ExecuteBulkIngest performs Trino bulk ingest using batched INSERT statements.
func (c *trinoConnectionImpl) ExecuteBulkIngest(ctx context.Context, conn *sqlwrapper.LoggingConn, options *driverbase.BulkIngestOptions, stream array.RecordReader) (rowCount int64, err error) {
func (c *trinoConnectionImpl) ExecuteBulkIngest(ctx context.Context, stmt sqlwrapper.StatementImpl, conn *sqlwrapper.LoggingConn, options *driverbase.BulkIngestOptions, stream array.RecordReader) (rowCount int64, err error) {
// inject the query ID capture through everything
params := stmt.(*trinoStatement).GetAdditionalExecParams()

schema := stream.Schema()
if err := c.createTableIfNeeded(ctx, conn, options.TableName, schema, options); err != nil {
if err := c.createTableIfNeeded(ctx, conn, options.TableName, schema, options, params); err != nil {
return -1, c.ErrorHelper.WrapIO(err, "failed to create table")
}

if options.IngestBatchSize > 0 {
return sqlwrapper.ExecuteBatchedBulkIngest(
ctx, conn, options, stream,
ctx, stmt, conn, options, stream,
c.TypeConverter, c, &c.Base().ErrorHelper,
)
}
Expand All @@ -204,7 +223,7 @@ func (c *trinoConnectionImpl) ExecuteBulkIngest(ctx context.Context, conn *sqlwr
}

// Use Trino-specific batching with accurate serialized size measurement
return c.executeDynamicBatchedIngest(ctx, conn, options, stream)
return c.executeDynamicBatchedIngest(ctx, conn, options, stream, params)
}

// executeDynamicBatchedIngest performs batched INSERT with incremental query building.
Expand All @@ -220,6 +239,7 @@ func (c *trinoConnectionImpl) executeDynamicBatchedIngest(
conn *sqlwrapper.LoggingConn,
options *driverbase.BulkIngestOptions,
stream array.RecordReader,
params []any,
) (int64, error) {
var totalRowsInserted int64
schema := stream.Schema()
Expand Down Expand Up @@ -269,7 +289,7 @@ func (c *trinoConnectionImpl) executeDynamicBatchedIngest(

if queryBuilder.Len()+additionalLength > options.MaxQuerySizeBytes && currentBatchRows > 0 {
// Execute current batch
rowsInserted, err := c.executeBatch(ctx, conn, queryBuilder.String())
rowsInserted, err := c.executeBatch(ctx, conn, queryBuilder.String(), params)
if err != nil {
return totalRowsInserted, c.ErrorHelper.WrapIO(err,
"failed to insert batch at rows %d-%d", startRowIdx, startRowIdx+currentBatchRows-1)
Expand All @@ -294,7 +314,7 @@ func (c *trinoConnectionImpl) executeDynamicBatchedIngest(

// Execute final batch for this record batch
if currentBatchRows > 0 {
rowsInserted, err := c.executeBatch(ctx, conn, queryBuilder.String())
rowsInserted, err := c.executeBatch(ctx, conn, queryBuilder.String(), params)
if err != nil {
return totalRowsInserted, c.ErrorHelper.WrapIO(err,
"failed to insert final batch at rows %d-%d", startRowIdx, startRowIdx+currentBatchRows-1)
Expand Down Expand Up @@ -355,12 +375,13 @@ func (c *trinoConnectionImpl) executeBatch(
ctx context.Context,
conn *sqlwrapper.LoggingConn,
querySQL string,
params []any,
) (int64, error) {
if querySQL == "" {
return 0, nil
}

result, err := conn.ExecContext(ctx, querySQL)
result, err := conn.ExecContext(ctx, querySQL, params...)
if err != nil {
return 0, err
}
Expand All @@ -374,20 +395,20 @@ func (c *trinoConnectionImpl) executeBatch(
}

// createTableIfNeeded creates the table based on the ingest mode
func (c *trinoConnectionImpl) createTableIfNeeded(ctx context.Context, conn *sqlwrapper.LoggingConn, tableName string, schema *arrow.Schema, options *driverbase.BulkIngestOptions) error {
func (c *trinoConnectionImpl) createTableIfNeeded(ctx context.Context, conn *sqlwrapper.LoggingConn, tableName string, schema *arrow.Schema, options *driverbase.BulkIngestOptions, params []any) error {
switch options.Mode {
case adbc.OptionValueIngestModeCreate:
// Create the table (fail if exists)
return c.createTable(ctx, conn, tableName, schema, false)
return c.createTable(ctx, conn, tableName, schema, false, params)
case adbc.OptionValueIngestModeCreateAppend:
// Create the table if it doesn't exist
return c.createTable(ctx, conn, tableName, schema, true)
return c.createTable(ctx, conn, tableName, schema, true, params)
case adbc.OptionValueIngestModeReplace:
// Drop and recreate the table
if err := c.dropTable(ctx, conn, tableName); err != nil {
if err := c.dropTable(ctx, conn, tableName, params); err != nil {
return err
}
return c.createTable(ctx, conn, tableName, schema, false)
return c.createTable(ctx, conn, tableName, schema, false, params)
case adbc.OptionValueIngestModeAppend:
// Table should already exist, do nothing
return nil
Expand All @@ -397,7 +418,7 @@ func (c *trinoConnectionImpl) createTableIfNeeded(ctx context.Context, conn *sql
}

// createTable creates a Trino table from Arrow schema
func (c *trinoConnectionImpl) createTable(ctx context.Context, conn *sqlwrapper.LoggingConn, tableName string, schema *arrow.Schema, ifNotExists bool) error {
func (c *trinoConnectionImpl) createTable(ctx context.Context, conn *sqlwrapper.LoggingConn, tableName string, schema *arrow.Schema, ifNotExists bool, params []any) error {
var queryBuilder strings.Builder
queryBuilder.WriteString("CREATE TABLE ")
if ifNotExists {
Expand All @@ -424,14 +445,14 @@ func (c *trinoConnectionImpl) createTable(ctx context.Context, conn *sqlwrapper.

queryBuilder.WriteString(")")

_, err := conn.ExecContext(ctx, queryBuilder.String())
_, err := conn.ExecContext(ctx, queryBuilder.String(), params...)
return err
}

// dropTable drops a Trino table
func (c *trinoConnectionImpl) dropTable(ctx context.Context, conn *sqlwrapper.LoggingConn, tableName string) error {
func (c *trinoConnectionImpl) dropTable(ctx context.Context, conn *sqlwrapper.LoggingConn, tableName string, params []any) error {
dropSQL := fmt.Sprintf("DROP TABLE IF EXISTS %s", quoteIdentifier(tableName))
_, err := conn.ExecContext(ctx, dropSQL)
_, err := conn.ExecContext(ctx, dropSQL, params...)
return err
}

Expand Down
46 changes: 23 additions & 23 deletions go/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ module github.com/adbc-drivers/trino
go 1.26.0

require (
github.com/adbc-drivers/driverbase-go/driverbase v0.0.0-20260423045143-148150eea03b
github.com/adbc-drivers/driverbase-go/sqlwrapper v0.0.0-20260423045143-148150eea03b
github.com/adbc-drivers/driverbase-go/driverbase v0.0.0-20260608005845-f218ccb883e8
github.com/adbc-drivers/driverbase-go/sqlwrapper v0.0.0-20260608005845-f218ccb883e8
github.com/adbc-drivers/driverbase-go/testutil v0.0.0-20260423045143-148150eea03b
github.com/adbc-drivers/driverbase-go/validation v0.0.0-20260423045143-148150eea03b
github.com/apache/arrow-adbc/go/adbc v1.11.0
Expand All @@ -30,7 +30,7 @@ require (

require (
github.com/andybalholm/brotli v1.2.1 // indirect
github.com/apache/thrift v0.22.0 // indirect
github.com/apache/thrift v0.23.0 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
Expand All @@ -46,33 +46,33 @@ require (
github.com/jcmturner/goidentity/v6 v6.0.1 // indirect
github.com/jcmturner/gokrb5/v8 v8.4.4 // indirect
github.com/jcmturner/rpc/v2 v2.0.3 // indirect
github.com/klauspost/compress v1.18.5 // indirect
github.com/klauspost/compress v1.18.6 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/pierrec/lz4 v2.6.1+incompatible // indirect
github.com/pierrec/lz4/v4 v4.1.26 // indirect
github.com/pierrec/lz4/v4 v4.1.27 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/zeebo/xxh3 v1.1.0 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/otel v1.43.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.43.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0 // indirect
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.43.0 // indirect
go.opentelemetry.io/otel/metric v1.43.0 // indirect
go.opentelemetry.io/otel/sdk v1.43.0 // indirect
go.opentelemetry.io/otel/trace v1.43.0 // indirect
go.opentelemetry.io/otel v1.44.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.44.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.44.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.44.0 // indirect
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.44.0 // indirect
go.opentelemetry.io/otel/metric v1.44.0 // indirect
go.opentelemetry.io/otel/sdk v1.44.0 // indirect
go.opentelemetry.io/otel/trace v1.44.0 // indirect
go.opentelemetry.io/proto/otlp v1.10.0 // indirect
golang.org/x/crypto v0.50.0 // indirect
golang.org/x/exp v0.0.0-20260410095643-746e56fc9e2f // indirect
golang.org/x/mod v0.35.0 // indirect
golang.org/x/net v0.53.0 // indirect
golang.org/x/crypto v0.52.0 // indirect
golang.org/x/exp v0.0.0-20260603202125-055de637280b // indirect
golang.org/x/mod v0.36.0 // indirect
golang.org/x/net v0.55.0 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.43.0 // indirect
golang.org/x/text v0.36.0 // indirect
golang.org/x/tools v0.44.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20260420184626-e10c466a9529 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260420184626-e10c466a9529 // indirect
google.golang.org/grpc v1.80.0 // indirect
golang.org/x/sys v0.45.0 // indirect
golang.org/x/text v0.37.0 // indirect
golang.org/x/tools v0.45.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20260526163538-3dc84a4a5aaa // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20260526163538-3dc84a4a5aaa // indirect
google.golang.org/grpc v1.81.1 // indirect
google.golang.org/protobuf v1.36.11 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
Loading
Loading