Skip to content

Commit

Permalink
Added named_scope to each module in equinox.nn for better profiling a…
Browse files Browse the repository at this point in the history
…nd debugging
  • Loading branch information
ahmed-alllam authored and patrick-kidger committed Aug 30, 2023
1 parent ec21970 commit 7dd080a
Show file tree
Hide file tree
Showing 12 changed files with 31 additions and 0 deletions.
2 changes: 2 additions & 0 deletions equinox/nn/_activations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, Union

import jax
import jax.numpy as jnp
from jaxtyping import Array

Expand Down Expand Up @@ -29,6 +30,7 @@ def __init__(

self.negative_slope = jnp.asarray(init_alpha)

@jax.named_scope("eqx.nn.PReLU")
def __call__(self, x: Array) -> Array:
r"""**Arguments:**
Expand Down
1 change: 1 addition & 0 deletions equinox/nn/_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def __init__(
self.use_value_bias = use_value_bias
self.use_output_bias = use_output_bias

@jax.named_scope("eqx.nn.MultiheadAttention")
def __call__(
self,
query: Float[Array, "q_seq q_size"],
Expand Down
1 change: 1 addition & 0 deletions equinox/nn/_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(
self.channelwise_affine = channelwise_affine
self.momentum = momentum

@jax.named_scope("eqx.nn.BatchNorm")
def __call__(
self,
x: Array,
Expand Down
3 changes: 3 additions & 0 deletions equinox/nn/_composed.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Union,
)

import jax
import jax.nn as jnn
import jax.random as jrandom
from jaxtyping import Array, PRNGKeyArray
Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(
self.use_bias = use_bias
self.use_final_bias = use_final_bias

@jax.named_scope("eqx.nn.MLP")
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
"""**Arguments:**
Expand Down Expand Up @@ -137,6 +139,7 @@ def __init__(self, layers: Sequence[Callable]):
self.layers = tuple(layers)

@overload
@jax.named_scope("eqx.nn.Sequential")
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
...

Expand Down
3 changes: 3 additions & 0 deletions equinox/nn/_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Callable, Sequence
from typing import Optional, TypeVar, Union

import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jrandom
Expand Down Expand Up @@ -150,6 +151,7 @@ def __init__(
self.groups = groups
self.use_bias = use_bias

@jax.named_scope("eqx.nn.Conv")
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
"""**Arguments:**
Expand Down Expand Up @@ -422,6 +424,7 @@ def __init__(
self.groups = groups
self.use_bias = use_bias

@jax.named_scope("eqx.nn.ConvTranspose")
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
"""**Arguments:**
Expand Down
2 changes: 2 additions & 0 deletions equinox/nn/_dropout.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings
from typing import Optional

import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jrandom
Expand Down Expand Up @@ -53,6 +54,7 @@ def __init__(
def deterministic(self):
return self.inference

@jax.named_scope("eqx.nn.Dropout")
def __call__(
self,
x: Array,
Expand Down
2 changes: 2 additions & 0 deletions equinox/nn/_embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

import jax
import jax.random as jrandom
from jaxtyping import Array, Float, PRNGKeyArray

Expand Down Expand Up @@ -45,6 +46,7 @@ def __init__(
self.num_embeddings = num_embeddings
self.embedding_size = embedding_size

@jax.named_scope("eqx.nn.Embedding")
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
"""**Arguments:**
Expand Down
3 changes: 3 additions & 0 deletions equinox/nn/_linear.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
from typing import Any, Literal, Optional, TypeVar, Union

import jax
import jax.numpy as jnp
import jax.random as jrandom
from jaxtyping import Array, PRNGKeyArray
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(
self.out_features = out_features
self.use_bias = use_bias

@jax.named_scope("eqx.nn.Linear")
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
"""**Arguments:**
Expand Down Expand Up @@ -109,6 +111,7 @@ def __init__(self, *args: Any, **kwargs: Any):
# Ignores args and kwargs
super().__init__()

@jax.named_scope("eqx.nn.Identity")
def __call__(self, x: _T, *, key: Optional[PRNGKeyArray] = None) -> _T:
"""**Arguments:**
Expand Down
2 changes: 2 additions & 0 deletions equinox/nn/_normalisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __call__(
) -> tuple[Array, State]:
...

@jax.named_scope("eqx.nn.LayerNorm")
def __call__(
self, x: Array, state: State = sentinel, *, key: Optional[PRNGKeyArray] = None
) -> Union[Array, tuple[Array, State]]:
Expand Down Expand Up @@ -212,6 +213,7 @@ def __call__(
) -> tuple[Array, State]:
...

@jax.named_scope("eqx.nn.GroupNorm")
def __call__(
self, x: Array, state: State = sentinel, *, key: Optional[PRNGKeyArray] = None
) -> Union[Array, tuple[Array, State]]:
Expand Down
8 changes: 8 additions & 0 deletions equinox/nn/_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def _check_is_padding_valid(self, padding):
f"{kernel_size}."
)

@jax.named_scope("eqx.nn.Pool")
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
"""**Arguments:**
Expand Down Expand Up @@ -189,6 +190,7 @@ def __init__(
**kwargs,
)

@jax.named_scope("eqx.nn.AvgPool1d")
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
"""**Arguments:**
Expand Down Expand Up @@ -238,6 +240,7 @@ def __init__(
)

# Redefined to get them in the right order in docs
@jax.named_scope("eqx.nn.MaxPool1d")
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
"""**Arguments:**
Expand Down Expand Up @@ -286,6 +289,7 @@ def __init__(
**kwargs,
)

@jax.named_scope("eqx.nn.AvgPool2d")
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
"""**Arguments:**
Expand Down Expand Up @@ -335,6 +339,7 @@ def __init__(
)

# Redefined to get them in the right order in docs
@jax.named_scope("eqx.nn.MaxPool2d")
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
"""**Arguments:**
Expand Down Expand Up @@ -383,6 +388,7 @@ def __init__(
**kwargs,
)

@jax.named_scope("eqx.nn.AvgPool3d")
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
"""**Arguments:**
Expand Down Expand Up @@ -432,6 +438,7 @@ def __init__(
**kwargs,
)

@jax.named_scope("eqx.nn.MaxPool3d")
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
"""**Arguments:**
Expand Down Expand Up @@ -511,6 +518,7 @@ def __init__(
f"{num_spatial_dims} containing ints."
)

@jax.named_scope("eqx.nn.AdaptivePool")
def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
"""**Arguments:**
Expand Down
3 changes: 3 additions & 0 deletions equinox/nn/_rnn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
from typing import Optional

import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
Expand Down Expand Up @@ -83,6 +84,7 @@ def __init__(
self.hidden_size = hidden_size
self.use_bias = use_bias

@jax.named_scope("eqx.nn.GRUCell")
def __call__(
self, input: Array, hidden: Array, *, key: Optional[PRNGKeyArray] = None
):
Expand Down Expand Up @@ -182,6 +184,7 @@ def __init__(
self.hidden_size = hidden_size
self.use_bias = use_bias

@jax.named_scope("eqx.nn.LSTMCell")
def __call__(self, input, hidden, *, key=None):
"""**Arguments:**
Expand Down
1 change: 1 addition & 0 deletions equinox/nn/_spectral_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def __init__(
u0, v0 = _power_iteration(weight, u0, v0, eps)
self.uv_index = StateIndex(lambda **_: (u0, v0))

@jax.named_scope("eqx.nn.SpectralNorm")
def __call__(
self,
x: Array,
Expand Down

0 comments on commit 7dd080a

Please sign in to comment.