Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions ffitemplate/_tmpl/driver.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ import (
"os"
"runtime"
"runtime/cgo"
"slices"
"strings"
"sync/atomic"
"unsafe"
Expand Down Expand Up @@ -148,6 +149,7 @@ func setErrWithDetails(err *C.struct_AdbcError, adbcError adbc.Error) {
cErr.values = (**C.cuint8_t)(C.calloc(C.size_t(numDetails), C.size_t(unsafe.Sizeof((*C.cuint8_t)(nil)))))
cErr.lengths = (*C.size_t)(C.calloc(C.size_t(numDetails), C.sizeof_size_t))

// SAFETY: no copy of fromCArr because these are written to, not read from
keys := fromCArr[*C.cchar_t](cErr.keys, numDetails)
values := fromCArr[*C.cuint8_t](cErr.values, numDetails)
lengths := fromCArr[C.size_t](cErr.lengths, numDetails)
Expand Down Expand Up @@ -266,6 +268,7 @@ func getFromHandle[T any](ptr unsafe.Pointer) *T {
func exportStringOption(val string, out *C.char, length *C.size_t) C.AdbcStatusCode {
lenWithTerminator := C.size_t(len(val) + 1)
if lenWithTerminator <= *length {
// SAFETY: no copy of fromCArr because this is written to, not read from
sink := fromCArr[byte]((*byte)(unsafe.Pointer(out)), int(*length))
copy(sink, val)
sink[len(val)] = 0
Expand All @@ -276,6 +279,7 @@ func exportStringOption(val string, out *C.char, length *C.size_t) C.AdbcStatusC

func exportBytesOption(val []byte, out *C.uint8_t, length *C.size_t) C.AdbcStatusCode {
if C.size_t(len(val)) <= *length {
// SAFETY: no copy of fromCArr because this is written to, not read from
sink := fromCArr[byte]((*byte)(out), int(*length))
copy(sink, val)
}
Expand Down Expand Up @@ -703,7 +707,7 @@ func {{.Prefix}}DatabaseSetOptionBytes(db *C.struct_AdbcDatabase, key *C.cchar_t
}
cdb := getFromHandle[cDatabase](db.private_data)
k := C.GoString(key)
v := fromCArr[byte](value, int(length))
v := slices.Clone(fromCArr[byte](value, int(length)))

if cdb.db != nil {
e := cdb.db.SetOptionBytes(cdb.newContext(), k, v)
Expand Down Expand Up @@ -924,7 +928,7 @@ func {{.Prefix}}ConnectionSetOptionBytes(db *C.struct_AdbcConnection, key *C.cch
return C.ADBC_STATUS_INVALID_STATE
}

e := conn.cnxn.SetOptionBytes(conn.newContext(), C.GoString(key), fromCArr[byte](value, int(length)))
e := conn.cnxn.SetOptionBytes(conn.newContext(), C.GoString(key), slices.Clone(fromCArr[byte](value, int(length))))
return C.AdbcStatusCode(errToAdbcErr(err, e))
}

Expand Down Expand Up @@ -1034,6 +1038,7 @@ func {{.Prefix}}ConnectionRelease(cnxn *C.struct_AdbcConnection, err *C.struct_A
return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Close(conn.newContext())))
}

// SAFETY: at each call site, consider whether a copy of the resulting slice must be made
func fromCArr[T, CType any](ptr *CType, sz int) []T {
if ptr == nil || sz == 0 {
return nil
Expand Down Expand Up @@ -1106,7 +1111,7 @@ func {{.Prefix}}ConnectionGetInfo(cnxn *C.struct_AdbcConnection, codes *C.cuint3
return C.ADBC_STATUS_INVALID_STATE
}

infoCodes := fromCArr[adbc.InfoCode](codes, int(len))
infoCodes := slices.Clone(fromCArr[adbc.InfoCode](codes, int(len)))
rdr, e := conn.cnxn.GetInfo(conn.newContext(), infoCodes)
if e != nil {
return C.AdbcStatusCode(errToAdbcErr(err, e))
Expand Down Expand Up @@ -1247,7 +1252,7 @@ func {{.Prefix}}ConnectionReadPartition(cnxn *C.struct_AdbcConnection, serialize
return C.ADBC_STATUS_INVALID_STATE
}

rdr, e := conn.cnxn.ReadPartition(conn.newContext(), fromCArr[byte](serialized, int(serializedLen)))
rdr, e := conn.cnxn.ReadPartition(conn.newContext(), slices.Clone(fromCArr[byte](serialized, int(serializedLen))))
if e != nil {
return C.AdbcStatusCode(errToAdbcErr(err, e))
}
Expand Down Expand Up @@ -1605,7 +1610,7 @@ func {{.Prefix}}StatementSetSubstraitPlan(stmt *C.struct_AdbcStatement, plan *C.
return C.ADBC_STATUS_INVALID_STATE
}

e := st.stmt.SetSubstraitPlan(st.newContext(), fromCArr[byte](plan, int(length)))
e := st.stmt.SetSubstraitPlan(st.newContext(), slices.Clone(fromCArr[byte](plan, int(length))))
return C.AdbcStatusCode(errToAdbcErr(err, e))
}

Expand Down Expand Up @@ -1706,7 +1711,7 @@ func {{.Prefix}}StatementSetOptionBytes(db *C.struct_AdbcStatement, key *C.cchar
return C.ADBC_STATUS_NOT_IMPLEMENTED
}

e := opts.SetOptionBytes(st.newContext(), C.GoString(key), fromCArr[byte](value, int(length)))
e := opts.SetOptionBytes(st.newContext(), C.GoString(key), slices.Clone(fromCArr[byte](value, int(length))))
return C.AdbcStatusCode(errToAdbcErr(err, e))
}

Expand Down Expand Up @@ -1808,6 +1813,7 @@ func {{.Prefix}}StatementExecutePartitions(stmt *C.struct_AdbcStatement, schema
totalLen += len(p)
}
partitions.private_data = C.calloc(C.size_t(totalLen), C.size_t(1))
// SAFETY: no copy of fromCArr because this is written to, not read from
dst := fromCArr[byte]((*byte)(partitions.private_data), totalLen)

partIDs := fromCArr[*C.cuint8_t](partitions.partitions, int(partitions.num_partitions))
Expand All @@ -1829,9 +1835,11 @@ func AdbcDriver{{.Prefix}}Init(version C.int, rawDriver *C.void, err *C.struct_A

switch version {
case C.ADBC_VERSION_1_0_0:
// SAFETY: no copy of fromCArr because this is written to, not read from
sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_0_0_SIZE)
memory.Set(sink, 0)
case C.ADBC_VERSION_1_1_0:
// SAFETY: no copy of fromCArr because this is written to, not read from
sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_1_0_SIZE)
memory.Set(sink, 0)
default:
Expand Down
5 changes: 4 additions & 1 deletion sqlwrapper/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ type ConnectionImplBase struct {
driverbase.ConnectionImplBase
Derived ConnectionImpl

dbImpl *databaseImpl

// Conn is the dedicated SQL connection for this ADBC session
Conn *LoggingConn
// TypeConverter handles SQL-to-Arrow type conversion
Expand All @@ -64,6 +66,7 @@ func newConnection(ctx context.Context, db *databaseImpl) (adbc.ConnectionWithCo
// Create the base sqlwrapper connection first
sqlwrapperConn := &ConnectionImplBase{
ConnectionImplBase: base,
dbImpl: db,
Conn: &LoggingConn{Conn: sqlConn, Logger: base.Logger},
TypeConverter: db.typeConverter,
Db: db.db,
Expand Down Expand Up @@ -112,7 +115,7 @@ func newConnection(ctx context.Context, db *databaseImpl) (adbc.ConnectionWithCo

// NewStatement satisfies adbc.Connection
func (c *ConnectionImplBase) NewStatement(ctx context.Context) (adbc.StatementWithContext, error) {
return newStatement(c), nil
return newStatement(c)
}

// SetTypeConverter allows higher-level drivers to customize type conversion
Expand Down
13 changes: 13 additions & 0 deletions sqlwrapper/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ type ConnectionFactory interface {
) (ConnectionImpl, error)
}

type StatementFactory interface {
CreateStatement(stmt *StatementImplBase) (StatementImpl, error)
}

// DBFactory handles creation of *sql.DB from connection options.
// Each driver is expected to implement this interface to provide database-specific
// DSN construction and connection logic for their particular database format.
Expand All @@ -51,6 +55,7 @@ type Driver struct {
driverName string
typeConverter TypeConverter
connectionFactory ConnectionFactory
stmtFactory StatementFactory
dbFactory DBFactory
errorInspector driverbase.ErrorInspector
}
Expand All @@ -70,6 +75,7 @@ func NewDriver(alloc memory.Allocator, driverName, vendorName string, dbFactory
driverName: driverName,
typeConverter: converter,
connectionFactory: nil, // No custom factory by default
stmtFactory: nil,
dbFactory: dbFactory,
}
}
Expand All @@ -82,6 +88,11 @@ func (d *Driver) WithConnectionFactory(factory ConnectionFactory) *Driver {
return d
}

func (d *Driver) WithStatementFactory(factory StatementFactory) *Driver {
d.stmtFactory = factory
return d
}

// WithErrorInspector sets a custom error inspector for extracting database error metadata.
// This allows drivers to map database-specific errors to ADBC status codes and extract
// SQLSTATE, vendor codes, and other error information.
Expand All @@ -100,6 +111,7 @@ type databaseImpl struct {
typeConverter TypeConverter
// connectionFactory creates custom connection implementations if provided
connectionFactory ConnectionFactory
stmtFactory StatementFactory
}

// NewDatabaseWithContext is the main entrypoint for driver‐agnostic ADBC database creation.
Expand Down Expand Up @@ -132,6 +144,7 @@ func (d *Driver) NewDatabaseWithContext(ctx context.Context, opts map[string]str
db: sqlDB,
typeConverter: d.typeConverter,
connectionFactory: d.connectionFactory,
stmtFactory: d.stmtFactory,
}
return driverbase.NewDatabase(db), nil
}
Expand Down
13 changes: 9 additions & 4 deletions sqlwrapper/record_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ type sqlRecordReaderImpl struct {
schema *arrow.Schema

// For bind parameter support
conn *LoggingConn // Database connection to execute queries
query string // Original SQL query with placeholders
stmt *LoggingStmt // Prepared statement (optional)
typeConverter TypeConverter // Type converter for building schemas
conn *LoggingConn // Database connection to execute queries
query string // Original SQL query with placeholders
stmt *LoggingStmt // Prepared statement (optional)
typeConverter TypeConverter // Type converter for building schemas
additionalParams []any

// Performance optimization: pre-computed inserters to avoid per-value type switching
columnInserters []Inserter
Expand Down Expand Up @@ -69,6 +70,10 @@ func (s *sqlRecordReaderImpl) NextResultSet(ctx context.Context, rec arrow.Recor
}
}

if len(s.additionalParams) > 0 {
args = append(args, s.additionalParams...)
}

// Execute query (with or without parameters)
if s.stmt != nil {
s.rows, err = s.stmt.QueryContext(ctx, args...)
Expand Down
Loading
Loading