Skip to content

Commit e1319f5

Browse files
committed
Cancellation types for Nexus operations invoked by workflows
1 parent 62604ca commit e1319f5

File tree

5 files changed

+408
-15
lines changed

5 files changed

+408
-15
lines changed

temporalio/worker/_interceptor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ class StartNexusOperationInput(Generic[InputT, OutputT]):
298298
operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]]
299299
input: InputT
300300
schedule_to_close_timeout: Optional[timedelta]
301+
cancellation_type: temporalio.workflow.NexusOperationCancellationType
301302
headers: Optional[Mapping[str, str]]
302303
output_type: Optional[Type[OutputT]] = None
303304

temporalio/worker/_workflow_instance.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@
5454
import temporalio.bridge.proto.activity_result
5555
import temporalio.bridge.proto.child_workflow
5656
import temporalio.bridge.proto.common
57+
import temporalio.bridge.proto.nexus
5758
import temporalio.bridge.proto.workflow_activation
5859
import temporalio.bridge.proto.workflow_commands
5960
import temporalio.bridge.proto.workflow_completion
6061
import temporalio.common
6162
import temporalio.converter
6263
import temporalio.exceptions
63-
import temporalio.nexus
6464
import temporalio.workflow
6565
from temporalio.service import __version__
6666

@@ -1502,9 +1502,10 @@ async def workflow_start_nexus_operation(
15021502
service: str,
15031503
operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]],
15041504
input: Any,
1505-
output_type: Optional[Type[OutputT]] = None,
1506-
schedule_to_close_timeout: Optional[timedelta] = None,
1507-
headers: Optional[Mapping[str, str]] = None,
1505+
output_type: Optional[Type[OutputT]],
1506+
schedule_to_close_timeout: Optional[timedelta],
1507+
cancellation_type: temporalio.workflow.NexusOperationCancellationType,
1508+
headers: Optional[Mapping[str, str]],
15081509
) -> temporalio.workflow.NexusOperationHandle[OutputT]:
15091510
# start_nexus_operation
15101511
return await self._outbound.start_nexus_operation(
@@ -1515,6 +1516,7 @@ async def workflow_start_nexus_operation(
15151516
input=input,
15161517
output_type=output_type,
15171518
schedule_to_close_timeout=schedule_to_close_timeout,
1519+
cancellation_type=cancellation_type,
15181520
headers=headers,
15191521
)
15201522
)
@@ -2757,7 +2759,7 @@ def _apply_schedule_command(
27572759
if self._input.retry_policy:
27582760
self._input.retry_policy.apply_to_proto(v.retry_policy)
27592761
v.cancellation_type = cast(
2760-
"temporalio.bridge.proto.workflow_commands.ActivityCancellationType.ValueType",
2762+
temporalio.bridge.proto.workflow_commands.ActivityCancellationType.ValueType,
27612763
int(self._input.cancellation_type),
27622764
)
27632765

@@ -2893,7 +2895,7 @@ def _apply_start_command(self) -> None:
28932895
if self._input.task_timeout:
28942896
v.workflow_task_timeout.FromTimedelta(self._input.task_timeout)
28952897
v.parent_close_policy = cast(
2896-
"temporalio.bridge.proto.child_workflow.ParentClosePolicy.ValueType",
2898+
temporalio.bridge.proto.child_workflow.ParentClosePolicy.ValueType,
28972899
int(self._input.parent_close_policy),
28982900
)
28992901
v.workflow_id_reuse_policy = cast(
@@ -2915,7 +2917,7 @@ def _apply_start_command(self) -> None:
29152917
self._input.search_attributes, v.search_attributes
29162918
)
29172919
v.cancellation_type = cast(
2918-
"temporalio.bridge.proto.child_workflow.ChildWorkflowCancellationType.ValueType",
2920+
temporalio.bridge.proto.child_workflow.ChildWorkflowCancellationType.ValueType,
29192921
int(self._input.cancellation_type),
29202922
)
29212923
if self._input.versioning_intent:
@@ -3011,11 +3013,6 @@ def __init__(
30113013

30123014
@property
30133015
def operation_token(self) -> Optional[str]:
3014-
# TODO(nexus-preview): How should this behave?
3015-
# Java has a separate class that only exists if the operation token exists:
3016-
# https://github.com/temporalio/sdk-java/blob/master/temporal-sdk/src/main/java/io/temporal/internal/sync/NexusOperationExecutionImpl.java#L26
3017-
# And Go similar:
3018-
# https://github.com/temporalio/sdk-go/blob/master/internal/workflow.go#L2770-L2771
30193016
try:
30203017
return self._start_fut.result()
30213018
except BaseException:
@@ -3064,6 +3061,11 @@ def _apply_schedule_command(self) -> None:
30643061
v.schedule_to_close_timeout.FromTimedelta(
30653062
self._input.schedule_to_close_timeout
30663063
)
3064+
v.cancellation_type = cast(
3065+
temporalio.bridge.proto.nexus.NexusOperationCancellationType.ValueType,
3066+
int(self._input.cancellation_type),
3067+
)
3068+
30673069
if self._input.headers:
30683070
for key, val in self._input.headers.items():
30693071
v.nexus_header[key] = val

temporalio/workflow.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -858,9 +858,10 @@ async def workflow_start_nexus_operation(
858858
service: str,
859859
operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]],
860860
input: Any,
861-
output_type: Optional[Type[OutputT]] = None,
862-
schedule_to_close_timeout: Optional[timedelta] = None,
863-
headers: Optional[Mapping[str, str]] = None,
861+
output_type: Optional[Type[OutputT]],
862+
schedule_to_close_timeout: Optional[timedelta],
863+
cancellation_type: temporalio.workflow.NexusOperationCancellationType,
864+
headers: Optional[Mapping[str, str]],
864865
) -> NexusOperationHandle[OutputT]: ...
865866

