Skip to content

Commit

Permalink
Stateful operations now support creating states multiple times, such …
Browse files Browse the repository at this point in the history
…that they are compatible with the original model.
  • Loading branch information
patrick-kidger committed Sep 12, 2023
1 parent 8da3f19 commit a7d03e9
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 8 deletions.
42 changes: 34 additions & 8 deletions equinox/nn/_stateful.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from collections.abc import Callable
from typing import Any, Generic, TypeVar
from typing import Any, Generic, TYPE_CHECKING, TypeVar, Union
from typing_extensions import ParamSpec

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import PyTree

from .._module import Module
from .._module import field, Module
from .._pretty_print import bracketed, named_objs, text, tree_pformat
from .._tree import tree_at

Expand Down Expand Up @@ -42,7 +42,9 @@ def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]:
[`equinox.nn.BatchNorm`][] for further reference.
""" # noqa: E501

marker: object
# Starts off as an `object` when initialised; later replaced with an `int` inside
# `make_with_state`.
marker: Union[object, int] = field(static=True)
init: _Value

def __init__(self, init: _Value):
Expand Down Expand Up @@ -334,10 +336,34 @@ def __init__(self, foo, bar):
```
"""

def make_with_state_impl(*args: _P.args, **kwargs: _P.kwargs) -> tuple[_T, State]:
model = make_model(*args, **kwargs)
state = State(model)
model = delete_init_state(model)
return model, state
# _P.{args, kwargs} not supported by beartype
if TYPE_CHECKING:

def make_with_state_impl(
*args: _P.args, **kwargs: _P.kwargs
) -> tuple[_T, State]:
...

else:

def make_with_state_impl(*args, **kwargs) -> tuple[_T, State]:
model = make_model(*args, **kwargs)

# Replace all markers with `int`s. This is needed to ensure that two calls
# to `make_with_state` produce compatible models and states.
leaves, treedef = jtu.tree_flatten(model, is_leaf=_is_index)
counter = 0
new_leaves = []
for leaf in leaves:
if _is_index(leaf):
leaf = StateIndex(leaf.init)
object.__setattr__(leaf, "marker", counter)
counter += 1
new_leaves.append(leaf)
model = jtu.tree_unflatten(treedef, new_leaves)

state = State(model)
model = delete_init_state(model)
return model, state

return make_with_state_impl
56 changes: 56 additions & 0 deletions tests/test_stateful.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import pytest

Expand All @@ -16,3 +19,56 @@ def test_delete_init_state():
leaves = [x for x in jtu.tree_leaves(model) if eqx.is_array(x)]
leaves2 = [x for x in jtu.tree_leaves(model2) if eqx.is_array(x)]
assert len(leaves) == len(leaves2) + 3


def test_double_state():
# From https://github.com/patrick-kidger/equinox/issues/450#issuecomment-1714501666

class Counter(eqx.Module):
index: eqx.nn.StateIndex

def __init__(self):
init_state = jnp.array(0)
self.index = eqx.nn.StateIndex(init_state)

def __call__(self, x, state):
value = state.get(self.index)
new_x = x + value
new_state = state.set(self.index, value + 1)
return new_x, new_state

class Model(eqx.Module):
linear: eqx.nn.Linear
counter: Counter
v_counter: Counter

def __init__(self, key):
# Not-stateful layer
self.linear = eqx.nn.Linear(2, 2, key=key)
# Stateful layer.
self.counter = Counter()
# Vmap'd stateful layer. (Whose initial state will include a batch
# dimension.)
self.v_counter = eqx.filter_vmap(Counter, axis_size=2)()

def __call__(self, x, state):
assert x.shape == (2,)
x = self.linear(x)
x, state = self.counter(x, state)
substate = state.substate(self.v_counter)
x, substate = eqx.filter_vmap(self.v_counter)(x, substate)
state = state.update(substate)
return x, state

key = jr.PRNGKey(0)
model, state = eqx.nn.make_with_state(Model)(key)
x = jnp.array([5.0, -1.0])
model(x, state)

@jax.jit
def make_state(key):
_, state = eqx.nn.make_with_state(Model)(key)
return state

new_state = make_state(jr.PRNGKey(1))
model(x, new_state)

0 comments on commit a7d03e9

Please sign in to comment.