Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/caskade/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .base import Node
from .backend import backend, ArrayLike
from .context import ActiveContext, ValidContext, OverrideParam
from .decorators import forward
from .decorators import forward, active_cache
from .module import Module
from .param import Param, dynamic
from .collection import NodeCollection, NodeList, NodeTuple
Expand Down Expand Up @@ -42,6 +42,7 @@
"ValidContext",
"OverrideParam",
"forward",
"active_cache",
"test",
"CaskadeException",
"GraphError",
Expand Down
4 changes: 2 additions & 2 deletions src/caskade/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ def __enter__(self):
self.outer_active = self.module.active
if self.outer_active and not self.active:
self.outer_params = list(p.value for p in self.module.dynamic_params)
self.module.clear_params()
self.module.clear_state()
self.module.active = self.active

def __exit__(self, exc_type, exc_value, traceback):
if not self.outer_active and self.active:
self.module.clear_params()
self.module.clear_state()
self.module.active = self.outer_active
if self.outer_active and not self.active:
self.module.fill_params(self.outer_params)
Expand Down
40 changes: 38 additions & 2 deletions src/caskade/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import functools
from contextlib import ExitStack

from .backend import backend
from .context import ActiveContext, OverrideParam
from .param import Param

__all__ = ("forward",)
__all__ = ("forward", "active_cache")


def forward(method):
Expand Down Expand Up @@ -99,3 +98,40 @@ def wrapped(self, *args, **kwargs):
return method(self, *args, **kwargs)

return wrapped

def active_cache(method):
"""
Decorator to enable caching for a method. This way the method will only be
called once in an active state, and the cache will be dropped when exiting.

Parameters
----------
method: (Callable)
The method to be decorated.

Returns
-------
Callable
The decorated method with caching enabled.
"""

NOVALUE = object()

cache = NOVALUE

def hook(self):
nonlocal cache
cache = NOVALUE
self.clear_state_hooks.remove(hook)

@functools.wraps(method)
def wrapped(self, *args, **kwargs):
nonlocal cache
if not self.active:
return method(self, *args, **kwargs)
elif cache is NOVALUE:
cache = method(self, *args, **kwargs)
self.clear_state_hooks.add(hook)
return cache

return wrapped
8 changes: 6 additions & 2 deletions src/caskade/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, name: Optional[str] = None, **kwargs):
self.local_dynamic_params = {}
self._type = "module"
self.valid_context = False
self.clear_state_hooks = set()

def update_graph(self):
"""Maintain a tuple of dynamic and live parameters at all points lower
Expand Down Expand Up @@ -289,16 +290,19 @@ def fill_params(self, params: Union[ArrayLike, Sequence, Mapping]):
params = self.from_valid(params)
self._fill_values(params)

def clear_params(self):
def clear_state(self):
"""Set all dynamic parameters to None and live parameters to LiveParam.
This is to be used on exiting an ``ActiveContext`` and so should not be
used by a user."""
if not self.active:
raise ActiveStateError(f"Module {self.name} must be active to clear params")
raise ActiveStateError(f"Module {self.name} must be active to clear state")

for param in self.dynamic_params + self.pointer_params:
param._value = None

for hook in list(self.clear_state_hooks):
hook(self)

def fill_kwargs(self, keys: tuple[str]) -> dict[str, ArrayLike]:
"""
Fill the kwargs for an ``@forward`` method with the values of the dynamic
Expand Down
29 changes: 28 additions & 1 deletion tests/test_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from caskade import Module, Param, forward, ActiveContext, OverrideParam, backend
from caskade import Module, Param, forward, ActiveContext, OverrideParam, backend, active_cache
import numpy as np


Expand Down Expand Up @@ -69,3 +69,30 @@ def testfunc(self):
testsim = TestSim()
assert testsim.testfunc(backend.make_array([5.0])).item() == 27.0
assert testsim.a.value.item() == 3.0

def test_active_cache():
if backend.backend == "object":
return

class TestSim(Module):
def __init__(self):
super().__init__()
self.a = Param("a", 3.0)

@active_cache
@forward
def testcache(self, x, a):
return x + a

@active_cache
def testonlycache(self, x):
return 2 * x

@forward
def testfunc(self):
return self.testcache(1.0) + self.testcache(2.0) + self.testonlycache(3.0) + self.testonlycache(4.0)

testsim = TestSim()
assert testsim.testfunc().item() == 20.0
assert testsim.testonlycache(5.0) == 10.0
assert testsim.testonlycache(6.0) == 12.0
2 changes: 1 addition & 1 deletion tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_module_methods():
m1.fill_params([1.0, 2.0, 3.0])

with pytest.raises(ActiveStateError):
m1.clear_params()
m1.clear_state()


def test_module_delattr():
Expand Down
Loading