diff --git a/providers/src/airflow/providers/standard/operators/generic_transfer.py b/providers/src/airflow/providers/common/sql/operators/generic_transfer.py similarity index 95% rename from providers/src/airflow/providers/standard/operators/generic_transfer.py rename to providers/src/airflow/providers/common/sql/operators/generic_transfer.py index 1a9b5af9b6264..74ac34d44ec2f 100644 --- a/providers/src/airflow/providers/standard/operators/generic_transfer.py +++ b/providers/src/airflow/providers/common/sql/operators/generic_transfer.py @@ -52,7 +52,7 @@ class GenericTransfer(BaseOperator): :param preoperator: sql statement or list of statements to be executed prior to loading the data. (templated) :param insert_args: extra params for `insert_rows` method. - :param chunk_size: number of records to be read in paginated mode (optional). + :param page_size: number of records to be read in paginated mode (optional). """ template_fields: Sequence[str] = ( @@ -81,7 +81,7 @@ def __init__( destination_hook_params: dict | None = None, preoperator: str | list[str] | None = None, insert_args: dict | None = None, - chunk_size: int | None = None, + page_size: int | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -93,7 +93,7 @@ def __init__( self.destination_hook_params = destination_hook_params self.preoperator = preoperator self.insert_args = insert_args or {} - self.chunk_size = chunk_size + self.page_size = page_size self._paginated_sql_statement_format = kwargs.get( "paginated_sql_statement_format", "{} LIMIT {} OFFSET {}" ) @@ -123,7 +123,7 @@ def destination_hook(self) -> DbApiHook: def get_paginated_sql(self, offset: int) -> str: """Format the paginated SQL statement using the current format.""" - return self._paginated_sql_statement_format.format(self.sql, self.chunk_size, offset) + return self._paginated_sql_statement_format.format(self.sql, self.page_size, offset) def render_template_fields( self, @@ -133,8 +133,8 @@ def render_template_fields( super().render_template_fields(context=context, jinja_env=jinja_env) # Make sure string are converted to integers - if isinstance(self.chunk_size, str): - self.chunk_size = int(self.chunk_size) + if isinstance(self.page_size, str): + self.page_size = int(self.page_size) commit_every = self.insert_args.get("commit_every") if isinstance(commit_every, str): self.insert_args["commit_every"] = int(commit_every) @@ -145,7 +145,7 @@ def execute(self, context: Context): self.log.info(self.preoperator) self.destination_hook.run(self.preoperator) - if self.chunk_size and isinstance(self.sql, str): + if self.page_size and isinstance(self.sql, str): self.defer( trigger=SQLExecuteQueryTrigger( conn_id=self.source_conn_id, @@ -184,7 +184,7 @@ def execute_complete( map_indexes=map_index, default=0, ) - + self.chunk_size + + self.page_size ) self.log.info("Offset increased to %d", offset) diff --git a/providers/src/airflow/providers/common/sql/triggers/sql.py b/providers/src/airflow/providers/common/sql/triggers/sql.py index 8c154b1a82f08..ee345722154e5 100644 --- a/providers/src/airflow/providers/common/sql/triggers/sql.py +++ b/providers/src/airflow/providers/common/sql/triggers/sql.py @@ -76,18 +76,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: self.log.info("Extracting data from %s", self.conn_id) self.log.info("Executing: \n %s", self.sql) - - get_records = getattr(hook, "get_records", None) - - if not callable(get_records): - raise RuntimeError( - f"Hook for connection {self.conn_id!r} " - f"({type(hook).__name__}) has no `get_records` method" - ) - else: - self.log.info("Reading records from %s", self.conn_id) - results = get_records(self.sql) - self.log.info("Reading records from %s done!", self.conn_id) + self.log.info("Reading records from %s", self.conn_id) + results = hook.get_records(self.sql) + self.log.info("Reading records from %s done!", self.conn_id) self.log.debug("results: %s", results) yield TriggerEvent({"status": "success", "results": results}) diff --git a/providers/tests/standard/operators/test_generic_transfer.py b/providers/tests/common/sql/operators/test_generic_transfer.py similarity index 75% rename from providers/tests/standard/operators/test_generic_transfer.py rename to providers/tests/common/sql/operators/test_generic_transfer.py index 4ea08e48891e6..4c26cc8f941e0 100644 --- a/providers/tests/standard/operators/test_generic_transfer.py +++ b/providers/tests/common/sql/operators/test_generic_transfer.py @@ -19,17 +19,20 @@ import inspect from contextlib import closing -from datetime import datetime +from datetime import datetime, timedelta from unittest import mock +from unittest.mock import MagicMock import pytest - from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.models.connection import Connection from airflow.models.dag import DAG +from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.postgres.hooks.postgres import PostgresHook from airflow.utils import timezone from tests_common.test_utils.compat import GenericTransfer +from tests_common.test_utils.operators.run_deferable import execute_operator from tests_common.test_utils.providers import get_provider_min_airflow_version pytestmark = pytest.mark.db_test @@ -38,6 +41,7 @@ DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat() DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10] TEST_DAG_ID = "unit_test_dag" +counter = 0 @pytest.mark.backend("mysql") @@ -193,6 +197,65 @@ def test_templated_fields(self): assert operator.preoperator == "my_preoperator" assert operator.insert_args == {"commit_every": 5000, "executemany": True, "replace": True} + def test_paginated_read(self): + """ + This unit test is based on the example described in the medium article: + https://medium.com/apache-airflow/transfering-data-from-sap-hana-to-mssql-using-the-airflow-generictransfer-d29f147a9f1f + """ + + def create_get_records_side_effect(): + records = [ + [[1, 2], [11, 12], [3, 4], [13, 14]], + [[3, 4], [13, 14]], + ] + + def side_effect(sql: str): + if records: + return records.pop(0) + return [] + + return side_effect + + get_records_side_effect = create_get_records_side_effect() + + def get_hook(conn_id: str, hook_params: dict | None = None): + mocked_hook = MagicMock(conn_name_attr=conn_id, spec=DbApiHook) + mocked_hook.get_records.side_effect = get_records_side_effect + return mocked_hook + + def get_connection(conn_id: str): + mocked_hook = get_hook(conn_id=conn_id) + mocked_conn = MagicMock(conn_id=conn_id, spec=Connection) + mocked_conn.get_hook.return_value = mocked_hook + return mocked_conn + + with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_connection): + with mock.patch("airflow.hooks.base.BaseHook.get_hook", side_effect=get_hook): + operator = GenericTransfer( + task_id="transfer_table", + source_conn_id="my_source_conn_id", + destination_conn_id="my_destination_conn_id", + sql="SELECT * FROM HR.EMPLOYEES", + destination_table="NEW_HR.EMPLOYEES", + page_size=1000, # Fetch data in chunks of 1000 rows for pagination + insert_args={ + "commit_every": 1000, # Number of rows inserted in each batch + "executemany": True, # Enable batch inserts + "fast_executemany": True, # Boost performance for MSSQL inserts + "replace": True, # Used for upserts/merges if needed + }, + execution_timeout=timedelta(hours=1), + ) + + results, events = execute_operator(operator) + + assert not results + assert len(events) == 3 + assert events[0].payload["results"] == [[1, 2], [11, 12], [3, 4], [13, 14]] + assert events[1].payload["results"] == [[3, 4], [13, 14]] + assert not events[2].payload["results"] + + def test_when_provider_min_airflow_version_is_3_0_or_higher_remove_obsolete_method(self): """ Once this test starts failing due to the fact that the minimum Airflow version is now 3.0.0 or higher diff --git a/tests_common/test_utils/compat.py b/tests_common/test_utils/compat.py index 3bd4b89dfc1c4..ae9e009a57a8a 100644 --- a/tests_common/test_utils/compat.py +++ b/tests_common/test_utils/compat.py @@ -43,8 +43,8 @@ from airflow.models.baseoperator import BaseOperatorLink try: + from airflow.providers.common.sql.operators.generic_transfer import GenericTransfer from airflow.providers.standard.operators.bash import BashOperator - from airflow.providers.standard.operators.generic_transfer import GenericTransfer from airflow.providers.standard.operators.python import PythonOperator from airflow.providers.standard.sensors.bash import BashSensor from airflow.providers.standard.sensors.date_time import DateTimeSensor diff --git a/tests_common/test_utils/mock_context.py b/tests_common/test_utils/mock_context.py index 5a17e058f5f11..0db6aa99e5777 100644 --- a/tests_common/test_utils/mock_context.py +++ b/tests_common/test_utils/mock_context.py @@ -57,8 +57,8 @@ def xcom_pull( run_id: str | None = None, ) -> Any: if map_indexes: - return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}") - return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}") + return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}", default) + return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}", default) def xcom_push(self, key: str, value: Any, session: Session = NEW_SESSION, **kwargs) -> None: values[f"{self.task_id}_{self.dag_id}_{key}_{self.map_index}"] = value