From b2ef1a8232361dc9a2810b33d345bf0a7f89f485 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Wed, 3 Dec 2025 14:48:03 -0500 Subject: [PATCH 1/8] allow static param to have None value --- src/caskade/__init__.py | 2 ++ src/caskade/errors.py | 6 ++++- src/caskade/module.py | 4 ++-- src/caskade/param.py | 48 ++++++++++++++------------------------- tests/test_context.py | 6 ++--- tests/test_forward.py | 19 ++++++++-------- tests/test_integration.py | 1 + tests/test_module.py | 5 ++-- tests/test_param.py | 21 ++++------------- 9 files changed, 48 insertions(+), 64 deletions(-) diff --git a/src/caskade/__init__.py b/src/caskade/__init__.py index 097af53..b421298 100644 --- a/src/caskade/__init__.py +++ b/src/caskade/__init__.py @@ -17,6 +17,7 @@ ParamConfigurationError, ParamTypeError, ActiveStateError, + FillParamsError, FillDynamicParamsError, FillDynamicParamsArrayError, FillDynamicParamsSequenceError, @@ -51,6 +52,7 @@ "ParamConfigurationError", "ParamTypeError", "ActiveStateError", + "FillParamsError", "FillDynamicParamsError", "FillDynamicParamsArrayError", "FillDynamicParamsSequenceError", diff --git a/src/caskade/errors.py b/src/caskade/errors.py index 3602805..6c4c479 100644 --- a/src/caskade/errors.py +++ b/src/caskade/errors.py @@ -36,7 +36,11 @@ class ActiveStateError(CaskadeException): """Class for exceptions related to the active state of a node in ``caskade``.""" -class FillDynamicParamsError(CaskadeException): +class FillParamsError(CaskadeException): + """Class for exceptions related to filling parameters in ``caskade``""" + + +class FillDynamicParamsError(FillParamsError): """Class for exceptions related to filling dynamic parameters in ``caskade``.""" diff --git a/src/caskade/module.py b/src/caskade/module.py index 0fee9f9..1fa56cf 100644 --- a/src/caskade/module.py +++ b/src/caskade/module.py @@ -8,7 +8,7 @@ from .errors import ( ActiveStateError, ParamConfigurationError, - FillDynamicParamsError, + FillParamsError, FillDynamicParamsArrayError, FillDynamicParamsSequenceError, FillDynamicParamsMappingError, @@ -312,7 +312,7 @@ def fill_kwargs(self, keys: tuple[str]) -> dict[str, ArrayLike]: if key in self.children and isinstance(self[key], Param): val = self.children[key].value if val is None: - raise FillDynamicParamsError( + raise FillParamsError( f"Param {key} in Module {self.name} has no value. " "Ensure that the parameter is set before calling the forward method or provided with the params." ) diff --git a/src/caskade/param.py b/src/caskade/param.py index bb57223..14acf65 100644 --- a/src/caskade/param.py +++ b/src/caskade/param.py @@ -90,7 +90,7 @@ def __init__( cyclic: bool = False, valid: Optional[tuple[Union[ArrayLike, float, int, None]]] = None, units: Optional[str] = None, - dynamic: bool = False, + dynamic: Optional[bool] = None, batched: bool = False, dtype: Optional[Any] = None, device: Optional[Any] = None, @@ -119,10 +119,11 @@ def __init__( self._cyclic = cyclic self.batched = batched self.shape = shape - if dynamic: - self.dynamic_value(value) + if dynamic or (dynamic is None and value is None): + self.to_dynamic() else: - self.value = value + self.to_static() + self.value = value self.valid = valid self.units = units @@ -171,14 +172,7 @@ def to_static(self, **kwargs): try: self.__value = self.__value(self) except: - raise ParamTypeError( - f"Cannot set pointer parameter {self.name} to static with `to_static`. Pointer could not be evaluated because of: \n" - + traceback.format_exc() - ) - if self.__value is None: - raise ParamTypeError( - f"Cannot set dynamic parameter {self.name} to static when no dynamic value is set. Try using `static_value(value)` to provide a value and set to static." - ) + self.__value = None self.node_type = "static" @property @@ -259,18 +253,17 @@ def static_value(self, value): ) # Catch cases where input is invalid - if value is None: - raise ParamTypeError("Cannot set to static with value of None") if isinstance(value, Param) or callable(value): raise ParamTypeError( - f"Cannot set static value to pointer ({self.name}). Try setting `pointer_func(func)` or `pointer_func(param)` to create a pointer." + f"Cannot set static value to pointer ({self.name}). Try setting `pointer_value(func)` or `pointer_value(param)` to create a pointer." ) - value = backend.as_array(value, dtype=self._dtype, device=self._device) + if value is not None: + value = backend.as_array(value, dtype=self._dtype, device=self._device) + self._shape_from_value(tuple(value.shape)) self.__value = value - self.node_type = "static" - self._shape_from_value(tuple(value.shape)) self.is_valid() + self.node_type = "static" def dynamic_value(self, value): # While active no value can be set @@ -279,24 +272,19 @@ def dynamic_value(self, value): f"Cannot set dynamic value of parameter {self.name} while active." ) - # No dynamic value - if value is None: - self.__value = None - self.node_type = "dynamic" - return - # Catch cases where input is invalid if isinstance(value, Param) or callable(value): raise ParamTypeError(f"Cannot set dynamic value to pointer ({self.name})") # Set to dynamic value - value = backend.as_array(value, dtype=self._dtype, device=self._device) + if value is not None: + value = backend.as_array(value, dtype=self._dtype, device=self._device) + self._shape_from_value(tuple(value.shape)) self.__value = value self.node_type = "dynamic" - self._shape_from_value(tuple(value.shape)) self.is_valid() - def pointer_func(self, value: Union["Param", Callable]): + def pointer_value(self, value: Union["Param", Callable]): # While active no value can be set if self.active: raise ActiveStateError( @@ -332,10 +320,8 @@ def value(self, value): if self.active: raise ActiveStateError(f"Cannot set value of parameter {self.name} while active") - if value is None: - self.dynamic_value(None) - elif isinstance(value, Param) or callable(value): - self.pointer_func(value) + if isinstance(value, Param) or callable(value): + self.pointer_value(value) elif self.dynamic: self.dynamic_value(value) else: diff --git a/tests/test_context.py b/tests/test_context.py index f098140..7ec743d 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -7,8 +7,8 @@ class TestSim(Module): def __init__(self): super().__init__() self.a = Param("a", 1.0) - self.b = Param("b", None) - self.c = Param("c", None) + self.b = Param("b", None, dynamic=True) + self.c = Param("c", None, dynamic=True) @forward def testfunc(self, a, b, c): @@ -43,7 +43,7 @@ def __init__(self): self.a = Param("a", 3.0) self.b = Param("b", lambda p: p["a"].value) self.b.link(self.a) - self.c = Param("c", None) + self.c = Param("c", None, dynamic=True) self.a_vals = (backend.make_array(1.0), backend.make_array(2.0)) @forward diff --git a/tests/test_forward.py b/tests/test_forward.py index 71feaa6..9fe056a 100644 --- a/tests/test_forward.py +++ b/tests/test_forward.py @@ -3,6 +3,7 @@ Param, forward, ValidContext, + FillParamsError, FillDynamicParamsError, FillDynamicParamsSequenceError, FillDynamicParamsMappingError, @@ -20,7 +21,7 @@ class TestSim(Module): def __init__(self, a, b_shape, c, m1): super().__init__("test_sim") self.a = Param("a", a) - self.b = Param("b", None, b_shape) + self.b = Param("b", None, b_shape, dynamic=True) self.c = Param("c", c) self.m1 = m1 @@ -32,9 +33,9 @@ def testfun(self, x, a=None, b=None, c=None): class TestSubSim(Module): def __init__(self, d=None, e=None, f=None): super().__init__() - self.d = Param("d", d) - self.e = Param("e", e) - self.f = Param("f", f) + self.d = Param("d", d, dynamic=True) + self.e = Param("e", e, dynamic=True) + self.f = Param("f", f, dynamic=True) @forward def __call__(self, d=None, e=None, live_c=None): @@ -50,7 +51,7 @@ def __call__(self, d=None, e=None, live_c=None): assert graph is not None, "should return a graphviz object" # Dont provide params - with pytest.raises(FillDynamicParamsError): + with pytest.raises(FillParamsError): main1.testfun() # List as params @@ -71,7 +72,7 @@ def __call__(self, d=None, e=None, live_c=None): assert valid_result.shape == (2, 2) assert backend.all(valid_result == result).item() # Wrong number of params, too few - with pytest.raises(FillDynamicParamsError): + with pytest.raises(FillParamsError): result = main1.testfun(1.0, params=[]) with pytest.raises(FillDynamicParamsSequenceError): result = main1.testfun(1.0, params=params[:3]) @@ -115,7 +116,7 @@ def __call__(self, d=None, e=None, live_c=None): assert valid_result.shape == (2, 2) assert backend.all(valid_result == result).item() # Wrong number of params, too few - with pytest.raises(FillDynamicParamsError): + with pytest.raises(FillParamsError): result = main1.testfun(1.0, backend.as_array([])) with pytest.raises(FillDynamicParamsArrayError): result = main1.testfun(1.0, params[:-3]) @@ -201,7 +202,7 @@ def __call__(self, d=None, e=None, live_c=None): "f": backend.make_array(1.0), }, } - with pytest.raises(FillDynamicParamsError): + with pytest.raises(FillParamsError): result = main1.testfun(1.0, params=params) # All params static @@ -217,7 +218,7 @@ def __call__(self, d=None, e=None, live_c=None): # dynamic with no shape main1.b = None - main1.b.dynamic_value = None + main1.b.dynamic_value(None) main1.b.shape = None with pytest.raises(ParamConfigurationError): main1.testfun(1.0, params=backend.module.ones(4)) diff --git a/tests/test_integration.py b/tests/test_integration.py index 474440f..aea7873 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -31,6 +31,7 @@ def __call__(self, d=None, e=None, f=None): sub1 = TestSubSim(d=1.0, e=lambda s: s.flink.value, f=None) sub1.e.link("flink", sub1.f) main1 = TestSim(a=2.0, b=None, c=None, c_shape=(), m1=sub1) + main1.b.to_dynamic() main1.c = main1.b sub1.f = main1.c diff --git a/tests/test_module.py b/tests/test_module.py index 505eeea..5920011 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -21,7 +21,7 @@ def test_module_creation(): m2 = Module("test2") m1.mod = m2 assert m1["mod"] == m2 - p1 = Param("test1") + p1 = Param("test1", dynamic=True) m2.p = p1 assert m2["p"] == p1 assert m1.mod.p == p1 @@ -68,7 +68,7 @@ def __init__(self, name, param): def test(self, p): return 2 * p - shared_param = Param("shared") + shared_param = Param("shared", dynamic=True) m1 = TestModule("m1", shared_param) m2 = TestModule("m2", shared_param) @@ -257,6 +257,7 @@ def test_module_and_collection(): S.d = D D.p = Param("p") D.p2 = Param("p2") + M.to_dynamic(False) params = { "p": 1.0, diff --git a/tests/test_param.py b/tests/test_param.py index 8ab5bfa..0cbe40d 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -79,12 +79,6 @@ def test_param_creation(): assert p9.valid[0].item() == 0 assert p9.valid[1].item() == 1 - # Invalid dynamic value - with pytest.raises(ParamTypeError): - p10 = Param("test", value=p9, dynamic=True) - with pytest.raises(ParamTypeError): - p11 = Param("test", value=lambda p: p.other.value * 2, dynamic=True) - # Set dynamic from other states p13 = Param("test", 1.0) # static p13.dynamic_value(2.0) @@ -156,7 +150,7 @@ def test_check_npvalue(): def test_value_setter(): # dynamic - p = Param("test") + p = Param("test", dynamic=True) assert p.node_type == "dynamic" # static @@ -184,14 +178,13 @@ def test_times_2(p): # Invalid pointer with pytest.raises(ParamTypeError): - p.pointer_func(1.0) + p.pointer_value(1.0) with pytest.raises(ParamTypeError): - p.pointer_func(None) + p.pointer_value(None) # Invalid static value with pytest.raises(ParamTypeError): - p.static_value(None) - + p.static_value(other) with pytest.raises(ParamTypeError): p.static_value(lambda p: p.other.value) @@ -202,7 +195,7 @@ def test_times_2(p): with pytest.raises(ActiveStateError): p.static_value(1.0) with pytest.raises(ActiveStateError): - p.pointer_func(lambda p: p.other.value) + p.pointer_value(lambda p: p.other.value) def test_param_shape(): @@ -256,15 +249,11 @@ def test_to_dynamic_static(): p.to_static() # from static assert p.static p = Param("test") - with pytest.raises(ParamTypeError): - p.to_static() # from dynamic, fails p.dynamic_value(2.0) p.to_static() # from dynamic with dynamic value assert p.static assert p.value.item() == 2.0 p.value = lambda p: p["other"].value * 2 - with pytest.raises(ParamTypeError): - p.to_static() # from pointer, fails p.link("other", other) p.to_static() # from pointer, succeeds assert p.static From fd0c19e52c293a7b26d8a35b66d2aadc1787681e Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Wed, 3 Dec 2025 15:30:14 -0500 Subject: [PATCH 2/8] test pointer to static with failed pointer --- tests/test_param.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_param.py b/tests/test_param.py index 0cbe40d..6e864a4 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -254,6 +254,9 @@ def test_to_dynamic_static(): assert p.static assert p.value.item() == 2.0 p.value = lambda p: p["other"].value * 2 + p.to_static() # Unable to evaluate pointer, becomes None + assert p.value is None + p.value = lambda p: p["other"].value * 2 p.link("other", other) p.to_static() # from pointer, succeeds assert p.static From 0d941096638f417aa4428de76fa609bb0e8a6d94 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Wed, 3 Dec 2025 16:05:48 -0500 Subject: [PATCH 3/8] coverage for setting dynamic_value with param --- tests/test_param.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_param.py b/tests/test_param.py index 6e864a4..ecee7be 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -227,6 +227,8 @@ def test_to_dynamic_static(): p.to_dynamic() # from dynamic assert p.dynamic p.dynamic_value(1.0) + with pytest.raises(ParamTypeError): + p.dynamic_value(other) assert p.dynamic p.to_dynamic() # from dynamic with dynamic value assert p.dynamic From 707f845fd31772daa80b1e67f98bd78aa9545e3a Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Wed, 3 Dec 2025 16:35:13 -0500 Subject: [PATCH 4/8] remove unnecessary forced dynamic in tests --- tests/test_context.py | 6 +++--- tests/test_integration.py | 1 - tests/test_module.py | 5 ++--- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/test_context.py b/tests/test_context.py index 7ec743d..b8ed0c1 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -7,8 +7,8 @@ class TestSim(Module): def __init__(self): super().__init__() self.a = Param("a", 1.0) - self.b = Param("b", None, dynamic=True) - self.c = Param("c", None, dynamic=True) + self.b = Param("b") + self.c = Param("c") @forward def testfunc(self, a, b, c): @@ -43,7 +43,7 @@ def __init__(self): self.a = Param("a", 3.0) self.b = Param("b", lambda p: p["a"].value) self.b.link(self.a) - self.c = Param("c", None, dynamic=True) + self.c = Param("c") self.a_vals = (backend.make_array(1.0), backend.make_array(2.0)) @forward diff --git a/tests/test_integration.py b/tests/test_integration.py index aea7873..474440f 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -31,7 +31,6 @@ def __call__(self, d=None, e=None, f=None): sub1 = TestSubSim(d=1.0, e=lambda s: s.flink.value, f=None) sub1.e.link("flink", sub1.f) main1 = TestSim(a=2.0, b=None, c=None, c_shape=(), m1=sub1) - main1.b.to_dynamic() main1.c = main1.b sub1.f = main1.c diff --git a/tests/test_module.py b/tests/test_module.py index 5920011..505eeea 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -21,7 +21,7 @@ def test_module_creation(): m2 = Module("test2") m1.mod = m2 assert m1["mod"] == m2 - p1 = Param("test1", dynamic=True) + p1 = Param("test1") m2.p = p1 assert m2["p"] == p1 assert m1.mod.p == p1 @@ -68,7 +68,7 @@ def __init__(self, name, param): def test(self, p): return 2 * p - shared_param = Param("shared", dynamic=True) + shared_param = Param("shared") m1 = TestModule("m1", shared_param) m2 = TestModule("m2", shared_param) @@ -257,7 +257,6 @@ def test_module_and_collection(): S.d = D D.p = Param("p") D.p2 = Param("p2") - M.to_dynamic(False) params = { "p": 1.0, From de8f8ce41c129f4a9952c699b6ab2dbe741cd09e Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 4 Dec 2025 13:25:23 -0500 Subject: [PATCH 5/8] make to_dynamic etc primary way to set value and type --- docs/source/notebooks/AdvancedGuide.ipynb | 4 +- docs/source/notebooks/BeginnersGuide.ipynb | 8 +- src/caskade/module.py | 8 +- src/caskade/param.py | 167 ++++++++++----------- tests/test_forward.py | 2 +- tests/test_module.py | 18 +-- tests/test_param.py | 32 ++-- 7 files changed, 118 insertions(+), 121 deletions(-) diff --git a/docs/source/notebooks/AdvancedGuide.ipynb b/docs/source/notebooks/AdvancedGuide.ipynb index 73cd613..53e9514 100644 --- a/docs/source/notebooks/AdvancedGuide.ipynb +++ b/docs/source/notebooks/AdvancedGuide.ipynb @@ -157,7 +157,7 @@ "# Set individual param to dynamic\n", "G1.x0.to_dynamic() # call function to set dynamic\n", "G1.q = None # set to None to make fully dynamic\n", - "G1.q.dynamic_value(0.5) # set with dynamic value\n", + "G1.q.to_dynamic(0.5) # set with dynamic value\n", "C.to_dynamic() # only sets immediate children to dynamic\n", "print(\"Individual params can be set to dynamic\")\n", "display(C.graphviz())\n", @@ -173,7 +173,7 @@ "G1.q.to_static() # Setting to static, uses the earlier value\n", "\n", "# Setting a value and make it static\n", - "G1.I0.static_value(10.0)\n", + "G1.I0.to_static(10.0)\n", "print(\"Individual params can be set to static\")\n", "display(C.graphviz())\n", "\n", diff --git a/docs/source/notebooks/BeginnersGuide.ipynb b/docs/source/notebooks/BeginnersGuide.ipynb index 645fbb7..4291de0 100644 --- a/docs/source/notebooks/BeginnersGuide.ipynb +++ b/docs/source/notebooks/BeginnersGuide.ipynb @@ -298,10 +298,10 @@ "thirdsim.x0 = lambda p: p.time.value * torch.tensor((1, 1)) - 0.5\n", "thirdsim.x0.link(simtime)\n", "\n", - "# Use `static_value` to set the value and set to static\n", - "# Similarly use `dynamic_value` to set value and set dynamic\n", - "secondsim.q.static_value(0.5)\n", - "secondsim.phi.static_value(3.14 / 3)\n", + "# Use `to_static(value)` to set the value and set to static\n", + "# Similarly use `to_dynamic(value)` to set value and set dynamic\n", + "secondsim.q.to_static(0.5)\n", + "secondsim.phi.to_static(3.14 / 3)\n", "\n", "combinedsim.graphviz()" ] diff --git a/src/caskade/module.py b/src/caskade/module.py index 1fa56cf..93175ca 100644 --- a/src/caskade/module.py +++ b/src/caskade/module.py @@ -165,7 +165,7 @@ def _fill_dict(self, node, params, dynamic_values=False): for key in params: if key in node.children and isinstance(node[key], Param) and node[key].dynamic: if dynamic_values: - node[key].dynamic_value(params[key]) + node[key].to_dynamic(params[key]) else: node[key]._value = params[key] elif ( @@ -229,7 +229,7 @@ def _fill_values( try: val = backend.view(params[..., pos : pos + size], B + param.shape) if dynamic_values: - param.dynamic_value(val) + param.to_dynamic(val) else: param._value = val except (RuntimeError, IndexError, ValueError, TypeError): @@ -244,7 +244,7 @@ def _fill_values( elif len(params) == len(dynamic_params): for param, value in zip(dynamic_params, params): if dynamic_values: - param.dynamic_value(value) + param.to_dynamic(value) else: param._value = value elif len(params) == len(self.dynamic_modules): @@ -332,7 +332,7 @@ def _check_dynamic_values(self, params_type: str = "ArrayLike"): """Check if all dynamic values are set.""" bad_params = [] for param in self.dynamic_params: - if "value" not in param.node_type: + if param.value is None: bad_params.append(param.name) if len(bad_params) > 0: raise ParamConfigurationError( diff --git a/src/caskade/param.py b/src/caskade/param.py index 14acf65..7b3d340 100644 --- a/src/caskade/param.py +++ b/src/caskade/param.py @@ -22,6 +22,9 @@ def valid_shape(shape, value_shape, batched): return False +NULL = object() + + class Param(Node): """ Node to represent a parameter in the graph. @@ -78,7 +81,6 @@ class Param(Node): graphviz_types = { "static": {"style": "filled", "color": "lightgrey", "shape": "box"}, "dynamic": {"style": "solid", "color": "black", "shape": "box"}, - "dynamic value": {"style": "solid", "color": "#333333", "shape": "box"}, "pointer": {"style": "filled", "color": "lightgrey", "shape": "rarrow"}, } @@ -146,35 +148,90 @@ def node_type(self): @node_type.setter def node_type(self, value): pre_type = self.node_type - if value == "dynamic" and self.__value is not None: - value = "dynamic value" self._node_type = value if pre_type != self.node_type: self.update_graph() - def to_dynamic(self, **kwargs): - """Change this parameter to a dynamic parameter. If the parameter has a - value, this will become a "dynamic value" parameter.""" - if self.pointer: - try: - self.__value = self.__value(self) - except: - self.__value = None + def to_dynamic(self, value=NULL, **kwargs): + """Change this parameter to a dynamic parameter. If a value is provided, this will be set as the dynamic value.""" + # While active no value can be set + if self.active: + raise ActiveStateError(f"Cannot set parameter {self.name} to dynamic while active.") + + # Catch cases where input is invalid + if isinstance(value, Param) or callable(value): + raise ParamTypeError(f"Cannot set dynamic value to pointer ({self.name}).") + + if value is NULL: + if self.pointer: + try: + self.__value = self.__value(self) + except: + self.__value = None + self.node_type = "dynamic" + return + + if value is not None: + value = backend.as_array(value, dtype=self._dtype, device=self._device) + self._shape_from_value(tuple(value.shape)) + self.__value = value self.node_type = "dynamic" + self.is_valid() - def to_static(self, **kwargs): - """Change this parameter to a static parameter. This only works if the - parameter has a dynamic value set, or if the pointer can be - evaluated.""" - if self.static: + def to_static(self, value=NULL, **kwargs): + """Change this parameter to a static parameter. If a value is provided + this will be set as the static value.""" + # While active no value can be set + if self.active: + raise ActiveStateError(f"Cannot set parameter {self.name} to static while active.") + + # Catch cases where input is invalid + if isinstance(value, Param) or callable(value): + raise ParamTypeError(f"Cannot set static value to pointer ({self.name}).") + + if value is NULL: + if self.pointer: + try: + self.__value = self.__value(self) + except: + self.__value = None + self.node_type = "static" return - if self.pointer: - try: - self.__value = self.__value(self) - except: - self.__value = None + + if value is not None: + value = backend.as_array(value, dtype=self._dtype, device=self._device) + self._shape_from_value(tuple(value.shape)) + self.__value = value + self.is_valid() self.node_type = "static" + def to_pointer(self, value=NULL, link=(), **kwargs): + # While active no value can be set + if self.active: + raise ActiveStateError(f"Cannot set parameter {self.name} to pointer while active") + + if value is NULL: + if callable(self.__value): + self.node_type = "pointer" + return + if len(self.children) == 1: + value = next(iter(self.children)) + else: + value = None + + if isinstance(value, Param): + self.link(value) + p_name = value.name + value = lambda p: p[p_name].value + elif value is not None and not callable(value): + raise ParamTypeError(f"Pointer function must be a Param or callable ({self.name})") + elif hasattr(value, "params"): + self.link(value.params) + self.link(link) + self.__value = value + self._shape = None + self.node_type = "pointer" + @property def shape(self) -> Optional[tuple[int, ...]]: try: @@ -245,64 +302,6 @@ def device(self) -> Optional[str]: pass return self._device - def static_value(self, value): - # While active no value can be set - if self.active: - raise ActiveStateError( - f"Cannot set static value of parameter {self.name} while active." - ) - - # Catch cases where input is invalid - if isinstance(value, Param) or callable(value): - raise ParamTypeError( - f"Cannot set static value to pointer ({self.name}). Try setting `pointer_value(func)` or `pointer_value(param)` to create a pointer." - ) - - if value is not None: - value = backend.as_array(value, dtype=self._dtype, device=self._device) - self._shape_from_value(tuple(value.shape)) - self.__value = value - self.is_valid() - self.node_type = "static" - - def dynamic_value(self, value): - # While active no value can be set - if self.active: - raise ActiveStateError( - f"Cannot set dynamic value of parameter {self.name} while active." - ) - - # Catch cases where input is invalid - if isinstance(value, Param) or callable(value): - raise ParamTypeError(f"Cannot set dynamic value to pointer ({self.name})") - - # Set to dynamic value - if value is not None: - value = backend.as_array(value, dtype=self._dtype, device=self._device) - self._shape_from_value(tuple(value.shape)) - self.__value = value - self.node_type = "dynamic" - self.is_valid() - - def pointer_value(self, value: Union["Param", Callable]): - # While active no value can be set - if self.active: - raise ActiveStateError( - f"Cannot set pointer function of parameter {self.name} while active" - ) - - if isinstance(value, Param): - self.link(value) - p_name = value.name - value = lambda p: p[p_name].value - elif not callable(value): - raise ParamTypeError(f"Pointer function must be a Param or callable ({self.name})") - elif hasattr(value, "params"): - self.link(value.params) - self.__value = value - self._shape = None - self.node_type = "pointer" - @property def value(self) -> Union[ArrayLike, None]: if self._value is not None: @@ -321,11 +320,11 @@ def value(self, value): raise ActiveStateError(f"Cannot set value of parameter {self.name} while active") if isinstance(value, Param) or callable(value): - self.pointer_value(value) + self.to_pointer(value) elif self.dynamic: - self.dynamic_value(value) + self.to_dynamic(value) else: - self.static_value(value) + self.to_static(value) @property def npvalue(self) -> ndarray: @@ -435,9 +434,9 @@ def _load_state_hdf5(self, h5group, index: int = -1, _done_load: set = None): value = h5group["value"][()] if "static" in h5group["value"].attrs["node_type"]: - self.static_value(value) + self.to_static(value) elif "dynamic" in h5group["value"].attrs["node_type"]: - self.dynamic_value(value) + self.to_dynamic(value) self.units = h5group["value"].attrs["units"] if "valid_left" in h5group["value"].attrs: self.valid = ( diff --git a/tests/test_forward.py b/tests/test_forward.py index 9fe056a..0718ec2 100644 --- a/tests/test_forward.py +++ b/tests/test_forward.py @@ -218,7 +218,7 @@ def __call__(self, d=None, e=None, live_c=None): # dynamic with no shape main1.b = None - main1.b.dynamic_value(None) + main1.b.to_dynamic(None) main1.b.shape = None with pytest.raises(ParamConfigurationError): main1.testfun(1.0, params=backend.module.ones(4)) diff --git a/tests/test_module.py b/tests/test_module.py index 505eeea..e0b5603 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -117,7 +117,7 @@ def __call__(self, d=None, e=None, live_c=None): sub1 = TestSubSim(d=2.0, e=2.5, f=None) main1 = TestSim(a=1.0, b_shape=(2,), c=4.0, m1=sub1) - main1.b.static_value(backend.make_array([1.0, 2.0])) + main1.b.to_static(backend.make_array([1.0, 2.0])) # Try to get auto params when not all dynamic values available with pytest.raises(ParamConfigurationError): @@ -128,7 +128,7 @@ def __call__(self, d=None, e=None, live_c=None): p00 = main1.build_params_dict() with pytest.raises(ParamConfigurationError): p00 = sub1.build_params_dict() - sub1.f.dynamic_value(3.0) + sub1.f.to_dynamic(3.0) # Check dynamic value assert main1.c.value.item() == 4.0 @@ -183,7 +183,7 @@ def __call__(self, d=None, e=None, live_c=None): # Check invalid dynamic value with pytest.warns(InvalidValueWarning): - sub1.f.dynamic_value(11.0) + sub1.f.to_dynamic(11.0) # All static make params main1.c.to_static() @@ -221,10 +221,10 @@ def test_batched_build_params_array(): M.p1 = Param("p1") M.p2 = Param("p2") - M.p1.dynamic_value([1.0, 2.0]) + M.p1.to_dynamic([1.0, 2.0]) M.p1.batched = True M.p1.shape = () - M.p2.dynamic_value([3.0, 4.0]) + M.p2.to_dynamic([3.0, 4.0]) M.p2.batched = True M.p2.shape = () @@ -232,15 +232,15 @@ def test_batched_build_params_array(): assert a.shape == (2, 2) with pytest.raises(ParamConfigurationError): - M.p1.dynamic_value([1.0, 2.0]) + M.p1.to_dynamic([1.0, 2.0]) M.p1.shape = (2,) - M.p2.dynamic_value([3.0, 4.0]) + M.p2.to_dynamic([3.0, 4.0]) M.p2.shape = () M.build_params_array() with pytest.raises(ParamConfigurationError): - M.p1.dynamic_value([1.0, 2.0]) + M.p1.to_dynamic([1.0, 2.0]) M.p1.shape = () - M.p2.dynamic_value([1.0, 2.0]) + M.p2.to_dynamic([1.0, 2.0]) M.p2.shape = (2,) M.build_params_array() diff --git a/tests/test_param.py b/tests/test_param.py index ecee7be..2e91878 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -41,7 +41,7 @@ def test_param_creation(): p3.value = 1.0 with pytest.raises(ActiveStateError): p33.active = True - p33.dynamic_value(1.0) + p33.to_dynamic(1.0) # Missmatch value and shape with pytest.raises(ParamConfigurationError): @@ -81,14 +81,14 @@ def test_param_creation(): # Set dynamic from other states p13 = Param("test", 1.0) # static - p13.dynamic_value(2.0) + p13.to_dynamic(2.0) assert p13.value.item() == 2.0 assert p13.dynamic p14 = Param("test") # dynamic - p14.dynamic_value(1.0) + p14.to_dynamic(1.0) assert p14.value.item() == 1.0 p15 = Param("test", p14) # pointer - p15.dynamic_value(2.0) + p15.to_dynamic(2.0) assert p15.value.item() == 2.0 p16 = Param("test", 1.0) # static p16.to_dynamic() @@ -127,7 +127,7 @@ def test_params_sticky_to(): assert p.value.dtype == backend.module.float32 p = p.to(dtype=backend.module.float64, device=device) assert p.value.dtype == backend.module.float64 - p.dynamic_value(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + p.to_dynamic(np.array([1.0, 2.0, 3.0], dtype=np.float32)) assert p.value.dtype == backend.module.float64 # neither dtype or value set p = Param("test", valid=(0, 2)) @@ -154,7 +154,7 @@ def test_value_setter(): assert p.node_type == "dynamic" # static - p.static_value(1.0) + p.to_static(1.0) assert p.node_type == "static" assert p.value.item() == 1.0 @@ -178,24 +178,22 @@ def test_times_2(p): # Invalid pointer with pytest.raises(ParamTypeError): - p.pointer_value(1.0) - with pytest.raises(ParamTypeError): - p.pointer_value(None) + p.to_pointer(1.0) # Invalid static value with pytest.raises(ParamTypeError): - p.static_value(other) + p.to_static(other) with pytest.raises(ParamTypeError): - p.static_value(lambda p: p.other.value) + p.to_static(lambda p: p.other.value) # Cannot update while active p.active = True with pytest.raises(ActiveStateError): - p.dynamic_value(1.0) + p.to_dynamic(1.0) with pytest.raises(ActiveStateError): - p.static_value(1.0) + p.to_static(1.0) with pytest.raises(ActiveStateError): - p.pointer_value(lambda p: p.other.value) + p.to_pointer(lambda p: p.other.value) def test_param_shape(): @@ -226,9 +224,9 @@ def test_to_dynamic_static(): p = Param("test") p.to_dynamic() # from dynamic assert p.dynamic - p.dynamic_value(1.0) + p.to_dynamic(1.0) with pytest.raises(ParamTypeError): - p.dynamic_value(other) + p.to_dynamic(other) assert p.dynamic p.to_dynamic() # from dynamic with dynamic value assert p.dynamic @@ -251,7 +249,7 @@ def test_to_dynamic_static(): p.to_static() # from static assert p.static p = Param("test") - p.dynamic_value(2.0) + p.to_dynamic(2.0) p.to_static() # from dynamic with dynamic value assert p.static assert p.value.item() == 2.0 From 8546de75e0195a56578781e534578cb99750e664 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 4 Dec 2025 13:41:09 -0500 Subject: [PATCH 6/8] more tests for param to_pointer --- src/caskade/param.py | 2 +- tests/test_param.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/caskade/param.py b/src/caskade/param.py index 7b3d340..817d1a9 100644 --- a/src/caskade/param.py +++ b/src/caskade/param.py @@ -215,7 +215,7 @@ def to_pointer(self, value=NULL, link=(), **kwargs): self.node_type = "pointer" return if len(self.children) == 1: - value = next(iter(self.children)) + value = next(iter(self.children.values())) else: value = None diff --git a/tests/test_param.py b/tests/test_param.py index 2e91878..5e7e02f 100644 --- a/tests/test_param.py +++ b/tests/test_param.py @@ -166,6 +166,20 @@ def test_value_setter(): p.value = other assert p.node_type == "pointer" assert p.shape == other.shape + p.to_pointer() + p.to_static() + assert p.value.item() == 2.0 + p.to_pointer() + assert p.node_type == "pointer" + assert p.value.item() == 2.0 + p.to_pointer(other) + assert p.node_type == "pointer" + p.to_static() + p.unlink(other) + p.to_pointer() + assert p.node_type == "pointer" + with pytest.raises(TypeError): + p.value # function def test_times_2(p): From 9cec43dfc90738e3ed37bd0b6b057c4c1a2ceff5 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 4 Dec 2025 14:22:29 -0500 Subject: [PATCH 7/8] add docstyring --- src/caskade/param.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/caskade/param.py b/src/caskade/param.py index 817d1a9..c78405d 100644 --- a/src/caskade/param.py +++ b/src/caskade/param.py @@ -1,6 +1,5 @@ from typing import Optional, Union, Callable, Any from warnings import warn -import traceback from math import prod from numpy import ndarray @@ -153,7 +152,8 @@ def node_type(self, value): self.update_graph() def to_dynamic(self, value=NULL, **kwargs): - """Change this parameter to a dynamic parameter. If a value is provided, this will be set as the dynamic value.""" + """Change this parameter to a dynamic parameter. If a value is provided, + this will be set as the dynamic value.""" # While active no value can be set if self.active: raise ActiveStateError(f"Cannot set parameter {self.name} to dynamic while active.") @@ -206,6 +206,11 @@ def to_static(self, value=NULL, **kwargs): self.node_type = "static" def to_pointer(self, value=NULL, link=(), **kwargs): + """Change this parameter to a pointer parameter. If a value is provided + this will be set as the pointer. Either provide a Param object to point + to its value, or provide a callable function to be called at runtime. It + is also possible to provide a tuple of nodes to link to while creating + the pointer.""" # While active no value can be set if self.active: raise ActiveStateError(f"Cannot set parameter {self.name} to pointer while active") From 834f1b1cfd5db32d739904e980930501cd12770d Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Thu, 4 Dec 2025 15:35:07 -0500 Subject: [PATCH 8/8] add fill values for node collection --- src/caskade/collection.py | 6 ++++++ tests/test_collection.py | 11 +++++++++++ 2 files changed, 17 insertions(+) diff --git a/src/caskade/collection.py b/src/caskade/collection.py index 7c1b2ab..36bb4a1 100644 --- a/src/caskade/collection.py +++ b/src/caskade/collection.py @@ -1,3 +1,5 @@ +from typing import Iterable + from .base import Node @@ -12,6 +14,10 @@ def to_static(self, **kwargs): if hasattr(node, "to_static"): node.to_static(**kwargs) + def fill_values(self, values: Iterable): + for node, value in zip(self, values): + node.value = value + def copy(self): raise NotImplementedError diff --git a/tests/test_collection.py b/tests/test_collection.py index 984c56d..87386dc 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -111,6 +111,17 @@ def test_node_list_creation(): n4.append(1) +@pytest.mark.parametrize("node_type", [NodeTuple, NodeList]) +def test_node_collection_param_values(node_type): + NL = node_type([Param("p1"), Param("p2"), Param("p3")]) + + NL.fill_values([1, 2, 3]) + + assert NL[0].value.item() == 1.0 + assert NL[1].value.item() == 2.0 + assert NL[2].value.item() == 3.0 + + def test_node_list_manipulation(): params = [Param("ptest1", 1), Param("ptest2", 2)]