Skip to content

Commit

Permalink
Bulkheads: Protect complex Task subclass methods with @bulkhead
Browse files Browse the repository at this point in the history
  • Loading branch information
davidfstr committed Feb 10, 2024
1 parent e89a582 commit bcdb415
Showing 1 changed file with 106 additions and 12 deletions.
118 changes: 106 additions & 12 deletions src/crystal/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@
bg_affinity, bg_call_later, fg_affinity, fg_call_and_wait, fg_call_later,
is_foreground_thread, NoForegroundThreadError
)
from functools import wraps
import os
from overrides import overrides
import shutil
import sys
from time import sleep
from types import TracebackType
import traceback
from typing import (
Any, Callable, cast, final, List, Literal, Iterator, Optional,
Sequence, Tuple, TYPE_CHECKING, Union
Sequence, Tuple, TYPE_CHECKING, TypeVar, Union
)
from typing_extensions import Concatenate, ParamSpec
from weakref import WeakSet

if TYPE_CHECKING:
Expand All @@ -44,6 +47,59 @@
_PROFILE_SCHEDULER = False


_TK = TypeVar('_TK', bound='Task')
_P = ParamSpec('_P')
_R = TypeVar('_R')


# ------------------------------------------------------------------------------
# Task Crashes


CrashReason = BaseException # with .__traceback__ set to a TracebackType


