Skip to content

Commit

Permalink
Improve dag_maker compatibility handling (#44125)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored Nov 18, 2024
1 parent 6f02fdb commit 007c4b1
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 323 deletions.
7 changes: 1 addition & 6 deletions providers/tests/amazon/aws/operators/test_base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.utils import timezone

from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS

TEST_CONN = "aws_test_conn"


Expand Down Expand Up @@ -118,10 +116,7 @@ def test_execute(self, op_kwargs, dag_maker):
with dag_maker("test_aws_base_operator", serialized=True):
FakeS3Operator(task_id="fake-task-id", **op_kwargs)

if AIRFLOW_V_3_0_PLUS:
dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
else:
dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
tis = {ti.task_id: ti for ti in dagrun.task_instances}
tis["fake-task-id"].run()

Expand Down
7 changes: 1 addition & 6 deletions providers/tests/amazon/aws/sensors/test_base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.utils import timezone

from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS

TEST_CONN = "aws_test_conn"


Expand Down Expand Up @@ -120,10 +118,7 @@ def test_execute(self, dag_maker, op_kwargs):
with dag_maker("test_aws_base_sensor", serialized=True):
FakeDynamoDBSensor(task_id="fake-task-id", **op_kwargs, poke_interval=1)

if AIRFLOW_V_3_0_PLUS:
dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
else:
dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
tis = {ti.task_id: ti for ti in dagrun.task_instances}
tis["fake-task-id"].run()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -733,10 +733,7 @@ def test_resolve_application_file_template_file(dag_maker, tmp_path, session):
kubernetes_conn_id="kubernetes_default_kube_config",
task_id="test_template_body_templating_task",
)
if AIRFLOW_V_3_0_PLUS:
ti = dag_maker.create_dagrun(logical_date=logical_date).task_instances[0]
else:
ti = dag_maker.create_dagrun(execution_date=logical_date).task_instances[0]
ti = dag_maker.create_dagrun(logical_date=logical_date).task_instances[0]
session.add(ti)
session.commit()
ti.render_templates()
Expand Down Expand Up @@ -776,10 +773,7 @@ def test_resolve_application_file_template_non_dictionary(dag_maker, tmp_path, b
kubernetes_conn_id="kubernetes_default_kube_config",
task_id="test_template_body_templating_task",
)
if AIRFLOW_V_3_0_PLUS:
ti = dag_maker.create_dagrun(logical_date=logical_date).task_instances[0]
else:
ti = dag_maker.create_dagrun(execution_date=logical_date).task_instances[0]
ti = dag_maker.create_dagrun(logical_date=logical_date).task_instances[0]
session.add(ti)
session.commit()
ti.render_templates()
Expand Down
97 changes: 28 additions & 69 deletions providers/tests/sftp/operators/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from airflow.utils import timezone
from airflow.utils.timezone import datetime

from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS
from tests_common.test_utils.config import conf_vars

pytestmark = pytest.mark.db_test
Expand Down Expand Up @@ -184,10 +183,7 @@ def test_file_transfer_with_intermediate_dir_put(self, dag_maker):
command=f"cat {self.test_remote_filepath_int_dir}",
do_xcom_push=True,
)
if AIRFLOW_V_3_0_PLUS:
dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
else:
dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
tis = {ti.task_id: ti for ti in dagrun.task_instances}
with pytest.warns(AirflowProviderDeprecationWarning, match="Parameter `ssh_hook` is deprecated..*"):
tis["test_sftp"].run()
Expand Down Expand Up @@ -220,10 +216,7 @@ def test_json_file_transfer_put(self, dag_maker):
command=f"cat {self.test_remote_filepath}",
do_xcom_push=True,
)
if AIRFLOW_V_3_0_PLUS:
dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
else:
dagrun = dag_maker.create_dagrun(execution_date=timezone.utcnow())
dagrun = dag_maker.create_dagrun(logical_date=timezone.utcnow())
tis = {ti.task_id: ti for ti in dagrun.task_instances}
with pytest.warns(AirflowProviderDeprecationWarning, match="Parameter `ssh_hook` is deprecated..*"):
tis["put_test_task"].run()
Expand All @@ -249,18 +242,11 @@ def test_pickle_file_transfer_get(self, dag_maker, create_remote_file_and_cleanu
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.GET,
)
if AIRFLOW_V_3_0_PLUS:
for ti in dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
with pytest.warns(
AirflowProviderDeprecationWarning, match="Parameter `ssh_hook` is deprecated..*"
):
ti.run()
else:
for ti in dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances:
with pytest.warns(
AirflowProviderDeprecationWarning, match="Parameter `ssh_hook` is deprecated..*"
):
ti.run()
for ti in dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
with pytest.warns(
AirflowProviderDeprecationWarning, match="Parameter `ssh_hook` is deprecated..*"
):
ti.run()

