Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Deprecate DbtCloudRunJobOperatorAsync and DbtCloudJobRunSensorAsync
Browse files Browse the repository at this point in the history
This PR deprecates the operator DbtCloudRunJobOperatorAsync and the
sensor DbtCloudJobRunSensorAsync from the dbt provider by proxying
them to their Airflow OSS provider's counterpart.

closes: #1414
pankajkoti committed Jan 23, 2024

Verified

This commit was signed with the committer’s verified signature.
erikmd Erik Martin-Dorel
1 parent ec3e539 commit 15cb47f
Showing 7 changed files with 79 additions and 351 deletions.
13 changes: 10 additions & 3 deletions astronomer/providers/dbt/cloud/hooks/dbt.py
Original file line number Diff line number Diff line change
@@ -42,9 +42,8 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:

class DbtCloudHookAsync(BaseHook):
"""
Interact with dbt Cloud using the V2 API.
:param dbt_cloud_conn_id: The ID of the :ref:`dbt Cloud connection <howto/connection:dbt-cloud>`.
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook` instead.
"""

conn_name_attr = "dbt_cloud_conn_id"
@@ -53,6 +52,14 @@ class DbtCloudHookAsync(BaseHook):
hook_name = "dbt Cloud"

def __init__(self, dbt_cloud_conn_id: str):
warnings.warn(
(
"This class is deprecated. "
"Use `airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook` instead."
),
DeprecationWarning,
stacklevel=2,
)
self.dbt_cloud_conn_id = dbt_cloud_conn_id

async def get_headers_tenants_from_connection(self) -> Tuple[Dict[str, Any], str]:
87 changes: 20 additions & 67 deletions astronomer/providers/dbt/cloud/operators/dbt.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,46 @@
from __future__ import annotations

import time
import warnings
from typing import Any

from airflow import AirflowException
from airflow.exceptions import AirflowFailException
from airflow.providers.dbt.cloud.hooks.dbt import (
DbtCloudHook,
DbtCloudJobRunException,
DbtCloudJobRunStatus,
JobRunInfo,
)
from airflow.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperator

from astronomer.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger
from astronomer.providers.utils.typing_compat import Context


class DbtCloudRunJobOperatorAsync(DbtCloudRunJobOperator):
"""
Executes a dbt Cloud job asynchronously. Trigger the dbt cloud job via worker to dbt and with run id in response
poll for the status in trigger.
.. seealso::
For more information on sync Operator DbtCloudRunJobOperator, take a look at the guide:
:ref:`howto/operator:DbtCloudRunJobOperator`
:param dbt_cloud_conn_id: The connection ID for connecting to dbt Cloud.
:param job_id: The ID of a dbt Cloud job.
:param account_id: Optional. The ID of a dbt Cloud account.
:param trigger_reason: Optional Description of the reason to trigger the job. Dbt requires the trigger reason while
making an API. if it is not provided uses the default reasons.
:param steps_override: Optional. List of dbt commands to execute when triggering the job instead of those
configured in dbt Cloud.
:param schema_override: Optional. Override the destination schema in the configured target for this job.
:param timeout: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days.
:param check_interval: Time in seconds to check on a job run's status. Defaults to 60 seconds.
:param additional_run_config: Optional. Any additional parameters that should be included in the API
request when triggering the job.
:return: The ID of the triggered dbt Cloud job run.
This class is deprecated.
Use :class: `~airflow.providers.dbt.cloud.operators.dbt.DbtCloudRunJobOperator` instead
and set `deferrable` param to `True` instead.
"""

def execute(self, context: Context) -> Any:
"""Submits a job which generates a run_id and gets deferred"""
if self.trigger_reason is None:
self.trigger_reason = (
f"Triggered via Apache Airflow by task {self.task_id!r} in the {self.dag.dag_id} DAG."
)
hook = DbtCloudHook(dbt_cloud_conn_id=self.dbt_cloud_conn_id)
trigger_job_response = hook.trigger_job_run(
account_id=self.account_id,
job_id=self.job_id,
cause=self.trigger_reason,
steps_override=self.steps_override,
schema_override=self.schema_override,
additional_run_config=self.additional_run_config,
def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
(
"This class is deprecated. "
"Use `airflow.providers.dbt.cloud.operators.dbt.DbtCloudRunJobOperator` "
"and set `deferrable` param to `True` instead."
),
DeprecationWarning,
stacklevel=2,
)
run_id = trigger_job_response.json()["data"]["id"]
job_run_url = trigger_job_response.json()["data"]["href"]

