diff --git a/temporalio/bridge/sdk-core b/temporalio/bridge/sdk-core index 871b320c8..199880d2f 160000 --- a/temporalio/bridge/sdk-core +++ b/temporalio/bridge/sdk-core @@ -1 +1 @@ -Subproject commit 871b320c8f51d52cb69fcc31f9c4dcd47b9f3961 +Subproject commit 199880d2f5673895e6437bc39a031243c7b7861c diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index 1b412cb7f..358d34090 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -298,6 +298,7 @@ class StartNexusOperationInput(Generic[InputT, OutputT]): operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]] input: InputT schedule_to_close_timeout: Optional[timedelta] + cancellation_type: temporalio.workflow.NexusOperationCancellationType headers: Optional[Mapping[str, str]] output_type: Optional[Type[OutputT]] = None diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index bf67a9756..ada236ab2 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -54,13 +54,13 @@ import temporalio.bridge.proto.activity_result import temporalio.bridge.proto.child_workflow import temporalio.bridge.proto.common +import temporalio.bridge.proto.nexus import temporalio.bridge.proto.workflow_activation import temporalio.bridge.proto.workflow_commands import temporalio.bridge.proto.workflow_completion import temporalio.common import temporalio.converter import temporalio.exceptions -import temporalio.nexus import temporalio.workflow from temporalio.service import __version__ @@ -881,9 +881,17 @@ def _apply_resolve_nexus_operation( ) -> None: handle = self._pending_nexus_operations.pop(job.seq, None) if not handle: - raise RuntimeError( - f"Failed to find nexus operation handle for job sequence number {job.seq}" - ) + # One way this can occur is: + # 1. Cancel request issued with cancellation_type=WaitRequested. + # 2. Server receives nexus cancel handler task completion and writes + # NexusOperationCancelRequestCompleted / NexusOperationCancelRequestFailed. On + # consuming this event, core sends an activation resolving the handle future as + # completed / failed. + # 4. Subsequently, the nexus operation completes as completed/failed, causing the server + # to write NexusOperationCompleted / NexusOperationFailed. On consuming this event, + # core sends an activation which would attempt to resolve the handle future as + # completed / failed, but it has already been resolved. + return # Handle the four oneof variants of NexusOperationResult result = job.result @@ -1500,9 +1508,10 @@ async def workflow_start_nexus_operation( service: str, operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]], input: Any, - output_type: Optional[Type[OutputT]] = None, - schedule_to_close_timeout: Optional[timedelta] = None, - headers: Optional[Mapping[str, str]] = None, + output_type: Optional[Type[OutputT]], + schedule_to_close_timeout: Optional[timedelta], + cancellation_type: temporalio.workflow.NexusOperationCancellationType, + headers: Optional[Mapping[str, str]], ) -> temporalio.workflow.NexusOperationHandle[OutputT]: # start_nexus_operation return await self._outbound.start_nexus_operation( @@ -1513,6 +1522,7 @@ async def workflow_start_nexus_operation( input=input, output_type=output_type, schedule_to_close_timeout=schedule_to_close_timeout, + cancellation_type=cancellation_type, headers=headers, ) ) @@ -1824,20 +1834,19 @@ async def run_child() -> Any: async def _outbound_start_nexus_operation( self, input: StartNexusOperationInput[Any, OutputT] ) -> _NexusOperationHandle[OutputT]: - # A Nexus operation handle contains two futures: self._start_fut is resolved as a - # result of the Nexus operation starting (activation job: - # resolve_nexus_operation_start), and self._result_fut is resolved as a result of - # the Nexus operation completing (activation job: resolve_nexus_operation). The - # handle itself corresponds to an asyncio.Task which waits on self.result_fut, - # handling CancelledError by emitting a RequestCancelNexusOperation command. We do - # not return the handle until we receive resolve_nexus_operation_start, like - # ChildWorkflowHandle and unlike ActivityHandle. Note that a Nexus operation may - # complete synchronously (in which case both jobs will be sent in the same - # activation, and start will be resolved without an operation token), or - # asynchronously (in which case start they may be sent in separate activations, - # and start will be resolved with an operation token). See comments in - # tests/worker/test_nexus.py for worked examples of the evolution of the resulting - # handle state machine in the sync and async Nexus response cases. + # A Nexus operation handle contains two futures: self._start_fut is resolved as a result of + # the Nexus operation starting (activation job: resolve_nexus_operation_start), and + # self._result_fut is resolved as a result of the Nexus operation completing (activation + # job: resolve_nexus_operation). The handle itself corresponds to an asyncio.Task which + # waits on self.result_fut, handling CancelledError by emitting a + # RequestCancelNexusOperation command. We do not return the handle until we receive + # resolve_nexus_operation_start, like ChildWorkflowHandle and unlike ActivityHandle. Note + # that a Nexus operation may complete synchronously (in which case both jobs will be sent in + # the same activation, and start will be resolved without an operation token), or + # asynchronously (in which case they may be sent in separate activations, and start will be + # resolved with an operation token). See comments in tests/worker/test_nexus.py for worked + # examples of the evolution of the resulting handle state machine in the sync and async + # Nexus response cases. handle: _NexusOperationHandle[OutputT] async def operation_handle_fn() -> OutputT: @@ -2758,7 +2767,7 @@ def _apply_schedule_command( if self._input.retry_policy: self._input.retry_policy.apply_to_proto(v.retry_policy) v.cancellation_type = cast( - "temporalio.bridge.proto.workflow_commands.ActivityCancellationType.ValueType", + temporalio.bridge.proto.workflow_commands.ActivityCancellationType.ValueType, int(self._input.cancellation_type), ) @@ -2894,7 +2903,7 @@ def _apply_start_command(self) -> None: if self._input.task_timeout: v.workflow_task_timeout.FromTimedelta(self._input.task_timeout) v.parent_close_policy = cast( - "temporalio.bridge.proto.child_workflow.ParentClosePolicy.ValueType", + temporalio.bridge.proto.child_workflow.ParentClosePolicy.ValueType, int(self._input.parent_close_policy), ) v.workflow_id_reuse_policy = cast( @@ -2916,7 +2925,7 @@ def _apply_start_command(self) -> None: self._input.search_attributes, v.search_attributes ) v.cancellation_type = cast( - "temporalio.bridge.proto.child_workflow.ChildWorkflowCancellationType.ValueType", + temporalio.bridge.proto.child_workflow.ChildWorkflowCancellationType.ValueType, int(self._input.cancellation_type), ) if self._input.versioning_intent: @@ -3012,11 +3021,6 @@ def __init__( @property def operation_token(self) -> Optional[str]: - # TODO(nexus-preview): How should this behave? - # Java has a separate class that only exists if the operation token exists: - # https://github.com/temporalio/sdk-java/blob/master/temporal-sdk/src/main/java/io/temporal/internal/sync/NexusOperationExecutionImpl.java#L26 - # And Go similar: - # https://github.com/temporalio/sdk-go/blob/master/internal/workflow.go#L2770-L2771 try: return self._start_fut.result() except BaseException: @@ -3065,6 +3069,11 @@ def _apply_schedule_command(self) -> None: v.schedule_to_close_timeout.FromTimedelta( self._input.schedule_to_close_timeout ) + v.cancellation_type = cast( + temporalio.bridge.proto.nexus.NexusOperationCancellationType.ValueType, + int(self._input.cancellation_type), + ) + if self._input.headers: for key, val in self._input.headers.items(): v.nexus_header[key] = val diff --git a/temporalio/workflow.py b/temporalio/workflow.py index f26dcad12..423d5289b 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -859,9 +859,10 @@ async def workflow_start_nexus_operation( service: str, operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]], input: Any, - output_type: Optional[Type[OutputT]] = None, - schedule_to_close_timeout: Optional[timedelta] = None, - headers: Optional[Mapping[str, str]] = None, + output_type: Optional[Type[OutputT]], + schedule_to_close_timeout: Optional[timedelta], + cancellation_type: temporalio.workflow.NexusOperationCancellationType, + headers: Optional[Mapping[str, str]], ) -> NexusOperationHandle[OutputT]: ... @abstractmethod @@ -1322,9 +1323,9 @@ async def sleep( This can be in single-line Temporal markdown format. """ await _Runtime.current().workflow_sleep( - duration=duration.total_seconds() - if isinstance(duration, timedelta) - else duration, + duration=( + duration.total_seconds() if isinstance(duration, timedelta) else duration + ), summary=summary, ) @@ -4413,6 +4414,8 @@ class NexusOperationHandle(Generic[OutputT]): This API is experimental and unstable. """ + # TODO(nexus-preview): should attempts to instantiate directly throw? + def cancel(self) -> bool: """Request cancellation of the operation.""" raise NotImplementedError @@ -5138,6 +5141,43 @@ def _to_proto(self) -> temporalio.bridge.proto.common.VersioningIntent.ValueType ServiceT = TypeVar("ServiceT") +class NexusOperationCancellationType(IntEnum): + """Defines behavior of a Nexus operation when the caller workflow initiates cancellation. + + Pass one of these values to :py:meth:`NexusClient.start_operation` to define cancellation + behavior. + + To initiate cancellation, use :py:meth:`NexusOperationHandle.cancel` and then `await` the + operation handle. This will result in a :py:class:`exceptions.NexusOperationError`. The values + of this enum define what is guaranteed to have happened by that point. + """ + + ABANDON = int(temporalio.bridge.proto.nexus.NexusOperationCancellationType.ABANDON) + """Do not send any cancellation request to the operation handler; just report cancellation to the caller""" + + TRY_CANCEL = int( + temporalio.bridge.proto.nexus.NexusOperationCancellationType.TRY_CANCEL + ) + """Send a cancellation request but immediately report cancellation to the caller. Note that this + does not guarantee that cancellation is delivered to the operation handler if the caller exits + before the delivery is done. + """ + + WAIT_REQUESTED = int( + temporalio.bridge.proto.nexus.NexusOperationCancellationType.WAIT_CANCELLATION_REQUESTED + ) + """Send a cancellation request and wait for confirmation that the request was received. + Does not wait for the operation to complete. + """ + + WAIT_COMPLETED = int( + temporalio.bridge.proto.nexus.NexusOperationCancellationType.WAIT_CANCELLATION_COMPLETED + ) + """Send a cancellation request and wait for the operation to complete. + Note that the operation may not complete as cancelled (for example, if it catches the + :py:exc:`asyncio.CancelledError` resulting from the cancellation request).""" + + class NexusClient(ABC, Generic[ServiceT]): """A client for invoking Nexus operations. @@ -5168,6 +5208,7 @@ async def start_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> NexusOperationHandle[OutputT]: ... @@ -5181,6 +5222,7 @@ async def start_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> NexusOperationHandle[OutputT]: ... @@ -5197,6 +5239,7 @@ async def start_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> NexusOperationHandle[OutputT]: ... @@ -5213,6 +5256,7 @@ async def start_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> NexusOperationHandle[OutputT]: ... @@ -5229,6 +5273,7 @@ async def start_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> NexusOperationHandle[OutputT]: ... @@ -5240,6 +5285,7 @@ async def start_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> Any: """Start a Nexus operation and return its handle. @@ -5269,6 +5315,7 @@ async def execute_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> OutputT: ... @@ -5282,6 +5329,7 @@ async def execute_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> OutputT: ... @@ -5298,6 +5346,7 @@ async def execute_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> OutputT: ... @@ -5317,6 +5366,7 @@ async def execute_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> OutputT: ... @@ -5333,6 +5383,7 @@ async def execute_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> OutputT: ... @@ -5344,6 +5395,7 @@ async def execute_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> Any: """Execute a Nexus operation and return its result. @@ -5395,6 +5447,7 @@ async def start_operation( *, output_type: Optional[Type] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> Any: return ( @@ -5405,6 +5458,7 @@ async def start_operation( input=input, output_type=output_type, schedule_to_close_timeout=schedule_to_close_timeout, + cancellation_type=cancellation_type, headers=headers, ) ) @@ -5416,6 +5470,7 @@ async def execute_operation( *, output_type: Optional[Type] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> Any: handle = await self.start_operation( @@ -5423,6 +5478,7 @@ async def execute_operation( input, output_type=output_type, schedule_to_close_timeout=schedule_to_close_timeout, + cancellation_type=cancellation_type, headers=headers, ) return await handle diff --git a/tests/conftest.py b/tests/conftest.py index 7d9f0157d..fa868530a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -118,6 +118,8 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: "system.enableDeploymentVersions=true", "--dynamic-config-value", "frontend.activityAPIsEnabled=true", + "--dynamic-config-value", + "component.nexusoperations.recordCancelRequestCompletionEvents=true", "--http-port", str(http_port), ], diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index a352877d5..79d3687fd 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -3,11 +3,14 @@ import time import uuid from contextlib import closing -from datetime import timedelta -from typing import Any, Awaitable, Callable, Optional, Sequence, Type, TypeVar +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import Any, Awaitable, Callable, Optional, Sequence, Type, TypeVar, Union from temporalio.api.common.v1 import WorkflowExecution +from temporalio.api.enums.v1 import EventType as EventType from temporalio.api.enums.v1 import IndexedValueType +from temporalio.api.history.v1 import HistoryEvent from temporalio.api.operatorservice.v1 import ( AddSearchAttributesRequest, ListSearchAttributesRequest, @@ -21,6 +24,7 @@ ) from temporalio.client import BuildIdOpAddNewDefault, Client, WorkflowHandle from temporalio.common import SearchAttributeKey +from temporalio.converter import DataConverter from temporalio.service import RPCError, RPCStatusCode from temporalio.worker import Worker, WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner @@ -287,3 +291,113 @@ async def check_unpaused() -> bool: return not info.paused await assert_eventually(check_unpaused) + + +async def print_history(handle: WorkflowHandle): + i = 1 + async for evt in handle.fetch_history_events(): + event = EventType.Name(evt.event_type).removeprefix("EVENT_TYPE_") + print(f"{i:2}: {event}") + i += 1 + + +@dataclass +class InterleavedHistoryEvent: + handle: WorkflowHandle + event: Union[HistoryEvent, str] + number: Optional[int] + time: datetime + + +async def print_interleaved_histories( + handles: list[WorkflowHandle], + extra_events: Optional[list[tuple[WorkflowHandle, str, datetime]]] = None, +) -> None: + """ + Print the interleaved history events from multiple workflow handles in columns. + + A column entry looks like + + : + + where is the number of milliseconds since the first event in any of the workflows. + """ + all_events: list[InterleavedHistoryEvent] = [] + workflow_start_times: dict[WorkflowHandle, datetime] = {} + + for handle in handles: + event_num = 1 + first_event = True + async for history_event in handle.fetch_history_events(): + event_time = history_event.event_time.ToDatetime() + if first_event: + workflow_start_times[handle] = event_time + first_event = False + all_events.append( + InterleavedHistoryEvent(handle, history_event, event_num, event_time) + ) + event_num += 1 + + if extra_events: + for handle, event_str, event_time in extra_events: + # Ensure timezone-naive + if event_time.tzinfo is not None: + event_time = event_time.astimezone(timezone.utc).replace(tzinfo=None) + all_events.append( + InterleavedHistoryEvent(handle, event_str, None, event_time) + ) + + zero_time = min(workflow_start_times.values()) + + all_events.sort(key=lambda item: item.time) + col_width = 50 + + def _format_row(items: list[str], truncate: bool = False) -> str: + if truncate: + items = [item[: col_width - 3] for item in items] + return " | ".join(f"{item:<{col_width - 3}}" for item in items) + + headers = [handle.id for handle in handles] + print("\n" + _format_row(headers, truncate=True)) + print("-" * (col_width * len(handles) + len(handles) - 1)) + + for event in all_events: + elapsed_ms = int((event.time - zero_time).total_seconds() * 1000) + + if isinstance(event.event, str): + event_desc = f" *: {elapsed_ms:>4} {event.event}" + summary = None + else: + event_type = EventType.Name(event.event.event_type).removeprefix( + "EVENT_TYPE_" + ) + event_desc = f"{event.number:2}: {elapsed_ms:>4} {event_type}" + + # Extract summary from user_metadata if present + summary = None + if event.event.HasField( + "user_metadata" + ) and event.event.user_metadata.HasField("summary"): + try: + summary = DataConverter.default.payload_converter.from_payload( + event.event.user_metadata.summary + ) + except Exception: + pass # Ignore decoding errors + + row = [""] * len(handles) + col_idx = handles.index(event.handle) + row[col_idx] = event_desc[: col_width - 3] + print(_format_row(row)) + + # Print summary on new line if present + if summary: + summary_row = [""] * len(handles) + # Left-align with event type name (after ": ") + # Calculate the padding needed + if event.number is not None: + padding = len(f"{event.number:2}: {elapsed_ms:>4} ") + else: + padding = len(f" *: {elapsed_ms:>4} ") + summary_row[col_idx] = f"{' ' * padding}[{summary}]"[: col_width - 3] + print(_format_row(summary_row)) diff --git a/tests/nexus/test_workflow_caller_cancellation_types.py b/tests/nexus/test_workflow_caller_cancellation_types.py new file mode 100644 index 000000000..7cbbb95c2 --- /dev/null +++ b/tests/nexus/test_workflow_caller_cancellation_types.py @@ -0,0 +1,498 @@ +import asyncio +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Optional + +import nexusrpc +import nexusrpc.handler._decorators +import pytest + +import temporalio.nexus._operation_handlers +from temporalio import exceptions, nexus, workflow +from temporalio.api.enums.v1 import EventType +from temporalio.api.history.v1 import HistoryEvent +from temporalio.client import ( + WithStartWorkflowOperation, + WorkflowExecutionStatus, + WorkflowFailureError, + WorkflowHandle, +) +from temporalio.common import WorkflowIDConflictPolicy +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker +from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name + + +@dataclass +class TestContext: + __test__ = False + cancellation_type: workflow.NexusOperationCancellationType + caller_op_future_resolved: asyncio.Future[datetime] = field( + default_factory=asyncio.Future + ) + cancel_handler_released: asyncio.Future[datetime] = field( + default_factory=asyncio.Future + ) + + +test_context: TestContext + + +@workflow.defn(sandboxed=False) +class HandlerWorkflow: + def __init__(self): + self.caller_op_future_resolved = asyncio.Event() + + @workflow.run + async def run(self) -> None: + try: + await asyncio.Future() + except asyncio.CancelledError: + if test_context.cancellation_type in [ + workflow.NexusOperationCancellationType.TRY_CANCEL, + workflow.NexusOperationCancellationType.WAIT_REQUESTED, + ]: + # We want to prove that the caller op future can be resolved before the operation + # (i.e. its backing workflow) is cancelled. + await self.caller_op_future_resolved.wait() + raise + + @workflow.signal + def set_caller_op_future_resolved(self) -> None: + self.caller_op_future_resolved.set() + + +@nexusrpc.service +class Service: + workflow_op: nexusrpc.Operation[None, None] + + +class WorkflowOpHandler( + temporalio.nexus._operation_handlers.WorkflowRunOperationHandler +): + def __init__(self): + pass + + async def start( + self, ctx: nexusrpc.handler.StartOperationContext, input: None + ) -> nexusrpc.handler.StartOperationResultAsync: + tctx = nexus.WorkflowRunOperationContext._from_start_operation_context(ctx) + handle = await tctx.start_workflow( + HandlerWorkflow.run, + id="handler-wf-" + str(uuid.uuid4()), + ) + return nexusrpc.handler.StartOperationResultAsync(token=handle.to_token()) + + async def cancel( + self, ctx: nexusrpc.handler.CancelOperationContext, token: str + ) -> None: + if ( + test_context.cancellation_type + == workflow.NexusOperationCancellationType.TRY_CANCEL + ): + # When this cancel handler returns, a nexus task completion will be sent to the handler + # server, and the handler server will respond to the nexus cancel request that was made + # by the caller server. At that point, the caller server will write + # NexusOperationCancelRequestCompleted. For TRY_CANCEL we want to prove that the nexus + # op handle future can be resolved as cancelled before any of that. + await test_context.caller_op_future_resolved + test_context.cancel_handler_released.set_result(datetime.now(timezone.utc)) + await super().cancel(ctx, token) + + +@nexusrpc.handler.service_handler(service=Service) +class ServiceHandler: + @nexusrpc.handler._decorators.operation_handler + def workflow_op(self) -> nexusrpc.handler.OperationHandler[None, None]: + return WorkflowOpHandler() + + +@dataclass +class Input: + endpoint: str + cancellation_type: Optional[workflow.NexusOperationCancellationType] + + +@dataclass +class CancellationResult: + operation_token: str + + +@workflow.defn(sandboxed=False) +class CallerWorkflow: + @workflow.init + def __init__(self, input: Input): + self.nexus_client = workflow.create_nexus_client( + service=Service, + endpoint=input.endpoint, + ) + self.released = False + self.operation_token: Optional[str] = None + + @workflow.signal + def release(self): + self.released = True + + @workflow.update + async def get_operation_token(self) -> str: + await workflow.wait_condition(lambda: self.operation_token is not None) + assert self.operation_token + return self.operation_token + + @workflow.run + async def run(self, input: Input) -> CancellationResult: + op_handle = await ( + self.nexus_client.start_operation( + Service.workflow_op, + input=None, + cancellation_type=input.cancellation_type, + ) + if input.cancellation_type is not None + else self.nexus_client.start_operation(Service.workflow_op, input=None) + ) + self.operation_token = op_handle.operation_token + assert self.operation_token + # Request cancellation of the asyncio task representing the nexus operation. When the handle + # task is awaited, the resulting asyncio.CancelledError is caught, and a + # RequestCancelNexusOperation command is emitted instead (see + # _WorkflowInstanceImpl._outbound_start_nexus_operation). + # + # On processing this command in the activation completion, the sdk-core nexus_operation + # state machine transitions behaves as follows for the different cancellation types: + # + # - Abandon and TryCancel: Immediately resolves with cancellation (i.e. via a second + # activation in the same WFT) + # + # For non-Abandon types, a RequestCancelNexusOperation command is sent to the server: + # + # - TryCancel: Immediately resolve the handle task as cancelled, but also cause the server + # to write NexusOperationCancelRequested. + # + # - WaitCancellationRequested: waits for NexusOperationCancelRequestCompleted (i.e. nexus op + # cancel handler has responded) before sending an activation job to Python + # resolving the nexus operation as cancelled + # - WaitCancellationCompleted: waits for NexusOperationCanceled (e.g. backing workflow has + # closed as cancelled) before sending an activation job to Python resolving the + # nexus operation as cancelled + op_handle.cancel() + if ( + test_context.cancellation_type + == workflow.NexusOperationCancellationType.WAIT_REQUESTED + ): + # For WAIT_REQUESTED, we need core to receive the NexusOperationCancelRequestCompleted + # event. That event should trigger a workflow task, but does not currently due to + # https://github.com/temporalio/temporal/issues/8175. Force a new WFT, allowing time for + # the event hopefully to arrive. + await workflow.sleep(0.1, summary="Force new WFT") + try: + await op_handle + except exceptions.NexusOperationError: + test_context.caller_op_future_resolved.set_result( + datetime.now(timezone.utc) + ) + assert op_handle.operation_token + if input.cancellation_type in [ + workflow.NexusOperationCancellationType.TRY_CANCEL, + workflow.NexusOperationCancellationType.WAIT_REQUESTED, + ]: + # We want to prove that the future can be unblocked before the handler workflow is + # cancelled. Send a signal, so that handler workflow can wait for it. + await workflow.get_external_workflow_handle_for( + HandlerWorkflow.run, + workflow_id=( + nexus.WorkflowHandle[None] + .from_token(self.operation_token) + .workflow_id + ), + ).signal(HandlerWorkflow.set_caller_op_future_resolved) + + await workflow.wait_condition(lambda: self.released) + return CancellationResult( + operation_token=op_handle.operation_token, + ) + else: + pytest.fail("Expected NexusOperationError") + + +@pytest.mark.parametrize( + "cancellation_type_name", + [ + workflow.NexusOperationCancellationType.ABANDON.name, + workflow.NexusOperationCancellationType.TRY_CANCEL.name, + workflow.NexusOperationCancellationType.WAIT_REQUESTED.name, + workflow.NexusOperationCancellationType.WAIT_COMPLETED.name, + ], +) +async def test_cancellation_type( + env: WorkflowEnvironment, + cancellation_type_name: str, +): + if env.supports_time_skipping: + pytest.skip("Nexus tests don't work with time-skipping server") + + cancellation_type = workflow.NexusOperationCancellationType[cancellation_type_name] + global test_context + test_context = TestContext(cancellation_type=cancellation_type) + + client = env.client + + async with Worker( + client, + task_queue=str(uuid.uuid4()), + workflows=[CallerWorkflow, HandlerWorkflow], + nexus_service_handlers=[ServiceHandler()], + ) as worker: + await create_nexus_endpoint(worker.task_queue, client) + + # Start the caller workflow, wait for the nexus op to have started and retrieve the nexus op + # token + with_start_workflow = WithStartWorkflowOperation( + CallerWorkflow.run, + Input( + endpoint=make_nexus_endpoint_name(worker.task_queue), + cancellation_type=cancellation_type, + ), + id="caller-wf-" + str(uuid.uuid4()), + task_queue=worker.task_queue, + id_conflict_policy=WorkflowIDConflictPolicy.FAIL, + ) + + operation_token = await client.execute_update_with_start_workflow( + CallerWorkflow.get_operation_token, + start_workflow_operation=with_start_workflow, + ) + handler_wf = ( + nexus.WorkflowHandle[None] + .from_token(operation_token) + ._to_client_workflow_handle(client) + ) + caller_wf = await with_start_workflow.workflow_handle() + + if cancellation_type == workflow.NexusOperationCancellationType.ABANDON: + await check_behavior_for_abandon(caller_wf, handler_wf) + elif cancellation_type == workflow.NexusOperationCancellationType.TRY_CANCEL: + await check_behavior_for_try_cancel(caller_wf, handler_wf) + elif ( + cancellation_type == workflow.NexusOperationCancellationType.WAIT_REQUESTED + ): + await check_behavior_for_wait_cancellation_requested(caller_wf, handler_wf) + elif ( + cancellation_type == workflow.NexusOperationCancellationType.WAIT_COMPLETED + ): + await check_behavior_for_wait_cancellation_completed(caller_wf, handler_wf) + else: + pytest.fail(f"Invalid cancellation type: {cancellation_type}") + + +async def check_behavior_for_abandon( + caller_wf: WorkflowHandle, + handler_wf: WorkflowHandle, +) -> None: + """ + Check that a cancellation request is not sent. + """ + handler_status = (await handler_wf.describe()).status + assert handler_status == WorkflowExecutionStatus.RUNNING + await caller_wf.signal(CallerWorkflow.release) + await caller_wf.result() + await assert_event_subsequence( + [ + (caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED), + (caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED), + ] + ) + assert not await has_event( + caller_wf, + EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED, + ) + + +async def check_behavior_for_try_cancel( + caller_wf: WorkflowHandle[Any, CancellationResult], + handler_wf: WorkflowHandle[Any, None], +) -> None: + """ + Check that a cancellation request is sent and the caller workflow nexus op future is unblocked + as cancelled before the cancel handler returns (i.e. before the + NexusOperationCancelRequestCompleted in the caller workflow history). + """ + try: + await handler_wf.result() + except WorkflowFailureError as err: + assert isinstance(err.__cause__, exceptions.CancelledError) + else: + pytest.fail("Expected WorkflowFailureError") + await caller_wf.signal(CallerWorkflow.release) + await caller_wf.result() + + handler_status = (await handler_wf.describe()).status + assert handler_status == WorkflowExecutionStatus.CANCELED + caller_op_future_resolved = test_context.caller_op_future_resolved.result() + await assert_event_subsequence( + [ + (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED), + (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED), + (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCELED), + ] + ) + op_cancel_requested_event = await get_event_time( + caller_wf, + EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED, + ) + op_cancel_request_completed_event = await get_event_time( + caller_wf, + EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED, + ) + assert ( + caller_op_future_resolved + < op_cancel_requested_event + < op_cancel_request_completed_event + ) + + +async def check_behavior_for_wait_cancellation_requested( + caller_wf: WorkflowHandle[Any, CancellationResult], + handler_wf: WorkflowHandle, +) -> None: + """ + Check that a cancellation request is sent and the caller workflow nexus operation future is + unblocked as cancelled after the cancel handler returns (i.e. after the + NexusOperationCancelRequestCompleted in the caller workflow history) but without waiting for + the operation to be canceled. + """ + try: + await handler_wf.result() + except WorkflowFailureError as err: + assert isinstance(err.__cause__, exceptions.CancelledError) + else: + pytest.fail("Expected WorkflowFailureError") + + await caller_wf.signal(CallerWorkflow.release) + await caller_wf.result() + + handler_status = (await handler_wf.describe()).status + assert handler_status == WorkflowExecutionStatus.CANCELED + await assert_event_subsequence( + [ + (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED), + (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED), + (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCELED), + ] + ) + caller_op_future_resolved = test_context.caller_op_future_resolved.result() + op_cancel_request_completed = await get_event_time( + caller_wf, + EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED, + ) + op_canceled = await get_event_time( + handler_wf, + EventType.EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED, + ) + assert op_cancel_request_completed < caller_op_future_resolved < op_canceled + + +async def check_behavior_for_wait_cancellation_completed( + caller_wf: WorkflowHandle[Any, CancellationResult], + handler_wf: WorkflowHandle, +) -> None: + """ + Check that a cancellation request is sent and the caller workflow nexus operation future is + unblocked after the operation is canceled. + """ + try: + await handler_wf.result() + except WorkflowFailureError as err: + assert isinstance(err.__cause__, exceptions.CancelledError) + else: + pytest.fail("Expected WorkflowFailureError") + + handler_status = (await handler_wf.describe()).status + assert handler_status == WorkflowExecutionStatus.CANCELED + + await caller_wf.signal(CallerWorkflow.release) + await caller_wf.result() + + await assert_event_subsequence( + [ + (caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED), + (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED), + ( + handler_wf, + EventType.EVENT_TYPE_WORKFLOW_EXECUTION_CANCEL_REQUESTED, + ), + (handler_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED), + (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCELED), + (caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED), + ] + ) + caller_op_future_resolved = test_context.caller_op_future_resolved.result() + handler_wf_canceled_event_time = await get_event_time( + handler_wf, + EventType.EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED, + ) + assert caller_op_future_resolved > handler_wf_canceled_event_time + + +async def has_event(wf_handle: WorkflowHandle, event_type: EventType.ValueType): + async for e in wf_handle.fetch_history_events(): + if e.event_type == event_type: + return True + return False + + +async def get_event_time( + wf_handle: WorkflowHandle, + event_type: EventType.ValueType, +) -> datetime: + async for event in wf_handle.fetch_history_events(): + if event.event_type == event_type: + return event.event_time.ToDatetime().replace(tzinfo=timezone.utc) + event_type_name = EventType.Name(event_type).removeprefix("EVENT_TYPE_") + assert False, f"Event {event_type_name} not found in {wf_handle.id}" + + +async def assert_event_subsequence( + expected_events: list[tuple[WorkflowHandle, EventType.ValueType]], +) -> None: + """ + Given a sequence of (WorkflowHandle, EventType) pairs, assert that the sorted sequence of events + from both workflows contains that subsequence. + """ + + def _event_time( + item: tuple[WorkflowHandle, HistoryEvent], + ) -> datetime: + return item[1].event_time.ToDatetime() + + all_events = [] + handles = {h for h, _ in expected_events} + for h in handles: + async for e in h.fetch_history_events(): + all_events.append((h, e)) + _all_events = iter(sorted(all_events, key=_event_time)) + _expected_events = iter(expected_events) + + previous_expected_handle, previous_expected_event_type_name = None, None + for expected_handle, expected_event_type in _expected_events: + expected_event_type_name = EventType.Name(expected_event_type).removeprefix( + "EVENT_TYPE_" + ) + has_expected = next( + ( + (h, e) + for h, e in _all_events + if h == expected_handle and e.event_type == expected_event_type + ), + None, + ) + if not has_expected: + if previous_expected_handle is not None: + prefix = f"After {previous_expected_event_type_name} in {previous_expected_handle.id}, " + else: + prefix = "" + pytest.fail( + f"{prefix}expected {expected_event_type_name} in {expected_handle.id}" + ) + previous_expected_event_type_name = expected_event_type_name + previous_expected_handle = expected_handle diff --git a/tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py b/tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py new file mode 100644 index 000000000..9585a8445 --- /dev/null +++ b/tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py @@ -0,0 +1,381 @@ +""" +See sibling file test_workflow_caller_cancellation_types.py for explanatory comments. +""" + +import asyncio +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Optional + +import nexusrpc +import nexusrpc.handler._decorators +import pytest + +import temporalio.nexus._operation_handlers +from temporalio import exceptions, nexus, workflow +from temporalio.api.enums.v1 import EventType +from temporalio.client import ( + WithStartWorkflowOperation, + WorkflowExecutionStatus, + WorkflowHandle, +) +from temporalio.common import WorkflowIDConflictPolicy +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker +from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name +from tests.nexus.test_workflow_caller_cancellation_types import ( + assert_event_subsequence, + get_event_time, + has_event, +) + + +@dataclass +class TestContext: + __test__ = False + cancellation_type: workflow.NexusOperationCancellationType + caller_op_future_resolved: asyncio.Future[datetime] = field( + default_factory=asyncio.Future + ) + cancel_handler_released: asyncio.Future[datetime] = field( + default_factory=asyncio.Future + ) + + +test_context: TestContext + + +@workflow.defn(sandboxed=False) +class HandlerWorkflow: + def __init__(self): + self.cancel_handler_released = asyncio.Event() + self.caller_op_future_resolved = asyncio.Event() + + @workflow.run + async def run(self) -> None: + # We want the cancel handler to be invoked, so this workflow must not close before + # then. + await self.cancel_handler_released.wait() + if ( + test_context.cancellation_type + == workflow.NexusOperationCancellationType.WAIT_REQUESTED + ): + # For WAIT_REQUESTED, we want to prove that the future can be unblocked before the + # handler workflow completes. + await self.caller_op_future_resolved.wait() + + @workflow.signal + def set_cancel_handler_released(self) -> None: + self.cancel_handler_released.set() + + @workflow.signal + def set_caller_op_future_resolved(self) -> None: + self.caller_op_future_resolved.set() + + +@nexusrpc.service +class Service: + workflow_op: nexusrpc.Operation[None, None] + + +class WorkflowOpHandler( + temporalio.nexus._operation_handlers.WorkflowRunOperationHandler +): + def __init__(self): + pass + + async def start( + self, ctx: nexusrpc.handler.StartOperationContext, input: None + ) -> nexusrpc.handler.StartOperationResultAsync: + tctx = nexus.WorkflowRunOperationContext._from_start_operation_context(ctx) + handle = await tctx.start_workflow( + HandlerWorkflow.run, + id="handler-wf-" + str(uuid.uuid4()), + ) + return nexusrpc.handler.StartOperationResultAsync(token=handle.to_token()) + + async def cancel( + self, ctx: nexusrpc.handler.CancelOperationContext, token: str + ) -> None: + client = nexus.client() + handler_wf: WorkflowHandle[HandlerWorkflow, None] = ( + client.get_workflow_handle_for( + HandlerWorkflow.run, + workflow_id=nexus.WorkflowHandle[None].from_token(token).workflow_id, + ) + ) + await handler_wf.signal(HandlerWorkflow.set_cancel_handler_released) + test_context.cancel_handler_released.set_result(datetime.now(timezone.utc)) + raise nexusrpc.HandlerError( + "Deliberate non-retryable error in cancel handler", + type=nexusrpc.HandlerErrorType.BAD_REQUEST, + ) + + +@nexusrpc.handler.service_handler(service=Service) +class ServiceHandler: + @nexusrpc.handler._decorators.operation_handler + def workflow_op(self) -> nexusrpc.handler.OperationHandler[None, None]: + return WorkflowOpHandler() + + +@dataclass +class Input: + endpoint: str + cancellation_type: Optional[workflow.NexusOperationCancellationType] + + +@dataclass +class CancellationResult: + operation_token: str + error_type: Optional[str] = None + error_cause_type: Optional[str] = None + + +@workflow.defn(sandboxed=False) +class CallerWorkflow: + @workflow.init + def __init__(self, input: Input): + self.nexus_client = workflow.create_nexus_client( + service=Service, + endpoint=input.endpoint, + ) + self.released = False + self.operation_token: Optional[str] = None + + @workflow.signal + def release(self): + self.released = True + + @workflow.update + async def get_operation_token(self) -> str: + await workflow.wait_condition(lambda: self.operation_token is not None) + assert self.operation_token + return self.operation_token + + @workflow.run + async def run(self, input: Input) -> CancellationResult: + op_handle = await ( + self.nexus_client.start_operation( + Service.workflow_op, + input=None, + cancellation_type=input.cancellation_type, + ) + if input.cancellation_type is not None + else self.nexus_client.start_operation(Service.workflow_op, input=None) + ) + self.operation_token = op_handle.operation_token + assert self.operation_token + op_handle.cancel() + if ( + test_context.cancellation_type + == workflow.NexusOperationCancellationType.WAIT_REQUESTED + ): + # For WAIT_REQUESTED, we need core to receive the NexusOperationCancelRequestCompleted + # event. That event should trigger a workflow task, but does not currently due to + # https://github.com/temporalio/temporal/issues/8175. Force a new WFT, allowing time for + # the event hopefully to arrive. + await workflow.sleep(0.1, summary="Force new WFT") + error_type, error_cause_type = None, None + try: + await op_handle + except exceptions.NexusOperationError as err: + error_type = err.__class__.__name__ + error_cause_type = err.__cause__.__class__.__name__ + + test_context.caller_op_future_resolved.set_result(datetime.now(timezone.utc)) + assert op_handle.operation_token + await workflow.wait_condition(lambda: self.released) + return CancellationResult( + operation_token=op_handle.operation_token, + error_type=error_type, + error_cause_type=error_cause_type, + ) + + +@pytest.mark.parametrize( + "cancellation_type_name", + [ + workflow.NexusOperationCancellationType.ABANDON.name, + workflow.NexusOperationCancellationType.TRY_CANCEL.name, + workflow.NexusOperationCancellationType.WAIT_REQUESTED.name, + workflow.NexusOperationCancellationType.WAIT_COMPLETED.name, + ], +) +async def test_cancellation_type( + env: WorkflowEnvironment, + cancellation_type_name: str, +): + if env.supports_time_skipping: + pytest.skip("Nexus tests don't work with time-skipping server") + + cancellation_type = workflow.NexusOperationCancellationType[cancellation_type_name] + global test_context + test_context = TestContext(cancellation_type=cancellation_type) + + client = env.client + + async with Worker( + client, + task_queue=str(uuid.uuid4()), + workflows=[CallerWorkflow, HandlerWorkflow], + nexus_service_handlers=[ServiceHandler()], + ) as worker: + await create_nexus_endpoint(worker.task_queue, client) + + # Start the caller workflow, wait for the nexus op to have started and retrieve the nexus op + # token + with_start_workflow = WithStartWorkflowOperation( + CallerWorkflow.run, + Input( + endpoint=make_nexus_endpoint_name(worker.task_queue), + cancellation_type=cancellation_type, + ), + id="caller-wf-" + str(uuid.uuid4()), + task_queue=worker.task_queue, + id_conflict_policy=WorkflowIDConflictPolicy.FAIL, + ) + + operation_token = await client.execute_update_with_start_workflow( + CallerWorkflow.get_operation_token, + start_workflow_operation=with_start_workflow, + ) + handler_wf = ( + nexus.WorkflowHandle[None] + .from_token(operation_token) + ._to_client_workflow_handle(client) + ) + caller_wf = await with_start_workflow.workflow_handle() + + if cancellation_type == workflow.NexusOperationCancellationType.ABANDON: + await check_behavior_for_abandon(caller_wf, handler_wf) + elif cancellation_type == workflow.NexusOperationCancellationType.TRY_CANCEL: + await check_behavior_for_try_cancel(caller_wf, handler_wf) + elif ( + cancellation_type == workflow.NexusOperationCancellationType.WAIT_REQUESTED + ): + await check_behavior_for_wait_cancellation_requested(caller_wf, handler_wf) + elif ( + cancellation_type == workflow.NexusOperationCancellationType.WAIT_COMPLETED + ): + await check_behavior_for_wait_cancellation_completed(caller_wf, handler_wf) + else: + pytest.fail(f"Invalid cancellation type: {cancellation_type}") + + +async def check_behavior_for_abandon( + caller_wf: WorkflowHandle, + handler_wf: WorkflowHandle, +) -> None: + """ + Check that a cancellation request is not sent. + """ + handler_status = (await handler_wf.describe()).status + assert handler_status == WorkflowExecutionStatus.RUNNING + await caller_wf.signal(CallerWorkflow.release) + result = await caller_wf.result() + assert result.error_type == "NexusOperationError" + assert result.error_cause_type == "CancelledError" + + await assert_event_subsequence( + [ + (caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED), + (caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED), + ] + ) + assert not await has_event( + caller_wf, + EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED, + ) + + +async def check_behavior_for_try_cancel( + caller_wf: WorkflowHandle[Any, CancellationResult], + handler_wf: WorkflowHandle[Any, None], +) -> None: + await handler_wf.result() + await caller_wf.signal(CallerWorkflow.release) + result = await caller_wf.result() + assert result.error_type == "NexusOperationError" + assert result.error_cause_type == "CancelledError" + + caller_op_future_resolved = test_context.caller_op_future_resolved.result() + await assert_event_subsequence( + [ + (caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED), + (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED), + (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED), + ] + ) + op_cancel_requested_event = await get_event_time( + caller_wf, + EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED, + ) + op_cancel_request_failed_event = await get_event_time( + caller_wf, + EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED, + ) + assert ( + caller_op_future_resolved + < op_cancel_requested_event + < op_cancel_request_failed_event + ) + + +async def check_behavior_for_wait_cancellation_requested( + caller_wf: WorkflowHandle[Any, CancellationResult], + handler_wf: WorkflowHandle, +) -> None: + await caller_wf.signal(CallerWorkflow.release) + result = await caller_wf.result() + assert result.error_type == "NexusOperationError" + assert result.error_cause_type == "HandlerError" + await handler_wf.signal(HandlerWorkflow.set_caller_op_future_resolved) + await handler_wf.result() + await assert_event_subsequence( + [ + (caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED), + (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED), + (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED), + (caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED), + ] + ) + caller_op_future_resolved = test_context.caller_op_future_resolved.result() + op_cancel_request_failed = await get_event_time( + caller_wf, + EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED, + ) + handler_wf_completed = await get_event_time( + handler_wf, + EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED, + ) + assert op_cancel_request_failed < caller_op_future_resolved < handler_wf_completed + + +async def check_behavior_for_wait_cancellation_completed( + caller_wf: WorkflowHandle[Any, CancellationResult], + handler_wf: WorkflowHandle, +) -> None: + await handler_wf.result() + await caller_wf.signal(CallerWorkflow.release) + result = await caller_wf.result() + assert not result.error_type + # Note that the relative order of these two events is non-deterministic, since one is the result + # of the cancel handler response being processed and the other is the result of the handler + # workflow exiting. + # (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED) + # (handler_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED) + await assert_event_subsequence( + [ + (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED), + (handler_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED), + (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_COMPLETED), + ] + ) + caller_op_future_resolved = test_context.caller_op_future_resolved.result() + handler_wf_completed = await get_event_time( + handler_wf, + EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED, + ) + assert handler_wf_completed < caller_op_future_resolved