# Test the received content.
with open(self.test_local_filepath, "rb") as file:
Expand All @@ -277,18 +263,11 @@ def test_json_file_transfer_get(self, dag_maker, create_remote_file_and_cleanup)
remote_filepath=self.test_remote_filepath,
operation=SFTPOperation.GET,
)
if AIRFLOW_V_3_0_PLUS:
for ti in dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
with pytest.warns(
AirflowProviderDeprecationWarning, match="Parameter `ssh_hook` is deprecated..*"
):
ti.run()
else:
for ti in dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances:
with pytest.warns(
AirflowProviderDeprecationWarning, match="Parameter `ssh_hook` is deprecated..*"
):
ti.run()
for ti in dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
with pytest.warns(
AirflowProviderDeprecationWarning, match="Parameter `ssh_hook` is deprecated..*"
):
ti.run()

# Test the received content.
content_received = None
Expand All @@ -307,30 +286,17 @@ def test_file_transfer_no_intermediate_dir_error_get(self, dag_maker, create_rem
operation=SFTPOperation.GET,
)

if AIRFLOW_V_3_0_PLUS:
for ti in dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
# This should raise an error with "No such file" as the directory
# does not exist.
with (
pytest.raises(AirflowException) as ctx,
pytest.warns(
AirflowProviderDeprecationWarning, match="Parameter `ssh_hook` is deprecated..*"
),
):
ti.run()
assert "No such file" in str(ctx.value)
else:
for ti in dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances:
# This should raise an error with "No such file" as the directory
# does not exist.
with (
pytest.raises(AirflowException) as ctx,
pytest.warns(
AirflowProviderDeprecationWarning, match="Parameter `ssh_hook` is deprecated..*"
),
):
ti.run()
assert "No such file" in str(ctx.value)
for ti in dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
# This should raise an error with "No such file" as the directory
# does not exist.
with (
pytest.raises(AirflowException) as ctx,
pytest.warns(
AirflowProviderDeprecationWarning, match="Parameter `ssh_hook` is deprecated..*"
),
):
ti.run()
assert "No such file" in str(ctx.value)

@conf_vars({("core", "enable_xcom_pickling"): "True"})
def test_file_transfer_with_intermediate_dir_error_get(self, dag_maker, create_remote_file_and_cleanup):
Expand All @@ -344,18 +310,11 @@ def test_file_transfer_with_intermediate_dir_error_get(self, dag_maker, create_r
create_intermediate_dirs=True,
)

if AIRFLOW_V_3_0_PLUS:
for ti in dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
with pytest.warns(
AirflowProviderDeprecationWarning, match="Parameter `ssh_hook` is deprecated..*"
):
ti.run()
else:
for ti in dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances:
with pytest.warns(
AirflowProviderDeprecationWarning, match="Parameter `ssh_hook` is deprecated..*"
):
ti.run()
for ti in dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances:
with pytest.warns(
AirflowProviderDeprecationWarning, match="Parameter `ssh_hook` is deprecated..*"
):
ti.run()

# Test the received content.
content_received = None
Expand Down
32 changes: 8 additions & 24 deletions providers/tests/standard/operators/test_bash.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@

from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS

if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType

if TYPE_CHECKING:
from airflow.models import TaskInstance

Expand Down Expand Up @@ -111,27 +108,14 @@ def test_echo_env_variables(
)

logical_date = utc_now
triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}
if AIRFLOW_V_3_0_PLUS:
dag_maker.create_dagrun(
run_type=DagRunType.MANUAL,
logical_date=logical_date,
start_date=utc_now,
state=State.RUNNING,
external_trigger=False,
data_interval=(logical_date, logical_date),
**triggered_by_kwargs,
)
else:
dag_maker.create_dagrun(
run_type=DagRunType.MANUAL,
execution_date=logical_date,
start_date=utc_now,
state=State.RUNNING,
external_trigger=False,
data_interval=(logical_date, logical_date),
**triggered_by_kwargs,
)
dag_maker.create_dagrun(
run_type=DagRunType.MANUAL,
logical_date=logical_date,
start_date=utc_now,
state=State.RUNNING,
external_trigger=False,
data_interval=(logical_date, logical_date),
)

