Skip to content

Commit a8047d9

Browse files
committed
Reduce QueryContext allocations by reusing the channel
This commit reduces the number of allocations and memory usage of QueryContext by inverting the goroutine: instead of processing the request in the goroutine and having it send the result, we now process the request in the method itself and goroutine is only used to interrupt the query if the context is canceled. The advantage of this approach is that we no longer need to send anything on the channel, but instead can treat the channel as a semaphore (this reduces the amount of memory allocated by this method). Additionally, we now reuse the channel used to communicate with the goroutine which reduces the number of allocations. This commit also adds a test that actually exercises the sqlite3_interrupt logic since the existing tests did not. Those tests cancelled the context before scanning any of the rows and could be made to pass without ever calling sqlite3_interrupt. The below version of SQLiteRows.Next passes the previous tests: ```go func (rc *SQLiteRows) Next(dest []driver.Value) error { rc.s.mu.Lock() defer rc.s.mu.Unlock() if rc.s.closed { return io.EOF } if err := rc.ctx.Err(); err != nil { return err } return rc.nextSyncLocked(dest) } ``` Benchmark results: ``` goos: darwin goarch: arm64 pkg: github.com/mattn/go-sqlite3 cpu: Apple M1 Max │ old.txt │ new.txt │ │ sec/op │ sec/op vs base │ Suite/BenchmarkQueryContext/Background-10 3.994µ ± 2% 4.034µ ± 1% ~ (p=0.289 n=10) Suite/BenchmarkQueryContext/WithCancel-10 12.02µ ± 3% 11.56µ ± 4% -3.87% (p=0.003 n=10) geomean 6.930µ 6.829µ -1.46% │ old.txt │ new.txt │ │ B/op │ B/op vs base │ Suite/BenchmarkQueryContext/Background-10 400.0 ± 0% 400.0 ± 0% ~ (p=1.000 n=10) ¹ Suite/BenchmarkQueryContext/WithCancel-10 2.376Ki ± 0% 1.025Ki ± 0% -56.87% (p=0.000 n=10) geomean 986.6 647.9 -34.33% ¹ all samples are equal │ old.txt │ new.txt │ │ allocs/op │ allocs/op vs base │ Suite/BenchmarkQueryContext/Background-10 12.00 ± 0% 12.00 ± 0% ~ (p=1.000 n=10) ¹ Suite/BenchmarkQueryContext/WithCancel-10 38.00 ± 0% 28.00 ± 0% -26.32% (p=0.000 n=10) geomean 21.35 18.33 -14.16% ¹ all samples are equal ```
1 parent 7658c06 commit a8047d9

File tree

3 files changed

+289
-43
lines changed

3 files changed

+289
-43
lines changed

sqlite3.go

+56-36
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,9 @@ type SQLiteRows struct {
399399
decltype []string
400400
ctx context.Context // no better alternative to pass context into Next() method
401401
closemu sync.Mutex
402+
// semaphore to signal the goroutine used to interrupt queries when a
403+
// cancellable context is passed to QueryContext
404+
sema chan struct{}
402405
}
403406

