Skip to content

Commit 308be71

Browse files
committed
Adds default reads/writes to burr actions
This allows you to specify defaults if your action does not write. In the majority of cases they will be none, but this allows simple (static) arbitrary values. This specifically helps with the branching case -- e.g. where you have two options, and want to null out anything it doesn't write. For instance, an error and a result -- you'll only ever produce one or the other. This works both in the function and class-based approaches -- in the function-based it is part of the two decorators (@action/@streaming_action). In the class-based it is part of the class, overriding the default_reads and default_writes property function We add a bunch of new tests for default (as the code to handle multiple action types is fairly dispersed, for now), and also make the naming of the other tests/content more consistent. Note that this does not currently work with settings defaults to append/increment operations -- it will produce strange behavior. This is documented in all appropriate signatures. This also does not work (or even make sense) in the case that the function writes a default that it also reads. In that case, it will clobber the current value with the write value. To avoid this, we just error out if that is the case beforehand.
1 parent bb2c446 commit 308be71

File tree

7 files changed

+753
-75
lines changed

7 files changed

+753
-75
lines changed

burr/core/action.py

Lines changed: 144 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import sys
77
import types
88
import typing
9+
from abc import ABC
910
from typing import (
1011
Any,
1112
AsyncGenerator,
@@ -42,6 +43,16 @@ def reads(self) -> list[str]:
4243
"""
4344
pass
4445

46+
@property
47+
def default_reads(self) -> Dict[str, Any]:
48+
"""Default values to read from state if they are not there already.
49+
This just fills out the gaps in state. This must be a subset
50+
of the ``reads`` value.
51+
52+
:return:
53+
"""
54+
return {}
55+
4556
@abc.abstractmethod
4657
def run(self, state: State, **run_kwargs) -> dict:
4758
"""Runs the function on the given state and returns the result.
@@ -122,18 +133,68 @@ def writes(self) -> list[str]:
122133
"""
123134
pass
124135

136+
@property
137+
def default_writes(self) -> Dict[str, Any]:
138+
"""Default state writes for the reducer. If nothing writes this field from within
139+
the reducer, then this will be written. Note that this is not (currently)
140+
intended to work with append/increment operations.
141+
142+
This must be a subset of the ``writes`` value.
143+
144+
:return: A key/value dictionary of default writes.
145+
"""
146+
return {}
147+
125148
@abc.abstractmethod
126149
def update(self, result: dict, state: State) -> State:
127150
pass
128151

129152

130-
class Action(Function, Reducer, abc.ABC):
153+
class _PostValidator(abc.ABCMeta):
154+
"""Metaclass to allow for __post_init__ to be called after __init__.
155+
While this is general we're keeping it here for now as it is only used
156+
by the Action class. This enables us to ensure that the default_reads are correct.
157+
"""
158+
159+
def __call__(cls, *args, **kwargs):
160+
instance = super().__call__(*args, **kwargs)
161+
if post := getattr(cls, "__post_init__", None):
162+
post(instance)
163+
return instance
164+
165+
166+
class Action(Function, Reducer, ABC, metaclass=_PostValidator):
131167
def __init__(self):
132168
"""Represents an action in a state machine. This is the base class from which
133169
actions extend. Note that this class needs to have a name set after the fact.
134170
"""
135171
self._name = None
136172

173+
def __post_init__(self):
174+
self._validate_defaults()
175+
176+
def _validate_defaults(self):
177+
reads = set(self.reads)
178+
missing_default_reads = {key for key in self.default_reads.keys() if key not in reads}
179+
if missing_default_reads:
180+
raise ValueError(
181+
f"The following default state reads are not in the set of reads for action: {self}: {', '.join(missing_default_reads)}. "
182+
f"Every read in default_reads must be in the reads list."
183+
)
184+
writes = self.writes
185+
missing_default_writes = {key for key in self.default_writes.keys() if key not in writes}
186+
if missing_default_writes:
187+
raise ValueError(
188+
f"The following default state writes are not in the set of writes for action: {self}: {', '.join(missing_default_writes)}. "
189+
f"Every write in default_writes must be in the writes list."
190+
)
191+
default_writes_also_in_reads = {key for key in self.default_writes.keys() if key in reads}
192+
if default_writes_also_in_reads:
193+
raise ValueError(
194+
f"The following default state writes are also in the reads for action: {self}: {', '.join(default_writes_also_in_reads)}. "
195+
f"Every write in default_writes must not be in the reads list -- this leads to undefined behavior."
196+
)
197+
137198
def with_name(self, name: str) -> Self:
138199
"""Returns a copy of the given action with the given name. Why do we need this?
139200
We instantiate actions without names, and then set them later. This is a way to
@@ -484,6 +545,8 @@ def __init__(
484545
fn: Callable,
485546
reads: List[str],
486547
writes: List[str],
548+
default_reads: Dict[str, Any] = None,
549+
default_writes: Dict[str, Any] = None,
487550
bound_params: dict = None,
488551
):
489552
"""Instantiates a function-based action with the given function, reads, and writes.
@@ -499,11 +562,21 @@ def __init__(
499562
self._writes = writes
500563
self._bound_params = bound_params if bound_params is not None else {}
501564
self._inputs = _get_inputs(self._bound_params, self._fn)
565+
self._default_reads = default_reads if default_reads is not None else {}
566+
self._default_writes = default_writes if default_writes is not None else {}
502567

503568
@property
504569
def fn(self) -> Callable:
505570
return self._fn
506571

572+
@property
573+
def default_reads(self) -> Dict[str, Any]:
574+
return self._default_reads
575+
576+
@property
577+
def default_writes(self) -> Dict[str, Any]:
578+
return self._default_writes
579+
507580
@property
508581
def reads(self) -> list[str]:
509582
return self._reads
@@ -526,7 +599,12 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedAction":
526599
:return:
527600
"""
528601
return FunctionBasedAction(
529-
self._fn, self._reads, self._writes, {**self._bound_params, **kwargs}
602+
self._fn,
603+
self._reads,
604+
self._writes,
605+
self.default_reads,
606+
self._default_writes,
607+
{**self._bound_params, **kwargs},
530608
)
531609

532610
def run_and_update(self, state: State, **run_kwargs) -> tuple[dict, State]:
@@ -918,6 +996,8 @@ def __init__(
918996
],
919997
reads: List[str],
920998
writes: List[str],
999+
default_reads: Optional[Dict[str, Any]] = None,
1000+
default_writes: Optional[Dict[str, Any]] = None,
9211001
bound_params: dict = None,
9221002
):
9231003
"""Instantiates a function-based streaming action with the given function, reads, and writes.
@@ -931,6 +1011,8 @@ def __init__(
9311011
self._fn = fn
9321012
self._reads = reads
9331013
self._writes = writes
1014+
self._default_reads = default_reads if default_reads is not None else {}
1015+
self._default_writes = default_writes if default_writes is not None else {}
9341016
self._bound_params = bound_params if bound_params is not None else {}
9351017

9361018
async def _a_stream_run_and_update(
@@ -957,6 +1039,14 @@ def reads(self) -> list[str]:
9571039
def writes(self) -> list[str]:
9581040
return self._writes
9591041

1042+
@property
1043+
def default_writes(self) -> Dict[str, Any]:
1044+
return self._default_writes
1045+
1046+
@property
1047+
def default_reads(self) -> Dict[str, Any]:
1048+
return self._default_reads
1049+
9601050
@property
9611051
def streaming(self) -> bool:
9621052
return True
@@ -969,7 +1059,12 @@ def with_params(self, **kwargs: Any) -> "FunctionBasedStreamingAction":
9691059
:return:
9701060
"""
9711061
return FunctionBasedStreamingAction(
972-
self._fn, self._reads, self._writes, {**self._bound_params, **kwargs}
1062+
self._fn,
1063+
self._reads,
1064+
self._writes,
1065+
self._default_reads,
1066+
self._default_writes,
1067+
{**self._bound_params, **kwargs},
9731068
)
9741069

