-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1686086
commit d578ccb
Showing
10 changed files
with
728 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import time | ||
from datetime import datetime | ||
|
||
from airflow import DAG | ||
from airflow.decorators import task | ||
from airflow.operators.trigger_dagrun import TriggerDagRunOperator | ||
|
||
from astronomer.providers.core.sensors.astro import ExternalDeploymentSensor | ||
|
||
with DAG( | ||
dag_id="example_astro_task", | ||
start_date=datetime(2022, 1, 1), | ||
schedule=None, | ||
catchup=False, | ||
tags=["example", "async", "core"], | ||
): | ||
ExternalDeploymentSensor( | ||
task_id="test1", | ||
external_dag_id="example_wait_to_test_example_astro_task", | ||
) | ||
|
||
ExternalDeploymentSensor( | ||
task_id="test2", | ||
external_dag_id="example_wait_to_test_example_astro_task", | ||
external_task_id="wait_for_2_min", | ||
) | ||
|
||
with DAG( | ||
dag_id="wait_to_test_example_astro_dag", | ||
start_date=datetime(2022, 1, 1), | ||
schedule=None, | ||
catchup=False, | ||
tags=["example", "async", "core"], | ||
): | ||
|
||
@task | ||
def wait_for_2_min() -> None: | ||
"""Wait for 2 min.""" | ||
time.sleep(120) | ||
|
||
wait_for_2_min() | ||
|
||
|
||
with DAG( | ||
dag_id="trigger_astro_test_and_example", | ||
start_date=datetime(2022, 1, 1), | ||
schedule=None, | ||
catchup=False, | ||
tags=["example", "async", "core"], | ||
): | ||
run_wait_dag = TriggerDagRunOperator( | ||
task_id="run_wait_dag", | ||
trigger_dag_id="example_external_task_async_waits_for_me", | ||
wait_for_completion=False, | ||
) | ||
|
||
run_astro_dag = TriggerDagRunOperator( | ||
task_id="run_astro_dag", | ||
trigger_dag_id="example_astro_task", | ||
wait_for_completion=False, | ||
) | ||
|
||
run_wait_dag >> run_astro_dag |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
from typing import Any | ||
from urllib.parse import quote | ||
|
||
import requests | ||
from airflow.exceptions import AirflowException | ||
from airflow.hooks.base import BaseHook | ||
|
||
|
||
class AstroHook(BaseHook): | ||
""" | ||
Custom Apache Airflow Hook for interacting with Astro Cloud API. | ||
:param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials. | ||
""" | ||
|
||
conn_name_attr = "astro_cloud_conn_id" | ||
default_conn_name = "astro_cloud_default" | ||
conn_type = "Astro Cloud" | ||
hook_name = "Astro Cloud" | ||
|
||
def __init__(self, astro_cloud_conn_id: str = "astro_cloud_conn_id"): | ||
super().__init__() | ||
self.astro_cloud_conn_id = astro_cloud_conn_id | ||
|
||
@classmethod | ||
def get_ui_field_behaviour(cls) -> dict[str, Any]: | ||
""" | ||
Returns UI field behavior customization for the Astro Cloud connection. | ||
This method defines hidden fields, relabeling, and placeholders for UI display. | ||
""" | ||
return { | ||
"hidden_fields": ["login", "port", "schema", "extra"], | ||
"relabeling": { | ||
"password": "Astro Cloud API Token", | ||
}, | ||
"placeholders": { | ||
"host": "https://clmkpsyfc010391acjie00t1l.astronomer.run/d5lc9c9x", | ||
"password": "JWT API Token", | ||
}, | ||
} | ||
|
||
def get_conn(self) -> tuple[str, str]: | ||
"""Retrieves the Astro Cloud connection details.""" | ||
conn = BaseHook.get_connection(self.astro_cloud_conn_id) | ||
base_url = conn.host or os.environ.get("AIRFLOW__WEBSERVER__BASE_URL") | ||
if base_url is None: | ||
raise AirflowException(f"Airflow host is missing in connection {self.astro_cloud_conn_id}") | ||
token = conn.password | ||
if token is None: | ||
raise AirflowException(f"Astro API token is missing in connection {self.astro_cloud_conn_id}") | ||
return base_url, token | ||
|
||
@property | ||
def _headers(self) -> dict[str, str]: | ||
"""Generates and returns headers for Astro Cloud API requests.""" | ||
_, token = self.get_conn() | ||
headers = {"accept": "application/json", "Authorization": f"Bearer {token}"} | ||
return headers | ||
|
||
def get_dag_runs(self, external_dag_id: str) -> list[dict[str, str]]: | ||
""" | ||
Retrieves information about running or queued DAG runs. | ||
:param external_dag_id: External ID of the DAG. | ||
""" | ||
base_url, _ = self.get_conn() | ||
path = f"/api/v1/dags/{external_dag_id}/dagRuns" | ||
params: dict[str, int | str | list[str]] = {"limit": 1, "state": ["running", "queued"]} | ||
url = f"{base_url}{path}" | ||
response = requests.get(url, headers=self._headers, params=params) | ||
response.raise_for_status() | ||
data: dict[str, list[dict[str, str]]] = response.json() | ||
return data["dag_runs"] | ||
|
||
def get_dag_run(self, external_dag_id: str, dag_run_id: str) -> dict[str, Any] | None: | ||
""" | ||
Retrieves information about a specific DAG run. | ||
:param external_dag_id: External ID of the DAG. | ||
:param dag_run_id: ID of the DAG run. | ||
""" | ||
base_url, _ = self.get_conn() | ||
dag_run_id = quote(dag_run_id) | ||
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}" | ||
url = f"{base_url}{path}" | ||
response = requests.get(url, headers=self._headers) | ||
response.raise_for_status() | ||
dr: dict[str, Any] = response.json() | ||
return dr | ||
|
||
def get_task_instance( | ||
self, external_dag_id: str, dag_run_id: str, external_task_id: str | ||
) -> dict[str, Any] | None: | ||
""" | ||
Retrieves information about a specific task instance within a DAG run. | ||
:param external_dag_id: External ID of the DAG. | ||
:param dag_run_id: ID of the DAG run. | ||
:param external_task_id: External ID of the task. | ||
""" | ||
base_url, _ = self.get_conn() | ||
dag_run_id = quote(dag_run_id) | ||
path = f"/api/v1/dags/{external_dag_id}/dagRuns/{dag_run_id}/taskInstances/{external_task_id}" | ||
url = f"{base_url}{path}" | ||
response = requests.get(url, headers=self._headers) | ||
response.raise_for_status() | ||
ti: dict[str, Any] = response.json() | ||
return ti |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from __future__ import annotations | ||
|
||
import datetime | ||
|
||
# import time | ||
from typing import Any, cast | ||
|
||
from airflow.exceptions import AirflowException, AirflowSkipException | ||
from airflow.sensors.base import BaseSensorOperator, PokeReturnValue | ||
|
||
from astronomer.providers.core.hooks.astro import AstroHook | ||
from astronomer.providers.core.triggers.astro import AstroDeploymentTrigger | ||
from astronomer.providers.utils.typing_compat import Context | ||
|
||
|
||
class ExternalDeploymentSensor(BaseSensorOperator): | ||
""" | ||
Custom Apache Airflow sensor for monitoring external deployments using Astro Cloud. | ||
:param external_dag_id: External ID of the DAG being monitored. | ||
:param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials. | ||
Defaults to "astro_cloud_default". | ||
:param external_task_id: External ID of the task being monitored. If None, monitors the entire DAG. | ||
:param kwargs: Additional keyword arguments passed to the BaseSensorOperator constructor. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
external_dag_id: str, | ||
astro_cloud_conn_id: str = "astro_cloud_default", | ||
external_task_id: str | None = None, | ||
**kwargs: Any, | ||
): | ||
super().__init__(**kwargs) | ||
self.astro_cloud_conn_id = astro_cloud_conn_id | ||
self.external_task_id = external_task_id | ||
self.external_dag_id = external_dag_id | ||
self._dag_run_id: str = "" | ||
|
||
# def wait_for_dag_start(self, second_to_wait: int = 120, sleep: int = 5) -> None: | ||
# """TODO""" | ||
# hook = AstroHook(self.astro_cloud_conn_id) | ||
# end_time = datetime.datetime.now() + datetime.timedelta(seconds=second_to_wait) | ||
# while end_time >= datetime.datetime.now(): | ||
# try: | ||
# dag_runs = hook.get_dag_runs(self.external_dag_id) | ||
# if dag_runs is not None: | ||
# return | ||
# except Exception: | ||
# time.sleep(sleep) | ||
|
||
def poke(self, context: Context) -> bool | PokeReturnValue: | ||
""" | ||
Check the status of a DAG/task in another deployment. | ||
Queries Airflow's REST API for the status of the specified DAG or task instance. | ||
Returns True if successful, False otherwise. | ||
:param context: The task execution context. | ||
""" | ||
hook = AstroHook(self.astro_cloud_conn_id) | ||
dag_runs: list[dict[str, Any]] = hook.get_dag_runs(self.external_dag_id) | ||
if dag_runs is None or len(dag_runs) == 0: | ||
self.log.info("No DAG runs found for DAG %s", self.external_dag_id) | ||
return True | ||
self._dag_run_id = cast(str, dag_runs[0]["dag_run_id"]) | ||
if self.external_task_id is not None: | ||
task_instance = hook.get_task_instance( | ||
self.external_dag_id, self._dag_run_id, self.external_task_id | ||
) | ||
task_state = task_instance.get("state") if task_instance else None | ||
if task_state == "success": | ||
return True | ||
else: | ||
state = dag_runs[0].get("state") | ||
if state == "success": | ||
return True | ||
return False | ||
|
||
def execute(self, context: Context) -> Any: | ||
""" | ||
Executes the sensor. | ||
If the external deployment is not successful, it defers the execution using an AstroDeploymentTrigger. | ||
:param context: The task execution context. | ||
""" | ||
if not self.poke(context): | ||
self.defer( | ||
timeout=datetime.timedelta(seconds=self.timeout), | ||
trigger=AstroDeploymentTrigger( | ||
astro_cloud_conn_id=self.astro_cloud_conn_id, | ||
external_task_id=self.external_task_id, | ||
external_dag_id=self.external_dag_id, | ||
poke_interval=self.poke_interval, | ||
dag_run_id=self._dag_run_id, | ||
), | ||
method_name="execute_complete", | ||
) | ||
|
||
def execute_complete(self, context: Context, event: dict[str, str]) -> None: | ||
""" | ||
Handles the completion event from the deferred execution. | ||
Raises AirflowSkipException if the upstream job failed and `soft_fail` is True. | ||
Otherwise, raises AirflowException. | ||
:param context: The task execution context. | ||
:param event: The event dictionary received from the deferred execution. | ||
""" | ||
if event.get("status") == "failed": | ||
if self.soft_fail: | ||
raise AirflowSkipException("Upstream job failed. Skipping the task.") | ||
else: | ||
raise AirflowException("Upstream job failed.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
from typing import Any, AsyncIterator | ||
|
||
from airflow.triggers.base import BaseTrigger, TriggerEvent | ||
|
||
from astronomer.providers.core.hooks.astro import AstroHook | ||
|
||
|
||
class AstroDeploymentTrigger(BaseTrigger): | ||
""" | ||
Custom Apache Airflow trigger for monitoring the completion status of an external deployment using Astro Cloud. | ||
:param external_dag_id: External ID of the DAG being monitored. | ||
:param dag_run_id: ID of the DAG run being monitored. | ||
:param external_task_id: External ID of the task being monitored. If None, monitors the entire DAG. | ||
:param astro_cloud_conn_id: The connection ID to retrieve Astro Cloud credentials. Defaults to "astro_cloud_default". | ||
:param poke_interval: Time in seconds to wait between consecutive checks for completion status. | ||
:param kwargs: Additional keyword arguments passed to the BaseTrigger constructor. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
external_dag_id: str, | ||
dag_run_id: str, | ||
external_task_id: str | None = None, | ||
astro_cloud_conn_id: str = "astro_cloud_default", | ||
poke_interval: float = 5.0, | ||
**kwargs: Any, | ||
): | ||
super().__init__(**kwargs) | ||
self.external_dag_id = external_dag_id | ||
self.dag_run_id = dag_run_id | ||
self.external_task_id = external_task_id | ||
self.astro_cloud_conn_id = astro_cloud_conn_id | ||
self.poke_interval = poke_interval | ||
|
||
def serialize(self) -> tuple[str, dict[str, Any]]: | ||
"""Serialize the trigger for storage in the database.""" | ||
return ( | ||
"astronomer.providers.core.triggers.astro.AstroDeploymentTrigger", | ||
{ | ||
"external_dag_id": self.external_dag_id, | ||
"external_task_id": self.external_task_id, | ||
"dag_run_id": self.dag_run_id, | ||
"astro_cloud_conn_id": self.astro_cloud_conn_id, | ||
"poke_interval": self.poke_interval, | ||
}, | ||
) | ||
|
||
async def run(self) -> AsyncIterator[TriggerEvent]: | ||
""" | ||
Asynchronously runs the trigger and yields completion status events. | ||
Checks the status of the external deployment using Astro Cloud at regular intervals. | ||
Yields TriggerEvent with the status "done" if successful, "failed" if failed. | ||
""" | ||
hook = AstroHook(self.astro_cloud_conn_id) | ||
while True: | ||
if self.external_task_id is not None: | ||
task_instance = hook.get_task_instance( | ||
self.external_dag_id, self.dag_run_id, self.external_task_id | ||
) | ||
state = task_instance.get("state") if task_instance else None | ||
if state in ("success", "skipped"): | ||
yield TriggerEvent({"status": "done"}) | ||
elif state in ("failed", "upstream_failed"): | ||
yield TriggerEvent({"status": "failed"}) | ||
else: | ||
dag_run = hook.get_dag_run(self.external_dag_id, self.dag_run_id) | ||
state = dag_run.get("state") if dag_run else None | ||
if state == "success": | ||
yield TriggerEvent({"status": "done"}) | ||
elif state == "failed": | ||
yield TriggerEvent({"status": "failed"}) | ||
self.log.info("Job status is %s sleeping for %s seconds.", state, self.poke_interval) | ||
await asyncio.sleep(self.poke_interval) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.