diff --git a/src/crystal/task.py b/src/crystal/task.py index 15651360..b02245e6 100644 --- a/src/crystal/task.py +++ b/src/crystal/task.py @@ -18,16 +18,19 @@ bg_affinity, bg_call_later, fg_affinity, fg_call_and_wait, fg_call_later, is_foreground_thread, NoForegroundThreadError ) +from functools import wraps import os from overrides import overrides import shutil import sys from time import sleep +from types import TracebackType import traceback from typing import ( Any, Callable, cast, final, List, Literal, Iterator, Optional, - Sequence, Tuple, TYPE_CHECKING, Union + Sequence, Tuple, TYPE_CHECKING, TypeVar, Union ) +from typing_extensions import Concatenate, ParamSpec from weakref import WeakSet if TYPE_CHECKING: @@ -44,6 +47,59 @@ _PROFILE_SCHEDULER = False +_TK = TypeVar('_TK', bound='Task') +_P = ParamSpec('_P') +_R = TypeVar('_R') + + +# ------------------------------------------------------------------------------ +# Task Crashes + + +CrashReason = BaseException # with .__traceback__ set to a TracebackType + + +def bulkhead( + task_method: 'Callable[Concatenate[_TK, _P], Optional[_R]]' + ) -> 'Callable[Concatenate[_TK, _P], Optional[_R]]': + """ + A method of Task (or a subclass) that captures any exceptions raised in + its interior as the "crash reason" of the task rather than reraising + the exception in its caller. + + If the task was already crashed (with a non-None "crash reason") when + this method is called, this method will immediately abort, returning None. + """ + @wraps(task_method) + def callable_with_crashes_captured(self: '_TK', *args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]: + if self.crash_reason is not None: + # Task has already crashed. Abort. + return None + try: + return task_method(self, *args, **kwargs) + except BaseException as e: + # Crash the task. Abort. + self.crash_reason = e + return None + callable_with_crashes_captured._crashes_captured = True # type: ignore[attr-defined] + return callable_with_crashes_captured + + +def call_bulkhead( + callable_with_crashes_captured: Callable[_P, _R], + /, *args: _P.args, + **kwargs: _P.kwargs + ) -> '_R': + """ + Calls a method marked as @bulkhead, which does not reraise exceptions from its interior. + + Raises AssertionError if the specified method is not actually marked with @bulkhead. + """ + if getattr(callable_with_crashes_captured, '_crashes_captured', False) != True: + raise AssertionError('Expected callable to be decorated with @bulkhead') + return callable_with_crashes_captured(*args, **kwargs) + + # ------------------------------------------------------------------------------ # Task @@ -56,9 +112,6 @@ """One task unit will be executed from each child during a scheduler pass.""" -CrashReason = str - - class Task(ListenableMixin): """ Encapsulates a long-running process that reports its status occasionally. @@ -245,6 +298,7 @@ def future(self): else: raise ValueError('Container tasks do not define a result by default.') + @bulkhead def dispose(self) -> None: """ Replaces this task's future with a new future that raises a @@ -420,6 +474,7 @@ def clear_completed_children(self) -> None: # === Public Operations === + @bulkhead @fg_affinity def try_get_next_task_unit(self) -> Optional[Callable[[], None]]: """ @@ -434,6 +489,11 @@ def try_get_next_task_unit(self) -> Optional[Callable[[], None]]: run on any thread. """ + # If this task previously crashed and either itself or its children are + # in a potentially invalid state, refuse to run this task any further + if self.crash_reason is not None: + return None + if self.complete: return None @@ -523,6 +583,7 @@ def _notify_did_schedule_all_children(self) -> Union[bool, Optional[Callable[[], return False # children did not change @bg_affinity + @bulkhead def _call_self_and_record_result(self): # (Ignore client requests to cancel) if self._future is None: @@ -833,6 +894,8 @@ def get_future(self, wait_for_embedded: bool=False) -> Future: self._download_body_with_embedded_future = Future() return self._download_body_with_embedded_future + @overrides + @bulkhead def dispose(self) -> None: super().dispose() if self._download_body_task is not None: @@ -842,11 +905,13 @@ def dispose(self) -> None: # === Events === + @bulkhead def child_task_subtitle_did_change(self, task: Task) -> None: if task is self._download_body_task: if not task.complete: self.subtitle = task.subtitle + @bulkhead def child_task_did_complete(self, task: Task) -> None: from crystal.model import ( ProjectHasTooManyRevisionsError, RevisionBodyMissingError @@ -1119,6 +1184,8 @@ def record_links() -> List[Resource]: return (links, linked_resources) + @overrides + @bulkhead def dispose(self) -> None: super().dispose() self._resource_revision = None @@ -1262,10 +1329,12 @@ def __init__(self, group: ResourceGroup) -> None: self.task_did_complete(download_task) # (NOTE: self.complete might be True now) + @bulkhead def child_task_subtitle_did_change(self, task: Task) -> None: if not task.complete: self.subtitle = task.subtitle + @bulkhead def child_task_did_complete(self, task: Task) -> None: task.dispose() @@ -1308,15 +1377,20 @@ def __init__(self, group: ResourceGroup) -> None: self._children_loaded = False @overrides + @bulkhead def try_get_next_task_unit(self) -> Optional[Callable[[], None]]: if not self._children_loaded: - def fg_task() -> None: - self._load_children() - self._update_completed_status() - return lambda: fg_call_and_wait(fg_task) + return self._load_children_and_update_completed_status return super().try_get_next_task_unit() + @bulkhead + def _load_children_and_update_completed_status(self) -> None: + def fg_task() -> None: + self._load_children() + self._update_completed_status() + fg_call_and_wait(fg_task) + @fg_affinity def _load_children(self) -> None: if self._children_loaded: @@ -1377,6 +1451,7 @@ def unmaterializeitem(t: DownloadResourceTask) -> None: assert self._children_loaded # because set earlier in this function + @bulkhead def group_did_add_member(self, group: ResourceGroup, member: Resource) -> None: if self._LAZY_LOAD_CHILDREN: self.notify_did_append_child(None) @@ -1396,11 +1471,13 @@ def group_did_add_member(self, group: ResourceGroup, member: Resource) -> None: self.task_did_complete(download_task) # (NOTE: self.complete might be True now) + @bulkhead def group_did_finish_updating(self) -> None: self._done_updating_group = True self._update_subtitle() self._update_completed_status() + @bulkhead def child_task_did_complete(self, task: Task) -> None: task.dispose() @@ -1485,6 +1562,7 @@ def __init__(self, group: ResourceGroup) -> None: def group(self) -> ResourceGroup: return self._update_members_task.group + @bulkhead def child_task_subtitle_did_change(self, task: Task) -> None: if task == self._update_members_task and not self._started_downloading_members: self.subtitle = 'Updating group members...' @@ -1492,6 +1570,7 @@ def child_task_subtitle_did_change(self, task: Task) -> None: self.subtitle = task.subtitle self._started_downloading_members = True + @bulkhead def child_task_did_complete(self, task: Task) -> None: task.dispose() @@ -1554,8 +1633,10 @@ def fg_task() -> None: # NOTE: Must synchronize access to RootTask.children with foreground thread fg_call_and_wait(fg_task) + @overrides + @bulkhead @fg_affinity - def try_get_next_task_unit(self): + def try_get_next_task_unit(self) -> Optional[Callable[[], None]]: if self.complete: return None @@ -1565,13 +1646,19 @@ def try_get_next_task_unit(self): return super().try_get_next_task_unit() + # === Events === + + @bulkhead def child_task_did_complete(self, task: Task) -> None: - task.dispose() + call_bulkhead(task.dispose) + @bulkhead def did_schedule_all_children(self) -> None: # Remove completed children after each scheduling pass self.clear_completed_children() + # === Protected Operations: Finish & Cleanup === + @overrides def clear_children_if_all_complete(self) -> bool: raise NotImplementedError( @@ -1587,6 +1674,8 @@ def fg_task() -> None: # NOTE: Must synchronize access to RootTask.children with foreground thread fg_call_and_wait(fg_task) + # === Public Operations === + def interrupt(self) -> None: """ Stop all descendent tasks, asynchronously, @@ -1594,6 +1683,8 @@ def interrupt(self) -> None: """ self.finish() + # === Utility === + def __repr__(self) -> str: return f'' @@ -1642,7 +1733,7 @@ def bg_task() -> None: with profiling_context as profiler: while True: def fg_task() -> Tuple[Optional[Callable[[], None]], bool]: - return (task.try_get_next_task_unit(), task.complete) + return (call_bulkhead(task.try_get_next_task_unit), task.complete) try: (unit, task_complete) = fg_call_and_wait(fg_task) # traceback: ignore except NoForegroundThreadError: @@ -1654,8 +1745,11 @@ def fg_task() -> Tuple[Optional[Callable[[], None]], bool]: else: sleep(_ROOT_TASK_POLL_INTERVAL) continue + # TODO: All except clauses below are probably dead code, + # since call_bulkhead() doesn't raise exceptions. + # Remove this dead code. try: - unit() # Run unit directly on this bg thread + call_bulkhead(unit) # Run unit directly on this bg thread except NoForegroundThreadError: # Probably the app was closed. Ignore error. return