66import sys
77import types
88import typing
9+ from abc import ABC
910from typing import (
1011 Any ,
1112 AsyncGenerator ,
@@ -42,6 +43,16 @@ def reads(self) -> list[str]:
4243 """
4344 pass
4445
46+ @property
47+ def default_reads (self ) -> Dict [str , Any ]:
48+ """Default values to read from state if they are not there already.
49+ This just fills out the gaps in state. This must be a subset
50+ of the ``reads`` value.
51+
52+ :return:
53+ """
54+ return {}
55+
4556 @abc .abstractmethod
4657 def run (self , state : State , ** run_kwargs ) -> dict :
4758 """Runs the function on the given state and returns the result.
@@ -122,18 +133,68 @@ def writes(self) -> list[str]:
122133 """
123134 pass
124135
136+ @property
137+ def default_writes (self ) -> Dict [str , Any ]:
138+ """Default state writes for the reducer. If nothing writes this field from within
139+ the reducer, then this will be written. Note that this is not (currently)
140+ intended to work with append/increment operations.
141+
142+ This must be a subset of the ``writes`` value.
143+
144+ :return: A key/value dictionary of default writes.
145+ """
146+ return {}
147+
125148 @abc .abstractmethod
126149 def update (self , result : dict , state : State ) -> State :
127150 pass
128151
129152
130- class Action (Function , Reducer , abc .ABC ):
153+ class _PostValidator (abc .ABCMeta ):
154+ """Metaclass to allow for __post_init__ to be called after __init__.
155+ While this is general we're keeping it here for now as it is only used
156+ by the Action class. This enables us to ensure that the default_reads are correct.
157+ """
158+
159+ def __call__ (cls , * args , ** kwargs ):
160+ instance = super ().__call__ (* args , ** kwargs )
161+ if post := getattr (cls , "__post_init__" , None ):
162+ post (instance )
163+ return instance
164+
165+
166+ class Action (Function , Reducer , ABC , metaclass = _PostValidator ):
131167 def __init__ (self ):
132168 """Represents an action in a state machine. This is the base class from which
133169 actions extend. Note that this class needs to have a name set after the fact.
134170 """
135171 self ._name = None
136172
173+ def __post_init__ (self ):
174+ self ._validate_defaults ()
175+
176+ def _validate_defaults (self ):
177+ reads = set (self .reads )
178+ missing_default_reads = {key for key in self .default_reads .keys () if key not in reads }
179+ if missing_default_reads :
180+ raise ValueError (
181+ f"The following default state reads are not in the set of reads for action: { self } : { ', ' .join (missing_default_reads )} . "
182+ f"Every read in default_reads must be in the reads list."
183+ )
184+ writes = self .writes
185+ missing_default_writes = {key for key in self .default_writes .keys () if key not in writes }
186+ if missing_default_writes :
187+ raise ValueError (
188+ f"The following default state writes are not in the set of writes for action: { self } : { ', ' .join (missing_default_writes )} . "
189+ f"Every write in default_writes must be in the writes list."
190+ )
191+ default_writes_also_in_reads = {key for key in self .default_writes .keys () if key in reads }
192+ if default_writes_also_in_reads :
193+ raise ValueError (
194+ f"The following default state writes are also in the reads for action: { self } : { ', ' .join (default_writes_also_in_reads )} . "
195+ f"Every write in default_writes must not be in the reads list -- this leads to undefined behavior."
196+ )
197+
137198 def with_name (self , name : str ) -> Self :
138199 """Returns a copy of the given action with the given name. Why do we need this?
139200 We instantiate actions without names, and then set them later. This is a way to
@@ -484,6 +545,8 @@ def __init__(
484545 fn : Callable ,
485546 reads : List [str ],
486547 writes : List [str ],
548+ default_reads : Dict [str , Any ] = None ,
549+ default_writes : Dict [str , Any ] = None ,
487550 bound_params : dict = None ,
488551 ):
489552 """Instantiates a function-based action with the given function, reads, and writes.
@@ -499,11 +562,21 @@ def __init__(
499562 self ._writes = writes
500563 self ._bound_params = bound_params if bound_params is not None else {}
501564 self ._inputs = _get_inputs (self ._bound_params , self ._fn )
565+ self ._default_reads = default_reads if default_reads is not None else {}
566+ self ._default_writes = default_writes if default_writes is not None else {}
502567
503568 @property
504569 def fn (self ) -> Callable :
505570 return self ._fn
506571
572+ @property
573+ def default_reads (self ) -> Dict [str , Any ]:
574+ return self ._default_reads
575+
576+ @property
577+ def default_writes (self ) -> Dict [str , Any ]:
578+ return self ._default_writes
579+
507580 @property
508581 def reads (self ) -> list [str ]:
509582 return self ._reads
@@ -526,7 +599,12 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedAction":
526599 :return:
527600 """
528601 return FunctionBasedAction (
529- self ._fn , self ._reads , self ._writes , {** self ._bound_params , ** kwargs }
602+ self ._fn ,
603+ self ._reads ,
604+ self ._writes ,
605+ self .default_reads ,
606+ self ._default_writes ,
607+ {** self ._bound_params , ** kwargs },
530608 )
531609
532610 def run_and_update (self , state : State , ** run_kwargs ) -> tuple [dict , State ]:
@@ -918,6 +996,8 @@ def __init__(
918996 ],
919997 reads : List [str ],
920998 writes : List [str ],
999+ default_reads : Optional [Dict [str , Any ]] = None ,
1000+ default_writes : Optional [Dict [str , Any ]] = None ,
9211001 bound_params : dict = None ,
9221002 ):
9231003 """Instantiates a function-based streaming action with the given function, reads, and writes.
@@ -931,6 +1011,8 @@ def __init__(
9311011 self ._fn = fn
9321012 self ._reads = reads
9331013 self ._writes = writes
1014+ self ._default_reads = default_reads if default_reads is not None else {}
1015+ self ._default_writes = default_writes if default_writes is not None else {}
9341016 self ._bound_params = bound_params if bound_params is not None else {}
9351017
9361018 async def _a_stream_run_and_update (
@@ -957,6 +1039,14 @@ def reads(self) -> list[str]:
9571039 def writes (self ) -> list [str ]:
9581040 return self ._writes
9591041
1042+ @property
1043+ def default_writes (self ) -> Dict [str , Any ]:
1044+ return self ._default_writes
1045+
1046+ @property
1047+ def default_reads (self ) -> Dict [str , Any ]:
1048+ return self ._default_reads
1049+
9601050 @property
9611051 def streaming (self ) -> bool :
9621052 return True
@@ -969,7 +1059,12 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedStreamingAction":
9691059 :return:
9701060 """
9711061 return FunctionBasedStreamingAction (
972- self ._fn , self ._reads , self ._writes , {** self ._bound_params , ** kwargs }
1062+ self ._fn ,
1063+ self ._reads ,
1064+ self ._writes ,
1065+ self ._default_reads ,
1066+ self ._default_writes ,
1067+ {** self ._bound_params , ** kwargs },
9731068 )
9741069
9751070 @property
@@ -999,7 +1094,10 @@ def bind(self, **kwargs: Any) -> Self:
9991094 ...
10001095
10011096
1002- def copy_func (f : types .FunctionType ) -> types .FunctionType :
1097+ T = TypeVar ("T" , bound = types .FunctionType )
1098+
1099+
1100+ def copy_func (f : T ) -> T :
10031101 """Copies a function. This is used internally to bind parameters to a function
10041102 so we don't accidentally overwrite them.
10051103
@@ -1033,7 +1131,12 @@ def my_action(state: State, z: int) -> tuple[dict, State]:
10331131 return self
10341132
10351133
1036- def action (reads : List [str ], writes : List [str ]) -> Callable [[Callable ], FunctionRepresentingAction ]:
1134+ def action (
1135+ reads : List [str ],
1136+ writes : List [str ],
1137+ default_reads : Dict [str , Any ] = None ,
1138+ default_writes : Dict [str , Any ] = None ,
1139+ ) -> Callable [[Callable ], FunctionRepresentingAction ]:
10371140 """Decorator to create a function-based action. This is user-facing.
10381141 Note that, in the future, with typed state, we may not need this for
10391142 all cases.
@@ -1044,19 +1147,38 @@ def action(reads: List[str], writes: List[str]) -> Callable[[Callable], Function
10441147
10451148 :param reads: Items to read from the state
10461149 :param writes: Items to write to the state
1150+ :param default_reads: Default values for reads. If nothing upstream produces these, they will
1151+ be filled automatically. This is equivalent to adding
1152+ ``state = state.update(**{key: value for key, value in default_reads.items() if key not in state})``
1153+ at the beginning of your function.
1154+ :param default_writes: Default values for writes. If the action's state update does not write to this,
1155+ they will be filled automatically with the default values. Leaving blank will have no default values.
1156+ This is equivalent to adding state = state.update(***deafult_writes) at the beginning of the function.
1157+ Note that this will not work as intended with append/increment operations, so be careful.
10471158 :return: The decorator to assign the function as an action
10481159 """
1160+ default_reads = default_reads if default_reads is not None else {}
1161+ default_writes = default_writes if default_writes is not None else {}
10491162
10501163 def decorator (fn ) -> FunctionRepresentingAction :
1051- setattr (fn , FunctionBasedAction .ACTION_FUNCTION , FunctionBasedAction (fn , reads , writes ))
1164+ setattr (
1165+ fn ,
1166+ FunctionBasedAction .ACTION_FUNCTION ,
1167+ FunctionBasedAction (
1168+ fn , reads , writes , default_reads = default_reads , default_writes = default_writes
1169+ ),
1170+ )
10521171 setattr (fn , "bind" , types .MethodType (bind , fn ))
10531172 return fn
10541173
10551174 return decorator
10561175
10571176
10581177def streaming_action (
1059- reads : List [str ], writes : List [str ]
1178+ reads : List [str ],
1179+ writes : List [str ],
1180+ default_reads : Optional [Dict [str , Any ]] = None ,
1181+ default_writes : Optional [Dict [str , Any ]] = None ,
10601182) -> Callable [[Callable ], FunctionRepresentingAction ]:
10611183 """Decorator to create a streaming function-based action. This is user-facing.
10621184
@@ -1090,14 +1212,28 @@ def streaming_response(state: State) -> Generator[dict, None, tuple[dict, State]
10901212 # return the final result
10911213 return {'response': full_response}, state.update(response=full_response)
10921214
1215+ :param reads: Items to read from the state
1216+ :param writes: Items to write to the state
1217+ :param default_reads: Default values for reads. If nothing upstream produces these, they will
1218+ be filled automatically. This is equivalent to adding
1219+ ``state = state.update(**{key: value for key, value in default_reads.items() if key not in state})``
1220+ at the beginning of your function.
1221+ :param default_writes: Default values for writes. If the action's state update does not write to this,
1222+ they will be filled automatically with the default values. Leaving blank will have no default values.
1223+ This is equivalent to adding state = state.update(***deafult_writes) at the beginning of the function.
1224+ Note that this will not work as intended with append/increment operations, so be careful.
1225+ :return: The decorator to assign the function as an action
1226+
10931227 """
1228+ default_reads = default_reads if default_reads is not None else {}
1229+ default_writes = default_writes if default_writes is not None else {}
10941230
10951231 def wrapped (fn ) -> FunctionRepresentingAction :
10961232 fn = copy_func (fn )
10971233 setattr (
10981234 fn ,
10991235 FunctionBasedAction .ACTION_FUNCTION ,
1100- FunctionBasedStreamingAction (fn , reads , writes ),
1236+ FunctionBasedStreamingAction (fn , reads , writes , default_reads , default_writes ),
11011237 )
11021238 setattr (fn , "bind" , types .MethodType (bind , fn ))
11031239 return fn
0 commit comments