Skip to content

Commit

Permalink
Start porting mapped task to SDK
Browse files Browse the repository at this point in the history
This PR restructures the Mapped Operator and Mapped Task Group code to live in
the Task SDK at definition time.

The big thing this change _does not do_ is make it possible to execute mapped
tasks via the Task Execution API server etc -- that is up next.

There were some un-avoidable changes to the scheduler/expansion part of mapped
tasks here. Of note:

`BaseOperator.get_mapped_ti_count` has moved from an instance method on
BaseOperator to be a class method. The reason for this was that with the move
of more and more of the "definition time" code into the TaskSDK BaseOperator
and AbstractOperator it is no longer possible to add DB-accessing code to a
base class and have it apply to the subclasses. (i.e.
`airflow.models.abstractoperator.AbstractOperator` is now _not always_ in the
MRO for tasks. Eventually that class will be deleted, but not yet)

On a similar vein XComArg's `get_task_map_length` is also moved to a single
dispatch class method on the TaskMap model since now the definition time
objects live in the TaskSDK, and there is no realistic way to get a per-type
subclass with DB logic (i.e. it's very complex to end up with a
PlainDBXComArg, a MapDBXComArg, etc. that we can attach the method too)

For those who aren't aware, singledispatch (and singledispatchmethod) are a
part of the standard library when the type of the first argument is used to
determine which implementation to call. If you are familiar with C++ or Java
this is very similar to method overloading, the one caveat is that it _only_
examines the type of the first argument, not the full signature.
  • Loading branch information
ashb committed Jan 20, 2025
1 parent 24b1fe8 commit eb93960
Show file tree
Hide file tree
Showing 68 changed files with 2,892 additions and 2,456 deletions.
2 changes: 1 addition & 1 deletion airflow/api/common/mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
if TYPE_CHECKING:
from sqlalchemy.orm import Session as SASession

from airflow.models.dag import DAG
from airflow.models.operator import Operator
from airflow.sdk.defiintion.dag import DAG


@provide_session
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_fastapi/core_api/services/ui/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _get_total_task_count(
node
if isinstance(node, int)
else (
node.get_mapped_ti_count(run_id=run_id, session=session)
BaseOperator.get_mapped_ti_count(node, run_id=run_id, session=session) or 0
if isinstance(node, (MappedTaskGroup, MappedOperator))
else node
)
Expand Down
5 changes: 5 additions & 0 deletions airflow/cli/commands/remote_commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,11 @@ def _get_ti(
dag = task.dag
if dag is None:
raise ValueError("Cannot get task instance for a task not assigned to a DAG")
if not isinstance(dag, DAG):
# SHouldn't really happen, and this command will go away before 3.0
raise ValueError(
f"We need a {DAG.__module__}.DAG, but we got {type(dag).__module__}.{type(dag).__name__}!"
)

# this check is imperfect because diff dags could have tasks with same name
# but in a task, dag_id is a property that accesses its dag, and we don't
Expand Down
6 changes: 3 additions & 3 deletions airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@
ListOfDictsExpandInput,
is_mappable,
)
from airflow.models.mappedoperator import MappedOperator, ensure_xcomarg_return_value
from airflow.models.xcom_arg import XComArg
from airflow.sdk.definitions._internal.contextmanager import DagContext, TaskGroupContext
from airflow.sdk.definitions.asset import Asset
from airflow.sdk.definitions.baseoperator import BaseOperator as TaskSDKBaseOperator
from airflow.sdk.definitions.mappedoperator import MappedOperator, ensure_xcomarg_return_value
from airflow.sdk.definitions.xcom_arg import XComArg
from airflow.typing_compat import ParamSpec
from airflow.utils import timezone
from airflow.utils.context import KNOWN_CONTEXT_KEYS
Expand All @@ -62,9 +62,9 @@
OperatorExpandArgument,
OperatorExpandKwargsArgument,
)
from airflow.models.mappedoperator import ValidationSource
from airflow.sdk.definitions.context import Context
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.mappedoperator import ValidationSource
from airflow.utils.task_group import TaskGroup


Expand Down
212 changes: 13 additions & 199 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,40 +19,36 @@

import datetime
import inspect
from collections.abc import Iterable, Iterator, Sequence
from collections.abc import Iterable, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable

