Skip to content

Commit

Permalink
Tidy up FAQ
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger authored Oct 2, 2023
1 parent 6d13e4e commit 539a367
Showing 1 changed file with 32 additions and 49 deletions.
81 changes: 32 additions & 49 deletions docs/faq.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# FAQ

## Optax is throwing an error.
## Optax throwing a `TypeError`.

Probably you're writing code that looks like
```python
Expand All @@ -18,30 +18,11 @@ optim.init(eqx.filter(model, eqx.is_inexact_array))
```
which after a little thought should make sense: Optax can only optimise floating-point JAX arrays. It's not meaningful to ask Optax to optimise whichever other arbitrary Python objects may be a part of your model. (e.g. the activation function of an `eqx.nn.MLP`).

## A module saved in two places has become two independent copies.
## How are batch dimensions handled?

Probably you're doing something like
```python
class Module(eqx.Module):
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear
All layers in `equinox.nn` are defined to operate on single batch elements, not a whole batch.

def __init__(...):
shared_linear = eqx.nn.Linear(...)
self.linear1 = shared_linear
self.linear2 = shared_linear
```
in which the same object is saved multiple times in the model. However, after making some gradient updates you'll find that `self.linear1` and `self.linear2` are now different.

This is intended. In Equinox+JAX, models are Py*Trees*, not DAGs. (Directed acyclic graphs.) This is basically just an arbitrary choice JAX made a long time ago in its design, but it does generally make reasoning about your code fairly easy. (You never need to track if an object is used in multiple places.)

That said, it can sometimes happen that you really do want to tie together multiple nodes in your PyTree. If this is the case, then use [`equinox.nn.Shared`][], which provides this behaviour. (It stores things as a tree, and then inserts a reference to each node into the right place whenever you need it.)

You can also check for whether you have duplicate nodes by using the [`equinox.tree_check`][] function.

## How do I input higher-order tensors (e.g. with batch dimensions) into my model?

Use [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap). This maps arbitrary JAX operations -- including any Equinox module -- over additional dimensions (such as batch dimensions).
To act on a batch, use [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap). This maps arbitrary JAX operations -- including any Equinox module -- over additional dimensions.

For example if `x` is an array/tensor of shape `(batch_size, input_size)`, then the following PyTorch code:

Expand All @@ -62,39 +43,41 @@ linear = eqx.nn.Linear(input_size, output_size, key=key)
y = jax.vmap(linear)(x)
```

## My model is slow to train!
## How to share a layer across two parts of a model?

Most autodifferentiable programs will have a "numerical bit" (e.g. a training step for your model) and a "normal programming bit" (e.g. saving models to disk).
Use [`equinox.nn.Shared`][] to tie together multiple nodes (layers, weights, ...) in a PyTree.

JAX makes this difference explicit. All the numerical work should go inside a single big JIT region, within which all numerical operations are compiled. For example:
In particular, *don't* do something like this:
```python
# Buggy code!
class Module(eqx.Module):
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear

```python
@eqx.filter_jit
def make_step(model, x, y):
# Inside JIT region
grads = compute_loss(model, x, y)
model = stochastic_gradient_descent(model, grads)
return model
def __init__(...):
shared_linear = eqx.nn.Linear(...)
self.linear1 = shared_linear
self.linear2 = shared_linear
```
as this is used to accomplish something different: this creates two separate layers, that are initialised with the same values for their parameters. After making some gradient updates, you'll find that `self.linear1` and `self.linear2` are now different.

@eqx.filter_grad
def compute_loss(model, x, y):
# Still inside JIT region
...

def stochastic_gradient_descent(model, grads):
# Also inside JIT region
...
The reason for this is that in Equinox+JAX, models are Py*Trees*, not DAGs. (Directed acyclic graphs.) JAX follows a functional-programming-like style, in which the *identity* of an object (whether tha be a layer, a weight, or whatever) doesn't matter. Only its *value* matters. (This is known as referential transparency.)

for step, (x, y) in zip(range(number_of_steps), dataloader):
model = make_step(model, x, y)
# Outside JIT region
```
See also the [`equinox.tree_check`][] function, which can be ran on a model to check if you have duplicate nodes.

## My model is slow...

A common mistake would be to put `jax.jit`/`eqx.filter_jit` on the `compute_loss` function instead of the overall `make_step` function. This would mean doing numerical work (`stochastic_gradient_descent`) outside of JIT. That would run, but would be unnecessarily slow.
#### ...to train.

Make sure you have JIT covering all JAX operations.

Most autodifferentiable programs will have a "numerical bit" (e.g. a training step for your model) and a "normal programming bit" (e.g. saving models to disk). JAX makes this difference explicit. All the numerical work should go inside a single big JIT region, within which all numerical operations are compiled.

See [the RNN example](https://docs.kidger.site/equinox/examples/train_rnn/) as an example of good practice. The whole `make_step` function is JIT compiled in one go.

## My model is slow to compile!
Common mistakes are to put `jax.jit`/`eqx.filter_jit` on just your loss function, and leave out either (a) computing gradients or (b) applying updates with `eqx.apply_updates`.

#### ...to compile.

95% of the time, it's because you've done something like this:
```python
Expand Down Expand Up @@ -145,7 +128,7 @@ This error happens because a model, when treated as a PyTree, may have leaves th

Instead of [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html), use [`equinox.filter_jit`][]. Likewise for [other transformations](https://docs.kidger.site/equinox/api/filtering/transformations).

## How do I mark an array as being non-trainable? (Like PyTorch's buffers?)
## How to mark arrays as non-trainable? (Like PyTorch's buffers?)

This can be done by using `jax.lax.stop_gradient`:
```python
Expand Down Expand Up @@ -237,4 +220,4 @@ Julia is often a small amount faster on microbenchmarks on CPUs. JAX+Equinox sup

Seriously, we think they're fair! Nonetheless all of the above approaches have their adherents, so it seems like all of these approaches are doing something right. So if you're already happily using one of them for your current project... then keep using them. (Don't rewrite things for no reason.) But conversely, we'd invite you to try Equinox for your next project. :)

For what it's worth, if you have the time to learn (e.g. you're a grad student), then we'd strongly recommend trying all of the above. All of these libraries have made substantial innovations, and have all made substantially moved the numerical computing space forward. Equinox deliberately takes inspiration from them. For example Julia has an excellent type system, and this has strongly informed [this Equinox design pattern](../pattern/).
For what it's worth, if you have the time to learn (e.g. you're a grad student), then we'd strongly recommend trying all of the above. All of these libraries have made substantial innovations, and have all made substantially moved the numerical computing space forward. Equinox deliberately takes inspiration from them!

0 comments on commit 539a367

Please sign in to comment.