Skip to content

Commit

Permalink
Convert the LocalExecutor to run tasks using new Task SDK supervisor …
Browse files Browse the repository at this point in the history
…code

This also lays the groundwork for a more general purpose "workload" execution
system, make a single interface for executors to run tasks and callbacks.

Also in this PR we set up the supervise function to send Task logs to a file,
and handle the task log template rendering in the scheduler before queueing
the workload.

Additionally we don't pass the activity directly to `supervise()` but instead
the properties/fields of it to reduce the coupling between SDK and Executor.
(More separation will appear in PRs over the next few weeks.)

The big change of note here is that rather than sending an airflow command
line to execute (`["airflow", "tasks", "run", ...]`) and going back in via the
CLI parser we go directly to a special purpose function. Much simpler.

It doesn't remove any of the old behaviour (CeleryExecutor still uses
LocalTaskJob via the CLI parser etc.), nor does anything currently send
callback requests via this new workload mechanism.

The `airflow.executors.workloads` module currently needs to be shared between
the Scheduler (or more specifically the Executor) and the "worker" side of
things. In the future these will be separate python dists and this module will
need to live somewhere else.

Right now we check the if `executor.queue_workload` is different from the
BaseExecutor version (which just raises an error right now) to see which
executors support this new version. That check will be removed as soon as all
the in-tree executors have been migrated.
  • Loading branch information