import methodtools
from sqlalchemy import select

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models.expandinput import NotFullyPopulated
from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator as TaskSDKAbstractOperator
from airflow.sdk.definitions._internal.abstractoperator import (
AbstractOperator as TaskSDKAbstractOperator,
NotMapped as NotMapped, # Re-export this for compat
)
from airflow.sdk.definitions.context import Context
from airflow.utils.db import exists_query
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.task_group import MappedTaskGroup
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule, db_safe_priority
from airflow.utils.weight_rule import db_safe_priority

if TYPE_CHECKING:
import jinja2 # Slow imports.
from sqlalchemy.orm import Session

from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.dag import DAG as SchedulerDAG
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstance
from airflow.sdk.definitions._internal.node import DAGNode
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.task.priority_strategy import PriorityWeightStrategy
from airflow.triggers.base import StartTriggerArgs
from airflow.utils.task_group import TaskGroup

TaskStateChangeCallback = Callable[[Context], None]

Expand All @@ -71,19 +67,12 @@
)
MAX_RETRY_DELAY: int = conf.getint("core", "max_task_retry_delay", fallback=24 * 60 * 60)

DEFAULT_WEIGHT_RULE: WeightRule = WeightRule(
conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM)
)
DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS
DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta(
"core", "default_task_execution_timeout"
)


class NotMapped(Exception):
"""Raise if a task is neither mapped nor has any parent mapped groups."""


class AbstractOperator(LoggingMixin, TaskSDKAbstractOperator):
"""
Common implementation for operators, including unmapped and mapped.
Expand All @@ -98,124 +87,8 @@ class AbstractOperator(LoggingMixin, TaskSDKAbstractOperator):
:meta private:
"""

trigger_rule: TriggerRule
weight_rule: PriorityWeightStrategy

@property
def on_failure_fail_dagrun(self):
"""
Whether the operator should fail the dagrun on failure.
:meta private:
"""
return self._on_failure_fail_dagrun

@on_failure_fail_dagrun.setter
def on_failure_fail_dagrun(self, value):
"""
Setter for on_failure_fail_dagrun property.
:meta private:
"""
if value is True and self.is_teardown is not True:
raise ValueError(
f"Cannot set task on_failure_fail_dagrun for "
f"'{self.task_id}' because it is not a teardown task."
)
self._on_failure_fail_dagrun = value

def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]:
"""
Return mapped nodes that are direct dependencies of the current task.
For now, this walks the entire DAG to find mapped nodes that has this
current task as an upstream. We cannot use ``downstream_list`` since it
only contains operators, not task groups. In the future, we should
provide a way to record an DAG node's all downstream nodes instead.
Note that this does not guarantee the returned tasks actually use the
current task for task mapping, but only checks those task are mapped
operators, and are downstreams of the current task.
To get a list of tasks that uses the current task for task mapping, use
:meth:`iter_mapped_dependants` instead.
"""
from airflow.models.mappedoperator import MappedOperator
from airflow.utils.task_group import TaskGroup

def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]:
"""
Recursively walk children in a task group.
This yields all direct children (including both tasks and task
groups), and all children of any task groups.
"""
for key, child in group.children.items():
yield key, child
if isinstance(child, TaskGroup):
yield from _walk_group(child)

dag = self.get_dag()
if not dag:
raise RuntimeError("Cannot check for mapped dependants when not attached to a DAG")
for key, child in _walk_group(dag.task_group):
if key == self.node_id:
continue
if not isinstance(child, (MappedOperator, MappedTaskGroup)):
continue
if self.node_id in child.upstream_task_ids:
yield child

def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]:
"""
Return mapped nodes that depend on the current task the expansion.
For now, this walks the entire DAG to find mapped nodes that has this
current task as an upstream. We cannot use ``downstream_list`` since it
only contains operators, not task groups. In the future, we should
provide a way to record an DAG node's all downstream nodes instead.
"""
return (
downstream
for downstream in self._iter_all_mapped_downstreams()
if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies())
)