866867
@abstractmethod
@@ -5117,6 +5118,46 @@ def _to_proto(self) -> temporalio.bridge.proto.common.VersioningIntent.ValueType
51175118
ServiceT = TypeVar("ServiceT")
51185119

51195120

5121+
class NexusOperationCancellationType(IntEnum):
5122+
"""Defines behavior of a Nexus operation when the caller workflow initiates cancellation.
5123+
5124+
Pass one of these values to :py:meth:`NexusClient.start_operation` to define cancellation
5125+
behavior.
5126+
5127+
To initiate cancellation, use :py:meth:`NexusOperationHandle.cancel` and then `await` the
5128+
operation handle. This will result in a :py:class:`exceptions.NexusOperationError`. The values
5129+
of this enum define what is guaranteed to have happened by that point.
5130+
"""
5131+
5132+
ABANDON = int(temporalio.bridge.proto.nexus.NexusOperationCancellationType.ABANDON)
5133+
"""Do not send any cancellation request to the operation handler; just report cancellation to the caller"""
5134+
5135+
TRY_CANCEL = int(
5136+
temporalio.bridge.proto.nexus.NexusOperationCancellationType.TRY_CANCEL
5137+
)
5138+
"""Send a cancellation request but immediately report cancellation to the caller. Note that this
5139+
does not guarantee that cancellation is delivered to the operation handler if the caller exits
5140+
before the delivery is done.
5141+
"""
5142+
5143+
# TODO(nexus-preview): core needs to be updated to handle
5144+
# NexusOperationCancelRequestCompleted and NexusOperationCancelRequestFailed
5145+
# see https://github.com/temporalio/sdk-core/issues/911
5146+
# WAIT_REQUESTED = int(
5147+
# temporalio.bridge.proto.nexus.NexusOperationCancellationType.WAIT_CANCELLATION_REQUESTED
5148+
# )
5149+
# """Send a cancellation request and wait for confirmation that the request was received.
5150+
# Does not wait for the operation to complete.
5151+
# """
5152+
5153+
WAIT_COMPLETED = int(
5154+
temporalio.bridge.proto.nexus.NexusOperationCancellationType.WAIT_CANCELLATION_COMPLETED
5155+
)
5156+
"""Send a cancellation request and wait for the operation to complete.
5157+
Note that the operation may not complete as cancelled (for example, if it catches the
5158+
:py:exc:`asyncio.CancelledError` resulting from the cancellation request)."""
5159+
5160+
51205161
class NexusClient(ABC, Generic[ServiceT]):
51215162
"""A client for invoking Nexus operations.
51225163
@@ -5147,6 +5188,7 @@ async def start_operation(
51475188
*,
51485189
output_type: Optional[Type[OutputT]] = None,
51495190
schedule_to_close_timeout: Optional[timedelta] = None,
5191+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
51505192
headers: Optional[Mapping[str, str]] = None,
51515193
) -> NexusOperationHandle[OutputT]: ...
51525194

@@ -5160,6 +5202,7 @@ async def start_operation(
51605202
*,
51615203
output_type: Optional[Type[OutputT]] = None,
51625204
schedule_to_close_timeout: Optional[timedelta] = None,
5205+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
51635206
headers: Optional[Mapping[str, str]] = None,
51645207
) -> NexusOperationHandle[OutputT]: ...
51655208

@@ -5176,6 +5219,7 @@ async def start_operation(
51765219
*,
51775220
output_type: Optional[Type[OutputT]] = None,
51785221
schedule_to_close_timeout: Optional[timedelta] = None,
5222+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
51795223
headers: Optional[Mapping[str, str]] = None,
51805224
) -> NexusOperationHandle[OutputT]: ...
51815225

@@ -5192,6 +5236,7 @@ async def start_operation(
51925236
*,
51935237
output_type: Optional[Type[OutputT]] = None,
51945238
schedule_to_close_timeout: Optional[timedelta] = None,
5239+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
51955240
headers: Optional[Mapping[str, str]] = None,
51965241
) -> NexusOperationHandle[OutputT]: ...
51975242

@@ -5208,6 +5253,7 @@ async def start_operation(
52085253
*,
52095254
output_type: Optional[Type[OutputT]] = None,
52105255
schedule_to_close_timeout: Optional[timedelta] = None,
5256+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
52115257
headers: Optional[Mapping[str, str]] = None,
52125258
) -> NexusOperationHandle[OutputT]: ...
52135259

@@ -5219,6 +5265,7 @@ async def start_operation(
52195265
*,
52205266
output_type: Optional[Type[OutputT]] = None,
52215267
schedule_to_close_timeout: Optional[timedelta] = None,
5268+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
52225269
headers: Optional[Mapping[str, str]] = None,
52235270
) -> Any:
52245271
"""Start a Nexus operation and return its handle.
@@ -5248,6 +5295,7 @@ async def execute_operation(
52485295
*,
52495296
output_type: Optional[Type[OutputT]] = None,
52505297
schedule_to_close_timeout: Optional[timedelta] = None,
5298+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
52515299
headers: Optional[Mapping[str, str]] = None,
52525300
) -> OutputT: ...
52535301