ashb committed Nov 27, 2024
1 parent 43adccf commit 17f1a5e
Show file tree
Hide file tree
Showing 13 changed files with 355 additions and 180 deletions.
2 changes: 1 addition & 1 deletion airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ logging:
is_template: true
default: "dag_id={{ ti.dag_id }}/run_id={{ ti.run_id }}/task_id={{ ti.task_id }}/\
{%% if ti.map_index >= 0 %%}map_index={{ ti.map_index }}/{%% endif %%}\
attempt={{ try_number }}.log"
attempt={{ ti.try_number }}.log"
log_processor_filename_template:
description: |
Formatting for how airflow generates file names for log
Expand Down
7 changes: 7 additions & 0 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from airflow.callbacks.base_callback_sink import BaseCallbackSink
from airflow.callbacks.callback_requests import CallbackRequest
from airflow.cli.cli_config import GroupCommand
from airflow.executors import workloads
from airflow.executors.executor_utils import ExecutorName
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
Expand Down Expand Up @@ -170,6 +171,9 @@ def queue_command(
else:
self.log.error("could not queue task %s", task_instance.key)

def queue_workload(self, workload: workloads.All) -> None:
raise ValueError(f"Un-handled workload kind {type(workload).__name__!r} in {type(self).__name__}")

def queue_task_instance(
self,
task_instance: TaskInstance,
Expand Down Expand Up @@ -409,6 +413,9 @@ def _process_tasks(self, task_tuples: list[TaskTuple]) -> None:
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config)
self.running.add(key)

# TODO: This should not be using `TaskInstanceState` here, this is just "did the process complete, or did
# it die". It is possible for the task itself to finish with success, but the state of the task to be set
# to FAILED. By using TaskInstanceState enum here it confuses matters!
def change_state(
self, key: TaskInstanceKey, state: TaskInstanceState, info=None, remove_running=True
) -> None:
Expand Down
136 changes: 36 additions & 100 deletions airflow/executors/local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,24 @@
import multiprocessing
import multiprocessing.sharedctypes
import os
import subprocess
from multiprocessing import Queue, SimpleQueue
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Optional

from setproctitle import setproctitle

from airflow import settings
from airflow.executors.base_executor import PARALLELISM, BaseExecutor
from airflow.traces.tracer import add_span
from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
from airflow.executors.base_executor import CommandType
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.executors import workloads

# This is a work to be executed by a worker.
# It can Key and Command - but it can also be None, None which is actually a
# "Poison Pill" - worker seeing Poison Pill should take the pill and ... die instantly.
ExecutorWorkType = Optional[tuple[TaskInstanceKey, CommandType]]
TaskInstanceStateType = tuple[TaskInstanceKey, TaskInstanceState, Optional[Exception]]
TaskInstanceStateType = tuple[workloads.TaskInstance, TaskInstanceState, Optional[Exception]]


def _run_worker(
logger_name: str,
input: SimpleQueue[ExecutorWorkType],
input: SimpleQueue[workloads.All | None],
output: Queue[TaskInstanceStateType],
unread_messages: multiprocessing.sharedctypes.Synchronized[int],
):
Expand All @@ -65,16 +57,16 @@ def _run_worker(
signal.signal(signal.SIGINT, signal.SIG_IGN)

log = logging.getLogger(logger_name)
log.info("Worker starting up pid=%d", os.getpid())

# We know we've just started a new process, so lets disconnect from the metadata db now
settings.engine.pool.dispose()
settings.engine.dispose()

setproctitle("airflow worker -- LocalExecutor: <idle>")

while True:
setproctitle("airflow worker -- LocalExecutor: <idle>")
try:
item = input.get()
workload = input.get()
except EOFError:
log.info(
"Failed to read tasks from the task queue because the other "
Expand All @@ -83,88 +75,47 @@ def _run_worker(
)
break

if item is None:
if workload is None:
# Received poison pill, no more tasks to run
return

# Decrement this as soon as we pick up a message off the queue
with unread_messages:
unread_messages.value -= 1
key = None
if ti := getattr(workload, "ti", None):
key = ti.key
else:
raise TypeError(f"Don't know how to get ti key from {type(workload).__name__}")

(key, command) = item
try:
state = _execute_work(log, key, command)
_execute_work(log, workload)

output.put((key, state, None))
output.put((key, TaskInstanceState.SUCCESS, None))
except Exception as e:
log.exception("uhoh")
output.put((key, TaskInstanceState.FAILED, e))


def _execute_work(log: logging.Logger, key: TaskInstanceKey, command: CommandType) -> TaskInstanceState:
def _execute_work(log: logging.Logger, workload: workloads.ExecuteTask) -> None:
"""
Execute command received and stores result state in queue.
:param key: the key to identify the task instance
:param command: the command to execute
"""
setproctitle(f"airflow worker -- LocalExecutor: {command}")
dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(command)
try:
with _airflow_parsing_context_manager(dag_id=dag_id, task_id=task_id):
if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER:
return _execute_work_in_subprocess(log, command)
else:
return _execute_work_in_fork(log, command)
finally:
# Remove the command since the worker is done executing the task
setproctitle("airflow worker -- LocalExecutor: <idle>")


def _execute_work_in_subprocess(log: logging.Logger, command: CommandType) -> TaskInstanceState:
try:
subprocess.check_call(command, close_fds=True)
return TaskInstanceState.SUCCESS
except subprocess.CalledProcessError as e:
log.error("Failed to execute task %s.", e)
return TaskInstanceState.FAILED


def _execute_work_in_fork(log: logging.Logger, command: CommandType) -> TaskInstanceState:
pid = os.fork()
if pid:
# In parent, wait for the child
pid, ret = os.waitpid(pid, 0)
return TaskInstanceState.SUCCESS if ret == 0 else TaskInstanceState.FAILED

from airflow.sentry import Sentry

ret = 1
try:
import signal

from airflow.cli.cli_parser import get_parser

signal.signal(signal.SIGINT, signal.SIG_IGN)
signal.signal(signal.SIGTERM, signal.SIG_DFL)
signal.signal(signal.SIGUSR2, signal.SIG_DFL)
from airflow.sdk.execution_time.supervisor import supervise

parser = get_parser()
# [1:] - remove "airflow" from the start of the command
args = parser.parse_args(command[1:])
args.shut_down_logging = False

setproctitle(f"airflow task supervisor: {command}")

args.func(args)
ret = 0
return TaskInstanceState.SUCCESS
except Exception as e:
log.exception("Failed to execute task %s.", e)
return TaskInstanceState.FAILED
finally:
Sentry.flush()
logging.shutdown()
os._exit(ret)
setproctitle(f"airflow worker -- LocalExecutor: {workload.ti.id}")
# This will return the exit code of the task process, but we don't care about that, just if the
# _supervisor_ had an error reporting the state back (which will result in an exception.)
supervise(
ti=workload.ti,
dag_path=workload.dag_path,
token=workload.token,
server="http://localhost:9091/execution/",
log_filename_suffix=workload.log_filename_suffix,
)


class LocalExecutor(BaseExecutor):
Expand All @@ -180,7 +131,7 @@ class LocalExecutor(BaseExecutor):

serve_logs: bool = True

activity_queue: SimpleQueue[ExecutorWorkType]
activity_queue: SimpleQueue[workloads.All | None]
result_queue: SimpleQueue[TaskInstanceStateType]
workers: dict[int, multiprocessing.Process]
_unread_messages: multiprocessing.sharedctypes.Synchronized[int]
Expand All @@ -203,22 +154,7 @@ def start(self) -> None:
# (it looks like an int to python)
self._unread_messages = multiprocessing.Value(ctypes.c_uint) # type: ignore[assignment]

@add_span
def execute_async(
self,
key: TaskInstanceKey,
command: CommandType,
queue: str | None = None,
executor_config: Any | None = None,
) -> None:
"""Execute asynchronously."""
self.validate_airflow_tasks_run_command(command)
self.activity_queue.put((key, command))
with self._unread_messages:
self._unread_messages.value += 1
self._check_workers(can_start=True)

def _check_workers(self, can_start: bool = True):
def _check_workers(self):
# Reap any dead workers
to_remove = set()
for pid, proc in self.workers.items():
Expand Down Expand Up @@ -270,12 +206,6 @@ def _read_results(self):
while not self.result_queue.empty():
key, state, exc = self.result_queue.get()

if exc:
# TODO: This needs a better stacktrace, it appears from here
if hasattr(exc, "add_note"):
exc.add_note("(This stacktrace is incorrect -- the exception came from a subprocess)")
raise exc

self.change_state(key, state)

def end(self) -> None:
Expand Down Expand Up @@ -306,3 +236,9 @@ def end(self) -> None:

def terminate(self):
"""Terminate the executor is not doing anything."""

def queue_workload(self, workload: workloads.All):
self.activity_queue.put(workload)
with self._unread_messages:
self._unread_messages.value += 1
self._check_workers()
96 changes: 96 additions & 0 deletions airflow/executors/workloads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import os
import uuid
from typing import TYPE_CHECKING, Literal, Union

from pydantic import BaseModel, Field

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance as TIModel
from airflow.models.taskinstancekey import TaskInstanceKey


__all__ = [
"All",
"ExecuteTask",
]


class BaseActivity(BaseModel):
token: str
"""The identity token for this workload"""


class TaskInstance(BaseModel):
"""Schema for TaskInstance with minimal required fields needed for Executors and Task SDK."""

id: uuid.UUID

task_id: str
dag_id: str
run_id: str
try_number: int
map_index: int | None = None

# TODO: Task-SDK: see if we can replace TIKey with this class entirely?
@property
def key(self) -> TaskInstanceKey:
from airflow.models.taskinstancekey import TaskInstanceKey

return TaskInstanceKey(
dag_id=self.dag_id,
task_id=self.task_id,
run_id=self.run_id,
try_number=self.try_number,
map_index=-1 if self.map_index is None else self.map_index,
)


class ExecuteTask(BaseActivity):
"""Execute the given Task."""

ti: TaskInstance
"""The TaskInstance to execute"""
dag_path: os.PathLike[str]
"""The filepath where the DAG can be found (likely prefixed with `DAG_FOLDER/`)"""

log_filename_suffix: str | None
"""The rendered log filename template the task logs should be written to"""

kind: Literal["ExecuteTask"] = Field(init=False, default="ExecuteTask")

@classmethod
def make(cls, ti: TIModel) -> ExecuteTask:
from pathlib import Path

from airflow.utils.helpers import log_filename_template_renderer

ser_ti = TaskInstance.model_validate(ti, from_attributes=True)
path = Path(ti.dag_run.dag_model.relative_fileloc)

if path and not path.is_absolute():
# TODO: What about multiple dag sub folders
path = "DAGS_FOLDER" / path

fname = log_filename_template_renderer()(ti=ti)
return cls(ti=ser_ti, dag_path=path, token="", log_filename_suffix=fname)


All = Union[ExecuteTask]
11 changes: 10 additions & 1 deletion airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
from airflow.callbacks.pipe_callback_sink import PipeCallbackSink
from airflow.configuration import conf
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.executors import workloads
from airflow.executors.base_executor import BaseExecutor
from airflow.executors.executor_loader import ExecutorLoader
from airflow.jobs.base_job_runner import BaseJobRunner
from airflow.jobs.job import Job, perform_heartbeat
Expand Down Expand Up @@ -91,7 +93,6 @@
from sqlalchemy.orm import Query, Session

from airflow.dag_processing.manager import DagFileProcessorAgent
from airflow.executors.base_executor import BaseExecutor
from airflow.executors.executor_utils import ExecutorName
from airflow.models.taskinstance import TaskInstanceKey
from airflow.utils.sqlalchemy import (
Expand Down Expand Up @@ -654,6 +655,14 @@ def _enqueue_task_instances_with_queued_state(
if ti.dag_run.state in State.finished_dr_states:
ti.set_state(None, session=session)
continue

# TODO: Task-SDK: This check is transitionary. Remove once all executors are ported over.
# Has a real queue_activity implemented
if executor.queue_workload.__func__ is not BaseExecutor.queue_workload: # type: ignore[attr-defined]
workload = workloads.ExecuteTask.make(ti)
executor.queue_workload(workload)
continue

command = ti.command_as_list(
local=True,
)
Expand Down
Loading

0 comments on commit 17f1a5e

Please sign in to comment.