11
11
import temporalio .nexus ._operation_handlers
12
12
from temporalio import exceptions , nexus , workflow
13
13
from temporalio .api .enums .v1 import EventType
14
- from temporalio .api .history .v1 import HistoryEvent
15
14
from temporalio .client import (
16
15
WithStartWorkflowOperation ,
17
16
WorkflowExecutionStatus ,
21
20
from temporalio .testing import WorkflowEnvironment
22
21
from temporalio .worker import Worker
23
22
from tests .helpers .nexus import create_nexus_endpoint , make_nexus_endpoint_name
23
+ from tests .nexus .test_workflow_caller_cancellation_types import (
24
+ assert_event_subsequence ,
25
+ get_event_time ,
26
+ has_event ,
27
+ )
24
28
25
29
26
30
@dataclass
@@ -257,13 +261,13 @@ async def check_behavior_for_abandon(
257
261
assert result .error_type == "NexusOperationError"
258
262
assert result .error_cause_type == "CancelledError"
259
263
260
- await _assert_event_subsequence (
264
+ await assert_event_subsequence (
261
265
[
262
266
(caller_wf , EventType .EVENT_TYPE_WORKFLOW_EXECUTION_STARTED ),
263
267
(caller_wf , EventType .EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED ),
264
268
]
265
269
)
266
- assert not await _has_event (
270
+ assert not await has_event (
267
271
caller_wf ,
268
272
EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED ,
269
273
)
@@ -280,18 +284,18 @@ async def check_behavior_for_try_cancel(
280
284
assert result .error_cause_type == "CancelledError"
281
285
282
286
caller_op_future_resolved = test_context .caller_op_future_resolved .result ()
283
- await _assert_event_subsequence (
287
+ await assert_event_subsequence (
284
288
[
285
289
(caller_wf , EventType .EVENT_TYPE_WORKFLOW_EXECUTION_STARTED ),
286
290
(caller_wf , EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED ),
287
291
(caller_wf , EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED ),
288
292
]
289
293
)
290
- op_cancel_requested_event = await _get_event_time (
294
+ op_cancel_requested_event = await get_event_time (
291
295
caller_wf ,
292
296
EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED ,
293
297
)
294
- op_cancel_request_failed_event = await _get_event_time (
298
+ op_cancel_request_failed_event = await get_event_time (
295
299
caller_wf ,
296
300
EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED ,
297
301
)
@@ -311,7 +315,7 @@ async def check_behavior_for_wait_cancellation_requested(
311
315
result = await caller_wf .result ()
312
316
assert result .error_type == "NexusOperationError"
313
317
assert result .error_cause_type == "HandlerError"
314
- await _assert_event_subsequence (
318
+ await assert_event_subsequence (
315
319
[
316
320
(caller_wf , EventType .EVENT_TYPE_WORKFLOW_EXECUTION_STARTED ),
317
321
(caller_wf , EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED ),
@@ -320,7 +324,7 @@ async def check_behavior_for_wait_cancellation_requested(
320
324
]
321
325
)
322
326
caller_op_future_resolved = test_context .caller_op_future_resolved .result ()
323
- op_cancel_request_failed = await _get_event_time (
327
+ op_cancel_request_failed = await get_event_time (
324
328
caller_wf ,
325
329
EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED ,
326
330
)
@@ -335,7 +339,7 @@ async def check_behavior_for_wait_cancellation_completed(
335
339
await caller_wf .signal (CallerWorkflow .release )
336
340
result = await caller_wf .result ()
337
341
assert not result .error_type
338
- await _assert_event_subsequence (
342
+ await assert_event_subsequence (
339
343
[
340
344
(caller_wf , EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED ),
341
345
(caller_wf , EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED ),
@@ -344,72 +348,8 @@ async def check_behavior_for_wait_cancellation_completed(
344
348
]
345
349
)
346
350
caller_op_future_resolved = test_context .caller_op_future_resolved .result ()
347
- handler_wf_completed = await _get_event_time (
351
+ handler_wf_completed = await get_event_time (
348
352
handler_wf ,
349
353
EventType .EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED ,
350
354
)
351
355
assert handler_wf_completed < caller_op_future_resolved
352
-
353
-
354
- async def _has_event (wf_handle : WorkflowHandle , event_type : EventType .ValueType ):
355
- async for e in wf_handle .fetch_history_events ():
356
- if e .event_type == event_type :
357
- return True
358
- return False
359
-
360
-
361
- async def _get_event_time (
362
- wf_handle : WorkflowHandle ,
363
- event_type : EventType .ValueType ,
364
- ) -> datetime :
365
- async for event in wf_handle .fetch_history_events ():
366
- if event .event_type == event_type :
367
- return event .event_time .ToDatetime ().replace (tzinfo = timezone .utc )
368
- event_type_name = EventType .Name (event_type ).removeprefix ("EVENT_TYPE_" )
369
- assert False , f"Event { event_type_name } not found in { wf_handle .id } "
370
-
371
-
372
- async def _assert_event_subsequence (
373
- expected_events : list [tuple [WorkflowHandle , EventType .ValueType ]],
374
- ) -> None :
375
- """
376
- Given a sequence of (WorkflowHandle, EventType) pairs, assert that the sorted sequence of events
377
- from both workflows contains that subsequence.
378
- """
379
-
380
- def _event_time (
381
- item : tuple [WorkflowHandle , HistoryEvent ],
382
- ) -> datetime :
383
- return item [1 ].event_time .ToDatetime ()
384
-
385
- all_events = []
386
- handles = {h for h , _ in expected_events }
387
- for h in handles :
388
- async for e in h .fetch_history_events ():
389
- all_events .append ((h , e ))
390
- _all_events = iter (sorted (all_events , key = _event_time ))
391
- _expected_events = iter (expected_events )
392
-
393
- previous_expected_handle , previous_expected_event_type_name = None , None
394
- for expected_handle , expected_event_type in _expected_events :
395
- expected_event_type_name = EventType .Name (expected_event_type ).removeprefix (
396
- "EVENT_TYPE_"
397
- )
398
- has_expected = next (
399
- (
400
- (h , e )
401
- for h , e in _all_events
402
- if h == expected_handle and e .event_type == expected_event_type
403
- ),
404
- None ,
405
- )
406
- if not has_expected :
407
- if previous_expected_handle is not None :
408
- prefix = f"After { previous_expected_event_type_name } in { previous_expected_handle .id } , "
409
- else :
410
- prefix = ""
411
- pytest .fail (
412
- f"{ prefix } expected { expected_event_type_name } in { expected_handle .id } "
413
- )
414
- previous_expected_event_type_name = expected_event_type_name
415
- previous_expected_handle = expected_handle
0 commit comments