Skip to content

Commit be9d01b

Browse files
authored
feat: Define workflow signal decorator and definition for registry (#46)
<!-- Describe what has changed in this PR --> **What changed?** Define workflow signal decorator and definition for registry <!-- Tell your future self why have you made these changes --> **Why?** This is a must have component in order to implement signals in cadence workflow. It allows user to define a method in their workflow class that can respond to signal. <!-- How have you verified this change? Tested locally? Added a unit test? Checked in staging env? --> **How did you test it?** Unit tests <!-- Assuming the worst case, what can be broken when deploying this change to production? --> **Potential risks** <!-- Is it notable for release? e.g. schema updates, configuration or data migration required? If so, please mention it, and also update CHANGELOG.md --> **Release notes** <!-- Is there any documentation updates should be made for config, https://cadenceworkflow.io/docs/operation-guide/setup/ ? If so, please open an PR in https://github.com/cadence-workflow/cadence-docs --> **Documentation Changes** --------- Signed-off-by: Tim Li <[email protected]>
1 parent ea98ef7 commit be9d01b

File tree

3 files changed

+335
-2
lines changed

3 files changed

+335
-2
lines changed

cadence/signal.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""
2+
Signal definition for Cadence workflows.
3+
4+
This module provides the SignalDefinition class used internally by WorkflowDefinition
5+
to track signal handler metadata.
6+
"""
7+
8+
import inspect
9+
from dataclasses import dataclass
10+
from functools import update_wrapper
11+
from inspect import Parameter, signature
12+
from typing import (
13+
Callable,
14+
Generic,
15+
ParamSpec,
16+
Type,
17+
TypeVar,
18+
TypedDict,
19+
get_type_hints,
20+
Any,
21+
)
22+
23+
P = ParamSpec("P")
24+
T = TypeVar("T")
25+
26+
27+
@dataclass(frozen=True)
28+
class SignalParameter:
29+
"""Parameter metadata for a signal handler."""
30+
31+
name: str
32+
type_hint: Type | None
33+
has_default: bool
34+
default_value: Any
35+
36+
37+
class SignalDefinitionOptions(TypedDict, total=False):
38+
"""Options for defining a signal."""
39+
40+
name: str
41+
42+
43+
class SignalDefinition(Generic[P, T]):
44+
"""
45+
Definition of a signal handler with metadata.
46+
47+
Similar to ActivityDefinition but for signal handlers.
48+
Provides type safety and metadata for signal handlers.
49+
"""
50+
51+
def __init__(
52+
self,
53+
wrapped: Callable[P, T],
54+
name: str,
55+
params: list[SignalParameter],
56+
is_async: bool,
57+
):
58+
self._wrapped = wrapped
59+
self._name = name
60+
self._params = params
61+
self._is_async = is_async
62+
update_wrapper(self, wrapped)
63+
64+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
65+
"""Call the wrapped signal handler function."""
66+
return self._wrapped(*args, **kwargs)
67+
68+
@property
69+
def name(self) -> str:
70+
"""Get the signal name."""
71+
return self._name
72+
73+
@property
74+
def params(self) -> list[SignalParameter]:
75+
"""Get the signal parameters."""
76+
return self._params
77+
78+
@property
79+
def is_async(self) -> bool:
80+
"""Check if the signal handler is async."""
81+
return self._is_async
82+
83+
@property
84+
def wrapped(self) -> Callable[P, T]:
85+
"""Get the wrapped signal handler function."""
86+
return self._wrapped
87+
88+
@staticmethod
89+
def wrap(
90+
fn: Callable[P, T], opts: SignalDefinitionOptions
91+
) -> "SignalDefinition[P, T]":
92+
"""
93+
Wrap a function as a SignalDefinition.
94+
95+
This is an internal method used by WorkflowDefinition to create signal definitions
96+
from methods decorated with @workflow.signal.
97+
98+
Args:
99+
fn: The signal handler function to wrap
100+
opts: Options for the signal definition
101+
102+
Returns:
103+
A SignalDefinition instance
104+
105+
Raises:
106+
ValueError: If return type is not None
107+
"""
108+
name = opts.get("name") or fn.__qualname__
109+
is_async = inspect.iscoroutinefunction(fn)
110+
params = _get_signal_signature(fn)
111+
_validate_signal_return_type(fn)
112+
113+
return SignalDefinition[P, T](fn, name, params, is_async)
114+
115+
116+
def _validate_signal_return_type(fn: Callable) -> None:
117+
"""
118+
Validate that signal handler returns None.
119+
120+
Args:
121+
fn: The signal handler function
122+
123+
Raises:
124+
ValueError: If return type is not None
125+
"""
126+
try:
127+
hints = get_type_hints(fn)
128+
ret_type = hints.get("return", inspect.Signature.empty)
129+
130+
if ret_type is not None and ret_type is not inspect.Signature.empty:
131+
raise ValueError(
132+
f"Signal handler '{fn.__qualname__}' must return None "
133+
f"(signals cannot return values), got {ret_type}"
134+
)
135+
except NameError:
136+
pass
137+
138+
139+
def _get_signal_signature(fn: Callable[P, T]) -> list[SignalParameter]:
140+
"""
141+
Extract parameter information from a signal handler function.
142+
143+
Args:
144+
fn: The signal handler function
145+
146+
Returns:
147+
List of SignalParameter objects
148+
149+
Raises:
150+
ValueError: If parameters are not positional
151+
"""
152+
sig = signature(fn)
153+
args = sig.parameters
154+
hints = get_type_hints(fn)
155+
params = []
156+
157+
for name, param in args.items():
158+
# Filter out the self parameter for instance methods
159+
if param.name == "self":
160+
continue
161+
162+
has_default = param.default != Parameter.empty
163+
default = param.default if has_default else None
164+
165+
if param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD):
166+
type_hint = hints.get(name, None)
167+
params.append(SignalParameter(name, type_hint, has_default, default))
168+
else:
169+
raise ValueError(
170+
f"Signal handler '{fn.__qualname__}' parameter '{name}' must be positional, "
171+
f"got {param.kind.name}"
172+
)
173+
174+
return params

cadence/workflow.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import inspect
2020

2121
from cadence.data_converter import DataConverter
22+
from cadence.signal import SignalDefinition, SignalDefinitionOptions
2223

2324
ResultType = TypeVar("ResultType")
2425

@@ -60,10 +61,22 @@ class WorkflowDefinition(Generic[C]):
6061
Provides type safety and metadata for workflow classes.
6162
"""
6263

63-
def __init__(self, cls: Type[C], name: str, run_method_name: str):
64+
def __init__(
65+
self,
66+
cls: Type[C],
67+
name: str,
68+
run_method_name: str,
69+
signals: dict[str, SignalDefinition[..., Any]],
70+
):
6471
self._cls: Type[C] = cls
6572
self._name = name
6673
self._run_method_name = run_method_name
74+
self._signals = signals
75+
76+
@property
77+
def signals(self) -> dict[str, SignalDefinition[..., Any]]:
78+
"""Get the signal definitions."""
79+
return self._signals
6780

6881
@property
6982
def name(self) -> str:
@@ -99,6 +112,11 @@ def wrap(cls: Type, opts: WorkflowDefinitionOptions) -> "WorkflowDefinition":
99112
name = opts["name"]
100113

101114
# Validate that the class has exactly one run method and find it
115+
# Also validate that class does not have multiple signal methods with the same name
116+
signals: dict[str, SignalDefinition[..., Any]] = {}
117+
signal_names: dict[
118+
str, str
119+
] = {} # Map signal name to method name for duplicate detection
102120
run_method_name = None
103121
for attr_name in dir(cls):
104122
if attr_name.startswith("_"):
@@ -116,10 +134,24 @@ def wrap(cls: Type, opts: WorkflowDefinitionOptions) -> "WorkflowDefinition":
116134
)
117135
run_method_name = attr_name
118136

