diff --git a/airflow/macros/__init__.py b/airflow/macros/__init__.py index be4554818acf2..26b08c8a6b383 100644 --- a/airflow/macros/__init__.py +++ b/airflow/macros/__init__.py @@ -17,11 +17,7 @@ # under the License. from __future__ import annotations -import json # noqa: F401 -import time # noqa: F401 -import uuid # noqa: F401 -from datetime import datetime, timedelta -from random import random # noqa: F401 +from datetime import datetime from typing import TYPE_CHECKING, Any import dateutil # noqa: F401 @@ -29,47 +25,12 @@ from babel.dates import LC_TIME, format_datetime import airflow.utils.yaml as yaml # noqa: F401 +from airflow.sdk.definitions.macros import ds_add, ds_format, json, time, uuid # noqa: F401 if TYPE_CHECKING: from pendulum import DateTime -def ds_add(ds: str, days: int) -> str: - """ - Add or subtract days from a YYYY-MM-DD. - - :param ds: anchor date in ``YYYY-MM-DD`` format to add to - :param days: number of days to add to the ds, you can use negative values - - >>> ds_add("2015-01-01", 5) - '2015-01-06' - >>> ds_add("2015-01-06", -5) - '2015-01-01' - """ - if not days: - return str(ds) - dt = datetime.strptime(str(ds), "%Y-%m-%d") + timedelta(days=days) - return dt.strftime("%Y-%m-%d") - - -def ds_format(ds: str, input_format: str, output_format: str) -> str: - """ - Output datetime string in a given format. - - :param ds: Input string which contains a date. - :param input_format: Input string format (e.g., '%Y-%m-%d'). - :param output_format: Output string format (e.g., '%Y-%m-%d'). - - >>> ds_format("2015-01-01", "%Y-%m-%d", "%m-%d-%y") - '01-01-15' - >>> ds_format("1/5/2015", "%m/%d/%Y", "%Y-%m-%d") - '2015-01-05' - >>> ds_format("12/07/2024", "%d/%m/%Y", "%A %d %B %Y", "en_US") - 'Friday 12 July 2024' - """ - return datetime.strptime(str(ds), input_format).strftime(output_format) - - def ds_format_locale( ds: str, input_format: str, output_format: str, locale: Locale | str | None = None ) -> str: @@ -99,6 +60,7 @@ def ds_format_locale( ) +# TODO: Task SDK: Move this to the Task SDK once we evaluate "pendulum"'s dependency def datetime_diff_for_humans(dt: Any, since: DateTime | None = None) -> str: """ Return a human-readable/approximate difference between datetimes. diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index aa23bf33e131a..134db08d71bb9 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -30,10 +30,9 @@ from airflow.exceptions import AirflowException from airflow.models.expandinput import NotFullyPopulated from airflow.sdk.definitions.abstractoperator import AbstractOperator as TaskSDKAbstractOperator -from airflow.template.templater import Templater from airflow.utils.context import Context from airflow.utils.db import exists_query -from airflow.utils.log.secrets_masker import redact +from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.sqlalchemy import with_row_locks from airflow.utils.state import State, TaskInstanceState @@ -42,8 +41,6 @@ from airflow.utils.weight_rule import WeightRule, db_safe_priority if TYPE_CHECKING: - from collections.abc import Mapping - import jinja2 # Slow import. from sqlalchemy.orm import Session @@ -52,7 +49,6 @@ from airflow.models.mappedoperator import MappedOperator from airflow.models.taskinstance import TaskInstance from airflow.sdk.definitions.baseoperator import BaseOperator - from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.node import DAGNode from airflow.task.priority_strategy import PriorityWeightStrategy from airflow.triggers.base import StartTriggerArgs @@ -88,7 +84,7 @@ class NotMapped(Exception): """Raise if a task is neither mapped nor has any parent mapped groups.""" -class AbstractOperator(Templater, TaskSDKAbstractOperator): +class AbstractOperator(LoggingMixin, TaskSDKAbstractOperator): """ Common implementation for operators, including unmapped and mapped. @@ -128,72 +124,6 @@ def on_failure_fail_dagrun(self, value): ) self._on_failure_fail_dagrun = value - def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment: - """Get the template environment for rendering templates.""" - if dag is None: - dag = self.get_dag() - return super().get_template_env(dag=dag) - - def _render(self, template, context, dag: DAG | None = None): - if dag is None: - dag = self.get_dag() - return super()._render(template, context, dag=dag) - - def _do_render_template_fields( - self, - parent: Any, - template_fields: Iterable[str], - context: Mapping[str, Any], - jinja_env: jinja2.Environment, - seen_oids: set[int], - ) -> None: - """Override the base to use custom error logging.""" - for attr_name in template_fields: - try: - value = getattr(parent, attr_name) - except AttributeError: - raise AttributeError( - f"{attr_name!r} is configured as a template field " - f"but {parent.task_type} does not have this attribute." - ) - try: - if not value: - continue - except Exception: - # This may happen if the templated field points to a class which does not support `__bool__`, - # such as Pandas DataFrames: - # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465 - self.log.info( - "Unable to check if the value of type '%s' is False for task '%s', field '%s'.", - type(value).__name__, - self.task_id, - attr_name, - ) - # We may still want to render custom classes which do not support __bool__ - pass - - try: - if callable(value): - rendered_content = value(context=context, jinja_env=jinja_env) - else: - rendered_content = self.render_template( - value, - context, - jinja_env, - seen_oids, - ) - except Exception: - value_masked = redact(name=attr_name, value=value) - self.log.exception( - "Exception rendering Jinja template for task '%s', field '%s'. Template: %r", - self.task_id, - attr_name, - value_masked, - ) - raise - else: - setattr(parent, attr_name, rendered_content) - def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]: """ Return mapped nodes that are direct dependencies of the current task. diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 08839cc0bf720..f28e05584fdf8 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -100,7 +100,6 @@ if TYPE_CHECKING: from types import ClassMethodDescriptorType - import jinja2 # Slow import. from sqlalchemy.orm import Session from airflow.models.abstractoperator import TaskStateChangeCallback @@ -738,23 +737,6 @@ def post_execute(self, context: Any, result: Any = None): logger=self.log, ).run(context, result) - def render_template_fields( - self, - context: Context, - jinja_env: jinja2.Environment | None = None, - ) -> None: - """ - Template all attributes listed in *self.template_fields*. - - This mutates the attributes in-place and is irreversible. - - :param context: Context dict with values to apply on content. - :param jinja_env: Jinja's environment to use for rendering. - """ - if not jinja_env: - jinja_env = self.get_template_env() - self._do_render_template_fields(self, self.template_fields, context, jinja_env, set()) - @provide_session def clear( self, diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py index 8d86ec193eb4d..b1e4daf78435a 100644 --- a/airflow/models/expandinput.py +++ b/airflow/models/expandinput.py @@ -25,7 +25,7 @@ import attr -from airflow.utils.mixins import ResolveMixin +from airflow.sdk.definitions.mixins import ResolveMixin from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: diff --git a/airflow/models/param.py b/airflow/models/param.py index ab7d2facd7e3b..4d55706d1ea57 100644 --- a/airflow/models/param.py +++ b/airflow/models/param.py @@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any, ClassVar from airflow.exceptions import AirflowException, ParamValidationError -from airflow.utils.mixins import ResolveMixin +from airflow.sdk.definitions.mixins import ResolveMixin from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index a5e50cb0d2cdb..27293fa2d022e 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -107,10 +107,10 @@ from airflow.models.xcom import LazyXComSelectSequence, XCom from airflow.plugins_manager import integrate_macros_plugins from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef +from airflow.sdk.definitions.templater import SandboxedEnvironment from airflow.sentry import Sentry from airflow.settings import task_instance_mutation_hook from airflow.stats import Stats -from airflow.templates import SandboxedEnvironment from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS from airflow.traces.tracer import Trace diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 103ddc663323c..9f99450e729ae 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -29,9 +29,9 @@ from airflow.models import MappedOperator, TaskInstance from airflow.models.abstractoperator import AbstractOperator from airflow.models.taskmixin import DependencyMixin +from airflow.sdk.definitions.mixins import ResolveMixin from airflow.sdk.types import NOTSET, ArgNotSet from airflow.utils.db import exists_query -from airflow.utils.mixins import ResolveMixin from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.setup_teardown import SetupTeardownContext from airflow.utils.state import State @@ -206,8 +206,7 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: """ raise NotImplementedError() - @provide_session - def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any: + def resolve(self, context: Context, session: Session | None = None, *, include_xcom: bool = True) -> Any: """ Pull XCom value. @@ -420,8 +419,8 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: ) return session.scalar(query) - @provide_session - def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_xcom: bool = True) -> Any: + # TODO: Task-SDK: Remove session argument once everything is ported over to Task SDK + def resolve(self, context: Context, session: Session | None = None, *, include_xcom: bool = True) -> Any: ti = context["ti"] if TYPE_CHECKING: assert isinstance(ti, TaskInstance) @@ -431,12 +430,12 @@ def resolve(self, context: Context, session: Session = NEW_SESSION, *, include_x context["expanded_ti_count"], session=session, ) + result = ti.xcom_pull( task_ids=task_id, map_indexes=map_indexes, key=self.key, default=NOTSET, - session=session, ) if not isinstance(result, ArgNotSet): return result diff --git a/airflow/notifications/basenotifier.py b/airflow/notifications/basenotifier.py index eaac6d11df36d..398d95cbb8d0a 100644 --- a/airflow/notifications/basenotifier.py +++ b/airflow/notifications/basenotifier.py @@ -21,8 +21,9 @@ from collections.abc import Sequence from typing import TYPE_CHECKING -from airflow.template.templater import Templater +from airflow.sdk.definitions.templater import Templater from airflow.utils.context import context_merge +from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: import jinja2 @@ -31,7 +32,7 @@ from airflow.utils.context import Context -class BaseNotifier(Templater): +class BaseNotifier(LoggingMixin, Templater): """BaseNotifier class for sending notifications.""" template_fields: Sequence[str] = () diff --git a/airflow/templates.py b/airflow/templates.py deleted file mode 100644 index 95851253a7d22..0000000000000 --- a/airflow/templates.py +++ /dev/null @@ -1,94 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from typing import TYPE_CHECKING - -import jinja2.nativetypes -import jinja2.sandbox - -if TYPE_CHECKING: - import datetime - - -class _AirflowEnvironmentMixin: - def __init__(self, **kwargs): - super().__init__(**kwargs) - - self.filters.update(FILTERS) - - def is_safe_attribute(self, obj, attr, value): - """ - Allow access to ``_`` prefix vars (but not ``__``). - - Unlike the stock SandboxedEnvironment, we allow access to "private" attributes (ones starting with - ``_``) whilst still blocking internal or truly private attributes (``__`` prefixed ones). - """ - return not jinja2.sandbox.is_internal_attribute(obj, attr) - - -class NativeEnvironment(_AirflowEnvironmentMixin, jinja2.nativetypes.NativeEnvironment): - """NativeEnvironment for Airflow task templates.""" - - -class SandboxedEnvironment(_AirflowEnvironmentMixin, jinja2.sandbox.SandboxedEnvironment): - """SandboxedEnvironment for Airflow task templates.""" - - -def ds_filter(value: datetime.date | datetime.time | None) -> str | None: - """Date filter.""" - if value is None: - return None - return value.strftime("%Y-%m-%d") - - -def ds_nodash_filter(value: datetime.date | datetime.time | None) -> str | None: - """Date filter without dashes.""" - if value is None: - return None - return value.strftime("%Y%m%d") - - -def ts_filter(value: datetime.date | datetime.time | None) -> str | None: - """Timestamp filter.""" - if value is None: - return None - return value.isoformat() - - -def ts_nodash_filter(value: datetime.date | datetime.time | None) -> str | None: - """Timestamp filter without dashes.""" - if value is None: - return None - return value.strftime("%Y%m%dT%H%M%S") - - -def ts_nodash_with_tz_filter(value: datetime.date | datetime.time | None) -> str | None: - """Timestamp filter with timezone.""" - if value is None: - return None - return value.isoformat().replace("-", "").replace(":", "") - - -FILTERS = { - "ds": ds_filter, - "ds_nodash": ds_nodash_filter, - "ts": ts_filter, - "ts_nodash": ts_nodash_filter, - "ts_nodash_with_tz": ts_nodash_with_tz_filter, -} diff --git a/airflow/utils/mixins.py b/airflow/utils/mixins.py index 99c4b039090d9..78484f8375ef6 100644 --- a/airflow/utils/mixins.py +++ b/airflow/utils/mixins.py @@ -20,14 +20,9 @@ import multiprocessing import multiprocessing.context -import typing from airflow.configuration import conf -if typing.TYPE_CHECKING: - from airflow.models.operator import Operator - from airflow.utils.context import Context - class MultiprocessingStartMethodMixin: """Convenience class to add support for different types of multiprocessing.""" @@ -49,25 +44,3 @@ def _get_multiprocessing_start_method(self) -> str: def _get_multiprocessing_context(self) -> multiprocessing.context.DefaultContext: mp_start_method = self._get_multiprocessing_start_method() return multiprocessing.get_context(mp_start_method) # type: ignore - - -class ResolveMixin: - """A runtime-resolved value.""" - - def iter_references(self) -> typing.Iterable[tuple[Operator, str]]: - """ - Find underlying XCom references this contains. - - This is used by the DAG parser to recursively find task dependencies. - - :meta private: - """ - raise NotImplementedError - - def resolve(self, context: Context, *, include_xcom: bool = True) -> typing.Any: - """ - Resolve this value for runtime. - - :meta private: - """ - raise NotImplementedError diff --git a/dev/breeze/tests/test_pytest_args_for_test_types.py b/dev/breeze/tests/test_pytest_args_for_test_types.py index be5138699722f..c6bfa93fefd68 100644 --- a/dev/breeze/tests/test_pytest_args_for_test_types.py +++ b/dev/breeze/tests/test_pytest_args_for_test_types.py @@ -127,7 +127,6 @@ "tests/security", "tests/sensors", "tests/task", - "tests/template", "tests/testconfig", "tests/timetables", ], diff --git a/providers/src/airflow/providers/standard/utils/python_virtualenv.py b/providers/src/airflow/providers/standard/utils/python_virtualenv.py index 9d03e43367a49..66cc92ee16e04 100644 --- a/providers/src/airflow/providers/standard/utils/python_virtualenv.py +++ b/providers/src/airflow/providers/standard/utils/python_virtualenv.py @@ -28,6 +28,7 @@ from jinja2 import select_autoescape from airflow.configuration import conf +from airflow.sdk.definitions.templater import NativeEnvironment from airflow.utils.process_utils import execute_in_subprocess @@ -196,9 +197,7 @@ def write_python_script( template_loader = jinja2.FileSystemLoader(searchpath=os.path.dirname(__file__)) template_env: jinja2.Environment if render_template_as_native_obj: - template_env = jinja2.nativetypes.NativeEnvironment( - loader=template_loader, undefined=jinja2.StrictUndefined - ) + template_env = NativeEnvironment(loader=template_loader, undefined=jinja2.StrictUndefined) else: template_env = jinja2.Environment( loader=template_loader, diff --git a/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py b/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py index 5104827882132..5a107f31146d0 100644 --- a/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py +++ b/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py @@ -794,7 +794,10 @@ def test_resolve_application_file_real_file( application_file = application_file.resolve().as_posix() if use_literal_value: - from airflow.template.templater import LiteralValue + try: + from airflow.sdk.definitions.templater import LiteralValue + except ImportError: + from airflow.template.templater import LiteralValue application_file = LiteralValue(application_file) else: @@ -820,7 +823,10 @@ def test_resolve_application_file_real_file( @pytest.mark.db_test def test_resolve_application_file_real_file_not_exists(create_task_instance_of_operator, tmp_path, session): application_file = (tmp_path / "test-application-file.yml").resolve().as_posix() - from airflow.template.templater import LiteralValue + try: + from airflow.sdk.definitions.templater import LiteralValue + except ImportError: + from airflow.template.templater import LiteralValue ti = create_task_instance_of_operator( SparkKubernetesOperator, diff --git a/scripts/cov/other_coverage.py b/scripts/cov/other_coverage.py index 0394e3590bec3..19a061557d70b 100644 --- a/scripts/cov/other_coverage.py +++ b/scripts/cov/other_coverage.py @@ -91,7 +91,6 @@ "tests/security", "tests/sensors", "tests/task", - "tests/template", "tests/testconfig", "tests/timetables", """ diff --git a/tests/template/__init__.py b/task_sdk/__init__.py similarity index 100% rename from tests/template/__init__.py rename to task_sdk/__init__.py diff --git a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py index 0e51a9748e824..0251033cd555f 100644 --- a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py @@ -18,10 +18,12 @@ from __future__ import annotations import datetime +import logging from abc import abstractmethod from collections.abc import ( Collection, Iterable, + Mapping, ) from typing import ( TYPE_CHECKING, @@ -31,10 +33,13 @@ from airflow.sdk.definitions.mixins import DependencyMixin from airflow.sdk.definitions.node import DAGNode +from airflow.sdk.definitions.templater import Templater from airflow.utils.trigger_rule import TriggerRule from airflow.utils.weight_rule import WeightRule if TYPE_CHECKING: + import jinja2 + from airflow.models.baseoperatorlink import BaseOperatorLink from airflow.models.operator import Operator from airflow.sdk.definitions.baseoperator import BaseOperator @@ -62,12 +67,14 @@ DEFAULT_WEIGHT_RULE: WeightRule = WeightRule.DOWNSTREAM DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = None +log = logging.getLogger(__name__) + class NotMapped(Exception): """Raise if a task is neither mapped nor has any parent mapped groups.""" -class AbstractOperator(DAGNode): +class AbstractOperator(Templater, DAGNode): """ Common implementation for operators, including unmapped and mapped. @@ -265,3 +272,67 @@ def get_upstreams_only_setups(self) -> Iterable[Operator]: for task in self.get_upstreams_only_setups_and_teardowns(): if task.is_setup: yield task + + # TODO: Task-SDK -- Should the following methods removed? + # get_template_env + # _render + def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment: + """Get the template environment for rendering templates.""" + if dag is None: + dag = self.get_dag() + return super().get_template_env(dag=dag) + + def _render(self, template, context, dag: DAG | None = None): + if dag is None: + dag = self.get_dag() + return super()._render(template, context, dag=dag) + + def _do_render_template_fields( + self, + parent: Any, + template_fields: Iterable[str], + context: Mapping[str, Any], + jinja_env: jinja2.Environment, + seen_oids: set[int], + ) -> None: + """Override the base to use custom error logging.""" + for attr_name in template_fields: + try: + value = getattr(parent, attr_name) + except AttributeError: + raise AttributeError( + f"{attr_name!r} is configured as a template field " + f"but {parent.task_type} does not have this attribute." + ) + try: + if not value: + continue + except Exception: + # This may happen if the templated field points to a class which does not support `__bool__`, + # such as Pandas DataFrames: + # https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465 + log.info( + "Unable to check if the value of type '%s' is False for task '%s', field '%s'.", + type(value).__name__, + self.task_id, + attr_name, + ) + # We may still want to render custom classes which do not support __bool__ + pass + + try: + if callable(value): + rendered_content = value(context=context, jinja_env=jinja_env) + else: + rendered_content = self.render_template(value, context, jinja_env, seen_oids) + except Exception: + # TODO: Mask the value. Depends on https://github.com/apache/airflow/issues/45438 + log.exception( + "Exception rendering Jinja template for task '%s', field '%s'. Template: %r", + self.task_id, + attr_name, + value, + ) + raise + else: + setattr(parent, attr_name, rendered_content) diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py index 93f90c2475418..44a152f6aa95f 100644 --- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py @@ -65,6 +65,8 @@ T = TypeVar("T", bound=FunctionType) if TYPE_CHECKING: + import jinja2 + from airflow.models.xcom_arg import XComArg from airflow.sdk.definitions.dag import DAG from airflow.sdk.definitions.taskgroup import TaskGroup @@ -1239,3 +1241,20 @@ def inherits_from_empty_operator(self): # needs to cope when `self` is a Serialized instance of a EmptyOperator or one # of its subclasses (which don't inherit from anything but BaseOperator). return getattr(self, "_is_empty", False) + + def render_template_fields( + self, + context: dict, # TODO: Change to `Context` once we have it + jinja_env: jinja2.Environment | None = None, + ) -> None: + """ + Template all attributes listed in *self.template_fields*. + + This mutates the attributes in-place and is irreversible. + + :param context: Context dict with values to apply on content. + :param jinja_env: Jinja's environment to use for rendering. + """ + if not jinja_env: + jinja_env = self.get_template_env() + self._do_render_template_fields(self, self.template_fields, context, jinja_env, set()) diff --git a/task_sdk/src/airflow/sdk/definitions/connection.py b/task_sdk/src/airflow/sdk/definitions/connection.py index 628b72e29be6e..aa8c79986afef 100644 --- a/task_sdk/src/airflow/sdk/definitions/connection.py +++ b/task_sdk/src/airflow/sdk/definitions/connection.py @@ -17,8 +17,15 @@ # under the License. from __future__ import annotations +import json +import logging +from contextlib import suppress +from json import JSONDecodeError + import attrs +log = logging.getLogger(__name__) + @attrs.define class Connection: @@ -50,3 +57,30 @@ class Connection: def get_uri(self): ... def get_hook(self): ... + + @property + def extra_dejson(self, nested: bool = False) -> dict: + """ + Deserialize extra property to JSON. + + :param nested: Determines whether nested structures are also deserialized into JSON (default False). + """ + extra_json = {} + + if self.extra: + try: + if nested: + for key, value in json.loads(self.extra).items(): + extra_json[key] = value + if isinstance(value, str): + with suppress(JSONDecodeError): + extra_json[key] = json.loads(value) + else: + extra_json = json.loads(self.extra) + except JSONDecodeError: + log.exception("Failed parsing the json for conn_id %s", self.conn_id) + + # TODO: Mask sensitive keys from this list + # mask_secret(extra) + + return extra_json diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index f49e42aa0eb30..0e0eead4f09c2 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -656,7 +656,7 @@ def resolve_template_files(self): def get_template_env(self, *, force_sandboxed: bool = False) -> jinja2.Environment: """Build a Jinja2 environment.""" - import airflow.templates + from airflow.sdk.definitions.templater import NativeEnvironment, SandboxedEnvironment # Collect directories to search for template files searchpath = [self.folder] @@ -674,9 +674,9 @@ def get_template_env(self, *, force_sandboxed: bool = False) -> jinja2.Environme jinja_env_options.update(self.jinja_environment_kwargs) env: jinja2.Environment if self.render_template_as_native_obj and not force_sandboxed: - env = airflow.templates.NativeEnvironment(**jinja_env_options) + env = NativeEnvironment(**jinja_env_options) else: - env = airflow.templates.SandboxedEnvironment(**jinja_env_options) + env = SandboxedEnvironment(**jinja_env_options) # Add any user defined items. Safe to edit globals as long as no templates are rendered yet. # http://jinja.pocoo.org/docs/2.10/api/#jinja2.Environment.globals diff --git a/task_sdk/src/airflow/sdk/definitions/macros.py b/task_sdk/src/airflow/sdk/definitions/macros.py new file mode 100644 index 0000000000000..d423ac504b006 --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/macros.py @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json # noqa: F401 +import time # noqa: F401 +import uuid # noqa: F401 +from datetime import datetime, timedelta +from random import random # noqa: F401 + + +def ds_add(ds: str, days: int) -> str: + """ + Add or subtract days from a YYYY-MM-DD. + + :param ds: anchor date in ``YYYY-MM-DD`` format to add to + :param days: number of days to add to the ds, you can use negative values + + >>> ds_add("2015-01-01", 5) + '2015-01-06' + >>> ds_add("2015-01-06", -5) + '2015-01-01' + """ + if not days: + return str(ds) + dt = datetime.strptime(str(ds), "%Y-%m-%d") + timedelta(days=days) + return dt.strftime("%Y-%m-%d") + + +def ds_format(ds: str, input_format: str, output_format: str) -> str: + """ + Output datetime string in a given format. + + :param ds: Input string which contains a date. + :param input_format: Input string format (e.g., '%Y-%m-%d'). + :param output_format: Output string format (e.g., '%Y-%m-%d'). + + >>> ds_format("2015-01-01", "%Y-%m-%d", "%m-%d-%y") + '01-01-15' + >>> ds_format("1/5/2015", "%m/%d/%Y", "%Y-%m-%d") + '2015-01-05' + >>> ds_format("12/07/2024", "%d/%m/%Y", "%A %d %B %Y", "en_US") + 'Friday 12 July 2024' + """ + return datetime.strptime(str(ds), input_format).strftime(output_format) diff --git a/task_sdk/src/airflow/sdk/definitions/mixins.py b/task_sdk/src/airflow/sdk/definitions/mixins.py index de63772615de5..7b1594e697874 100644 --- a/task_sdk/src/airflow/sdk/definitions/mixins.py +++ b/task_sdk/src/airflow/sdk/definitions/mixins.py @@ -18,10 +18,11 @@ from __future__ import annotations from abc import abstractmethod -from collections.abc import Iterable, Sequence +from collections.abc import Iterable, Mapping, Sequence from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from airflow.sdk.definitions.abstractoperator import AbstractOperator from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.edges import EdgeModifier @@ -109,9 +110,6 @@ def __rlshift__(self, other: DependencyMixin | Sequence[DependencyMixin]): def _iter_references(cls, obj: Any) -> Iterable[tuple[DependencyMixin, str]]: from airflow.sdk.definitions.abstractoperator import AbstractOperator - # TODO:Task-SDK - from airflow.utils.mixins import ResolveMixin - if isinstance(obj, AbstractOperator): yield obj, "operator" elif isinstance(obj, ResolveMixin): @@ -119,3 +117,25 @@ def _iter_references(cls, obj: Any) -> Iterable[tuple[DependencyMixin, str]]: elif isinstance(obj, Sequence): for o in obj: yield from cls._iter_references(o) + + +class ResolveMixin: + """A runtime-resolved value.""" + + def iter_references(self) -> Iterable[tuple[AbstractOperator, str]]: + """ + Find underlying XCom references this contains. + + This is used by the DAG parser to recursively find task dependencies. + + :meta private: + """ + raise NotImplementedError + + def resolve(self, context: Mapping[str, Any], *, include_xcom: bool = True) -> Any: + """ + Resolve this value for runtime. + + :meta private: + """ + raise NotImplementedError diff --git a/airflow/template/templater.py b/task_sdk/src/airflow/sdk/definitions/templater.py similarity index 74% rename from airflow/template/templater.py rename to task_sdk/src/airflow/sdk/definitions/templater.py index e81c0877c23d5..ac33e7cbed62b 100644 --- a/airflow/template/templater.py +++ b/task_sdk/src/airflow/sdk/definitions/templater.py @@ -17,25 +17,39 @@ from __future__ import annotations +import datetime +import logging from collections.abc import Collection, Iterable, Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Any +import jinja2 +import jinja2.nativetypes +import jinja2.sandbox + from airflow.io.path import ObjectStoragePath +from airflow.sdk.definitions.mixins import ResolveMixin from airflow.utils.helpers import render_template_as_native, render_template_to_string -from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.mixins import ResolveMixin if TYPE_CHECKING: from collections.abc import Mapping - import jinja2 - from airflow.models.operator import Operator from airflow.sdk.definitions.dag import DAG from airflow.utils.context import Context +def literal(value: Any) -> LiteralValue: + """ + Wrap a value to ensure it is rendered as-is without applying Jinja templating to its contents. + + Designed for use in an operator's template field. + + :param value: The value to be rendered without templating + """ + return LiteralValue(value) + + @dataclass(frozen=True) class LiteralValue(ResolveMixin): """ @@ -53,7 +67,12 @@ def resolve(self, context: Context, *, include_xcom: bool = True) -> Any: return self.value -class Templater(LoggingMixin): +log = logging.getLogger(__name__) + + +# TODO: Task-SDK: Should everything below this line live in `_internal/templater.py`? +# so that it is not exposed to the public API. +class Templater: """ This renders the template fields of object. @@ -70,7 +89,6 @@ def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment: # This is imported locally since Jinja2 is heavy and we don't need it # for most of the functionalities. It is imported by get_template_env() # though, so we don't need to put this after the 'if dag' check. - from airflow.templates import SandboxedEnvironment if dag: return dag.get_template_env(force_sandboxed=False) @@ -94,7 +112,7 @@ def resolve_template_files(self) -> None: try: setattr(self, field, env.loader.get_source(env, content)[0]) # type: ignore except Exception: - self.log.exception("Failed to resolve template field %r", field) + log.exception("Failed to resolve template field %r", field) elif isinstance(content, list): env = self.get_template_env() for i, item in enumerate(content): @@ -102,7 +120,7 @@ def resolve_template_files(self) -> None: try: content[i] = env.loader.get_source(env, item)[0] # type: ignore except Exception: - self.log.exception("Failed to get source %s", item) + log.exception("Failed to get source %s", item) self.prepare_template() def _do_render_template_fields( @@ -218,3 +236,71 @@ def _render_nested_template_fields( # content has no inner template fields return self._do_render_template_fields(value, nested_template_fields, context, jinja_env, seen_oids) + + +class _AirflowEnvironmentMixin: + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.filters.update(FILTERS) + + def is_safe_attribute(self, obj, attr, value): + """ + Allow access to ``_`` prefix vars (but not ``__``). + + Unlike the stock SandboxedEnvironment, we allow access to "private" attributes (ones starting with + ``_``) whilst still blocking internal or truly private attributes (``__`` prefixed ones). + """ + return not jinja2.sandbox.is_internal_attribute(obj, attr) + + +class NativeEnvironment(_AirflowEnvironmentMixin, jinja2.nativetypes.NativeEnvironment): + """NativeEnvironment for Airflow task templates.""" + + +class SandboxedEnvironment(_AirflowEnvironmentMixin, jinja2.sandbox.SandboxedEnvironment): + """SandboxedEnvironment for Airflow task templates.""" + + +def ds_filter(value: datetime.date | datetime.time | None) -> str | None: + """Date filter.""" + if value is None: + return None + return value.strftime("%Y-%m-%d") + + +def ds_nodash_filter(value: datetime.date | datetime.time | None) -> str | None: + """Date filter without dashes.""" + if value is None: + return None + return value.strftime("%Y%m%d") + + +def ts_filter(value: datetime.date | datetime.time | None) -> str | None: + """Timestamp filter.""" + if value is None: + return None + return value.isoformat() + + +def ts_nodash_filter(value: datetime.date | datetime.time | None) -> str | None: + """Timestamp filter without dashes.""" + if value is None: + return None + return value.strftime("%Y%m%dT%H%M%S") + + +def ts_nodash_with_tz_filter(value: datetime.date | datetime.time | None) -> str | None: + """Timestamp filter with timezone.""" + if value is None: + return None + return value.isoformat().replace("-", "").replace(":", "") + + +FILTERS = { + "ds": ds_filter, + "ds_nodash": ds_nodash_filter, + "ts": ts_filter, + "ts_nodash": ts_nodash_filter, + "ts_nodash_with_tz": ts_nodash_with_tz_filter, +} diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py b/task_sdk/src/airflow/sdk/execution_time/context.py index d96f5aeda2c95..72ac2af225e85 100644 --- a/task_sdk/src/airflow/sdk/execution_time/context.py +++ b/task_sdk/src/airflow/sdk/execution_time/context.py @@ -135,3 +135,25 @@ def get(self, key, default_var: Any = NOTSET) -> Any: if e.error.error == ErrorType.VARIABLE_NOT_FOUND: return default_var raise + + +class MacrosAccessor: + """Wrapper to access Macros module lazily.""" + + _macros_module = None + + def __getattr__(self, item: str) -> Any: + # Lazily load Macros module + if not self._macros_module: + import airflow.sdk.definitions.macros + + self._macros_module = airflow.sdk.definitions.macros + return getattr(self._macros_module, item) + + def __repr__(self) -> str: + return "" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, MacrosAccessor): + return False + return True diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 4f8a4f0045c00..4f4a9e6ec0c9e 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -44,9 +44,10 @@ ToTask, XComResult, ) -from airflow.sdk.execution_time.context import ConnectionAccessor, VariableAccessor +from airflow.sdk.execution_time.context import ConnectionAccessor, MacrosAccessor, VariableAccessor if TYPE_CHECKING: + import jinja2 from structlog.typing import FilteringBoundLogger as Logger @@ -76,8 +77,9 @@ def get_template_context(self): "ti": self, # "outlet_events": OutletEventAccessors(), # "expanded_ti_count": expanded_ti_count, + "expanded_ti_count": None, # TODO: Implement this # "inlet_events": InletEventsAccessors(task.inlets, session=session), - # "macros": macros, + "macros": MacrosAccessor(), # "params": validated_params, # "prev_data_interval_start_success": get_prev_data_interval_start_success(), # "prev_data_interval_end_success": get_prev_data_interval_end_success(), @@ -118,6 +120,38 @@ def get_template_context(self): # TODO: We should use/move TypeDict from airflow.utils.context.Context return context + def render_templates( + self, context: dict | None = None, jinja_env: jinja2.Environment | None = None + ) -> BaseOperator: + """ + Render templates in the operator fields. + + If the task was originally mapped, this may replace ``self.task`` with + the unmapped, fully rendered BaseOperator. The original ``self.task`` + before replacement is returned. + """ + if not context: + context = self.get_template_context() + original_task = self.task + + ti = context["ti"] + + if TYPE_CHECKING: + assert original_task + assert self.task + assert ti.task + + # If self.task is mapped, this call replaces self.task to point to the + # unmapped BaseOperator created by this function! This is because the + # MappedOperator is useless for template rendering, and we need to be + # able to access the unmapped task instead. + original_task.render_template_fields(context, jinja_env) + # TODO: Add support for rendering templates in the MappedOperator + # if isinstance(self.task, MappedOperator): + # self.task = context["ti"].task + + return original_task + def xcom_pull( self, task_ids: str | Iterable[str] | None = None, # TODO: Simplify to a single task_id? (breaking change) @@ -228,6 +262,12 @@ def xcom_push(self, key: str, value: Any): ), ) + def get_relevant_upstream_map_indexes( + self, upstream: BaseOperator, ti_count: int | None, session: Any + ) -> list[int]: + # TODO: Implement this method + return None + def parse(what: StartupDetails) -> RuntimeTaskInstance: # TODO: Task-SDK: @@ -383,6 +423,9 @@ def run(ti: RuntimeTaskInstance, log: Logger): # TODO: Get a real context object ti.task = ti.task.prepare_for_execution() context = ti.get_template_context() + jinja_env = ti.task.dag.get_template_env() + ti.task = ti.render_templates(context=context, jinja_env=jinja_env) + # TODO: Get things from _execute_task_with_callbacks # - Clearing XCom # - Setting Current Context (set_current_context) diff --git a/airflow/utils/template.py b/task_sdk/tests/dags/taskflow_api.py similarity index 50% rename from airflow/utils/template.py rename to task_sdk/tests/dags/taskflow_api.py index ae6042a7e913c..3dbd5f99647e5 100644 --- a/airflow/utils/template.py +++ b/task_sdk/tests/dags/taskflow_api.py @@ -1,3 +1,4 @@ +# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -16,17 +17,38 @@ # under the License. from __future__ import annotations -from typing import Any +import pendulum + +from airflow.decorators import dag, task + + +@dag( + schedule=None, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + catchup=False, +) +def taskflow_api(): + @task() + def extract(): + order_data_dict = {"1001": 301.27, "1002": 433.21, "1003": 502.22} + return order_data_dict + + @task(multiple_outputs=True) + def transform(order_data_dict: dict): + total_order_value = 0 + + for value in order_data_dict.values(): + total_order_value += value -from airflow.template.templater import LiteralValue + return {"total_order_value": total_order_value} + @task() + def load(total_order_value: float): + print(f"Total order value is: {total_order_value:.2f}") -def literal(value: Any) -> LiteralValue: - """ - Wrap a value to ensure it is rendered as-is without applying Jinja templating to its contents. + order_data = extract() + order_summary = transform(order_data) + load(order_summary["total_order_value"]) - Designed for use in an operator's template field. - :param value: The value to be rendered without templating - """ - return LiteralValue(value) +taskflow_api() diff --git a/task_sdk/tests/defintions/test_baseoperator.py b/task_sdk/tests/defintions/test_baseoperator.py index f9431b4fe9c3c..e64a31717bc33 100644 --- a/task_sdk/tests/defintions/test_baseoperator.py +++ b/task_sdk/tests/defintions/test_baseoperator.py @@ -17,18 +17,55 @@ from __future__ import annotations +import logging +import uuid import warnings -from datetime import datetime, timedelta, timezone +from datetime import date, datetime, timedelta, timezone +from typing import NamedTuple +from unittest import mock +import jinja2 import pytest from airflow.sdk.definitions.baseoperator import BaseOperator, BaseOperatorMeta from airflow.sdk.definitions.dag import DAG +from airflow.sdk.definitions.templater import literal from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy, _UpstreamPriorityWeightStrategy DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=timezone.utc) +class ClassWithCustomAttributes: + """Class for testing purpose: allows to create objects with custom attributes in one single statement.""" + + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + def __str__(self): + return f"{ClassWithCustomAttributes.__name__}({str(self.__dict__)})" + + def __repr__(self): + return self.__str__() + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __ne__(self, other): + return not self.__eq__(other) + + +# Objects with circular references (for testing purpose) +object1 = ClassWithCustomAttributes(attr="{{ foo }}_1", template_fields=["ref"]) +object2 = ClassWithCustomAttributes(attr="{{ foo }}_2", ref=object1, template_fields=["ref"]) +setattr(object1, "ref", object2) + + +class MockNamedTuple(NamedTuple): + var1: str + var2: str + + # Essentially similar to airflow.models.baseoperator.BaseOperator class FakeOperator(metaclass=BaseOperatorMeta): def __init__(self, test_param, params=None, default_args=None): @@ -264,6 +301,205 @@ def test_invalid_trigger_rule(self): ): BaseOperator(task_id="op1", trigger_rule="some_rule") + @pytest.mark.parametrize( + ("content", "context", "expected_output"), + [ + ("{{ foo }}", {"foo": "bar"}, "bar"), + (["{{ foo }}_1", "{{ foo }}_2"], {"foo": "bar"}, ["bar_1", "bar_2"]), + (("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, ("bar_1", "bar_2")), + ( + {"key1": "{{ foo }}_1", "key2": "{{ foo }}_2"}, + {"foo": "bar"}, + {"key1": "bar_1", "key2": "bar_2"}, + ), + ( + {"key_{{ foo }}_1": 1, "key_2": "{{ foo }}_2"}, + {"foo": "bar"}, + {"key_{{ foo }}_1": 1, "key_2": "bar_2"}, + ), + (date(2018, 12, 6), {"foo": "bar"}, date(2018, 12, 6)), + (datetime(2018, 12, 6, 10, 55), {"foo": "bar"}, datetime(2018, 12, 6, 10, 55)), + (MockNamedTuple("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, MockNamedTuple("bar_1", "bar_2")), + ({"{{ foo }}_1", "{{ foo }}_2"}, {"foo": "bar"}, {"bar_1", "bar_2"}), + (None, {}, None), + ([], {}, []), + ({}, {}, {}), + ( + # check nested fields can be templated + ClassWithCustomAttributes(att1="{{ foo }}_1", att2="{{ foo }}_2", template_fields=["att1"]), + {"foo": "bar"}, + ClassWithCustomAttributes(att1="bar_1", att2="{{ foo }}_2", template_fields=["att1"]), + ), + ( + # check deep nested fields can be templated + ClassWithCustomAttributes( + nested1=ClassWithCustomAttributes( + att1="{{ foo }}_1", att2="{{ foo }}_2", template_fields=["att1"] + ), + nested2=ClassWithCustomAttributes( + att3="{{ foo }}_3", att4="{{ foo }}_4", template_fields=["att3"] + ), + template_fields=["nested1"], + ), + {"foo": "bar"}, + ClassWithCustomAttributes( + nested1=ClassWithCustomAttributes( + att1="bar_1", att2="{{ foo }}_2", template_fields=["att1"] + ), + nested2=ClassWithCustomAttributes( + att3="{{ foo }}_3", att4="{{ foo }}_4", template_fields=["att3"] + ), + template_fields=["nested1"], + ), + ), + ( + # check null value on nested template field + ClassWithCustomAttributes(att1=None, template_fields=["att1"]), + {}, + ClassWithCustomAttributes(att1=None, template_fields=["att1"]), + ), + ( + # check there is no RecursionError on circular references + object1, + {"foo": "bar"}, + object1, + ), + # By default, Jinja2 drops one (single) trailing newline + ("{{ foo }}\n\n", {"foo": "bar"}, "bar\n"), + (literal("{{ foo }}"), {"foo": "bar"}, "{{ foo }}"), + (literal(["{{ foo }}_1", "{{ foo }}_2"]), {"foo": "bar"}, ["{{ foo }}_1", "{{ foo }}_2"]), + (literal(("{{ foo }}_1", "{{ foo }}_2")), {"foo": "bar"}, ("{{ foo }}_1", "{{ foo }}_2")), + ], + ) + def test_render_template(self, content, context, expected_output): + """Test render_template given various input types.""" + task = BaseOperator(task_id="op1") + + result = task.render_template(content, context) + assert result == expected_output + + @pytest.mark.parametrize( + ("content", "context", "expected_output"), + [ + ("{{ foo }}", {"foo": "bar"}, "bar"), + ("{{ foo }}", {"foo": ["bar1", "bar2"]}, ["bar1", "bar2"]), + (["{{ foo }}", "{{ foo | length}}"], {"foo": ["bar1", "bar2"]}, [["bar1", "bar2"], 2]), + (("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, ("bar_1", "bar_2")), + ("{{ ds }}", {"ds": date(2018, 12, 6)}, date(2018, 12, 6)), + (datetime(2018, 12, 6, 10, 55), {"foo": "bar"}, datetime(2018, 12, 6, 10, 55)), + ("{{ ds }}", {"ds": datetime(2018, 12, 6, 10, 55)}, datetime(2018, 12, 6, 10, 55)), + (MockNamedTuple("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, MockNamedTuple("bar_1", "bar_2")), + ( + ("{{ foo }}", "{{ foo.isoformat() }}"), + {"foo": datetime(2018, 12, 6, 10, 55)}, + (datetime(2018, 12, 6, 10, 55), "2018-12-06T10:55:00"), + ), + (None, {}, None), + ([], {}, []), + ({}, {}, {}), + ], + ) + def test_render_template_with_native_envs(self, content, context, expected_output): + """Test render_template given various input types with Native Python types""" + with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE, render_template_as_native_obj=True): + task = BaseOperator(task_id="op1") + + result = task.render_template(content, context) + assert result == expected_output + + def test_render_template_fields(self): + """Verify if operator attributes are correctly templated.""" + task = MockOperator(task_id="op1", arg1="{{ foo }}", arg2="{{ bar }}") + + # Assert nothing is templated yet + assert task.arg1 == "{{ foo }}" + assert task.arg2 == "{{ bar }}" + + # Trigger templating and verify if attributes are templated correctly + task.render_template_fields(context={"foo": "footemplated", "bar": "bartemplated"}) + assert task.arg1 == "footemplated" + assert task.arg2 == "bartemplated" + + def test_render_template_fields_func_using_context(self): + """Verify if operator attributes are correctly templated.""" + + def fn_to_template(context, jinja_env): + tmp = context["task"].render_template("{{ bar }}", context, jinja_env) + return "foo_" + tmp + + task = MockOperator(task_id="op1", arg2=fn_to_template) + + # Trigger templating and verify if attributes are templated correctly + task.render_template_fields(context={"bar": "bartemplated", "task": task}) + assert task.arg2 == "foo_bartemplated" + + def test_render_template_fields_simple_func(self): + """Verify if operator attributes are correctly templated.""" + + def fn_to_template(**kwargs): + a = "foo_" + ("bar" * 3) + return a + + task = MockOperator(task_id="op1", arg2=fn_to_template) + task.render_template_fields({}) + assert task.arg2 == "foo_barbarbar" + + @pytest.mark.parametrize(("content",), [(object(),), (uuid.uuid4(),)]) + def test_render_template_fields_no_change(self, content): + """Tests if non-templatable types remain unchanged.""" + task = BaseOperator(task_id="op1") + + result = task.render_template(content, {"foo": "bar"}) + assert content is result + + def test_nested_template_fields_declared_must_exist(self): + """Test render_template when a nested template field is missing.""" + task = BaseOperator(task_id="op1") + + error_message = ( + "'missing_field' is configured as a template field but ClassWithCustomAttributes does not have " + "this attribute." + ) + with pytest.raises(AttributeError, match=error_message): + task.render_template( + ClassWithCustomAttributes( + template_fields=["missing_field"], task_type="ClassWithCustomAttributes" + ), + {}, + ) + + def test_string_template_field_attr_is_converted_to_list(self): + """Verify template_fields attribute is converted to a list if declared as a string.""" + + class StringTemplateFieldsOperator(BaseOperator): + template_fields = "a_string" + + warning_message = ( + "The `template_fields` value for StringTemplateFieldsOperator is a string but should be a " + "list or tuple of string. Wrapping it in a list for execution. Please update " + "StringTemplateFieldsOperator accordingly." + ) + with pytest.warns(UserWarning, match=warning_message) as warnings: + task = StringTemplateFieldsOperator(task_id="op1") + + assert len(warnings) == 1 + assert isinstance(task.template_fields, list) + + def test_jinja_invalid_expression_is_just_propagated(self): + """Test render_template propagates Jinja invalid expression errors.""" + task = BaseOperator(task_id="op1") + + with pytest.raises(jinja2.exceptions.TemplateSyntaxError): + task.render_template("{{ invalid expression }}", {}) + + @mock.patch("airflow.sdk.definitions.templater.SandboxedEnvironment", autospec=True) + def test_jinja_env_creation(self, mock_jinja_env): + """Verify if a Jinja environment is created only once when templating.""" + task = MockOperator(task_id="op1", arg1="{{ foo }}", arg2="{{ bar }}") + + task.render_template_fields(context={"foo": "whatever", "bar": "whatever"}) + assert mock_jinja_env.call_count == 1 + def test_init_subclass_args(): class InitSubclassOp(BaseOperator): @@ -330,8 +566,58 @@ def test_dag_level_retry_delay(): assert task1.retry_delay == timedelta(seconds=100) -def test_task_level_retry_delay(): - with DAG(dag_id="test_task_level_retry_delay", default_args={"retry_delay": timedelta(seconds=100)}): - task1 = BaseOperator(task_id="test_no_explicit_retry_delay", retry_delay=200) - - assert task1.retry_delay == timedelta(seconds=200) +@pytest.mark.parametrize( + ("task", "context", "expected_exception", "expected_rendering", "expected_log", "not_expected_log"), + [ + # Simple success case. + ( + MockOperator(task_id="op1", arg1="{{ foo }}"), + dict(foo="footemplated"), + None, + dict(arg1="footemplated"), + None, + "Exception rendering Jinja template", + ), + # Jinja syntax error. + ( + MockOperator(task_id="op1", arg1="{{ foo"), + dict(), + jinja2.TemplateSyntaxError, + None, + "Exception rendering Jinja template for task 'op1', field 'arg1'. Template: '{{ foo'", + None, + ), + # Type error + ( + MockOperator(task_id="op1", arg1="{{ foo + 1 }}"), + dict(foo="footemplated"), + TypeError, + None, + "Exception rendering Jinja template for task 'op1', field 'arg1'. Template: '{{ foo + 1 }}'", + None, + ), + ], +) +def test_render_template_fields_logging( + caplog, monkeypatch, task, context, expected_exception, expected_rendering, expected_log, not_expected_log +): + """Verify if operator attributes are correctly templated.""" + + # Trigger templating and verify results + def _do_render(): + task.render_template_fields(context=context) + + if expected_exception: + with ( + pytest.raises(expected_exception), + caplog.at_level(logging.ERROR, logger="airflow.sdk.definitions.templater"), + ): + _do_render() + else: + _do_render() + for k, v in expected_rendering.items(): + assert getattr(task, k) == v + if expected_log: + assert expected_log in caplog.text + if not_expected_log: + assert not_expected_log not in caplog.text diff --git a/task_sdk/tests/defintions/test_macros.py b/task_sdk/tests/defintions/test_macros.py new file mode 100644 index 0000000000000..f36fd8d648401 --- /dev/null +++ b/task_sdk/tests/defintions/test_macros.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import lazy_object_proxy +import pytest + +from airflow.sdk.definitions import macros + + +@pytest.mark.parametrize( + "ds, days, expected", + [ + ("2015-01-01", 5, "2015-01-06"), + ("2015-01-02", 0, "2015-01-02"), + ("2015-01-06", -5, "2015-01-01"), + (lazy_object_proxy.Proxy(lambda: "2015-01-01"), 5, "2015-01-06"), + (lazy_object_proxy.Proxy(lambda: "2015-01-02"), 0, "2015-01-02"), + (lazy_object_proxy.Proxy(lambda: "2015-01-06"), -5, "2015-01-01"), + ], +) +def test_ds_add(ds, days, expected): + result = macros.ds_add(ds, days) + assert result == expected + + +@pytest.mark.parametrize( + "ds, input_format, output_format, expected", + [ + ("2015-01-02", "%Y-%m-%d", "%m-%d-%y", "01-02-15"), + ("2015-01-02", "%Y-%m-%d", "%Y-%m-%d", "2015-01-02"), + ("1/5/2015", "%m/%d/%Y", "%m-%d-%y", "01-05-15"), + ("1/5/2015", "%m/%d/%Y", "%Y-%m-%d", "2015-01-05"), + (lazy_object_proxy.Proxy(lambda: "2015-01-02"), "%Y-%m-%d", "%m-%d-%y", "01-02-15"), + (lazy_object_proxy.Proxy(lambda: "2015-01-02"), "%Y-%m-%d", "%Y-%m-%d", "2015-01-02"), + (lazy_object_proxy.Proxy(lambda: "1/5/2015"), "%m/%d/%Y", "%m-%d-%y", "01-05-15"), + (lazy_object_proxy.Proxy(lambda: "1/5/2015"), "%m/%d/%Y", "%Y-%m-%d", "2015-01-05"), + ], +) +def test_ds_format(ds, input_format, output_format, expected): + result = macros.ds_format(ds, input_format, output_format) + assert result == expected + + +@pytest.mark.parametrize( + "input_value, expected", + [ + ('{"field1":"value1", "field2":4, "field3":true}', {"field1": "value1", "field2": 4, "field3": True}), + ( + '{"field1": [ 1, 2, 3, 4, 5 ], "field2" : {"mini1" : 1, "mini2" : "2"}}', + {"field1": [1, 2, 3, 4, 5], "field2": {"mini1": 1, "mini2": "2"}}, + ), + ], +) +def test_json_loads(input_value, expected): + result = macros.json.loads(input_value) + assert result == expected diff --git a/tests/template/test_templater.py b/task_sdk/tests/defintions/test_templater.py similarity index 70% rename from tests/template/test_templater.py rename to task_sdk/tests/defintions/test_templater.py index 778ca275e881f..69855b4ac2806 100644 --- a/tests/template/test_templater.py +++ b/task_sdk/tests/defintions/test_templater.py @@ -17,12 +17,13 @@ from __future__ import annotations +from datetime import datetime, timezone + import jinja2 +import pytest -from airflow.io.path import ObjectStoragePath -from airflow.models.dag import DAG -from airflow.template.templater import LiteralValue, Templater -from airflow.utils.context import Context +from airflow.sdk.definitions.dag import DAG +from airflow.sdk.definitions.templater import LiteralValue, SandboxedEnvironment, Templater class TestTemplater: @@ -54,15 +55,18 @@ def test_resolve_template_files_logs_exception(self, caplog): assert "Failed to resolve template field 'message'" in caplog.text def test_render_object_storage_path(self): + # TODO: Move this import to top-level after https://github.com/apache/airflow/issues/45425 + from airflow.io.path import ObjectStoragePath + templater = Templater() path = ObjectStoragePath("s3://bucket/key/{{ ds }}/part") - context = Context({"ds": "2006-02-01"}) # type: ignore + context = {"ds": "2006-02-01"} jinja_env = templater.get_template_env() rendered_content = templater._render_object_storage_path(path, context, jinja_env) assert rendered_content == ObjectStoragePath("s3://bucket/key/2006-02-01/part") def test_render_template(self): - context = Context({"name": "world"}) # type: ignore + context = {"name": "world"} templater = Templater() templater.message = "Hello {{ name }}" templater.template_fields = ["message"] @@ -73,7 +77,7 @@ def test_render_template(self): def test_not_render_literal_value(self): templater = Templater() templater.template_ext = [] - context = Context() + context = {} content = LiteralValue("Hello {{ name }}") rendered_content = templater.render_template(content, context) @@ -83,9 +87,42 @@ def test_not_render_literal_value(self): def test_not_render_file_literal_value(self): templater = Templater() templater.template_ext = [".txt"] - context = Context() + context = {} content = LiteralValue("template_file.txt") rendered_content = templater.render_template(content, context) assert rendered_content == "template_file.txt" + + +@pytest.fixture +def env(): + return SandboxedEnvironment(undefined=jinja2.StrictUndefined, cache_size=0) + + +def test_protected_access(env): + class Test: + _protected = 123 + + assert env.from_string(r"{{ obj._protected }}").render(obj=Test) == "123" + + +def test_private_access(env): + with pytest.raises(jinja2.exceptions.SecurityError): + env.from_string(r"{{ func.__code__ }}").render(func=test_private_access) + + +@pytest.mark.parametrize( + ["name", "expected"], + ( + ("ds", "2012-07-24"), + ("ds_nodash", "20120724"), + ("ts", "2012-07-24T03:04:52+00:00"), + ("ts_nodash", "20120724T030452"), + ("ts_nodash_with_tz", "20120724T030452+0000"), + ), +) +def test_filters(env, name, expected): + when = datetime(2012, 7, 24, 3, 4, 52, tzinfo=timezone.utc) + result = env.from_string("{{ date |" + name + " }}").render(date=when) + assert result == expected diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 8749bb2be1085..e212f9c47bcd0 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -46,7 +46,7 @@ TaskState, VariableResult, ) -from airflow.sdk.execution_time.context import ConnectionAccessor, VariableAccessor +from airflow.sdk.execution_time.context import ConnectionAccessor, MacrosAccessor, VariableAccessor from airflow.sdk.execution_time.task_runner import ( CommsDecoder, RuntimeTaskInstance, @@ -476,10 +476,10 @@ def test_startup_basic_templated_dag(mocked_parse, make_ti_context, mock_supervi ), ], ) -def test_startup_and_run_dag_with_templated_fields( +def test_startup_and_run_dag_with_rtif( mocked_parse, task_params, expected_rendered_fields, make_ti_context, time_machine, mock_supervisor_comms ): - """Test startup of a DAG with various templated fields.""" + """Test startup of a DAG with various rendered templated fields.""" class CustomOperator(BaseOperator): template_fields = tuple(task_params.keys()) @@ -523,6 +523,42 @@ def execute(self, context): mock_supervisor_comms.assert_has_calls(expected_calls) +@pytest.mark.parametrize( + ["command", "rendered_command"], + [ + ("{{ task.task_id }}", "templated_task"), + ("{{ run_id }}", "c"), + ("{{ logical_date }}", "2024-12-01 01:00:00+00:00"), + ], +) +def test_startup_and_run_dag_with_templated_fields( + command, rendered_command, mocked_parse, make_ti_context, time_machine, mock_supervisor_comms +): + """Test startup of a DAG with various templated fields.""" + + from airflow.providers.standard.operators.bash import BashOperator + + task = BashOperator(task_id="templated_task", bash_command=command) + + what = StartupDetails( + ti=TaskInstance(id=uuid7(), task_id="templated_task", dag_id="basic_dag", run_id="c", try_number=1), + file="", + requests_fd=0, + ti_context=make_ti_context(), + ) + ti = mocked_parse(what, "basic_dag", task) + ti._ti_context_from_server = make_ti_context( + logical_date="2024-12-01 01:00:00+00:00", + run_id="c", + ) + + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + run(ti, log=mock.MagicMock()) + assert ti.task.bash_command == rendered_command + + @pytest.mark.parametrize( ["dag_id", "task_id", "fail_with_exception"], [ @@ -599,7 +635,9 @@ def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_ }, "conn": ConnectionAccessor(), "dag": runtime_ti.task.dag, + "expanded_ti_count": None, "inlets": task.inlets, + "macros": MacrosAccessor(), "map_index_template": task.map_index_template, "outlets": task.outlets, "run_id": "test_run", @@ -637,6 +675,7 @@ def test_get_context_with_ti_context_from_server(self, mocked_parse, make_ti_con "conn": ConnectionAccessor(), "dag": runtime_ti.task.dag, "inlets": task.inlets, + "macros": MacrosAccessor(), "map_index_template": task.map_index_template, "outlets": task.outlets, "run_id": "test_run", @@ -649,6 +688,7 @@ def test_get_context_with_ti_context_from_server(self, mocked_parse, make_ti_con "logical_date": timezone.datetime(2024, 12, 1, 1, 0, 0), "ds": "2024-12-01", "ds_nodash": "20241201", + "expanded_ti_count": None, "task_instance_key_str": "basic_task__hello__20241201", "ts": "2024-12-01T01:00:00+00:00", "ts_nodash": "20241201T010000", @@ -706,6 +746,62 @@ def test_get_connection_from_context(self, mocked_parse, make_ti_context, mock_s extra='{"extra_key": "extra_value"}', ) + def test_template_render(self, mocked_parse, make_ti_context): + task = BaseOperator(task_id="test_template_render_task") + + ti = TaskInstance( + id=uuid7(), task_id=task.task_id, dag_id="test_template_render", run_id="test_run", try_number=1 + ) + + what = StartupDetails(ti=ti, file="", requests_fd=0, ti_context=make_ti_context()) + runtime_ti = mocked_parse(what, ti.dag_id, task) + template_context = runtime_ti.get_template_context() + result = runtime_ti.task.render_template( + "Task: {{ dag.dag_id }} -> {{ task.task_id }}", template_context + ) + assert result == "Task: test_template_render -> test_template_render_task" + + @pytest.mark.parametrize( + ["content", "expected_output"], + [ + ('{{ conn.get("a_connection").host }}', "hostvalue"), + ('{{ conn.get("a_connection", "unused_fallback").host }}', "hostvalue"), + ("{{ conn.a_connection.host }}", "hostvalue"), + ("{{ conn.a_connection.login }}", "loginvalue"), + ("{{ conn.a_connection.password }}", "passwordvalue"), + ('{{ conn.a_connection.extra_dejson["extra__asana__workspace"] }}', "extra1"), + ("{{ conn.a_connection.extra_dejson.extra__asana__workspace }}", "extra1"), + ], + ) + def test_template_with_connection( + self, content, expected_output, make_ti_context, mocked_parse, mock_supervisor_comms + ): + """ + Test the availability of connections in templates + """ + task = BaseOperator(task_id="hello") + + ti = TaskInstance( + id=uuid7(), task_id=task.task_id, dag_id="basic_task", run_id="test_run", try_number=1 + ) + conn = ConnectionResult( + conn_id="a_connection", + conn_type="a_type", + host="hostvalue", + login="loginvalue", + password="passwordvalue", + schema="schemavalues", + extra='{"extra__asana__workspace": "extra1"}', + ) + + what = StartupDetails(ti=ti, file="", requests_fd=0, ti_context=make_ti_context()) + runtime_ti = mocked_parse(what, ti.dag_id, task) + mock_supervisor_comms.get_message.return_value = conn + + context = runtime_ti.get_template_context() + result = runtime_ti.task.render_template(content, context) + assert result == expected_output + @pytest.mark.parametrize( ["accessor_type", "var_value", "expected_value"], [ diff --git a/tests/core/test_templates.py b/tests/core/test_templates.py deleted file mode 100644 index b64d31803ce8b..0000000000000 --- a/tests/core/test_templates.py +++ /dev/null @@ -1,57 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import jinja2 -import jinja2.exceptions -import pendulum -import pytest - -import airflow.templates - - -@pytest.fixture -def env(): - return airflow.templates.SandboxedEnvironment(undefined=jinja2.StrictUndefined, cache_size=0) - - -def test_protected_access(env): - class Test: - _protected = 123 - - assert env.from_string(r"{{ obj._protected }}").render(obj=Test) == "123" - - -def test_private_access(env): - with pytest.raises(jinja2.exceptions.SecurityError): - env.from_string(r"{{ func.__code__ }}").render(func=test_private_access) - - -@pytest.mark.parametrize( - ["name", "expected"], - ( - ("ds", "2012-07-24"), - ("ds_nodash", "20120724"), - ("ts", "2012-07-24T03:04:52+00:00"), - ("ts_nodash", "20120724T030452"), - ("ts_nodash_with_tz", "20120724T030452+0000"), - ), -) -def test_filters(env, name, expected): - when = pendulum.datetime(2012, 7, 24, 3, 4, 52, tz="UTC") - result = env.from_string("{{ date |" + name + " }}").render(date=when) - assert result == expected diff --git a/tests/macros/test_macros.py b/tests/macros/test_macros.py index 5dc0108ae0fdb..2fa4dfd232245 100644 --- a/tests/macros/test_macros.py +++ b/tests/macros/test_macros.py @@ -25,40 +25,6 @@ from airflow.utils import timezone -@pytest.mark.parametrize( - "ds, days, expected", - [ - ("2015-01-01", 5, "2015-01-06"), - ("2015-01-02", 0, "2015-01-02"), - ("2015-01-06", -5, "2015-01-01"), - (lazy_object_proxy.Proxy(lambda: "2015-01-01"), 5, "2015-01-06"), - (lazy_object_proxy.Proxy(lambda: "2015-01-02"), 0, "2015-01-02"), - (lazy_object_proxy.Proxy(lambda: "2015-01-06"), -5, "2015-01-01"), - ], -) -def test_ds_add(ds, days, expected): - result = macros.ds_add(ds, days) - assert result == expected - - -@pytest.mark.parametrize( - "ds, input_format, output_format, expected", - [ - ("2015-01-02", "%Y-%m-%d", "%m-%d-%y", "01-02-15"), - ("2015-01-02", "%Y-%m-%d", "%Y-%m-%d", "2015-01-02"), - ("1/5/2015", "%m/%d/%Y", "%m-%d-%y", "01-05-15"), - ("1/5/2015", "%m/%d/%Y", "%Y-%m-%d", "2015-01-05"), - (lazy_object_proxy.Proxy(lambda: "2015-01-02"), "%Y-%m-%d", "%m-%d-%y", "01-02-15"), - (lazy_object_proxy.Proxy(lambda: "2015-01-02"), "%Y-%m-%d", "%Y-%m-%d", "2015-01-02"), - (lazy_object_proxy.Proxy(lambda: "1/5/2015"), "%m/%d/%Y", "%m-%d-%y", "01-05-15"), - (lazy_object_proxy.Proxy(lambda: "1/5/2015"), "%m/%d/%Y", "%Y-%m-%d", "2015-01-05"), - ], -) -def test_ds_format(ds, input_format, output_format, expected): - result = macros.ds_format(ds, input_format, output_format) - assert result == expected - - @pytest.mark.parametrize( "ds, input_format, output_format, locale, expected", [ @@ -111,21 +77,6 @@ def test_datetime_diff_for_humans(dt, since, expected): assert result == expected -@pytest.mark.parametrize( - "input_value, expected", - [ - ('{"field1":"value1", "field2":4, "field3":true}', {"field1": "value1", "field2": 4, "field3": True}), - ( - '{"field1": [ 1, 2, 3, 4, 5 ], "field2" : {"mini1" : 1, "mini2" : "2"}}', - {"field1": [1, 2, 3, 4, 5], "field2": {"mini1": 1, "mini2": "2"}}, - ), - ], -) -def test_json_loads(input_value, expected): - result = macros.json.loads(input_value) - assert result == expected - - @pytest.mark.parametrize( "input_value, expected", [ diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 2c598edc777ac..bc601099744ff 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -18,14 +18,10 @@ from __future__ import annotations import copy -import logging -import uuid from collections import defaultdict -from datetime import date, datetime -from typing import NamedTuple +from datetime import datetime from unittest import mock -import jinja2 import pytest from airflow.decorators import task as task_decorator @@ -44,7 +40,6 @@ from airflow.providers.common.sql.operators import sql from airflow.utils.edgemodifier import Label from airflow.utils.task_group import TaskGroup -from airflow.utils.template import literal from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType @@ -72,17 +67,6 @@ def __ne__(self, other): return not self.__eq__(other) -# Objects with circular references (for testing purpose) -object1 = ClassWithCustomAttributes(attr="{{ foo }}_1", template_fields=["ref"]) -object2 = ClassWithCustomAttributes(attr="{{ foo }}_2", ref=object1, template_fields=["ref"]) -setattr(object1, "ref", object2) - - -class MockNamedTuple(NamedTuple): - var1: str - var2: str - - class TestBaseOperator: def test_trigger_rule_validation(self): from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE @@ -107,211 +91,6 @@ def test_trigger_rule_validation(self): task_id="test_valid_trigger_rule", dag=non_fail_stop_dag, trigger_rule=TriggerRule.ALWAYS ) - @pytest.mark.db_test - @pytest.mark.parametrize( - ("content", "context", "expected_output"), - [ - ("{{ foo }}", {"foo": "bar"}, "bar"), - (["{{ foo }}_1", "{{ foo }}_2"], {"foo": "bar"}, ["bar_1", "bar_2"]), - (("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, ("bar_1", "bar_2")), - ( - {"key1": "{{ foo }}_1", "key2": "{{ foo }}_2"}, - {"foo": "bar"}, - {"key1": "bar_1", "key2": "bar_2"}, - ), - ( - {"key_{{ foo }}_1": 1, "key_2": "{{ foo }}_2"}, - {"foo": "bar"}, - {"key_{{ foo }}_1": 1, "key_2": "bar_2"}, - ), - (date(2018, 12, 6), {"foo": "bar"}, date(2018, 12, 6)), - (datetime(2018, 12, 6, 10, 55), {"foo": "bar"}, datetime(2018, 12, 6, 10, 55)), - (MockNamedTuple("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, MockNamedTuple("bar_1", "bar_2")), - ({"{{ foo }}_1", "{{ foo }}_2"}, {"foo": "bar"}, {"bar_1", "bar_2"}), - (None, {}, None), - ([], {}, []), - ({}, {}, {}), - ( - # check nested fields can be templated - ClassWithCustomAttributes(att1="{{ foo }}_1", att2="{{ foo }}_2", template_fields=["att1"]), - {"foo": "bar"}, - ClassWithCustomAttributes(att1="bar_1", att2="{{ foo }}_2", template_fields=["att1"]), - ), - ( - # check deep nested fields can be templated - ClassWithCustomAttributes( - nested1=ClassWithCustomAttributes( - att1="{{ foo }}_1", att2="{{ foo }}_2", template_fields=["att1"] - ), - nested2=ClassWithCustomAttributes( - att3="{{ foo }}_3", att4="{{ foo }}_4", template_fields=["att3"] - ), - template_fields=["nested1"], - ), - {"foo": "bar"}, - ClassWithCustomAttributes( - nested1=ClassWithCustomAttributes( - att1="bar_1", att2="{{ foo }}_2", template_fields=["att1"] - ), - nested2=ClassWithCustomAttributes( - att3="{{ foo }}_3", att4="{{ foo }}_4", template_fields=["att3"] - ), - template_fields=["nested1"], - ), - ), - ( - # check null value on nested template field - ClassWithCustomAttributes(att1=None, template_fields=["att1"]), - {}, - ClassWithCustomAttributes(att1=None, template_fields=["att1"]), - ), - ( - # check there is no RecursionError on circular references - object1, - {"foo": "bar"}, - object1, - ), - # By default, Jinja2 drops one (single) trailing newline - ("{{ foo }}\n\n", {"foo": "bar"}, "bar\n"), - (literal("{{ foo }}"), {"foo": "bar"}, "{{ foo }}"), - (literal(["{{ foo }}_1", "{{ foo }}_2"]), {"foo": "bar"}, ["{{ foo }}_1", "{{ foo }}_2"]), - (literal(("{{ foo }}_1", "{{ foo }}_2")), {"foo": "bar"}, ("{{ foo }}_1", "{{ foo }}_2")), - ], - ) - def test_render_template(self, content, context, expected_output): - """Test render_template given various input types.""" - task = BaseOperator(task_id="op1") - - result = task.render_template(content, context) - assert result == expected_output - - @pytest.mark.parametrize( - ("content", "context", "expected_output"), - [ - ("{{ foo }}", {"foo": "bar"}, "bar"), - ("{{ foo }}", {"foo": ["bar1", "bar2"]}, ["bar1", "bar2"]), - (["{{ foo }}", "{{ foo | length}}"], {"foo": ["bar1", "bar2"]}, [["bar1", "bar2"], 2]), - (("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, ("bar_1", "bar_2")), - ("{{ ds }}", {"ds": date(2018, 12, 6)}, date(2018, 12, 6)), - (datetime(2018, 12, 6, 10, 55), {"foo": "bar"}, datetime(2018, 12, 6, 10, 55)), - ("{{ ds }}", {"ds": datetime(2018, 12, 6, 10, 55)}, datetime(2018, 12, 6, 10, 55)), - (MockNamedTuple("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, MockNamedTuple("bar_1", "bar_2")), - ( - ("{{ foo }}", "{{ foo.isoformat() }}"), - {"foo": datetime(2018, 12, 6, 10, 55)}, - (datetime(2018, 12, 6, 10, 55), "2018-12-06T10:55:00"), - ), - (None, {}, None), - ([], {}, []), - ({}, {}, {}), - ], - ) - def test_render_template_with_native_envs(self, content, context, expected_output): - """Test render_template given various input types with Native Python types""" - with DAG("test-dag", schedule=None, start_date=DEFAULT_DATE, render_template_as_native_obj=True): - task = BaseOperator(task_id="op1") - - result = task.render_template(content, context) - assert result == expected_output - - @pytest.mark.db_test - def test_render_template_fields(self): - """Verify if operator attributes are correctly templated.""" - task = MockOperator(task_id="op1", arg1="{{ foo }}", arg2="{{ bar }}") - - # Assert nothing is templated yet - assert task.arg1 == "{{ foo }}" - assert task.arg2 == "{{ bar }}" - - # Trigger templating and verify if attributes are templated correctly - task.render_template_fields(context={"foo": "footemplated", "bar": "bartemplated"}) - assert task.arg1 == "footemplated" - assert task.arg2 == "bartemplated" - - @pytest.mark.db_test - def test_render_template_fields_func_using_context(self): - """Verify if operator attributes are correctly templated.""" - - def fn_to_template(context, jinja_env): - tmp = context["task"].render_template("{{ bar }}", context, jinja_env) - return "foo_" + tmp - - task = MockOperator(task_id="op1", arg2=fn_to_template) - - # Trigger templating and verify if attributes are templated correctly - task.render_template_fields(context={"bar": "bartemplated", "task": task}) - assert task.arg2 == "foo_bartemplated" - - @pytest.mark.db_test - def test_render_template_fields_simple_func(self): - """Verify if operator attributes are correctly templated.""" - - def fn_to_template(**kwargs): - a = "foo_" + ("bar" * 3) - return a - - task = MockOperator(task_id="op1", arg2=fn_to_template) - task.render_template_fields({}) - assert task.arg2 == "foo_barbarbar" - - @pytest.mark.parametrize(("content",), [(object(),), (uuid.uuid4(),)]) - def test_render_template_fields_no_change(self, content): - """Tests if non-templatable types remain unchanged.""" - task = BaseOperator(task_id="op1") - - result = task.render_template(content, {"foo": "bar"}) - assert content is result - - @pytest.mark.db_test - def test_nested_template_fields_declared_must_exist(self): - """Test render_template when a nested template field is missing.""" - task = BaseOperator(task_id="op1") - - error_message = ( - "'missing_field' is configured as a template field but ClassWithCustomAttributes does not have " - "this attribute." - ) - with pytest.raises(AttributeError, match=error_message): - task.render_template( - ClassWithCustomAttributes( - template_fields=["missing_field"], task_type="ClassWithCustomAttributes" - ), - {}, - ) - - def test_string_template_field_attr_is_converted_to_list(self): - """Verify template_fields attribute is converted to a list if declared as a string.""" - - class StringTemplateFieldsOperator(BaseOperator): - template_fields = "a_string" - - warning_message = ( - "The `template_fields` value for StringTemplateFieldsOperator is a string but should be a " - "list or tuple of string. Wrapping it in a list for execution. Please update " - "StringTemplateFieldsOperator accordingly." - ) - with pytest.warns(UserWarning, match=warning_message) as warnings: - task = StringTemplateFieldsOperator(task_id="op1") - - assert len(warnings) == 1 - assert isinstance(task.template_fields, list) - - def test_jinja_invalid_expression_is_just_propagated(self): - """Test render_template propagates Jinja invalid expression errors.""" - task = BaseOperator(task_id="op1") - - with pytest.raises(jinja2.exceptions.TemplateSyntaxError): - task.render_template("{{ invalid expression }}", {}) - - @pytest.mark.db_test - @mock.patch("airflow.templates.SandboxedEnvironment", autospec=True) - def test_jinja_env_creation(self, mock_jinja_env): - """Verify if a Jinja environment is created only once when templating.""" - task = MockOperator(task_id="op1", arg1="{{ foo }}", arg2="{{ bar }}") - - task.render_template_fields(context={"foo": "whatever", "bar": "whatever"}) - assert mock_jinja_env.call_count == 1 - def test_cross_downstream(self): """Test if all dependencies between tasks are all set correctly.""" dag = DAG(dag_id="test_dag", schedule=None, start_date=datetime.now()) @@ -632,63 +411,6 @@ def task0(): copy.deepcopy(dag) -@pytest.mark.db_test -@pytest.mark.parametrize( - ("task", "context", "expected_exception", "expected_rendering", "expected_log", "not_expected_log"), - [ - # Simple success case. - ( - MockOperator(task_id="op1", arg1="{{ foo }}"), - dict(foo="footemplated"), - None, - dict(arg1="footemplated"), - None, - "Exception rendering Jinja template", - ), - # Jinja syntax error. - ( - MockOperator(task_id="op1", arg1="{{ foo"), - dict(), - jinja2.TemplateSyntaxError, - None, - "Exception rendering Jinja template for task 'op1', field 'arg1'. Template: '{{ foo'", - None, - ), - # Type error - ( - MockOperator(task_id="op1", arg1="{{ foo + 1 }}"), - dict(foo="footemplated"), - TypeError, - None, - "Exception rendering Jinja template for task 'op1', field 'arg1'. Template: '{{ foo + 1 }}'", - None, - ), - ], -) -def test_render_template_fields_logging( - caplog, monkeypatch, task, context, expected_exception, expected_rendering, expected_log, not_expected_log -): - """Verify if operator attributes are correctly templated.""" - - # Trigger templating and verify results - def _do_render(): - task.render_template_fields(context=context) - - logger = logging.getLogger("airflow.task") - monkeypatch.setattr(logger, "propagate", True) - if expected_exception: - with pytest.raises(expected_exception): - _do_render() - else: - _do_render() - for k, v in expected_rendering.items(): - assert getattr(task, k) == v - if expected_log: - assert expected_log in caplog.text - if not_expected_log: - assert not_expected_log not in caplog.text - - @pytest.mark.db_test def test_find_mapped_dependants_in_another_group(dag_maker): from airflow.utils.task_group import TaskGroup diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 53090f6225bc8..03c0c2e1daefa 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -72,8 +72,8 @@ from airflow.sdk import TaskGroup from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAll, AssetAny from airflow.sdk.definitions.contextmanager import TaskGroupContext +from airflow.sdk.definitions.templater import NativeEnvironment, SandboxedEnvironment from airflow.security import permissions -from airflow.templates import NativeEnvironment, SandboxedEnvironment from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable from airflow.timetables.simple import ( AssetTriggeredTimetable,