Skip to content

Commit

Permalink
Deprecate DbtCloudRunJobOperatorAsync and DbtCloudJobRunSensorAsync
Browse files Browse the repository at this point in the history
This PR deprecates the operator DbtCloudRunJobOperatorAsync and the
sensor DbtCloudJobRunSensorAsync from the dbt provider by proxying
them to their Airflow OSS provider's counterpart.

closes: #1414
  • Loading branch information
pankajkoti committed Jan 23, 2024
1 parent ec3e539 commit 7f9e19b
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 351 deletions.
13 changes: 10 additions & 3 deletions astronomer/providers/dbt/cloud/hooks/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,8 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:

class DbtCloudHookAsync(BaseHook):
"""
Interact with dbt Cloud using the V2 API.
:param dbt_cloud_conn_id: The ID of the :ref:`dbt Cloud connection <howto/connection:dbt-cloud>`.
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook` instead.
"""

conn_name_attr = "dbt_cloud_conn_id"
Expand All @@ -53,6 +52,14 @@ class DbtCloudHookAsync(BaseHook):
hook_name = "dbt Cloud"

def __init__(self, dbt_cloud_conn_id: str):
warnings.warn(
(
"This class is deprecated. "
"Use `airflow.providers.dbt.cloud.hooks.dbt.DbtCloudHook` instead."
),
DeprecationWarning,
stacklevel=2,
)
self.dbt_cloud_conn_id = dbt_cloud_conn_id

async def get_headers_tenants_from_connection(self) -> Tuple[Dict[str, Any], str]:
Expand Down
88 changes: 21 additions & 67 deletions astronomer/providers/dbt/cloud/operators/dbt.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,47 @@
from __future__ import annotations

import time
import warnings
from typing import Any

from airflow import AirflowException
from airflow.exceptions import AirflowFailException
from airflow.providers.dbt.cloud.hooks.dbt import (
DbtCloudHook,
DbtCloudJobRunException,
DbtCloudJobRunStatus,
JobRunInfo,
)
from airflow.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperator

from astronomer.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger
from astronomer.providers.utils.typing_compat import Context


class DbtCloudRunJobOperatorAsync(DbtCloudRunJobOperator):
"""
Executes a dbt Cloud job asynchronously. Trigger the dbt cloud job via worker to dbt and with run id in response
poll for the status in trigger.
.. seealso::
For more information on sync Operator DbtCloudRunJobOperator, take a look at the guide:
:ref:`howto/operator:DbtCloudRunJobOperator`
:param dbt_cloud_conn_id: The connection ID for connecting to dbt Cloud.
:param job_id: The ID of a dbt Cloud job.
:param account_id: Optional. The ID of a dbt Cloud account.
:param trigger_reason: Optional Description of the reason to trigger the job. Dbt requires the trigger reason while
making an API. if it is not provided uses the default reasons.
:param steps_override: Optional. List of dbt commands to execute when triggering the job instead of those
configured in dbt Cloud.
:param schema_override: Optional. Override the destination schema in the configured target for this job.
:param timeout: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days.
:param check_interval: Time in seconds to check on a job run's status. Defaults to 60 seconds.
:param additional_run_config: Optional. Any additional parameters that should be included in the API
request when triggering the job.
:return: The ID of the triggered dbt Cloud job run.
This class is deprecated.
Use :class: `~airflow.providers.dbt.cloud.operators.dbt.DbtCloudRunJobOperator` instead
and set `deferrable` param to `True` instead.
"""

def execute(self, context: Context) -> Any:
"""Submits a job which generates a run_id and gets deferred"""
if self.trigger_reason is None:
self.trigger_reason = (
f"Triggered via Apache Airflow by task {self.task_id!r} in the {self.dag.dag_id} DAG."
)
hook = DbtCloudHook(dbt_cloud_conn_id=self.dbt_cloud_conn_id)
trigger_job_response = hook.trigger_job_run(
account_id=self.account_id,
job_id=self.job_id,
cause=self.trigger_reason,
steps_override=self.steps_override,
schema_override=self.schema_override,
additional_run_config=self.additional_run_config,
def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
(
"This class is deprecated. "
"Use `airflow.providers.dbt.cloud.operators.dbt.DbtCloudRunJobOperator` "
"and set `deferrable` param to `True` instead."
),
DeprecationWarning,
stacklevel=2,
)
run_id = trigger_job_response.json()["data"]["id"]
job_run_url = trigger_job_response.json()["data"]["href"]

context["ti"].xcom_push(key="job_run_url", value=job_run_url)
end_time = time.time() + self.timeout

job_run_info = JobRunInfo(account_id=self.account_id, run_id=run_id)
job_run_status = hook.get_job_run_status(**job_run_info)
if not DbtCloudJobRunStatus.is_terminal(job_run_status):
self.defer(
timeout=self.execution_timeout,
trigger=DbtCloudRunJobTrigger(
conn_id=self.dbt_cloud_conn_id,
run_id=run_id,
end_time=end_time,
account_id=self.account_id,
poll_interval=self.check_interval,
),
method_name="execute_complete",
)
elif job_run_status == DbtCloudJobRunStatus.SUCCESS.value:
self.log.info("Job run %s has completed successfully.", str(run_id))
return run_id
elif job_run_status in (DbtCloudJobRunStatus.CANCELLED.value, DbtCloudJobRunStatus.ERROR.value):
raise DbtCloudJobRunException(f"Job run {run_id} has failed or has been cancelled.")
super().__init__(*args, deferrable=True, **kwargs)

def execute_complete(self, context: Context, event: dict[str, Any]) -> int:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""