137+
if hasattr(attr, "_workflow_signal"):
138+
signal_name = getattr(attr, "_workflow_signal")
139+
if signal_name in signal_names:
140+
raise ValueError(
141+
f"Multiple @workflow.signal methods found in class {cls.__name__} "
142+
f"with signal name '{signal_name}': '{attr_name}' and '{signal_names[signal_name]}'"
143+
)
144+
# Create SignalDefinition from the decorated method
145+
signal_def = SignalDefinition.wrap(
146+
attr, SignalDefinitionOptions(name=signal_name)
147+
)
148+
signals[signal_name] = signal_def
149+
signal_names[signal_name] = attr_name
150+
119151
if run_method_name is None:
120152
raise ValueError(f"No @workflow.run method found in class {cls.__name__}")
121153

122-
return WorkflowDefinition(cls, name, run_method_name)
154+
return WorkflowDefinition(cls, name, run_method_name, signals)
123155

124156

125157
def run(func: Optional[T] = None) -> Union[T, Callable[[T], T]]:
@@ -163,6 +195,36 @@ def decorator(f: T) -> T:
163195
return decorator(func)
164196

165197

198+
def signal(name: str | None = None) -> Callable[[T], T]:
199+
"""
200+
Decorator to mark a method as a workflow signal handler.
201+
202+
Example:
203+
@workflow.signal(name="approval_channel")
204+
async def approve(self, approved: bool):
205+
self.approved = approved
206+
207+
Args:
208+
name: The name of the signal
209+
210+
Returns:
211+
The decorated method with workflow signal metadata
212+
213+
Raises:
214+
ValueError: If name is not provided
215+
216+
"""
217+
if name is None:
218+
raise ValueError("name is required")
219+
220+
def decorator(f: T) -> T:
221+
f._workflow_signal = name # type: ignore
222+
return f
223+
224+
# Only allow @workflow.signal(name), require name to be explicitly provided
225+
return decorator
226+
227+
166228
@dataclass(frozen=True)
167229
class WorkflowInfo:
168230
workflow_type: str