def bulkhead(
task_method: 'Callable[Concatenate[_TK, _P], Optional[_R]]'
) -> 'Callable[Concatenate[_TK, _P], Optional[_R]]':
"""
A method of Task (or a subclass) that captures any exceptions raised in
its interior as the "crash reason" of the task rather than reraising
the exception in its caller.
If the task was already crashed (with a non-None "crash reason") when
this method is called, this method will immediately abort, returning None.
"""
@wraps(task_method)
def callable_with_crashes_captured(self: '_TK', *args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
if self.crash_reason is not None:
# Task has already crashed. Abort.
return None
try:
return task_method(self, *args, **kwargs)
except BaseException as e:
# Crash the task. Abort.
self.crash_reason = e
return None
callable_with_crashes_captured._crashes_captured = True # type: ignore[attr-defined]
return callable_with_crashes_captured


def call_bulkhead(
callable_with_crashes_captured: Callable[_P, _R],
/, *args: _P.args,
**kwargs: _P.kwargs
) -> '_R':
"""
Calls a method marked as @bulkhead, which does not reraise exceptions from its interior.
Raises AssertionError if the specified method is not actually marked with @bulkhead.
"""
if getattr(callable_with_crashes_captured, '_crashes_captured', False) != True:
raise AssertionError('Expected callable to be decorated with @bulkhead')
return callable_with_crashes_captured(*args, **kwargs)


# ------------------------------------------------------------------------------
# Task

Expand All @@ -56,9 +112,6 @@
"""One task unit will be executed from each child during a scheduler pass."""


CrashReason = str


class Task(ListenableMixin):
"""
Encapsulates a long-running process that reports its status occasionally.
Expand Down Expand Up @@ -245,6 +298,7 @@ def future(self):
else:
raise ValueError('Container tasks do not define a result by default.')

@bulkhead
def dispose(self) -> None:
"""
Replaces this task's future with a new future that raises a
Expand Down Expand Up @@ -420,6 +474,7 @@ def clear_completed_children(self) -> None:

# === Public Operations ===

@bulkhead
@fg_affinity
def try_get_next_task_unit(self) -> Optional[Callable[[], None]]:
"""
Expand All @@ -434,6 +489,11 @@ def try_get_next_task_unit(self) -> Optional[Callable[[], None]]:
run on any thread.
"""

# If this task previously crashed and either itself or its children are
# in a potentially invalid state, refuse to run this task any further
if self.crash_reason is not None:
return None

if self.complete:
return None

Expand Down Expand Up @@ -523,6 +583,7 @@ def _notify_did_schedule_all_children(self) -> Union[bool, Optional[Callable[[],
return False # children did not change

@bg_affinity
@bulkhead
def _call_self_and_record_result(self):
# (Ignore client requests to cancel)
if self._future is None:
Expand Down Expand Up @@ -833,6 +894,8 @@ def get_future(self, wait_for_embedded: bool=False) -> Future:
self._download_body_with_embedded_future = Future()
return self._download_body_with_embedded_future

@overrides
@bulkhead
def dispose(self) -> None:
super().dispose()
if self._download_body_task is not None:
Expand All @@ -842,11 +905,13 @@ def dispose(self) -> None:

# === Events ===

@bulkhead
def child_task_subtitle_did_change(self, task: Task) -> None:
if task is self._download_body_task:
if not task.complete:
self.subtitle = task.subtitle

@bulkhead
def child_task_did_complete(self, task: Task) -> None:
from crystal.model import (
ProjectHasTooManyRevisionsError, RevisionBodyMissingError
Expand Down Expand Up @@ -1119,6 +1184,8 @@ def record_links() -> List[Resource]:

return (links, linked_resources)

@overrides
@bulkhead
def dispose(self) -> None:
super().dispose()
self._resource_revision = None
Expand Down Expand Up @@ -1262,10 +1329,12 @@ def __init__(self, group: ResourceGroup) -> None:
self.task_did_complete(download_task)
# (NOTE: self.complete might be True now)

@bulkhead
def child_task_subtitle_did_change(self, task: Task) -> None:
if not task.complete:
self.subtitle = task.subtitle

@bulkhead
def child_task_did_complete(self, task: Task) -> None:
task.dispose()

Expand Down Expand Up @@ -1308,15 +1377,20 @@ def __init__(self, group: ResourceGroup) -> None:
self._children_loaded = False

@overrides
@bulkhead
def try_get_next_task_unit(self) -> Optional[Callable[[], None]]:
if not self._children_loaded:
def fg_task() -> None:
self._load_children()
self._update_completed_status()
return lambda: fg_call_and_wait(fg_task)
return self._load_children_and_update_completed_status

return super().try_get_next_task_unit()

@bulkhead
def _load_children_and_update_completed_status(self) -> None:
def fg_task() -> None:
self._load_children()
self._update_completed_status()
fg_call_and_wait(fg_task)

@fg_affinity
def _load_children(self) -> None:
if self._children_loaded:
Expand Down Expand Up @@ -1377,6 +1451,7 @@ def unmaterializeitem(t: DownloadResourceTask) -> None:

assert self._children_loaded # because set earlier in this function

@bulkhead
def group_did_add_member(self, group: ResourceGroup, member: Resource) -> None:
if self._LAZY_LOAD_CHILDREN:
self.notify_did_append_child(None)
Expand All @@ -1396,11 +1471,13 @@ def group_did_add_member(self, group: ResourceGroup, member: Resource) -> None:
self.task_did_complete(download_task)
# (NOTE: self.complete might be True now)

@bulkhead
def group_did_finish_updating(self) -> None:
self._done_updating_group = True
self._update_subtitle()
self._update_completed_status()

@bulkhead
def child_task_did_complete(self, task: Task) -> None:
task.dispose()

Expand Down Expand Up @@ -1485,13 +1562,15 @@ def __init__(self, group: ResourceGroup) -> None:
def group(self) -> ResourceGroup:
return self._update_members_task.group

@bulkhead
def child_task_subtitle_did_change(self, task: Task) -> None:
if task == self._update_members_task and not self._started_downloading_members:
self.subtitle = 'Updating group members...'
elif task == self._download_members_task:
self.subtitle = task.subtitle
self._started_downloading_members = True

@bulkhead
def child_task_did_complete(self, task: Task) -> None:
task.dispose()

Expand Down Expand Up @@ -1554,8 +1633,10 @@ def fg_task() -> None:
# NOTE: Must synchronize access to RootTask.children with foreground thread
fg_call_and_wait(fg_task)

@overrides
@bulkhead
@fg_affinity
def try_get_next_task_unit(self):
def try_get_next_task_unit(self) -> Optional[Callable[[], None]]:
if self.complete:
return None

Expand All @@ -1565,13 +1646,19 @@ def try_get_next_task_unit(self):

return super().try_get_next_task_unit()

# === Events ===

@bulkhead
def child_task_did_complete(self, task: Task) -> None:
task.dispose()
call_bulkhead(task.dispose)

@bulkhead
def did_schedule_all_children(self) -> None:
# Remove completed children after each scheduling pass
self.clear_completed_children()

# === Protected Operations: Finish & Cleanup ===

@overrides
def clear_children_if_all_complete(self) -> bool:
raise NotImplementedError(
Expand All @@ -1587,13 +1674,17 @@ def fg_task() -> None:
# NOTE: Must synchronize access to RootTask.children with foreground thread
fg_call_and_wait(fg_task)

# === Public Operations ===

def interrupt(self) -> None:
"""
Stop all descendent tasks, asynchronously,
by interrupting the scheduler thread.
"""
self.finish()

# === Utility ===

def __repr__(self) -> str:
return f'<RootTask at 0x{id(self):x}>'

Expand Down Expand Up @@ -1642,7 +1733,7 @@ def bg_task() -> None:
with profiling_context as profiler:
while True:
def fg_task() -> Tuple[Optional[Callable[[], None]], bool]:
return (task.try_get_next_task_unit(), task.complete)
return (call_bulkhead(task.try_get_next_task_unit), task.complete)
try:
(unit, task_complete) = fg_call_and_wait(fg_task) # traceback: ignore
except NoForegroundThreadError:
Expand All @@ -1654,8 +1745,11 @@ def fg_task() -> Tuple[Optional[Callable[[], None]], bool]:
else:
sleep(_ROOT_TASK_POLL_INTERVAL)
continue
# TODO: All except clauses below are probably dead code,
# since call_bulkhead() doesn't raise exceptions.
# Remove this dead code.
try:
unit() # Run unit directly on this bg thread
call_bulkhead(unit) # Run unit directly on this bg thread
except NoForegroundThreadError:
# Probably the app was closed. Ignore error.
return
Expand Down

0 comments on commit bcdb415

Please sign in to comment.