Skip to content

Commit

Permalink
AIP-82 Use hash instead of repr (#44797)
Browse files Browse the repository at this point in the history
  • Loading branch information
vincbeck authored Dec 20, 2024
1 parent 5e0aeef commit cd2ad3c
Showing 1 changed file with 47 additions and 28 deletions.
75 changes: 47 additions & 28 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@

from __future__ import annotations

import json
import logging
import traceback
from typing import TYPE_CHECKING, NamedTuple
from typing import TYPE_CHECKING, Any, NamedTuple

from sqlalchemy import and_, delete, exists, func, select, tuple_
from sqlalchemy.exc import OperationalError
Expand All @@ -50,6 +51,7 @@
from airflow.models.errors import ParseImportError
from airflow.models.trigger import Trigger
from airflow.sdk.definitions.asset import Asset, AssetAlias
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.triggers.base import BaseTrigger
from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
from airflow.utils.sqlalchemy import with_row_locks
Expand All @@ -64,6 +66,7 @@

from airflow.models.dagwarning import DagWarning
from airflow.serialization.serialized_objects import MaybeSerializedDAG
from airflow.triggers.base import BaseTrigger
from airflow.typing_compat import Self

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -652,50 +655,50 @@ def add_asset_trigger_references(
self, assets: dict[tuple[str, str], AssetModel], *, session: Session
) -> None:
# Update references from assets being used
refs_to_add: dict[tuple[str, str], set[str]] = {}
refs_to_remove: dict[tuple[str, str], set[str]] = {}
triggers: dict[str, BaseTrigger] = {}
refs_to_add: dict[tuple[str, str], set[int]] = {}
refs_to_remove: dict[tuple[str, str], set[int]] = {}
triggers: dict[int, BaseTrigger] = {}

# Optimization: if no asset collected, skip fetching active assets
active_assets = _find_active_assets(self.assets.keys(), session=session) if self.assets else {}

for name_uri, asset in self.assets.items():
# If the asset belong to a DAG not active or paused, consider there is no watcher associated to it
asset_watchers = asset.watchers if name_uri in active_assets else []
trigger_repr_to_trigger_dict: dict[str, BaseTrigger] = {
repr(trigger): trigger for trigger in asset_watchers
trigger_hash_to_trigger_dict: dict[int, BaseTrigger] = {
self._get_base_trigger_hash(trigger): trigger for trigger in asset_watchers
}
triggers.update(trigger_repr_to_trigger_dict)
trigger_repr_from_asset: set[str] = set(trigger_repr_to_trigger_dict.keys())
triggers.update(trigger_hash_to_trigger_dict)
trigger_hash_from_asset: set[int] = set(trigger_hash_to_trigger_dict.keys())

asset_model = assets[name_uri]
trigger_repr_from_asset_model: set[str] = {
BaseTrigger.repr(trigger.classpath, trigger.kwargs) for trigger in asset_model.triggers
trigger_hash_from_asset_model: set[int] = {
self._get_trigger_hash(trigger.classpath, trigger.kwargs) for trigger in asset_model.triggers
}

# Optimization: no diff between the DB and DAG definitions, no update needed
if trigger_repr_from_asset == trigger_repr_from_asset_model:
if trigger_hash_from_asset == trigger_hash_from_asset_model:
continue

diff_to_add = trigger_repr_from_asset - trigger_repr_from_asset_model
diff_to_remove = trigger_repr_from_asset_model - trigger_repr_from_asset
diff_to_add = trigger_hash_from_asset - trigger_hash_from_asset_model
diff_to_remove = trigger_hash_from_asset_model - trigger_hash_from_asset
if diff_to_add:
refs_to_add[name_uri] = diff_to_add
if diff_to_remove:
refs_to_remove[name_uri] = diff_to_remove

if refs_to_add:
all_trigger_reprs: set[str] = {
trigger_repr for trigger_reprs in refs_to_add.values() for trigger_repr in trigger_reprs
all_trigger_hashes: set[int] = {
trigger_hash for trigger_hashes in refs_to_add.values() for trigger_hash in trigger_hashes
}

all_trigger_keys: set[tuple[str, str]] = {
self._encrypt_trigger_kwargs(triggers[trigger_repr])
for trigger_reprs in refs_to_add.values()
for trigger_repr in trigger_reprs
self._encrypt_trigger_kwargs(triggers[trigger_hash])
for trigger_hashes in refs_to_add.values()
for trigger_hash in trigger_hashes
}
orm_triggers: dict[str, Trigger] = {
BaseTrigger.repr(trigger.classpath, trigger.kwargs): trigger
orm_triggers: dict[int, Trigger] = {
self._get_trigger_hash(trigger.classpath, trigger.kwargs): trigger
for trigger in session.scalars(
select(Trigger).where(
tuple_(Trigger.classpath, Trigger.encrypted_kwargs).in_(all_trigger_keys)
Expand All @@ -707,32 +710,32 @@ def add_asset_trigger_references(
new_trigger_models = [
trigger
for trigger in [
Trigger.from_object(triggers[trigger_repr])
for trigger_repr in all_trigger_reprs
if trigger_repr not in orm_triggers
Trigger.from_object(triggers[trigger_hash])
for trigger_hash in all_trigger_hashes
if trigger_hash not in orm_triggers
]
]
session.add_all(new_trigger_models)
orm_triggers.update(
(BaseTrigger.repr(trigger.classpath, trigger.kwargs), trigger)
(self._get_trigger_hash(trigger.classpath, trigger.kwargs), trigger)
for trigger in new_trigger_models
)

# Add new references
for name_uri, trigger_reprs in refs_to_add.items():
for name_uri, trigger_hashes in refs_to_add.items():
asset_model = assets[name_uri]
asset_model.triggers.extend(
[orm_triggers.get(trigger_repr) for trigger_repr in trigger_reprs]
[orm_triggers.get(trigger_hash) for trigger_hash in trigger_hashes]
)

if refs_to_remove:
# Remove old references
for name_uri, trigger_reprs in refs_to_remove.items():
for name_uri, trigger_hashes in refs_to_remove.items():
asset_model = assets[name_uri]
asset_model.triggers = [
trigger
for trigger in asset_model.triggers
if BaseTrigger.repr(trigger.classpath, trigger.kwargs) not in trigger_reprs
if self._get_trigger_hash(trigger.classpath, trigger.kwargs) not in trigger_hashes
]

# Remove references from assets no longer used
Expand All @@ -747,3 +750,19 @@ def add_asset_trigger_references(
def _encrypt_trigger_kwargs(trigger: BaseTrigger) -> tuple[str, str]:
classpath, kwargs = trigger.serialize()
return classpath, Trigger.encrypt_kwargs(kwargs)

@staticmethod
def _get_trigger_hash(classpath: str, kwargs: dict[str, Any]) -> int:
"""
Return the hash of the trigger classpath and kwargs. This is used to uniquely identify a trigger.
We do not want to move this logic in a `__hash__` method in `BaseTrigger` because we do not want to
make the triggers hashable. The reason being, when the triggerer retrieve the list of triggers, we do
not want it dedupe them. When used to defer tasks, 2 triggers can have the same classpath and kwargs.
This is not true for event driven scheduling.
"""
return hash((classpath, json.dumps(BaseSerialization.serialize(kwargs)).encode("utf-8")))

def _get_base_trigger_hash(self, trigger: BaseTrigger) -> int:
classpath, kwargs = trigger.serialize()
return self._get_trigger_hash(classpath, kwargs)

0 comments on commit cd2ad3c

Please sign in to comment.