@@ -5261,6 +5309,7 @@ async def execute_operation(
52615309
*,
52625310
output_type: Optional[Type[OutputT]] = None,
52635311
schedule_to_close_timeout: Optional[timedelta] = None,
5312+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
52645313
headers: Optional[Mapping[str, str]] = None,
52655314
) -> OutputT: ...
52665315

@@ -5277,6 +5326,7 @@ async def execute_operation(
52775326
*,
52785327
output_type: Optional[Type[OutputT]] = None,
52795328
schedule_to_close_timeout: Optional[timedelta] = None,
5329+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
52805330
headers: Optional[Mapping[str, str]] = None,
52815331
) -> OutputT: ...
52825332

@@ -5296,6 +5346,7 @@ async def execute_operation(
52965346
*,
52975347
output_type: Optional[Type[OutputT]] = None,
52985348
schedule_to_close_timeout: Optional[timedelta] = None,
5349+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
52995350
headers: Optional[Mapping[str, str]] = None,
53005351
) -> OutputT: ...
53015352

@@ -5312,6 +5363,7 @@ async def execute_operation(
53125363
*,
53135364
output_type: Optional[Type[OutputT]] = None,
53145365
schedule_to_close_timeout: Optional[timedelta] = None,
5366+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
53155367
headers: Optional[Mapping[str, str]] = None,
53165368
) -> OutputT: ...
53175369

