diff --git a/src/caskade/__init__.py b/src/caskade/__init__.py index 1545957..25a8430 100644 --- a/src/caskade/__init__.py +++ b/src/caskade/__init__.py @@ -3,7 +3,7 @@ from .base import Node from .backend import backend, ArrayLike from .context import ActiveContext, ValidContext, OverrideParam -from .decorators import forward +from .decorators import forward, active_cache from .module import Module from .param import Param, dynamic from .collection import NodeCollection, NodeList, NodeTuple @@ -42,6 +42,7 @@ "ValidContext", "OverrideParam", "forward", + "active_cache", "test", "CaskadeException", "GraphError", diff --git a/src/caskade/context.py b/src/caskade/context.py index 2218973..356b543 100644 --- a/src/caskade/context.py +++ b/src/caskade/context.py @@ -16,12 +16,12 @@ def __enter__(self): self.outer_active = self.module.active if self.outer_active and not self.active: self.outer_params = list(p.value for p in self.module.dynamic_params) - self.module.clear_params() + self.module.clear_state() self.module.active = self.active def __exit__(self, exc_type, exc_value, traceback): if not self.outer_active and self.active: - self.module.clear_params() + self.module.clear_state() self.module.active = self.outer_active if self.outer_active and not self.active: self.module.fill_params(self.outer_params) diff --git a/src/caskade/decorators.py b/src/caskade/decorators.py index 45193bd..255d339 100644 --- a/src/caskade/decorators.py +++ b/src/caskade/decorators.py @@ -2,11 +2,10 @@ import functools from contextlib import ExitStack -from .backend import backend from .context import ActiveContext, OverrideParam from .param import Param -__all__ = ("forward",) +__all__ = ("forward", "active_cache") def forward(method): @@ -99,3 +98,40 @@ def wrapped(self, *args, **kwargs): return method(self, *args, **kwargs) return wrapped + +def active_cache(method): + """ + Decorator to enable caching for a method. This way the method will only be + called once in an active state, and the cache will be dropped when exiting. + + Parameters + ---------- + method: (Callable) + The method to be decorated. + + Returns + ------- + Callable + The decorated method with caching enabled. + """ + + NOVALUE = object() + + cache = NOVALUE + + def hook(self): + nonlocal cache + cache = NOVALUE + self.clear_state_hooks.remove(hook) + + @functools.wraps(method) + def wrapped(self, *args, **kwargs): + nonlocal cache + if not self.active: + return method(self, *args, **kwargs) + elif cache is NOVALUE: + cache = method(self, *args, **kwargs) + self.clear_state_hooks.add(hook) + return cache + + return wrapped \ No newline at end of file diff --git a/src/caskade/module.py b/src/caskade/module.py index 1ed1697..723251a 100644 --- a/src/caskade/module.py +++ b/src/caskade/module.py @@ -75,6 +75,7 @@ def __init__(self, name: Optional[str] = None, **kwargs): self.local_dynamic_params = {} self._type = "module" self.valid_context = False + self.clear_state_hooks = set() def update_graph(self): """Maintain a tuple of dynamic and live parameters at all points lower @@ -289,16 +290,19 @@ def fill_params(self, params: Union[ArrayLike, Sequence, Mapping]): params = self.from_valid(params) self._fill_values(params) - def clear_params(self): + def clear_state(self): """Set all dynamic parameters to None and live parameters to LiveParam. This is to be used on exiting an ``ActiveContext`` and so should not be used by a user.""" if not self.active: - raise ActiveStateError(f"Module {self.name} must be active to clear params") + raise ActiveStateError(f"Module {self.name} must be active to clear state") for param in self.dynamic_params + self.pointer_params: param._value = None + for hook in list(self.clear_state_hooks): + hook(self) + def fill_kwargs(self, keys: tuple[str]) -> dict[str, ArrayLike]: """ Fill the kwargs for an ``@forward`` method with the values of the dynamic diff --git a/tests/test_context.py b/tests/test_context.py index 2749030..58b7dc9 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -1,4 +1,4 @@ -from caskade import Module, Param, forward, ActiveContext, OverrideParam, backend +from caskade import Module, Param, forward, ActiveContext, OverrideParam, backend, active_cache import numpy as np @@ -69,3 +69,30 @@ def testfunc(self): testsim = TestSim() assert testsim.testfunc(backend.make_array([5.0])).item() == 27.0 assert testsim.a.value.item() == 3.0 + +def test_active_cache(): + if backend.backend == "object": + return + + class TestSim(Module): + def __init__(self): + super().__init__() + self.a = Param("a", 3.0) + + @active_cache + @forward + def testcache(self, x, a): + return x + a + + @active_cache + def testonlycache(self, x): + return 2 * x + + @forward + def testfunc(self): + return self.testcache(1.0) + self.testcache(2.0) + self.testonlycache(3.0) + self.testonlycache(4.0) + + testsim = TestSim() + assert testsim.testfunc().item() == 20.0 + assert testsim.testonlycache(5.0) == 10.0 + assert testsim.testonlycache(6.0) == 12.0 \ No newline at end of file diff --git a/tests/test_module.py b/tests/test_module.py index 816afe8..43867de 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -38,7 +38,7 @@ def test_module_methods(): m1.fill_params([1.0, 2.0, 3.0]) with pytest.raises(ActiveStateError): - m1.clear_params() + m1.clear_state() def test_module_delattr():