From fdc3117f58394514c61286fe658170386c649d8b Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Tue, 24 Dec 2024 22:32:33 +0530 Subject: [PATCH] AIP-72: Add Taskflow API support & template rendering in Task SDK closes https://github.com/apache/airflow/issues/45232 The Templater class has been moved to the Task SDK to align with the language-specific aspects of template rendering. Templating logic is inherently tied to Python constructs. By keeping the Templater class within the Task SDK, we ensure that the core templating logic remains coupled with language-specific implementations. Options I had were keeping it on the Schdeuler or the Execution side of Task SDK, neither of which is ideal as we would want to change the code in definition like DAG, Operator alongwith how it renders. --- airflow/macros/__init__.py | 44 +-- airflow/models/abstractoperator.py | 74 +---- airflow/models/baseoperator.py | 18 -- airflow/models/expandinput.py | 2 +- airflow/models/param.py | 2 +- airflow/models/taskinstance.py | 2 +- airflow/models/xcom_arg.py | 11 +- airflow/notifications/basenotifier.py | 5 +- airflow/templates.py | 94 ------ airflow/utils/mixins.py | 27 -- .../tests/test_pytest_args_for_test_types.py | 1 - .../standard/utils/python_virtualenv.py | 5 +- .../operators/test_spark_kubernetes.py | 10 +- scripts/cov/other_coverage.py | 1 - {tests/template => task_sdk}/__init__.py | 0 .../sdk/definitions/abstractoperator.py | 73 ++++- .../airflow/sdk/definitions/baseoperator.py | 19 ++ .../src/airflow/sdk/definitions/connection.py | 34 ++ task_sdk/src/airflow/sdk/definitions/dag.py | 6 +- .../src/airflow/sdk/definitions/macros.py | 61 ++++ .../src/airflow/sdk/definitions/mixins.py | 28 +- .../src/airflow/sdk/definitions}/templater.py | 102 +++++- .../src/airflow/sdk/execution_time/context.py | 22 ++ .../airflow/sdk/execution_time/task_runner.py | 47 ++- .../tests/dags/taskflow_api.py | 40 ++- .../tests/defintions/test_baseoperator.py | 298 +++++++++++++++++- task_sdk/tests/defintions/test_macros.py | 72 +++++ .../tests/defintions}/test_templater.py | 53 +++- .../tests/execution_time/test_task_runner.py | 102 +++++- tests/core/test_templates.py | 57 ---- tests/macros/test_macros.py | 49 --- tests/models/test_baseoperator.py | 280 +--------------- tests/models/test_dag.py | 2 +- 33 files changed, 941 insertions(+), 700 deletions(-) delete mode 100644 airflow/templates.py rename {tests/template => task_sdk}/__init__.py (100%) create mode 100644 task_sdk/src/airflow/sdk/definitions/macros.py rename {airflow/template => task_sdk/src/airflow/sdk/definitions}/templater.py (74%) rename airflow/utils/template.py => task_sdk/tests/dags/taskflow_api.py (50%) create mode 100644 task_sdk/tests/defintions/test_macros.py rename {tests/template => task_sdk/tests/defintions}/test_templater.py (70%) delete mode 100644 tests/core/test_templates.py diff --git a/airflow/macros/__init__.py b/airflow/macros/__init__.py index be4554818acf2..26b08c8a6b383 100644 --- a/airflow/macros/__init__.py +++ b/airflow/macros/__init__.py @@ -17,11 +17,7 @@ # under the License. from __future__ import annotations -import json # noqa: F401 -import time # noqa: F401 -import uuid # noqa: F401 -from datetime import datetime, timedelta -from random import random # noqa: F401 +from datetime import datetime from typing import TYPE_CHECKING, Any import dateutil # noqa: F401 @@ -29,47 +25,12 @@ from babel.dates import LC_TIME, format_datetime import airflow.utils.yaml as yaml # noqa: F401 +from airflow.sdk.definitions.macros import ds_add, ds_format, json, time, uuid # noqa: F401 if TYPE_CHECKING: from pendulum import DateTime -def ds_add(ds: str, days: int) -> str: - """ - Add or subtract days from a YYYY-MM-DD. - - :param ds: anchor date in ``YYYY-MM-DD`` format to add to - :param days: number of days to add to the ds, you can use negative values - - >>> ds_add("2015-01-01", 5) - '2015-01-06' - >>> ds_add("2015-01-06", -5) - '2015-01-01' - """ - if not days: - return str(ds) - dt = datetime.strptime(str(ds), "%Y-%m-%d") + timedelta(days=days) - return dt.strftime("%Y-%m-%d") - - -def ds_format(ds: str, input_format: str, output_format: str) -> str: - """ - Output datetime string in a given format. - - :param ds: Input string which contains a date. - :param input_format: Input string format (e.g., '%Y-%m-%d'). - :param output_format: Output string format (e.g., '%Y-%m-%d'). - - >>> ds_format("2015-01-01", "%Y-%m-%d", "%m-%d-%y") - '01-01-15' - >>> ds_format("1/5/2015", "%m/%d/%Y", "%Y-%m-%d") - '2015-01-05' - >>> ds_format("12/07/2024", "%d/%m/%Y", "%A %d %B %Y", "en_US") - 'Friday 12 July 2024' - """ - return datetime.strptime(str(ds), input_format).strftime(output_format) - - def ds_format_locale( ds: str, input_format: str, output_format: str, locale: Locale | str | None = None ) -> str: @@ -99,6 +60,7 @@ def ds_format_locale( ) +# TODO: Task SDK: Move this to the Task SDK once we evaluate "pendulum"'s dependency def datetime_diff_for_humans(dt: Any, since: DateTime | None = None) -> str: """ Return a human-readable/approximate difference between datetimes. diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index aa23bf33e131a..134db08d71bb9 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -30,10 +30,9 @@ from airflow.exceptions import AirflowException from airflow.models.expandinput import NotFullyPopulated from airflow.sdk.definitions.abstractoperator import AbstractOperator as TaskSDKAbstractOperator -from airflow.template.templater import Templater from airflow.utils.context import Context from airflow.utils.db import exists_query -from airflow.utils.log.secrets_masker import redact +from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.sqlalchemy import with_row_locks from airflow.utils.state import State, TaskInstanceState @@ -42,8 +41,6 @@ from airflow.utils.weight_rule import WeightRule, db_safe_priority if TYPE_CHECKING: - from collections.abc import Mapping - import jinja2 # Slow import. from sqlalchemy.orm import Session @@ -52,7 +49,6 @@ from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance from airflow.sdk.definitions.baseoperator import BaseOperator - from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.node import DAGNode from airflow.task.priority_strategy import PriorityWeightStrategy from airflow.triggers.base import StartTriggerArgs @@ -88,7 +84,7 @@ class NotMapped(Exception): """Raise if a task is neither mapped nor has any parent mapped groups.""" -class AbstractOperator(Templater, TaskSDKAbstractOperator): +class AbstractOperator(LoggingMixin, TaskSDKAbstractOperator): """ Common implementation for operators, including unmapped and mapped. @@ -128,72 +124,6 @@ def on_failure_fail_dagrun(self, value): ) self._on_failure_fail_dagrun = value - def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment: - """Get the template environment for rendering templates.""" - if dag is None: - dag = self.get_dag() - return super().get_template_env(dag=dag) - - def _render(self, template, context, dag: DAG | None = None): - if dag is None: - dag = self.get_dag() - return super()._render(template, context, dag=dag) - - def _do_render_template_fields( - self, - parent: Any, - template_fields: Iterable[str], - context: Mapping[str, Any], - jinja_env: jinja2.Environment, - seen_oids: set[int], - ) -> None: - """Override the base to use custom error logging.""" - for attr_name in template_fields: - try: - value = getattr(parent, attr_name) - except AttributeError: - raise AttributeError( - f"{attr_name!r} is configured as a template field " - f"but {parent.task_type} does not have this attribute." - ) - try: - if not value: - continue - except Exception: - # This may happen if the templated field points to a class which does not support `__bool__`, - # such as Pandas DataFrames: - # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465 - self.log.info( - "Unable to check if the value of type '%s' is False for task '%s', field '%s'.", - type(value).__name__, - self.task_id, - attr_name, - ) - # We may still want to render custom classes which do not support __bool__ - pass - - try: - if callable(value): - rendered_content = value(context=context, jinja_env=jinja_env) - else: - rendered_content = self.render_template( - value, - context, - jinja_env, - seen_oids, - ) - except Exception: - value_masked = redact(name=attr_name, value=value) - self.log.exception( - "Exception rendering Jinja template for task '%s', field '%s'. Template: %r", - self.task_id, - attr_name, - value_masked, - ) - raise - else: - setattr(parent, attr_name, rendered_content) - def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: """ Return mapped nodes that are direct dependencies of the current task. diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 08839cc0bf720..f28e05584fdf8 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -100,7 +100,6 @@ if TYPE_CHECKING: from types import ClassMethodDescriptorType - import jinja2 # Slow import. from sqlalchemy.orm import Session from airflow.models.abstractoperator import TaskStateChangeCallback @@ -738,23 +737,6 @@ def post_execute(self, context: Any, result: Any = None): logger=self.log, ).run(context, result) - def render_template_fields( - self, - context: Context, - jinja_env: jinja2.Environment | None = None, - ) -> None: - """ - Template all attributes listed in *self.template_fields*. - - This mutates the attributes in-place and is irreversible. - - :param context: Context dict with values to apply on content. - :param jinja_env: Jinja's environment to use for rendering. - """ - if not jinja_env: - jinja_env = self.get_template_env() - self._do_render_template_fields(self, self.template_fields, context, jinja_env, set()) - @provide_session def clear( self, diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py index 8d86ec193eb4d..b1e4daf78435a 100644 --- a/airflow/models/expandinput.py +++ b/airflow/models/expandinput.py @@ -25,7 +25,7 @@ import attr -from airflow.utils.mixins import ResolveMixin +from airflow.sdk.definitions.mixins import ResolveMixin from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: diff --git a/airflow/models/param.py b/airflow/models/param.py index ab7d2facd7e3b..4d55706d1ea57 100644 --- a/airflow/models/param.py +++ b/airflow/models/param.py @@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any, ClassVar from airflow.exceptions import AirflowException, ParamValidationError -from airflow.utils.mixins import ResolveMixin +from airflow.sdk.definitions.mixins import ResolveMixin from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index a5e50cb0d2cdb..27293fa2d022e 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -107,10 +107,10 @@ from airflow.models.xcom import LazyXComSelectSequence, XCom from airflow.plugins_manager import integrate_macros_plugins from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef +from airflow.sdk.definitions.templater import SandboxedEnvironment from airflow.sentry import Sentry from airflow.settings import task_instance_mutation_hook from airflow.stats import Stats -from airflow.templates import SandboxedEnvironment from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS from airflow.traces.tracer import Trace diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 103ddc663323c..9f99450e729ae 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -29,9 +29,9 @@ from airflow.models import MappedOperator, TaskInstance from airflow.models.abstractoperator import AbstractOperator from airflow.models.taskmixin import DependencyMixin +from airflow.sdk.definitions.mixins import ResolveMixin from airflow.sdk.types import NOTSET, ArgNotSet from airflow.utils.db import exists_query -from airflow.utils.mixins import ResolveMixin from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.state import State @@ -206,8 +206,7 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: """ raise NotImplementedError() - @provide_session - def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any: + def resolve(self, context: Context, session: Session | None = None, *, include_xcom: bool = True) -> Any: """ Pull XCom value. @@ -420,8 +419,8 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: ) return session.scalar(query) - @provide_session - def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any: + # TODO: Task-SDK: Remove session argument once everything is ported over to Task SDK + def resolve(self, context: Context, session: Session | None = None, *, include_xcom: bool = True) -> Any: ti = context["ti"] if TYPE_CHECKING: assert isinstance(ti, TaskInstance) @@ -431,12 +430,12 @@ def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_x context["expanded_ti_count"], session=session, ) + result = ti.xcom_pull( task_ids=task_id, map_indexes=map_indexes, key=self.key, default=NOTSET, - session=session, ) if not isinstance(result, ArgNotSet): return result diff --git a/airflow/notifications/basenotifier.py b/airflow/notifications/basenotifier.py index eaac6d11df36d..398d95cbb8d0a 100644 --- a/airflow/notifications/basenotifier.py +++ b/airflow/notifications/basenotifier.py @@ -21,8 +21,9 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from airflow.template.templater import Templater +from airflow.sdk.definitions.templater import Templater from airflow.utils.context import context_merge +from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: import jinja2 @@ -31,7 +32,7 @@ from airflow.utils.context import Context -class BaseNotifier(Templater): +class BaseNotifier(LoggingMixin, Templater): """BaseNotifier class for sending notifications.""" template_fields: Sequence[str] = () diff --git a/airflow/templates.py b/airflow/templates.py deleted file mode 100644 index 95851253a7d22..0000000000000 --- a/airflow/templates.py +++ /dev/null @@ -1,94 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from typing import TYPE_CHECKING - -import jinja2.nativetypes -import jinja2.sandbox - -if TYPE_CHECKING: - import datetime - - -class _AirflowEnvironmentMixin: - def __init__(self, **kwargs): - super().__init__(**kwargs) - - self.filters.update(FILTERS) - - def is_safe_attribute(self, obj, attr, value): - """ - Allow access to ``_`` prefix vars (but not ``__``). - - Unlike the stock SandboxedEnvironment, we allow access to "private" attributes (ones starting with - ``_``) whilst still blocking internal or truly private attributes (``__`` prefixed ones). - """ - return not jinja2.sandbox.is_internal_attribute(obj, attr) - - -class NativeEnvironment(_AirflowEnvironmentMixin, jinja2.nativetypes.NativeEnvironment): - """NativeEnvironment for Airflow task templates.""" - - -class SandboxedEnvironment(_AirflowEnvironmentMixin, jinja2.sandbox.SandboxedEnvironment): - """SandboxedEnvironment for Airflow task templates.""" - - -def ds_filter(value: datetime.date | datetime.time | None) -> str | None: - """Date filter.""" - if value is None: - return None - return value.strftime("%Y-%m-%d") - - -def ds_nodash_filter(value: datetime.date | datetime.time | None) -> str | None: - """Date filter without dashes.""" - if value is None: - return None - return value.strftime("%Y%m%d") - - -def ts_filter(value: datetime.date | datetime.time | None) -> str | None: - """Timestamp filter.""" - if value is None: - return None - return value.isoformat() - - -def ts_nodash_filter(value: datetime.date | datetime.time | None) -> str | None: - """Timestamp filter without dashes.""" - if value is None: - return None - return value.strftime("%Y%m%dT%H%M%S") - - -def ts_nodash_with_tz_filter(value: datetime.date | datetime.time | None) -> str | None: - """Timestamp filter with timezone.""" - if value is None: - return None - return value.isoformat().replace("-", "").replace(":", "") - - -FILTERS = { - "ds": ds_filter, - "ds_nodash": ds_nodash_filter, - "ts": ts_filter, - "ts_nodash": ts_nodash_filter, - "ts_nodash_with_tz": ts_nodash_with_tz_filter, -} diff --git a/airflow/utils/mixins.py b/airflow/utils/mixins.py index 99c4b039090d9..78484f8375ef6 100644 --- a/airflow/utils/mixins.py +++ b/airflow/utils/mixins.py @@ -20,14 +20,9 @@ import multiprocessing import multiprocessing.context -import typing from airflow.configuration import conf -if typing.TYPE_CHECKING: - from airflow.models.operator import Operator - from airflow.utils.context import Context - class MultiprocessingStartMethodMixin: """Convenience class to add support for different types of multiprocessing.""" @@ -49,25 +44,3 @@ def _get_multiprocessing_start_method(self) -> str: def _get_multiprocessing_context(self) -> multiprocessing.context.DefaultContext: mp_start_method = self._get_multiprocessing_start_method() return multiprocessing.get_context(mp_start_method) # type: ignore - - -class ResolveMixin: - """A runtime-resolved value.""" - - def iter_references(self) -> typing.Iterable[tuple[Operator, str]]: - """ - Find underlying XCom references this contains. - - This is used by the DAG parser to recursively find task dependencies. - - :meta private: - """ - raise NotImplementedError - - def resolve(self, context: Context, *, include_xcom: bool = True) -> typing.Any: - """ - Resolve this value for runtime. - - :meta private: - """ - raise NotImplementedError diff --git a/dev/breeze/tests/test_pytest_args_for_test_types.py b/dev/breeze/tests/test_pytest_args_for_test_types.py index be5138699722f..c6bfa93fefd68 100644 --- a/dev/breeze/tests/test_pytest_args_for_test_types.py +++ b/dev/breeze/tests/test_pytest_args_for_test_types.py @@ -127,7 +127,6 @@ "tests/security", "tests/sensors", "tests/task", - "tests/template", "tests/testconfig", "tests/timetables", ], diff --git a/providers/src/airflow/providers/standard/utils/python_virtualenv.py b/providers/src/airflow/providers/standard/utils/python_virtualenv.py index 9d03e43367a49..66cc92ee16e04 100644 --- a/providers/src/airflow/providers/standard/utils/python_virtualenv.py +++ b/providers/src/airflow/providers/standard/utils/python_virtualenv.py @@ -28,6 +28,7 @@ from jinja2 import select_autoescape from airflow.configuration import conf +from airflow.sdk.definitions.templater import NativeEnvironment from airflow.utils.process_utils import execute_in_subprocess @@ -196,9 +197,7 @@ def write_python_script( template_loader = jinja2.FileSystemLoader(searchpath=os.path.dirname(__file__)) template_env: jinja2.Environment if render_template_as_native_obj: - template_env = jinja2.nativetypes.NativeEnvironment( - loader=template_loader, undefined=jinja2.StrictUndefined - ) + template_env = NativeEnvironment(loader=template_loader, undefined=jinja2.StrictUndefined) else: template_env = jinja2.Environment( loader=template_loader, diff --git a/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py b/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py index 5104827882132..5a107f31146d0 100644 --- a/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py +++ b/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py @@ -794,7 +794,10 @@ def test_resolve_application_file_real_file( application_file = application_file.resolve().as_posix() if use_literal_value: - from airflow.template.templater import LiteralValue + try: + from airflow.sdk.definitions.templater import LiteralValue + except ImportError: + from airflow.template.templater import LiteralValue application_file = LiteralValue(application_file) else: @@ -820,7 +823,10 @@ def test_resolve_application_file_real_file( @pytest.mark.db_test def test_resolve_application_file_real_file_not_exists(create_task_instance_of_operator, tmp_path, session): application_file = (tmp_path / "test-application-file.yml").resolve().as_posix() - from airflow.template.templater import LiteralValue + try: + from airflow.sdk.definitions.templater import LiteralValue + except ImportError: + from airflow.template.templater import LiteralValue ti = create_task_instance_of_operator( SparkKubernetesOperator, diff --git a/scripts/cov/other_coverage.py b/scripts/cov/other_coverage.py index 0394e3590bec3..19a061557d70b 100644 --- a/scripts/cov/other_coverage.py +++ b/scripts/cov/other_coverage.py @@ -91,7 +91,6 @@ "tests/security", "tests/sensors", "tests/task", - "tests/template", "tests/testconfig", "tests/timetables", """ diff --git a/tests/template/__init__.py b/task_sdk/__init__.py similarity index 100% rename from tests/template/__init__.py rename to task_sdk/__init__.py diff --git a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py index 0e51a9748e824..0251033cd555f 100644 --- a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py @@ -18,10 +18,12 @@ from __future__ import annotations import datetime +import logging from abc import abstractmethod from collections.abc import ( Collection, Iterable, + Mapping, ) from typing import ( TYPE_CHECKING, @@ -31,10 +33,13 @@ from airflow.sdk.definitions.mixins import DependencyMixin from airflow.sdk.definitions.node import DAGNode +from airflow.sdk.definitions.templater import Templater from airflow.utils.trigger_rule import TriggerRule from airflow.utils.weight_rule import WeightRule if TYPE_CHECKING: + import jinja2 + from airflow.models.baseoperatorlink import BaseOperatorLink from airflow.models.operator import Operator from airflow.sdk.definitions.baseoperator import BaseOperator @@ -62,12 +67,14 @@ DEFAULT_WEIGHT_RULE: WeightRule = WeightRule.DOWNSTREAM DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = None +log = logging.getLogger(__name__) + class NotMapped(Exception): """Raise if a task is neither mapped nor has any parent mapped groups.""" -class AbstractOperator(DAGNode): +class AbstractOperator(Templater, DAGNode): """ Common implementation for operators, including unmapped and mapped. @@ -265,3 +272,67 @@ def get_upstreams_only_setups(self) -> Iterable[Operator]: for task in self.get_upstreams_only_setups_and_teardowns(): if task.is_setup: yield task + + # TODO: Task-SDK -- Should the following methods removed? + # get_template_env + # _render + def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment: + """Get the template environment for rendering templates.""" + if dag is None: + dag = self.get_dag() + return super().get_template_env(dag=dag) + + def _render(self, template, context, dag: DAG | None = None): + if dag is None: + dag = self.get_dag() + return super()._render(template, context, dag=dag) + + def _do_render_template_fields( + self, + parent: Any, + template_fields: Iterable[str], + context: Mapping[str, Any], + jinja_env: jinja2.Environment, + seen_oids: set[int], + ) -> None: + """Override the base to use custom error logging.""" + for attr_name in template_fields: + try: + value = getattr(parent, attr_name) + except AttributeError: + raise AttributeError( + f"{attr_name!r} is configured as a template field " + f"but {parent.task_type} does not have this attribute." + ) + try: + if not value: + continue + except Exception: + # This may happen if the templated field points to a class which does not support `__bool__`, + # such as Pandas DataFrames: + # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465 + log.info( + "Unable to check if the value of type '%s' is False for task '%s', field '%s'.", + type(value).__name__, + self.task_id, + attr_name, + ) + # We may still want to render custom classes which do not support __bool__ + pass + + try: + if callable(value): + rendered_content = value(context=context, jinja_env=jinja_env) + else: + rendered_content = self.render_template(value, context, jinja_env, seen_oids) + except Exception: + # TODO: Mask the value. Depends on https://github.com/apache/airflow/issues/45438 + log.exception( + "Exception rendering Jinja template for task '%s', field '%s'. Template: %r", + self.task_id, + attr_name, + value, + ) + raise + else: + setattr(parent, attr_name, rendered_content) diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py index 93f90c2475418..44a152f6aa95f 100644 --- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py @@ -65,6 +65,8 @@ T = TypeVar("T", bound=FunctionType) if TYPE_CHECKING: + import jinja2 + from airflow.models.xcom_arg import XComArg from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.taskgroup import TaskGroup @@ -1239,3 +1241,20 @@ def inherits_from_empty_operator(self): # needs to cope when `self` is a Serialized instance of a EmptyOperator or one # of its subclasses (which don't inherit from anything but BaseOperator). return getattr(self, "_is_empty", False) + + def render_template_fields( + self, + context: dict, # TODO: Change to `Context` once we have it + jinja_env: jinja2.Environment | None = None, + ) -> None: + """ + Template all attributes listed in *self.template_fields*. + + This mutates the attributes in-place and is irreversible. + + :param context: Context dict with values to apply on content. + :param jinja_env: Jinja's environment to use for rendering. + """ + if not jinja_env: + jinja_env = self.get_template_env() + self._do_render_template_fields(self, self.template_fields, context, jinja_env, set()) diff --git a/task_sdk/src/airflow/sdk/definitions/connection.py b/task_sdk/src/airflow/sdk/definitions/connection.py index 628b72e29be6e..aa8c79986afef 100644 --- a/task_sdk/src/airflow/sdk/definitions/connection.py +++ b/task_sdk/src/airflow/sdk/definitions/connection.py @@ -17,8 +17,15 @@ # under the License. from __future__ import annotations +import json +import logging +from contextlib import suppress +from json import JSONDecodeError + import attrs +log = logging.getLogger(__name__) + @attrs.define class Connection: @@ -50,3 +57,30 @@ class Connection: def get_uri(self): ... def get_hook(self): ... + + @property + def extra_dejson(self, nested: bool = False) -> dict: + """ + Deserialize extra property to JSON. + + :param nested: Determines whether nested structures are also deserialized into JSON (default False). + """ + extra_json = {} + + if self.extra: + try: + if nested: + for key, value in json.loads(self.extra).items(): + extra_json[key] = value + if isinstance(value, str): + with suppress(JSONDecodeError): + extra_json[key] = json.loads(value) + else: + extra_json = json.loads(self.extra) + except JSONDecodeError: + log.exception("Failed parsing the json for conn_id %s", self.conn_id) + + # TODO: Mask sensitive keys from this list + # mask_secret(extra) + + return extra_json diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index f49e42aa0eb30..0e0eead4f09c2 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -656,7 +656,7 @@ def resolve_template_files(self): def get_template_env(self, *, force_sandboxed: bool = False) -> jinja2.Environment: """Build a Jinja2 environment.""" - import airflow.templates + from airflow.sdk.definitions.templater import NativeEnvironment, SandboxedEnvironment # Collect directories to search for template files searchpath = [self.folder] @@ -674,9 +674,9 @@ def get_template_env(self, *, force_sandboxed: bool = False) -> jinja2.Environme jinja_env_options.update(self.jinja_environment_kwargs) env: jinja2.Environment if self.render_template_as_native_obj and not force_sandboxed: - env = airflow.templates.NativeEnvironment(**jinja_env_options) + env = NativeEnvironment(**jinja_env_options) else: - env = airflow.templates.SandboxedEnvironment(**jinja_env_options) + env = SandboxedEnvironment(**jinja_env_options) # Add any user defined items. Safe to edit globals as long as no templates are rendered yet. # http://jinja.pocoo.org/docs/2.10/api/#jinja2.Environment.globals diff --git a/task_sdk/src/airflow/sdk/definitions/macros.py b/task_sdk/src/airflow/sdk/definitions/macros.py new file mode 100644 index 0000000000000..d423ac504b006 --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/macros.py @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json # noqa: F401 +import time # noqa: F401 +import uuid # noqa: F401 +from datetime import datetime, timedelta +from random import random # noqa: F401 + + +def ds_add(ds: str, days: int) -> str: + """ + Add or subtract days from a YYYY-MM-DD. + + :param ds: anchor date in ``YYYY-MM-DD`` format to add to + :param days: number of days to add to the ds, you can use negative values + + >>> ds_add("2015-01-01", 5) + '2015-01-06' + >>> ds_add("2015-01-06", -5) + '2015-01-01' + """ + if not days: + return str(ds) + dt = datetime.strptime(str(ds), "%Y-%m-%d") + timedelta(days=days) + return dt.strftime("%Y-%m-%d") + + +def ds_format(ds: str, input_format: str, output_format: str) -> str: + """ + Output datetime string in a given format. + + :param ds: Input string which contains a date. + :param input_format: Input string format (e.g., '%Y-%m-%d'). + :param output_format: Output string format (e.g., '%Y-%m-%d'). + + >>> ds_format("2015-01-01", "%Y-%m-%d", "%m-%d-%y") + '01-01-15' + >>> ds_format("1/5/2015", "%m/%d/%Y", "%Y-%m-%d") + '2015-01-05' + >>> ds_format("12/07/2024", "%d/%m/%Y", "%A %d %B %Y", "en_US") + 'Friday 12 July 2024' + """ + return datetime.strptime(str(ds), input_format).strftime(output_format) diff --git a/task_sdk/src/airflow/sdk/definitions/mixins.py b/task_sdk/src/airflow/sdk/definitions/mixins.py index de63772615de5..7b1594e697874 100644 --- a/task_sdk/src/airflow/sdk/definitions/mixins.py +++ b/task_sdk/src/airflow/sdk/definitions/mixins.py @@ -18,10 +18,11 @@ from __future__ import annotations from abc import abstractmethod -from collections.abc import Iterable, Sequence +from collections.abc import Iterable, Mapping, Sequence from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from airflow.sdk.definitions.abstractoperator import AbstractOperator from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.edges import EdgeModifier @@ -109,9 +110,6 @@ def __rlshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): def _iter_references(cls, obj: Any) -> Iterable[tuple[DependencyMixin, str]]: from airflow.sdk.definitions.abstractoperator import AbstractOperator - # TODO:Task-SDK - from airflow.utils.mixins import ResolveMixin - if isinstance(obj, AbstractOperator): yield obj, "operator" elif isinstance(obj, ResolveMixin): @@ -119,3 +117,25 @@ def _iter_references(cls, obj: Any) -> Iterable[tuple[DependencyMixin, str]]: elif isinstance(obj, Sequence): for o in obj: yield from cls._iter_references(o) + + +class ResolveMixin: + """A runtime-resolved value.""" + + def iter_references(self) -> Iterable[tuple[AbstractOperator, str]]: + """ + Find underlying XCom references this contains. + + This is used by the DAG parser to recursively find task dependencies. + + :meta private: + """ + raise NotImplementedError + + def resolve(self, context: Mapping[str, Any], *, include_xcom: bool = True) -> Any: + """ + Resolve this value for runtime. + + :meta private: + """ + raise NotImplementedError diff --git a/airflow/template/templater.py b/task_sdk/src/airflow/sdk/definitions/templater.py similarity index 74% rename from airflow/template/templater.py rename to task_sdk/src/airflow/sdk/definitions/templater.py index e81c0877c23d5..ac33e7cbed62b 100644 --- a/airflow/template/templater.py +++ b/task_sdk/src/airflow/sdk/definitions/templater.py @@ -17,25 +17,39 @@ from __future__ import annotations +import datetime +import logging from collections.abc import Collection, Iterable, Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Any +import jinja2 +import jinja2.nativetypes +import jinja2.sandbox + from airflow.io.path import ObjectStoragePath +from airflow.sdk.definitions.mixins import ResolveMixin from airflow.utils.helpers import render_template_as_native, render_template_to_string -from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.mixins import ResolveMixin if TYPE_CHECKING: from collections.abc import Mapping - import jinja2 - from airflow.models.operator import Operator from airflow.sdk.definitions.dag import DAG from airflow.utils.context import Context +def literal(value: Any) -> LiteralValue: + """ + Wrap a value to ensure it is rendered as-is without applying Jinja templating to its contents. + + Designed for use in an operator's template field. + + :param value: The value to be rendered without templating + """ + return LiteralValue(value) + + @dataclass(frozen=True) class LiteralValue(ResolveMixin): """ @@ -53,7 +67,12 @@ def resolve(self, context: Context, *, include_xcom: bool = True) -> Any: return self.value -class Templater(LoggingMixin): +log = logging.getLogger(__name__) + + +# TODO: Task-SDK: Should everything below this line live in `_internal/templater.py`? +# so that it is not exposed to the public API. +class Templater: """ This renders the template fields of object. @@ -70,7 +89,6 @@ def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment: # This is imported locally since Jinja2 is heavy and we don't need it # for most of the functionalities. It is imported by get_template_env() # though, so we don't need to put this after the 'if dag' check. - from airflow.templates import SandboxedEnvironment if dag: return dag.get_template_env(force_sandboxed=False) @@ -94,7 +112,7 @@ def resolve_template_files(self) -> None: try: setattr(self, field, env.loader.get_source(env, content)[0]) # type: ignore except Exception: - self.log.exception("Failed to resolve template field %r", field) + log.exception("Failed to resolve template field %r", field) elif isinstance(content, list): env = self.get_template_env() for i, item in enumerate(content): @@ -102,7 +120,7 @@ def resolve_template_files(self) -> None: try: content[i] = env.loader.get_source(env, item)[0] # type: ignore except Exception: - self.log.exception("Failed to get source %s", item) + log.exception("Failed to get source %s", item) self.prepare_template() def _do_render_template_fields( @@ -218,3 +236,71 @@ def _render_nested_template_fields( # content has no inner template fields return self._do_render_template_fields(value, nested_template_fields, context, jinja_env, seen_oids) + + +class _AirflowEnvironmentMixin: + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.filters.update(FILTERS) + + def is_safe_attribute(self, obj, attr, value): + """ + Allow access to ``_`` prefix vars (but not ``__``). + + Unlike the stock SandboxedEnvironment, we allow access to "private" attributes (ones starting with + ``_``) whilst still blocking internal or truly private attributes (``__`` prefixed ones). + """ + return not jinja2.sandbox.is_internal_attribute(obj, attr) + + +class NativeEnvironment(_AirflowEnvironmentMixin, jinja2.nativetypes.NativeEnvironment): + """NativeEnvironment for Airflow task templates.""" + + +class SandboxedEnvironment(_AirflowEnvironmentMixin, jinja2.sandbox.SandboxedEnvironment): + """SandboxedEnvironment for Airflow task templates.""" + + +def ds_filter(value: datetime.date | datetime.time | None) -> str | None: + """Date filter.""" + if value is None: + return None + return value.strftime("%Y-%m-%d") + + +def ds_nodash_filter(value: datetime.date | datetime.time | None) -> str | None: + """Date filter without dashes.""" + if value is None: + return None + return value.strftime("%Y%m%d") + + +def ts_filter(value: datetime.date | datetime.time | None) -> str | None: + """Timestamp filter.""" + if value is None: + return None + return value.isoformat() + + +def ts_nodash_filter(value: datetime.date | datetime.time | None) -> str | None: + """Timestamp filter without dashes.""" + if value is None: + return None + return value.strftime("%Y%m%dT%H%M%S") + + +def ts_nodash_with_tz_filter(value: datetime.date | datetime.time | None) -> str | None: + """Timestamp filter with timezone.""" + if value is None: + return None + return value.isoformat().replace("-", "").replace(":", "") + + +FILTERS = { + "ds": ds_filter, + "ds_nodash": ds_nodash_filter, + "ts": ts_filter, + "ts_nodash": ts_nodash_filter, + "ts_nodash_with_tz": ts_nodash_with_tz_filter, +} diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py b/task_sdk/src/airflow/sdk/execution_time/context.py index d96f5aeda2c95..72ac2af225e85 100644 --- a/task_sdk/src/airflow/sdk/execution_time/context.py +++ b/task_sdk/src/airflow/sdk/execution_time/context.py @@ -135,3 +135,25 @@ def get(self, key, default_var: Any = NOTSET) -> Any: if e.error.error == ErrorType.VARIABLE_NOT_FOUND: return default_var raise + + +class MacrosAccessor: + """Wrapper to access Macros module lazily.""" + + _macros_module = None + + def __getattr__(self, item: str) -> Any: + # Lazily load Macros module + if not self._macros_module: + import airflow.sdk.definitions.macros + + self._macros_module = airflow.sdk.definitions.macros + return getattr(self._macros_module, item) + + def __repr__(self) -> str: + return "" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, MacrosAccessor): + return False + return True diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 4f8a4f0045c00..4f4a9e6ec0c9e 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -44,9 +44,10 @@ ToTask, XComResult, ) -from airflow.sdk.execution_time.context import ConnectionAccessor, VariableAccessor +from airflow.sdk.execution_time.context import ConnectionAccessor, MacrosAccessor, VariableAccessor if TYPE_CHECKING: + import jinja2 from structlog.typing import FilteringBoundLogger as Logger @@ -76,8 +77,9 @@ def get_template_context(self): "ti": self, # "outlet_events": OutletEventAccessors(), # "expanded_ti_count": expanded_ti_count, + "expanded_ti_count": None, # TODO: Implement this # "inlet_events": InletEventsAccessors(task.inlets, session=session), - # "macros": macros, + "macros": MacrosAccessor(), # "params": validated_params, # "prev_data_interval_start_success": get_prev_data_interval_start_success(), # "prev_data_interval_end_success": get_prev_data_interval_end_success(), @@ -118,6 +120,38 @@ def get_template_context(self): # TODO: We should use/move TypeDict from airflow.utils.context.Context return context + def render_templates( + self, context: dict | None = None, jinja_env: jinja2.Environment | None = None + ) -> BaseOperator: + """ + Render templates in the operator fields. + + If the task was originally mapped, this may replace ``self.task`` with + the unmapped, fully rendered BaseOperator. The original ``self.task`` + before replacement is returned. + """ + if not context: + context = self.get_template_context() + original_task = self.task + + ti = context["ti"] + + if TYPE_CHECKING: + assert original_task + assert self.task + assert ti.task + + # If self.task is mapped, this call replaces self.task to point to the + # unmapped BaseOperator created by this function! This is because the + # MappedOperator is useless for template rendering, and we need to be + # able to access the unmapped task instead. + original_task.render_template_fields(context, jinja_env) + # TODO: Add support for rendering templates in the MappedOperator + # if isinstance(self.task, MappedOperator): + # self.task = context["ti"].task + + return original_task + def xcom_pull( self, task_ids: str | Iterable[str] | None = None, # TODO: Simplify to a single task_id? (breaking change) @@ -228,6 +262,12 @@ def xcom_push(self, key: str, value: Any): ), ) + def get_relevant_upstream_map_indexes( + self, upstream: BaseOperator, ti_count: int | None, session: Any + ) -> list[int]: + # TODO: Implement this method + return None + def parse(what: StartupDetails) -> RuntimeTaskInstance: # TODO: Task-SDK: @@ -383,6 +423,9 @@ def run(ti: RuntimeTaskInstance, log: Logger): # TODO: Get a real context object ti.task = ti.task.prepare_for_execution() context = ti.get_template_context() + jinja_env = ti.task.dag.get_template_env() + ti.task = ti.render_templates(context=context, jinja_env=jinja_env) + # TODO: Get things from _execute_task_with_callbacks # - Clearing XCom # - Setting Current Context (set_current_context) diff --git a/airflow/utils/template.py b/task_sdk/tests/dags/taskflow_api.py similarity index 50% rename from airflow/utils/template.py rename to task_sdk/tests/dags/taskflow_api.py index ae6042a7e913c..3dbd5f99647e5 100644 --- a/airflow/utils/template.py +++ b/task_sdk/tests/dags/taskflow_api.py @@ -1,3 +1,4 @@ +# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -16,17 +17,38 @@ # under the License. from __future__ import annotations -from typing import Any +import pendulum + +from airflow.decorators import dag, task + + +@dag( + schedule=None, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + catchup=False, +) +def taskflow_api(): + @task() + def extract(): + order_data_dict = {"1001": 301.27, "1002": 433.21, "1003": 502.22} + return order_data_dict + + @task(multiple_outputs=True) + def transform(order_data_dict: dict): + total_order_value = 0 + + for value in order_data_dict.values(): + total_order_value += value -from airflow.template.templater import LiteralValue + return {"total_order_value": total_order_value} + @task() + def load(total_order_value: float): + print(f"Total order value is: {total_order_value:.2f}") -def literal(value: Any) -> LiteralValue: - """ - Wrap a value to ensure it is rendered as-is without applying Jinja templating to its contents. + order_data = extract() + order_summary = transform(order_data) + load(order_summary["total_order_value"]) - Designed for use in an operator's template field. - :param value: The value to be rendered without templating - """ - return LiteralValue(value) +taskflow_api() diff --git a/task_sdk/tests/defintions/test_baseoperator.py b/task_sdk/tests/defintions/test_baseoperator.py index f9431b4fe9c3c..e64a31717bc33 100644 --- a/task_sdk/tests/defintions/test_baseoperator.py +++ b/task_sdk/tests/defintions/test_baseoperator.py @@ -17,18 +17,55 @@ from __future__ import annotations +import logging +import uuid import warnings -from datetime import datetime, timedelta, timezone +from datetime import date, datetime, timedelta, timezone +from typing import NamedTuple +from unittest import mock +import jinja2 import pytest from airflow.sdk.definitions.baseoperator import BaseOperator, BaseOperatorMeta from airflow.sdk.definitions.dag import DAG +from airflow.sdk.definitions.templater import literal from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy, _UpstreamPriorityWeightStrategy DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc) +class ClassWithCustomAttributes: + """Class for testing purpose: allows to create objects with custom attributes in one single statement.""" + + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + def __str__(self): + return f"{ClassWithCustomAttributes.__name__}({str(self.__dict__)})" + + def __repr__(self): + return self.__str__() + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not self.__eq__(other) + + +# Objects with circular references (for testing purpose) +object1 = ClassWithCustomAttributes(attr="{{ foo }}_1", template_fields=["ref"]) +object2 = ClassWithCustomAttributes(attr="{{ foo }}_2", ref=object1, template_fields=["ref"]) +setattr(object1, "ref", object2) + + +class MockNamedTuple(NamedTuple): + var1: str + var2: str + + # Essentially similar to airflow.models.baseoperator.BaseOperator class FakeOperator(metaclass=BaseOperatorMeta): def __init__(self, test_param, params=None, default_args=None): @@ -264,6 +301,205 @@ def test_invalid_trigger_rule(self): ): BaseOperator(task_id="op1", trigger_rule="some_rule") + @pytest.mark.parametrize( + ("content", "context", "expected_output"), + [ + ("{{ foo }}", {"foo": "bar"}, "bar"), + (["{{ foo }}_1", "{{ foo }}_2"], {"foo": "bar"}, ["bar_1", "bar_2"]), + (("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, ("bar_1", "bar_2")), + ( + {"key1": "{{ foo }}_1", "key2": "{{ foo }}_2"}, + {"foo": "bar"}, + {"key1": "bar_1", "key2": "bar_2"}, + ), + ( + {"key_{{ foo }}_1": 1, "key_2": "{{ foo }}_2"}, + {"foo": "bar"}, + {"key_{{ foo }}_1": 1, "key_2": "bar_2"}, + ), + (date(2018, 12, 6), {"foo": "bar"}, date(2018, 12, 6)), + (datetime(2018, 12, 6, 10, 55), {"foo": "bar"}, datetime(2018, 12, 6, 10, 55)), + (MockNamedTuple("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, MockNamedTuple("bar_1", "bar_2")), + ({"{{ foo }}_1", "{{ foo }}_2"}, {"foo": "bar"}, {"bar_1", "bar_2"}), + (None, {}, None), + ([], {}, []), + ({}, {}, {}), + ( + # check nested fields can be templated + ClassWithCustomAttributes(att1="{{ foo }}_1", att2="{{ foo }}_2", template_fields=["att1"]), + {"foo": "bar"}, + ClassWithCustomAttributes(att1="bar_1", att2="{{ foo }}_2", template_fields=["att1"]), + ), + ( + # check deep nested fields can be templated + ClassWithCustomAttributes( + nested1=ClassWithCustomAttributes( + att1="{{ foo }}_1", att2="{{ foo }}_2", template_fields=["att1"] + ), + nested2=ClassWithCustomAttributes( + att3="{{ foo }}_3", att4="{{ foo }}_4", template_fields=["att3"] + ), + template_fields=["nested1"], + ), + {"foo": "bar"}, + ClassWithCustomAttributes( + nested1=ClassWithCustomAttributes( + att1="bar_1", att2="{{ foo }}_2", template_fields=["att1"] + ), + nested2=ClassWithCustomAttributes( + att3="{{ foo }}_3", att4="{{ foo }}_4", template_fields=["att3"] + ), + template_fields=["nested1"], + ), + ), + ( + # check null value on nested template field + ClassWithCustomAttributes(att1=None, template_fields=["att1"]), + {}, + ClassWithCustomAttributes(att1=None, template_fields=["att1"]), + ), + ( + # check there is no RecursionError on circular references + object1, + {"foo": "bar"}, + object1, + ), + # By default, Jinja2 drops one (single) trailing newline + ("{{ foo }}\n\n", {"foo": "bar"}, "bar\n"), + (literal("{{ foo }}"), {"foo": "bar"}, "{{ foo }}"), + (literal(["{{ foo }}_1", "{{ foo }}_2"]), {"foo": "bar"}, ["{{ foo }}_1", "{{ foo }}_2"]), + (literal(("{{ foo }}_1", "{{ foo }}_2")), {"foo": "bar"}, ("{{ foo }}_1", "{{ foo }}_2")), + ], + ) + def test_render_template(self, content, context, expected_output): + """Test render_template given various input types.""" + task = BaseOperator(task_id="op1") + + result = task.render_template(content, context) + assert result == expected_output + + @pytest.mark.parametrize( + ("content", "context", "expected_output"), + [ + ("{{ foo }}", {"foo": "bar"}, "bar"), + ("{{ foo }}", {"foo": ["bar1", "bar2"]}, ["bar1", "bar2"]), + (["{{ foo }}", "{{ foo | length}}"], {"foo": ["bar1", "bar2"]}, [["bar1", "bar2"], 2]), + (("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, ("bar_1", "bar_2")), + ("{{ ds }}", {"ds": date(2018, 12, 6)}, date(2018, 12, 6)), + (datetime(2018, 12, 6, 10, 55), {"foo": "bar"}, datetime(2018, 12, 6, 10, 55)), + ("{{ ds }}", {"ds": datetime(2018, 12, 6, 10, 55)}, datetime(2018, 12, 6, 10, 55)), + (MockNamedTuple("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, MockNamedTuple("bar_1", "bar_2")), + ( + ("{{ foo }}", "{{ foo.isoformat() }}"), + {"foo": datetime(2018, 12, 6, 10, 55)}, + (datetime(2018, 12, 6, 10, 55), "2018-12-06T10:55:00"), + ), + (None, {}, None), + ([], {}, []), + ({}, {}, {}), + ], + ) + def test_render_template_with_native_envs(self, content, context, expected_output): + """Test render_template given various input types with Native Python types""" + with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE, render_template_as_native_obj=True): + task = BaseOperator(task_id="op1") + + result = task.render_template(content, context) + assert result == expected_output + + def test_render_template_fields(self): + """Verify if operator attributes are correctly templated.""" + task = MockOperator(task_id="op1", arg1="{{ foo }}", arg2="{{ bar }}") + + # Assert nothing is templated yet + assert task.arg1 == "{{ foo }}" + assert task.arg2 == "{{ bar }}" + + # Trigger templating and verify if attributes are templated correctly + task.render_template_fields(context={"foo": "footemplated", "bar": "bartemplated"}) + assert task.arg1 == "footemplated" + assert task.arg2 == "bartemplated" + + def test_render_template_fields_func_using_context(self): + """Verify if operator attributes are correctly templated.""" + + def fn_to_template(context, jinja_env): + tmp = context["task"].render_template("{{ bar }}", context, jinja_env) + return "foo_" + tmp + + task = MockOperator(task_id="op1", arg2=fn_to_template) + + # Trigger templating and verify if attributes are templated correctly + task.render_template_fields(context={"bar": "bartemplated", "task": task}) + assert task.arg2 == "foo_bartemplated" + + def test_render_template_fields_simple_func(self): + """Verify if operator attributes are correctly templated.""" + + def fn_to_template(**kwargs): + a = "foo_" + ("bar" * 3) + return a + + task = MockOperator(task_id="op1", arg2=fn_to_template) + task.render_template_fields({}) + assert task.arg2 == "foo_barbarbar" + + @pytest.mark.parametrize(("content",), [(object(),), (uuid.uuid4(),)]) + def test_render_template_fields_no_change(self, content): + """Tests if non-templatable types remain unchanged.""" + task = BaseOperator(task_id="op1") + + result = task.render_template(content, {"foo": "bar"}) + assert content is result + + def test_nested_template_fields_declared_must_exist(self): + """Test render_template when a nested template field is missing.""" + task = BaseOperator(task_id="op1") + + error_message = ( + "'missing_field' is configured as a template field but ClassWithCustomAttributes does not have " + "this attribute." + ) + with pytest.raises(AttributeError, match=error_message): + task.render_template( + ClassWithCustomAttributes( + template_fields=["missing_field"], task_type="ClassWithCustomAttributes" + ), + {}, + ) + + def test_string_template_field_attr_is_converted_to_list(self): + """Verify template_fields attribute is converted to a list if declared as a string.""" + + class StringTemplateFieldsOperator(BaseOperator): + template_fields = "a_string" + + warning_message = ( + "The `template_fields` value for StringTemplateFieldsOperator is a string but should be a " + "list or tuple of string. Wrapping it in a list for execution. Please update " + "StringTemplateFieldsOperator accordingly." + ) + with pytest.warns(UserWarning, match=warning_message) as warnings: + task = StringTemplateFieldsOperator(task_id="op1") + + assert len(warnings) == 1 + assert isinstance(task.template_fields, list) + + def test_jinja_invalid_expression_is_just_propagated(self): + """Test render_template propagates Jinja invalid expression errors.""" + task = BaseOperator(task_id="op1") + + with pytest.raises(jinja2.exceptions.TemplateSyntaxError): + task.render_template("{{ invalid expression }}", {}) + + @mock.patch("airflow.sdk.definitions.templater.SandboxedEnvironment", autospec=True) + def test_jinja_env_creation(self, mock_jinja_env): + """Verify if a Jinja environment is created only once when templating.""" + task = MockOperator(task_id="op1", arg1="{{ foo }}", arg2="{{ bar }}") + + task.render_template_fields(context={"foo": "whatever", "bar": "whatever"}) + assert mock_jinja_env.call_count == 1 + def test_init_subclass_args(): class InitSubclassOp(BaseOperator): @@ -330,8 +566,58 @@ def test_dag_level_retry_delay(): assert task1.retry_delay == timedelta(seconds=100) -def test_task_level_retry_delay(): - with DAG(dag_id="test_task_level_retry_delay", default_args={"retry_delay": timedelta(seconds=100)}): - task1 = BaseOperator(task_id="test_no_explicit_retry_delay", retry_delay=200) - - assert task1.retry_delay == timedelta(seconds=200) +@pytest.mark.parametrize( + ("task", "context", "expected_exception", "expected_rendering", "expected_log", "not_expected_log"), + [ + # Simple success case. + ( + MockOperator(task_id="op1", arg1="{{ foo }}"), + dict(foo="footemplated"), + None, + dict(arg1="footemplated"), + None, + "Exception rendering Jinja template", + ), + # Jinja syntax error. + ( + MockOperator(task_id="op1", arg1="{{ foo"), + dict(), + jinja2.TemplateSyntaxError, + None, + "Exception rendering Jinja template for task 'op1', field 'arg1'. Template: '{{ foo'", + None, + ), + # Type error + ( + MockOperator(task_id="op1", arg1="{{ foo + 1 }}"), + dict(foo="footemplated"), + TypeError, + None, + "Exception rendering Jinja template for task 'op1', field 'arg1'. Template: '{{ foo + 1 }}'", + None, + ), + ], +) +def test_render_template_fields_logging( + caplog, monkeypatch, task, context, expected_exception, expected_rendering, expected_log, not_expected_log +): + """Verify if operator attributes are correctly templated.""" + + # Trigger templating and verify results + def _do_render(): + task.render_template_fields(context=context) + + if expected_exception: + with ( + pytest.raises(expected_exception), + caplog.at_level(logging.ERROR, logger="airflow.sdk.definitions.templater"), + ): + _do_render() + else: + _do_render() + for k, v in expected_rendering.items(): + assert getattr(task, k) == v + if expected_log: + assert expected_log in caplog.text + if not_expected_log: + assert not_expected_log not in caplog.text diff --git a/task_sdk/tests/defintions/test_macros.py b/task_sdk/tests/defintions/test_macros.py new file mode 100644 index 0000000000000..f36fd8d648401 --- /dev/null +++ b/task_sdk/tests/defintions/test_macros.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import lazy_object_proxy +import pytest + +from airflow.sdk.definitions import macros + + +@pytest.mark.parametrize( + "ds, days, expected", + [ + ("2015-01-01", 5, "2015-01-06"), + ("2015-01-02", 0, "2015-01-02"), + ("2015-01-06", -5, "2015-01-01"), + (lazy_object_proxy.Proxy(lambda: "2015-01-01"), 5, "2015-01-06"), + (lazy_object_proxy.Proxy(lambda: "2015-01-02"), 0, "2015-01-02"), + (lazy_object_proxy.Proxy(lambda: "2015-01-06"), -5, "2015-01-01"), + ], +) +def test_ds_add(ds, days, expected): + result = macros.ds_add(ds, days) + assert result == expected + + +@pytest.mark.parametrize( + "ds, input_format, output_format, expected", + [ + ("2015-01-02", "%Y-%m-%d", "%m-%d-%y", "01-02-15"), + ("2015-01-02", "%Y-%m-%d", "%Y-%m-%d", "2015-01-02"), + ("1/5/2015", "%m/%d/%Y", "%m-%d-%y", "01-05-15"), + ("1/5/2015", "%m/%d/%Y", "%Y-%m-%d", "2015-01-05"), + (lazy_object_proxy.Proxy(lambda: "2015-01-02"), "%Y-%m-%d", "%m-%d-%y", "01-02-15"), + (lazy_object_proxy.Proxy(lambda: "2015-01-02"), "%Y-%m-%d", "%Y-%m-%d", "2015-01-02"), + (lazy_object_proxy.Proxy(lambda: "1/5/2015"), "%m/%d/%Y", "%m-%d-%y", "01-05-15"), + (lazy_object_proxy.Proxy(lambda: "1/5/2015"), "%m/%d/%Y", "%Y-%m-%d", "2015-01-05"), + ], +) +def test_ds_format(ds, input_format, output_format, expected): + result = macros.ds_format(ds, input_format, output_format) + assert result == expected + + +@pytest.mark.parametrize( + "input_value, expected", + [ + ('{"field1":"value1", "field2":4, "field3":true}', {"field1": "value1", "field2": 4, "field3": True}), + ( + '{"field1": [ 1, 2, 3, 4, 5 ], "field2" : {"mini1" : 1, "mini2" : "2"}}', + {"field1": [1, 2, 3, 4, 5], "field2": {"mini1": 1, "mini2": "2"}}, + ), + ], +) +def test_json_loads(input_value, expected): + result = macros.json.loads(input_value) + assert result == expected diff --git a/tests/template/test_templater.py b/task_sdk/tests/defintions/test_templater.py similarity index 70% rename from tests/template/test_templater.py rename to task_sdk/tests/defintions/test_templater.py index 778ca275e881f..69855b4ac2806 100644 --- a/tests/template/test_templater.py +++ b/task_sdk/tests/defintions/test_templater.py @@ -17,12 +17,13 @@ from __future__ import annotations +from datetime import datetime, timezone + import jinja2 +import pytest -from airflow.io.path import ObjectStoragePath -from airflow.models.dag import DAG -from airflow.template.templater import LiteralValue, Templater -from airflow.utils.context import Context +from airflow.sdk.definitions.dag import DAG +from airflow.sdk.definitions.templater import LiteralValue, SandboxedEnvironment, Templater class TestTemplater: @@ -54,15 +55,18 @@ def test_resolve_template_files_logs_exception(self, caplog): assert "Failed to resolve template field 'message'" in caplog.text def test_render_object_storage_path(self): + # TODO: Move this import to top-level after https://github.com/apache/airflow/issues/45425 + from airflow.io.path import ObjectStoragePath + templater = Templater() path = ObjectStoragePath("s3://bucket/key/{{ ds }}/part") - context = Context({"ds": "2006-02-01"}) # type: ignore + context = {"ds": "2006-02-01"} jinja_env = templater.get_template_env() rendered_content = templater._render_object_storage_path(path, context, jinja_env) assert rendered_content == ObjectStoragePath("s3://bucket/key/2006-02-01/part") def test_render_template(self): - context = Context({"name": "world"}) # type: ignore + context = {"name": "world"} templater = Templater() templater.message = "Hello {{ name }}" templater.template_fields = ["message"] @@ -73,7 +77,7 @@ def test_render_template(self): def test_not_render_literal_value(self): templater = Templater() templater.template_ext = [] - context = Context() + context = {} content = LiteralValue("Hello {{ name }}") rendered_content = templater.render_template(content, context) @@ -83,9 +87,42 @@ def test_not_render_literal_value(self): def test_not_render_file_literal_value(self): templater = Templater() templater.template_ext = [".txt"] - context = Context() + context = {} content = LiteralValue("template_file.txt") rendered_content = templater.render_template(content, context) assert rendered_content == "template_file.txt" + + +@pytest.fixture +def env(): + return SandboxedEnvironment(undefined=jinja2.StrictUndefined, cache_size=0) + + +def test_protected_access(env): + class Test: + _protected = 123 + + assert env.from_string(r"{{ obj._protected }}").render(obj=Test) == "123" + + +def test_private_access(env): + with pytest.raises(jinja2.exceptions.SecurityError): + env.from_string(r"{{ func.__code__ }}").render(func=test_private_access) + + +@pytest.mark.parametrize( + ["name", "expected"], + ( + ("ds", "2012-07-24"), + ("ds_nodash", "20120724"), + ("ts", "2012-07-24T03:04:52+00:00"), + ("ts_nodash", "20120724T030452"), + ("ts_nodash_with_tz", "20120724T030452+0000"), + ), +) +def test_filters(env, name, expected): + when = datetime(2012, 7, 24, 3, 4, 52, tzinfo=timezone.utc) + result = env.from_string("{{ date |" + name + " }}").render(date=when) + assert result == expected diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 8749bb2be1085..e212f9c47bcd0 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -46,7 +46,7 @@ TaskState, VariableResult, ) -from airflow.sdk.execution_time.context import ConnectionAccessor, VariableAccessor +from airflow.sdk.execution_time.context import ConnectionAccessor, MacrosAccessor, VariableAccessor from airflow.sdk.execution_time.task_runner import ( CommsDecoder, RuntimeTaskInstance, @@ -476,10 +476,10 @@ def test_startup_basic_templated_dag(mocked_parse, make_ti_context, mock_supervi ), ], ) -def test_startup_and_run_dag_with_templated_fields( +def test_startup_and_run_dag_with_rtif( mocked_parse, task_params, expected_rendered_fields, make_ti_context, time_machine, mock_supervisor_comms ): - """Test startup of a DAG with various templated fields.""" + """Test startup of a DAG with various rendered templated fields.""" class CustomOperator(BaseOperator): template_fields = tuple(task_params.keys()) @@ -523,6 +523,42 @@ def execute(self, context): mock_supervisor_comms.assert_has_calls(expected_calls) +@pytest.mark.parametrize( + ["command", "rendered_command"], + [ + ("{{ task.task_id }}", "templated_task"), + ("{{ run_id }}", "c"), + ("{{ logical_date }}", "2024-12-01 01:00:00+00:00"), + ], +) +def test_startup_and_run_dag_with_templated_fields( + command, rendered_command, mocked_parse, make_ti_context, time_machine, mock_supervisor_comms +): + """Test startup of a DAG with various templated fields.""" + + from airflow.providers.standard.operators.bash import BashOperator + + task = BashOperator(task_id="templated_task", bash_command=command) + + what = StartupDetails( + ti=TaskInstance(id=uuid7(), task_id="templated_task", dag_id="basic_dag", run_id="c", try_number=1), + file="", + requests_fd=0, + ti_context=make_ti_context(), + ) + ti = mocked_parse(what, "basic_dag", task) + ti._ti_context_from_server = make_ti_context( + logical_date="2024-12-01 01:00:00+00:00", + run_id="c", + ) + + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + run(ti, log=mock.MagicMock()) + assert ti.task.bash_command == rendered_command + + @pytest.mark.parametrize( ["dag_id", "task_id", "fail_with_exception"], [ @@ -599,7 +635,9 @@ def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_ }, "conn": ConnectionAccessor(), "dag": runtime_ti.task.dag, + "expanded_ti_count": None, "inlets": task.inlets, + "macros": MacrosAccessor(), "map_index_template": task.map_index_template, "outlets": task.outlets, "run_id": "test_run", @@ -637,6 +675,7 @@ def test_get_context_with_ti_context_from_server(self, mocked_parse, make_ti_con "conn": ConnectionAccessor(), "dag": runtime_ti.task.dag, "inlets": task.inlets, + "macros": MacrosAccessor(), "map_index_template": task.map_index_template, "outlets": task.outlets, "run_id": "test_run", @@ -649,6 +688,7 @@ def test_get_context_with_ti_context_from_server(self, mocked_parse, make_ti_con "logical_date": timezone.datetime(2024, 12, 1, 1, 0, 0), "ds": "2024-12-01", "ds_nodash": "20241201", + "expanded_ti_count": None, "task_instance_key_str": "basic_task__hello__20241201", "ts": "2024-12-01T01:00:00+00:00", "ts_nodash": "20241201T010000", @@ -706,6 +746,62 @@ def test_get_connection_from_context(self, mocked_parse, make_ti_context, mock_s extra='{"extra_key": "extra_value"}', ) + def test_template_render(self, mocked_parse, make_ti_context): + task = BaseOperator(task_id="test_template_render_task") + + ti = TaskInstance( + id=uuid7(), task_id=task.task_id, dag_id="test_template_render", run_id="test_run", try_number=1 + ) + + what = StartupDetails(ti=ti, file="", requests_fd=0, ti_context=make_ti_context()) + runtime_ti = mocked_parse(what, ti.dag_id, task) + template_context = runtime_ti.get_template_context() + result = runtime_ti.task.render_template( + "Task: {{ dag.dag_id }} -> {{ task.task_id }}", template_context + ) + assert result == "Task: test_template_render -> test_template_render_task" + + @pytest.mark.parametrize( + ["content", "expected_output"], + [ + ('{{ conn.get("a_connection").host }}', "hostvalue"), + ('{{ conn.get("a_connection", "unused_fallback").host }}', "hostvalue"), + ("{{ conn.a_connection.host }}", "hostvalue"), + ("{{ conn.a_connection.login }}", "loginvalue"), + ("{{ conn.a_connection.password }}", "passwordvalue"), + ('{{ conn.a_connection.extra_dejson["extra__asana__workspace"] }}', "extra1"), + ("{{ conn.a_connection.extra_dejson.extra__asana__workspace }}", "extra1"), + ], + ) + def test_template_with_connection( + self, content, expected_output, make_ti_context, mocked_parse, mock_supervisor_comms + ): + """ + Test the availability of connections in templates + """ + task = BaseOperator(task_id="hello") + + ti = TaskInstance( + id=uuid7(), task_id=task.task_id, dag_id="basic_task", run_id="test_run", try_number=1 + ) + conn = ConnectionResult( + conn_id="a_connection", + conn_type="a_type", + host="hostvalue", + login="loginvalue", + password="passwordvalue", + schema="schemavalues", + extra='{"extra__asana__workspace": "extra1"}', + ) + + what = StartupDetails(ti=ti, file="", requests_fd=0, ti_context=make_ti_context()) + runtime_ti = mocked_parse(what, ti.dag_id, task) + mock_supervisor_comms.get_message.return_value = conn + + context = runtime_ti.get_template_context() + result = runtime_ti.task.render_template(content, context) + assert result == expected_output + @pytest.mark.parametrize( ["accessor_type", "var_value", "expected_value"], [ diff --git a/tests/core/test_templates.py b/tests/core/test_templates.py deleted file mode 100644 index b64d31803ce8b..0000000000000 --- a/tests/core/test_templates.py +++ /dev/null @@ -1,57 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import jinja2 -import jinja2.exceptions -import pendulum -import pytest - -import airflow.templates - - -@pytest.fixture -def env(): - return airflow.templates.SandboxedEnvironment(undefined=jinja2.StrictUndefined, cache_size=0) - - -def test_protected_access(env): - class Test: - _protected = 123 - - assert env.from_string(r"{{ obj._protected }}").render(obj=Test) == "123" - - -def test_private_access(env): - with pytest.raises(jinja2.exceptions.SecurityError): - env.from_string(r"{{ func.__code__ }}").render(func=test_private_access) - - -@pytest.mark.parametrize( - ["name", "expected"], - ( - ("ds", "2012-07-24"), - ("ds_nodash", "20120724"), - ("ts", "2012-07-24T03:04:52+00:00"), - ("ts_nodash", "20120724T030452"), - ("ts_nodash_with_tz", "20120724T030452+0000"), - ), -) -def test_filters(env, name, expected): - when = pendulum.datetime(2012, 7, 24, 3, 4, 52, tz="UTC") - result = env.from_string("{{ date |" + name + " }}").render(date=when) - assert result == expected diff --git a/tests/macros/test_macros.py b/tests/macros/test_macros.py index 5dc0108ae0fdb..2fa4dfd232245 100644 --- a/tests/macros/test_macros.py +++ b/tests/macros/test_macros.py @@ -25,40 +25,6 @@ from airflow.utils import timezone -@pytest.mark.parametrize( - "ds, days, expected", - [ - ("2015-01-01", 5, "2015-01-06"), - ("2015-01-02", 0, "2015-01-02"), - ("2015-01-06", -5, "2015-01-01"), - (lazy_object_proxy.Proxy(lambda: "2015-01-01"), 5, "2015-01-06"), - (lazy_object_proxy.Proxy(lambda: "2015-01-02"), 0, "2015-01-02"), - (lazy_object_proxy.Proxy(lambda: "2015-01-06"), -5, "2015-01-01"), - ], -) -def test_ds_add(ds, days, expected): - result = macros.ds_add(ds, days) - assert result == expected - - -@pytest.mark.parametrize( - "ds, input_format, output_format, expected", - [ - ("2015-01-02", "%Y-%m-%d", "%m-%d-%y", "01-02-15"), - ("2015-01-02", "%Y-%m-%d", "%Y-%m-%d", "2015-01-02"), - ("1/5/2015", "%m/%d/%Y", "%m-%d-%y", "01-05-15"), - ("1/5/2015", "%m/%d/%Y", "%Y-%m-%d", "2015-01-05"), - (lazy_object_proxy.Proxy(lambda: "2015-01-02"), "%Y-%m-%d", "%m-%d-%y", "01-02-15"), - (lazy_object_proxy.Proxy(lambda: "2015-01-02"), "%Y-%m-%d", "%Y-%m-%d", "2015-01-02"), - (lazy_object_proxy.Proxy(lambda: "1/5/2015"), "%m/%d/%Y", "%m-%d-%y", "01-05-15"), - (lazy_object_proxy.Proxy(lambda: "1/5/2015"), "%m/%d/%Y", "%Y-%m-%d", "2015-01-05"), - ], -) -def test_ds_format(ds, input_format, output_format, expected): - result = macros.ds_format(ds, input_format, output_format) - assert result == expected - - @pytest.mark.parametrize( "ds, input_format, output_format, locale, expected", [ @@ -111,21 +77,6 @@ def test_datetime_diff_for_humans(dt, since, expected): assert result == expected -@pytest.mark.parametrize( - "input_value, expected", - [ - ('{"field1":"value1", "field2":4, "field3":true}', {"field1": "value1", "field2": 4, "field3": True}), - ( - '{"field1": [ 1, 2, 3, 4, 5 ], "field2" : {"mini1" : 1, "mini2" : "2"}}', - {"field1": [1, 2, 3, 4, 5], "field2": {"mini1": 1, "mini2": "2"}}, - ), - ], -) -def test_json_loads(input_value, expected): - result = macros.json.loads(input_value) - assert result == expected - - @pytest.mark.parametrize( "input_value, expected", [ diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 2c598edc777ac..bc601099744ff 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -18,14 +18,10 @@ from __future__ import annotations import copy -import logging -import uuid from collections import defaultdict -from datetime import date, datetime -from typing import NamedTuple +from datetime import datetime from unittest import mock -import jinja2 import pytest from airflow.decorators import task as task_decorator @@ -44,7 +40,6 @@ from airflow.providers.common.sql.operators import sql from airflow.utils.edgemodifier import Label from airflow.utils.task_group import TaskGroup -from airflow.utils.template import literal from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType @@ -72,17 +67,6 @@ def __ne__(self, other): return not self.__eq__(other) -# Objects with circular references (for testing purpose) -object1 = ClassWithCustomAttributes(attr="{{ foo }}_1", template_fields=["ref"]) -object2 = ClassWithCustomAttributes(attr="{{ foo }}_2", ref=object1, template_fields=["ref"]) -setattr(object1, "ref", object2) - - -class MockNamedTuple(NamedTuple): - var1: str - var2: str - - class TestBaseOperator: def test_trigger_rule_validation(self): from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE @@ -107,211 +91,6 @@ def test_trigger_rule_validation(self): task_id="test_valid_trigger_rule", dag=non_fail_stop_dag, trigger_rule=TriggerRule.ALWAYS ) - @pytest.mark.db_test - @pytest.mark.parametrize( - ("content", "context", "expected_output"), - [ - ("{{ foo }}", {"foo": "bar"}, "bar"), - (["{{ foo }}_1", "{{ foo }}_2"], {"foo": "bar"}, ["bar_1", "bar_2"]), - (("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, ("bar_1", "bar_2")), - ( - {"key1": "{{ foo }}_1", "key2": "{{ foo }}_2"}, - {"foo": "bar"}, - {"key1": "bar_1", "key2": "bar_2"}, - ), - ( - {"key_{{ foo }}_1": 1, "key_2": "{{ foo }}_2"}, - {"foo": "bar"}, - {"key_{{ foo }}_1": 1, "key_2": "bar_2"}, - ), - (date(2018, 12, 6), {"foo": "bar"}, date(2018, 12, 6)), - (datetime(2018, 12, 6, 10, 55), {"foo": "bar"}, datetime(2018, 12, 6, 10, 55)), - (MockNamedTuple("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, MockNamedTuple("bar_1", "bar_2")), - ({"{{ foo }}_1", "{{ foo }}_2"}, {"foo": "bar"}, {"bar_1", "bar_2"}), - (None, {}, None), - ([], {}, []), - ({}, {}, {}), - ( - # check nested fields can be templated - ClassWithCustomAttributes(att1="{{ foo }}_1", att2="{{ foo }}_2", template_fields=["att1"]), - {"foo": "bar"}, - ClassWithCustomAttributes(att1="bar_1", att2="{{ foo }}_2", template_fields=["att1"]), - ), - ( - # check deep nested fields can be templated - ClassWithCustomAttributes( - nested1=ClassWithCustomAttributes( - att1="{{ foo }}_1", att2="{{ foo }}_2", template_fields=["att1"] - ), - nested2=ClassWithCustomAttributes( - att3="{{ foo }}_3", att4="{{ foo }}_4", template_fields=["att3"] - ), - template_fields=["nested1"], - ), - {"foo": "bar"}, - ClassWithCustomAttributes( - nested1=ClassWithCustomAttributes( - att1="bar_1", att2="{{ foo }}_2", template_fields=["att1"] - ), - nested2=ClassWithCustomAttributes( - att3="{{ foo }}_3", att4="{{ foo }}_4", template_fields=["att3"] - ), - template_fields=["nested1"], - ), - ), - ( - # check null value on nested template field - ClassWithCustomAttributes(att1=None, template_fields=["att1"]), - {}, - ClassWithCustomAttributes(att1=None, template_fields=["att1"]), - ), - ( - # check there is no RecursionError on circular references - object1, - {"foo": "bar"}, - object1, - ), - # By default, Jinja2 drops one (single) trailing newline - ("{{ foo }}\n\n", {"foo": "bar"}, "bar\n"), - (literal("{{ foo }}"), {"foo": "bar"}, "{{ foo }}"), - (literal(["{{ foo }}_1", "{{ foo }}_2"]), {"foo": "bar"}, ["{{ foo }}_1", "{{ foo }}_2"]), - (literal(("{{ foo }}_1", "{{ foo }}_2")), {"foo": "bar"}, ("{{ foo }}_1", "{{ foo }}_2")), - ], - ) - def test_render_template(self, content, context, expected_output): - """Test render_template given various input types.""" - task = BaseOperator(task_id="op1") - - result = task.render_template(content, context) - assert result == expected_output - - @pytest.mark.parametrize( - ("content", "context", "expected_output"), - [ - ("{{ foo }}", {"foo": "bar"}, "bar"), - ("{{ foo }}", {"foo": ["bar1", "bar2"]}, ["bar1", "bar2"]), - (["{{ foo }}", "{{ foo | length}}"], {"foo": ["bar1", "bar2"]}, [["bar1", "bar2"], 2]), - (("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, ("bar_1", "bar_2")), - ("{{ ds }}", {"ds": date(2018, 12, 6)}, date(2018, 12, 6)), - (datetime(2018, 12, 6, 10, 55), {"foo": "bar"}, datetime(2018, 12, 6, 10, 55)), - ("{{ ds }}", {"ds": datetime(2018, 12, 6, 10, 55)}, datetime(2018, 12, 6, 10, 55)), - (MockNamedTuple("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, MockNamedTuple("bar_1", "bar_2")), - ( - ("{{ foo }}", "{{ foo.isoformat() }}"), - {"foo": datetime(2018, 12, 6, 10, 55)}, - (datetime(2018, 12, 6, 10, 55), "2018-12-06T10:55:00"), - ), - (None, {}, None), - ([], {}, []), - ({}, {}, {}), - ], - ) - def test_render_template_with_native_envs(self, content, context, expected_output): - """Test render_template given various input types with Native Python types""" - with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE, render_template_as_native_obj=True): - task = BaseOperator(task_id="op1") - - result = task.render_template(content, context) - assert result == expected_output - - @pytest.mark.db_test - def test_render_template_fields(self): - """Verify if operator attributes are correctly templated.""" - task = MockOperator(task_id="op1", arg1="{{ foo }}", arg2="{{ bar }}") - - # Assert nothing is templated yet - assert task.arg1 == "{{ foo }}" - assert task.arg2 == "{{ bar }}" - - # Trigger templating and verify if attributes are templated correctly - task.render_template_fields(context={"foo": "footemplated", "bar": "bartemplated"}) - assert task.arg1 == "footemplated" - assert task.arg2 == "bartemplated" - - @pytest.mark.db_test - def test_render_template_fields_func_using_context(self): - """Verify if operator attributes are correctly templated.""" - - def fn_to_template(context, jinja_env): - tmp = context["task"].render_template("{{ bar }}", context, jinja_env) - return "foo_" + tmp - - task = MockOperator(task_id="op1", arg2=fn_to_template) - - # Trigger templating and verify if attributes are templated correctly - task.render_template_fields(context={"bar": "bartemplated", "task": task}) - assert task.arg2 == "foo_bartemplated" - - @pytest.mark.db_test - def test_render_template_fields_simple_func(self): - """Verify if operator attributes are correctly templated.""" - - def fn_to_template(**kwargs): - a = "foo_" + ("bar" * 3) - return a - - task = MockOperator(task_id="op1", arg2=fn_to_template) - task.render_template_fields({}) - assert task.arg2 == "foo_barbarbar" - - @pytest.mark.parametrize(("content",), [(object(),), (uuid.uuid4(),)]) - def test_render_template_fields_no_change(self, content): - """Tests if non-templatable types remain unchanged.""" - task = BaseOperator(task_id="op1") - - result = task.render_template(content, {"foo": "bar"}) - assert content is result - - @pytest.mark.db_test - def test_nested_template_fields_declared_must_exist(self): - """Test render_template when a nested template field is missing.""" - task = BaseOperator(task_id="op1") - - error_message = ( - "'missing_field' is configured as a template field but ClassWithCustomAttributes does not have " - "this attribute." - ) - with pytest.raises(AttributeError, match=error_message): - task.render_template( - ClassWithCustomAttributes( - template_fields=["missing_field"], task_type="ClassWithCustomAttributes" - ), - {}, - ) - - def test_string_template_field_attr_is_converted_to_list(self): - """Verify template_fields attribute is converted to a list if declared as a string.""" - - class StringTemplateFieldsOperator(BaseOperator): - template_fields = "a_string" - - warning_message = ( - "The `template_fields` value for StringTemplateFieldsOperator is a string but should be a " - "list or tuple of string. Wrapping it in a list for execution. Please update " - "StringTemplateFieldsOperator accordingly." - ) - with pytest.warns(UserWarning, match=warning_message) as warnings: - task = StringTemplateFieldsOperator(task_id="op1") - - assert len(warnings) == 1 - assert isinstance(task.template_fields, list) - - def test_jinja_invalid_expression_is_just_propagated(self): - """Test render_template propagates Jinja invalid expression errors.""" - task = BaseOperator(task_id="op1") - - with pytest.raises(jinja2.exceptions.TemplateSyntaxError): - task.render_template("{{ invalid expression }}", {}) - - @pytest.mark.db_test - @mock.patch("airflow.templates.SandboxedEnvironment", autospec=True) - def test_jinja_env_creation(self, mock_jinja_env): - """Verify if a Jinja environment is created only once when templating.""" - task = MockOperator(task_id="op1", arg1="{{ foo }}", arg2="{{ bar }}") - - task.render_template_fields(context={"foo": "whatever", "bar": "whatever"}) - assert mock_jinja_env.call_count == 1 - def test_cross_downstream(self): """Test if all dependencies between tasks are all set correctly.""" dag = DAG(dag_id="test_dag", schedule=None, start_date=datetime.now()) @@ -632,63 +411,6 @@ def task0(): copy.deepcopy(dag) -@pytest.mark.db_test -@pytest.mark.parametrize( - ("task", "context", "expected_exception", "expected_rendering", "expected_log", "not_expected_log"), - [ - # Simple success case. - ( - MockOperator(task_id="op1", arg1="{{ foo }}"), - dict(foo="footemplated"), - None, - dict(arg1="footemplated"), - None, - "Exception rendering Jinja template", - ), - # Jinja syntax error. - ( - MockOperator(task_id="op1", arg1="{{ foo"), - dict(), - jinja2.TemplateSyntaxError, - None, - "Exception rendering Jinja template for task 'op1', field 'arg1'. Template: '{{ foo'", - None, - ), - # Type error - ( - MockOperator(task_id="op1", arg1="{{ foo + 1 }}"), - dict(foo="footemplated"), - TypeError, - None, - "Exception rendering Jinja template for task 'op1', field 'arg1'. Template: '{{ foo + 1 }}'", - None, - ), - ], -) -def test_render_template_fields_logging( - caplog, monkeypatch, task, context, expected_exception, expected_rendering, expected_log, not_expected_log -): - """Verify if operator attributes are correctly templated.""" - - # Trigger templating and verify results - def _do_render(): - task.render_template_fields(context=context) - - logger = logging.getLogger("airflow.task") - monkeypatch.setattr(logger, "propagate", True) - if expected_exception: - with pytest.raises(expected_exception): - _do_render() - else: - _do_render() - for k, v in expected_rendering.items(): - assert getattr(task, k) == v - if expected_log: - assert expected_log in caplog.text - if not_expected_log: - assert not_expected_log not in caplog.text - - @pytest.mark.db_test def test_find_mapped_dependants_in_another_group(dag_maker): from airflow.utils.task_group import TaskGroup diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 53090f6225bc8..03c0c2e1daefa 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -72,8 +72,8 @@ from airflow.sdk import TaskGroup from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny from airflow.sdk.definitions.contextmanager import TaskGroupContext +from airflow.sdk.definitions.templater import NativeEnvironment, SandboxedEnvironment from airflow.security import permissions -from airflow.templates import NativeEnvironment, SandboxedEnvironment from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.timetables.simple import ( AssetTriggeredTimetable,