Skip to content

Commit 99e7b9c

Browse files
Merge pull request marcboeker#333 from apecloud/expose-stmt-bind
Expose `Bind` and `(Query|Exec)Bound` on `Stmt` for advanced usage
2 parents 0614b2d + b8a948d commit 99e7b9c

File tree

3 files changed

+129
-3
lines changed

3 files changed

+129
-3
lines changed

errors.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ var (
106106
errPrepare = errors.New("could not prepare query")
107107
errMissingPrepareContext = errors.New("missing context for multi-statement query: try using PrepareContext")
108108
errEmptyQuery = errors.New("empty query")
109+
errCouldNotBind = errors.New("could not bind parameter")
110+
errActiveRows = errors.New("ExecContext or QueryContext with active Rows")
111+
errNotBound = errors.New("parameters have not been bound")
109112
errBeginTx = errors.New("could not begin transaction")
110113
errMultipleTx = errors.New("multiple transactions")
111114
errReadOnlyTxNotSupported = errors.New("read-only transactions are not supported")

statement.go

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ type Stmt struct {
5353
c *Conn
5454
stmt *C.duckdb_prepared_statement
5555
closeOnRowsClose bool
56+
bound bool
5657
closed bool
5758
rows bool
5859
}
@@ -131,6 +132,18 @@ func (s *Stmt) StatementType() (StmtType, error) {
131132
return StmtType(C.duckdb_prepared_statement_type(*s.stmt)), nil
132133
}
133134

135+
// Bind binds the parameters to the statement.
136+
// WARNING: This is a low-level API and should be used with caution.
137+
func (s *Stmt) Bind(args []driver.NamedValue) error {
138+
if s.closed {
139+
return errors.Join(errCouldNotBind, errClosedStmt)
140+
}
141+
if s.stmt == nil {
142+
return errors.Join(errCouldNotBind, errUninitializedStmt)
143+
}
144+
return s.bind(args)
145+
}
146+
134147
func (s *Stmt) bind(args []driver.NamedValue) error {
135148
if s.NumInput() > len(args) {
136149
return fmt.Errorf("incorrect argument count for command: have %d want %d", len(args), s.NumInput())
@@ -258,6 +271,7 @@ func (s *Stmt) bind(args []driver.NamedValue) error {
258271
}
259272
}
260273

274+
s.bound = true
261275
return nil
262276
}
263277

@@ -279,6 +293,30 @@ func (s *Stmt) ExecContext(ctx context.Context, nargs []driver.NamedValue) (driv
279293
return &result{ra}, nil
280294
}
281295

296+
// ExecBound executes a bound query that doesn't return rows, such as an INSERT or UPDATE.
297+
// It can only be used after Bind has been called.
298+
// WARNING: This is a low-level API and should be used with caution.
299+
func (s *Stmt) ExecBound(ctx context.Context) (driver.Result, error) {
300+
if s.closed {
301+
return nil, errClosedCon
302+
}
303+
if s.rows {
304+
return nil, errActiveRows
305+
}
306+
if !s.bound {
307+
return nil, errNotBound
308+
}
309+
310+
res, err := s.executeBound(ctx)
311+
if err != nil {
312+
return nil, err
313+
}
314+
defer C.duckdb_destroy_result(res)
315+
316+
ra := int64(C.duckdb_value_int64(res, 0, 0))
317+
return &result{ra}, nil
318+
}
319+
282320
// Deprecated: Use QueryContext instead.
283321
func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
284322
return s.QueryContext(context.Background(), argsToNamedArgs(args))
@@ -295,6 +333,28 @@ func (s *Stmt) QueryContext(ctx context.Context, nargs []driver.NamedValue) (dri
295333
return newRowsWithStmt(*res, s), nil
296334
}
297335

336+
// QueryBound executes a bound query that may return rows, such as a SELECT.
337+
// It can only be used after Bind has been called.
338+
// WARNING: This is a low-level API and should be used with caution.
339+
func (s *Stmt) QueryBound(ctx context.Context) (driver.Rows, error) {
340+
if s.closed {
341+
return nil, errClosedCon
342+
}
343+
if s.rows {
344+
return nil, errActiveRows
345+
}
346+
if !s.bound {
347+
return nil, errNotBound
348+
}
349+
350+
res, err := s.executeBound(ctx)
351+
if err != nil {
352+
return nil, err
353+
}
354+
s.rows = true
355+
return newRowsWithStmt(*res, s), nil
356+
}
357+
298358
// This method executes the query in steps and checks if context is cancelled before executing each step.
299359
// It uses Pending Result Interface C APIs to achieve this. Reference - https://duckdb.org/docs/api/c/api#pending-result-interface
300360
func (s *Stmt) execute(ctx context.Context, args []driver.NamedValue) (*C.duckdb_result, error) {
@@ -304,11 +364,13 @@ func (s *Stmt) execute(ctx context.Context, args []driver.NamedValue) (*C.duckdb
304364
if s.rows {
305365
panic("database/sql/driver: misuse of duckdb driver: ExecContext or QueryContext with active Rows")
306366
}
307-
308367
if err := s.bind(args); err != nil {
309368
return nil, err
310369
}
370+
return s.executeBound(ctx)
371+
}
311372

373+
func (s *Stmt) executeBound(ctx context.Context) (*C.duckdb_result, error) {
312374
var pendingRes C.duckdb_pending_result
313375
if state := C.duckdb_pending_prepared(*s.stmt, &pendingRes); state == C.DuckDBError {
314376
dbErr := getDuckDBError(C.GoString(C.duckdb_pending_error(pendingRes)))
@@ -360,5 +422,3 @@ func argsToNamedArgs(values []driver.Value) []driver.NamedValue {
360422
}
361423
return args
362424
}
363-
364-
var errCouldNotBind = errors.New("could not bind parameter")

statement_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package duckdb
33
import (
44
"context"
55
"database/sql"
6+
"database/sql/driver"
67
"errors"
78
"testing"
89

@@ -56,6 +57,27 @@ func TestPrepareQuery(t *testing.T) {
5657
require.ErrorContains(t, err, paramIndexErrMsg)
5758
require.Equal(t, TYPE_INVALID, paramType)
5859

60+
rows, err := stmt.QueryBound(context.Background())
61+
require.Nil(t, rows)
62+
require.ErrorIs(t, err, errNotBound)
63+
64+
err = stmt.Bind([]driver.NamedValue{{Ordinal: 1, Value: 0}})
65+
require.NoError(t, err)
66+
67+
rows, err = stmt.QueryBound(context.Background())
68+
require.NoError(t, err)
69+
require.NotNil(t, rows)
70+
71+
badRows, err := stmt.QueryBound(context.Background())
72+
require.ErrorIs(t, err, errActiveRows)
73+
require.Nil(t, badRows)
74+
75+
badResults, err := stmt.ExecBound(context.Background())
76+
require.ErrorIs(t, err, errActiveRows)
77+
require.Nil(t, badResults)
78+
79+
require.NoError(t, rows.Close())
80+
5981
require.NoError(t, stmt.Close())
6082

6183
stmtType, err = stmt.StatementType()
@@ -66,6 +88,10 @@ func TestPrepareQuery(t *testing.T) {
6688
require.ErrorIs(t, err, errClosedStmt)
6789
require.Equal(t, TYPE_INVALID, paramType)
6890

91+
err = stmt.Bind([]driver.NamedValue{{Ordinal: 1, Value: 0}})
92+
require.ErrorIs(t, err, errCouldNotBind)
93+
require.ErrorIs(t, err, errClosedStmt)
94+
6995
return nil
7096
})
7197
require.NoError(t, err)
@@ -146,6 +172,17 @@ func TestPrepareQueryPositional(t *testing.T) {
146172
require.ErrorContains(t, err, paramIndexErrMsg)
147173
require.Equal(t, TYPE_INVALID, paramType)
148174

175+
result, err := stmt.ExecBound(context.Background())
176+
require.Nil(t, result)
177+
require.ErrorIs(t, err, errNotBound)
178+
179+
err = stmt.Bind([]driver.NamedValue{{Ordinal: 1, Value: 0}, {Ordinal: 2, Value: "hello"}})
180+
require.NoError(t, err)
181+
182+
result, err = stmt.ExecBound(context.Background())
183+
require.NoError(t, err)
184+
require.NotNil(t, result)
185+
149186
require.NoError(t, stmt.Close())
150187

151188
stmtType, err = stmt.StatementType()
@@ -160,6 +197,10 @@ func TestPrepareQueryPositional(t *testing.T) {
160197
require.ErrorIs(t, err, errClosedStmt)
161198
require.Equal(t, TYPE_INVALID, paramType)
162199

200+
err = stmt.Bind([]driver.NamedValue{{Ordinal: 1, Value: 0}, {Ordinal: 2, Value: "hello"}})
201+
require.ErrorIs(t, err, errCouldNotBind)
202+
require.ErrorIs(t, err, errClosedStmt)
203+
163204
return nil
164205
})
165206
require.NoError(t, err)
@@ -245,6 +286,17 @@ func TestPrepareQueryNamed(t *testing.T) {
245286
require.ErrorContains(t, err, paramIndexErrMsg)
246287
require.Equal(t, TYPE_INVALID, paramType)
247288

289+
result, err := stmt.ExecBound(context.Background())
290+
require.Nil(t, result)
291+
require.ErrorIs(t, err, errNotBound)
292+
293+
err = stmt.Bind([]driver.NamedValue{{Name: "bar", Value: "hello"}, {Name: "baz", Value: 0}})
294+
require.NoError(t, err)
295+
296+
result, err = stmt.ExecBound(context.Background())
297+
require.NoError(t, err)
298+
require.NotNil(t, result)
299+
248300
require.NoError(t, stmt.Close())
249301

250302
stmtType, err = stmt.StatementType()
@@ -259,6 +311,10 @@ func TestPrepareQueryNamed(t *testing.T) {
259311
require.ErrorIs(t, err, errClosedStmt)
260312
require.Equal(t, TYPE_INVALID, paramType)
261313

314+
err = stmt.Bind([]driver.NamedValue{{Name: "bar", Value: "hello"}, {Name: "baz", Value: 0}})
315+
require.ErrorIs(t, err, errCouldNotBind)
316+
require.ErrorIs(t, err, errClosedStmt)
317+
262318
return nil
263319
})
264320
require.NoError(t, err)
@@ -280,6 +336,13 @@ func TestUninitializedStmt(t *testing.T) {
280336
paramName, err := stmt.ParamName(1)
281337
require.ErrorIs(t, err, errUninitializedStmt)
282338
require.Equal(t, "", paramName)
339+
340+
err = stmt.Bind([]driver.NamedValue{{Ordinal: 1, Value: 0}})
341+
require.ErrorIs(t, err, errCouldNotBind)
342+
require.ErrorIs(t, err, errUninitializedStmt)
343+
344+
_, err = stmt.ExecBound(context.Background())
345+
require.ErrorIs(t, err, errNotBound)
283346
}
284347

285348
func TestPrepareWithError(t *testing.T) {

0 commit comments

Comments
 (0)