context["ti"].xcom_push(key="job_run_url", value=job_run_url)
end_time = time.time() + self.timeout

job_run_info = JobRunInfo(account_id=self.account_id, run_id=run_id)
job_run_status = hook.get_job_run_status(**job_run_info)
if not DbtCloudJobRunStatus.is_terminal(job_run_status):
self.defer(
timeout=self.execution_timeout,
trigger=DbtCloudRunJobTrigger(
conn_id=self.dbt_cloud_conn_id,
run_id=run_id,
end_time=end_time,
account_id=self.account_id,
poll_interval=self.check_interval,
),
method_name="execute_complete",
)
elif job_run_status == DbtCloudJobRunStatus.SUCCESS.value:
self.log.info("Job run %s has completed successfully.", str(run_id))
return run_id
elif job_run_status in (DbtCloudJobRunStatus.CANCELLED.value, DbtCloudJobRunStatus.ERROR.value):
raise DbtCloudJobRunException(f"Job run {run_id} has failed or has been cancelled.")
super().__init__(*args, deferrable=True, **kwargs)

def execute_complete(self, context: Context, event: dict[str, Any]) -> int:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
# We handle the case where the job run is cancelled a bit differently than the OSS operator.
# Essentially, we do not want to retry the task if the job run is cancelled, whereas the OSS operator will
# retry the task if the job run is cancelled. This has been specifically handled here differently based upon
# the feedback from a user. And hence, while we are deprecating this operator, we are not changing the behavior
# of the `execute_complete` method. We can check if the wider OSS community wants this behavior to be changed
# in the future as it is here, and then we can remove this override.
if event["status"] == "cancelled":
self.log.info("Job run %s has been cancelled.", str(event["run_id"]))
self.log.info("Task will not be retried.")
78 changes: 26 additions & 52 deletions astronomer/providers/dbt/cloud/sensors/dbt.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,35 @@
import time
from typing import Any, Dict
from __future__ import annotations

from airflow.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensor
import warnings
from typing import Any

from astronomer.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger
from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception
from astronomer.providers.utils.typing_compat import Context
from airflow.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensor


class DbtCloudJobRunSensorAsync(DbtCloudJobRunSensor):
"""
Checks the status of a dbt Cloud job run.
.. seealso::
For more information on sync Sensor DbtCloudJobRunSensor, take a look at the guide::
:ref:`howto/operator:DbtCloudJobRunSensor`
:param dbt_cloud_conn_id: The connection identifier for connecting to dbt Cloud.
:param run_id: The job run identifier.
:param account_id: The dbt Cloud account identifier.
:param timeout: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days.
This class is deprecated.
Use :class: `~airflow.providers.dbt.cloud.sensors.dbt.DbtCloudJobRunSensor` instead
and set `deferrable` param to `True` instead.
"""

def __init__(
self,
*,
poll_interval: float = 5,
timeout: float = 60 * 60 * 24 * 7,
**kwargs: Any,
):
self.poll_interval = poll_interval
self.timeout = timeout
super().__init__(**kwargs)

def execute(self, context: "Context") -> None:
"""Defers trigger class to poll for state of the job run until it reaches a failure state or success state"""
if not poke(self, context):
end_time = time.time() + self.timeout
self.defer(
timeout=self.execution_timeout,
trigger=DbtCloudRunJobTrigger(
run_id=self.run_id,
conn_id=self.dbt_cloud_conn_id,
account_id=self.account_id,
poll_interval=self.poll_interval,
end_time=end_time,
),
method_name="execute_complete",
def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
(
"This class is deprecated. "
"Use `airflow.providers.dbt.cloud.sensors.dbt.DbtCloudJobRunSensor` "
"and set `deferrable` param to `True` instead."
),
DeprecationWarning,
stacklevel=2,
)
# TODO: Remove once deprecated
if kwargs.get("poll_interval"):
warnings.warn(

Check warning on line 28 in astronomer/providers/dbt/cloud/sensors/dbt.py

Codecov / codecov/patch

astronomer/providers/dbt/cloud/sensors/dbt.py#L28

Added line #L28 was not covered by tests
"Argument `poll_interval` is deprecated and will be removed "
"in a future release. Please use `poke_interval` instead.",
DeprecationWarning,
stacklevel=2,
)

