Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions app/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from urllib.parse import unquote, urlparse

from sqlalchemy import create_engine, event, inspect, text
from sqlalchemy.engine import Engine
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.orm import Session, sessionmaker

from app.models import Base
Expand Down Expand Up @@ -46,29 +46,41 @@ def create_schema(database_url: str) -> None:
engine.dispose()


def _add_bounty_column_if_missing(
connection: Connection,
existing_columns: set[str],
column_name: str,
column_definition: str,
) -> bool:
if column_name in existing_columns:
return False
connection.execute(text(f"ALTER TABLE bounties ADD COLUMN {column_name} {column_definition}"))
existing_columns.add(column_name)
return True


def _migrate_schema(engine: Engine) -> None:
inspector = inspect(engine)
if "bounties" not in inspector.get_table_names():
return
bounty_columns = {column["name"] for column in inspector.get_columns("bounties")}
bounty_column_migrations = (
("max_awards", "INTEGER NOT NULL DEFAULT 1", None),
(
"awards_paid",
"INTEGER NOT NULL DEFAULT 0",
"UPDATE bounties SET awards_paid = 1 WHERE status = 'paid'",
),
("github_paid_issue_finalized_at", "TIMESTAMP", None),
("github_paid_issue_finalization", "TEXT", None),
)
with engine.begin() as connection:
if "max_awards" not in bounty_columns:
connection.execute(
text("ALTER TABLE bounties ADD COLUMN max_awards INTEGER NOT NULL DEFAULT 1")
)
if "awards_paid" not in bounty_columns:
connection.execute(
text("ALTER TABLE bounties ADD COLUMN awards_paid INTEGER NOT NULL DEFAULT 0")
)
connection.execute(text("UPDATE bounties SET awards_paid = 1 WHERE status = 'paid'"))
if "github_paid_issue_finalized_at" not in bounty_columns:
connection.execute(
text("ALTER TABLE bounties ADD COLUMN github_paid_issue_finalized_at TIMESTAMP")
)
if "github_paid_issue_finalization" not in bounty_columns:
connection.execute(
text("ALTER TABLE bounties ADD COLUMN github_paid_issue_finalization TEXT")
for column_name, column_definition, backfill_sql in bounty_column_migrations:
added = _add_bounty_column_if_missing(
connection, bounty_columns, column_name, column_definition
)
if added and backfill_sql is not None:
connection.execute(text(backfill_sql))
if "submissions" in inspector.get_table_names():
connection.execute(
text(
Expand Down
14 changes: 12 additions & 2 deletions tests/test_ledger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import pytest
from sqlalchemy import select
from sqlalchemy import inspect, select

from app.db import create_schema, make_engine, session_scope
from app.ledger.service import (
Expand Down Expand Up @@ -628,7 +628,9 @@ def test_bounty_max_awards_must_be_positive(sqlite_url: str) -> None:
)


def test_create_schema_migrates_existing_bounty_award_columns(sqlite_url: str) -> None:
def test_create_schema_migrates_existing_bounty_columns_and_submission_index(
sqlite_url: str,
) -> None:
engine = make_engine(sqlite_url)
with engine.begin() as connection:
connection.exec_driver_sql(
Expand Down Expand Up @@ -669,6 +671,14 @@ def test_create_schema_migrates_existing_bounty_award_columns(sqlite_url: str) -
assert bounty is not None
assert bounty.max_awards == 1
assert bounty.awards_paid == 1
assert bounty.github_paid_issue_finalized_at is None
assert bounty.github_paid_issue_finalization is None

engine = make_engine(sqlite_url)
submission_indexes = inspect(engine).get_indexes("submissions")
engine.dispose()

assert any(index["name"] == "uq_submission_bounty_url" for index in submission_indexes)


def test_hash_chain_detects_tampering(sqlite_url: str) -> None:
Expand Down