From 6b9eb08a988c714edaa62a5738f36ef7ee53d1e5 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 2 Jun 2026 15:09:51 +0900 Subject: [PATCH 1/3] feat(sqlwrapper): allow extending statement implementation --- sqlwrapper/connection.go | 5 +- sqlwrapper/driver.go | 13 +++++ sqlwrapper/record_reader.go | 13 +++-- sqlwrapper/statement.go | 105 +++++++++++++++++++++++++----------- 4 files changed, 101 insertions(+), 35 deletions(-) diff --git a/sqlwrapper/connection.go b/sqlwrapper/connection.go index 5bddbf7..20b5b5c 100644 --- a/sqlwrapper/connection.go +++ b/sqlwrapper/connection.go @@ -40,6 +40,8 @@ type ConnectionImplBase struct { driverbase.ConnectionImplBase Derived ConnectionImpl + dbImpl *databaseImpl + // Conn is the dedicated SQL connection for this ADBC session Conn *LoggingConn // TypeConverter handles SQL-to-Arrow type conversion @@ -64,6 +66,7 @@ func newConnection(ctx context.Context, db *databaseImpl) (adbc.ConnectionWithCo // Create the base sqlwrapper connection first sqlwrapperConn := &ConnectionImplBase{ ConnectionImplBase: base, + dbImpl: db, Conn: &LoggingConn{Conn: sqlConn, Logger: base.Logger}, TypeConverter: db.typeConverter, Db: db.db, @@ -112,7 +115,7 @@ func newConnection(ctx context.Context, db *databaseImpl) (adbc.ConnectionWithCo // NewStatement satisfies adbc.Connection func (c *ConnectionImplBase) NewStatement(ctx context.Context) (adbc.StatementWithContext, error) { - return newStatement(c), nil + return newStatement(c) } // SetTypeConverter allows higher-level drivers to customize type conversion diff --git a/sqlwrapper/driver.go b/sqlwrapper/driver.go index 3a7ef4a..8dd3149 100644 --- a/sqlwrapper/driver.go +++ b/sqlwrapper/driver.go @@ -37,6 +37,10 @@ type ConnectionFactory interface { ) (ConnectionImpl, error) } +type StatementFactory interface { + CreateStatement(stmt *StatementImplBase) (StatementImpl, error) +} + // DBFactory handles creation of *sql.DB from connection options. // Each driver is expected to implement this interface to provide database-specific // DSN construction and connection logic for their particular database format. @@ -51,6 +55,7 @@ type Driver struct { driverName string typeConverter TypeConverter connectionFactory ConnectionFactory + stmtFactory StatementFactory dbFactory DBFactory errorInspector driverbase.ErrorInspector } @@ -70,6 +75,7 @@ func NewDriver(alloc memory.Allocator, driverName, vendorName string, dbFactory driverName: driverName, typeConverter: converter, connectionFactory: nil, // No custom factory by default + stmtFactory: nil, dbFactory: dbFactory, } } @@ -82,6 +88,11 @@ func (d *Driver) WithConnectionFactory(factory ConnectionFactory) *Driver { return d } +func (d *Driver) WithStatementFactory(factory StatementFactory) *Driver { + d.stmtFactory = factory + return d +} + // WithErrorInspector sets a custom error inspector for extracting database error metadata. // This allows drivers to map database-specific errors to ADBC status codes and extract // SQLSTATE, vendor codes, and other error information. @@ -100,6 +111,7 @@ type databaseImpl struct { typeConverter TypeConverter // connectionFactory creates custom connection implementations if provided connectionFactory ConnectionFactory + stmtFactory StatementFactory } // NewDatabaseWithContext is the main entrypoint for driver‐agnostic ADBC database creation. @@ -132,6 +144,7 @@ func (d *Driver) NewDatabaseWithContext(ctx context.Context, opts map[string]str db: sqlDB, typeConverter: d.typeConverter, connectionFactory: d.connectionFactory, + stmtFactory: d.stmtFactory, } return driverbase.NewDatabase(db), nil } diff --git a/sqlwrapper/record_reader.go b/sqlwrapper/record_reader.go index 1fa7a30..ffd0f67 100644 --- a/sqlwrapper/record_reader.go +++ b/sqlwrapper/record_reader.go @@ -35,10 +35,11 @@ type sqlRecordReaderImpl struct { schema *arrow.Schema // For bind parameter support - conn *LoggingConn // Database connection to execute queries - query string // Original SQL query with placeholders - stmt *LoggingStmt // Prepared statement (optional) - typeConverter TypeConverter // Type converter for building schemas + conn *LoggingConn // Database connection to execute queries + query string // Original SQL query with placeholders + stmt *LoggingStmt // Prepared statement (optional) + typeConverter TypeConverter // Type converter for building schemas + additionalParams []any // Performance optimization: pre-computed inserters to avoid per-value type switching columnInserters []Inserter @@ -69,6 +70,10 @@ func (s *sqlRecordReaderImpl) NextResultSet(ctx context.Context, rec arrow.Recor } } + if len(s.additionalParams) > 0 { + args = append(args, s.additionalParams...) + } + // Execute query (with or without parameters) if s.stmt != nil { s.rows, err = s.stmt.QueryContext(ctx, args...) diff --git a/sqlwrapper/statement.go b/sqlwrapper/statement.go index 5904fb1..51a4f9c 100644 --- a/sqlwrapper/statement.go +++ b/sqlwrapper/statement.go @@ -31,7 +31,7 @@ import ( // BulkIngester interface allows drivers to implement database-specific bulk ingest functionality type BulkIngester interface { - ExecuteBulkIngest(ctx context.Context, conn *LoggingConn, options *driverbase.BulkIngestOptions, stream array.RecordReader) (int64, error) + ExecuteBulkIngest(ctx context.Context, stmt StatementImpl, conn *LoggingConn, options *driverbase.BulkIngestOptions, stream array.RecordReader) (int64, error) // QuoteIdentifier quotes a table/column identifier for SQL QuoteIdentifier(name string) string @@ -49,9 +49,17 @@ const ( OptionKeyBatchSize = "adbc.statement.batch_size" ) -// statementImpl implements the ADBC Statement interface on top of database/sql. -type statementImpl struct { +// StatementImpl implements the ADBC Statement interface on top of database/sql. +type StatementImpl interface { + driverbase.StatementImpl + + // Inject driver-specific params into Exec, as many drivers implement special functionality in this way + GetAdditionalExecParams() []any +} + +type StatementImplBase struct { driverbase.StatementImplBase + Derived StatementImpl // conn is the dedicated SQL connection conn *LoggingConn @@ -64,7 +72,7 @@ type statementImpl struct { // boundStream holds the bound Arrow record stream for bulk operations boundStream array.RecordReader // batchSize controls how many records to process at once during streaming execution - batchSize int + batchSize int64 // typeConverter handles SQL-to-Arrow type conversion typeConverter TypeConverter @@ -75,25 +83,42 @@ type statementImpl struct { } // Base returns the embedded StatementImplBase for driverbase plumbing -func (s *statementImpl) Base() *driverbase.StatementImplBase { +func (s *StatementImplBase) Base() *driverbase.StatementImplBase { return &s.StatementImplBase } +func (s *StatementImplBase) GetAdditionalExecParams() []any { + return nil +} + // newStatement constructs a new StatementImpl wrapped by driverbase -func newStatement(c *ConnectionImplBase) adbc.StatementWithContext { +func newStatement(c *ConnectionImplBase) (adbc.StatementWithContext, error) { base := driverbase.NewStatementImplBase(&c.ConnectionImplBase, c.ErrorHelper) - return driverbase.NewStatement(&statementImpl{ + wrapper := &StatementImplBase{ StatementImplBase: base, conn: c.Conn, connectionImpl: c.Derived, batchSize: 1000, // Default batch size for streaming operations typeConverter: c.TypeConverter, bulkIngestOptions: driverbase.NewBulkIngestOptions(), - }) + } + + var impl StatementImpl + if c.dbImpl.stmtFactory != nil { + var err error + impl, err = c.dbImpl.stmtFactory.CreateStatement(wrapper) + if err != nil { + return nil, err + } + } else { + impl = wrapper + } + wrapper.Derived = impl + return driverbase.NewStatement(impl), nil } // SetSqlQuery stores the SQL text on the statement -func (s *statementImpl) SetSqlQuery(ctx context.Context, query string) error { +func (s *StatementImplBase) SetSqlQuery(ctx context.Context, query string) error { if err := s.connectionImpl.ClearPending(); err != nil { return err } @@ -114,7 +139,7 @@ func (s *statementImpl) SetSqlQuery(ctx context.Context, query string) error { } // SetOption sets a string option on this statement -func (s *statementImpl) SetOption(ctx context.Context, key, val string) error { +func (s *StatementImplBase) SetOption(ctx context.Context, key, val string) error { // Let driverbase handle standard bulk ingest options first if handled, err := s.bulkIngestOptions.SetOption(&s.Base().ErrorHelper, key, val); err != nil { return err @@ -128,14 +153,23 @@ func (s *statementImpl) SetOption(ctx context.Context, key, val string) error { if err != nil { return s.Base().ErrorHelper.InvalidArgument("invalid batch size: %v", err) } - return s.SetBatchSize(size) + return s.SetBatchSize(int64(size)) + default: + return s.StatementImplBase.SetOption(ctx, key, val) + } +} + +func (s *StatementImplBase) SetOptionInt(ctx context.Context, key string, val int64) error { + switch key { + case OptionKeyBatchSize: + return s.SetBatchSize(val) default: - return s.Base().ErrorHelper.NotImplemented("unsupported option: %s", key) + return s.StatementImplBase.SetOptionInt(ctx, key, val) } } // Bind uses an arrow record batch to bind parameters to the query -func (s *statementImpl) Bind(ctx context.Context, record arrow.RecordBatch) error { +func (s *StatementImplBase) Bind(ctx context.Context, record arrow.RecordBatch) error { if record == nil { return s.Base().ErrorHelper.InvalidArgument("record cannot be nil") } @@ -152,7 +186,7 @@ func (s *statementImpl) Bind(ctx context.Context, record arrow.RecordBatch) erro } // BindStream uses a record batch stream to bind parameters for bulk operations -func (s *statementImpl) BindStream(ctx context.Context, stream array.RecordReader) error { +func (s *StatementImplBase) BindStream(ctx context.Context, stream array.RecordReader) error { if stream == nil { return s.Base().ErrorHelper.InvalidArgument("stream cannot be nil") } @@ -171,7 +205,7 @@ func (s *statementImpl) BindStream(ctx context.Context, stream array.RecordReade } // ExecuteUpdate runs DML/DDL and returns rows affected -func (s *statementImpl) ExecuteUpdate(ctx context.Context) (int64, error) { +func (s *StatementImplBase) ExecuteUpdate(ctx context.Context) (int64, error) { if err := s.connectionImpl.ClearPending(); err != nil { return -1, err } @@ -198,10 +232,11 @@ func (s *statementImpl) ExecuteUpdate(ctx context.Context) (int64, error) { var res sql.Result var err error + params := s.Derived.GetAdditionalExecParams() if s.stmt != nil { - res, err = s.stmt.ExecContext(ctx) + res, err = s.stmt.ExecContext(ctx, params...) } else { - res, err = s.conn.ExecContext(ctx, s.query) + res, err = s.conn.ExecContext(ctx, s.query, params...) } if err != nil { return -1, s.Base().ErrorHelper.WrapIO(err, "failed to execute statement") @@ -214,7 +249,7 @@ func (s *statementImpl) ExecuteUpdate(ctx context.Context) (int64, error) { } // ExecuteSchema returns the Arrow schema by querying zero rows -func (s *statementImpl) ExecuteSchema(ctx context.Context) (schema *arrow.Schema, err error) { +func (s *StatementImplBase) ExecuteSchema(ctx context.Context) (schema *arrow.Schema, err error) { if s.query == "" { return nil, s.Base().ErrorHelper.InvalidState("no query set") } @@ -259,7 +294,7 @@ func (c closer) Close() error { } // ExecuteQuery runs a SELECT and returns a RecordReader for streaming Arrow records -func (s *statementImpl) ExecuteQuery(ctx context.Context) (reader array.RecordReader, rowCount int64, err error) { +func (s *StatementImplBase) ExecuteQuery(ctx context.Context) (reader array.RecordReader, rowCount int64, err error) { if s.bulkIngestOptions.IsSet() { rowCount, err = s.executeBulkIngest(ctx) return @@ -275,10 +310,11 @@ func (s *statementImpl) ExecuteQuery(ctx context.Context) (reader array.RecordRe // Create the record reader implementation with all the state impl := &sqlRecordReaderImpl{ - conn: s.conn, - query: s.query, - stmt: s.stmt, - typeConverter: s.typeConverter, + conn: s.conn, + query: s.query, + stmt: s.stmt, + typeConverter: s.typeConverter, + additionalParams: s.Derived.GetAdditionalExecParams(), } // Let BaseRecordReader handle parameterized vs non-parameterized logic @@ -290,7 +326,7 @@ func (s *statementImpl) ExecuteQuery(ctx context.Context) (reader array.RecordRe // TODO(lidavidm): when given ctx is cancelled, cancel the query, but // not in a way that breaks the connection! options := driverbase.BaseRecordReaderOptions{ - BatchRowLimit: int64(s.batchSize), + BatchRowLimit: s.batchSize, } if err := baseRecordReader.Init(context.Background(), memory.DefaultAllocator, s.connectionImpl.Base().Logger, s.boundStream, options, impl); err != nil { @@ -308,7 +344,7 @@ func (s *statementImpl) ExecuteQuery(ctx context.Context) (reader array.RecordRe } // Close shuts down the prepared stmt (if any) and releases bound resources -func (s *statementImpl) Close(ctx context.Context) error { +func (s *StatementImplBase) Close(ctx context.Context) error { // Check if already closed if s.closed { return s.Base().ErrorHelper.InvalidState("statement already closed") @@ -331,7 +367,7 @@ func (s *statementImpl) Close(ctx context.Context) error { return s.StatementImplBase.Close(ctx) } -func (s *statementImpl) Prepare(ctx context.Context) (err error) { +func (s *StatementImplBase) Prepare(ctx context.Context) (err error) { if s.query == "" { return s.Base().ErrorHelper.InvalidArgument("no query to prepare") } @@ -356,7 +392,7 @@ func (s *statementImpl) Prepare(ctx context.Context) (err error) { } // SetBatchSize configures the batch size for streaming operations -func (s *statementImpl) SetBatchSize(size int) error { +func (s *StatementImplBase) SetBatchSize(size int64) error { if size <= 0 { return s.Base().ErrorHelper.InvalidArgument("batch size must be positive") } @@ -365,7 +401,7 @@ func (s *statementImpl) SetBatchSize(size int) error { } // executeBulkUpdate executes bulk updates by iterating through the bound stream directly -func (s *statementImpl) executeBulkUpdate(ctx context.Context) (totalAffected int64, err error) { +func (s *StatementImplBase) executeBulkUpdate(ctx context.Context) (totalAffected int64, err error) { if s.query == "" { return -1, s.Base().ErrorHelper.InvalidArgument("no query set") } @@ -390,6 +426,10 @@ func (s *statementImpl) executeBulkUpdate(ctx context.Context) (totalAffected in } params := make([]any, s.boundStream.Schema().NumFields()) + additionalParams := s.Derived.GetAdditionalExecParams() + if len(additionalParams) > 0 { + params = append(params, additionalParams...) + } for s.boundStream.Next() { record := s.boundStream.RecordBatch() for rowIdx := range int(record.NumRows()) { @@ -425,7 +465,7 @@ func (s *statementImpl) executeBulkUpdate(ctx context.Context) (totalAffected in } // executeBulkIngest executes bulk ingest using the connection's ExecuteBulkIngest method -func (s *statementImpl) executeBulkIngest(ctx context.Context) (int64, error) { +func (s *StatementImplBase) executeBulkIngest(ctx context.Context) (int64, error) { // Check for proper bulk ingest setup if s.boundStream == nil { return -1, s.Base().ErrorHelper.InvalidState("bulk ingest options are set but no stream is bound - call BindStream() first") @@ -438,7 +478,7 @@ func (s *statementImpl) executeBulkIngest(ctx context.Context) (int64, error) { s.boundStream = nil }() - rowCount, err := ingester.ExecuteBulkIngest(ctx, s.conn, &s.bulkIngestOptions, s.boundStream) + rowCount, err := ingester.ExecuteBulkIngest(ctx, s.Derived, s.conn, &s.bulkIngestOptions, s.boundStream) if err != nil { return -1, err } @@ -460,6 +500,7 @@ func (s *statementImpl) executeBulkIngest(ctx context.Context) (int64, error) { // - uses options.IngestBatchSize (defaults to 1000 if <= 0). func ExecuteBatchedBulkIngest( ctx context.Context, + stmtImpl StatementImpl, conn *LoggingConn, options *driverbase.BulkIngestOptions, stream array.RecordReader, @@ -483,6 +524,7 @@ func ExecuteBatchedBulkIngest( } quotedTableName := ingester.QuoteIdentifier(options.TableName) + params := stmtImpl.GetAdditionalExecParams() iterator, err := NewRowBufferIterator(stream, batchSize, typeConverter) if err != nil { @@ -500,6 +542,9 @@ func ExecuteBatchedBulkIngest( for iterator.Next() { buffer, rowCount := iterator.CurrentBatch() + if len(params) > 0 { + buffer = append(buffer, params...) + } if rowCount == batchSize { // Full batch: use pre-prepared statement From 813a03cfc1735ae71f285a941f6a2125f45d6e51 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 5 Jun 2026 15:05:26 +0900 Subject: [PATCH 2/3] fix(ffitemplate): clone input slices where necessary --- ffitemplate/_tmpl/driver.go.tmpl | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/ffitemplate/_tmpl/driver.go.tmpl b/ffitemplate/_tmpl/driver.go.tmpl index e0a72bd..28f760e 100644 --- a/ffitemplate/_tmpl/driver.go.tmpl +++ b/ffitemplate/_tmpl/driver.go.tmpl @@ -60,6 +60,7 @@ import ( "os" "runtime" "runtime/cgo" + "slices" "strings" "sync/atomic" "unsafe" @@ -148,6 +149,7 @@ func setErrWithDetails(err *C.struct_AdbcError, adbcError adbc.Error) { cErr.values = (**C.cuint8_t)(C.calloc(C.size_t(numDetails), C.size_t(unsafe.Sizeof((*C.cuint8_t)(nil))))) cErr.lengths = (*C.size_t)(C.calloc(C.size_t(numDetails), C.sizeof_size_t)) + // SAFETY: no copy of fromCArr because these are written to, not read from keys := fromCArr[*C.cchar_t](cErr.keys, numDetails) values := fromCArr[*C.cuint8_t](cErr.values, numDetails) lengths := fromCArr[C.size_t](cErr.lengths, numDetails) @@ -266,6 +268,7 @@ func getFromHandle[T any](ptr unsafe.Pointer) *T { func exportStringOption(val string, out *C.char, length *C.size_t) C.AdbcStatusCode { lenWithTerminator := C.size_t(len(val) + 1) if lenWithTerminator <= *length { + // SAFETY: no copy of fromCArr because this is written to, not read from sink := fromCArr[byte]((*byte)(unsafe.Pointer(out)), int(*length)) copy(sink, val) sink[len(val)] = 0 @@ -276,6 +279,7 @@ func exportStringOption(val string, out *C.char, length *C.size_t) C.AdbcStatusC func exportBytesOption(val []byte, out *C.uint8_t, length *C.size_t) C.AdbcStatusCode { if C.size_t(len(val)) <= *length { + // SAFETY: no copy of fromCArr because this is written to, not read from sink := fromCArr[byte]((*byte)(out), int(*length)) copy(sink, val) } @@ -703,7 +707,7 @@ func {{.Prefix}}DatabaseSetOptionBytes(db *C.struct_AdbcDatabase, key *C.cchar_t } cdb := getFromHandle[cDatabase](db.private_data) k := C.GoString(key) - v := fromCArr[byte](value, int(length)) + v := slices.Clone(fromCArr[byte](value, int(length))) if cdb.db != nil { e := cdb.db.SetOptionBytes(cdb.newContext(), k, v) @@ -924,7 +928,7 @@ func {{.Prefix}}ConnectionSetOptionBytes(db *C.struct_AdbcConnection, key *C.cch return C.ADBC_STATUS_INVALID_STATE } - e := conn.cnxn.SetOptionBytes(conn.newContext(), C.GoString(key), fromCArr[byte](value, int(length))) + e := conn.cnxn.SetOptionBytes(conn.newContext(), C.GoString(key), slices.Clone(fromCArr[byte](value, int(length)))) return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -1034,6 +1038,7 @@ func {{.Prefix}}ConnectionRelease(cnxn *C.struct_AdbcConnection, err *C.struct_A return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Close(conn.newContext()))) } +// SAFETY: at each call site, consider whether a copy of the resulting slice must be made func fromCArr[T, CType any](ptr *CType, sz int) []T { if ptr == nil || sz == 0 { return nil @@ -1106,7 +1111,7 @@ func {{.Prefix}}ConnectionGetInfo(cnxn *C.struct_AdbcConnection, codes *C.cuint3 return C.ADBC_STATUS_INVALID_STATE } - infoCodes := fromCArr[adbc.InfoCode](codes, int(len)) + infoCodes := slices.Clone(fromCArr[adbc.InfoCode](codes, int(len))) rdr, e := conn.cnxn.GetInfo(conn.newContext(), infoCodes) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) @@ -1247,7 +1252,7 @@ func {{.Prefix}}ConnectionReadPartition(cnxn *C.struct_AdbcConnection, serialize return C.ADBC_STATUS_INVALID_STATE } - rdr, e := conn.cnxn.ReadPartition(conn.newContext(), fromCArr[byte](serialized, int(serializedLen))) + rdr, e := conn.cnxn.ReadPartition(conn.newContext(), slices.Clone(fromCArr[byte](serialized, int(serializedLen)))) if e != nil { return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -1605,7 +1610,7 @@ func {{.Prefix}}StatementSetSubstraitPlan(stmt *C.struct_AdbcStatement, plan *C. return C.ADBC_STATUS_INVALID_STATE } - e := st.stmt.SetSubstraitPlan(st.newContext(), fromCArr[byte](plan, int(length))) + e := st.stmt.SetSubstraitPlan(st.newContext(), slices.Clone(fromCArr[byte](plan, int(length)))) return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -1706,7 +1711,7 @@ func {{.Prefix}}StatementSetOptionBytes(db *C.struct_AdbcStatement, key *C.cchar return C.ADBC_STATUS_NOT_IMPLEMENTED } - e := opts.SetOptionBytes(st.newContext(), C.GoString(key), fromCArr[byte](value, int(length))) + e := opts.SetOptionBytes(st.newContext(), C.GoString(key), slices.Clone(fromCArr[byte](value, int(length)))) return C.AdbcStatusCode(errToAdbcErr(err, e)) } @@ -1808,6 +1813,7 @@ func {{.Prefix}}StatementExecutePartitions(stmt *C.struct_AdbcStatement, schema totalLen += len(p) } partitions.private_data = C.calloc(C.size_t(totalLen), C.size_t(1)) + // SAFETY: no copy of fromCArr because this is written to, not read from dst := fromCArr[byte]((*byte)(partitions.private_data), totalLen) partIDs := fromCArr[*C.cuint8_t](partitions.partitions, int(partitions.num_partitions)) @@ -1829,9 +1835,11 @@ func AdbcDriver{{.Prefix}}Init(version C.int, rawDriver *C.void, err *C.struct_A switch version { case C.ADBC_VERSION_1_0_0: + // SAFETY: no copy of fromCArr because this is written to, not read from sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_0_0_SIZE) memory.Set(sink, 0) case C.ADBC_VERSION_1_1_0: + // SAFETY: no copy of fromCArr because this is written to, not read from sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_1_0_SIZE) memory.Set(sink, 0) default: From c33fa1b45e0bc08da3b7b5675d698b45f36a32b1 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 8 Jun 2026 07:33:43 +0900 Subject: [PATCH 3/3] feedback --- sqlwrapper/statement.go | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/sqlwrapper/statement.go b/sqlwrapper/statement.go index 51a4f9c..a0d623b 100644 --- a/sqlwrapper/statement.go +++ b/sqlwrapper/statement.go @@ -149,11 +149,11 @@ func (s *StatementImplBase) SetOption(ctx context.Context, key, val string) erro switch key { case OptionKeyBatchSize: - size, err := strconv.Atoi(val) + size, err := strconv.ParseInt(val, 10, 64) if err != nil { return s.Base().ErrorHelper.InvalidArgument("invalid batch size: %v", err) } - return s.SetBatchSize(int64(size)) + return s.SetBatchSize(size) default: return s.StatementImplBase.SetOption(ctx, key, val) } @@ -508,8 +508,20 @@ func ExecuteBatchedBulkIngest( ingester BulkIngester, errorHelper *driverbase.ErrorHelper, ) (totalRowsInserted int64, err error) { - if stream == nil { - return -1, errorHelper.InvalidArgument("stream cannot be nil") + if errorHelper == nil { + return -1, errors.New("errorHelper cannot be nil") + } else if stream == nil { + return -1, errorHelper.Internal("stream cannot be nil") + } else if stmtImpl == nil { + return -1, errorHelper.Internal("stmtImpl cannot be nil") + } else if conn == nil { + return -1, errorHelper.Internal("conn cannot be nil") + } else if options == nil { + return -1, errorHelper.Internal("options cannot be nil") + } else if typeConverter == nil { + return -1, errorHelper.Internal("typeConverter cannot be nil") + } else if ingester == nil { + return -1, errorHelper.Internal("ingester cannot be nil") } batchSize := options.IngestBatchSize