diff --git a/temporalio/activity.py b/temporalio/activity.py index c67fa0f38..4a0914bc2 100644 --- a/temporalio/activity.py +++ b/temporalio/activity.py @@ -34,6 +34,9 @@ overload, ) +import temporalio.bridge +import temporalio.bridge.proto +import temporalio.bridge.proto.activity_task import temporalio.common import temporalio.converter @@ -135,6 +138,34 @@ def _logger_details(self) -> Mapping[str, Any]: _current_context: contextvars.ContextVar[_Context] = contextvars.ContextVar("activity") +@dataclass +class _ActivityCancellationDetailsHolder: + details: Optional[ActivityCancellationDetails] = None + + +@dataclass(frozen=True) +class ActivityCancellationDetails: + """Provides the reasons for the activity's cancellation. Cancellation details are set once and do not change once set.""" + + not_found: bool = False + cancel_requested: bool = False + paused: bool = False + timed_out: bool = False + worker_shutdown: bool = False + + @staticmethod + def _from_proto( + proto: temporalio.bridge.proto.activity_task.ActivityCancellationDetails, + ) -> ActivityCancellationDetails: + return ActivityCancellationDetails( + not_found=proto.is_not_found, + cancel_requested=proto.is_cancelled, + paused=proto.is_paused, + timed_out=proto.is_timed_out, + worker_shutdown=proto.is_worker_shutdown, + ) + + @dataclass class _Context: info: Callable[[], Info] @@ -148,6 +179,7 @@ class _Context: temporalio.converter.PayloadConverter, ] runtime_metric_meter: Optional[temporalio.common.MetricMeter] + cancellation_details: _ActivityCancellationDetailsHolder _logger_details: Optional[Mapping[str, Any]] = None _payload_converter: Optional[temporalio.converter.PayloadConverter] = None _metric_meter: Optional[temporalio.common.MetricMeter] = None @@ -260,6 +292,11 @@ def info() -> Info: return _Context.current().info() +def cancellation_details() -> Optional[ActivityCancellationDetails]: + """Cancellation details of the current activity, if any. Once set, cancellation details do not change.""" + return _Context.current().cancellation_details.details + + def heartbeat(*details: Any) -> None: """Send a heartbeat for the current activity. diff --git a/temporalio/bridge/src/client.rs b/temporalio/bridge/src/client.rs index f5c0aa750..49210c063 100644 --- a/temporalio/bridge/src/client.rs +++ b/temporalio/bridge/src/client.rs @@ -235,6 +235,9 @@ impl ClientRef { "patch_schedule" => { rpc_call!(retry_client, call, patch_schedule) } + "pause_activity" => { + rpc_call!(retry_client, call, pause_activity) + } "poll_activity_task_queue" => { rpc_call!(retry_client, call, poll_activity_task_queue) } @@ -325,6 +328,9 @@ impl ClientRef { "trigger_workflow_rule" => { rpc_call!(retry_client, call, trigger_workflow_rule) } + "unpause_activity" => { + rpc_call!(retry_client, call, unpause_activity) + } "update_namespace" => { rpc_call_on_trait!(retry_client, call, WorkflowService, update_namespace) } diff --git a/temporalio/client.py b/temporalio/client.py index 4cd9d1f19..f46297eb9 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -56,6 +56,7 @@ import temporalio.runtime import temporalio.service import temporalio.workflow +from temporalio.activity import ActivityCancellationDetails from temporalio.service import ( HttpConnectProxyConfig, KeepAliveConfig, @@ -5145,9 +5146,10 @@ def __init__(self) -> None: class AsyncActivityCancelledError(temporalio.exceptions.TemporalError): """Error that occurs when async activity attempted heartbeat but was cancelled.""" - def __init__(self) -> None: + def __init__(self, details: Optional[ActivityCancellationDetails] = None) -> None: """Create async activity cancelled error.""" super().__init__("Activity cancelled") + self.details = details class ScheduleAlreadyRunningError(temporalio.exceptions.TemporalError): @@ -6287,8 +6289,14 @@ async def heartbeat_async_activity( metadata=input.rpc_metadata, timeout=input.rpc_timeout, ) - if resp_by_id.cancel_requested: - raise AsyncActivityCancelledError() + if resp_by_id.cancel_requested or resp_by_id.activity_paused: + raise AsyncActivityCancelledError( + details=ActivityCancellationDetails( + cancel_requested=resp_by_id.cancel_requested, + paused=resp_by_id.activity_paused, + ) + ) + else: resp = await self._client.workflow_service.record_activity_task_heartbeat( temporalio.api.workflowservice.v1.RecordActivityTaskHeartbeatRequest( @@ -6301,8 +6309,13 @@ async def heartbeat_async_activity( metadata=input.rpc_metadata, timeout=input.rpc_timeout, ) - if resp.cancel_requested: - raise AsyncActivityCancelledError() + if resp.cancel_requested or resp.activity_paused: + raise AsyncActivityCancelledError( + details=ActivityCancellationDetails( + cancel_requested=resp.cancel_requested, + paused=resp.activity_paused, + ) + ) async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> None: result = ( diff --git a/temporalio/testing/_activity.py b/temporalio/testing/_activity.py index 19dd3819b..3694dfdc7 100644 --- a/temporalio/testing/_activity.py +++ b/temporalio/testing/_activity.py @@ -74,15 +74,29 @@ def __init__(self) -> None: self._cancelled = False self._worker_shutdown = False self._activities: Set[_Activity] = set() + self._cancellation_details = ( + temporalio.activity._ActivityCancellationDetailsHolder() + ) - def cancel(self) -> None: + def cancel( + self, + cancellation_details: temporalio.activity.ActivityCancellationDetails = temporalio.activity.ActivityCancellationDetails( + cancel_requested=True + ), + ) -> None: """Cancel the activity. + Args: + cancellation_details: details about the cancellation. These will + be accessible through temporalio.activity.cancellation_details() + in the activity after cancellation. + This only has an effect on the first call. """ if self._cancelled: return self._cancelled = True + self._cancellation_details.details = cancellation_details for act in self._activities: act.cancel() @@ -154,6 +168,7 @@ def __init__( else self.cancel_thread_raiser.shielded, payload_converter_class_or_instance=env.payload_converter, runtime_metric_meter=env.metric_meter, + cancellation_details=env._cancellation_details, ) self.task: Optional[asyncio.Task] = None diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index f38a27e12..b05f3f6e9 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -15,7 +15,7 @@ import warnings from abc import ABC, abstractmethod from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from typing import ( Any, @@ -216,7 +216,13 @@ def _cancel( warnings.warn(f"Cannot find activity to cancel for token {task_token!r}") return logger.debug("Cancelling activity %s, reason: %s", task_token, cancel.reason) - activity.cancel(cancelled_by_request=True) + activity.cancellation_details.details = ( + temporalio.activity.ActivityCancellationDetails._from_proto(cancel.details) + ) + activity.cancel( + cancelled_by_request=cancel.details.is_cancelled + or cancel.details.is_worker_shutdown + ) def _heartbeat(self, task_token: bytes, *details: Any) -> None: # We intentionally make heartbeating non-async, but since the data @@ -303,6 +309,24 @@ async def _run_activity( await self._data_converter.encode_failure( err, completion.result.failed.failure ) + elif ( + isinstance( + err, + (asyncio.CancelledError, temporalio.exceptions.CancelledError), + ) + and running_activity.cancellation_details.details + and running_activity.cancellation_details.details.paused + ): + temporalio.activity.logger.warning( + f"Completing as failure due to unhandled cancel error produced by activity pause", + ) + await self._data_converter.encode_failure( + temporalio.exceptions.ApplicationError( + type="ActivityPause", + message="Unhandled activity cancel error produced by activity pause", + ), + completion.result.failed.failure, + ) elif ( isinstance( err, @@ -336,7 +360,6 @@ async def _run_activity( await self._data_converter.encode_failure( err, completion.result.failed.failure ) - # For broken executors, we have to fail the entire worker if isinstance(err, concurrent.futures.BrokenExecutor): self._fail_worker_exception_queue.put_nowait(err) @@ -524,6 +547,7 @@ async def _execute_activity( else running_activity.cancel_thread_raiser.shielded, payload_converter_class_or_instance=self._data_converter.payload_converter, runtime_metric_meter=None if sync_non_threaded else self._metric_meter, + cancellation_details=running_activity.cancellation_details, ) ) temporalio.activity.logger.debug("Starting activity") @@ -570,6 +594,9 @@ class _RunningActivity: done: bool = False cancelled_by_request: bool = False cancelled_due_to_heartbeat_error: Optional[Exception] = None + cancellation_details: temporalio.activity._ActivityCancellationDetailsHolder = ( + field(default_factory=temporalio.activity._ActivityCancellationDetailsHolder) + ) def cancel( self, @@ -659,6 +686,7 @@ async def execute_activity(self, input: ExecuteActivityInput) -> Any: # can set the initializer on the executor). ctx = temporalio.activity._Context.current() info = ctx.info() + cancellation_details = ctx.cancellation_details # Heartbeat calls internally use a data converter which is async so # they need to be called on the event loop @@ -717,6 +745,7 @@ async def heartbeat_with_context(*details: Any) -> None: worker_shutdown_event.thread_event, payload_converter_class_or_instance, ctx.runtime_metric_meter, + cancellation_details, input.fn, *input.args, ] @@ -732,7 +761,6 @@ async def heartbeat_with_context(*details: Any) -> None: finally: if shared_manager: await shared_manager.unregister_heartbeater(info.task_token) - # Otherwise for async activity, just run return await input.fn(*input.args) @@ -764,6 +792,7 @@ def _execute_sync_activity( temporalio.converter.PayloadConverter, ], runtime_metric_meter: Optional[temporalio.common.MetricMeter], + cancellation_details: temporalio.activity._ActivityCancellationDetailsHolder, fn: Callable[..., Any], *args: Any, ) -> Any: @@ -795,6 +824,7 @@ def _execute_sync_activity( else cancel_thread_raiser.shielded, payload_converter_class_or_instance=payload_converter_class_or_instance, runtime_metric_meter=runtime_metric_meter, + cancellation_details=cancellation_details, ) ) return fn(*args) diff --git a/tests/conftest.py b/tests/conftest.py index be99e117f..37b1fe89c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -115,6 +115,8 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]: "frontend.workerVersioningDataAPIs=true", "--dynamic-config-value", "system.enableDeploymentVersions=true", + "--dynamic-config-value", + "frontend.activityAPIsEnabled=true", ], dev_server_download_version=DEV_SERVER_DOWNLOAD_VERSION, ) diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index da5259748..a352877d5 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -13,7 +13,12 @@ ListSearchAttributesRequest, ) from temporalio.api.update.v1 import UpdateRef -from temporalio.api.workflowservice.v1 import PollWorkflowExecutionUpdateRequest +from temporalio.api.workflow.v1 import PendingActivityInfo +from temporalio.api.workflowservice.v1 import ( + PauseActivityRequest, + PollWorkflowExecutionUpdateRequest, + UnpauseActivityRequest, +) from temporalio.client import BuildIdOpAddNewDefault, Client, WorkflowHandle from temporalio.common import SearchAttributeKey from temporalio.service import RPCError, RPCStatusCode @@ -210,3 +215,75 @@ async def check_workflow_exists() -> bool: await assert_eq_eventually(True, check_workflow_exists) assert handle is not None return handle + + +async def assert_pending_activity_exists_eventually( + handle: WorkflowHandle, + activity_id: str, + timeout: timedelta = timedelta(seconds=5), +) -> PendingActivityInfo: + """Wait until a pending activity with the given ID exists and return it.""" + + async def check() -> PendingActivityInfo: + act_info = await get_pending_activity_info(handle, activity_id) + if act_info is not None: + return act_info + raise AssertionError( + f"Activity with ID {activity_id} not found in pending activities" + ) + + return await assert_eventually(check, timeout=timeout) + + +async def get_pending_activity_info( + handle: WorkflowHandle, + activity_id: str, +) -> Optional[PendingActivityInfo]: + """Get pending activity info by ID, or None if not found.""" + desc = await handle.describe() + for act in desc.raw_description.pending_activities: + if act.activity_id == activity_id: + return act + return None + + +async def pause_and_assert(client: Client, handle: WorkflowHandle, activity_id: str): + """Pause the given activity and assert it becomes paused.""" + desc = await handle.describe() + req = PauseActivityRequest( + namespace=client.namespace, + execution=WorkflowExecution( + workflow_id=desc.raw_description.workflow_execution_info.execution.workflow_id, + run_id=desc.raw_description.workflow_execution_info.execution.run_id, + ), + id=activity_id, + ) + await client.workflow_service.pause_activity(req) + + # Assert eventually paused + async def check_paused() -> bool: + info = await assert_pending_activity_exists_eventually(handle, activity_id) + return info.paused + + await assert_eventually(check_paused) + + +async def unpause_and_assert(client: Client, handle: WorkflowHandle, activity_id: str): + """Unpause the given activity and assert it is not paused.""" + desc = await handle.describe() + req = UnpauseActivityRequest( + namespace=client.namespace, + execution=WorkflowExecution( + workflow_id=desc.raw_description.workflow_execution_info.execution.workflow_id, + run_id=desc.raw_description.workflow_execution_info.execution.run_id, + ), + id=activity_id, + ) + await client.workflow_service.unpause_activity(req) + + # Assert eventually not paused + async def check_unpaused() -> bool: + info = await assert_pending_activity_exists_eventually(handle, activity_id) + return not info.paused + + await assert_eventually(check_unpaused) diff --git a/tests/testing/test_activity.py b/tests/testing/test_activity.py index 29b66c772..ff281d722 100644 --- a/tests/testing/test_activity.py +++ b/tests/testing/test_activity.py @@ -26,7 +26,11 @@ async def via_create_task(): await asyncio.Future() raise RuntimeError("Unreachable") except asyncio.CancelledError: - activity.heartbeat("cancelled") + cancellation_details = activity.cancellation_details() + if cancellation_details: + activity.heartbeat( + f"cancelled={cancellation_details.cancel_requested}", + ) return "done" env = ActivityEnvironment() @@ -37,9 +41,11 @@ async def via_create_task(): task = asyncio.create_task(env.run(do_stuff, "param1")) await waiting.wait() # Cancel and confirm done - env.cancel() + env.cancel( + cancellation_details=activity.ActivityCancellationDetails(cancel_requested=True) + ) assert "done" == await task - assert heartbeats == ["param: param1", "task, type: unknown", "cancelled"] + assert heartbeats == ["param: param1", "task, type: unknown", "cancelled=True"] def test_activity_env_sync(): @@ -72,7 +78,11 @@ def via_thread(): raise RuntimeError("Unexpected") except CancelledError: nonlocal properly_cancelled - properly_cancelled = True + cancellation_details = activity.cancellation_details() + if cancellation_details: + properly_cancelled = cancellation_details.cancel_requested + else: + properly_cancelled = False env = ActivityEnvironment() # Set heartbeat handler to add to list @@ -84,7 +94,9 @@ def via_thread(): waiting.wait() # Cancel and confirm done time.sleep(1) - env.cancel() + env.cancel( + cancellation_details=activity.ActivityCancellationDetails(cancel_requested=True) + ) thread.join() assert heartbeats == ["param: param1", "task, type: unknown"] assert properly_cancelled diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 957347817..6b331455d 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import concurrent.futures import dataclasses import json import logging @@ -38,6 +39,7 @@ from google.protobuf.timestamp_pb2 import Timestamp from typing_extensions import Literal, Protocol, runtime_checkable +import temporalio.activity import temporalio.worker import temporalio.workflow from temporalio import activity, workflow @@ -52,6 +54,7 @@ from temporalio.bridge.proto.workflow_activation import WorkflowActivation from temporalio.bridge.proto.workflow_completion import WorkflowActivationCompletion from temporalio.client import ( + AsyncActivityCancelledError, Client, RPCError, RPCStatusCode, @@ -116,11 +119,15 @@ admitted_update_task, assert_eq_eventually, assert_eventually, + assert_pending_activity_exists_eventually, assert_task_fail_eventually, assert_workflow_exists_eventually, ensure_search_attributes_present, find_free_port, + get_pending_activity_info, new_worker, + pause_and_assert, + unpause_and_assert, workflow_update_exists, ) from tests.helpers.external_stack_trace import ( @@ -7622,3 +7629,289 @@ async def test_workflow_missing_local_activity_no_activities(client: Client): handle, message_contains="Activity function say_hello is not registered on this worker, no available activities", ) + + +@activity.defn +async def heartbeat_activity( + catch_err: bool = True, +) -> Optional[temporalio.activity.ActivityCancellationDetails]: + while True: + try: + activity.heartbeat() + # If we have heartbeat details, we are on the second attempt, we have retried due to pause/unpause. + if activity.info().heartbeat_details: + return activity.cancellation_details() + await asyncio.sleep(0.1) + except (CancelledError, asyncio.CancelledError) as err: + if not catch_err: + raise err + return activity.cancellation_details() + finally: + activity.heartbeat("finally-complete") + + +@activity.defn +def sync_heartbeat_activity( + catch_err: bool = True, +) -> Optional[temporalio.activity.ActivityCancellationDetails]: + while True: + try: + activity.heartbeat() + # If we have heartbeat details, we are on the second attempt, we have retried due to pause/unpause. + if activity.info().heartbeat_details: + return activity.cancellation_details() + time.sleep(0.1) + except (CancelledError, asyncio.CancelledError) as err: + if not catch_err: + raise err + return activity.cancellation_details() + finally: + activity.heartbeat("finally-complete") + + +@workflow.defn +class ActivityHeartbeatWorkflow: + @workflow.run + async def run( + self, activity_id: str + ) -> list[Optional[temporalio.activity.ActivityCancellationDetails]]: + result = [] + result.append( + await workflow.execute_activity( + sync_heartbeat_activity, + activity_id=activity_id, + start_to_close_timeout=timedelta(seconds=10), + heartbeat_timeout=timedelta(seconds=2), + retry_policy=RetryPolicy(maximum_attempts=1), + ) + ) + result.append( + await workflow.execute_activity( + heartbeat_activity, + activity_id=f"{activity_id}-2", + start_to_close_timeout=timedelta(seconds=10), + heartbeat_timeout=timedelta(seconds=2), + retry_policy=RetryPolicy(maximum_attempts=1), + ) + ) + return result + + +async def test_activity_pause_cancellation_details( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip("Time-skipping server does not support pause API yet") + with concurrent.futures.ThreadPoolExecutor() as executor: + async with Worker( + client, + task_queue=str(uuid.uuid4()), + workflows=[ActivityHeartbeatWorkflow], + activities=[heartbeat_activity, sync_heartbeat_activity], + activity_executor=executor, + ) as worker: + test_activity_id = f"heartbeat-activity-{uuid.uuid4()}" + + handle = await client.start_workflow( + ActivityHeartbeatWorkflow.run, + test_activity_id, + id=f"test-activity-pause-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Wait for sync activity + activity_info_1 = await assert_pending_activity_exists_eventually( + handle, test_activity_id + ) + # Assert not paused + assert not activity_info_1.paused + # Pause activity then assert it is paused + await pause_and_assert(client, handle, activity_info_1.activity_id) + + # Wait for async activity + activity_info_2 = await assert_pending_activity_exists_eventually( + handle, f"{test_activity_id}-2" + ) + # Assert not paused + assert not activity_info_2.paused + # Pause activity then assert it is paused + await pause_and_assert(client, handle, activity_info_2.activity_id) + + # Assert workflow return value for paused activities that caught the + # cancel error + result = await handle.result() + assert result[0] == temporalio.activity.ActivityCancellationDetails( + paused=True + ) + assert result[1] == temporalio.activity.ActivityCancellationDetails( + paused=True + ) + + +@workflow.defn +class ActivityHeartbeatPauseUnpauseWorkflow: + @workflow.run + async def run( + self, activity_id: str + ) -> list[Optional[temporalio.activity.ActivityCancellationDetails]]: + results = [] + results.append( + await workflow.execute_activity( + sync_heartbeat_activity, + False, + activity_id=activity_id, + start_to_close_timeout=timedelta(seconds=10), + heartbeat_timeout=timedelta(seconds=1), + retry_policy=RetryPolicy(maximum_attempts=2), + ) + ) + results.append( + await workflow.execute_activity( + heartbeat_activity, + False, + activity_id=f"{activity_id}-2", + start_to_close_timeout=timedelta(seconds=10), + heartbeat_timeout=timedelta(seconds=1), + retry_policy=RetryPolicy(maximum_attempts=2), + ) + ) + return results + + +async def test_activity_pause_unpause(client: Client, env: WorkflowEnvironment): + if env.supports_time_skipping: + pytest.skip("Time-skipping server does not support pause API yet") + + async def check_heartbeat_details_exist( + handle: WorkflowHandle, + activity_id: str, + ) -> None: + act_info = await get_pending_activity_info(handle, activity_id) + if act_info is None: + raise AssertionError(f"Activity with ID {activity_id} not found.") + if len(act_info.heartbeat_details.payloads) == 0: + raise AssertionError( + f"Activity with ID {activity_id} has no heartbeat details" + ) + + with concurrent.futures.ThreadPoolExecutor() as executor: + async with Worker( + client, + task_queue=str(uuid.uuid4()), + workflows=[ActivityHeartbeatPauseUnpauseWorkflow], + activities=[heartbeat_activity, sync_heartbeat_activity], + activity_executor=executor, + max_heartbeat_throttle_interval=timedelta(milliseconds=300), + default_heartbeat_throttle_interval=timedelta(milliseconds=300), + ) as worker: + test_activity_id = f"heartbeat-activity-{uuid.uuid4()}" + + handle = await client.start_workflow( + ActivityHeartbeatPauseUnpauseWorkflow.run, + test_activity_id, + id=f"test-activity-pause-unpause-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + # Wait for sync activity + activity_info_1 = await assert_pending_activity_exists_eventually( + handle, test_activity_id + ) + # Assert not paused + assert not activity_info_1.paused + # Pause activity then assert it is paused + await pause_and_assert(client, handle, activity_info_1.activity_id) + + # Wait for heartbeat details to exist. At this point, the activity has finished executing + # due to cancellation from the pause. + await assert_eventually( + lambda: check_heartbeat_details_exist( + handle, activity_info_1.activity_id + ) + ) + + # Unpause activity + await unpause_and_assert(client, handle, activity_info_1.activity_id) + # Expect second activity to have started now + activity_info_2 = await assert_pending_activity_exists_eventually( + handle, f"{test_activity_id}-2" + ) + # Assert not paused + assert not activity_info_2.paused + # Pause activity then assert it is paused + await pause_and_assert(client, handle, activity_info_2.activity_id) + # Wait for heartbeat details to exist. At this point, the activity has finished executing + # due to cancellation from the pause. + await assert_eventually( + lambda: check_heartbeat_details_exist( + handle, activity_info_2.activity_id + ) + ) + # Unpause activity + await unpause_and_assert(client, handle, activity_info_2.activity_id) + + # Check workflow complete + result = await handle.result() + assert result[0] == None + assert result[1] == None + + +@activity.defn +async def external_activity_heartbeat() -> None: + activity.raise_complete_async() + + +@workflow.defn +class ExternalActivityWorkflow: + @workflow.run + async def run(self, activity_id: str) -> None: + await workflow.execute_activity( + external_activity_heartbeat, + activity_id=activity_id, + start_to_close_timeout=timedelta(seconds=10), + heartbeat_timeout=timedelta(seconds=1), + retry_policy=RetryPolicy(maximum_attempts=2), + ) + + +async def test_external_activity_cancellation_details( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip("Time-skipping server does not support pause API yet") + async with Worker( + client, + task_queue=str(uuid.uuid4()), + workflows=[ExternalActivityWorkflow], + activities=[external_activity_heartbeat], + ) as worker: + test_activity_id = f"heartbeat-activity-{uuid.uuid4()}" + + wf_handle = await client.start_workflow( + ExternalActivityWorkflow.run, + test_activity_id, + id=f"test-external-activity-pause-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + wf_desc = await wf_handle.describe() + + # Wait for external activity + activity_info = await assert_pending_activity_exists_eventually( + wf_handle, test_activity_id + ) + # Assert not paused + assert not activity_info.paused + + external_activity_handle = client.get_async_activity_handle( + workflow_id=wf_desc.id, run_id=wf_desc.run_id, activity_id=test_activity_id + ) + + # Pause activity then assert it is paused + await pause_and_assert(client, wf_handle, activity_info.activity_id) + + try: + await external_activity_handle.heartbeat() + except AsyncActivityCancelledError as err: + assert err.details == temporalio.activity.ActivityCancellationDetails( + paused=True + )