Skip to content

Commit

Permalink
fixup! AIP-72: Add Taskflow API support & template rendering in Task SDK
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil committed Jan 7, 2025
1 parent fdc3117 commit 077ce3b
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 79 deletions.
4 changes: 2 additions & 2 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import datetime
import inspect
from collections.abc import Iterable, Iterator, Sequence
from collections.abc import Iterable, Iterator, Mapping, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable

Expand Down Expand Up @@ -512,7 +512,7 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence

def render_template_fields(
self,
context: Context,
context: Mapping[str, Any],
jinja_env: jinja2.Environment | None = None,
) -> None:
"""
Expand Down
11 changes: 6 additions & 5 deletions airflow/models/expandinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from airflow.models.xcom_arg import XComArg
from airflow.serialization.serialized_objects import _ExpandInputRef
from airflow.typing_compat import TypeGuard
from airflow.utils.context import Context

ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"]

Expand Down Expand Up @@ -69,7 +68,9 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]:
yield from self._input.iter_references()

@provide_session
def resolve(self, context: Context, *, include_xcom: bool = True, session: Session = NEW_SESSION) -> Any:
def resolve(
self, context: Mapping[str, Any], *, include_xcom: bool = True, session: Session = NEW_SESSION
) -> Any:
data, _ = self._input.resolve(context, session=session, include_xcom=include_xcom)
return data[self._key]

Expand Down Expand Up @@ -166,7 +167,7 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1)

def _expand_mapped_field(
self, key: str, value: Any, context: Context, *, session: Session, include_xcom: bool
self, key: str, value: Any, context: Mapping[str, Any], *, session: Session, include_xcom: bool
) -> Any:
if _needs_run_time_resolution(value):
value = (
Expand Down Expand Up @@ -210,7 +211,7 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]:
yield from x.iter_references()

def resolve(
self, context: Context, session: Session, *, include_xcom: bool = True
self, context: Mapping[str, Any], session: Session, *, include_xcom: bool = True
) -> tuple[Mapping[str, Any], set[int]]:
data = {
k: self._expand_mapped_field(k, v, context, session=session, include_xcom=include_xcom)
Expand Down Expand Up @@ -260,7 +261,7 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]:
yield from x.iter_references()

def resolve(
self, context: Context, session: Session, *, include_xcom: bool = True
self, context: Mapping[str, Any], session: Session, *, include_xcom: bool = True
) -> tuple[Mapping[str, Any], set[int]]:
map_index = context["ti"].map_index
if map_index < 0:
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:

def render_template_fields(
self,
context: Context,
context: Mapping[str, Any],
jinja_env: jinja2.Environment | None = None,
) -> None:
"""
Expand Down
5 changes: 2 additions & 3 deletions airflow/models/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import copy
import json
import logging
from collections.abc import ItemsView, Iterable, MutableMapping, ValuesView
from collections.abc import ItemsView, Iterable, Mapping, MutableMapping, ValuesView
from typing import TYPE_CHECKING, Any, ClassVar

from airflow.exceptions import AirflowException, ParamValidationError
Expand All @@ -31,7 +31,6 @@
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.sdk.definitions.dag import DAG
from airflow.utils.context import Context

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -295,7 +294,7 @@ def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET):
def iter_references(self) -> Iterable[tuple[Operator, str]]:
return ()

def resolve(self, context: Context, *, include_xcom: bool = True) -> Any:
def resolve(self, context: Mapping[str, Any], *, include_xcom: bool = True) -> Any:
"""Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
with contextlib.suppress(KeyError):
return context["dag_run"].conf[self._name]
Expand Down
22 changes: 16 additions & 6 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from airflow.models.operator import Operator
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.dag import DAG
from airflow.utils.context import Context
from airflow.utils.edgemodifier import EdgeModifier

# Callable objects contained by MapXComArg. We only accept callables from
Expand Down Expand Up @@ -206,7 +205,9 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
"""
raise NotImplementedError()