9751070
@property
@@ -999,7 +1094,10 @@ def bind(self, **kwargs: Any) -> Self:
9991094
...
10001095

10011096

1002-
def copy_func(f: types.FunctionType) -> types.FunctionType:
1097+
T = TypeVar("T", bound=types.FunctionType)
1098+
1099+
1100+
def copy_func(f: T) -> T:
10031101
"""Copies a function. This is used internally to bind parameters to a function
10041102
so we don't accidentally overwrite them.
10051103
@@ -1033,7 +1131,12 @@ def my_action(state: State, z: int) -> tuple[dict, State]:
10331131
return self
10341132

10351133

1036-
def action(reads: List[str], writes: List[str]) -> Callable[[Callable], FunctionRepresentingAction]:
1134+
def action(
1135+
reads: List[str],
1136+
writes: List[str],
1137+
default_reads: Dict[str, Any] = None,
1138+
default_writes: Dict[str, Any] = None,
1139+
) -> Callable[[Callable], FunctionRepresentingAction]:
10371140
"""Decorator to create a function-based action. This is user-facing.
10381141
Note that, in the future, with typed state, we may not need this for
10391142
all cases.
@@ -1044,19 +1147,38 @@ def action(reads: List[str], writes: List[str]) -> Callable[[Callable], Function
10441147
10451148
:param reads: Items to read from the state
10461149
:param writes: Items to write to the state
1150+
:param default_reads: Default values for reads. If nothing upstream produces these, they will
1151+
be filled automatically. This is equivalent to adding
1152+
``state = state.update(**{key: value for key, value in default_reads.items() if key not in state})``
1153+
at the beginning of your function.
1154+
:param default_writes: Default values for writes. If the action's state update does not write to this,
1155+
they will be filled automatically with the default values. Leaving blank will have no default values.
1156+
This is equivalent to adding state = state.update(***deafult_writes) at the beginning of the function.
1157+
Note that this will not work as intended with append/increment operations, so be careful.
10471158
:return: The decorator to assign the function as an action
10481159
"""
1160+
default_reads = default_reads if default_reads is not None else {}
1161+
default_writes = default_writes if default_writes is not None else {}
10491162

10501163
def decorator(fn) -> FunctionRepresentingAction:
1051-
setattr(fn, FunctionBasedAction.ACTION_FUNCTION, FunctionBasedAction(fn, reads, writes))
1164+
setattr(
1165+
fn,
1166+
FunctionBasedAction.ACTION_FUNCTION,
1167+
FunctionBasedAction(
1168+
fn, reads, writes, default_reads=default_reads, default_writes=default_writes
1169+
),
1170+
)
10521171
setattr(fn, "bind", types.MethodType(bind, fn))
10531172
return fn
10541173

10551174
return decorator
10561175

10571176

10581177
def streaming_action(
1059-
reads: List[str], writes: List[str]
1178+
reads: List[str],
1179+
writes: List[str],
1180+
default_reads: Optional[Dict[str, Any]] = None,
1181+
default_writes: Optional[Dict[str, Any]] = None,
10601182
) -> Callable[[Callable], FunctionRepresentingAction]:
10611183
"""Decorator to create a streaming function-based action. This is user-facing.
10621184
@@ -1090,14 +1212,28 @@ def streaming_response(state: State) -> Generator[dict, None, tuple[dict, State]
10901212
# return the final result
10911213
return {'response': full_response}, state.update(response=full_response)
10921214
1215+
:param reads: Items to read from the state
1216+
:param writes: Items to write to the state
1217+
:param default_reads: Default values for reads. If nothing upstream produces these, they will
1218+
be filled automatically. This is equivalent to adding
1219+
``state = state.update(**{key: value for key, value in default_reads.items() if key not in state})``
1220+
at the beginning of your function.
1221+
:param default_writes: Default values for writes. If the action's state update does not write to this,
1222+
they will be filled automatically with the default values. Leaving blank will have no default values.
1223+
This is equivalent to adding state = state.update(***deafult_writes) at the beginning of the function.
1224+
Note that this will not work as intended with append/increment operations, so be careful.
1225+
:return: The decorator to assign the function as an action
1226+
10931227
"""
1228+
default_reads = default_reads if default_reads is not None else {}
1229+
default_writes = default_writes if default_writes is not None else {}
10941230

10951231
def wrapped(fn) -> FunctionRepresentingAction:
10961232
fn = copy_func(fn)
10971233
setattr(
10981234
fn,
10991235
FunctionBasedAction.ACTION_FUNCTION,
1100-
FunctionBasedStreamingAction(fn, reads, writes),
1236+
FunctionBasedStreamingAction(fn, reads, writes, default_reads, default_writes),
11011237
)
11021238
setattr(fn, "bind", types.MethodType(bind, fn))
11031239
return fn

0 commit comments

Comments
 (0)