77"""
88
99import logging
10- from typing import Callable , Dict , Optional , Unpack , TypedDict , Sequence , overload
10+ from typing import Callable , Dict , Optional , Unpack , TypedDict , overload , Type , Union , TypeVar
1111from cadence .activity import ActivityDefinitionOptions , ActivityDefinition , ActivityDecorator , P , T
12+ from cadence .workflow import WorkflowDefinition , WorkflowDefinitionOptions
1213
1314logger = logging .getLogger (__name__ )
1415
16+ # TypeVar for workflow class types
17+ W = TypeVar ('W' )
18+
1519
1620class RegisterWorkflowOptions (TypedDict , total = False ):
1721 """Options for registering a workflow."""
@@ -28,53 +32,58 @@ class Registry:
2832
2933 def __init__ (self ) -> None :
3034 """Initialize the registry."""
31- self ._workflows : Dict [str , Callable ] = {}
35+ self ._workflows : Dict [str , WorkflowDefinition ] = {}
3236 self ._activities : Dict [str , ActivityDefinition ] = {}
3337 self ._workflow_aliases : Dict [str , str ] = {} # alias -> name mapping
3438
3539 def workflow (
3640 self ,
37- func : Optional [Callable ] = None ,
41+ cls : Optional [Type [ W ] ] = None ,
3842 ** kwargs : Unpack [RegisterWorkflowOptions ]
39- ) -> Callable :
43+ ) -> Union [ Type [ W ], Callable [[ Type [ W ]], Type [ W ]]] :
4044 """
41- Register a workflow function .
42-
45+ Register a workflow class .
46+
4347 This method can be used as a decorator or called directly.
44-
48+ Only supports class-based workflows.
49+
4550 Args:
46- func : The workflow function to register
51+ cls : The workflow class to register
4752 **kwargs: Options for registration (name, alias)
48-
53+
4954 Returns:
50- The decorated function or the function itself
51-
55+ The decorated class
56+
5257 Raises:
5358 KeyError: If workflow name already exists
59+ ValueError: If class workflow is invalid
5460 """
5561 options = RegisterWorkflowOptions (** kwargs )
56-
57- def decorator (f : Callable ) -> Callable :
58- workflow_name = options .get ('name' ) or f .__name__
59-
62+
63+ def decorator (target : Type [ W ] ) -> Type [ W ] :
64+ workflow_name = options .get ('name' ) or target .__name__
65+
6066 if workflow_name in self ._workflows :
6167 raise KeyError (f"Workflow '{ workflow_name } ' is already registered" )
62-
63- self ._workflows [workflow_name ] = f
64-
68+
69+ # Create WorkflowDefinition with type information
70+ workflow_opts = WorkflowDefinitionOptions (name = workflow_name )
71+ workflow_def = WorkflowDefinition .wrap (target , workflow_opts )
72+ self ._workflows [workflow_name ] = workflow_def
73+
6574 # Register alias if provided
6675 alias = options .get ('alias' )
6776 if alias :
6877 if alias in self ._workflow_aliases :
6978 raise KeyError (f"Workflow alias '{ alias } ' is already registered" )
7079 self ._workflow_aliases [alias ] = workflow_name
71-
80+
7281 logger .info (f"Registered workflow '{ workflow_name } '" )
73- return f
74-
75- if func is None :
82+ return target
83+
84+ if cls is None :
7685 return decorator
77- return decorator (func )
86+ return decorator (cls )
7887
7988 @overload
8089 def activity (self , func : Callable [P , T ]) -> ActivityDefinition [P , T ]:
@@ -135,25 +144,25 @@ def _register_activity(self, defn: ActivityDefinition) -> None:
135144 self ._activities [defn .name ] = defn
136145
137146
138- def get_workflow (self , name : str ) -> Callable :
147+ def get_workflow (self , name : str ) -> WorkflowDefinition :
139148 """
140149 Get a registered workflow by name.
141-
150+
142151 Args:
143152 name: Name or alias of the workflow
144-
153+
145154 Returns:
146- The workflow function
147-
155+ The workflow definition
156+
148157 Raises:
149158 KeyError: If workflow is not found
150159 """
151160 # Check if it's an alias
152161 actual_name = self ._workflow_aliases .get (name , name )
153-
162+
154163 if actual_name not in self ._workflows :
155164 raise KeyError (f"Workflow '{ name } ' not found in registry" )
156-
165+
157166 return self ._workflows [actual_name ]
158167
159168 def get_activity (self , name : str ) -> ActivityDefinition :
@@ -188,7 +197,7 @@ def of(*args: 'Registry') -> 'Registry':
188197
189198 return result
190199
191- def _find_activity_definitions (instance : object ) -> Sequence [ActivityDefinition ]:
200+ def _find_activity_definitions (instance : object ) -> list [ActivityDefinition ]:
192201 attr_to_def = {}
193202 for t in instance .__class__ .__mro__ :
194203 for attr in dir (t ):
@@ -200,10 +209,7 @@ def _find_activity_definitions(instance: object) -> Sequence[ActivityDefinition]
200209 raise ValueError (f"'{ attr } ' was overridden with a duplicate activity definition" )
201210 attr_to_def [attr ] = value
202211
203- # Create new definitions, copying the attributes from the declaring type but using the function
204- # from the specific object. This allows for the decorator to be applied to the base class and the
205- # function to be overridden
206- result = []
212+ result : list [ActivityDefinition ] = []
207213 for attr , definition in attr_to_def .items ():
208214 result .append (ActivityDefinition (getattr (instance , attr ), definition .name , definition .strategy , definition .params ))
209215
0 commit comments