Skip to content

Commit 077ce3b

Browse files
committed
fixup! AIP-72: Add Taskflow API support & template rendering in Task SDK
1 parent fdc3117 commit 077ce3b

File tree

10 files changed

+37
-79
lines changed

10 files changed

+37
-79
lines changed

airflow/models/abstractoperator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import datetime
2121
import inspect
22-
from collections.abc import Iterable, Iterator, Sequence
22+
from collections.abc import Iterable, Iterator, Mapping, Sequence
2323
from functools import cached_property
2424
from typing import TYPE_CHECKING, Any, Callable
2525

@@ -512,7 +512,7 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence
512512

513513
def render_template_fields(
514514
self,
515-
context: Context,
515+
context: Mapping[str, Any],
516516
jinja_env: jinja2.Environment | None = None,
517517
) -> None:
518518
"""

airflow/models/expandinput.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from airflow.models.xcom_arg import XComArg
3636
from airflow.serialization.serialized_objects import _ExpandInputRef
3737
from airflow.typing_compat import TypeGuard
38-
from airflow.utils.context import Context
3938

4039
ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"]
4140

@@ -69,7 +68,9 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]:
6968
yield from self._input.iter_references()
7069

7170
@provide_session
72-
def resolve(self, context: Context, *, include_xcom: bool = True, session: Session = NEW_SESSION) -> Any:
71+
def resolve(
72+
self, context: Mapping[str, Any], *, include_xcom: bool = True, session: Session = NEW_SESSION
73+
) -> Any:
7374
data, _ = self._input.resolve(context, session=session, include_xcom=include_xcom)
7475
return data[self._key]
7576

@@ -166,7 +167,7 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
166167
return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1)
167168

168169
def _expand_mapped_field(
169-
self, key: str, value: Any, context: Context, *, session: Session, include_xcom: bool
170+
self, key: str, value: Any, context: Mapping[str, Any], *, session: Session, include_xcom: bool
170171
) -> Any:
171172
if _needs_run_time_resolution(value):
172173
value = (
@@ -210,7 +211,7 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]:
210211
yield from x.iter_references()
211212

212213
def resolve(
213-
self, context: Context, session: Session, *, include_xcom: bool = True
214+
self, context: Mapping[str, Any], session: Session, *, include_xcom: bool = True
214215
) -> tuple[Mapping[str, Any], set[int]]:
215216
data = {
216217
k: self._expand_mapped_field(k, v, context, session=session, include_xcom=include_xcom)
@@ -260,7 +261,7 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]:
260261
yield from x.iter_references()
261262

262263
def resolve(
263-
self, context: Context, session: Session, *, include_xcom: bool = True
264+
self, context: Mapping[str, Any], session: Session, *, include_xcom: bool = True
264265
) -> tuple[Mapping[str, Any], set[int]]:
265266
map_index = context["ti"].map_index
266267
if map_index < 0:

airflow/models/mappedoperator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -869,7 +869,7 @@ def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
869869

870870
def render_template_fields(
871871
self,
872-
context: Context,
872+
context: Mapping[str, Any],
873873
jinja_env: jinja2.Environment | None = None,
874874
) -> None:
875875
"""

airflow/models/param.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import copy
2121
import json
2222
import logging
23-
from collections.abc import ItemsView, Iterable, MutableMapping, ValuesView
23+
from collections.abc import ItemsView, Iterable, Mapping, MutableMapping, ValuesView
2424
from typing import TYPE_CHECKING, Any, ClassVar
2525

2626
from airflow.exceptions import AirflowException, ParamValidationError
@@ -31,7 +31,6 @@
3131
from airflow.models.dagrun import DagRun
3232
from airflow.models.operator import Operator
3333
from airflow.sdk.definitions.dag import DAG
34-
from airflow.utils.context import Context
3534

3635
logger = logging.getLogger(__name__)
3736

@@ -295,7 +294,7 @@ def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET):
295294
def iter_references(self) -> Iterable[tuple[Operator, str]]:
296295
return ()
297296

298-
def resolve(self, context: Context, *, include_xcom: bool = True) -> Any:
297+
def resolve(self, context: Mapping[str, Any], *, include_xcom: bool = True) -> Any:
299298
"""Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
300299
with contextlib.suppress(KeyError):
301300
return context["dag_run"].conf[self._name]

airflow/models/xcom_arg.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
from airflow.models.operator import Operator
4545
from airflow.sdk.definitions.baseoperator import BaseOperator
4646
from airflow.sdk.definitions.dag import DAG
47-
from airflow.utils.context import Context
4847
from airflow.utils.edgemodifier import EdgeModifier
4948

5049
# Callable objects contained by MapXComArg. We only accept callables from
@@ -206,7 +205,9 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
206205
"""
207206
raise NotImplementedError()
208207

