Skip to content

Commit

Permalink
do not push stale update to related DagRun on TI update after task ex…
Browse files Browse the repository at this point in the history
…ecution

Signed-off-by: Maciej Obuchowski <[email protected]>
  • Loading branch information
mobuchowski committed Dec 28, 2024
1 parent a540eeb commit 4142d75
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 8 deletions.
2 changes: 1 addition & 1 deletion airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _execute_callbacks(
dagbag: DagBag, callback_requests: list[CallbackRequest], log: FilteringBoundLogger
) -> None:
for request in callback_requests:
log.debug("Processing Callback Request", request=request)
log.debug("Processing Callback Request", request=request.to_json())
if isinstance(request, TaskCallbackRequest):
raise NotImplementedError(
"Haven't coded Task callback yet - https://github.com/apache/airflow/issues/44354!"
Expand Down
3 changes: 2 additions & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,7 @@ def _handle_failure(
)

if not test_mode:
TaskInstance.save_to_db(failure_context["ti"], session)
TaskInstance.save_to_db(task_instance, session)

with Trace.start_span_from_taskinstance(ti=task_instance) as span:
span.set_attributes(
Expand Down Expand Up @@ -3114,6 +3114,7 @@ def fetch_handle_failure_context(
@staticmethod
@provide_session
def save_to_db(ti: TaskInstance, session: Session = NEW_SESSION):
ti.get_dagrun().refresh_from_db()
ti.updated_at = timezone.utcnow()
session.merge(ti)
session.flush()
Expand Down
52 changes: 46 additions & 6 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from uuid import uuid4

import pendulum
import psutil
import pytest
import time_machine
import uuid6
Expand Down Expand Up @@ -80,6 +81,7 @@
from airflow.sdk.definitions.asset import Asset, AssetAlias
from airflow.sensors.base import BaseSensorOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG
from airflow.settings import reconfigure_orm
from airflow.stats import Stats
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
Expand Down Expand Up @@ -2290,11 +2292,11 @@ def test_outlet_assets(self, create_task_instance):
dagbag.collect_dags(only_if_updated=False, safe_mode=False)
dagbag.sync_to_db(session=session)
run_id = str(uuid4())
dr = DagRun(dag1.dag_id, run_id=run_id, run_type="anything")
session.merge(dr)
dr = DagRun(dag1.dag_id, run_id=run_id, run_type="anything", logical_date=timezone.utcnow())
task = dag1.get_task("producing_task_1")
task.bash_command = "echo 1" # make it go faster
ti = TaskInstance(task, run_id=run_id)
ti.dag_run = dr
session.merge(ti)
session.commit()
ti._run_raw_task()
Expand Down Expand Up @@ -2349,10 +2351,12 @@ def test_outlet_assets_failed(self, create_task_instance):
dagbag.collect_dags(only_if_updated=False, safe_mode=False)
dagbag.sync_to_db(session=session)
run_id = str(uuid4())
dr = DagRun(dag_with_fail_task.dag_id, run_id=run_id, run_type="anything")
session.merge(dr)
dr = DagRun(
dag_with_fail_task.dag_id, run_id=run_id, run_type="anything", logical_date=timezone.utcnow()
)
task = dag_with_fail_task.get_task("fail_task")
ti = TaskInstance(task, run_id=run_id)
ti.dag_run = dr
session.merge(ti)
session.commit()
with pytest.raises(AirflowFailException):
Expand Down Expand Up @@ -2403,10 +2407,12 @@ def test_outlet_assets_skipped(self):
dagbag.collect_dags(only_if_updated=False, safe_mode=False)
dagbag.sync_to_db(session=session)
run_id = str(uuid4())
dr = DagRun(dag_with_skip_task.dag_id, run_id=run_id, run_type="anything")
session.merge(dr)
dr = DagRun(
dag_with_skip_task.dag_id, run_id=run_id, run_type="anything", logical_date=timezone.utcnow()
)
task = dag_with_skip_task.get_task("skip_task")
ti = TaskInstance(task, run_id=run_id)
ti.dag_run = dr
session.merge(ti)
session.commit()
ti._run_raw_task()
Expand Down Expand Up @@ -3518,6 +3524,40 @@ def test_handle_failure(self, create_dummy_dag, session=None):
assert "task_instance" in context_arg_3
mock_on_retry_3.assert_not_called()

@provide_session
def test_handle_failure_does_not_push_stale_dagrun_model(self, dag_maker, create_dummy_dag, session=None):
session = settings.Session()
with dag_maker():

def method(): ...

task = PythonOperator(task_id="mytask", python_callable=method)
dr = dag_maker.create_dagrun()
ti = dr.get_task_instance(task.task_id)
ti.state = State.RUNNING

assert dr.state == DagRunState.RUNNING

session.merge(ti)
session.flush()
session.commit()

pid = os.fork()
if pid:
process = psutil.Process(pid)
dr.state = DagRunState.SUCCESS
session.merge(dr)
session.flush()
session.commit()
process.wait(timeout=5)
else:
reconfigure_orm(disable_connection_pool=True)
ti.handle_failure("should not update related models")
os._exit(0)

dr.refresh_from_db()
assert dr.state == DagRunState.SUCCESS

def test_handle_failure_updates_queued_task_updates_state(self, dag_maker):
session = settings.Session()
with dag_maker():
Expand Down
10 changes: 10 additions & 0 deletions tests/sensors/test_external_task_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,16 @@ def setup_method(self):
self.args = {"owner": "airflow", "start_date": DEFAULT_DATE}
self.dag = DAG(TEST_DAG_ID, schedule=None, default_args=self.args)
self.dag_run_id = DagRunType.MANUAL.generate_run_id(DEFAULT_DATE)
self.dag_run = DagRun(
dag_id=self.dag.dag_id,
run_id=self.dag_run_id,
run_type=DagRunType.MANUAL,
logical_date=DEFAULT_DATE,
)
with create_session() as session:
session.merge(self.dag_run)
session.flush()
session.commit()

def add_time_sensor(self, task_id=TEST_TASK_ID):
op = TimeSensor(task_id=task_id, target_time=time(0), dag=self.dag)
Expand Down

0 comments on commit 4142d75

Please sign in to comment.