diff --git a/app/db.py b/app/db.py index e2900c27..cb822f9d 100644 --- a/app/db.py +++ b/app/db.py @@ -6,7 +6,7 @@ from urllib.parse import unquote, urlparse from sqlalchemy import create_engine, event, inspect, text -from sqlalchemy.engine import Engine +from sqlalchemy.engine import Connection, Engine from sqlalchemy.orm import Session, sessionmaker from app.models import Base @@ -46,29 +46,41 @@ def create_schema(database_url: str) -> None: engine.dispose() +def _add_bounty_column_if_missing( + connection: Connection, + existing_columns: set[str], + column_name: str, + column_definition: str, +) -> bool: + if column_name in existing_columns: + return False + connection.execute(text(f"ALTER TABLE bounties ADD COLUMN {column_name} {column_definition}")) + existing_columns.add(column_name) + return True + + def _migrate_schema(engine: Engine) -> None: inspector = inspect(engine) if "bounties" not in inspector.get_table_names(): return bounty_columns = {column["name"] for column in inspector.get_columns("bounties")} + bounty_column_migrations = ( + ("max_awards", "INTEGER NOT NULL DEFAULT 1", None), + ( + "awards_paid", + "INTEGER NOT NULL DEFAULT 0", + "UPDATE bounties SET awards_paid = 1 WHERE status = 'paid'", + ), + ("github_paid_issue_finalized_at", "TIMESTAMP", None), + ("github_paid_issue_finalization", "TEXT", None), + ) with engine.begin() as connection: - if "max_awards" not in bounty_columns: - connection.execute( - text("ALTER TABLE bounties ADD COLUMN max_awards INTEGER NOT NULL DEFAULT 1") - ) - if "awards_paid" not in bounty_columns: - connection.execute( - text("ALTER TABLE bounties ADD COLUMN awards_paid INTEGER NOT NULL DEFAULT 0") - ) - connection.execute(text("UPDATE bounties SET awards_paid = 1 WHERE status = 'paid'")) - if "github_paid_issue_finalized_at" not in bounty_columns: - connection.execute( - text("ALTER TABLE bounties ADD COLUMN github_paid_issue_finalized_at TIMESTAMP") - ) - if "github_paid_issue_finalization" not in bounty_columns: - connection.execute( - text("ALTER TABLE bounties ADD COLUMN github_paid_issue_finalization TEXT") + for column_name, column_definition, backfill_sql in bounty_column_migrations: + added = _add_bounty_column_if_missing( + connection, bounty_columns, column_name, column_definition ) + if added and backfill_sql is not None: + connection.execute(text(backfill_sql)) if "submissions" in inspector.get_table_names(): connection.execute( text( diff --git a/tests/test_ledger.py b/tests/test_ledger.py index 2f558293..6378058f 100644 --- a/tests/test_ledger.py +++ b/tests/test_ledger.py @@ -1,7 +1,7 @@ from __future__ import annotations import pytest -from sqlalchemy import select +from sqlalchemy import inspect, select from app.db import create_schema, make_engine, session_scope from app.ledger.service import ( @@ -628,7 +628,9 @@ def test_bounty_max_awards_must_be_positive(sqlite_url: str) -> None: ) -def test_create_schema_migrates_existing_bounty_award_columns(sqlite_url: str) -> None: +def test_create_schema_migrates_existing_bounty_columns_and_submission_index( + sqlite_url: str, +) -> None: engine = make_engine(sqlite_url) with engine.begin() as connection: connection.exec_driver_sql( @@ -669,6 +671,14 @@ def test_create_schema_migrates_existing_bounty_award_columns(sqlite_url: str) - assert bounty is not None assert bounty.max_awards == 1 assert bounty.awards_paid == 1 + assert bounty.github_paid_issue_finalized_at is None + assert bounty.github_paid_issue_finalization is None + + engine = make_engine(sqlite_url) + submission_indexes = inspect(engine).get_indexes("submissions") + engine.dispose() + + assert any(index["name"] == "uq_submission_bounty_url" for index in submission_indexes) def test_hash_chain_detects_tampering(sqlite_url: str) -> None: