Skip to content

Commit d3c66c9

Browse files
committed
Improve the performance of QueryContext by reusing the result channel
This commit improves the performance of QueryContext by changing it to reuse the result channel instead of creating a new one for each query. This is particularly impactful for queries that scan more than one row. It also adds a test that actually exercises the sqlite3_interrupt logic since the existing tests did not. Those tests cancelled the context before scanning any of the rows and could be made to pass without ever calling sqlite3_interrupt. The below version of SQLiteRows.Next passes the previous tests: ```go func (rc *SQLiteRows) Next(dest []driver.Value) error { rc.s.mu.Lock() defer rc.s.mu.Unlock() if rc.s.closed { return io.EOF } if err := rc.ctx.Err(); err != nil { return err } return rc.nextSyncLocked(dest) } ``` Benchmark results: ``` goos: darwin goarch: arm64 pkg: github.com/mattn/go-sqlite3 cpu: Apple M1 Max │ b.txt │ n.txt │ │ sec/op │ sec/op vs base │ Suite/BenchmarkQueryContext/Background-10 4.088µ ± 1% 4.154µ ± 3% +1.60% (p=0.011 n=10) Suite/BenchmarkQueryContext/WithCancel-10 12.84µ ± 3% 11.67µ ± 3% -9.08% (p=0.000 n=10) geomean 7.245µ 6.963µ -3.89% │ b.txt │ n.txt │ │ B/op │ B/op vs base │ Suite/BenchmarkQueryContext/Background-10 400.0 ± 0% 400.0 ± 0% ~ (p=1.000 n=10) ¹ Suite/BenchmarkQueryContext/WithCancel-10 2.547Ki ± 0% 1.282Ki ± 0% -49.67% (p=0.000 n=10) geomean 1021.4 724.6 -29.06% ¹ all samples are equal │ b.txt │ n.txt │ │ allocs/op │ allocs/op vs base │ Suite/BenchmarkQueryContext/Background-10 12.00 ± 0% 12.00 ± 0% ~ (p=1.000 n=10) ¹ Suite/BenchmarkQueryContext/WithCancel-10 49.00 ± 0% 28.00 ± 0% -42.86% (p=0.000 n=10) geomean 24.25 18.33 -24.41% ¹ all samples are equal ```
1 parent 41871ea commit d3c66c9

File tree

3 files changed

+236
-15
lines changed

3 files changed

+236
-15
lines changed

sqlite3.go

+14-8
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ type SQLiteRows struct {
399399
cls bool
400400
closed bool
401401
ctx context.Context // no better alternative to pass context into Next() method
402+
resultCh chan error
402403
}
403404

404405
type functionInfo struct {
@@ -2172,24 +2173,29 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
21722173
return io.EOF
21732174
}
21742175

