diff --git a/lib/dl_connector_ydb/dl_connector_ydb/core/base/adapter.py b/lib/dl_connector_ydb/dl_connector_ydb/core/base/adapter.py index d013fc354f..1e0eb39e73 100644 --- a/lib/dl_connector_ydb/dl_connector_ydb/core/base/adapter.py +++ b/lib/dl_connector_ydb/dl_connector_ydb/core/base/adapter.py @@ -16,7 +16,14 @@ from dl_core import exc from dl_core.connection_executors.adapters.adapters_base_sa_classic import BaseClassicAdapter -from dl_core.connection_models import TableIdent +from dl_core.connection_executors.models.db_adapter_data import RawColumnInfo +from dl_core.connection_models import ( + DBIdent, + SATextTableDefinition, + TableDefinition, + TableIdent, +) +from dl_core.utils import sa_plain_text import dl_sqlalchemy_ydb.dialect import dl_connector_ydb.core.base.row_converters @@ -116,3 +123,43 @@ def make_exc( # TODO: Move to ErrorTransformer def get_engine_kwargs(self) -> dict: return {} + + def _get_raw_columns_info(self, table_def: TableDefinition) -> tuple[RawColumnInfo, ...]: + # Check if target path is view + if isinstance(table_def, TableIdent): + assert table_def.table_name is not None + + db_engine = self.get_db_engine(table_def.db_name) + connection = db_engine.connect() + + try: + # SA db_engine -> SA connection -> DBAPI connection -> YDB driver + driver = connection.connection._driver # type: ignore # 2024-01-24 # TODO: "DBAPIConnection" has no attribute "_driver" [attr-defined] + assert driver + + # User can gain access to tables by absolute path instead of relative to db_name root. + # Possible solution: require prefix be equal to db_name/ + if table_def.db_name is None: + table_path = table_def.table_name + elif table_def.table_name.startswith("/"): + if table_def.table_name.startswith(table_def.db_name + "/"): + table_path = table_def.table_name + else: + # Not ok? + raise ValueError("absolute table path is not subpath of database path") + else: + table_path = table_def.db_name.rstrip("/") + "/" + table_def.table_name + + response = driver.scheme_client.async_describe_path(table_path) + result = response.result() + + if result.is_view(): + return self._get_subselect_table_info( + SATextTableDefinition( + sa_plain_text(f"(SELECT * FROM `{table_path}` LIMIT 1)"), + ), + ).columns + finally: + connection.close() + + return super()._get_raw_columns_info(table_def) diff --git a/lib/dl_connector_ydb/dl_connector_ydb/core/ydb/adapter.py b/lib/dl_connector_ydb/dl_connector_ydb/core/ydb/adapter.py index df450d16ca..c830d0333c 100644 --- a/lib/dl_connector_ydb/dl_connector_ydb/core/ydb/adapter.py +++ b/lib/dl_connector_ydb/dl_connector_ydb/core/ydb/adapter.py @@ -122,7 +122,7 @@ def _list_table_names_i(self, db_name: str, show_dot: bool = False) -> Iterable[ ] children.sort() for full_path, child in children: - if child.is_any_table(): + if child.is_any_table() or child.is_view(): yield full_path.removeprefix(unprefix) elif child.is_directory(): queue.append(full_path) diff --git a/lib/dl_connector_ydb/dl_connector_ydb_tests/db/api/base.py b/lib/dl_connector_ydb/dl_connector_ydb_tests/db/api/base.py index 6c22a974b9..d934e792f4 100644 --- a/lib/dl_connector_ydb/dl_connector_ydb_tests/db/api/base.py +++ b/lib/dl_connector_ydb/dl_connector_ydb_tests/db/api/base.py @@ -102,6 +102,36 @@ def dataset_params(self, sample_table: DbTable) -> dict: ) +class YDBViewDatasetTestBase(YDBConnectionTestBase, DatasetTestBase): + @pytest.fixture(scope="class") + def sample_view_name( + self, + db: Db, + sample_table: DbTable, + ) -> str: + view_name = sample_table.name + "_view" + + db.get_current_connection().connection.cursor().execute_scheme( + f"CREATE VIEW `{view_name}` WITH (security_invoker = TRUE) AS SELECT * FROM `{sample_table.name}`;" + ) + + yield view_name + + db.get_current_connection().connection.cursor().execute_scheme(f"DROP VIEW `{view_name}`;") + + @pytest.fixture(scope="class") + def dataset_params( + self, + sample_view_name: str, + ) -> dict: + return dict( + source_type=SOURCE_TYPE_YDB_TABLE.name, + parameters=dict( + table_name=sample_view_name, + ), + ) + + class YDBDataApiTestBase(YDBDatasetTestBase, StandardizedDataApiTestBase): mutation_caches_enabled = False diff --git a/lib/dl_connector_ydb/dl_connector_ydb_tests/db/api/test_dataset.py b/lib/dl_connector_ydb/dl_connector_ydb_tests/db/api/test_dataset.py index 0d3d580de6..59b68bea26 100644 --- a/lib/dl_connector_ydb/dl_connector_ydb_tests/db/api/test_dataset.py +++ b/lib/dl_connector_ydb/dl_connector_ydb_tests/db/api/test_dataset.py @@ -1,7 +1,10 @@ from dl_api_client.dsmaker.primitives import Dataset from dl_api_lib_testing.connector.dataset_suite import DefaultConnectorDatasetTestSuite -from dl_connector_ydb_tests.db.api.base import YDBDatasetTestBase +from dl_connector_ydb_tests.db.api.base import ( + YDBDatasetTestBase, + YDBViewDatasetTestBase, +) from dl_connector_ydb_tests.db.config import TABLE_SCHEMA @@ -12,3 +15,12 @@ def check_basic_dataset(self, ds: Dataset, annotation: dict) -> None: assert field_names == {column[0] for column in TABLE_SCHEMA} assert ds.annotation == annotation + + +class TestYDBViewDataset(YDBViewDatasetTestBase, DefaultConnectorDatasetTestSuite): + def check_basic_dataset(self, ds: Dataset, annotation: dict) -> None: + assert ds.id + field_names = {field.title for field in ds.result_schema} + assert field_names == {column[0] for column in TABLE_SCHEMA} + + assert ds.annotation == annotation diff --git a/lib/dl_connector_ydb/docker-compose.yml b/lib/dl_connector_ydb/docker-compose.yml index 84878d3b67..4f45e7410a 100644 --- a/lib/dl_connector_ydb/docker-compose.yml +++ b/lib/dl_connector_ydb/docker-compose.yml @@ -9,6 +9,7 @@ services: YDB_GRPC_ENABLE_TLS: 1 GRPC_TLS_PORT: "51902" YDB_GRPC_TLS_DATA_PATH: "/ydb_certs" + YDB_FEATURE_FLAGS: "enable_views" hostname: "db-ydb" ports: - "51900:51900" diff --git a/lib/dl_connector_ydb/pyproject.toml b/lib/dl_connector_ydb/pyproject.toml index 2aee44c9ef..afea6497e6 100644 --- a/lib/dl_connector_ydb/pyproject.toml +++ b/lib/dl_connector_ydb/pyproject.toml @@ -19,7 +19,7 @@ dl-formula = {path = "../dl_formula"} dl-formula-ref = {path = "../dl_formula_ref"} dl-i18n = {path = "../dl_i18n"} dl-query-processing = {path = "../dl_query_processing"} -dl-sqlalchemy-ydb = {path = "../../lib/dl_sqlalchemy_ydb"} +dl-sqlalchemy-ydb = {path = "../dl_sqlalchemy_ydb"} dl-type-transformer = {path = "../dl_type_transformer"} dl-utils = {path = "../dl_utils"} grpcio = "*" diff --git a/lib/dl_sqlalchemy_ydb/dl_sqlalchemy_ydb/dialect.py b/lib/dl_sqlalchemy_ydb/dl_sqlalchemy_ydb/dialect.py index 6fdf295e41..27282ca4e8 100644 --- a/lib/dl_sqlalchemy_ydb/dl_sqlalchemy_ydb/dialect.py +++ b/lib/dl_sqlalchemy_ydb/dl_sqlalchemy_ydb/dialect.py @@ -33,7 +33,7 @@ class YqlInterval(sa.types.Interval): __visit_name__ = "interval" def result_processor(self, dialect: sa.engine.Dialect, coltype: typing.Any) -> typing.Any: - def process(value: typing.Optional[datetime.timedelta]) -> typing.Optional[int]: + def process(value: typing.Optional[datetime.timedelta] | int) -> typing.Optional[int]: if value is None: return None if isinstance(value, datetime.timedelta): diff --git a/metapkg/poetry.lock b/metapkg/poetry.lock index e88ee0a605..fd1f8cd17b 100644 --- a/metapkg/poetry.lock +++ b/metapkg/poetry.lock @@ -2569,7 +2569,7 @@ dl-formula = {path = "../dl_formula"} dl-formula-ref = {path = "../dl_formula_ref"} dl-i18n = {path = "../dl_i18n"} dl-query-processing = {path = "../dl_query_processing"} -dl-sqlalchemy-ydb = {path = "../../lib/dl_sqlalchemy_ydb"} +dl-sqlalchemy-ydb = {path = "../dl_sqlalchemy_ydb"} dl-type-transformer = {path = "../dl_type_transformer"} dl-utils = {path = "../dl_utils"} grpcio = "*"