Skip to content

Commit

Permalink
Merge branch 'main' into feature/http-extra-options-check-response
Browse files Browse the repository at this point in the history
  • Loading branch information
dabla authored Jan 7, 2025
2 parents faefff1 + 0cc2d72 commit f2cc837
Show file tree
Hide file tree
Showing 55 changed files with 1,117 additions and 780 deletions.
2 changes: 1 addition & 1 deletion airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def __attrs_post_init__(self):
XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value)

def _expand_mapped_kwargs(
self, context: Context, session: Session, *, include_xcom: bool
self, context: Mapping[str, Any], session: Session, *, include_xcom: bool
) -> tuple[Mapping[str, Any], set[int]]:
# We only use op_kwargs_expand_input so this must always be empty.
if self.expand_input is not EXPAND_INPUT_EMPTY:
Expand Down
44 changes: 3 additions & 41 deletions airflow/macros/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,59 +17,20 @@
# under the License.
from __future__ import annotations

import json # noqa: F401
import time # noqa: F401
import uuid # noqa: F401
from datetime import datetime, timedelta
from random import random # noqa: F401
from datetime import datetime
from typing import TYPE_CHECKING, Any

import dateutil # noqa: F401
from babel import Locale
from babel.dates import LC_TIME, format_datetime

import airflow.utils.yaml as yaml # noqa: F401
from airflow.sdk.definitions.macros import ds_add, ds_format, json, time, uuid # noqa: F401

if TYPE_CHECKING:
from pendulum import DateTime


def ds_add(ds: str, days: int) -> str:
"""
Add or subtract days from a YYYY-MM-DD.
:param ds: anchor date in ``YYYY-MM-DD`` format to add to
:param days: number of days to add to the ds, you can use negative values
>>> ds_add("2015-01-01", 5)
'2015-01-06'
>>> ds_add("2015-01-06", -5)
'2015-01-01'
"""
if not days:
return str(ds)
dt = datetime.strptime(str(ds), "%Y-%m-%d") + timedelta(days=days)
return dt.strftime("%Y-%m-%d")


def ds_format(ds: str, input_format: str, output_format: str) -> str:
"""
Output datetime string in a given format.
:param ds: Input string which contains a date.
:param input_format: Input string format (e.g., '%Y-%m-%d').
:param output_format: Output string format (e.g., '%Y-%m-%d').
>>> ds_format("2015-01-01", "%Y-%m-%d", "%m-%d-%y")
'01-01-15'
>>> ds_format("1/5/2015", "%m/%d/%Y", "%Y-%m-%d")
'2015-01-05'
>>> ds_format("12/07/2024", "%d/%m/%Y", "%A %d %B %Y", "en_US")
'Friday 12 July 2024'
"""
return datetime.strptime(str(ds), input_format).strftime(output_format)


def ds_format_locale(
ds: str, input_format: str, output_format: str, locale: Locale | str | None = None
) -> str:
Expand Down Expand Up @@ -99,6 +60,7 @@ def ds_format_locale(
)


# TODO: Task SDK: Move this to the Task SDK once we evaluate "pendulum"'s dependency
def datetime_diff_for_humans(dt: Any, since: DateTime | None = None) -> str:
"""
Return a human-readable/approximate difference between datetimes.
Expand Down
78 changes: 4 additions & 74 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import datetime
import inspect
from collections.abc import Iterable, Iterator, Sequence
from collections.abc import Iterable, Iterator, Mapping, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable

Expand All @@ -30,10 +30,9 @@
from airflow.exceptions import AirflowException
from airflow.models.expandinput import NotFullyPopulated
from airflow.sdk.definitions.abstractoperator import AbstractOperator as TaskSDKAbstractOperator
from airflow.template.templater import Templater
from airflow.utils.context import Context
from airflow.utils.db import exists_query
from airflow.utils.log.secrets_masker import redact
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.setup_teardown import SetupTeardownContext
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.state import State, TaskInstanceState
Expand All @@ -42,8 +41,6 @@
from airflow.utils.weight_rule import WeightRule, db_safe_priority

if TYPE_CHECKING:
from collections.abc import Mapping

import jinja2 # Slow import.
from sqlalchemy.orm import Session

Expand All @@ -52,7 +49,6 @@
from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskinstance import TaskInstance
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.node import DAGNode
from airflow.task.priority_strategy import PriorityWeightStrategy
from airflow.triggers.base import StartTriggerArgs
Expand Down Expand Up @@ -88,7 +84,7 @@ class NotMapped(Exception):
"""Raise if a task is neither mapped nor has any parent mapped groups."""


class AbstractOperator(Templater, TaskSDKAbstractOperator):
class AbstractOperator(LoggingMixin, TaskSDKAbstractOperator):
"""
Common implementation for operators, including unmapped and mapped.
Expand Down Expand Up @@ -128,72 +124,6 @@ def on_failure_fail_dagrun(self, value):
)
self._on_failure_fail_dagrun = value

