Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 51 additions & 9 deletions ffitemplate/_tmpl/driver.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ import (
"fmt"
"io"
"log/slog"
"math"
"os"
"runtime"
"runtime/cgo"
"slices"
"strings"
"sync/atomic"
"unsafe"
Expand Down Expand Up @@ -148,6 +150,7 @@ func setErrWithDetails(err *C.struct_AdbcError, adbcError adbc.Error) {
cErr.values = (**C.cuint8_t)(C.calloc(C.size_t(numDetails), C.size_t(unsafe.Sizeof((*C.cuint8_t)(nil)))))
cErr.lengths = (*C.size_t)(C.calloc(C.size_t(numDetails), C.sizeof_size_t))

// SAFETY: no copy of fromCArr because these are written to, not read from
keys := fromCArr[*C.cchar_t](cErr.keys, numDetails)
values := fromCArr[*C.cuint8_t](cErr.values, numDetails)
lengths := fromCArr[C.size_t](cErr.lengths, numDetails)
Expand Down Expand Up @@ -264,9 +267,10 @@ func getFromHandle[T any](ptr unsafe.Pointer) *T {
}

func exportStringOption(val string, out *C.char, length *C.size_t) C.AdbcStatusCode {
lenWithTerminator := C.size_t(len(val) + 1)
lenWithTerminator := C.size_t(len(val)+1)
if lenWithTerminator <= *length {
sink := fromCArr[byte]((*byte)(unsafe.Pointer(out)), int(*length))
// SAFETY: no copy of fromCArr because this is written to, not read from
sink := fromCArr[byte]((*byte)(unsafe.Pointer(out)), len(val)+1)
copy(sink, val)
sink[len(val)] = 0
}
Expand All @@ -276,7 +280,8 @@ func exportStringOption(val string, out *C.char, length *C.size_t) C.AdbcStatusC

func exportBytesOption(val []byte, out *C.uint8_t, length *C.size_t) C.AdbcStatusCode {
if C.size_t(len(val)) <= *length {
sink := fromCArr[byte]((*byte)(out), int(*length))
// SAFETY: no copy of fromCArr because this is written to, not read from
sink := fromCArr[byte]((*byte)(out), len(val))
copy(sink, val)
}
*length = C.size_t(len(val))
Expand Down Expand Up @@ -703,7 +708,11 @@ func {{.Prefix}}DatabaseSetOptionBytes(db *C.struct_AdbcDatabase, key *C.cchar_t
}
cdb := getFromHandle[cDatabase](db.private_data)
k := C.GoString(key)
v := fromCArr[byte](value, int(length))
var safeLen int
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use c.GoBytes(value, length)?

That would do the checks and create the copy, outputting a []byte for you and simplify this a bit.

if safeLen, code = checkLengthToInt(length, err); code != C.ADBC_STATUS_OK {
return code
}
v := slices.Clone(fromCArr[byte](value, safeLen))

if cdb.db != nil {
e := cdb.db.SetOptionBytes(cdb.newContext(), k, v)
Expand Down Expand Up @@ -924,7 +933,11 @@ func {{.Prefix}}ConnectionSetOptionBytes(db *C.struct_AdbcConnection, key *C.cch
return C.ADBC_STATUS_INVALID_STATE
}

e := conn.cnxn.SetOptionBytes(conn.newContext(), C.GoString(key), fromCArr[byte](value, int(length)))
var safeLen int
if safeLen, code = checkLengthToInt(length, err); code != C.ADBC_STATUS_OK {
return code
}
e := conn.cnxn.SetOptionBytes(conn.newContext(), C.GoString(key), slices.Clone(fromCArr[byte](value, safeLen)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above, just use c.GoBytes instead of slices.Clone(fromCArr....)

return C.AdbcStatusCode(errToAdbcErr(err, e))
}

Expand Down Expand Up @@ -1034,6 +1047,7 @@ func {{.Prefix}}ConnectionRelease(cnxn *C.struct_AdbcConnection, err *C.struct_A
return C.AdbcStatusCode(errToAdbcErr(err, conn.cnxn.Close(conn.newContext())))
}

// SAFETY: at each call site, consider whether a copy of the resulting slice must be made
func fromCArr[T, CType any](ptr *CType, sz int) []T {
if ptr == nil || sz == 0 {
return nil
Expand All @@ -1042,6 +1056,14 @@ func fromCArr[T, CType any](ptr *CType, sz int) []T {
return unsafe.Slice((*T)(unsafe.Pointer(ptr)), sz)
}

func checkLengthToInt(length C.size_t, err *C.struct_AdbcError) (int, C.AdbcStatusCode) {
if length > C.size_t(math.MaxInt) {
setErr(err, "Length %d exceeds max Go int %d", length, math.MaxInt)
return 0, C.ADBC_STATUS_INVALID_ARGUMENT
}
return int(length), C.ADBC_STATUS_OK
}

func toCdataStream(ptr *C.struct_ArrowArrayStream) *cdata.CArrowArrayStream {
return (*cdata.CArrowArrayStream)(unsafe.Pointer(ptr))
}
Expand Down Expand Up @@ -1106,7 +1128,11 @@ func {{.Prefix}}ConnectionGetInfo(cnxn *C.struct_AdbcConnection, codes *C.cuint3
return C.ADBC_STATUS_INVALID_STATE
}

infoCodes := fromCArr[adbc.InfoCode](codes, int(len))
var safeLen int
if safeLen, code = checkLengthToInt(len, err); code != C.ADBC_STATUS_OK {
return code
}
infoCodes := slices.Clone(fromCArr[adbc.InfoCode](codes, safeLen))
rdr, e := conn.cnxn.GetInfo(conn.newContext(), infoCodes)
if e != nil {
return C.AdbcStatusCode(errToAdbcErr(err, e))
Expand Down Expand Up @@ -1247,7 +1273,11 @@ func {{.Prefix}}ConnectionReadPartition(cnxn *C.struct_AdbcConnection, serialize
return C.ADBC_STATUS_INVALID_STATE
}

rdr, e := conn.cnxn.ReadPartition(conn.newContext(), fromCArr[byte](serialized, int(serializedLen)))
var safeLen int
if safeLen, code = checkLengthToInt(serializedLen, err); code != C.ADBC_STATUS_OK {
return code
}
rdr, e := conn.cnxn.ReadPartition(conn.newContext(), slices.Clone(fromCArr[byte](serialized, safeLen)))
if e != nil {
return C.AdbcStatusCode(errToAdbcErr(err, e))
}
Expand Down Expand Up @@ -1605,7 +1635,11 @@ func {{.Prefix}}StatementSetSubstraitPlan(stmt *C.struct_AdbcStatement, plan *C.
return C.ADBC_STATUS_INVALID_STATE
}

e := st.stmt.SetSubstraitPlan(st.newContext(), fromCArr[byte](plan, int(length)))
var safeLen int
if safeLen, code = checkLengthToInt(length, err); code != C.ADBC_STATUS_OK {
return code
}
e := st.stmt.SetSubstraitPlan(st.newContext(), slices.Clone(fromCArr[byte](plan, safeLen)))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And so on.... Not gonna comment at every call site

return C.AdbcStatusCode(errToAdbcErr(err, e))
}

Expand Down Expand Up @@ -1706,7 +1740,12 @@ func {{.Prefix}}StatementSetOptionBytes(db *C.struct_AdbcStatement, key *C.cchar
return C.ADBC_STATUS_NOT_IMPLEMENTED
}

e := opts.SetOptionBytes(st.newContext(), C.GoString(key), fromCArr[byte](value, int(length)))

var safeLen int
if safeLen, code = checkLengthToInt(length, err); code != C.ADBC_STATUS_OK {
return code
}
e := opts.SetOptionBytes(st.newContext(), C.GoString(key), slices.Clone(fromCArr[byte](value, safeLen)))
return C.AdbcStatusCode(errToAdbcErr(err, e))
}

Expand Down Expand Up @@ -1808,6 +1847,7 @@ func {{.Prefix}}StatementExecutePartitions(stmt *C.struct_AdbcStatement, schema
totalLen += len(p)
}
partitions.private_data = C.calloc(C.size_t(totalLen), C.size_t(1))
// SAFETY: no copy of fromCArr because this is written to, not read from
dst := fromCArr[byte]((*byte)(partitions.private_data), totalLen)

partIDs := fromCArr[*C.cuint8_t](partitions.partitions, int(partitions.num_partitions))
Expand All @@ -1829,9 +1869,11 @@ func AdbcDriver{{.Prefix}}Init(version C.int, rawDriver *C.void, err *C.struct_A

switch version {
case C.ADBC_VERSION_1_0_0:
// SAFETY: no copy of fromCArr because this is written to, not read from
sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_0_0_SIZE)
memory.Set(sink, 0)
case C.ADBC_VERSION_1_1_0:
// SAFETY: no copy of fromCArr because this is written to, not read from
sink := fromCArr[byte]((*byte)(unsafe.Pointer(driver)), C.ADBC_DRIVER_1_1_0_SIZE)
memory.Set(sink, 0)
default:
Expand Down
Loading