# We handle the case where the job run is cancelled a bit differently than the OSS operator.
# Essentially, we do not want to retry the task if the job run is cancelled, whereas the OSS operator will
# retry the task if the job run is cancelled. This has been specifically handled here differently based upon
# the feedback from a user. And hence, while we are deprecating this operator, we are not changing the behavior
# of the `execute_complete` method. We can check if the wider OSS community wants this behavior to be changed
# in the future as it is here, and then we can remove this override.
if event["status"] == "cancelled":
self.log.info("Job run %s has been cancelled.", str(event["run_id"]))
self.log.info("Task will not be retried.")
Expand Down
78 changes: 26 additions & 52 deletions astronomer/providers/dbt/cloud/sensors/dbt.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,35 @@
import time
from typing import Any, Dict
from __future__ import annotations

from airflow.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensor
import warnings
from typing import Any

from astronomer.providers.dbt.cloud.triggers.dbt import DbtCloudRunJobTrigger
from astronomer.providers.utils.sensor_util import poke, raise_error_or_skip_exception
from astronomer.providers.utils.typing_compat import Context
from airflow.providers.dbt.cloud.sensors.dbt import DbtCloudJobRunSensor


class DbtCloudJobRunSensorAsync(DbtCloudJobRunSensor):
"""
Checks the status of a dbt Cloud job run.
.. seealso::
For more information on sync Sensor DbtCloudJobRunSensor, take a look at the guide::
:ref:`howto/operator:DbtCloudJobRunSensor`
:param dbt_cloud_conn_id: The connection identifier for connecting to dbt Cloud.
:param run_id: The job run identifier.
:param account_id: The dbt Cloud account identifier.
:param timeout: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days.
This class is deprecated.
Use :class: `~airflow.providers.dbt.cloud.sensors.dbt.DbtCloudJobRunSensor` instead
and set `deferrable` param to `True` instead.
"""

def __init__(
self,
*,
poll_interval: float = 5,
timeout: float = 60 * 60 * 24 * 7,
**kwargs: Any,
):
self.poll_interval = poll_interval
self.timeout = timeout
super().__init__(**kwargs)

def execute(self, context: "Context") -> None:
"""Defers trigger class to poll for state of the job run until it reaches a failure state or success state"""
if not poke(self, context):
end_time = time.time() + self.timeout
self.defer(
timeout=self.execution_timeout,
trigger=DbtCloudRunJobTrigger(
run_id=self.run_id,
conn_id=self.dbt_cloud_conn_id,
account_id=self.account_id,
poll_interval=self.poll_interval,
end_time=end_time,
),
method_name="execute_complete",
def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
(
"This class is deprecated. "
"Use `airflow.providers.dbt.cloud.sensors.dbt.DbtCloudJobRunSensor` "
"and set `deferrable` param to `True` instead."
),
DeprecationWarning,
stacklevel=2,
)
# TODO: Remove once deprecated
if kwargs.get("poll_interval"):
warnings.warn(
"Argument `poll_interval` is deprecated and will be removed "
"in a future release. Please use `poke_interval` instead.",
DeprecationWarning,
stacklevel=2,
)

def execute_complete(self, context: "Context", event: Dict[str, Any]) -> int:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event["status"] in ["error", "cancelled"]:
raise_error_or_skip_exception(self.soft_fail, event["message"])
self.log.info(event["message"])
return int(event["run_id"])
kwargs["poke_interval"] = kwargs.pop("poll_interval")
super().__init__(*args, deferrable=True, **kwargs)
19 changes: 11 additions & 8 deletions astronomer/providers/dbt/cloud/triggers/dbt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import time
import warnings
from typing import Any, AsyncIterator, Dict, Optional, Tuple

from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudJobRunStatus
Expand All @@ -10,14 +11,8 @@

class DbtCloudRunJobTrigger(BaseTrigger):
"""
DbtCloudRunJobTrigger is triggered with run id and account id, makes async Http call to dbt and get the status
for the submitted job with run id in polling interval of time.
:param conn_id: The connection identifier for connecting to Dbt.
:param run_id: The ID of a dbt Cloud job.
:param end_time: Time in seconds to wait for a job run to reach a terminal status. Defaults to 7 days.
:param account_id: The ID of a dbt Cloud account.
:param poll_interval: polling period in seconds to check for the status.
This class is deprecated and will be removed in 2.0.0.
Use :class: `~airflow.providers.dbt.triggers.dbt.DbtCloudRunJobTrigger` instead.
"""

def __init__(
Expand All @@ -28,6 +23,14 @@ def __init__(
poll_interval: float,
account_id: Optional[int],
):
warnings.warn(
(
"This class is deprecated. "
"Use `airflow.providers.dbt.triggers.dbt.DbtCloudRunJobTrigger` instead."
),
DeprecationWarning,
stacklevel=2,
)
super().__init__()
self.run_id = run_id
self.account_id = account_id
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ databricks =
apache-airflow-providers-databricks>=2.2.0
databricks-sql-connector>=2.0.4;python_version>='3.10'
dbt.cloud =
apache-airflow-providers-dbt-cloud>=2.1.0
apache-airflow-providers-dbt-cloud>=3.5.1
google =
apache-airflow-providers-google>=8.1.0
gcloud-aio-storage
Expand Down Expand Up @@ -130,7 +130,7 @@ all =
apache-airflow-providers-microsoft-azure>=8.5.1
asyncssh>=2.12.0
databricks-sql-connector>=2.0.4;python_version>='3.10'
apache-airflow-providers-dbt-cloud>=2.1.0
apache-airflow-providers-dbt-cloud>=3.5.1
gcloud-aio-bigquery
gcloud-aio-storage
kubernetes_asyncio
Expand Down
Loading

0 comments on commit 7f9e19b

Please sign in to comment.