22from contextlib import contextmanager
33from contextvars import ContextVar
44from dataclasses import dataclass
5- from typing import Iterator , Callable , TypeVar , TypedDict , Type , cast , Any
6- from functools import wraps
5+ from typing import Iterator , Callable , TypeVar , TypedDict , Type , cast , Any , Optional , Union
6+ import inspect
77
88from cadence .client import Client
99
10- T = TypeVar ('T' )
10+ T = TypeVar ('T' , bound = Callable [..., Any ] )
1111
1212
1313class WorkflowDefinitionOptions (TypedDict , total = False ):
@@ -23,9 +23,10 @@ class WorkflowDefinition:
2323 Provides type safety and metadata for workflow classes.
2424 """
2525
26- def __init__ (self , cls : Type , name : str ):
26+ def __init__ (self , cls : Type , name : str , run_method_name : str ):
2727 self ._cls = cls
2828 self ._name = name
29+ self ._run_method_name = run_method_name
2930
3031 @property
3132 def name (self ) -> str :
@@ -39,13 +40,7 @@ def cls(self) -> Type:
3940
4041 def get_run_method (self , instance : Any ) -> Callable :
4142 """Get the workflow run method from an instance of the workflow class."""
42- for attr_name in dir (instance ):
43- if attr_name .startswith ('_' ):
44- continue
45- attr = getattr (instance , attr_name )
46- if callable (attr ) and hasattr (attr , '_workflow_run' ):
47- return cast (Callable , attr )
48- raise ValueError (f"No @workflow.run method found in class { self ._cls .__name__ } " )
43+ return cast (Callable , getattr (instance , self ._run_method_name ))
4944
5045 @staticmethod
5146 def wrap (cls : Type , opts : WorkflowDefinitionOptions ) -> 'WorkflowDefinition' :
@@ -66,8 +61,8 @@ def wrap(cls: Type, opts: WorkflowDefinitionOptions) -> 'WorkflowDefinition':
6661 if "name" in opts and opts ["name" ]:
6762 name = opts ["name" ]
6863
69- # Validate that the class has exactly one run method
70- run_method_count = 0
64+ # Validate that the class has exactly one run method and find it
65+ run_method_name = None
7166 for attr_name in dir (cls ):
7267 if attr_name .startswith ('_' ):
7368 continue
@@ -78,40 +73,54 @@ def wrap(cls: Type, opts: WorkflowDefinitionOptions) -> 'WorkflowDefinition':
7873
7974 # Check for workflow run method
8075 if hasattr (attr , '_workflow_run' ):
81- run_method_count += 1
76+ if run_method_name is not None :
77+ raise ValueError (f"Multiple @workflow.run methods found in class { cls .__name__ } " )
78+ run_method_name = attr_name
8279
83- if run_method_count == 0 :
80+ if run_method_name is None :
8481 raise ValueError (f"No @workflow.run method found in class { cls .__name__ } " )
85- elif run_method_count > 1 :
86- raise ValueError (f"Multiple @workflow.run methods found in class { cls .__name__ } " )
8782
88- return WorkflowDefinition (cls , name )
83+ return WorkflowDefinition (cls , name , run_method_name )
8984
9085
91- def run (func : Callable [..., T ] ) -> Callable [... , T ]:
86+ def run (func : Optional [ T ] = None ) -> Union [ T , Callable [[ T ] , T ] ]:
9287 """
9388 Decorator to mark a method as the main workflow run method.
9489
90+ Can be used with or without parentheses:
91+ @workflow.run
92+ async def my_workflow(self):
93+ ...
94+
95+ @workflow.run()
96+ async def my_workflow(self):
97+ ...
98+
9599 Args:
96100 func: The method to mark as the workflow run method
97101
98102 Returns:
99103 The decorated method with workflow run metadata
104+
105+ Raises:
106+ ValueError: If the function is not async
100107 """
101- @wraps (func )
102- def wrapper (* args , ** kwargs ):
103- return func (* args , ** kwargs )
104-
105- # Attach metadata to the function
106- wrapper ._workflow_run = True # type: ignore
107- return wrapper
108-
109-
110- # Create a simple namespace object for the workflow decorators
111- class _WorkflowNamespace :
112- run = staticmethod (run )
113-
114- workflow = _WorkflowNamespace ()
108+ def decorator (f : T ) -> T :
109+ # Validate that the function is async
110+ if not inspect .iscoroutinefunction (f ):
111+ raise ValueError (f"Workflow run method '{ f .__name__ } ' must be async" )
112+
113+ # Attach metadata to the function
114+ f ._workflow_run = True # type: ignore
115+ return f
116+
117+ # Support both @workflow.run and @workflow.run()
118+ if func is None :
119+ # Called with parentheses: @workflow.run()
120+ return decorator
121+ else :
122+ # Called without parentheses: @workflow.run
123+ return decorator (func )
115124
116125
117126@dataclass
0 commit comments