def resolve(self, context: Context, session: Session | None = None, *, include_xcom: bool = True) -> Any:
def resolve(
self, context: Mapping[str, Any], session: Session | None = None, *, include_xcom: bool = True
) -> Any:
"""
Pull XCom value.
Expand Down Expand Up @@ -420,7 +421,10 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
return session.scalar(query)

# TODO: Task-SDK: Remove session argument once everything is ported over to Task SDK
def resolve(self, context: Context, session: Session | None = None, *, include_xcom: bool = True) -> Any:
@provide_session
def resolve(
self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True
) -> Any:
ti = context["ti"]
if TYPE_CHECKING:
assert isinstance(ti, TaskInstance)
Expand Down Expand Up @@ -534,7 +538,9 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
return self.arg.get_task_map_length(run_id, session=session)

@provide_session
def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any:
def resolve(
self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True
) -> Any:
value = self.arg.resolve(context, session=session, include_xcom=include_xcom)
if not isinstance(value, (Sequence, dict)):
raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}")
Expand Down Expand Up @@ -615,7 +621,9 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
return max(ready_lengths)

@provide_session
def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any:
def resolve(
self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True
) -> Any:
values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args]
for value in values:
if not isinstance(value, (Sequence, dict)):
Expand Down Expand Up @@ -690,7 +698,9 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
return sum(ready_lengths)

@provide_session
def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any:
def resolve(
self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True
) -> Any:
values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args]
for value in values:
if not isinstance(value, (Sequence, dict)):
Expand Down
4 changes: 2 additions & 2 deletions task_sdk/src/airflow/sdk/definitions/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import inspect
import sys
import warnings
from collections.abc import Collection, Iterable, Sequence
from collections.abc import Collection, Iterable, Mapping, Sequence
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from functools import total_ordering, wraps
Expand Down Expand Up @@ -1244,7 +1244,7 @@ def inherits_from_empty_operator(self):

def render_template_fields(
self,
context: dict, # TODO: Change to `Context` once we have it
context: Mapping[str, Any], # TODO: Change to `Context` once we have it
jinja_env: jinja2.Environment | None = None,
) -> None:
"""
Expand Down
4 changes: 2 additions & 2 deletions task_sdk/src/airflow/sdk/definitions/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from airflow.sdk.definitions.abstractoperator import AbstractOperator
from airflow.models.operator import Operator
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.edges import EdgeModifier

Expand Down Expand Up @@ -122,7 +122,7 @@ def _iter_references(cls, obj: Any) -> Iterable[tuple[DependencyMixin, str]]:
class ResolveMixin:
"""A runtime-resolved value."""

def iter_references(self) -> Iterable[tuple[AbstractOperator, str]]:
def iter_references(self) -> Iterable[tuple[Operator, str]]:
"""
Find underlying XCom references this contains.
Expand Down
3 changes: 1 addition & 2 deletions task_sdk/src/airflow/sdk/definitions/templater.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@

from airflow.models.operator import Operator
from airflow.sdk.definitions.dag import DAG
from airflow.utils.context import Context


def literal(value: Any) -> LiteralValue:
Expand All @@ -63,7 +62,7 @@ class LiteralValue(ResolveMixin):
def iter_references(self) -> Iterable[tuple[Operator, str]]:
return ()

def resolve(self, context: Context, *, include_xcom: bool = True) -> Any:
def resolve(self, context: Mapping[str, Any], *, include_xcom: bool = True) -> Any:
return self.value


Expand Down
7 changes: 5 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def get_template_context(self):
return context

def render_templates(
self, context: dict | None = None, jinja_env: jinja2.Environment | None = None
self, context: dict[str, Any] | None = None, jinja_env: jinja2.Environment | None = None
) -> BaseOperator:
"""
Render templates in the operator fields.
Expand All @@ -134,6 +134,9 @@ def render_templates(
context = self.get_template_context()
original_task = self.task

if TYPE_CHECKING:
assert context

ti = context["ti"]

if TYPE_CHECKING:
Expand Down Expand Up @@ -264,7 +267,7 @@ def xcom_push(self, key: str, value: Any):

def get_relevant_upstream_map_indexes(
self, upstream: BaseOperator, ti_count: int | None, session: Any
) -> list[int]:
) -> int | range | None:
# TODO: Implement this method
return None

Expand Down
54 changes: 0 additions & 54 deletions task_sdk/tests/dags/taskflow_api.py

This file was deleted.

0 comments on commit 077ce3b

Please sign in to comment.