def execute_complete(self, context: "Context", event: Dict[str, Any]) -> int:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event["status"] in ["error", "cancelled"]:
raise_error_or_skip_exception(self.soft_fail, event["message"])
self.log.info(event["message"])
return int(event["run_id"])
kwargs["poke_interval"] = kwargs.pop("poll_interval")

Check warning on line 34 in astronomer/providers/dbt/cloud/sensors/dbt.py

Codecov / codecov/patch

astronomer/providers/dbt/cloud/sensors/dbt.py#L34

Added line #L34 was not covered by tests
super().__init__(*args, deferrable=True, **kwargs)
19 changes: 11 additions & 8 deletions astronomer/providers/dbt/cloud/triggers/dbt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import time
import warnings
from typing import Any, AsyncIterator, Dict, Optional, Tuple

from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudJobRunStatus
@@ -10,14 +11,8 @@

class DbtCloudRunJobTrigger(BaseTrigger):
"""
DbtCloudRunJobTrigger is triggered with run id and account id, makes async Http call to dbt and get the status
for the submitted job with run id in polling interval of time.
:param conn_id: The connection identifier for connecting to Dbt.
:param run_id: The ID of a dbt Cloud job.
:param end_time: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days.
:param account_id: The ID of a dbt Cloud account.
:param poll_interval: polling period in seconds to check for the status.
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.dbt.triggers.dbt.DbtCloudRunJobTrigger` instead.
"""