2175-
if rc.ctx.Done() == nil {
2176+
done := rc.ctx.Done()
2177+
if done == nil {
21762178
return rc.nextSyncLocked(dest)
21772179
}
2178-
resultCh := make(chan error)
2179-
defer close(resultCh)
2180+
if err := rc.ctx.Err(); err != nil {
2181+
return err // Fast check if the channel is closed
2182+
}
2183+
if rc.resultCh == nil {
2184+
rc.resultCh = make(chan error)
2185+
}
21802186
go func() {
2181-
resultCh <- rc.nextSyncLocked(dest)
2187+
rc.resultCh <- rc.nextSyncLocked(dest)
21822188
}()
21832189
select {
2184-
case err := <-resultCh:
2190+
case err := <-rc.resultCh:
21852191
return err
2186-
case <-rc.ctx.Done():
2192+
case <-done:
21872193
select {
2188-
case <-resultCh: // no need to interrupt
2194+
case <-rc.resultCh: // no need to interrupt
21892195
default:
21902196
// this is still racy and can be no-op if executed between sqlite3_* calls in nextSyncLocked.
21912197
C.sqlite3_interrupt(rc.s.c.db)
2192-
<-resultCh // ensure goroutine completed
2198+
<-rc.resultCh // ensure goroutine completed
21932199
}
21942200
return rc.ctx.Err()
21952201
}

sqlite3_go18_test.go

+147
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@ package sqlite3
1111
import (
1212
"context"
1313
"database/sql"
14+
"errors"
1415
"fmt"
1516
"io/ioutil"
1617
"math/rand"
1718
"os"
19+
"strings"
1820
"sync"
1921
"testing"
2022
"time"
@@ -268,6 +270,151 @@ func TestQueryRowContextCancelParallel(t *testing.T) {
268270
}
269271
}
270272

273+
// Test that we can successfully interrupt a long running query when
274+
// the context is canceled. The previous two QueryRowContext tests
275+
// only test that we handle a previously cancelled context and thus
276+
// do not call sqlite3_interrupt.
277+
func TestQueryRowContextCancelInterrupt(t *testing.T) {
278+
db, err := sql.Open("sqlite3", ":memory:")
279+
if err != nil {
280+
t.Fatal(err)
281+
}
282+
defer db.Close()
283+
284+
// Test that we have the unixepoch function and if not skip the test.
285+
if _, err := db.Exec(`SELECT unixepoch(datetime(100000, 'unixepoch', 'localtime'))`); err != nil {
286+
libVersion, libVersionNumber, sourceID := Version()
287+
if strings.Contains(err.Error(), "no such function: unixepoch") {
288+
t.Skip("Skipping the 'unixepoch' function is not implemented in "+
289+
"this version of sqlite3:", libVersion, libVersionNumber, sourceID)
290+
}
291+
t.Fatal(err)
292+
}
293+
294+
const createTableStmt = `
295+
CREATE TABLE timestamps (
296+
ts TIMESTAMP NOT NULL
297+
);`
298+
if _, err := db.Exec(createTableStmt); err != nil {
299+
t.Fatal(err)
300+
}
301+
302+
stmt, err := db.Prepare(`INSERT INTO timestamps VALUES (?);`)
303+
if err != nil {
304+
t.Fatal(err)
305+
}
306+
defer stmt.Close()
307+
308+
// Computationally expensive query that consumes many rows. This is needed
309+
// to test cancellation because queries are not interrupted immediately.
310+
// Instead, queries are only halted at certain checkpoints where the
311+
// sqlite3.isInterrupted is checked and true.
312+
queryStmt := `
313+
SELECT
314+
SUM(unixepoch(datetime(ts + 10, 'unixepoch', 'localtime'))) AS c1,
315+
SUM(unixepoch(datetime(ts + 20, 'unixepoch', 'localtime'))) AS c2,
316+
SUM(unixepoch(datetime(ts + 30, 'unixepoch', 'localtime'))) AS c3,
317+
SUM(unixepoch(datetime(ts + 40, 'unixepoch', 'localtime'))) AS c4
318+
FROM
319+
timestamps
320+
WHERE datetime(ts, 'unixepoch', 'localtime')
321+
LIKE
322+
?;`
323+
324+
query := func(t *testing.T, timeout time.Duration) (int, error) {
325+
// Create a complicated pattern to match timestamps
326+
const pattern = "%2%0%2%4%-%-%:%:%"
327+
ctx, cancel := context.WithTimeout(context.Background(), timeout)
328+
defer cancel()
329+
rows, err := db.QueryContext(ctx, queryStmt, pattern)
330+
if err != nil {
331+
return 0, err
332+
}
333+
var count int
334+
for rows.Next() {
335+
var n int64
336+
if err := rows.Scan(&n, &n, &n, &n); err != nil {
337+
return count, err
338+
}
339+
count++
340+
}
341+
return count, rows.Err()
342+
}
343+
344+
average := func(n int, fn func()) time.Duration {
345+
start := time.Now()
346+
for i := 0; i < n; i++ {
347+
fn()
348+
}
349+
return time.Since(start) / time.Duration(n)
350+
}
351+
352+
createRows := func(n int) {
353+
t.Logf("Creating %d rows", n)
354+
if _, err := db.Exec(`DELETE FROM timestamps; VACUUM;`); err != nil {
355+
t.Fatal(err)
356+
}
357+
ts := time.Date(2024, 6, 6, 8, 9, 10, 12345, time.UTC).Unix()
358+
rr := rand.New(rand.NewSource(1234))
359+
for i := 0; i < n; i++ {
360+
if _, err := stmt.Exec(ts + rr.Int63n(10_000) - 5_000); err != nil {
361+
t.Fatal(err)
362+
}
363+
}
364+
}
365+
366+
const TargetRuntime = 200 * time.Millisecond
367+
const N = 5_000 // Number of rows to insert at a time
368+
369+
// Create enough rows that the query takes ~200ms to run.
370+
start := time.Now()
371+
createRows(N)
372+
baseAvg := average(4, func() {
373+
if _, err := query(t, time.Hour); err != nil {
374+
t.Fatal(err)
375+
}
376+
})
377+
t.Log("Base average:", baseAvg)
378+
rowCount := N * (int(TargetRuntime/baseAvg) + 1)
379+
createRows(rowCount)
380+
t.Log("Table setup time:", time.Since(start))
381+
382+
// Set the timeout to 1/10 of the average query time.
383+
avg := average(2, func() {
384+
n, err := query(t, time.Hour)
385+
if err != nil {
386+
t.Fatal(err)
387+
}
388+
if n == 0 {
389+
t.Fatal("scanned zero rows")
390+
}
391+
})
392+
// Guard against the timeout being too short to reliably test.
393+
if avg < TargetRuntime/2 {
394+
t.Fatalf("Average query runtime should be around %s got: %s ",
395+
TargetRuntime, avg)
396+
}
397+
timeout := (avg / 10).Round(100 * time.Microsecond)
398+
t.Logf("Average: %s Timeout: %s", avg, timeout)
399+
400+
for i := 0; i < 10; i++ {
401+
tt := time.Now()
402+
n, err := query(t, timeout)
403+
if !errors.Is(err, context.DeadlineExceeded) {
404+
fn := t.Errorf
405+
if err != nil {
406+
fn = t.Fatalf
407+
}
408+
fn("expected error %v got %v", context.DeadlineExceeded, err)
409+
}
410+
d := time.Since(tt)
411+
t.Logf("%d: rows: %d duration: %s", i, n, d)
412+
if d > timeout*4 {
413+
t.Errorf("query was cancelled after %s but did not abort until: %s", timeout, d)
414+
}
415+
}
416+
}
417+
271418
func TestExecCancel(t *testing.T) {
272419
db, err := sql.Open("sqlite3", ":memory:")
273420
if err != nil {

sqlite3_test.go

+75-7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ package sqlite3
1010

1111
import (
1212
"bytes"
13+
"context"
1314
"database/sql"
1415
"database/sql/driver"
1516
"errors"
@@ -2030,7 +2031,7 @@ func BenchmarkCustomFunctions(b *testing.B) {
20302031
}
20312032

20322033
func TestSuite(t *testing.T) {
2033-
initializeTestDB(t)
2034+
initializeTestDB(t, false)
20342035
defer freeTestDB()
20352036

20362037
for _, test := range tests {
@@ -2039,7 +2040,7 @@ func TestSuite(t *testing.T) {
20392040
}
20402041

20412042
func BenchmarkSuite(b *testing.B) {
2042-
initializeTestDB(b)
2043+
initializeTestDB(b, true)
20432044
defer freeTestDB()
20442045

20452046
for _, benchmark := range benchmarks {
@@ -2068,8 +2069,13 @@ type TestDB struct {
20682069

20692070
var db *TestDB
20702071

2071-
func initializeTestDB(t testing.TB) {
2072-
tempFilename := TempFilename(t)
2072+
func initializeTestDB(t testing.TB, memory bool) {
2073+
var tempFilename string
2074+
if memory {
2075+
tempFilename = ":memory:"
2076+
} else {
2077+
tempFilename = TempFilename(t)
2078+
}
20732079
d, err := sql.Open("sqlite3", tempFilename+"?_busy_timeout=99999")
20742080
if err != nil {
20752081
os.Remove(tempFilename)
@@ -2084,9 +2090,11 @@ func freeTestDB() {
20842090
if err != nil {
20852091
panic(err)
20862092
}
2087-
err = os.Remove(db.tempFilename)
2088-
if err != nil {
2089-
panic(err)
2093+
if db.tempFilename != "" && db.tempFilename != ":memory:" {
2094+
err := os.Remove(db.tempFilename)
2095+
if err != nil {
2096+
panic(err)
2097+
}
20902098
}
20912099
}
20922100

@@ -2107,6 +2115,7 @@ var tests = []testing.InternalTest{
21072115
var benchmarks = []testing.InternalBenchmark{
21082116
{Name: "BenchmarkExec", F: benchmarkExec},
21092117
{Name: "BenchmarkQuery", F: benchmarkQuery},
2118+
{Name: "BenchmarkQueryContext", F: benchmarkQueryContext},
21102119
{Name: "BenchmarkParams", F: benchmarkParams},
21112120
{Name: "BenchmarkStmt", F: benchmarkStmt},
21122121
{Name: "BenchmarkRows", F: benchmarkRows},
@@ -2479,6 +2488,65 @@ func benchmarkQuery(b *testing.B) {
24792488
}
24802489
}
24812490

2491+
// benchmarkQueryContext is benchmark for QueryContext
2492+
func benchmarkQueryContext(b *testing.B) {
2493+
const createTableStmt = `
2494+
CREATE TABLE IF NOT EXISTS query_context(
2495+
id INTEGER PRIMARY KEY
2496+
);
2497+
DELETE FROM query_context;
2498+
VACUUM;`
2499+
test := func(ctx context.Context, b *testing.B) {
2500+
if _, err := db.Exec(createTableStmt); err != nil {
2501+
b.Fatal(err)
2502+
}
2503+
for i := 0; i < 10; i++ {
2504+
_, err := db.Exec("INSERT INTO query_context VALUES (?);", int64(i))
2505+
if err != nil {
2506+
db.Fatal(err)
2507+
}
2508+
}
2509+
stmt, err := db.PrepareContext(ctx, `SELECT id FROM query_context;`)
2510+
if err != nil {
2511+
b.Fatal(err)
2512+
}
2513+
b.Cleanup(func() { stmt.Close() })
2514+
2515+
var n int
2516+
for i := 0; i < b.N; i++ {
2517+
rows, err := stmt.QueryContext(ctx)
2518+
if err != nil {
2519+
b.Fatal(err)
2520+
}
2521+
for rows.Next() {
2522+
if err := rows.Scan(&n); err != nil {
2523+
b.Fatal(err)
2524+
}
2525+
}
2526+
if err := rows.Err(); err != nil {
2527+
b.Fatal(err)
2528+
}
2529+
}
2530+
}
2531+
2532+
// When the context does not have a Done channel we should use
2533+
// the fast path that directly handles the query instead of
2534+
// handling it in a goroutine. This benchmark also serves to
2535+
// highlight the performance impact of using a cancelable
2536+
// context.
2537+
b.Run("Background", func(b *testing.B) {
2538+
test(context.Background(), b)
2539+
})
2540+
2541+
// Benchmark a query with a context that can be canceled. This
2542+
// requires using a goroutine and is thus much slower.
2543+
b.Run("WithCancel", func(b *testing.B) {
2544+
ctx, cancel := context.WithCancel(context.Background())
2545+
defer cancel()
2546+
test(ctx, b)
2547+
})
2548+
}
2549+
24822550
// benchmarkParams is benchmark for params
24832551
func benchmarkParams(b *testing.B) {
24842552
for i := 0; i < b.N; i++ {

0 commit comments

Comments
 (0)