Skip to content

Commit

Permalink
AIP-72: Add Taskflow API support & template rendering in Task SDK
Browse files Browse the repository at this point in the history
closes #45232

The Templater class has been moved to the Task SDK to align with the language-specific aspects of template rendering. Templating logic is inherently tied to Python constructs. By keeping the Templater class within the Task SDK, we ensure that the core templating logic remains coupled with language-specific implementations.

Options I had were keeping it on the Schdeuler or the Execution side of Task SDK, neither of which is ideal as we would want to change the code in definition like DAG, Operator alongwith how it renders.
  • Loading branch information
kaxil committed Jan 7, 2025
1 parent a6da8df commit fdc3117
Show file tree
Hide file tree
Showing 33 changed files with 941 additions and 700 deletions.
44 changes: 3 additions & 41 deletions airflow/macros/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,59 +17,20 @@
# under the License.
from __future__ import annotations

import json # noqa: F401
import time # noqa: F401
import uuid # noqa: F401
from datetime import datetime, timedelta
from random import random # noqa: F401
from datetime import datetime
from typing import TYPE_CHECKING, Any

import dateutil # noqa: F401
from babel import Locale
from babel.dates import LC_TIME, format_datetime

import airflow.utils.yaml as yaml # noqa: F401
from airflow.sdk.definitions.macros import ds_add, ds_format, json, time, uuid # noqa: F401

if TYPE_CHECKING:
from pendulum import DateTime


def ds_add(ds: str, days: int) -> str:
"""
Add or subtract days from a YYYY-MM-DD.
:param ds: anchor date in ``YYYY-MM-DD`` format to add to
:param days: number of days to add to the ds, you can use negative values
>>> ds_add("2015-01-01", 5)
'2015-01-06'
>>> ds_add("2015-01-06", -5)
'2015-01-01'
"""
if not days:
return str(ds)
dt = datetime.strptime(str(ds), "%Y-%m-%d") + timedelta(days=days)
return dt.strftime("%Y-%m-%d")


def ds_format(ds: str, input_format: str, output_format: str) -> str:
"""
Output datetime string in a given format.
:param ds: Input string which contains a date.
:param input_format: Input string format (e.g., '%Y-%m-%d').
:param output_format: Output string format (e.g., '%Y-%m-%d').
>>> ds_format("2015-01-01", "%Y-%m-%d", "%m-%d-%y")
'01-01-15'
>>> ds_format("1/5/2015", "%m/%d/%Y", "%Y-%m-%d")
'2015-01-05'
>>> ds_format("12/07/2024", "%d/%m/%Y", "%A %d %B %Y", "en_US")
'Friday 12 July 2024'
"""
return datetime.strptime(str(ds), input_format).strftime(output_format)


def ds_format_locale(
ds: str, input_format: str, output_format: str, locale: Locale | str | None = None
) -> str:
Expand Down Expand Up @@ -99,6 +60,7 @@ def ds_format_locale(
)


# TODO: Task SDK: Move this to the Task SDK once we evaluate "pendulum"'s dependency
def datetime_diff_for_humans(dt: Any, since: DateTime | None = None) -> str:
"""
Return a human-readable/approximate difference between datetimes.
Expand Down
74 changes: 2 additions & 72 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@
from airflow.exceptions import AirflowException
from airflow.models.expandinput import NotFullyPopulated
from airflow.sdk.definitions.abstractoperator import AbstractOperator as TaskSDKAbstractOperator
from airflow.template.templater import Templater
from airflow.utils.context import Context
from airflow.utils.db import exists_query
from airflow.utils.log.secrets_masker import redact
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.state import State, TaskInstanceState
Expand All @@ -42,8 +41,6 @@
from airflow.utils.weight_rule import WeightRule, db_safe_priority

if TYPE_CHECKING:
from collections.abc import Mapping

import jinja2 # Slow import.
from sqlalchemy.orm import Session

Expand All @@ -52,7 +49,6 @@
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstance
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.node import DAGNode
from airflow.task.priority_strategy import PriorityWeightStrategy
from airflow.triggers.base import StartTriggerArgs
Expand Down Expand Up @@ -88,7 +84,7 @@ class NotMapped(Exception):
"""Raise if a task is neither mapped nor has any parent mapped groups."""


class AbstractOperator(Templater, TaskSDKAbstractOperator):
class AbstractOperator(LoggingMixin, TaskSDKAbstractOperator):
"""
Common implementation for operators, including unmapped and mapped.
Expand Down Expand Up @@ -128,72 +124,6 @@ def on_failure_fail_dagrun(self, value):
)
self._on_failure_fail_dagrun = value

def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
"""Get the template environment for rendering templates."""
if dag is None:
dag = self.get_dag()
return super().get_template_env(dag=dag)

def _render(self, template, context, dag: DAG | None = None):
if dag is None:
dag = self.get_dag()
return super()._render(template, context, dag=dag)

