From 077ce3b5780470b0dcf12d6c19b6e03d9bbf72cf Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Tue, 7 Jan 2025 13:12:51 +0530 Subject: [PATCH] fixup! AIP-72: Add Taskflow API support & template rendering in Task SDK --- airflow/models/abstractoperator.py | 4 +- airflow/models/expandinput.py | 11 ++-- airflow/models/mappedoperator.py | 2 +- airflow/models/param.py | 5 +- airflow/models/xcom_arg.py | 22 +++++--- .../airflow/sdk/definitions/baseoperator.py | 4 +- .../src/airflow/sdk/definitions/mixins.py | 4 +- .../src/airflow/sdk/definitions/templater.py | 3 +- .../airflow/sdk/execution_time/task_runner.py | 7 ++- task_sdk/tests/dags/taskflow_api.py | 54 ------------------- 10 files changed, 37 insertions(+), 79 deletions(-) delete mode 100644 task_sdk/tests/dags/taskflow_api.py diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index 134db08d71bb95..f87b6e06b1c075 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -19,7 +19,7 @@ import datetime import inspect -from collections.abc import Iterable, Iterator, Sequence +from collections.abc import Iterable, Iterator, Mapping, Sequence from functools import cached_property from typing import TYPE_CHECKING, Any, Callable @@ -512,7 +512,7 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence def render_template_fields( self, - context: Context, + context: Mapping[str, Any], jinja_env: jinja2.Environment | None = None, ) -> None: """ diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py index b1e4daf78435a8..bf3c6e9505600e 100644 --- a/airflow/models/expandinput.py +++ b/airflow/models/expandinput.py @@ -35,7 +35,6 @@ from airflow.models.xcom_arg import XComArg from airflow.serialization.serialized_objects import _ExpandInputRef from airflow.typing_compat import TypeGuard - from airflow.utils.context import Context ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"] @@ -69,7 +68,9 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]: yield from self._input.iter_references() @provide_session - def resolve(self, context: Context, *, include_xcom: bool = True, session: Session = NEW_SESSION) -> Any: + def resolve( + self, context: Mapping[str, Any], *, include_xcom: bool = True, session: Session = NEW_SESSION + ) -> Any: data, _ = self._input.resolve(context, session=session, include_xcom=include_xcom) return data[self._key] @@ -166,7 +167,7 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int: return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1) def _expand_mapped_field( - self, key: str, value: Any, context: Context, *, session: Session, include_xcom: bool + self, key: str, value: Any, context: Mapping[str, Any], *, session: Session, include_xcom: bool ) -> Any: if _needs_run_time_resolution(value): value = ( @@ -210,7 +211,7 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]: yield from x.iter_references() def resolve( - self, context: Context, session: Session, *, include_xcom: bool = True + self, context: Mapping[str, Any], session: Session, *, include_xcom: bool = True ) -> tuple[Mapping[str, Any], set[int]]: data = { k: self._expand_mapped_field(k, v, context, session=session, include_xcom=include_xcom) @@ -260,7 +261,7 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]: yield from x.iter_references() def resolve( - self, context: Context, session: Session, *, include_xcom: bool = True + self, context: Mapping[str, Any], session: Session, *, include_xcom: bool = True ) -> tuple[Mapping[str, Any], set[int]]: map_index = context["ti"].map_index if map_index < 0: diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 524415b848f62b..19173c233352df 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -869,7 +869,7 @@ def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int: def render_template_fields( self, - context: Context, + context: Mapping[str, Any], jinja_env: jinja2.Environment | None = None, ) -> None: """ diff --git a/airflow/models/param.py b/airflow/models/param.py index 4d55706d1ea570..416d9cfb8b9b44 100644 --- a/airflow/models/param.py +++ b/airflow/models/param.py @@ -20,7 +20,7 @@ import copy import json import logging -from collections.abc import ItemsView, Iterable, MutableMapping, ValuesView +from collections.abc import ItemsView, Iterable, Mapping, MutableMapping, ValuesView from typing import TYPE_CHECKING, Any, ClassVar from airflow.exceptions import AirflowException, ParamValidationError @@ -31,7 +31,6 @@ from airflow.models.dagrun import DagRun from airflow.models.operator import Operator from airflow.sdk.definitions.dag import DAG - from airflow.utils.context import Context logger = logging.getLogger(__name__) @@ -295,7 +294,7 @@ def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET): def iter_references(self) -> Iterable[tuple[Operator, str]]: return () - def resolve(self, context: Context, *, include_xcom: bool = True) -> Any: + def resolve(self, context: Mapping[str, Any], *, include_xcom: bool = True) -> Any: """Pull DagParam value from DagRun context. This method is run during ``op.execute()``.""" with contextlib.suppress(KeyError): return context["dag_run"].conf[self._name] diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 9f99450e729ae6..cf4147dcbfcdf5 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -44,7 +44,6 @@ from airflow.models.operator import Operator from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG - from airflow.utils.context import Context from airflow.utils.edgemodifier import EdgeModifier # Callable objects contained by MapXComArg. We only accept callables from @@ -206,7 +205,9 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: """ raise NotImplementedError() - def resolve(self, context: Context, session: Session | None = None, *, include_xcom: bool = True) -> Any: + def resolve( + self, context: Mapping[str, Any], session: Session | None = None, *, include_xcom: bool = True + ) -> Any: """ Pull XCom value. @@ -420,7 +421,10 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: return session.scalar(query) # TODO: Task-SDK: Remove session argument once everything is ported over to Task SDK - def resolve(self, context: Context, session: Session | None = None, *, include_xcom: bool = True) -> Any: + @provide_session + def resolve( + self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True + ) -> Any: ti = context["ti"] if TYPE_CHECKING: assert isinstance(ti, TaskInstance) @@ -534,7 +538,9 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: return self.arg.get_task_map_length(run_id, session=session) @provide_session - def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any: + def resolve( + self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True + ) -> Any: value = self.arg.resolve(context, session=session, include_xcom=include_xcom) if not isinstance(value, (Sequence, dict)): raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}") @@ -615,7 +621,9 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: return max(ready_lengths) @provide_session - def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any: + def resolve( + self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True + ) -> Any: values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args] for value in values: if not isinstance(value, (Sequence, dict)): @@ -690,7 +698,9 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: return sum(ready_lengths) @provide_session - def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any: + def resolve( + self, context: Mapping[str, Any], session: Session = NEW_SESSION, *, include_xcom: bool = True + ) -> Any: values = [arg.resolve(context, session=session, include_xcom=include_xcom) for arg in self.args] for value in values: if not isinstance(value, (Sequence, dict)): diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py index 44a152f6aa95fe..8dee46f00e4152 100644 --- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py @@ -24,7 +24,7 @@ import inspect import sys import warnings -from collections.abc import Collection, Iterable, Sequence +from collections.abc import Collection, Iterable, Mapping, Sequence from dataclasses import dataclass, field from datetime import datetime, timedelta from functools import total_ordering, wraps @@ -1244,7 +1244,7 @@ def inherits_from_empty_operator(self): def render_template_fields( self, - context: dict, # TODO: Change to `Context` once we have it + context: Mapping[str, Any], # TODO: Change to `Context` once we have it jinja_env: jinja2.Environment | None = None, ) -> None: """ diff --git a/task_sdk/src/airflow/sdk/definitions/mixins.py b/task_sdk/src/airflow/sdk/definitions/mixins.py index 7b1594e697874d..583d8b6491ebb1 100644 --- a/task_sdk/src/airflow/sdk/definitions/mixins.py +++ b/task_sdk/src/airflow/sdk/definitions/mixins.py @@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from airflow.sdk.definitions.abstractoperator import AbstractOperator + from airflow.models.operator import Operator from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.edges import EdgeModifier @@ -122,7 +122,7 @@ def _iter_references(cls, obj: Any) -> Iterable[tuple[DependencyMixin, str]]: class ResolveMixin: """A runtime-resolved value.""" - def iter_references(self) -> Iterable[tuple[AbstractOperator, str]]: + def iter_references(self) -> Iterable[tuple[Operator, str]]: """ Find underlying XCom references this contains. diff --git a/task_sdk/src/airflow/sdk/definitions/templater.py b/task_sdk/src/airflow/sdk/definitions/templater.py index ac33e7cbed62b3..65e9c70f390697 100644 --- a/task_sdk/src/airflow/sdk/definitions/templater.py +++ b/task_sdk/src/airflow/sdk/definitions/templater.py @@ -36,7 +36,6 @@ from airflow.models.operator import Operator from airflow.sdk.definitions.dag import DAG - from airflow.utils.context import Context def literal(value: Any) -> LiteralValue: @@ -63,7 +62,7 @@ class LiteralValue(ResolveMixin): def iter_references(self) -> Iterable[tuple[Operator, str]]: return () - def resolve(self, context: Context, *, include_xcom: bool = True) -> Any: + def resolve(self, context: Mapping[str, Any], *, include_xcom: bool = True) -> Any: return self.value diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 4f4a9e6ec0c9e0..610556ce005e19 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -121,7 +121,7 @@ def get_template_context(self): return context def render_templates( - self, context: dict | None = None, jinja_env: jinja2.Environment | None = None + self, context: dict[str, Any] | None = None, jinja_env: jinja2.Environment | None = None ) -> BaseOperator: """ Render templates in the operator fields. @@ -134,6 +134,9 @@ def render_templates( context = self.get_template_context() original_task = self.task + if TYPE_CHECKING: + assert context + ti = context["ti"] if TYPE_CHECKING: @@ -264,7 +267,7 @@ def xcom_push(self, key: str, value: Any): def get_relevant_upstream_map_indexes( self, upstream: BaseOperator, ti_count: int | None, session: Any - ) -> list[int]: + ) -> int | range | None: # TODO: Implement this method return None diff --git a/task_sdk/tests/dags/taskflow_api.py b/task_sdk/tests/dags/taskflow_api.py deleted file mode 100644 index 3dbd5f99647e5c..00000000000000 --- a/task_sdk/tests/dags/taskflow_api.py +++ /dev/null @@ -1,54 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import pendulum - -from airflow.decorators import dag, task - - -@dag( - schedule=None, - start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), - catchup=False, -) -def taskflow_api(): - @task() - def extract(): - order_data_dict = {"1001": 301.27, "1002": 433.21, "1003": 502.22} - return order_data_dict - - @task(multiple_outputs=True) - def transform(order_data_dict: dict): - total_order_value = 0 - - for value in order_data_dict.values(): - total_order_value += value - - return {"total_order_value": total_order_value} - - @task() - def load(total_order_value: float): - print(f"Total order value is: {total_order_value:.2f}") - - order_data = extract() - order_summary = transform(order_data) - load(order_summary["total_order_value"]) - - -taskflow_api()