def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
"""
Return mapped task groups this task belongs to.
Groups are returned from the innermost to the outmost.
:meta private:
"""
if (group := self.task_group) is None:
return
# TODO: Task-SDK: this type ignore shouldn't be necessary, revisit once mapping support is fully in the
# SDK
yield from group.iter_mapped_task_groups() # type: ignore[misc]

def get_closest_mapped_task_group(self) -> MappedTaskGroup | None:
"""
Get the mapped task group "closest" to this task in the DAG.
:meta private:
"""
return next(self.iter_mapped_task_groups(), None)

def get_needs_expansion(self) -> bool:
"""
Return true if the task is MappedOperator or is in a mapped task group.
:meta private:
"""
if self._needs_expansion is None:
if self.get_closest_mapped_task_group() is not None:
self._needs_expansion = True
else:
self._needs_expansion = False
return self._needs_expansion

def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator:
"""
Get the "normal" operator from current abstract operator.
Expand Down Expand Up @@ -343,43 +216,6 @@ def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None:
return link.get_link(self.unmap(None), ti.dag_run.logical_date) # type: ignore[misc]
return link.get_link(self.unmap(None), ti_key=ti.key)

@methodtools.lru_cache(maxsize=None)
def get_parse_time_mapped_ti_count(self) -> int:
"""
Return the number of mapped task instances that can be created on DAG run creation.
This only considers literal mapped arguments, and would return *None*
when any non-literal values are used for mapping.
:raise NotFullyPopulated: If non-literal mapped arguments are encountered.
:raise NotMapped: If the operator is neither mapped, nor has any parent
mapped task groups.
:return: Total number of mapped TIs this task should have.
"""
group = self.get_closest_mapped_task_group()
if group is None:
raise NotMapped
return group.get_parse_time_mapped_ti_count()

def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
"""
Return the number of mapped TaskInstances that can be created at run time.
This considers both literal and non-literal mapped arguments, and the
result is therefore available when all depended tasks have finished. The
return value should be identical to ``parse_time_mapped_ti_count`` if
all mapped arguments are literal.
:raise NotFullyPopulated: If upstream tasks are not all complete yet.
:raise NotMapped: If the operator is neither mapped, nor has any parent
mapped task groups.
:return: Total number of mapped TIs this task should have.
"""
group = self.get_closest_mapped_task_group()
if group is None:
raise NotMapped
return group.get_mapped_ti_count(run_id, session=session)

def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]:
"""
Create the mapped task instances for mapped task.
Expand All @@ -390,16 +226,20 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence
"""
from sqlalchemy import func, or_

from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstance
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.settings import task_instance_mutation_hook

if not isinstance(self, (BaseOperator, MappedOperator)):
raise RuntimeError(f"cannot expand unrecognized operator type {type(self).__name__}")
raise RuntimeError(
f"cannot expand unrecognized operator type {type(self).__module__}.{type(self).__name__}"
)

from airflow.models.baseoperator import BaseOperator as DBBaseOperator

try:
total_length: int | None = self.get_mapped_ti_count(run_id, session=session)
total_length: int | None = DBBaseOperator.get_mapped_ti_count(self, run_id, session=session)
except NotFullyPopulated as e:
# It's possible that the upstream tasks are not yet done, but we
# don't have upstream of upstreams in partial DAGs (possible in the
Expand Down Expand Up @@ -509,29 +349,3 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence
ti.state = TaskInstanceState.REMOVED
session.flush()
return all_expanded_tis, total_expanded_ti_count - 1

def render_template_fields(
self,
context: Context,
jinja_env: jinja2.Environment | None = None,
) -> None:
"""
Template all attributes listed in *self.template_fields*.
If the operator is mapped, this should return the unmapped, fully
rendered, and map-expanded operator. The mapped operator should not be
modified. However, *context* may be modified in-place to reference the
unmapped operator for template rendering.
If the operator is not mapped, this should modify the operator in-place.
"""
raise NotImplementedError()

def __enter__(self):
if not self.is_setup and not self.is_teardown:
raise AirflowException("Only setup/teardown tasks can be used as context managers.")
SetupTeardownContext.push_setup_teardown_task(self)
return SetupTeardownContext

def __exit__(self, exc_type, exc_val, exc_tb):
SetupTeardownContext.set_work_task_roots_and_leaves()
Loading

0 comments on commit eb93960

Please sign in to comment.