def _do_render_template_fields(
self,
parent: Any,
template_fields: Iterable[str],
context: Mapping[str, Any],
jinja_env: jinja2.Environment,
seen_oids: set[int],
) -> None:
"""Override the base to use custom error logging."""
for attr_name in template_fields:
try:
value = getattr(parent, attr_name)
except AttributeError:
raise AttributeError(
f"{attr_name!r} is configured as a template field "
f"but {parent.task_type} does not have this attribute."
)
try:
if not value:
continue
except Exception:
# This may happen if the templated field points to a class which does not support `__bool__`,
# such as Pandas DataFrames:
# https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465
self.log.info(
"Unable to check if the value of type '%s' is False for task '%s', field '%s'.",
type(value).__name__,
self.task_id,
attr_name,
)
# We may still want to render custom classes which do not support __bool__
pass

try:
if callable(value):
rendered_content = value(context=context, jinja_env=jinja_env)
else:
rendered_content = self.render_template(
value,
context,
jinja_env,
seen_oids,
)
except Exception:
value_masked = redact(name=attr_name, value=value)
self.log.exception(
"Exception rendering Jinja template for task '%s', field '%s'. Template: %r",
self.task_id,
attr_name,
value_masked,
)
raise
else:
setattr(parent, attr_name, rendered_content)

def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]:
"""
Return mapped nodes that are direct dependencies of the current task.
Expand Down
18 changes: 0 additions & 18 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@
if TYPE_CHECKING:
from types import ClassMethodDescriptorType

import jinja2 # Slow import.
from sqlalchemy.orm import Session

from airflow.models.abstractoperator import TaskStateChangeCallback
Expand Down Expand Up @@ -738,23 +737,6 @@ def post_execute(self, context: Any, result: Any = None):
logger=self.log,
).run(context, result)

def render_template_fields(
self,
context: Context,
jinja_env: jinja2.Environment | None = None,
) -> None:
"""
Template all attributes listed in *self.template_fields*.
This mutates the attributes in-place and is irreversible.
:param context: Context dict with values to apply on content.
:param jinja_env: Jinja's environment to use for rendering.
"""
if not jinja_env:
jinja_env = self.get_template_env()
self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())

@provide_session
def clear(
self,
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/expandinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import attr

from airflow.utils.mixins import ResolveMixin
from airflow.sdk.definitions.mixins import ResolveMixin
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from typing import TYPE_CHECKING, Any, ClassVar

from airflow.exceptions import AirflowException, ParamValidationError
from airflow.utils.mixins import ResolveMixin
from airflow.sdk.definitions.mixins import ResolveMixin
from airflow.utils.types import NOTSET, ArgNotSet

if TYPE_CHECKING:
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@
from airflow.models.xcom import LazyXComSelectSequence, XCom
from airflow.plugins_manager import integrate_macros_plugins
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef
from airflow.sdk.definitions.templater import SandboxedEnvironment
from airflow.sentry import Sentry
from airflow.settings import task_instance_mutation_hook
from airflow.stats import Stats
from airflow.templates import SandboxedEnvironment
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
from airflow.traces.tracer import Trace
Expand Down
11 changes: 5 additions & 6 deletions airflow/models/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
from airflow.models import MappedOperator, TaskInstance
from airflow.models.abstractoperator import AbstractOperator
from airflow.models.taskmixin import DependencyMixin
from airflow.sdk.definitions.mixins import ResolveMixin
from airflow.sdk.types import NOTSET, ArgNotSet
from airflow.utils.db import exists_query
from airflow.utils.mixins import ResolveMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.state import State
Expand Down Expand Up @@ -206,8 +206,7 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
"""
raise NotImplementedError()

@provide_session
def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any:
def resolve(self, context: Context, session: Session | None = None, *, include_xcom: bool = True) -> Any:
"""
Pull XCom value.
Expand Down Expand Up @@ -420,8 +419,8 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
)
return session.scalar(query)

@provide_session
def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any:
# TODO: Task-SDK: Remove session argument once everything is ported over to Task SDK
def resolve(self, context: Context, session: Session | None = None, *, include_xcom: bool = True) -> Any:
ti = context["ti"]
if TYPE_CHECKING:
assert isinstance(ti, TaskInstance)
Expand All @@ -431,12 +430,12 @@ def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_x
context["expanded_ti_count"],
session=session,
)

result = ti.xcom_pull(
task_ids=task_id,
map_indexes=map_indexes,
key=self.key,
default=NOTSET,
session=session,
)
if not isinstance(result, ArgNotSet):
return result
Expand Down
5 changes: 3 additions & 2 deletions airflow/notifications/basenotifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING

from airflow.template.templater import Templater
from airflow.sdk.definitions.templater import Templater
from airflow.utils.context import context_merge
from airflow.utils.log.logging_mixin import LoggingMixin

if TYPE_CHECKING:
import jinja2
Expand All @@ -31,7 +32,7 @@
from airflow.utils.context import Context


class BaseNotifier(Templater):
class BaseNotifier(LoggingMixin, Templater):
"""BaseNotifier class for sending notifications."""

template_fields: Sequence[str] = ()
Expand Down
Loading

0 comments on commit fdc3117

Please sign in to comment.