tests/cadence/worker/test_registry.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from cadence import workflow
1010
from cadence.worker import Registry
1111
from cadence.workflow import WorkflowDefinition
12+
from cadence.signal import SignalDefinition
1213
from tests.cadence import common_activities
1314

1415

@@ -212,3 +213,99 @@ async def run(self, input: str) -> str:
212213
workflow_def = reg.get_workflow("custom_workflow_name")
213214
assert workflow_def.name == "custom_workflow_name"
214215
assert workflow_def.cls == CustomWorkflow
216+
217+
def test_workflow_with_signal(self):
218+
"""Test workflow with signal handler."""
219+
reg = Registry()
220+
221+
@reg.workflow
222+
class WorkflowWithSignal:
223+
@workflow.run
224+
async def run(self):
225+
return "done"
226+
227+
@workflow.signal(name="approval")
228+
async def handle_approval(self, approved: bool):
229+
self.approved = approved
230+
231+
workflow_def = reg.get_workflow("WorkflowWithSignal")
232+
assert isinstance(workflow_def, WorkflowDefinition)
233+
assert len(workflow_def.signals) == 1
234+
assert "approval" in workflow_def.signals
235+
signal_def = workflow_def.signals["approval"]
236+
assert isinstance(signal_def, SignalDefinition)
237+
assert signal_def.name == "approval"
238+
assert signal_def.is_async is True
239+
assert len(signal_def.params) == 1
240+
assert signal_def.params[0].name == "approved"
241+
242+
def test_workflow_with_multiple_signals(self):
243+
"""Test workflow with multiple signal handlers."""
244+
reg = Registry()
245+
246+
@reg.workflow
247+
class WorkflowWithMultipleSignals:
248+
@workflow.run
249+
async def run(self):
250+
return "done"
251+
252+
@workflow.signal(name="approval")
253+
async def handle_approval(self, approved: bool):
254+
self.approved = approved
255+
256+
@workflow.signal(name="cancel")
257+
async def handle_cancel(self):
258+
self.cancelled = True
259+
260+
workflow_def = reg.get_workflow("WorkflowWithMultipleSignals")
261+
assert len(workflow_def.signals) == 2
262+
assert "approval" in workflow_def.signals
263+
assert "cancel" in workflow_def.signals
264+
assert isinstance(workflow_def.signals["approval"], SignalDefinition)
265+
assert isinstance(workflow_def.signals["cancel"], SignalDefinition)
266+
assert workflow_def.signals["approval"].name == "approval"
267+
assert workflow_def.signals["cancel"].name == "cancel"
268+
269+
def test_signal_decorator_requires_name(self):
270+
"""Test that signal decorator requires name parameter."""
271+
with pytest.raises(ValueError, match="name is required"):
272+
273+
@workflow.signal()
274+
async def test_signal(self):
275+
pass
276+
277+
def test_workflow_without_signals(self):
278+
"""Test that workflow without signals has empty signals dict."""
279+
reg = Registry()
280+
281+
@reg.workflow
282+
class WorkflowWithoutSignals:
283+
@workflow.run
284+
async def run(self):
285+
return "done"
286+
287+
workflow_def = reg.get_workflow("WorkflowWithoutSignals")
288+
assert isinstance(workflow_def.signals, dict)
289+
assert len(workflow_def.signals) == 0
290+
291+
def test_duplicate_signal_names_error(self):
292+
"""Test that duplicate signal names raise ValueError."""
293+
reg = Registry()
294+
295+
with pytest.raises(
296+
ValueError, match="Multiple.*signal.*found.*with signal name 'approval'"
297+
):
298+
299+
@reg.workflow
300+
class WorkflowWithDuplicateSignalNames:
301+
@workflow.run
302+
async def run(self):
303+
return "done"
304+
305+
@workflow.signal(name="approval")
306+
async def handle_approval(self, approved: bool):
307+
self.approved = approved
308+
309+
@workflow.signal(name="approval")
310+
async def handle_approval_different(self):
311+
self.also_approved = True

0 commit comments

Comments
 (0)