Skip to content

Commit

Permalink
add data transformation to migration
Browse files Browse the repository at this point in the history
  • Loading branch information
tianj7 committed Sep 6, 2023
1 parent c9d6b6a commit da1599a
Showing 1 changed file with 112 additions and 10 deletions.
122 changes: 112 additions & 10 deletions migrations/versions/9b3a5a7145d7_authlib_update_1_2_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,35 @@
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy import Column, String
from sqlalchemy.orm import Session
from fence.models import Client
import json

# revision identifiers, used by Alembic.
revision = "9b3a5a7145d7" # pragma: allowlist secret
down_revision = "a04a70296688" # pragma: allowlist secret
branch_labels = None
depends_on = None
from authlib.common.encoding import json_loads, json_dumps


def upgrade():
# Add New Columns for client Table
op.add_column("client", sa.Column("client_metadata", sa.Text(), nullable=True))
op.add_column(
"client", sa.Column("client_secret_expires_at", sa.Integer(), nullable=False)
"client",
sa.Column(
"client_secret_expires_at", sa.Integer(), nullable=False, server_default="0"
),
)

# Modify Columns for client Table
op.alter_column("client", "issued_at", new_column_name="client_id_issued_at")
op.alter_column("client", "client_id", nullable=False, type_=sa.String(48))
op.alter_column("client", "client_secret", nullable=True, type_=sa.String(120))

set_metadata_values(op)

# Delete Columns for client Table
op.drop_column("client", "redirect_uri")
op.drop_column("client", "token_endpoint_auth_method")
Expand Down Expand Up @@ -60,10 +68,6 @@ def upgrade():

def downgrade():

# Add New Columns for client Table
op.drop_column("client", "client_metadata")
op.drop_column("client", "client_secret_expires_at")

# Modify Columns for client Table
op.alter_column("client", "client_id_issued_at", new_column_name="issued_at")
op.alter_column("client", "client_id", nullable=False, type_=sa.String(40))
Expand All @@ -75,9 +79,16 @@ def downgrade():
"client",
sa.Column("token_endpoint_auth_method", sa.String(length=48), nullable=True),
)
op.add_column("client", sa.Column("grant_type", sa.Text(), nullable=False))
op.add_column("client", sa.Column("response_type", sa.Text(), nullable=False))
op.add_column("client", sa.Column("scope", sa.Text(), nullable=False))
op.add_column(
"client", sa.Column("grant_type", sa.Text(), nullable=False, server_default="")
)
op.add_column(
"client",
sa.Column("response_type", sa.Text(), nullable=False, server_default=""),
)
op.add_column(
"client", sa.Column("scope", sa.Text(), nullable=False, server_default="")
)
op.add_column(
"client", sa.Column("client_name", sa.String(length=100), nullable=True)
)
Expand All @@ -95,9 +106,100 @@ def downgrade():
op.add_column(
"client", sa.Column("software_version", sa.String(length=48), nullable=True)
)
op.add_column("client", sa.Column("_allowed_scopes", sa.Text(), nullable=False))
op.add_column(
"client",
sa.Column("_allowed_scopes", sa.Text(), nullable=False, server_default=""),
)
op.add_column("client", sa.Column("_redirect_uris", sa.Text(), nullable=True))

set_old_column_values()

# Drop New Columns for client Table
op.drop_column("client", "client_metadata")
op.drop_column("client", "client_secret_expires_at")

# Remove New Columns for authorization_code Table
op.drop_column("authorization_code", "code_challenge")
op.drop_column("authorization_code", "code_challenge_method")


def set_metadata_values(op):
conn = op.get_bind()
session = Session(bind=conn)
for client in session.query(Client).all():
if client.i18n_metadata:
metadata = json.loads(client.i18n_metadata)
else:
metadata = {}

if client.redirect_uri:
metadata["redirect_uris"] = client.redirect_uri
if client.token_endpoint_auth_method:
metadata["token_endpoint_auth_method"] = client.token_endpoint_auth_method
if client._allowed_scopes:
metadata["scope"] = client._allowed_scopes.split(" ")
if client.grant_type:
metadata["grant_type"] = client.grant_type.splitlines()
if client.response_type:
metadata["response_type"] = client.response_type.splitlines()
if client.client_uri:
metadata["client_uri"] = client.client_uri
if client.logo_uri:
metadata["logo_uri"] = client.logo_uri
if client.contact:
metadata["contact"] = client.contact
if client.contact:
metadata["tos_uri"] = client.tos_uri
if client.contact:
metadata["policy_uri"] = client.policy_uri
if client.contact:
metadata["jwks_uri"] = client.jwks_uri
if client.contact:
metadata["jwks_text"] = client.jwks_text
if client.contact:
metadata["software_id"] = client.software_id
if client.contact:
metadata["software_version"] = client.software_version

client._client_metadata = json_dumps(metadata)
session.commit()


def set_old_column_values():
conn = op.get_bind()
session = Session(bind=conn)
for client in session.query(Client).all():
if client._client_metadata:
metadata = json_loads(client._client_metadata)
client.i18n_metadata = metadata

if client.redirect_uri:
client.redirect_uri = metadata["redirect_uris"]
if client.token_endpoint_auth_method:
client.token_endpoint_auth_method = metadata["token_endpoint_auth_method"]
if client._allowed_scopes:
client._allowed_scopes = " ".join(metadata["scope"])
if client.grant_type:
client.grant_type = "\n".joinmetadata["grant_type"]
if client.response_type:
client.response_type = "\n".join(metadata["response_type"])
if client.client_uri:
client.client_uri = metadata["client_uri"]
if client.logo_uri:
client.logo_uri = metadata["logo_uri"]
if client.contact:
client.contact = metadata["contact"]
if client.contact:
client.tos_uri = metadata["tos_uri"]
if client.contact:
client.policy_uri = metadata["policy_uri"]
if client.contact:
client.jwks_uri = metadata["jwks_uri"]
if client.contact:
client.jwks_text = metadata["jwks_text"]
if client.contact:
client.software_id = metadata["software_id"]
if client.contact:
client.software_version = metadata["software_version"]

session.commit()

0 comments on commit da1599a

Please sign in to comment.