Skip to content

Commit

Permalink
Fail a task if an inlet or outlet asset is inactive or an inactive as…
Browse files Browse the repository at this point in the history
…set is added to an asset alias (#44831)

* feat(taskinstance): fail a task if its outlets contain inactive asset

* test(taskinstance): active assets in test cases

* test(taskinstance): add test cases test_run_with_inactive_assets_in_the_same_dag and test_run_with_inactive_assets_in_different_dags

* feat(taskinstance): fail a task if asset is not active in inlets

* refactor(taskinstance): rework warning message

* feat(asset_alias): block adding asset events to assets that can not be active

* feat(asset-alias): handle the case that asset is not active but might be able to be activated when adding

* feat(taskinstance): refactor name_uri as asset key

* refactor(exceptions): move error msg logic to customized exception

* test(taskinstance): add test case test_outlet_asset_alias_asset_inactive

* refactor(taskinstance): remove AssetUniqueKey.to_tuple and use attrs.astuple instead
  • Loading branch information
Lee-W authored Dec 25, 2024
1 parent f1167f4 commit b13ed89
Show file tree
Hide file tree
Showing 4 changed files with 328 additions and 39 deletions.
31 changes: 31 additions & 0 deletions airflow/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from __future__ import annotations

import warnings
from collections.abc import Collection
from datetime import timedelta
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, NamedTuple
Expand All @@ -32,6 +33,7 @@
from collections.abc import Sized

from airflow.models import DagRun
from airflow.sdk.definitions.asset import AssetUniqueKey


class AirflowException(Exception):
Expand Down Expand Up @@ -111,6 +113,35 @@ class AirflowFailException(AirflowException):
"""Raise when the task should be failed without retrying."""


class AirflowExecuteWithInactiveAssetExecption(AirflowFailException):
"""Raise when the task is executed with inactive assets."""

def __init__(self, inactive_asset_unikeys: Collection[AssetUniqueKey]) -> None:
self.inactive_asset_unique_keys = inactive_asset_unikeys

@property
def inactive_assets_error_msg(self):
return ", ".join(
f'Asset(name="{key.name}", uri="{key.uri}")' for key in self.inactive_asset_unique_keys
)


class AirflowInactiveAssetInInletOrOutletException(AirflowExecuteWithInactiveAssetExecption):
"""Raise when the task is executed with inactive assets in its inlet or outlet."""

def __str__(self) -> str:
return f"Task has the following inactive assets in its inlets or outlets: {self.inactive_assets_error_msg}"


class AirflowInactiveAssetAddedToAssetAliasException(AirflowExecuteWithInactiveAssetExecption):
"""Raise when inactive assets are added to an asset alias."""

def __str__(self) -> str:
return (
f"The following assets accessed by an AssetAlias are inactive: {self.inactive_assets_error_msg}"
)


class AirflowOptionalProviderFeatureException(AirflowException):
"""Raise by providers when imports are missing for optional provider features."""

Expand Down
45 changes: 43 additions & 2 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import quote

import attrs
import dill
import jinja2
import lazy_object_proxy
Expand Down Expand Up @@ -79,6 +80,8 @@
from airflow.exceptions import (
AirflowException,
AirflowFailException,
AirflowInactiveAssetAddedToAssetAliasException,
AirflowInactiveAssetInInletOrOutletException,
AirflowRescheduleException,
AirflowSensorTimeout,
AirflowSkipException,
Expand All @@ -91,7 +94,7 @@
XComForMappingNotPushed,
)
from airflow.listeners.listener import get_listener_manager
from airflow.models.asset import AssetEvent, AssetModel
from airflow.models.asset import AssetActive, AssetEvent, AssetModel
from airflow.models.base import Base, StringID, TaskInstanceDependencies, _sentinel
from airflow.models.dagbag import DagBag
from airflow.models.log import Log
Expand Down Expand Up @@ -263,6 +266,7 @@ def _run_raw_task(
context = ti.get_template_context(ignore_param_exceptions=False, session=session)

try:
ti._validate_inlet_outlet_assets_activeness(session=session)
if not mark_success:
TaskInstance._execute_task_with_callbacks(
self=ti, # type: ignore[arg-type]
Expand Down Expand Up @@ -2749,16 +2753,24 @@ def _register_asset_changes_int(
frozen_extra = frozenset(asset_alias_event.extra.items())
asset_alias_names[(asset_unique_key, frozen_extra)].add(asset_alias_name)

asset_unique_keys = {key for key, _ in asset_alias_names}
asset_models: dict[AssetUniqueKey, AssetModel] = {
AssetUniqueKey.from_asset(asset_obj): asset_obj
for asset_obj in session.scalars(
select(AssetModel).where(
tuple_(AssetModel.name, AssetModel.uri).in_(
(key.name, key.uri) for key, _ in asset_alias_names
attrs.astuple(key) for key in asset_unique_keys
)
)
)
}
inactive_asset_unique_keys = TaskInstance._get_inactive_asset_unique_keys(
asset_unique_keys={key for key in asset_unique_keys if key in asset_models},
session=session,
)
if inactive_asset_unique_keys:
raise AirflowInactiveAssetAddedToAssetAliasException(inactive_asset_unique_keys)

if missing_assets := [
asset_unique_key.to_asset()
for asset_unique_key, _ in asset_alias_names
Expand Down Expand Up @@ -3642,6 +3654,35 @@ def duration_expression_update(
}
)

def _validate_inlet_outlet_assets_activeness(self, session: Session) -> None:
if not self.task or not (self.task.outlets or self.task.inlets):
return

all_asset_unique_keys = {
AssetUniqueKey.from_asset(inlet_or_outlet)
for inlet_or_outlet in itertools.chain(self.task.inlets, self.task.outlets)
if isinstance(inlet_or_outlet, Asset)
}
inactive_asset_unique_keys = self._get_inactive_asset_unique_keys(all_asset_unique_keys, session)
if inactive_asset_unique_keys:
raise AirflowInactiveAssetInInletOrOutletException(inactive_asset_unique_keys)

@staticmethod
def _get_inactive_asset_unique_keys(
asset_unique_keys: set[AssetUniqueKey], session: Session
) -> set[AssetUniqueKey]:
active_asset_unique_keys = {
AssetUniqueKey(name, uri)
for name, uri in session.execute(
select(AssetActive.name, AssetActive.uri).where(
tuple_in_condition(
(AssetActive.name, AssetActive.uri), [attrs.astuple(key) for key in asset_unique_keys]
)
)
)
}
return asset_unique_keys - active_asset_unique_keys


def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> MappedTaskGroup | None:
"""Given two operators, find their innermost common mapped task group."""
Expand Down
Loading

0 comments on commit b13ed89

Please sign in to comment.