diff --git a/superset/config.py b/superset/config.py index 8554b2a264512..1182042ec6fec 100644 --- a/superset/config.py +++ b/superset/config.py @@ -1905,6 +1905,15 @@ class ExtraDynamicQueryFilters(TypedDict, total=False): EXTRA_DYNAMIC_QUERY_FILTERS: ExtraDynamicQueryFilters = {} +# The migrations that add catalog permissions might take a considerably long time +# to execute as it has to create permissions to all schemas and catalogs from all +# other catalogs accessible by the credentials. This flag allows to skip the +# creation of these secondary perms, and focus only on permissions for the default +# catalog. These secondary permissions can be created later by editing the DB +# connection via the UI (without downtime). +CATALOGS_SIMPLIFIED_MIGRATION: bool = False + + # ------------------------------------------------------------------- # * WARNING: STOP EDITING HERE * # ------------------------------------------------------------------- diff --git a/superset/migrations/shared/catalogs.py b/superset/migrations/shared/catalogs.py index 2306e58499468..d32975c416c91 100644 --- a/superset/migrations/shared/catalogs.py +++ b/superset/migrations/shared/catalogs.py @@ -23,6 +23,7 @@ import sqlalchemy as sa from alembic import op +from flask import current_app from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session @@ -425,9 +426,13 @@ def upgrade_database_catalogs( # update `schema_perm` and `catalog_perm` for tables and charts update_schema_catalog_perms(session, database, catalog_perm, default_catalog, False) - # add any new catalogs discovered and their schemas - new_catalog_pvms = add_non_default_catalogs(database, default_catalog, session) - pvms.update(new_catalog_pvms) + if ( + not current_app.config["CATALOGS_SIMPLIFIED_MIGRATION"] + and not database.is_oauth2_enabled() + ): + # add any new catalogs discovered and their schemas + new_catalog_pvms = add_non_default_catalogs(database, default_catalog, session) + pvms.update(new_catalog_pvms) # add default catalog permission and permissions for any new found schemas, and also # permissions for new catalogs and their schemas diff --git a/tests/unit_tests/migrations/shared/catalogs_test.py b/tests/unit_tests/migrations/shared/catalogs_test.py index 56d202eaca61c..db06b75a8de9f 100644 --- a/tests/unit_tests/migrations/shared/catalogs_test.py +++ b/tests/unit_tests/migrations/shared/catalogs_test.py @@ -18,6 +18,7 @@ from pytest_mock import MockerFixture from sqlalchemy.orm.session import Session +from superset import app from superset.migrations.shared.catalogs import ( downgrade_catalog_perms, upgrade_catalog_perms, @@ -329,3 +330,252 @@ def test_upgrade_catalog_perms_graceful( ("[my_db].[my_table](id:1)",), ("[my_db].[public]",), ] + + +def test_upgrade_catalog_perms_oauth_connection( + mocker: MockerFixture, + session: Session, +) -> None: + """ + Test the `upgrade_catalog_perms` function when the DB is set up using OAuth. + + During the migration we try to connect to the analytical database to get the list of + schemas. This step should be skipped if the database is set up using OAuth and not + raise an exception. + """ + from superset.connectors.sqla.models import SqlaTable + from superset.models.core import Database + from superset.models.slice import Slice + from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState + + engine = session.get_bind() + Database.metadata.create_all(engine) + + mocker.patch("superset.migrations.shared.catalogs.op") + db = mocker.patch("superset.migrations.shared.catalogs.db") + db.Session.return_value = session + add_non_default_catalogs = mocker.patch( + "superset.migrations.shared.catalogs.add_non_default_catalogs" + ) + mocker.patch("superset.migrations.shared.catalogs.op", session) + + database = Database( + database_name="my_db", + sqlalchemy_uri="bigquery://my-test-project", + encrypted_extra='{"oauth2_client_info": "fake_mock_oauth_conn"}', + ) + dataset = SqlaTable( + table_name="my_table", + database=database, + catalog=None, + schema="public", + schema_perm="[my_db].[public]", + ) + session.add(dataset) + session.commit() + + chart = Slice( + slice_name="my_chart", + datasource_type="table", + datasource_id=dataset.id, + ) + query = Query( + client_id="foo", + database=database, + catalog=None, + schema="public", + ) + saved_query = SavedQuery( + database=database, + sql="SELECT * FROM public.t", + catalog=None, + schema="public", + ) + tab_state = TabState( + database=database, + catalog=None, + schema="public", + ) + table_schema = TableSchema( + database=database, + catalog=None, + schema="public", + ) + session.add_all([chart, query, saved_query, tab_state, table_schema]) + session.commit() + + # before migration + assert dataset.catalog is None + assert query.catalog is None + assert saved_query.catalog is None + assert tab_state.catalog is None + assert table_schema.catalog is None + assert dataset.schema_perm == "[my_db].[public]" + assert chart.schema_perm == "[my_db].[public]" + assert session.query(ViewMenu.name).all() == [ + ("[my_db].(id:1)",), + ("[my_db].[my_table](id:1)",), + ("[my_db].[public]",), + ] + + upgrade_catalog_perms() + session.commit() + + # after migration + assert dataset.catalog == "my-test-project" + assert query.catalog == "my-test-project" + assert saved_query.catalog == "my-test-project" + assert tab_state.catalog == "my-test-project" + assert table_schema.catalog == "my-test-project" + assert dataset.schema_perm == "[my_db].[my-test-project].[public]" + assert chart.schema_perm == "[my_db].[my-test-project].[public]" + assert session.query(ViewMenu.name).all() == [ + ("[my_db].(id:1)",), + ("[my_db].[my_table](id:1)",), + ("[my_db].[my-test-project].[public]",), + ("[my_db].[my-test-project]",), + ] + + add_non_default_catalogs.assert_not_called() + + downgrade_catalog_perms() + session.commit() + + # revert + assert dataset.catalog is None + assert query.catalog is None + assert saved_query.catalog is None + assert tab_state.catalog is None + assert table_schema.catalog is None + assert dataset.schema_perm == "[my_db].[public]" + assert chart.schema_perm == "[my_db].[public]" + assert session.query(ViewMenu.name).all() == [ + ("[my_db].(id:1)",), + ("[my_db].[my_table](id:1)",), + ("[my_db].[public]",), + ] + + +def test_upgrade_catalog_perms_simplified_migration( + mocker: MockerFixture, + session: Session, +) -> None: + """ + Test the `upgrade_catalog_perms` function when the ``CATALOGS_SIMPLIFIED_MIGRATION`` + config is set to ``True``. + + This should only update existing permissions + create a new permission + for the default catalog. + """ + from superset.connectors.sqla.models import SqlaTable + from superset.models.core import Database + from superset.models.slice import Slice + from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState + + engine = session.get_bind() + Database.metadata.create_all(engine) + + mocker.patch("superset.migrations.shared.catalogs.op") + db = mocker.patch("superset.migrations.shared.catalogs.db") + db.Session.return_value = session + add_non_default_catalogs = mocker.patch( + "superset.migrations.shared.catalogs.add_non_default_catalogs" + ) + mocker.patch("superset.migrations.shared.catalogs.op", session) + + database = Database( + database_name="my_db", + sqlalchemy_uri="bigquery://my-test-project", + ) + dataset = SqlaTable( + table_name="my_table", + database=database, + catalog=None, + schema="public", + schema_perm="[my_db].[public]", + ) + session.add(dataset) + session.commit() + + chart = Slice( + slice_name="my_chart", + datasource_type="table", + datasource_id=dataset.id, + ) + query = Query( + client_id="foo", + database=database, + catalog=None, + schema="public", + ) + saved_query = SavedQuery( + database=database, + sql="SELECT * FROM public.t", + catalog=None, + schema="public", + ) + tab_state = TabState( + database=database, + catalog=None, + schema="public", + ) + table_schema = TableSchema( + database=database, + catalog=None, + schema="public", + ) + session.add_all([chart, query, saved_query, tab_state, table_schema]) + session.commit() + + # before migration + assert dataset.catalog is None + assert query.catalog is None + assert saved_query.catalog is None + assert tab_state.catalog is None + assert table_schema.catalog is None + assert dataset.schema_perm == "[my_db].[public]" + assert chart.schema_perm == "[my_db].[public]" + assert session.query(ViewMenu.name).all() == [ + ("[my_db].(id:1)",), + ("[my_db].[my_table](id:1)",), + ("[my_db].[public]",), + ] + + with app.test_request_context(): + app.config["CATALOGS_SIMPLIFIED_MIGRATION"] = True + upgrade_catalog_perms() + session.commit() + + # after migration + assert dataset.catalog == "my-test-project" + assert query.catalog == "my-test-project" + assert saved_query.catalog == "my-test-project" + assert tab_state.catalog == "my-test-project" + assert table_schema.catalog == "my-test-project" + assert dataset.schema_perm == "[my_db].[my-test-project].[public]" + assert chart.schema_perm == "[my_db].[my-test-project].[public]" + assert session.query(ViewMenu.name).all() == [ + ("[my_db].(id:1)",), + ("[my_db].[my_table](id:1)",), + ("[my_db].[my-test-project].[public]",), + ("[my_db].[my-test-project]",), + ] + + add_non_default_catalogs.assert_not_called() + + downgrade_catalog_perms() + session.commit() + + # revert + assert dataset.catalog is None + assert query.catalog is None + assert saved_query.catalog is None + assert tab_state.catalog is None + assert table_schema.catalog is None + assert dataset.schema_perm == "[my_db].[public]" + assert chart.schema_perm == "[my_db].[public]" + assert session.query(ViewMenu.name).all() == [ + ("[my_db].(id:1)",), + ("[my_db].[my_table](id:1)",), + ("[my_db].[public]",), + ]