-
-
Notifications
You must be signed in to change notification settings - Fork 147
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Renamed eqx.tree_inference -> eqx.nn.inference_mode, as it's really a…
…n eqx.nn thing, not an eqx thing.
- Loading branch information
1 parent
8c39b21
commit 3bcc628
Showing
15 changed files
with
94 additions
and
88 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,10 +20,6 @@ | |
|
||
--- | ||
|
||
::: equinox.tree_inference | ||
|
||
--- | ||
|
||
::: equinox.tree_flatten_one_level | ||
|
||
--- | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Training/Inference | ||
|
||
::: equinox.nn.inference_mode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import jax.tree_util as jtu | ||
from jaxtyping import PyTree | ||
|
||
from .._tree import tree_at | ||
|
||
|
||
def _inferences(pytree): | ||
is_leaf = lambda x: hasattr(x, "inference") and x is not pytree | ||
|
||
out = [pytree.inference] if hasattr(pytree, "inference") else [] | ||
|
||
leaves = [x for x in jtu.tree_leaves(pytree, is_leaf=is_leaf) if is_leaf(x)] | ||
# Nodes with an inference flag might have sub-nodes with an inference flag. | ||
|
||
for x in leaves: | ||
out.extend(_inferences(x)) | ||
return out | ||
|
||
|
||
def inference_mode(pytree: PyTree, value: bool = True) -> PyTree: | ||
"""Convenience function for setting all `inference` attributes. | ||
`inference` flags are used to toggle the behaviour of a number of the pre-built | ||
neural network layers, such as [`equinox.nn.Dropout`][] or | ||
[`equinox.nn.BatchNorm`][]. | ||
!!! Example | ||
```python | ||
class Model(eqx.Module): | ||
norm: eqx.nn.BatchNorm | ||
dropout: eqx.nn.Dropout | ||
linear: eqx.nn.Linear | ||
def __init__(self, key): | ||
key1, key2 = jax.random.split(key) | ||
self.norm = eqx.nn.BatchNorm(3, "batch", key=key1) | ||
self.dropout = eqx.nn.Dropout(0.4) | ||
self.linear = eqx.nn.Linear(3, 1, key=key2) | ||
def __call__(self, x, ctx, *, key): | ||
x, ctx = self.norm(x, ctx) | ||
x = self.dropout(x, key=key) | ||
x = self.linear(x) | ||
return x, ctx | ||
training_model = Model(jax.random.PRNGKey(0)) | ||
inference_model = eqx.nn.inference_mode(training_model) | ||
training_model_again = eqx.nn.inference_mode(inference_model, value=False) | ||
``` | ||
This function is essentially equivalent to: | ||
```python | ||
has_inference = lambda leaf: hasattr(leaf, "inference") | ||
def where(pytree): | ||
return tuple(x.inference | ||
for x in jtu.tree_leaves(pytree, is_leaf=has_inference) | ||
if has_inference(x)) | ||
inference_pytree = equinox.tree_at(where, pytree, replace_fn=lambda _: value) | ||
``` | ||
**Arguments:** | ||
- `pytree`: the PyTree to modify. | ||
- `value`: the value to set all `inference` attributes to. Defaults to `True`, i.e. | ||
inference mode. | ||
**Returns:** | ||
A copy of `pytree` with all `inference` flags set to `value`. | ||
""" | ||
return tree_at(_inferences, pytree, replace_fn=lambda _: value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters