Skip to content

Commit

Permalink
Don't commit for read-only query (#44905)
Browse files Browse the repository at this point in the history
* Don't commit for read-only query

* Pass session from outside when we can

This does not use the create_session/provide_session/NEW_SESSION
paradigm because it requires importing them globally, which is not
allowed in the SDK. We also do not want to create a session unless
absolutely needed.
  • Loading branch information
uranusjr authored Dec 14, 2024
1 parent 23f59fe commit 4675389
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 14 deletions.
2 changes: 1 addition & 1 deletion airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2281,7 +2281,7 @@ def dag_ready(dag_id: str, cond: BaseAsset, statuses: dict) -> bool | None:
# we may be dealing with old version. In that case,
# just wait for the dag to be reserialized.
try:
return cond.evaluate(statuses)
return cond.evaluate(statuses, session=session)
except AttributeError:
log.warning("dag '%s' has old serialization; skipping DAG run creation.", dag_id)
return None
Expand Down
3 changes: 2 additions & 1 deletion airflow/timetables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

if TYPE_CHECKING:
from pendulum import DateTime
from sqlalchemy.orm import Session

from airflow.sdk.definitions.asset import Asset, AssetAlias
from airflow.serialization.dag_dependency import DagDependency
Expand Down Expand Up @@ -52,7 +53,7 @@ def __and__(self, other: BaseAsset) -> BaseAsset:
def as_expression(self) -> Any:
return None

def evaluate(self, statuses: dict[str, bool]) -> bool:
def evaluate(self, statuses: dict[str, bool], *, session: Session | None = None) -> bool:
return False

def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
Expand Down
19 changes: 11 additions & 8 deletions task_sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import annotations

import contextlib
import logging
import operator
import os
Expand All @@ -32,6 +33,8 @@
from collections.abc import Iterable, Iterator
from urllib.parse import SplitResult

from sqlalchemy.orm import Session

from airflow.models.asset import AssetModel
from airflow.triggers.base import BaseTrigger

Expand Down Expand Up @@ -227,7 +230,7 @@ def as_expression(self) -> Any:
"""
raise NotImplementedError

def evaluate(self, statuses: dict[str, bool]) -> bool:
def evaluate(self, statuses: dict[str, bool], *, session: Session | None = None) -> bool:
raise NotImplementedError

def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
Expand Down Expand Up @@ -385,7 +388,7 @@ def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
def iter_asset_aliases(self) -> Iterator[tuple[str, AssetAlias]]:
return iter(())

def evaluate(self, statuses: dict[str, bool]) -> bool:
def evaluate(self, statuses: dict[str, bool], *, session: Session | None = None) -> bool:
return statuses.get(self.uri, False)

def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDependency]:
Expand Down Expand Up @@ -428,11 +431,11 @@ class AssetAlias(BaseAsset):
name: str = attrs.field(validator=_validate_non_empty_identifier)
group: str = attrs.field(kw_only=True, default="asset", validator=_validate_identifier)

def _resolve_assets(self) -> list[Asset]:
def _resolve_assets(self, session: Session | None = None) -> list[Asset]:
from airflow.models.asset import expand_alias_to_assets
from airflow.utils.session import create_session

with create_session() as session:
with contextlib.nullcontext(session) if session else create_session() as session:
asset_models = expand_alias_to_assets(self.name, session)
return [m.to_public() for m in asset_models]

Expand All @@ -444,8 +447,8 @@ def as_expression(self) -> Any:
"""
return {"alias": {"name": self.name, "group": self.group}}

def evaluate(self, statuses: dict[str, bool]) -> bool:
return any(x.evaluate(statuses=statuses) for x in self._resolve_assets())
def evaluate(self, statuses: dict[str, bool], *, session: Session | None = None) -> bool:
return any(x.evaluate(statuses=statuses, session=session) for x in self._resolve_assets(session))

def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
return iter(())
Expand Down Expand Up @@ -495,8 +498,8 @@ def __init__(self, *objects: BaseAsset) -> None:
raise TypeError("expect asset expressions in condition")
self.objects = objects

def evaluate(self, statuses: dict[str, bool]) -> bool:
return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects)
def evaluate(self, statuses: dict[str, bool], *, session: Session | None = None) -> bool:
return self.agg_func(x.evaluate(statuses=statuses, session=session) for x in self.objects)

def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
seen: set[AssetUniqueKey] = set() # We want to keep the first instance.
Expand Down
6 changes: 4 additions & 2 deletions task_sdk/src/airflow/sdk/definitions/asset/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
if TYPE_CHECKING:
from collections.abc import Callable, Collection, Iterator, Mapping

from sqlalchemy.orm import Session

from airflow.io.path import ObjectStoragePath
from airflow.models.param import ParamsDict
from airflow.sdk.definitions.asset import AssetAlias, AssetUniqueKey
Expand Down Expand Up @@ -120,8 +122,8 @@ def __attrs_post_init__(self) -> None:
with self._source.create_dag(dag_id=self._function.__name__):
_AssetMainOperator.from_definition(self)

def evaluate(self, statuses: dict[str, bool]) -> bool:
return all(o.evaluate(statuses=statuses) for o in self._source.outlets)
def evaluate(self, statuses: dict[str, bool], *, session: Session | None = None) -> bool:
return all(o.evaluate(statuses=statuses, session=session) for o in self._source.outlets)

def iter_assets(self) -> Iterator[tuple[AssetUniqueKey, Asset]]:
for o in self._source.outlets:
Expand Down
4 changes: 2 additions & 2 deletions task_sdk/tests/defintions/test_asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,11 +514,11 @@ def test_as_expression(self, request: pytest.FixtureRequest, alias_fixture_name)

def test_evalute_empty(self, asset_alias_1, asset):
assert asset_alias_1.evaluate({asset.uri: True}) is False
assert asset_alias_1._resolve_assets.mock_calls == [mock.call()]
assert asset_alias_1._resolve_assets.mock_calls == [mock.call(None)]

def test_evalute_resolved(self, resolved_asset_alias_2, asset):
assert resolved_asset_alias_2.evaluate({asset.uri: True}) is True
assert resolved_asset_alias_2._resolve_assets.mock_calls == [mock.call()]
assert resolved_asset_alias_2._resolve_assets.mock_calls == [mock.call(None)]


class TestAssetSubclasses:
Expand Down

0 comments on commit 4675389

Please sign in to comment.