def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
"""Get the template environment for rendering templates."""
if dag is None:
dag = self.get_dag()
return super().get_template_env(dag=dag)

def _render(self, template, context, dag: DAG | None = None):
if dag is None:
dag = self.get_dag()
return super()._render(template, context, dag=dag)

def _do_render_template_fields(
self,
parent: Any,
template_fields: Iterable[str],
context: Mapping[str, Any],
jinja_env: jinja2.Environment,
seen_oids: set[int],
) -> None:
"""Override the base to use custom error logging."""
for attr_name in template_fields:
try:
value = getattr(parent, attr_name)
except AttributeError:
raise AttributeError(
f"{attr_name!r} is configured as a template field "
f"but {parent.task_type} does not have this attribute."
)
try:
if not value:
continue
except Exception:
# This may happen if the templated field points to a class which does not support `__bool__`,
# such as Pandas DataFrames:
# https://github.com/pandas-dev/pandas/blob/9135c3aaf12d26f857fcc787a5b64d521c51e379/pandas/core/generic.py#L1465
self.log.info(
"Unable to check if the value of type '%s' is False for task '%s', field '%s'.",
type(value).__name__,
self.task_id,
attr_name,
)
# We may still want to render custom classes which do not support __bool__
pass

try:
if callable(value):
rendered_content = value(context=context, jinja_env=jinja_env)
else:
rendered_content = self.render_template(
value,
context,
jinja_env,
seen_oids,
)
except Exception:
value_masked = redact(name=attr_name, value=value)
self.log.exception(
"Exception rendering Jinja template for task '%s', field '%s'. Template: %r",
self.task_id,
attr_name,
value_masked,
)
raise
else:
setattr(parent, attr_name, rendered_content)

def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]:
"""
Return mapped nodes that are direct dependencies of the current task.
Expand Down Expand Up @@ -582,7 +512,7 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence

def render_template_fields(
self,
context: Context,
context: Mapping[str, Any],
jinja_env: jinja2.Environment | None = None,
) -> None:
"""
Expand Down
18 changes: 0 additions & 18 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@
if TYPE_CHECKING:
from types import ClassMethodDescriptorType

import jinja2 # Slow import.
from sqlalchemy.orm import Session

from airflow.models.abstractoperator import TaskStateChangeCallback
Expand Down Expand Up @@ -738,23 +737,6 @@ def post_execute(self, context: Any, result: Any = None):
logger=self.log,
).run(context, result)

def render_template_fields(
self,
context: Context,
jinja_env: jinja2.Environment | None = None,
) -> None:
"""
Template all attributes listed in *self.template_fields*.
This mutates the attributes in-place and is irreversible.
:param context: Context dict with values to apply on content.
:param jinja_env: Jinja's environment to use for rendering.
"""
if not jinja_env:
jinja_env = self.get_template_env()
self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())