@@ -5323,6 +5375,7 @@ async def execute_operation(
53235375
*,
53245376
output_type: Optional[Type[OutputT]] = None,
53255377
schedule_to_close_timeout: Optional[timedelta] = None,
5378+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
53265379
headers: Optional[Mapping[str, str]] = None,
53275380
) -> Any:
53285381
"""Execute a Nexus operation and return its result.
@@ -5374,6 +5427,7 @@ async def start_operation(
53745427
*,
53755428
output_type: Optional[Type] = None,
53765429
schedule_to_close_timeout: Optional[timedelta] = None,
5430+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
53775431
headers: Optional[Mapping[str, str]] = None,
53785432
) -> Any:
53795433
return (
@@ -5384,6 +5438,7 @@ async def start_operation(
53845438
input=input,
53855439
output_type=output_type,
53865440
schedule_to_close_timeout=schedule_to_close_timeout,
5441+
cancellation_type=cancellation_type,
53875442
headers=headers,
53885443
)
53895444
)
@@ -5395,13 +5450,15 @@ async def execute_operation(
53955450
*,
53965451
output_type: Optional[Type] = None,
53975452
schedule_to_close_timeout: Optional[timedelta] = None,
5453+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
53985454
headers: Optional[Mapping[str, str]] = None,
53995455
) -> Any:
54005456
handle = await self.start_operation(
54015457
operation,
54025458
input,
54035459
output_type=output_type,
54045460
schedule_to_close_timeout=schedule_to_close_timeout,
5461+
cancellation_type=cancellation_type,
54055462
headers=headers,
54065463
)
54075464
return await handle

tests/helpers/__init__.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from typing import Any, Awaitable, Callable, Optional, Sequence, Type, TypeVar
88

99
from temporalio.api.common.v1 import WorkflowExecution
10+
from temporalio.api.enums.v1 import EventType as EventType
1011
from temporalio.api.enums.v1 import IndexedValueType
12+
from temporalio.api.history.v1 import HistoryEvent
1113
from temporalio.api.operatorservice.v1 import (
1214
AddSearchAttributesRequest,
1315
ListSearchAttributesRequest,
@@ -287,3 +289,40 @@ async def check_unpaused() -> bool:
287289
return not info.paused
288290

289291
await assert_eventually(check_unpaused)
292+
293+
294+
async def print_history(handle: WorkflowHandle):
295+
i = 1
296+
async for evt in handle.fetch_history_events():
297+
event = EventType.Name(evt.event_type).removeprefix("EVENT_TYPE_")
298+
print(f"{i:2}: {event}")
299+
i += 1
300+
301+
302+
async def print_interleaved_histories(*handles: WorkflowHandle) -> None:
303+
"""
304+
Print the interleaved history events from multiple workflow handles in columns.
305+
"""
306+
all_events: list[tuple[WorkflowHandle, HistoryEvent, int]] = []
307+
for handle in handles:
308+
event_num = 1
309+
async for event in handle.fetch_history_events():
310+
all_events.append((handle, event, event_num))
311+
event_num += 1
312+
all_events.sort(key=lambda item: item[1].event_time.ToDatetime())
313+
col_width = 40
314+
315+
def _format_row(items: list[str], truncate: bool = False) -> str:
316+
if truncate:
317+
items = [item[: col_width - 3] for item in items]
318+
return " | ".join(f"{item:<{col_width - 3}}" for item in items)
319+
320+
headers = [handle.id for handle in handles]
321+
print("\n" + _format_row(headers, truncate=True))
322+
print("-" * (col_width * len(handles) + len(handles) - 1))
323+
for handle, event, event_num in all_events:
324+
event_type = EventType.Name(event.event_type).removeprefix("EVENT_TYPE_")
325+
row = [""] * len(handles)
326+
col_idx = handles.index(handle)
327+
row[col_idx] = f"{event_num:2}: {event_type[: col_width - 5]}"
328+
print(_format_row(row))

0 commit comments

Comments
 (0)