diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py index 7fbc540adf045..3e151cb49c778 100644 --- a/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow/ti_deps/deps/trigger_rule_dep.py @@ -28,6 +28,7 @@ from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.base_ti_dep import BaseTIDep, TIDepStatus from airflow.utils.state import TaskInstanceState +from airflow.utils.task_group import MappedTaskGroup from airflow.utils.trigger_rule import TriggerRule as TR if TYPE_CHECKING: @@ -131,6 +132,20 @@ def _get_expanded_ti_count() -> int: """ return ti.task.get_mapped_ti_count(ti.run_id, session=session) + def _iter_expansion_dependencies() -> Iterator[str]: + from airflow.models.mappedoperator import MappedOperator + + if isinstance(ti.task, MappedOperator): + for op in ti.task.iter_mapped_dependencies(): + yield op.task_id + task_group = ti.task.task_group + if task_group and task_group.iter_mapped_task_groups(): + yield from ( + op.task_id + for tg in task_group.iter_mapped_task_groups() + for op in tg.iter_mapped_dependencies() + ) + @functools.lru_cache def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None: """Get the given task's map indexes relevant to the current ti. @@ -141,6 +156,9 @@ def _get_relevant_upstream_map_indexes(upstream_id: str) -> int | range | None: """ if TYPE_CHECKING: assert isinstance(ti.task.dag, DAG) + if isinstance(ti.task.task_group, MappedTaskGroup): + if upstream_id not in set(_iter_expansion_dependencies()): + return None try: expanded_ti_count = _get_expanded_ti_count() except (NotFullyPopulated, NotMapped): diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 5509634d17dd4..bbadc57bd0651 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -1303,8 +1303,8 @@ def file_transforms(filename): states = self.get_states(dr) expected = { "file_transforms.my_setup": {0: "success", 1: "failed", 2: "skipped"}, - "file_transforms.my_work": {0: "success", 1: "upstream_failed", 2: "skipped"}, - "file_transforms.my_teardown": {0: "success", 1: "upstream_failed", 2: "skipped"}, + "file_transforms.my_work": {2: "upstream_failed", 1: "upstream_failed", 0: "upstream_failed"}, + "file_transforms.my_teardown": {2: "success", 1: "success", 0: "success"}, } assert states == expected diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py index 2168215ed3ab6..3df5f893be8a7 100644 --- a/tests/ti_deps/deps/test_trigger_rule_dep.py +++ b/tests/ti_deps/deps/test_trigger_rule_dep.py @@ -1164,19 +1164,23 @@ def _one_scheduling_decision_iteration() -> dict[tuple[str, int], TaskInstance]: tis = _one_scheduling_decision_iteration() assert sorted(tis) == [("tg.t1", 0), ("tg.t1", 1), ("tg.t1", 2)] - # After running the first t1, the first t2 becomes immediately available. + # After running the first t1, the remaining t1 must be run before t2 is available. tis["tg.t1", 0].run() tis = _one_scheduling_decision_iteration() - assert sorted(tis) == [("tg.t1", 1), ("tg.t1", 2), ("tg.t2", 0)] + assert sorted(tis) == [("tg.t1", 1), ("tg.t1", 2)] - # Similarly for the subsequent t2 instances. + # After running all t1, t2 is available. + tis["tg.t1", 1].run() tis["tg.t1", 2].run() tis = _one_scheduling_decision_iteration() - assert sorted(tis) == [("tg.t1", 1), ("tg.t2", 0), ("tg.t2", 2)] + assert sorted(tis) == [("tg.t2", 0), ("tg.t2", 1), ("tg.t2", 2)] - # But running t2 partially does not make t3 available. - tis["tg.t1", 1].run() + # Similarly for t2 instances. They both have to complete before t3 is available tis["tg.t2", 0].run() + tis = _one_scheduling_decision_iteration() + assert sorted(tis) == [("tg.t2", 1), ("tg.t2", 2)] + + # But running t2 partially does not make t3 available. tis["tg.t2", 2].run() tis = _one_scheduling_decision_iteration() assert sorted(tis) == [("tg.t2", 1)] @@ -1406,3 +1410,34 @@ def w2(): (status,) = self.get_dep_statuses(dr, "w2", flag_upstream_failed=True, session=session) assert status.reason.startswith("All setup tasks must complete successfully") assert self.get_ti(dr, "w2").state == expected + + +def test_mapped_tasks_in_mapped_task_group_waits_for_upstreams_to_complete(dag_maker, session): + """Test that one failed trigger rule works well in mapped task group""" + with dag_maker() as dag: + + @dag.task + def t1(): + return [1, 2, 3] + + @task_group("tg1") + def tg1(a): + @dag.task() + def t2(a): + return a + + @dag.task(trigger_rule=TriggerRule.ONE_FAILED) + def t3(a): + return a + + t2(a) >> t3(a) + + t = t1() + tg1.expand(a=t) + + dr = dag_maker.create_dagrun() + ti = dr.get_task_instance(task_id="t1") + ti.run() + dr.task_instance_scheduling_decisions() + ti3 = dr.get_task_instance(task_id="tg1.t3") + assert not ti3.state