diff --git a/flou/alembic.ini b/flou/alembic.ini index 297d168..7ce0928 100644 --- a/flou/alembic.ini +++ b/flou/alembic.ini @@ -12,7 +12,7 @@ script_location = migrations file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s # sys.path path, will be prepended to sys.path if present. -# defaults to the current working directory. +#_utils defaults to the current working directory. prepend_sys_path = . # timezone to use when rendering the date within the migration file @@ -83,7 +83,7 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne # Logging configuration [loggers] -keys = root,sqlalchemy,alembic +keys = root,sqlalchemy,alembic,alembic_utils [handlers] keys = console @@ -106,6 +106,11 @@ level = INFO handlers = qualname = alembic +[logger_alembic_utils] +level = INFO +handlers = +qualname = alembic_utils + [handler_console] class = StreamHandler args = (sys.stderr,) diff --git a/flou/flou/experiments/models.py b/flou/flou/experiments/models.py index d768208..b330e18 100644 --- a/flou/flou/experiments/models.py +++ b/flou/flou/experiments/models.py @@ -4,6 +4,8 @@ from sqlalchemy import ForeignKey, text from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.types import String +from alembic_utils.pg_function import PGFunction +from alembic_utils.pg_trigger import PGTrigger from flou.database.models import Base from flou.database.utils import JSONType @@ -15,7 +17,7 @@ class Experiment(Base): id: Mapped[uuid.UUID] = mapped_column( primary_key=True, server_default=text("gen_random_uuid()") ) - index: Mapped[str] = mapped_column(default=0, nullable=False) + index: Mapped[int] = mapped_column(default=0, nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) description: Mapped[str] = mapped_column(String(), nullable=False) inputs: Mapped[dict] = mapped_column(JSONType(), default=dict, nullable=False) @@ -24,13 +26,43 @@ class Experiment(Base): trials: Mapped[List["Trial"]] = relationship(back_populates="experiment") +# Define the trigger function using alembic_utils +experiments_set_index = PGFunction( + schema="public", + signature="experiments_set_index()", + definition=""" + RETURNS trigger AS $$ + BEGIN + NEW.index := COALESCE( + (SELECT MAX(index) FROM experiments_experiments), -1 + ) + 1; + RETURN NEW; + END; + $$ LANGUAGE plpgsql; + """, +) + + +# Define the trigger using alembic_utils +experiments_set_index_trigger = PGTrigger( + schema="public", + signature="experiments_set_index_trigger", + on_entity="public.experiments_experiments", + is_constraint=False, + definition=""" + BEFORE INSERT ON public.experiments_experiments + FOR EACH ROW EXECUTE FUNCTION public.experiments_set_index(); + """, +) + + class Trial(Base): __tablename__ = "experiments_trials" id: Mapped[uuid.UUID] = mapped_column( primary_key=True, server_default=text("gen_random_uuid()") ) - index: Mapped[str] = mapped_column(default=0, nullable=False) + index: Mapped[int] = mapped_column(default=0, nullable=False) experiment_id: Mapped[int] = mapped_column(ForeignKey("experiments_experiments.id")) name: Mapped[str] = mapped_column(String(255), nullable=False) ltm_id: Mapped[int] = mapped_column(ForeignKey("ltm_ltms.id"), nullable=False) @@ -40,3 +72,33 @@ class Trial(Base): outputs: Mapped[dict] = mapped_column(JSONType(), default=dict, nullable=False) experiment: Mapped[Experiment] = relationship("Experiment", back_populates="trials") + + +# Define the trigger function using alembic_utils +trials_set_index = PGFunction( + schema="public", + signature="trials_set_index()", + definition=""" + RETURNS trigger AS $$ + BEGIN + NEW.index := COALESCE( + (SELECT MAX(index) FROM experiments_trials WHERE experiment_id = NEW.experiment_id), -1 + ) + 1; + RETURN NEW; + END; + $$ LANGUAGE plpgsql; + """, +) + + +# Define the trigger using alembic_utils +trials_set_index_trigger = PGTrigger( + schema="public", + signature="trials_set_index_trigger", + on_entity="public.experiments_trials", + is_constraint=False, + definition=""" + BEFORE INSERT ON public.experiments_trials + FOR EACH ROW EXECUTE FUNCTION public.trials_set_index(); + """, +) diff --git a/flou/migrations/env.py b/flou/migrations/env.py index 7fcc050..53a54a0 100644 --- a/flou/migrations/env.py +++ b/flou/migrations/env.py @@ -7,6 +7,9 @@ from sqlalchemy.exc import InvalidRequestError from alembic import context +from alembic_utils.replaceable_entity import register_entities +from alembic_utils.pg_trigger import PGTrigger +from alembic_utils.pg_function import PGFunction # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -21,20 +24,34 @@ # for 'autogenerate' support # from myapp import mymodel # target_metadata = mymodel.Base.metadata +import flou from flou.database.models import Base target_metadata = Base.metadata # Function to dynamically import all `models.py` and `models/` from apps def import_all_models(): root_package = 'flou' + alembic_utils_entities = [] + package = importlib.import_module(root_package) for importer, modname, ispkg in pkgutil.walk_packages(package.__path__, package.__name__ + '.'): if modname.endswith('.models'): - importlib.import_module(modname) + models_module = importlib.import_module(modname) + # look for PGTriggers which need to be manually added to + # alembic_utils' `register_entitis` + for _, variable in models_module.__dict__.items(): + + if isinstance(variable, (PGTrigger, PGFunction, )): + # Ensure variable is not a subclass + if variable.__class__ in (PGTrigger, PGFunction, ): + alembic_utils_entities.append(variable) + + register_entities(alembic_utils_entities) # register all entities +# setup models & triggers try: - import_all_models() + import_all_models() # add every model to the DeclarativeBase except InvalidRequestError: pass # don't break on tests diff --git a/flou/migrations/versions/2024_11_04_2001-d17bb320f4d3_.py b/flou/migrations/versions/2024_11_04_2001-d17bb320f4d3_.py deleted file mode 100644 index c33dae4..0000000 --- a/flou/migrations/versions/2024_11_04_2001-d17bb320f4d3_.py +++ /dev/null @@ -1,61 +0,0 @@ -"""empty message - -Revision ID: d17bb320f4d3 -Revises: bb20ac30a018 -Create Date: 2024-11-04 20:01:47.492583 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - -import flou - - -# revision identifiers, used by Alembic. -revision: str = 'd17bb320f4d3' -down_revision: Union[str, None] = 'bb20ac30a018' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table('experiments_experiments', - sa.Column('id', sa.Uuid(), nullable=False), - sa.Column('index', sa.String(), nullable=False), - sa.Column('name', sa.String(length=255), nullable=False), - sa.Column('description', sa.String(), nullable=False), - sa.Column('ltm_id', sa.Integer(), nullable=False), - sa.Column('rollback_index', sa.Integer(), nullable=False), - sa.Column('snapshot_index', sa.Integer(), nullable=False), - sa.Column('inputs', flou.database.utils.JSONType(), nullable=False), - sa.Column('outputs', flou.database.utils.JSONType(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.ForeignKeyConstraint(['ltm_id'], ['ltm_ltms.id'], ), - sa.PrimaryKeyConstraint('id') - ) - op.create_table('experiments_trials', - sa.Column('id', sa.Uuid(), nullable=False), - sa.Column('index', sa.String(), nullable=False), - sa.Column('experiment_id', sa.Uuid(), nullable=False), - sa.Column('name', sa.String(length=255), nullable=False), - sa.Column('rollback_index', sa.Integer(), nullable=False), - sa.Column('snapshot_index', sa.Integer(), nullable=False), - sa.Column('inputs', flou.database.utils.JSONType(), nullable=False), - sa.Column('outputs', flou.database.utils.JSONType(), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), - sa.ForeignKeyConstraint(['experiment_id'], ['experiments_experiments.id'], ), - sa.PrimaryKeyConstraint('id') - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('experiments_trials') - op.drop_table('experiments_experiments') - # ### end Alembic commands ### diff --git a/flou/migrations/versions/2024_11_06_1430-69c9354bb7ff_.py b/flou/migrations/versions/2024_11_06_1430-69c9354bb7ff_.py deleted file mode 100644 index 5c3bc1e..0000000 --- a/flou/migrations/versions/2024_11_06_1430-69c9354bb7ff_.py +++ /dev/null @@ -1,40 +0,0 @@ -"""empty message - -Revision ID: 69c9354bb7ff -Revises: d17bb320f4d3 -Create Date: 2024-11-06 14:30:21.018516 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = '69c9354bb7ff' -down_revision: Union[str, None] = 'd17bb320f4d3' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint('experiments_experiments_ltm_id_fkey', 'experiments_experiments', type_='foreignkey') - op.drop_column('experiments_experiments', 'rollback_index') - op.drop_column('experiments_experiments', 'ltm_id') - op.drop_column('experiments_experiments', 'snapshot_index') - op.add_column('experiments_trials', sa.Column('ltm_id', sa.Integer(), nullable=False)) - op.create_foreign_key(None, 'experiments_trials', 'ltm_ltms', ['ltm_id'], ['id']) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_constraint(None, 'experiments_trials', type_='foreignkey') - op.drop_column('experiments_trials', 'ltm_id') - op.add_column('experiments_experiments', sa.Column('snapshot_index', sa.INTEGER(), autoincrement=False, nullable=False)) - op.add_column('experiments_experiments', sa.Column('ltm_id', sa.INTEGER(), autoincrement=False, nullable=False)) - op.add_column('experiments_experiments', sa.Column('rollback_index', sa.INTEGER(), autoincrement=False, nullable=False)) - op.create_foreign_key('experiments_experiments_ltm_id_fkey', 'experiments_experiments', 'ltm_ltms', ['ltm_id'], ['id']) - # ### end Alembic commands ### diff --git a/flou/migrations/versions/2024_11_06_1432-076b9aea5f59_.py b/flou/migrations/versions/2024_11_06_1432-076b9aea5f59_.py deleted file mode 100644 index 350bb08..0000000 --- a/flou/migrations/versions/2024_11_06_1432-076b9aea5f59_.py +++ /dev/null @@ -1,28 +0,0 @@ -"""empty message - -Revision ID: 076b9aea5f59 -Revises: 69c9354bb7ff -Create Date: 2024-11-06 14:32:55.161297 - -""" -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision: str = '076b9aea5f59' -down_revision: Union[str, None] = '69c9354bb7ff' -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - op.alter_column('experiments_experiments', 'id', server_default=sa.text('gen_random_uuid()')) - op.alter_column('experiments_trials', 'id', server_default=sa.text('gen_random_uuid()')) - - -def downgrade() -> None: - op.alter_column('experiments_experiments', 'id', server_default=None) - op.alter_column('experiments_trials', 'id', server_default=None) diff --git a/flou/pyproject.toml b/flou/pyproject.toml index a4ec71c..e030cd1 100644 --- a/flou/pyproject.toml +++ b/flou/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "watchdog", "sqlalchemy[asyncio]", "alembic", + "alembic_utils", "psycopg", ]