Skip to content

Commit

Permalink
refactor: Added unit test for GenericTransfer using deferred pageable…
Browse files Browse the repository at this point in the history
… reads
  • Loading branch information
davidblain-infrabel committed Jan 7, 2025
1 parent 735a557 commit f092507
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class GenericTransfer(BaseOperator):
:param preoperator: sql statement or list of statements to be
executed prior to loading the data. (templated)
:param insert_args: extra params for `insert_rows` method.
:param chunk_size: number of records to be read in paginated mode (optional).
:param page_size: number of records to be read in paginated mode (optional).
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -81,7 +81,7 @@ def __init__(
destination_hook_params: dict | None = None,
preoperator: str | list[str] | None = None,
insert_args: dict | None = None,
chunk_size: int | None = None,
page_size: int | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -93,7 +93,7 @@ def __init__(
self.destination_hook_params = destination_hook_params
self.preoperator = preoperator
self.insert_args = insert_args or {}
self.chunk_size = chunk_size
self.page_size = page_size
self._paginated_sql_statement_format = kwargs.get(
"paginated_sql_statement_format", "{} LIMIT {} OFFSET {}"
)
Expand Down Expand Up @@ -123,7 +123,7 @@ def destination_hook(self) -> DbApiHook:

def get_paginated_sql(self, offset: int) -> str:
"""Format the paginated SQL statement using the current format."""
return self._paginated_sql_statement_format.format(self.sql, self.chunk_size, offset)
return self._paginated_sql_statement_format.format(self.sql, self.page_size, offset)

def render_template_fields(
self,
Expand All @@ -133,8 +133,8 @@ def render_template_fields(
super().render_template_fields(context=context, jinja_env=jinja_env)

# Make sure string are converted to integers
if isinstance(self.chunk_size, str):
self.chunk_size = int(self.chunk_size)
if isinstance(self.page_size, str):
self.page_size = int(self.page_size)
commit_every = self.insert_args.get("commit_every")
if isinstance(commit_every, str):
self.insert_args["commit_every"] = int(commit_every)
Expand All @@ -145,7 +145,7 @@ def execute(self, context: Context):
self.log.info(self.preoperator)
self.destination_hook.run(self.preoperator)

if self.chunk_size and isinstance(self.sql, str):
if self.page_size and isinstance(self.sql, str):
self.defer(
trigger=SQLExecuteQueryTrigger(
conn_id=self.source_conn_id,
Expand Down Expand Up @@ -184,7 +184,7 @@ def execute_complete(
map_indexes=map_index,
default=0,
)
+ self.chunk_size
+ self.page_size
)

self.log.info("Offset increased to %d", offset)
Expand Down
15 changes: 3 additions & 12 deletions providers/src/airflow/providers/common/sql/triggers/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]:

self.log.info("Extracting data from %s", self.conn_id)
self.log.info("Executing: \n %s", self.sql)

get_records = getattr(hook, "get_records", None)

if not callable(get_records):
raise RuntimeError(
f"Hook for connection {self.conn_id!r} "
f"({type(hook).__name__}) has no `get_records` method"
)
else:
self.log.info("Reading records from %s", self.conn_id)
results = get_records(self.sql)
self.log.info("Reading records from %s done!", self.conn_id)
self.log.info("Reading records from %s", self.conn_id)
results = hook.get_records(self.sql)
self.log.info("Reading records from %s done!", self.conn_id)

self.log.debug("results: %s", results)
yield TriggerEvent({"status": "success", "results": results})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@

import inspect
from contextlib import closing
from datetime import datetime
from datetime import datetime, timedelta
from unittest import mock
from unittest.mock import MagicMock

import pytest

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models.connection import Connection
from airflow.models.dag import DAG
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.utils import timezone

from tests_common.test_utils.compat import GenericTransfer
from tests_common.test_utils.operators.run_deferable import execute_operator
from tests_common.test_utils.providers import get_provider_min_airflow_version

pytestmark = pytest.mark.db_test
Expand All @@ -38,6 +41,7 @@
DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10]
TEST_DAG_ID = "unit_test_dag"
counter = 0


@pytest.mark.backend("mysql")
Expand Down Expand Up @@ -193,6 +197,65 @@ def test_templated_fields(self):
assert operator.preoperator == "my_preoperator"
assert operator.insert_args == {"commit_every": 5000, "executemany": True, "replace": True}

def test_paginated_read(self):
"""
This unit test is based on the example described in the medium article:
https://medium.com/apache-airflow/transfering-data-from-sap-hana-to-mssql-using-the-airflow-generictransfer-d29f147a9f1f
"""

def create_get_records_side_effect():
records = [
[[1, 2], [11, 12], [3, 4], [13, 14]],
[[3, 4], [13, 14]],
]

def side_effect(sql: str):
if records:
return records.pop(0)
return []

return side_effect

get_records_side_effect = create_get_records_side_effect()

def get_hook(conn_id: str, hook_params: dict | None = None):
mocked_hook = MagicMock(conn_name_attr=conn_id, spec=DbApiHook)
mocked_hook.get_records.side_effect = get_records_side_effect
return mocked_hook

def get_connection(conn_id: str):
mocked_hook = get_hook(conn_id=conn_id)
mocked_conn = MagicMock(conn_id=conn_id, spec=Connection)
mocked_conn.get_hook.return_value = mocked_hook
return mocked_conn

with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_connection):
with mock.patch("airflow.hooks.base.BaseHook.get_hook", side_effect=get_hook):
operator = GenericTransfer(
task_id="transfer_table",
source_conn_id="my_source_conn_id",
destination_conn_id="my_destination_conn_id",
sql="SELECT * FROM HR.EMPLOYEES",
destination_table="NEW_HR.EMPLOYEES",
page_size=1000, # Fetch data in chunks of 1000 rows for pagination
insert_args={
"commit_every": 1000, # Number of rows inserted in each batch
"executemany": True, # Enable batch inserts
"fast_executemany": True, # Boost performance for MSSQL inserts
"replace": True, # Used for upserts/merges if needed
},
execution_timeout=timedelta(hours=1),
)

results, events = execute_operator(operator)

assert not results
assert len(events) == 3
assert events[0].payload["results"] == [[1, 2], [11, 12], [3, 4], [13, 14]]
assert events[1].payload["results"] == [[3, 4], [13, 14]]
assert not events[2].payload["results"]


def test_when_provider_min_airflow_version_is_3_0_or_higher_remove_obsolete_method(self):
"""
Once this test starts failing due to the fact that the minimum Airflow version is now 3.0.0 or higher
Expand Down
2 changes: 1 addition & 1 deletion tests_common/test_utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
from airflow.models.baseoperator import BaseOperatorLink

try:
from airflow.providers.common.sql.operators.generic_transfer import GenericTransfer
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.operators.generic_transfer import GenericTransfer
from airflow.providers.standard.operators.python import PythonOperator
from airflow.providers.standard.sensors.bash import BashSensor
from airflow.providers.standard.sensors.date_time import DateTimeSensor
Expand Down
4 changes: 2 additions & 2 deletions tests_common/test_utils/mock_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def xcom_pull(
run_id: str | None = None,
) -> Any:
if map_indexes:
return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}")
return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}")
return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}", default)
return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}", default)

def xcom_push(self, key: str, value: Any, session: Session = NEW_SESSION, **kwargs) -> None:
values[f"{self.task_id}_{self.dag_id}_{key}_{self.map_index}"] = value
Expand Down

0 comments on commit f092507

Please sign in to comment.