@provide_session
def clear(
self,
Expand Down
13 changes: 7 additions & 6 deletions airflow/models/expandinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import attr

from airflow.utils.mixins import ResolveMixin
from airflow.sdk.definitions.mixins import ResolveMixin
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
Expand All @@ -35,7 +35,6 @@
from airflow.models.xcom_arg import XComArg
from airflow.serialization.serialized_objects import _ExpandInputRef
from airflow.typing_compat import TypeGuard
from airflow.utils.context import Context

ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"]

Expand Down Expand Up @@ -69,7 +68,9 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]:
yield from self._input.iter_references()

@provide_session
def resolve(self, context: Context, *, include_xcom: bool = True, session: Session = NEW_SESSION) -> Any:
def resolve(
self, context: Mapping[str, Any], *, include_xcom: bool = True, session: Session = NEW_SESSION
) -> Any:
data, _ = self._input.resolve(context, session=session, include_xcom=include_xcom)
return data[self._key]

Expand Down Expand Up @@ -166,7 +167,7 @@ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1)

def _expand_mapped_field(
self, key: str, value: Any, context: Context, *, session: Session, include_xcom: bool
self, key: str, value: Any, context: Mapping[str, Any], *, session: Session, include_xcom: bool
) -> Any:
if _needs_run_time_resolution(value):
value = (
Expand Down Expand Up @@ -210,7 +211,7 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]:
yield from x.iter_references()

def resolve(
self, context: Context, session: Session, *, include_xcom: bool = True
self, context: Mapping[str, Any], session: Session, *, include_xcom: bool = True
) -> tuple[Mapping[str, Any], set[int]]:
data = {
k: self._expand_mapped_field(k, v, context, session=session, include_xcom=include_xcom)
Expand Down Expand Up @@ -260,7 +261,7 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]:
yield from x.iter_references()

def resolve(
self, context: Context, session: Session, *, include_xcom: bool = True
self, context: Mapping[str, Any], session: Session, *, include_xcom: bool = True
) -> tuple[Mapping[str, Any], set[int]]:
map_index = context["ti"].map_index
if map_index < 0:
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
return DagAttributeTypes.OP, self.task_id

def _expand_mapped_kwargs(
self, context: Context, session: Session, *, include_xcom: bool
self, context: Mapping[str, Any], session: Session, *, include_xcom: bool
) -> tuple[Mapping[str, Any], set[int]]:
"""
Get the kwargs to create the unmapped operator.
Expand Down Expand Up @@ -869,7 +869,7 @@ def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:

def render_template_fields(
self,
context: Context,
context: Mapping[str, Any],
jinja_env: jinja2.Environment | None = None,
) -> None:
"""
Expand Down
7 changes: 3 additions & 4 deletions airflow/models/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@
import copy
import json
import logging
from collections.abc import ItemsView, Iterable, MutableMapping, ValuesView
from collections.abc import ItemsView, Iterable, Mapping, MutableMapping, ValuesView
from typing import TYPE_CHECKING, Any, ClassVar

from airflow.exceptions import AirflowException, ParamValidationError
from airflow.utils.mixins import ResolveMixin
from airflow.sdk.definitions.mixins import ResolveMixin
from airflow.utils.types import NOTSET, ArgNotSet

if TYPE_CHECKING:
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.sdk.definitions.dag import DAG
from airflow.utils.context import Context

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -295,7 +294,7 @@ def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET):
def iter_references(self) -> Iterable[tuple[Operator, str]]:
return ()

def resolve(self, context: Context, *, include_xcom: bool = True) -> Any:
def resolve(self, context: Mapping[str, Any], *, include_xcom: bool = True) -> Any:
"""Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
with contextlib.suppress(KeyError):
return context["dag_run"].conf[self._name]
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@
from airflow.models.xcom import LazyXComSelectSequence, XCom
from airflow.plugins_manager import integrate_macros_plugins
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef
from airflow.sdk.definitions.templater import SandboxedEnvironment
from airflow.sentry import Sentry
from airflow.settings import task_instance_mutation_hook
from airflow.stats import Stats
from airflow.templates import SandboxedEnvironment
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
from airflow.traces.tracer import Trace
Expand Down
Loading

0 comments on commit f2cc837

Please sign in to comment.