-
Notifications
You must be signed in to change notification settings - Fork 589
/
Copy pathdb_mig.py
103 lines (79 loc) · 3.48 KB
/
db_mig.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Database migration utilities."""
import logging
from typing import Callable
from sqlalchemy import Column, MetaData, inspect, text
from models.base import Base
logger = logging.getLogger(__name__)
async def add_column_if_not_exists(
conn, dialect, table_name: str, column: Column
) -> None:
"""Add a column to a table if it doesn't exist.
Args:
conn: SQLAlchemy conn
table_name: Name of the table
column: Column to add
"""
# Use run_sync to perform inspection on the connection
def _get_columns(connection):
inspector = inspect(connection)
return [c["name"] for c in inspector.get_columns(table_name)]
columns = await conn.run_sync(_get_columns)
if column.name not in columns:
# Build column definition
column_def = f"{column.name} {column.type.compile(dialect)}"
# Add DEFAULT if specified
if column.default is not None:
if hasattr(column.default, "arg"):
default_value = column.default.arg
if not isinstance(default_value, Callable):
if isinstance(default_value, bool):
default_value = str(default_value).lower()
elif isinstance(default_value, str):
default_value = f"'{default_value}'"
elif isinstance(default_value, (list, dict)):
default_value = "'{}'"
column_def += f" DEFAULT {default_value}"
# Execute ALTER TABLE
await conn.execute(text(f"ALTER TABLE {table_name} ADD COLUMN {column_def}"))
logger.info(f"Added column {column.name} to table {table_name}")
async def update_table_schema(conn, dialect, model_cls) -> None:
"""Update table schema by adding missing columns from the model.
Args:
conn: SQLAlchemy conn
dialect: SQLAlchemy dialect
model_cls: SQLAlchemy model class to check for new columns
"""
if not hasattr(model_cls, "__table__"):
return
table_name = model_cls.__tablename__
for name, column in model_cls.__table__.columns.items():
if name != "id": # Skip primary key
await add_column_if_not_exists(conn, dialect, table_name, column)
async def safe_migrate(engine) -> None:
"""Safely migrate all SQLAlchemy models by adding new columns.
Args:
engine: SQLAlchemy engine
"""
logger.info("Starting database schema migration")
dialect = engine.dialect
async with engine.begin() as conn:
try:
# Create tables if they don't exist
await conn.run_sync(Base.metadata.create_all)
# Get existing table metadata
metadata = MetaData()
await conn.run_sync(metadata.reflect)
# Update schema for all model classes
for mapper in Base.registry.mappers:
model_cls = mapper.class_
if hasattr(model_cls, "__tablename__"):
table_name = model_cls.__tablename__
if table_name in metadata.tables:
# We need a sync wrapper for the async update_table_schema
async def update_table_wrapper():
await update_table_schema(conn, dialect, model_cls)
await update_table_wrapper()
except Exception as e:
logger.error(f"Error updating database schema: {str(e)}")
raise
logger.info("Database schema updated successfully")