209-
def resolve(self, context: Context, session: Session | None = None, *, include_xcom: bool = True) -> Any:
208+
def resolve(
209+
self, context: Mapping[str, Any], session: Session | None = None, *, include_xcom: bool = True
210+
) -> Any:
210211
"""
211212
Pull XCom value.
212213
@@ -420,7 +421,10 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
420421
return session.scalar(query)
421422

422423
# TODO: Task-SDK: Remove session argument once everything is ported over to Task SDK
423-
def resolve(self, context: Context, session: Session | None = None, *, include_xcom: bool = True) -> Any:
424+
@provide_session
425+
def resolve(
426+
self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True
427+
) -> Any:
424428
ti = context["ti"]
425429
if TYPE_CHECKING:
426430
assert isinstance(ti, TaskInstance)
@@ -534,7 +538,9 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
534538
return self.arg.get_task_map_length(run_id, session=session)
535539

536540
@provide_session
537-
def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any:
541+
def resolve(
542+
self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True
543+
) -> Any:
538544
value = self.arg.resolve(context, session=session, include_xcom=include_xcom)
539545
if not isinstance(value, (Sequence, dict)):
540546
raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}")
@@ -615,7 +621,9 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
615621
return max(ready_lengths)
616622

617623
@provide_session
618-
def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any:
624+
def resolve(
625+
self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True
626+
) -> Any:
619627
values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args]
620628
for value in values:
621629
if not isinstance(value, (Sequence, dict)):
@@ -690,7 +698,9 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
690698
return sum(ready_lengths)
691699

692700
@provide_session
693-
def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any:
701+
def resolve(
702+
self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True
703+
) -> Any:
694704
values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args]
695705
for value in values:
696706
if not isinstance(value, (Sequence, dict)):

task_sdk/src/airflow/sdk/definitions/baseoperator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import inspect
2525
import sys
2626
import warnings
27-
from collections.abc import Collection, Iterable, Sequence
27+
from collections.abc import Collection, Iterable, Mapping, Sequence
2828
from dataclasses import dataclass, field
2929
from datetime import datetime, timedelta
3030
from functools import total_ordering, wraps
@@ -1244,7 +1244,7 @@ def inherits_from_empty_operator(self):
12441244

12451245
def render_template_fields(
12461246
self,
1247-
context: dict, # TODO: Change to `Context` once we have it
1247+
context: Mapping[str, Any], # TODO: Change to `Context` once we have it
12481248
jinja_env: jinja2.Environment | None = None,
12491249
) -> None:
12501250
"""

task_sdk/src/airflow/sdk/definitions/mixins.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from typing import TYPE_CHECKING, Any
2323

2424
if TYPE_CHECKING:
25-
from airflow.sdk.definitions.abstractoperator import AbstractOperator
25+
from airflow.models.operator import Operator
2626
from airflow.sdk.definitions.baseoperator import BaseOperator
2727
from airflow.sdk.definitions.edges import EdgeModifier
2828

@@ -122,7 +122,7 @@ def _iter_references(cls, obj: Any) -> Iterable[tuple[DependencyMixin, str]]:
122122
class ResolveMixin:
123123
"""A runtime-resolved value."""
124124

125-
def iter_references(self) -> Iterable[tuple[AbstractOperator, str]]:
125+
def iter_references(self) -> Iterable[tuple[Operator, str]]:
126126
"""
127127
Find underlying XCom references this contains.
128128

task_sdk/src/airflow/sdk/definitions/templater.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
from airflow.models.operator import Operator
3838
from airflow.sdk.definitions.dag import DAG
39-
from airflow.utils.context import Context
4039

4140

4241
def literal(value: Any) -> LiteralValue:
@@ -63,7 +62,7 @@ class LiteralValue(ResolveMixin):
6362
def iter_references(self) -> Iterable[tuple[Operator, str]]:
6463
return ()
6564

66-
def resolve(self, context: Context, *, include_xcom: bool = True) -> Any:
65+
def resolve(self, context: Mapping[str, Any], *, include_xcom: bool = True) -> Any:
6766
return self.value
6867

6968

task_sdk/src/airflow/sdk/execution_time/task_runner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def get_template_context(self):
121121
return context
122122

123123
def render_templates(
124-
self, context: dict | None = None, jinja_env: jinja2.Environment | None = None
124+
self, context: dict[str, Any] | None = None, jinja_env: jinja2.Environment | None = None
125125
) -> BaseOperator:
126126
"""
127127
Render templates in the operator fields.
@@ -134,6 +134,9 @@ def render_templates(
134134
context = self.get_template_context()
135135
original_task = self.task
136136

137+
if TYPE_CHECKING:
138+
assert context
139+
137140
ti = context["ti"]
138141

139142
if TYPE_CHECKING:
@@ -264,7 +267,7 @@ def xcom_push(self, key: str, value: Any):
264267

265268
def get_relevant_upstream_map_indexes(
266269
self, upstream: BaseOperator, ti_count: int | None, session: Any
267-
) -> list[int]:
270+
) -> int | range | None:
268271
# TODO: Implement this method
269272
return None
270273

task_sdk/tests/dags/taskflow_api.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

0 commit comments

Comments
 (0)