Skip to content

Commit

Permalink
test(amazon): make emr sensor tests check class type and deferrable a…
Browse files Browse the repository at this point in the history
…ttribute
  • Loading branch information
Lee-W committed Dec 20, 2023
1 parent 3f9a599 commit b99adf0
Showing 1 changed file with 22 additions and 86 deletions.
108 changes: 22 additions & 86 deletions tests/amazon/aws/sensors/test_emr_sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.sensors.emr import (
EmrStepSensor,
)

from astronomer.providers.amazon.aws.sensors.emr import (
EmrContainerSensorAsync,
EmrJobFlowSensorAsync,
EmrStepSensorAsync,
)
from astronomer.providers.amazon.aws.triggers.emr import (
EmrContainerSensorTrigger,
EmrJobFlowSensorTrigger,
EmrStepSensorTrigger,
)

TASK_ID = "test_emr_container_sensor"
Expand All @@ -28,55 +29,28 @@


class TestEmrContainerSensorAsync:
TASK = EmrContainerSensorAsync(
task_id=TASK_ID,
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
job_id=JOB_ID,
poll_interval=5,
max_retries=1,
aws_conn_id=AWS_CONN_ID,
)

@mock.patch(f"{MODULE}.EmrContainerSensorAsync.defer")
@mock.patch(f"{MODULE}.EmrContainerSensorAsync.poke", return_value=True)
def test_emr_container_sensor_async_finish_before_deferred(self, mock_poke, mock_defer, context):
"""Assert task is not deferred when it receives a finish status before deferring"""
self.TASK.execute(context)
assert not mock_defer.called

@mock.patch(f"{MODULE}.EmrContainerSensorAsync.poke", return_value=False)
def test_emr_container_sensor_async(self, mock_poke, context):
"""
Asserts that a task is deferred and a EmrContainerSensorTrigger will be fired
when the EmrContainerSensorAsync is executed.
"""

with pytest.raises(TaskDeferred) as exc:
self.TASK.execute(context)
assert isinstance(
exc.value.trigger, EmrContainerSensorTrigger
), "Trigger is not a EmrContainerSensorTrigger"

def test_emr_container_sensor_async_execute_failure(self, context):
"""Tests that an AirflowException is raised in case of error event"""

with pytest.raises(AirflowException):
self.TASK.execute_complete(
context=None, event={"status": "error", "message": "test failure message"}
)

def test_emr_container_sensor_async_execute_complete(self):
"""Asserts that logging occurs as expected"""

assert (
self.TASK.execute_complete(context=None, event={"status": "success", "message": "Job completed"})
is None
def test_init(self):
task = EmrContainerSensorAsync(
task_id=TASK_ID,
virtual_cluster_id=VIRTUAL_CLUSTER_ID,
job_id=JOB_ID,
poll_interval=5,
max_retries=1,
aws_conn_id=AWS_CONN_ID,
)
assert isinstance(task, EmrContainerSensorAsync)
assert task.deferrable is True

def test_emr_container_sensor_async_execute_complete_event_none(self):
"""Asserts that logging occurs as expected"""

assert self.TASK.execute_complete(context=None, event=None) is None
class TestEmrStepSensorAsync:
def test_init(self):
task = EmrStepSensorAsync(
task_id="emr_step_sensor",
job_flow_id=JOB_ID,
step_id=STEP_ID,
)
assert isinstance(task, EmrStepSensor)
assert task.deferrable is True


class TestEmrJobFlowSensorAsync:
Expand Down Expand Up @@ -140,41 +114,3 @@ def test_emr_job_flow_sensor_async_execute_complete_event_none(self):
"""Asserts that logging occurs as expected"""

assert self.TASK.execute_complete(context=None, event=None) is None


class TestEmrStepSensorAsync:
TASK = EmrStepSensorAsync(
task_id="emr_step_sensor",
job_flow_id=JOB_ID,
step_id=STEP_ID,
)

@mock.patch(f"{MODULE}.EmrStepSensorAsync.defer")
@mock.patch(f"{MODULE}.EmrStepSensorAsync.poke", return_value=True)
def test_emr_step_sensor_async_finish_before_deferred(self, mock_poke, mock_defer, context):
"""Assert task is not deferred when it receives a finish status before deferring"""
self.TASK.execute(context)
assert not mock_defer.called

@mock.patch(f"{MODULE}.EmrStepSensorAsync.poke", return_value=False)
def test_emr_step_sensor_async(self, mock_poke, context):
"""Assert execute method defer for EmrStepSensorAsync sensor"""

with pytest.raises(TaskDeferred) as exc:
self.TASK.execute(context)
assert isinstance(exc.value.trigger, EmrStepSensorTrigger), "Trigger is not a EmrStepSensorTrigger"

def test_emr_step_sensor_execute_complete_success(self):
"""Assert execute_complete log success message when triggerer fire with target state"""

with mock.patch.object(self.TASK.log, "info") as mock_log_info:
self.TASK.execute_complete(
context={}, event={"status": "success", "message": "Job flow currently COMPLETED"}
)
mock_log_info.assert_called_with("%s completed successfully.", "j-T0CT8Z0C20NT")

def test_emr_step_sensor_execute_complete_failure(self):
"""Assert execute_complete method fail"""

with pytest.raises(AirflowException):
self.TASK.execute_complete(context={}, event={"status": "error", "message": ""})

0 comments on commit b99adf0

Please sign in to comment.