diff --git a/docs/api/manipulation.md b/docs/api/manipulation.md index 06a023d1..bab25133 100644 --- a/docs/api/manipulation.md +++ b/docs/api/manipulation.md @@ -20,10 +20,6 @@ --- -::: equinox.tree_inference - ---- - ::: equinox.tree_flatten_one_level --- diff --git a/docs/api/nn/inference.md b/docs/api/nn/inference.md new file mode 100644 index 00000000..1c743385 --- /dev/null +++ b/docs/api/nn/inference.md @@ -0,0 +1,3 @@ +# Training/Inference + +::: equinox.nn.inference_mode diff --git a/equinox/__init__.py b/equinox/__init__.py index 12522da5..3adb6404 100644 --- a/equinox/__init__.py +++ b/equinox/__init__.py @@ -53,7 +53,6 @@ tree_check as tree_check, tree_equal as tree_equal, tree_flatten_one_level as tree_flatten_one_level, - tree_inference as tree_inference, ) from ._update import apply_updates as apply_updates from ._vmap_pmap import ( @@ -61,6 +60,7 @@ filter_vmap as filter_vmap, if_array as if_array, ) +from .nn import inference_mode as tree_inference # noqa: F401 - backward compatibility __version__ = importlib.metadata.version("equinox") diff --git a/equinox/_tree.py b/equinox/_tree.py index 52311d6c..c0e1d39e 100644 --- a/equinox/_tree.py +++ b/equinox/_tree.py @@ -270,75 +270,6 @@ def tree_equal(*pytrees: PyTree) -> Union[bool, np.bool_, Bool[Array, ""]]: return out -def _inferences(pytree): - is_leaf = lambda x: hasattr(x, "inference") and x is not pytree - - out = [pytree.inference] if hasattr(pytree, "inference") else [] - - leaves = [x for x in jtu.tree_leaves(pytree, is_leaf=is_leaf) if is_leaf(x)] - # Nodes with an inference flag might have sub-nodes with an inference flag. - - for x in leaves: - out.extend(_inferences(x)) - return out - - -def tree_inference(pytree: PyTree, value: bool) -> PyTree: - """Convenience function for setting all `inference` attributes on a PyTree. - - `inference` flags are used to toggle the behaviour of a number of the pre-built - neural network layers, such as [`equinox.nn.Dropout`][] or - [`equinox.nn.BatchNorm`][]. - - !!! Example - - ```python - class Model(eqx.Module): - norm: eqx.nn.BatchNorm - dropout: eqx.nn.Dropout - linear: eqx.nn.Linear - - def __init__(self, key): - key1, key2 = jax.random.split(key) - self.norm = eqx.nn.BatchNorm(3, "batch", key=key1) - self.dropout = eqx.nn.Dropout(0.4) - self.linear = eqx.nn.Linear(3, 1, key=key2) - - def __call__(self, x, ctx, *, key): - x, ctx = self.norm(x, ctx) - x = self.dropout(x, key=key) - x = self.linear(x) - return x, ctx - - training_model = Model(jax.random.PRNGKey(0)) - inference_model = eqx.tree_inference(training_model, value=True) - training_model_again = eqx.tree_inference(inference_model, value=False) - ``` - - Equivalent to: - ```python - has_inference = lambda leaf: hasattr(leaf, "inference") - - def where(pytree): - return tuple(x.inference - for x in jtu.tree_leaves(pytree, is_leaf=has_inference) - if has_inference(x)) - - equinox.tree_at(where, pytree, replace_fn=lambda _: value) - ``` - - **Arguments:** - - - `pytree`: the PyTree to modify. - - `value`: the value to set all `inference` attributes to. - - **Returns:** - - A copy of `pytree` with all `inference` flags set to `value`. - """ - return tree_at(_inferences, pytree, replace_fn=lambda _: value) - - def tree_flatten_one_level( pytree: PyTree, ) -> tuple[list[PyTree], PyTreeDef]: # pyright: ignore diff --git a/equinox/nn/__init__.py b/equinox/nn/__init__.py index 2f70587c..96b60232 100644 --- a/equinox/nn/__init__.py +++ b/equinox/nn/__init__.py @@ -13,6 +13,7 @@ ) from ._dropout import Dropout as Dropout from ._embedding import Embedding as Embedding +from ._inference import inference_mode as inference_mode from ._linear import Identity as Identity, Linear as Linear from ._mlp import MLP as MLP from ._normalisation import GroupNorm as GroupNorm, LayerNorm as LayerNorm diff --git a/equinox/nn/_attention.py b/equinox/nn/_attention.py index 435f949f..d4b4d5b6 100644 --- a/equinox/nn/_attention.py +++ b/equinox/nn/_attention.py @@ -168,7 +168,7 @@ def __init__( - `dropout_p`: Dropout probability on attention weights. - `inference`: Whether to actually apply dropout at all. If `True` then dropout is not applied. If `False` then dropout is applied. This may be toggled - with [`equinox.tree_inference`][] or overridden during + with [`equinox.nn.inference_mode`][] or overridden during [`equinox.nn.MultiheadAttention.__call__`][]. - `key`: A `jax.random.PRNGKey` used to provide randomness for parameter initialisation. (Keyword only argument.) diff --git a/equinox/nn/_batch_norm.py b/equinox/nn/_batch_norm.py index 66c9d338..b12c30a6 100644 --- a/equinox/nn/_batch_norm.py +++ b/equinox/nn/_batch_norm.py @@ -38,7 +38,7 @@ class BatchNorm(StatefulLayer): training then statistics are computed using the input data, and the running statistics updated. During inference then just the running statistics are used. Whether the model is in training or inference mode should be toggled using - [`equinox.tree_inference`][]. + [`equinox.nn.inference_mode`][]. """ # noqa: E501 weight: Optional[Float[Array, "input_size"]] @@ -79,7 +79,7 @@ def __init__( - `inference`: If `False` then the batch means and variances will be calculated and used to update the running statistics. If `True` then the running statistics are directly used for normalisation. This may be toggled with - [`equinox.tree_inference`][] or overridden during + [`equinox.nn.inference_mode`][] or overridden during [`equinox.nn.BatchNorm.__call__`][]. - `dtype`: The dtype of the input array. """ diff --git a/equinox/nn/_dropout.py b/equinox/nn/_dropout.py index e3b82960..023a00b6 100644 --- a/equinox/nn/_dropout.py +++ b/equinox/nn/_dropout.py @@ -16,7 +16,7 @@ class Dropout(Module): Note that this layer behaves differently during training and inference. During training then dropout is randomly applied; during inference this layer does nothing. Whether the model is in training or inference mode should be toggled using - [`equinox.tree_inference`][]. + [`equinox.nn.inference_mode`][]. """ # Not static fields as it makes sense to want to modify them via equinox.tree_at. @@ -35,7 +35,7 @@ def __init__( - `p`: The fraction of entries to set to zero. (On average.) - `inference`: Whether to actually apply dropout at all. If `True` then dropout is *not* applied. If `False` then dropout is applied. This may be toggled - with [`equinox.tree_inference`][] or overridden during + with [`equinox.nn.inference_mode`][] or overridden during [`equinox.nn.Dropout.__call__`][]. - `deterministic`: Deprecated alternative to `inference`. """ diff --git a/equinox/nn/_inference.py b/equinox/nn/_inference.py new file mode 100644 index 00000000..e37f57f3 --- /dev/null +++ b/equinox/nn/_inference.py @@ -0,0 +1,74 @@ +import jax.tree_util as jtu +from jaxtyping import PyTree + +from .._tree import tree_at + + +def _inferences(pytree): + is_leaf = lambda x: hasattr(x, "inference") and x is not pytree + + out = [pytree.inference] if hasattr(pytree, "inference") else [] + + leaves = [x for x in jtu.tree_leaves(pytree, is_leaf=is_leaf) if is_leaf(x)] + # Nodes with an inference flag might have sub-nodes with an inference flag. + + for x in leaves: + out.extend(_inferences(x)) + return out + + +def inference_mode(pytree: PyTree, value: bool = True) -> PyTree: + """Convenience function for setting all `inference` attributes. + + `inference` flags are used to toggle the behaviour of a number of the pre-built + neural network layers, such as [`equinox.nn.Dropout`][] or + [`equinox.nn.BatchNorm`][]. + + !!! Example + + ```python + class Model(eqx.Module): + norm: eqx.nn.BatchNorm + dropout: eqx.nn.Dropout + linear: eqx.nn.Linear + + def __init__(self, key): + key1, key2 = jax.random.split(key) + self.norm = eqx.nn.BatchNorm(3, "batch", key=key1) + self.dropout = eqx.nn.Dropout(0.4) + self.linear = eqx.nn.Linear(3, 1, key=key2) + + def __call__(self, x, ctx, *, key): + x, ctx = self.norm(x, ctx) + x = self.dropout(x, key=key) + x = self.linear(x) + return x, ctx + + training_model = Model(jax.random.PRNGKey(0)) + inference_model = eqx.nn.inference_mode(training_model) + training_model_again = eqx.nn.inference_mode(inference_model, value=False) + ``` + + This function is essentially equivalent to: + ```python + has_inference = lambda leaf: hasattr(leaf, "inference") + + def where(pytree): + return tuple(x.inference + for x in jtu.tree_leaves(pytree, is_leaf=has_inference) + if has_inference(x)) + + inference_pytree = equinox.tree_at(where, pytree, replace_fn=lambda _: value) + ``` + + **Arguments:** + + - `pytree`: the PyTree to modify. + - `value`: the value to set all `inference` attributes to. Defaults to `True`, i.e. + inference mode. + + **Returns:** + + A copy of `pytree` with all `inference` flags set to `value`. + """ + return tree_at(_inferences, pytree, replace_fn=lambda _: value) diff --git a/equinox/nn/_spectral_norm.py b/equinox/nn/_spectral_norm.py index 4998c607..d4615738 100644 --- a/equinox/nn/_spectral_norm.py +++ b/equinox/nn/_spectral_norm.py @@ -49,7 +49,7 @@ class SpectralNorm(StatefulLayer, Generic[_Layer]): Note that this layer behaves differently during training and inference. During training then power iterations are updated; during inference they are fixed. Whether the model is in training or inference mode should be toggled using - [`equinox.tree_inference`][]. + [`equinox.nn.inference_mode`][]. """ # noqa: E501 layer: _Layer @@ -81,7 +81,7 @@ def __init__( - `eps`: Epsilon for numerical stability when calculating norms. - `inference`: Whether this is in inference mode, at which time no power iterations are performed. This may be toggled with - [`equinox.tree_inference`][]. + [`equinox.nn.inference_mode`][]. - `key`: A `jax.random.PRNGKey` used to provide randomness for initialisation. (Keyword only argument.) """ diff --git a/examples/deep_convolutional_gan.ipynb b/examples/deep_convolutional_gan.ipynb index 9f9b4f9b..c56aa286 100644 --- a/examples/deep_convolutional_gan.ipynb +++ b/examples/deep_convolutional_gan.ipynb @@ -640,12 +640,12 @@ " return out\n", "\n", "\n", - "inference_gen = eqx.tree_inference(generator, value=True)\n", + "inference_gen = eqx.nn.inference_mode(generator)\n", "inference_gen = eqx.Partial(inference_gen, state=generator_state)\n", "\n", "generated_images = evaluate(inference_gen, noise)\n", "\n", - "inference_discriminator = eqx.tree_inference(discriminator, value=True)\n", + "inference_discriminator = eqx.nn.inference_mode(discriminator)\n", "inference_discriminator = eqx.Partial(\n", " inference_discriminator, state=discriminator_state\n", ")\n", diff --git a/examples/stateful.ipynb b/examples/stateful.ipynb index 38044b8c..ab43867c 100644 --- a/examples/stateful.ipynb +++ b/examples/stateful.ipynb @@ -147,7 +147,7 @@ "metadata": {}, "outputs": [], "source": [ - "inference_model = eqx.tree_inference(model, value=True)\n", + "inference_model = eqx.nn.inference_mode(model)\n", "inference_model = eqx.Partial(inference_model, state=state)\n", "\n", "\n", diff --git a/mkdocs.yml b/mkdocs.yml index 8f363f65..bca0bd88 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -127,6 +127,7 @@ nav: - 'api/nn/embedding.md' - 'api/nn/mlp.md' - 'api/nn/sequential.md' + - 'api/nn/inference.md' - 'api/nn/stateful.md' - Filtering: - 'api/filtering/partition-combine.md' diff --git a/tests/test_nn.py b/tests/test_nn.py index 2ce46abd..5a82441e 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -823,7 +823,7 @@ def test_batch_norm(getkey): # Test that the statistics don't update at inference - ibn = eqx.tree_inference(bn, value=True) + ibn = eqx.nn.inference_mode(bn, value=True) vibn = jax.vmap(ibn, axis_name="batch", in_axes=(0, None), out_axes=(0, None)) out, state = vibn(4 * x1 + 20, state) running_mean3, running_var3 = state.get(bn.state_index) @@ -869,7 +869,7 @@ def λ1(): spectral = eqx.tree_at( lambda s: s.layer.weight, spectral, spectral.layer.weight + 1 ) - spectral = eqx.tree_inference(spectral, value=True) + spectral = eqx.nn.inference_mode(spectral, value=True) assert not jnp.allclose(λ1(), 1) for _ in range(100): _, state = spectral(x, state) diff --git a/tests/test_tree.py b/tests/test_tree.py index 1f40cc07..8e4e38e4 100644 --- a/tests/test_tree.py +++ b/tests/test_tree.py @@ -155,10 +155,10 @@ def run3(x, y): assert not run3(a, 1) -def test_tree_inference(getkey): +def test_inference_mode(getkey): attention = eqx.nn.MultiheadAttention(2, 4, key=getkey()) assert attention.dropout.inference is False - attention2 = eqx.tree_inference(attention, True) + attention2 = eqx.nn.inference_mode(attention) assert attention.dropout.inference is False assert attention2.dropout.inference is True