diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index f3e3b8322cad1..e289d50d7f18a 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -27,9 +27,10 @@ from __future__ import annotations +import json import logging import traceback -from typing import TYPE_CHECKING, NamedTuple +from typing import TYPE_CHECKING, Any, NamedTuple from sqlalchemy import and_, delete, exists, func, select, tuple_ from sqlalchemy.exc import OperationalError @@ -50,6 +51,7 @@ from airflow.models.errors import ParseImportError from airflow.models.trigger import Trigger from airflow.sdk.definitions.asset import Asset, AssetAlias +from airflow.serialization.serialized_objects import BaseSerialization from airflow.triggers.base import BaseTrigger from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries from airflow.utils.sqlalchemy import with_row_locks @@ -64,6 +66,7 @@ from airflow.models.dagwarning import DagWarning from airflow.serialization.serialized_objects import MaybeSerializedDAG + from airflow.triggers.base import BaseTrigger from airflow.typing_compat import Self log = logging.getLogger(__name__) @@ -652,9 +655,9 @@ def add_asset_trigger_references( self, assets: dict[tuple[str, str], AssetModel], *, session: Session ) -> None: # Update references from assets being used - refs_to_add: dict[tuple[str, str], set[str]] = {} - refs_to_remove: dict[tuple[str, str], set[str]] = {} - triggers: dict[str, BaseTrigger] = {} + refs_to_add: dict[tuple[str, str], set[int]] = {} + refs_to_remove: dict[tuple[str, str], set[int]] = {} + triggers: dict[int, BaseTrigger] = {} # Optimization: if no asset collected, skip fetching active assets active_assets = _find_active_assets(self.assets.keys(), session=session) if self.assets else {} @@ -662,40 +665,40 @@ def add_asset_trigger_references( for name_uri, asset in self.assets.items(): # If the asset belong to a DAG not active or paused, consider there is no watcher associated to it asset_watchers = asset.watchers if name_uri in active_assets else [] - trigger_repr_to_trigger_dict: dict[str, BaseTrigger] = { - repr(trigger): trigger for trigger in asset_watchers + trigger_hash_to_trigger_dict: dict[int, BaseTrigger] = { + self._get_base_trigger_hash(trigger): trigger for trigger in asset_watchers } - triggers.update(trigger_repr_to_trigger_dict) - trigger_repr_from_asset: set[str] = set(trigger_repr_to_trigger_dict.keys()) + triggers.update(trigger_hash_to_trigger_dict) + trigger_hash_from_asset: set[int] = set(trigger_hash_to_trigger_dict.keys()) asset_model = assets[name_uri] - trigger_repr_from_asset_model: set[str] = { - BaseTrigger.repr(trigger.classpath, trigger.kwargs) for trigger in asset_model.triggers + trigger_hash_from_asset_model: set[int] = { + self._get_trigger_hash(trigger.classpath, trigger.kwargs) for trigger in asset_model.triggers } # Optimization: no diff between the DB and DAG definitions, no update needed - if trigger_repr_from_asset == trigger_repr_from_asset_model: + if trigger_hash_from_asset == trigger_hash_from_asset_model: continue - diff_to_add = trigger_repr_from_asset - trigger_repr_from_asset_model - diff_to_remove = trigger_repr_from_asset_model - trigger_repr_from_asset + diff_to_add = trigger_hash_from_asset - trigger_hash_from_asset_model + diff_to_remove = trigger_hash_from_asset_model - trigger_hash_from_asset if diff_to_add: refs_to_add[name_uri] = diff_to_add if diff_to_remove: refs_to_remove[name_uri] = diff_to_remove if refs_to_add: - all_trigger_reprs: set[str] = { - trigger_repr for trigger_reprs in refs_to_add.values() for trigger_repr in trigger_reprs + all_trigger_hashes: set[int] = { + trigger_hash for trigger_hashes in refs_to_add.values() for trigger_hash in trigger_hashes } all_trigger_keys: set[tuple[str, str]] = { - self._encrypt_trigger_kwargs(triggers[trigger_repr]) - for trigger_reprs in refs_to_add.values() - for trigger_repr in trigger_reprs + self._encrypt_trigger_kwargs(triggers[trigger_hash]) + for trigger_hashes in refs_to_add.values() + for trigger_hash in trigger_hashes } - orm_triggers: dict[str, Trigger] = { - BaseTrigger.repr(trigger.classpath, trigger.kwargs): trigger + orm_triggers: dict[int, Trigger] = { + self._get_trigger_hash(trigger.classpath, trigger.kwargs): trigger for trigger in session.scalars( select(Trigger).where( tuple_(Trigger.classpath, Trigger.encrypted_kwargs).in_(all_trigger_keys) @@ -707,32 +710,32 @@ def add_asset_trigger_references( new_trigger_models = [ trigger for trigger in [ - Trigger.from_object(triggers[trigger_repr]) - for trigger_repr in all_trigger_reprs - if trigger_repr not in orm_triggers + Trigger.from_object(triggers[trigger_hash]) + for trigger_hash in all_trigger_hashes + if trigger_hash not in orm_triggers ] ] session.add_all(new_trigger_models) orm_triggers.update( - (BaseTrigger.repr(trigger.classpath, trigger.kwargs), trigger) + (self._get_trigger_hash(trigger.classpath, trigger.kwargs), trigger) for trigger in new_trigger_models ) # Add new references - for name_uri, trigger_reprs in refs_to_add.items(): + for name_uri, trigger_hashes in refs_to_add.items(): asset_model = assets[name_uri] asset_model.triggers.extend( - [orm_triggers.get(trigger_repr) for trigger_repr in trigger_reprs] + [orm_triggers.get(trigger_hash) for trigger_hash in trigger_hashes] ) if refs_to_remove: # Remove old references - for name_uri, trigger_reprs in refs_to_remove.items(): + for name_uri, trigger_hashes in refs_to_remove.items(): asset_model = assets[name_uri] asset_model.triggers = [ trigger for trigger in asset_model.triggers - if BaseTrigger.repr(trigger.classpath, trigger.kwargs) not in trigger_reprs + if self._get_trigger_hash(trigger.classpath, trigger.kwargs) not in trigger_hashes ] # Remove references from assets no longer used @@ -747,3 +750,19 @@ def add_asset_trigger_references( def _encrypt_trigger_kwargs(trigger: BaseTrigger) -> tuple[str, str]: classpath, kwargs = trigger.serialize() return classpath, Trigger.encrypt_kwargs(kwargs) + + @staticmethod + def _get_trigger_hash(classpath: str, kwargs: dict[str, Any]) -> int: + """ + Return the hash of the trigger classpath and kwargs. This is used to uniquely identify a trigger. + + We do not want to move this logic in a `__hash__` method in `BaseTrigger` because we do not want to + make the triggers hashable. The reason being, when the triggerer retrieve the list of triggers, we do + not want it dedupe them. When used to defer tasks, 2 triggers can have the same classpath and kwargs. + This is not true for event driven scheduling. + """ + return hash((classpath, json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8"))) + + def _get_base_trigger_hash(self, trigger: BaseTrigger) -> int: + classpath, kwargs = trigger.serialize() + return self._get_trigger_hash(classpath, kwargs)