Skip to content

Commit

Permalink
Moved StatefulLayer to the same file as Sequential, as it's really pa…
Browse files Browse the repository at this point in the history
…rt of the Sequential subsystem.
  • Loading branch information
patrick-kidger committed Aug 31, 2023
1 parent 7dd080a commit a234840
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 166 deletions.
11 changes: 6 additions & 5 deletions equinox/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from ._activations import PReLU as PReLU
from ._attention import MultiheadAttention as MultiheadAttention
from ._batch_norm import BatchNorm as BatchNorm
from ._composed import Lambda as Lambda, MLP as MLP, Sequential as Sequential
from ._conv import (
Conv as Conv,
Conv1d as Conv1d,
Expand All @@ -15,6 +14,7 @@
from ._dropout import Dropout as Dropout
from ._embedding import Embedding as Embedding
from ._linear import Identity as Identity, Linear as Linear
from ._mlp import MLP as MLP
from ._normalisation import GroupNorm as GroupNorm, LayerNorm as LayerNorm
from ._pool import (
AdaptiveAvgPool1d as AdaptiveAvgPool1d,
Expand All @@ -33,9 +33,10 @@
Pool as Pool,
)
from ._rnn import GRUCell as GRUCell, LSTMCell as LSTMCell
from ._spectral_norm import SpectralNorm as SpectralNorm
from ._stateful import (
State as State,
from ._sequential import (
Lambda as Lambda,
Sequential as Sequential,
StatefulLayer as StatefulLayer,
StateIndex as StateIndex,
)
from ._spectral_norm import SpectralNorm as SpectralNorm
from ._stateful import State as State, StateIndex as StateIndex
3 changes: 2 additions & 1 deletion equinox/nn/_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from jaxtyping import Array, Bool, Float

from .._module import field
from ._stateful import State, StatefulLayer, StateIndex
from ._sequential import StatefulLayer
from ._stateful import State, StateIndex


class BatchNorm(StatefulLayer):
Expand Down
133 changes: 1 addition & 132 deletions equinox/nn/_composed.py → equinox/nn/_mlp.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from collections.abc import Callable, Sequence
from collections.abc import Callable
from typing import (
Any,
Literal,
Optional,
overload,
Union,
)

Expand All @@ -12,11 +10,9 @@
import jax.random as jrandom
from jaxtyping import Array, PRNGKeyArray

from .._custom_types import sentinel
from .._doc_utils import doc_repr
from .._module import field, Module
from ._linear import Linear
from ._stateful import State, StatefulLayer


_identity = doc_repr(lambda x: x, "lambda x: x")
Expand Down Expand Up @@ -123,130 +119,3 @@ def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
x = self.layers[-1](x)
x = self.final_activation(x)
return x


class Sequential(Module):
"""A sequence of [`equinox.Module`][]s applied in order.
!!! note
Activation functions can be added by wrapping them in [`equinox.nn.Lambda`][].
"""

layers: tuple

def __init__(self, layers: Sequence[Callable]):
self.layers = tuple(layers)

@overload
@jax.named_scope("eqx.nn.Sequential")
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
...

@overload
def __call__(
self, x: Array, state: State, *, key: Optional[PRNGKeyArray] = None
) -> tuple[Array, State]:
...

def __call__(
self,
x: Array,
state: State = sentinel,
*,
key: Optional[PRNGKeyArray] = None,
) -> Union[Array, tuple[Array, State]]:
"""**Arguments:**
- `x`: passed to the first member of the sequence.
- `state`: If provided, then it is passed to, and updated from, any layer
which subclasses [`equinox.nn.StatefulLayer`][].
- `key`: Ignored; provided for compatibility with the rest of the Equinox API.
(Keyword only argument.)
**Returns:**
The output of the last member of the sequence.
If `state` is passed, then a 2-tuple of `(output, state)` is returned.
If `state` is not passed, then just the output is returned.
"""

if key is None:
keys = [None] * len(self.layers)
else:
keys = jrandom.split(key, len(self.layers))
for layer, key in zip(self.layers, keys):
if isinstance(layer, StatefulLayer):
x, state = layer(x, state=state, key=key)
else:
x = layer(x, key=key)
if state is sentinel:
return x
else:
return x, state

def __getitem__(self, i: Union[int, slice]) -> Callable:
if isinstance(i, int):
return self.layers[i]
elif isinstance(i, slice):
return Sequential(self.layers[i])
else:
raise TypeError(f"Indexing with type {type(i)} is not supported")

def __iter__(self):
yield from self.layers

def __len__(self):
return len(self.layers)


Sequential.__init__.__doc__ = """**Arguments:**
- `layers`: A sequence of [`equinox.Module`][]s.
"""


class Lambda(Module):
"""Wraps a callable (e.g. an activation function) for use with
[`equinox.nn.Sequential`][].
Precisely, this just adds an extra `key` argument (that is ignored). Given some
function `fn`, then `Lambda` is essentially a convenience for `lambda x, key: f(x)`.
!!! faq
If you get a TypeError saying the function is not a valid JAX type, see the
[FAQ](https://docs.kidger.site/equinox/faq/).
!!! Example
```python
model = eqx.nn.Sequential(
[
eqx.nn.Linear(...),
eqx.nn.Lambda(jax.nn.relu),
...
]
)
```
"""

fn: Callable[[Any], Any]

def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
"""**Arguments:**
- `x`: The input JAX array.
- `key`: Ignored.
**Returns:**
The output of the `fn(x)` operation.
"""
return self.fn(x)


Lambda.__init__.__doc__ = """**Arguments:**
- `fn`: A callable to be wrapped in [`equinox.Module`][].
"""
162 changes: 162 additions & 0 deletions equinox/nn/_sequential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import abc
from collections.abc import Callable, Sequence
from typing import Any, Optional, overload, Union

import jax
import jax.random as jr
from jaxtyping import Array, PRNGKeyArray

from .._custom_types import sentinel
from .._module import Module
from ._stateful import State


class Sequential(Module):
"""A sequence of [`equinox.Module`][]s applied in order.
!!! note
Activation functions can be added by wrapping them in [`equinox.nn.Lambda`][].
"""

layers: tuple

def __init__(self, layers: Sequence[Callable]):
self.layers = tuple(layers)

@overload
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
...

@overload
def __call__(
self, x: Array, state: State, *, key: Optional[PRNGKeyArray] = None
) -> tuple[Array, State]:
...

@jax.named_scope("eqx.nn.Sequential")
def __call__(
self,
x: Array,
state: State = sentinel,
*,
key: Optional[PRNGKeyArray] = None,
) -> Union[Array, tuple[Array, State]]:
"""**Arguments:**
- `x`: passed to the first member of the sequence.
- `state`: If provided, then it is passed to, and updated from, any layer
which subclasses [`equinox.nn.StatefulLayer`][].
- `key`: Ignored; provided for compatibility with the rest of the Equinox API.
(Keyword only argument.)
**Returns:**
The output of the last member of the sequence.
If `state` is passed, then a 2-tuple of `(output, state)` is returned.
If `state` is not passed, then just the output is returned.
"""

if key is None:
keys = [None] * len(self.layers)
else:
keys = jr.split(key, len(self.layers))
for layer, key in zip(self.layers, keys):
if isinstance(layer, StatefulLayer):
x, state = layer(x, state=state, key=key)
else:
x = layer(x, key=key)
if state is sentinel:
return x
else:
return x, state

def __getitem__(self, i: Union[int, slice]) -> Callable:
if isinstance(i, int):
return self.layers[i]
elif isinstance(i, slice):
return Sequential(self.layers[i])
else:
raise TypeError(f"Indexing with type {type(i)} is not supported")

def __iter__(self):
yield from self.layers

def __len__(self):
return len(self.layers)


Sequential.__init__.__doc__ = """**Arguments:**
- `layers`: A sequence of [`equinox.Module`][]s.
"""


class StatefulLayer(Module):
"""An abstract base class, used to mark a stateful layer for the sake of
[`equinox.nn.Sequential`][]. If `Sequential` sees that a layer inherits
from `StatefulLayer`, then it will know to pass in `state` as well as the
piped data `x`.
Subclasses must implement the `__call__` method that takes input data and the
current state as arguments and returns the output data and updated state.
"""

@abc.abstractmethod
def __call__(
self,
x: Array,
state: State,
*,
key: Optional[PRNGKeyArray],
) -> tuple[Array, State]:
"""The function signature that stateful layers should conform to, to be
compatible with [`equinox.nn.Sequential`][].
"""
raise NotImplementedError("Subclasses must implement the __call__ method.")


class Lambda(Module):
"""Wraps a callable (e.g. an activation function) for use with
[`equinox.nn.Sequential`][].
Precisely, this just adds an extra `key` argument (that is ignored). Given some
function `fn`, then `Lambda` is essentially a convenience for `lambda x, key: f(x)`.
!!! faq
If you get a TypeError saying the function is not a valid JAX type, see the
[FAQ](https://docs.kidger.site/equinox/faq/).
!!! Example
```python
model = eqx.nn.Sequential(
[
eqx.nn.Linear(...),
eqx.nn.Lambda(jax.nn.relu),
...
]
)
```
"""

fn: Callable[[Any], Any]

def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
"""**Arguments:**
- `x`: The input JAX array.
- `key`: Ignored.
**Returns:**
The output of the `fn(x)` operation.
"""
return self.fn(x)


Lambda.__init__.__doc__ = """**Arguments:**
- `fn`: A callable to be wrapped in [`equinox.Module`][].
"""
3 changes: 2 additions & 1 deletion equinox/nn/_spectral_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from .._module import field
from .._tree import tree_at
from ._stateful import State, StatefulLayer, StateIndex
from ._sequential import StatefulLayer
from ._stateful import State, StateIndex


def _power_iteration(weight, u, v, eps):
Expand Down
Loading

0 comments on commit a234840

Please sign in to comment.