def __init__(
@@ -28,6 +23,14 @@ def __init__(
poll_interval: float,
account_id: Optional[int],
):
warnings.warn(
(
"This class is deprecated. "
"Use `airflow.providers.dbt.triggers.dbt.DbtCloudRunJobTrigger` instead."
),
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.run_id = run_id
self.account_id = account_id
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -59,7 +59,7 @@ databricks =
apache-airflow-providers-databricks>=2.2.0
databricks-sql-connector>=2.0.4;python_version>='3.10'
dbt.cloud =
apache-airflow-providers-dbt-cloud>=2.1.0
apache-airflow-providers-dbt-cloud>=3.5.1
google =
apache-airflow-providers-google>=8.1.0
gcloud-aio-storage
@@ -130,7 +130,7 @@ all =
apache-airflow-providers-microsoft-azure>=8.5.1
asyncssh>=2.12.0
databricks-sql-connector>=2.0.4;python_version>='3.10'
apache-airflow-providers-dbt-cloud>=2.1.0
apache-airflow-providers-dbt-cloud>=3.5.1
gcloud-aio-bigquery
gcloud-aio-storage
kubernetes_asyncio
162 changes: 6 additions & 156 deletions tests/dbt/cloud/operators/test_dbt.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,8 @@
from datetime import datetime
from unittest import mock

import pytest
from airflow.exceptions import AirflowException, AirflowFailException, TaskDeferred
from airflow.models import DAG, DagRun, TaskInstance
from airflow.providers.dbt.cloud.hooks.dbt import (
DbtCloudJobRunException,
DbtCloudJobRunStatus,
)
from airflow.models import DAG
from airflow.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperator
from airflow.utils import timezone
from airflow.utils.types import DagRunType

from astronomer.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperatorAsync
from astronomer.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger


class TestDbtCloudRunJobOperatorAsync:
@@ -24,155 +14,15 @@ class TestDbtCloudRunJobOperatorAsync:
DEFAULT_DATE = timezone.datetime(2021, 1, 1)
dag = DAG("test_dbt_cloud_job_run_op", start_date=DEFAULT_DATE)

def create_context(self, task):
execution_date = datetime(2022, 1, 1, 0, 0, 0)
dag_run = DagRun(
dag_id=self.dag.dag_id,
execution_date=execution_date,
run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date),
)
task_instance = TaskInstance(task=task)
task_instance.dag_run = dag_run
task_instance.dag_id = self.dag.dag_id
task_instance.xcom_push = mock.Mock()
return {
"dag": self.dag,
"run_id": dag_run.run_id,
"task": task,
"ti": task_instance,
"task_instance": task_instance,
}

@mock.patch(
"airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_status",
return_value=DbtCloudJobRunStatus.SUCCESS.value,
)
@mock.patch("astronomer.providers.dbt.cloud.operators.dbt.DbtCloudRunJobOperatorAsync.defer")
@mock.patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_connection")
@mock.patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.trigger_job_run")
def test_dbt_run_job_op_async_succeeded_before_deferred(
self, mock_trigger_job_run, mock_dbt_hook, mock_defer, mock_job_run_status
):
dbt_op = DbtCloudRunJobOperatorAsync(
def test_init(self):
task = DbtCloudRunJobOperatorAsync(
dbt_cloud_conn_id=self.CONN_ID,
task_id=f"{self.TASK_ID}",
job_id=self.DBT_RUN_ID,
check_interval=self.CHECK_INTERVAL,
timeout=self.TIMEOUT,
dag=self.dag,
)
dbt_op.execute(self.create_context(dbt_op))
assert not mock_defer.called

@pytest.mark.parametrize(
"status", (DbtCloudJobRunStatus.CANCELLED.value, DbtCloudJobRunStatus.ERROR.value)
)
@mock.patch(
"airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_status",
)
@mock.patch("astronomer.providers.dbt.cloud.operators.dbt.DbtCloudRunJobOperatorAsync.defer")
@mock.patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_connection")
@mock.patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.trigger_job_run")
def test_dbt_run_job_op_async_failed_before_deferred(
self, mock_trigger_job_run, mock_dbt_hook, mock_defer, mock_job_run_status, status
):
mock_job_run_status.return_value = status
dbt_op = DbtCloudRunJobOperatorAsync(
dbt_cloud_conn_id=self.CONN_ID,
task_id=f"{self.TASK_ID}{status}",
job_id=self.DBT_RUN_ID,
check_interval=self.CHECK_INTERVAL,
timeout=self.TIMEOUT,
dag=self.dag,
)
with pytest.raises(DbtCloudJobRunException):
dbt_op.execute(self.create_context(dbt_op))
assert not mock_defer.called

@pytest.mark.parametrize(
"status",
(
DbtCloudJobRunStatus.QUEUED.value,
DbtCloudJobRunStatus.STARTING.value,
DbtCloudJobRunStatus.RUNNING.value,
),
)
@mock.patch(
"airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_job_run_status",
)
@mock.patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.get_connection")
@mock.patch("airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook.trigger_job_run")
def test_dbt_run_job_op_async(self, mock_trigger_job_run, mock_dbt_hook, mock_job_run_status, status):
"""
Asserts that a task is deferred and an DbtCloudRunJobTrigger will be fired
when the DbtCloudRunJobOperatorAsync is provided with all required arguments
"""
mock_job_run_status.return_value = status
dbt_op = DbtCloudRunJobOperatorAsync(
dbt_cloud_conn_id=self.CONN_ID,
task_id=f"{self.TASK_ID}{status}",
job_id=self.DBT_RUN_ID,
check_interval=self.CHECK_INTERVAL,
timeout=self.TIMEOUT,
dag=self.dag,
)
with pytest.raises(TaskDeferred) as exc:
dbt_op.execute(self.create_context(dbt_op))

assert isinstance(exc.value.trigger, DbtCloudRunJobTrigger), "Trigger is not a DbtCloudRunJobTrigger"

def test_dbt_run_job_op_with_exception(self):
"""Test DbtCloudRunJobOperatorAsync to raise exception"""
dbt_op = DbtCloudRunJobOperatorAsync(
dbt_cloud_conn_id=self.CONN_ID,
task_id=self.TASK_ID,
job_id=self.DBT_RUN_ID,
check_interval=self.CHECK_INTERVAL,
timeout=self.TIMEOUT,
)
with pytest.raises(AirflowException):
dbt_op.execute_complete(
context=None, event={"status": "error", "message": "test failure message"}
)

def test_dbt_run_job_cancelled_exception(self, caplog):
"""Test DbtCloudRunJobOperatorAsync to raise exception when job is cancelled"""
dbt_op = DbtCloudRunJobOperatorAsync(
dbt_cloud_conn_id=self.CONN_ID,
task_id=self.TASK_ID,
job_id=self.DBT_RUN_ID,
check_interval=self.CHECK_INTERVAL,
timeout=self.TIMEOUT,
)
with pytest.raises(AirflowFailException) as exc:
dbt_op.execute_complete(
context=None,
event={
"status": "cancelled",
"message": f"Job run {self.DBT_RUN_ID} has been cancelled.",
"run_id": self.DBT_RUN_ID,
},
)
assert f"Job run {self.DBT_RUN_ID} has been cancelled." in str(exc.value)
assert "Task will not be retried." in caplog.text

@pytest.mark.parametrize(
"mock_event",
[
({"status": "success", "message": "Job run 48617 has completed successfully.", "run_id": 1234}),
],
)
def test_dbt_job_execute_complete(self, mock_event):
"""Test DbtCloudRunJobOperatorAsync by mocking the success response and assert the log and return value"""
dbt_op = DbtCloudRunJobOperatorAsync(
dbt_cloud_conn_id=self.CONN_ID,
task_id=self.TASK_ID,
job_id=self.DBT_RUN_ID,
check_interval=self.CHECK_INTERVAL,
timeout=self.TIMEOUT,
)

with mock.patch.object(dbt_op.log, "info") as mock_log_info:
assert dbt_op.execute_complete(context=None, event=mock_event) == self.DBT_RUN_ID

mock_log_info.assert_called_with("Job run 48617 has completed successfully.")
assert isinstance(task, DbtCloudRunJobOperator)
assert task.deferrable is True
67 changes: 4 additions & 63 deletions tests/dbt/cloud/sensors/test_dbt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
from unittest import mock

import pytest
from airflow import AirflowException
from airflow.exceptions import TaskDeferred
from airflow.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensor

from astronomer.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensorAsync
from astronomer.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger

MODULE = "astronomer.providers.dbt.cloud.sensors.dbt"


class TestDbtCloudJobRunSensorAsync:
@@ -16,65 +9,13 @@ class TestDbtCloudJobRunSensorAsync:
DBT_RUN_ID = 1234
TIMEOUT = 300

@mock.patch(f"{MODULE}.DbtCloudJobRunSensorAsync.defer")
@mock.patch(f"{MODULE}.DbtCloudJobRunSensorAsync.poke", return_value=True)
def test_DbtCloudJobRunSensorAsync_async_finish_before_deferred(self, mock_poke, mock_defer, context):
"""Assert task is not deferred when it receives a finish status before deferring"""
task = DbtCloudJobRunSensorAsync(
dbt_cloud_conn_id=self.CONN_ID,
task_id=self.TASK_ID,
run_id=self.DBT_RUN_ID,
timeout=self.TIMEOUT,
)
task.execute(context)

assert not mock_defer.called

@mock.patch(f"{MODULE}.DbtCloudJobRunSensorAsync.poke", return_value=False)
def test_dbt_job_run_sensor_async(self, context):
"""Assert execute method defer for Dbt cloud job run status sensors"""
task = DbtCloudJobRunSensorAsync(
dbt_cloud_conn_id=self.CONN_ID,
task_id=self.TASK_ID,
run_id=self.DBT_RUN_ID,
timeout=self.TIMEOUT,
)
with pytest.raises(TaskDeferred) as exc:
task.execute(context)
assert isinstance(exc.value.trigger, DbtCloudRunJobTrigger), "Trigger is not a DbtCloudRunJobTrigger"

def test_dbt_job_run_sensor_async_execute_complete_success(self):
"""Assert execute_complete log success message when trigger fire with target status"""
def test_init(self):
task = DbtCloudJobRunSensorAsync(
dbt_cloud_conn_id=self.CONN_ID,
task_id=self.TASK_ID,
run_id=self.DBT_RUN_ID,
timeout=self.TIMEOUT,
)

msg = f"Job run {self.DBT_RUN_ID} has completed successfully."
with mock.patch.object(task.log, "info") as mock_log_info:
task.execute_complete(
context={}, event={"status": "success", "message": msg, "run_id": self.DBT_RUN_ID}
)
mock_log_info.assert_called_with(msg)

@pytest.mark.parametrize(
"mock_status, mock_message",
[
("cancelled", "Job run 1234 has been cancelled."),
("error", "Job run 1234 has failed."),
],
)
def test_dbt_job_run_sensor_async_execute_complete_failure(self, mock_status, mock_message):
"""Assert execute_complete method to raise exception on the cancelled and error status"""
task = DbtCloudJobRunSensorAsync(
dbt_cloud_conn_id=self.CONN_ID,
task_id=self.TASK_ID,
run_id=self.DBT_RUN_ID,
timeout=self.TIMEOUT,
)
with pytest.raises(AirflowException):
task.execute_complete(
context={}, event={"status": mock_status, "message": mock_message, "run_id": self.DBT_RUN_ID}
)
assert isinstance(task, DbtCloudJobRunSensor)
assert task.deferrable is True

0 comments on commit 15cb47f

Please sign in to comment.