404407
type functionInfo struct {
@@ -2050,36 +2053,37 @@ func isInterruptErr(err error) bool {
20502053

20512054
// exec executes a query that doesn't return rows. Attempts to honor context timeout.
20522055
func (s *SQLiteStmt) exec(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
2053-
if ctx.Done() == nil {
2056+
done := ctx.Done()
2057+
if done == nil {
20542058
return s.execSync(args)
20552059
}
2056-
2057-
type result struct {
2058-
r driver.Result
2059-
err error
2060+
if err := ctx.Err(); err != nil {
2061+
return nil, err // Fast check if the channel is closed
20602062
}
2061-
resultCh := make(chan result)
2062-
defer close(resultCh)
2063+
2064+
sema := make(chan struct{})
20632065
go func() {
2064-
r, err := s.execSync(args)
2065-
resultCh <- result{r, err}
2066-
}()
2067-
var rv result
2068-
select {
2069-
case rv = <-resultCh:
2070-
case <-ctx.Done():
20712066
select {
2072-
case rv = <-resultCh: // no need to interrupt, operation completed in db
2073-
default:
2074-
// this is still racy and can be no-op if executed between sqlite3_* calls in execSync.
2067+
case <-done:
20752068
C.sqlite3_interrupt(s.c.db)
2076-
rv = <-resultCh // wait for goroutine completed
2077-
if isInterruptErr(rv.err) {
2078-
return nil, ctx.Err()
2079-
}
2069+
// Wait until signaled. We need to ensure that this goroutine
2070+
// will not call interrupt after this method returns.
2071+
<-sema
2072+
case <-sema:
20802073
}
2074+
}()
2075+
r, err := s.execSync(args)
2076+
// Signal the goroutine to exit. This send will only succeed at a point
2077+
// where it is impossible for the goroutine to call sqlite3_interrupt.
2078+
//
2079+
// This is necessary to ensure the goroutine does not interrupt an
2080+
// unrelated query if the context is cancelled after this method returns
2081+
// but before the goroutine exits (we don't wait for it to exit).
2082+
sema <- struct{}{}
2083+
if err != nil && isInterruptErr(err) {
2084+
return nil, ctx.Err()
20812085
}
2082-
return rv.r, rv.err
2086+
return r, err
20832087
}
20842088

20852089
func (s *SQLiteStmt) execSync(args []driver.NamedValue) (driver.Result, error) {
@@ -2117,6 +2121,9 @@ func (rc *SQLiteRows) Close() error {
21172121
return nil
21182122
}
21192123
rc.s = nil // remove reference to SQLiteStmt
2124+
if rc.sema != nil {
2125+
close(rc.sema)
2126+
}
21202127
s.mu.Lock()
21212128
if s.closed {
21222129
s.mu.Unlock()
@@ -2174,27 +2181,40 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
21742181
return io.EOF
21752182
}
21762183

2177-
if rc.ctx.Done() == nil {
2184+
done := rc.ctx.Done()
2185+
if done == nil {
21782186
return rc.nextSyncLocked(dest)
21792187
}
2180-
resultCh := make(chan error)
2181-
defer close(resultCh)
2188+
if err := rc.ctx.Err(); err != nil {
2189+
return err // Fast check if the channel is closed
2190+
}
2191+
2192+
if rc.sema == nil {
2193+
rc.sema = make(chan struct{})
2194+
}
21822195
go func() {
2183-
resultCh <- rc.nextSyncLocked(dest)
2184-
}()
2185-
select {
2186-
case err := <-resultCh:
2187-
return err
2188-
case <-rc.ctx.Done():
21892196
select {
2190-
case <-resultCh: // no need to interrupt
2191-
default:
2192-
// this is still racy and can be no-op if executed between sqlite3_* calls in nextSyncLocked.
2197+
case <-done:
21932198
C.sqlite3_interrupt(rc.s.c.db)
2194-
<-resultCh // ensure goroutine completed
2199+
// Wait until signaled. We need to ensure that this goroutine
2200+
// will not call interrupt after this method returns.
2201+
<-rc.sema
2202+
case <-rc.sema:
21952203
}
2196-
return rc.ctx.Err()
2204+
}()
2205+
2206+
err := rc.nextSyncLocked(dest)
2207+
// Signal the goroutine to exit. This send will only succeed at a point
2208+
// where it is impossible for the goroutine to call sqlite3_interrupt.
2209+
//
2210+
// This is necessary to ensure the goroutine does not interrupt an
2211+
// unrelated query if the context is cancelled after this method returns
2212+
// but before the goroutine exits (we don't wait for it to exit).
2213+
rc.sema <- struct{}{}
2214+
if err != nil && isInterruptErr(err) {
2215+
err = rc.ctx.Err()
21972216
}
2217+
return err
21982218
}
21992219

22002220
// nextSyncLocked moves cursor to next; must be called with locked mutex.

sqlite3_go18_test.go

+147
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@ package sqlite3
1111
import (
1212
"context"
1313
"database/sql"
14+
"errors"
1415
"fmt"
1516
"io/ioutil"
1617
"math/rand"
1718
"os"
19+
"strings"
1820
"sync"
1921
"testing"
2022
"time"
@@ -268,6 +270,151 @@ func TestQueryRowContextCancelParallel(t *testing.T) {
268270
}
269271
}
270272

273+
// Test that we can successfully interrupt a long running query when
274+
// the context is canceled. The previous two QueryRowContext tests
275+
// only test that we handle a previously cancelled context and thus
276+
// do not call sqlite3_interrupt.
277+
func TestQueryRowContextCancelInterrupt(t *testing.T) {
278+
db, err := sql.Open("sqlite3", ":memory:")
279+
if err != nil {
280+
t.Fatal(err)
281+
}
282+
defer db.Close()
283+
284+
// Test that we have the unixepoch function and if not skip the test.
285+
if _, err := db.Exec(`SELECT unixepoch(datetime(100000, 'unixepoch', 'localtime'))`); err != nil {
286+
libVersion, libVersionNumber, sourceID := Version()
287+
if strings.Contains(err.Error(), "no such function: unixepoch") {
288+
t.Skip("Skipping the 'unixepoch' function is not implemented in "+
289+
"this version of sqlite3:", libVersion, libVersionNumber, sourceID)
290+
}
291+
t.Fatal(err)
292+
}
293+
294+
const createTableStmt = `
295+
CREATE TABLE timestamps (
296+
ts TIMESTAMP NOT NULL
297+
);`
298+
if _, err := db.Exec(createTableStmt); err != nil {
299+
t.Fatal(err)
300+
}
301+
302+
stmt, err := db.Prepare(`INSERT INTO timestamps VALUES (?);`)
303+
if err != nil {
304+
t.Fatal(err)
305+
}
306+
defer stmt.Close()
307+
308+
// Computationally expensive query that consumes many rows. This is needed
309+
// to test cancellation because queries are not interrupted immediately.
310+
// Instead, queries are only halted at certain checkpoints where the
311+
// sqlite3.isInterrupted is checked and true.
312+
queryStmt := `
313+
SELECT
314+
SUM(unixepoch(datetime(ts + 10, 'unixepoch', 'localtime'))) AS c1,
315+
SUM(unixepoch(datetime(ts + 20, 'unixepoch', 'localtime'))) AS c2,
316+
SUM(unixepoch(datetime(ts + 30, 'unixepoch', 'localtime'))) AS c3,
317+
SUM(unixepoch(datetime(ts + 40, 'unixepoch', 'localtime'))) AS c4
318+
FROM
319+
timestamps
320+
WHERE datetime(ts, 'unixepoch', 'localtime')
321+
LIKE
322+
?;`
323+
324+
query := func(t *testing.T, timeout time.Duration) (int, error) {
325+
// Create a complicated pattern to match timestamps
326+
const pattern = "%2%0%2%4%-%-%:%:%"
327+
ctx, cancel := context.WithTimeout(context.Background(), timeout)
328+
defer cancel()
329+
rows, err := db.QueryContext(ctx, queryStmt, pattern)
330+
if err != nil {
331+
return 0, err
332+
}
333+
var count int
334+
for rows.Next() {
335+
var n int64
336+
if err := rows.Scan(&n, &n, &n, &n); err != nil {
337+
return count, err
338+
}
339+
count++
340+
}
341+
return count, rows.Err()
342+
}
343+
344+
average := func(n int, fn func()) time.Duration {
345+
start := time.Now()
346+
for i := 0; i < n; i++ {
347+
fn()
348+
}
349+
return time.Since(start) / time.Duration(n)
350+
}
351+
352+
createRows := func(n int) {
353+
t.Logf("Creating %d rows", n)
354+
if _, err := db.Exec(`DELETE FROM timestamps; VACUUM;`); err != nil {
355+
t.Fatal(err)
356+
}
357+
ts := time.Date(2024, 6, 6, 8, 9, 10, 12345, time.UTC).Unix()
358+
rr := rand.New(rand.NewSource(1234))
359+
for i := 0; i < n; i++ {
360+
if _, err := stmt.Exec(ts + rr.Int63n(10_000) - 5_000); err != nil {
361+
t.Fatal(err)
362+
}
363+
}
364+
}
365+
366+
const TargetRuntime = 200 * time.Millisecond
367+
const N = 5_000 // Number of rows to insert at a time
368+
369+
// Create enough rows that the query takes ~200ms to run.
370+
start := time.Now()
371+
createRows(N)
372+
baseAvg := average(4, func() {
373+
if _, err := query(t, time.Hour); err != nil {
374+
t.Fatal(err)
375+
}
376+
})
377+
t.Log("Base average:", baseAvg)
378+
rowCount := N * (int(TargetRuntime/baseAvg) + 1)
379+
createRows(rowCount)
380+
t.Log("Table setup time:", time.Since(start))
381+
382+
// Set the timeout to 1/10 of the average query time.
383+
avg := average(2, func() {
384+
n, err := query(t, time.Hour)
385+
if err != nil {
386+
t.Fatal(err)
387+
}
388+
if n == 0 {
389+
t.Fatal("scanned zero rows")
390+
}
391+
})
392+
// Guard against the timeout being too short to reliably test.
393+
if avg < TargetRuntime/2 {
394+
t.Fatalf("Average query runtime should be around %s got: %s ",
395+
TargetRuntime, avg)
396+
}
397+
timeout := (avg / 10).Round(100 * time.Microsecond)
398+
t.Logf("Average: %s Timeout: %s", avg, timeout)
399+
400+
for i := 0; i < 10; i++ {
401+
tt := time.Now()
402+
n, err := query(t, timeout)
403+
if !errors.Is(err, context.DeadlineExceeded) {
404+
fn := t.Errorf
405+
if err != nil {
406+
fn = t.Fatalf
407+
}
408+
fn("expected error %v got %v", context.DeadlineExceeded, err)
409+
}
410+
d := time.Since(tt)
411+
t.Logf("%d: rows: %d duration: %s", i, n, d)
412+
if d > timeout*4 {
413+
t.Errorf("query was cancelled after %s but did not abort until: %s", timeout, d)
414+
}
415+
}
416+
}
417+
271418
func TestExecCancel(t *testing.T) {
272419
db, err := sql.Open("sqlite3", ":memory:")
273420
if err != nil {

0 commit comments

Comments
 (0)