Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions db/db_sqlite3.c
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,44 @@ static bool db_sqlite3_setup(struct db *db, bool create)
"PRAGMA foreign_keys = ON;", -1, &stmt, NULL);
err = sqlite3_step(stmt);
sqlite3_finalize(stmt);
return err == SQLITE_DONE;

if (err != SQLITE_DONE)
return false;

if (db->developer) {
sqlite3_prepare_v2(conn2sql(db->conn),
"PRAGMA trusted_schema = OFF;", -1, &stmt, NULL);
sqlite3_step(stmt);
sqlite3_finalize(stmt);

sqlite3_prepare_v2(conn2sql(db->conn),
"PRAGMA cell_size_check = ON;", -1, &stmt, NULL);
sqlite3_step(stmt);
sqlite3_finalize(stmt);
}

return true;
}

static bool db_sqlite3_query(struct db_stmt *stmt)
{
sqlite3_stmt *s;
sqlite3 *conn = conn2sql(stmt->db->conn);
int err;
const char *query = stmt->query->query;
char *modified_query = NULL;

if (stmt->db->developer &&
strncasecmp(query, "CREATE TABLE", 12) == 0 &&
!strstr(query, "STRICT")) {
modified_query = tal_fmt(stmt, "%s STRICT", query);
query = modified_query;
}

err = sqlite3_prepare_v2(conn, query, -1, &s, NULL);

err = sqlite3_prepare_v2(conn, stmt->query->query, -1, &s, NULL);
if (modified_query)
tal_free(modified_query);

for (size_t i=0; i<stmt->query->placeholders; i++) {
struct db_binding *b = &stmt->bindings[i];
Expand Down
2 changes: 2 additions & 0 deletions devtools/sql-rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def rewrite_single(self, query):
r'BIGINT': 'INTEGER',
r'BIGINTEGER': 'INTEGER',
r'BIGSERIAL': 'INTEGER',
r'VARCHAR(?:\(\d+\))?': 'TEXT',
r'\bINT\b': 'INTEGER',
r'CURRENT_TIMESTAMP\(\)': "strftime('%s', 'now')",
r'INSERT INTO[ \t]+(.*)[ \t]+ON CONFLICT.*DO NOTHING;': 'INSERT OR IGNORE INTO \\1;',
# Rewrite "decode('abcd', 'hex')" to become "x'abcd'"
Expand Down
19 changes: 19 additions & 0 deletions tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,3 +642,22 @@ def test_channel_htlcs_id_change(bitcoind, node_factory):
# Make some HTLCS
for amt in (100, 500, 1000, 5000, 10000, 50000, 100000):
l1.pay(l3, amt)


@unittest.skipIf(os.getenv('TEST_DB_PROVIDER', 'sqlite3') != 'sqlite3', "STRICT tables are SQLite3 specific")
def test_sqlite_strict_mode(node_factory):
"""Test that STRICT is appended to CREATE TABLE in developer mode."""
l1 = node_factory.get_node(options={'developer': None})

# Query sqlite_master to check table definitions
tables = l1.db_query("SELECT name, sql FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")

strict_tables = [t for t in tables if t['sql'] and 'STRICT' in t['sql']]
assert len(strict_tables) > 0, f"Expected at least one STRICT table in developer mode, found none out of {len(tables)}"

# Check specific tables we know should be STRICT in developer mode
known_strict_tables = ['version', 'forwards', 'payments', 'local_anchors', 'addresses']
for table_name in known_strict_tables:
table_sql = next((t['sql'] for t in tables if t['name'] == table_name), None)
if table_sql:
assert 'STRICT' in table_sql, f"Expected table '{table_name}' to be STRICT in developer mode"