with mock.patch.dict(
"os.environ", {"AIRFLOW_HOME": "MY_PATH_TO_AIRFLOW_HOME", "PYTHONPATH": "AWESOME_PYTHONPATH"}
Expand Down
43 changes: 7 additions & 36 deletions providers/tests/standard/operators/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@
from airflow.utils.session import create_session
from airflow.utils.state import State

from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS

if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType

pytestmark = pytest.mark.db_test

DEFAULT_DATE = timezone.datetime(2016, 1, 1)
Expand Down Expand Up @@ -79,24 +74,12 @@ def base_tests_setup(self, dag_maker):
self.branch_1.set_upstream(self.branch_op)
self.branch_2.set_upstream(self.branch_op)

triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}
if AIRFLOW_V_3_0_PLUS:
self.dr = dag_maker.create_dagrun(
run_id="manual__",
start_date=DEFAULT_DATE,
logical_date=DEFAULT_DATE,
state=State.RUNNING,
data_interval=(DEFAULT_DATE, DEFAULT_DATE),
**triggered_by_kwargs,
)
else:
self.dr = dag_maker.create_dagrun(
run_id="manual__",
start_date=DEFAULT_DATE,
execution_date=DEFAULT_DATE,
state=State.RUNNING,
data_interval=(DEFAULT_DATE, DEFAULT_DATE),
**triggered_by_kwargs,
)

def teardown_method(self):
Expand Down Expand Up @@ -251,25 +234,13 @@ def test_branch_datetime_operator_use_task_logical_date(self, dag_maker, target_
"""Check if BranchDateTimeOperator uses task logical date"""
in_between_date = timezone.datetime(2020, 7, 7, 10, 30, 0)
self.branch_op.use_task_logical_date = True
triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}
if AIRFLOW_V_3_0_PLUS:
self.dr = dag_maker.create_dagrun(
run_id="manual_exec_date__",
start_date=in_between_date,
logical_date=in_between_date,
state=State.RUNNING,
data_interval=(in_between_date, in_between_date),
**triggered_by_kwargs,
)
else:
self.dr = dag_maker.create_dagrun(
run_id="manual_exec_date__",
start_date=in_between_date,
execution_date=in_between_date,
state=State.RUNNING,
data_interval=(in_between_date, in_between_date),
**triggered_by_kwargs,
)
self.dr = dag_maker.create_dagrun(
run_id="manual_exec_date__",
start_date=in_between_date,
logical_date=in_between_date,
state=State.RUNNING,
data_interval=(in_between_date, in_between_date),
)

self.branch_op.target_lower = target_lower
self.branch_op.target_upper = target_upper
Expand Down
33 changes: 8 additions & 25 deletions providers/tests/standard/operators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,6 @@
from tests_common.test_utils.compat import AIRFLOW_V_2_9_PLUS, AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS
from tests_common.test_utils.db import clear_db_runs

if AIRFLOW_V_3_0_PLUS:
from airflow.utils.types import DagRunTriggeredByType


if TYPE_CHECKING:
from airflow.models.dagrun import DagRun

Expand Down Expand Up @@ -148,27 +144,14 @@ def default_kwargs(**kwargs):
return kwargs

def create_dag_run(self) -> DagRun:
triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}
if AIRFLOW_V_3_0_PLUS:
return self.dag_maker.create_dagrun(
state=DagRunState.RUNNING,
start_date=self.dag_maker.start_date,
session=self.dag_maker.session,
logical_date=self.default_date,
run_type=DagRunType.MANUAL,
data_interval=(self.default_date, self.default_date),
**triggered_by_kwargs, # type: ignore
)
else:
return self.dag_maker.create_dagrun(
state=DagRunState.RUNNING,
start_date=self.dag_maker.start_date,
session=self.dag_maker.session,
execution_date=self.default_date,
run_type=DagRunType.MANUAL,
data_interval=(self.default_date, self.default_date),
**triggered_by_kwargs, # type: ignore
)
return self.dag_maker.create_dagrun(
state=DagRunState.RUNNING,
start_date=self.dag_maker.start_date,
session=self.dag_maker.session,
logical_date=self.default_date,
run_type=DagRunType.MANUAL,
data_interval=(self.default_date, self.default_date),
)

def create_ti(self, fn, **kwargs) -> TI:
"""Create TaskInstance for class defined Operator."""
Expand Down
Loading

0 comments on commit 007c4b1

Please sign in to comment.