Skip to content

Commit 35d6745

Browse files
authored
Merge pull request #111 from tailscale/alisdair/connlogger
sqlite: add optional ConnLogger
2 parents 3a6395a + 731f626 commit 35d6745

File tree

2 files changed

+244
-6
lines changed

2 files changed

+244
-6
lines changed

sqlite.go

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,22 @@ func Connector(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Trace
127127
}
128128
}
129129

130+
// ConnectorWithLogger returns a [driver.Connector] for the given connection
131+
// parameters. makeLogger is used to create a [ConnLogger] when [Connect] is
132+
// called.
133+
func ConnectorWithLogger(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Tracer, makeLogger func() ConnLogger) driver.Connector {
134+
return &connector{
135+
name: sqliteURI,
136+
tracer: tracer,
137+
makeLogger: makeLogger,
138+
connInitFunc: connInitFunc,
139+
}
140+
}
141+
130142
type connector struct {
131143
name string
132144
tracer sqliteh.Tracer
145+
makeLogger func() ConnLogger // or nil
133146
connInitFunc ConnInitFunc
134147
}
135148

@@ -152,12 +165,14 @@ func (p *connector) Connect(ctx context.Context) (driver.Conn, error) {
152165
}
153166
return nil, err
154167
}
155-
156168
c := &conn{
157169
db: db,
158170
tracer: p.tracer,
159171
id: sqliteh.TraceConnID(maxConnID.Add(1)),
160172
}
173+
if p.makeLogger != nil {
174+
c.logger = p.makeLogger()
175+
}
161176
if p.connInitFunc != nil {
162177
if err := p.connInitFunc(ctx, c); err != nil {
163178
db.Close()
@@ -179,6 +194,7 @@ type conn struct {
179194
db sqliteh.DB
180195
id sqliteh.TraceConnID
181196
tracer sqliteh.Tracer
197+
logger ConnLogger
182198
stmts map[string]*stmt // persisted statements
183199
txState txState
184200
readOnly bool
@@ -202,6 +218,7 @@ func (c *conn) Close() error {
202218
err := reserr(c.db, "Conn.Close", "", c.db.Close())
203219
return err
204220
}
221+
205222
func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
206223
persist := ctx.Value(persistQuery{}) != nil
207224
return c.prepare(ctx, query, persist)
@@ -341,6 +358,9 @@ func (c *conn) txInit(ctx context.Context) error {
341358
return err
342359
}
343360
} else {
361+
if c.logger != nil {
362+
c.logger.Begin()
363+
}
344364
// TODO(crawshaw): offer BEGIN DEFERRED (and BEGIN CONCURRENT?)
345365
// semantics via a context annotation function.
346366
if err := c.execInternal(ctx, "BEGIN IMMEDIATE"); err != nil {
@@ -351,15 +371,16 @@ func (c *conn) txInit(ctx context.Context) error {
351371
}
352372

353373
func (c *conn) txEnd(ctx context.Context, endStmt string) error {
354-
state, readOnly := c.txState, c.readOnly
355-
c.txState = txStateNone
356-
c.readOnly = false
357-
if state != txStateBegun {
374+
defer func() {
375+
c.txState = txStateNone
376+
c.readOnly = false
377+
}()
378+
if c.txState != txStateBegun {
358379
return nil
359380
}
360381

361382
err := c.execInternal(context.Background(), endStmt)
362-
if readOnly {
383+
if c.readOnly {
363384
if err2 := c.execInternal(ctx, "PRAGMA query_only=false"); err == nil {
364385
err = err2
365386
}
@@ -377,10 +398,14 @@ func (tx *connTx) Commit() error {
377398
return ErrClosed
378399
}
379400

401+
readonly := tx.conn.readOnly
380402
err := tx.conn.txEnd(context.Background(), "COMMIT")
381403
if tx.conn.tracer != nil {
382404
tx.conn.tracer.Commit(tx.conn.id, err)
383405
}
406+
if tx.conn.logger != nil && !readonly {
407+
tx.conn.logger.Commit(err)
408+
}
384409
return err
385410
}
386411

@@ -390,10 +415,14 @@ func (tx *connTx) Rollback() error {
390415
return ErrClosed
391416
}
392417

418+
readonly := tx.conn.readOnly
393419
err := tx.conn.txEnd(context.Background(), "ROLLBACK")
394420
if tx.conn.tracer != nil {
395421
tx.conn.tracer.Rollback(tx.conn.id, err)
396422
}
423+
if tx.conn.logger != nil && !readonly {
424+
tx.conn.logger.Rollback()
425+
}
397426
return err
398427
}
399428

@@ -490,6 +519,9 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
490519
if err := s.bindAll(args); err != nil {
491520
return nil, s.reserr("Stmt.Exec(Bind)", err)
492521
}
522+
if s.conn.logger != nil && !s.conn.readOnly {
523+
s.conn.logger.Statement(s.stmt.ExpandedSQL())
524+
}
493525

494526
if ctx.Value(queryCancelKey{}) != nil {
495527
done := make(chan struct{})
@@ -1068,3 +1100,25 @@ func WithQueryCancel(ctx context.Context) context.Context {
10681100

10691101
// queryCancelKey is a context key for query context enforcement.
10701102
type queryCancelKey struct{}
1103+
1104+
// ConnLogger is implemented by the caller to support statement-level logging
1105+
// for write transactions. Only Exec calls are logged, not Query calls, as this
1106+
// is intended as a mechanism to replay failed transactions.
1107+
//
1108+
// Aside from logging only executed statements, ConnLogger also differs from
1109+
// [sqliteh.Tracer] by logging the expanded SQL, instead of the query with
1110+
// placeholders.
1111+
type ConnLogger interface {
1112+
// Begin is called when a writable transaction is opened.
1113+
Begin()
1114+
1115+
// Statement is called with evaluated SQL when a statement is executed.
1116+
Statement(sql string)
1117+
1118+
// Commit is called after a commit statement, with the error resulting
1119+
// from the attempted commit.
1120+
Commit(error)
1121+
1122+
// Rollback is called after a rollback statement.
1123+
Rollback()
1124+
}

sqlite_test.go

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"os"
1414
"reflect"
1515
"runtime"
16+
"slices"
1617
"strconv"
1718
"strings"
1819
"sync"
@@ -1354,3 +1355,186 @@ func TestDisableFunction(t *testing.T) {
13541355
t.Fatal("Attempting to use the LOWER function after disabling should have failed")
13551356
}
13561357
}
1358+
1359+
type connLogger struct {
1360+
ch chan []string
1361+
statements []string
1362+
panicOnUse bool
1363+
}
1364+
1365+
func (cl *connLogger) Begin() {
1366+
if cl.panicOnUse {
1367+
panic("unexpected connLogger.Begin()")
1368+
}
1369+
cl.statements = nil
1370+
}
1371+
1372+
func (cl *connLogger) Statement(s string) {
1373+
if cl.panicOnUse {
1374+
panic("unexpected connLogger.Statement: " + s)
1375+
}
1376+
cl.statements = append(cl.statements, s)
1377+
}
1378+
1379+
func (cl *connLogger) Commit(err error) {
1380+
if cl.panicOnUse {
1381+
panic("unexpected connLogger.Commit()")
1382+
}
1383+
if err != nil {
1384+
return
1385+
}
1386+
cl.ch <- cl.statements
1387+
}
1388+
1389+
func (cl *connLogger) Rollback() {
1390+
if cl.panicOnUse {
1391+
panic("unexpected connLogger.Rollback()")
1392+
}
1393+
cl.statements = nil
1394+
}
1395+
1396+
func TestConnLogger_writable(t *testing.T) {
1397+
for _, commit := range []bool{true, false} {
1398+
doneStatement := "ROLLBACK"
1399+
if commit {
1400+
doneStatement = "COMMIT"
1401+
}
1402+
t.Run(doneStatement, func(t *testing.T) {
1403+
ctx := context.Background()
1404+
ch := make(chan []string, 1)
1405+
txl := connLogger{ch: ch}
1406+
makeLogger := func() ConnLogger { return &txl }
1407+
db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger))
1408+
configDB(t, db)
1409+
1410+
tx, err := db.BeginTx(ctx, nil)
1411+
if err != nil {
1412+
t.Fatal(err)
1413+
}
1414+
if _, err := tx.Exec("CREATE TABLE T (x INTEGER)"); err != nil {
1415+
t.Fatal(err)
1416+
}
1417+
if _, err := tx.Exec("INSERT INTO T VALUES (?)", 1); err != nil {
1418+
t.Fatal(err)
1419+
}
1420+
if _, err := tx.Query("SELECT x FROM T"); err != nil {
1421+
t.Fatal(err)
1422+
}
1423+
done := tx.Rollback
1424+
if commit {
1425+
done = tx.Commit
1426+
}
1427+
if err := done(); err != nil {
1428+
t.Fatal(err)
1429+
}
1430+
if !commit {
1431+
select {
1432+
case got := <-ch:
1433+
t.Errorf("unexpectedly logged statements for rollback:\n%s", strings.Join(got, "\n"))
1434+
default:
1435+
return
1436+
}
1437+
}
1438+
1439+
want := []string{
1440+
"BEGIN IMMEDIATE",
1441+
"CREATE TABLE T (x INTEGER)",
1442+
"INSERT INTO T VALUES (1)",
1443+
doneStatement,
1444+
}
1445+
select {
1446+
case got := <-ch:
1447+
if !slices.Equal(got, want) {
1448+
t.Errorf("unexpected log statements. got:\n%s\n\nwant:\n%s", strings.Join(got, "\n"), strings.Join(want, "\n"))
1449+
}
1450+
default:
1451+
t.Fatal("no logged statements after commit")
1452+
}
1453+
})
1454+
}
1455+
}
1456+
1457+
func TestConnLogger_commit_error(t *testing.T) {
1458+
ctx := context.Background()
1459+
ch := make(chan []string, 1)
1460+
txl := connLogger{ch: ch}
1461+
makeLogger := func() ConnLogger { return &txl }
1462+
db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger))
1463+
configDB(t, db)
1464+
1465+
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
1466+
t.Fatal(err)
1467+
}
1468+
if _, err := db.Exec("CREATE TABLE A (x INTEGER PRIMARY KEY)"); err != nil {
1469+
t.Fatal(err)
1470+
}
1471+
if _, err := db.Exec("CREATE TABLE B (x INTEGER REFERENCES A(X) DEFERRABLE INITIALLY DEFERRED)"); err != nil {
1472+
t.Fatal(err)
1473+
}
1474+
1475+
tx, err := db.BeginTx(ctx, nil)
1476+
if err != nil {
1477+
t.Fatal(err)
1478+
}
1479+
if _, err := tx.Exec("INSERT INTO B VALUES (?)", 1); err != nil {
1480+
t.Fatal(err)
1481+
}
1482+
if err := tx.Commit(); err == nil {
1483+
t.Fatal("expected Commit to error, but didn't")
1484+
}
1485+
select {
1486+
case got := <-ch:
1487+
t.Errorf("unexpectedly logged statements for errored commit:\n%s", strings.Join(got, "\n"))
1488+
default:
1489+
return
1490+
}
1491+
}
1492+
1493+
func TestConnLogger_read_tx(t *testing.T) {
1494+
ctx := context.Background()
1495+
ch := make(chan []string, 1)
1496+
txl := connLogger{ch: ch}
1497+
makeLogger := func() ConnLogger { return &txl }
1498+
db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger))
1499+
configDB(t, db)
1500+
1501+
tx, err := db.BeginTx(ctx, nil)
1502+
if err != nil {
1503+
t.Fatal(err)
1504+
}
1505+
if _, err := tx.Exec("CREATE TABLE T (x INTEGER)"); err != nil {
1506+
t.Fatal(err)
1507+
}
1508+
if _, err := tx.Exec("INSERT INTO T VALUES (?)", 1); err != nil {
1509+
t.Fatal(err)
1510+
}
1511+
if err := tx.Commit(); err != nil {
1512+
t.Fatal(err)
1513+
}
1514+
select {
1515+
case got := <-ch:
1516+
if len(got) == 0 {
1517+
t.Errorf("expected logged statements for write tx")
1518+
}
1519+
default:
1520+
t.Errorf("expected logged statements for write tx")
1521+
}
1522+
1523+
txl.panicOnUse = true
1524+
for _, commit := range []bool{true, false} {
1525+
rx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
1526+
if err != nil {
1527+
t.Fatal(err)
1528+
}
1529+
if _, err := rx.Query("SELECT x FROM T"); err != nil {
1530+
t.Fatal(err)
1531+
}
1532+
done := rx.Rollback
1533+
if commit {
1534+
done = rx.Commit
1535+
}
1536+
if err := done(); err != nil {
1537+
t.Fatal(err)
1538+
}
1539+
}
1540+
}

0 commit comments

Comments
 (0)