|
| 1 | +import sqlite3 |
| 2 | +import os |
| 3 | +from string.templatelib import Template, Interpolation |
| 4 | + |
| 5 | +# Setup: Create demo database |
| 6 | +DB_PATH = "demo.db" |
| 7 | + |
| 8 | + |
| 9 | +def reset_database(): |
| 10 | + """Reset the database to initial state.""" |
| 11 | + if os.path.exists(DB_PATH): |
| 12 | + os.remove(DB_PATH) |
| 13 | + |
| 14 | + conn = sqlite3.connect(DB_PATH) |
| 15 | + cursor = conn.cursor() |
| 16 | + |
| 17 | + cursor.execute(""" |
| 18 | + CREATE TABLE users ( |
| 19 | + id INTEGER PRIMARY KEY, |
| 20 | + name TEXT NOT NULL, |
| 21 | + email TEXT NOT NULL |
| 22 | + ) |
| 23 | + """) |
| 24 | + |
| 25 | + users = [ |
| 26 | + |
| 27 | + |
| 28 | + |
| 29 | + ] |
| 30 | + |
| 31 | + cursor.executemany("INSERT INTO users (name, email) VALUES (?, ?)", users) |
| 32 | + conn.commit() |
| 33 | + conn.close() |
| 34 | + |
| 35 | + |
| 36 | +def get_table_count(cursor): |
| 37 | + """Return number of tables in database.""" |
| 38 | + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") |
| 39 | + return len(cursor.fetchall()) |
| 40 | + |
| 41 | + |
| 42 | +# Test unsafe f-string |
| 43 | +reset_database() |
| 44 | +conn = sqlite3.connect(DB_PATH) |
| 45 | +cursor = conn.cursor() |
| 46 | + |
| 47 | +user_input = "admin'; DROP TABLE users; --" |
| 48 | + |
| 49 | +# Unsafe f-string |
| 50 | +unsafe_query = f"SELECT * FROM users WHERE name = '{user_input}'" |
| 51 | +print(f"Query: {unsafe_query}") |
| 52 | +cursor.executescript(unsafe_query) |
| 53 | +print(f"Tables: {get_table_count(cursor)}\n") |
| 54 | + |
| 55 | +# Reset and test safe version |
| 56 | +reset_database() |
| 57 | +conn = sqlite3.connect(DB_PATH) |
| 58 | +cursor = conn.cursor() |
| 59 | + |
| 60 | + |
| 61 | +def safe_sql(template: Template) -> str: |
| 62 | + """Escape SQL values to prevent injection.""" |
| 63 | + result = [] |
| 64 | + for item in template: |
| 65 | + if isinstance(item, Interpolation): |
| 66 | + # Escape single quotes by doubling them |
| 67 | + safe_value = str(item.value).replace("'", "''") |
| 68 | + result.append(f"'{safe_value}'") |
| 69 | + else: |
| 70 | + result.append(item) |
| 71 | + return "".join(result) |
| 72 | + |
| 73 | + |
| 74 | +# Safe t-string |
| 75 | +query = t"SELECT * FROM users WHERE name = {user_input}" |
| 76 | + |
| 77 | +safe_query = safe_sql(query) |
| 78 | +print(f"Query: {safe_query}") |
| 79 | +cursor.execute(safe_query) |
| 80 | +print(f"Tables: {get_table_count(cursor)}") |
| 81 | + |
| 82 | +conn.close() |
0 commit comments