Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0.11.0 #536

Merged
merged 52 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
7b2a918
Renamed eqx.tree_inference -> eqx.nn.inference_mode, as it's really a…
patrick-kidger Aug 31, 2023
776e7d7
Embedding now supports initialistion with just a weight. It now valid…
patrick-kidger Sep 5, 2023
d71dc06
LayerNorm is now a bit safer, by explicitly validating shapes. (dlwh …
patrick-kidger Sep 5, 2023
858123a
Added equinox.internal.closure_to_pytree
patrick-kidger Sep 7, 2023
3ce4668
Fixed tree_check raising false positives
patrick-kidger Sep 7, 2023
3c97e17
EQX_ON_ERROR=breakpoint now produces a warning, to help avoid folks l…
patrick-kidger Sep 7, 2023
9e9a5b8
Support class creation kwargs, e.g. `class A(eqx.Module, foo=bar)`.
patrick-kidger Sep 8, 2023
dfdb5d0
PyTree closures now have stricter equality
patrick-kidger Sep 8, 2023
b5482ff
Bound methods now support equality against each other.
patrick-kidger Sep 8, 2023
5295f42
PyTree closures now have stricter equality... in the REPL as well
patrick-kidger Sep 8, 2023
5ee2ee0
filter_eval_shape now alwasys has a .out_struct property
patrick-kidger Sep 11, 2023
ff8b143
Improved documentation for eqx.field
patrick-kidger Sep 11, 2023
73b3385
Added Module.__check_init__ for after-initialisation checking of inva…
patrick-kidger Sep 11, 2023
bd9020a
Now explicitly check that sketchy method assignment doesn't happen in…
patrick-kidger Sep 11, 2023
c555307
module_update_wrapper now looks at __wrapped__, rather than requiring…
patrick-kidger Sep 11, 2023
d7513da
Added notes on the abstract/final design pattern. It's a fairly eleme…
patrick-kidger Sep 12, 2023
9b1f492
Added strict modules, for enforcing the abstract/final design pattern.
patrick-kidger Sep 12, 2023
60159f9
Fix doc typo
patrick-kidger Sep 12, 2023
93eeb02
Updated stateful operations to be able to support vmap'ing.
patrick-kidger Aug 31, 2023
244978d
Stateful operations now support creating states multiple times, such …
patrick-kidger Sep 12, 2023
9dda998
Added deprecation message
patrick-kidger Sep 12, 2023
dc8e968
Removed errors from internal README, now that they're part of the pub…
patrick-kidger Sep 12, 2023
013d0a4
Added lots of extra commentary for explaining how Modules work.
patrick-kidger Sep 12, 2023
a0ba042
add `filter_checkpoint` (#497)
dlwh Sep 12, 2023
56cc31f
Typo fix
patrick-kidger Sep 13, 2023
57afee1
Added support for sharing layers between different parts of a model.
patrick-kidger Sep 7, 2023
0e52c9e
Bump version number
patrick-kidger Sep 17, 2023
eb65acd
Calling equinox.filter_{grad, value_and_grad} without a positional fi…
patrick-kidger Sep 20, 2023
8638ac7
Fix static typechecking error with JAX version 0.4.16
patrick-kidger Sep 20, 2023
8e6e950
Statefulness is now propagated through Sequential layers.
patrick-kidger Sep 21, 2023
9cc2760
Added FAQ entries comparing against PyTorch etc.
patrick-kidger Sep 21, 2023
e210188
Fixed spurious error when accessing methods during __init__.
patrick-kidger Sep 22, 2023
c68fea4
Fixed up wrong annotation for closure conversion.
patrick-kidger Sep 22, 2023
d9b018a
Added an example for a Vision Transformer (ViT) (#483)
ahmed-alllam Sep 23, 2023
a3e9531
Using both __init__ and __post_init__ now raises a warning, as these …
patrick-kidger Sep 25, 2023
3b6fe93
Added test and explanation for assigning __init__.__doc__
patrick-kidger Sep 25, 2023
e0f0484
Added more tests for Modules. Fixed edge-case when using eqx.field in…
patrick-kidger Sep 25, 2023
e9cacff
Updated the abstract/final design pattern.
patrick-kidger Sep 26, 2023
b9a0e95
Removed the abstract/final design pattern from the documentation for …
patrick-kidger Sep 26, 2023
4ab56b8
strict=True now allows special forms like typing.Generic etc.
patrick-kidger Sep 26, 2023
a644ac4
Field conversion now happens before __post_init__
patrick-kidger Sep 26, 2023
fc13135
Tweaked ViT example and added it docs. Improved error from LayerNorm …
patrick-kidger Sep 27, 2023
7ea43cf
AbstractVars are now overriden by annotations in subclasses.
patrick-kidger Sep 28, 2023
89f3405
Moved conversion to Foo.__init__ from MetaFoo.__call__.
patrick-kidger Sep 27, 2023
dbcd6ad
Added help on debugging recompilation
patrick-kidger Sep 28, 2023
f8cea8b
Updated FAQ on sharing layers to be mention eqx.nn.Shared
patrick-kidger Sep 28, 2023
8132b7c
Handle `np.generic` in ser/de and add path info to deserialisation er…
colehaus Sep 28, 2023
0c06b17
Fixed broken link
patrick-kidger Sep 28, 2023
9528aee
Fixed eqx.field(metadata=...) resulting in static and converter being…
patrick-kidger Sep 28, 2023
9ef366e
eqx.tree_equal now behaves better with NumPy.
patrick-kidger Sep 28, 2023
bcc29da
Updated serde errors to operate in a type-safe way.
patrick-kidger Sep 28, 2023
74a346b
Serialisation fix for float64 scalars, which otherwise get downcast b…
patrick-kidger Sep 29, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ build/
dist/
site/
examples/data
examples/CIFAR
.all_objects.cache
.pymon
.idea
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ repos:
rev: v1.1.315
hooks:
- id: pyright
additional_dependencies: [beartype, einops, jax, jaxtyping, pytest, tensorflow, tf2onnx, typing_extensions]
additional_dependencies: [beartype, einops, jax, jaxtyping, optax, pytest, tensorflow, tf2onnx, typing_extensions]
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.6.3
hooks:
Expand Down
4 changes: 4 additions & 0 deletions docs/api/filtering/transformations.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ Likewise, `eqx.filter_grad` will automatically differentiate all floating-point

---

::: equinox.filter_checkpoint

---

::: equinox.filter_custom_jvp

---
Expand Down
4 changes: 0 additions & 4 deletions docs/api/manipulation.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@

---

::: equinox.tree_inference

---

::: equinox.tree_flatten_one_level

---
Expand Down
99 changes: 99 additions & 0 deletions docs/api/module/advanced_fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,105 @@ Equinox modules can be used as [abstract base classes](https://docs.python.org/3
selection:
members: false

## Checking invariants

Equinox extends dataclasses with a `__check_init__` method, which is automatically ran after initialisation. This can be used to check invariants like so:

```python
class Positive(eqx.Module):
x: int

def __check_init__(self):
if self.x <= 0:
raise ValueError("Oh no!")
```

This method has three key differences compared to the `__post_init__` provided by dataclasses:

- It is not overridden by an `__init__` method of a subclass. In contrast, the following code has a bug (Equinox will raise a warning if you do this):

```python
class Parent(eqx.Module):
x: int

def __post_init__(self):
if self.x <= 0:
raise ValueError("Oh no!")

class Child(Parent):
x_as_str: str

def __init__(self, x):
self.x = x
self.x_as_str = str(x)

Child(-1) # No error!
```

- It is automatically called for parent classes; `super().__check_init__()` is not required:

```python
class Parent(eqx.Module):
def __check_init__(self):
print("Parent")

class Child(Parent):
def __check_init__(self):
print("Child")

Child() # prints out both Child and Parent
```

As with the previous bullet point, this is to prevent child classes accidentally failing to check that the invariants of their parent hold.

- Assignment is not allowed:

```python
class MyModule(eqx.Module):
foo: int

def __check_init__(self):
self.foo = 1 # will raise an error
```

This is to prevent `__check_init__` from doing anything too surprising: as the name suggests, it's meant to be used for checking invariants.

## Creating wrapper modules

::: equinox.module_update_wrapper

<!--
## Strict modules

Equinox supports an entirely optional "strict mode", for validating that you follow the abstract/final design pattern as discussed in [this style guide](../../../pattern/).

When enabled via
```python
class Foo(eqx.Module, strict=True):
...
```
then the following things are checked when you define your class (an error is raised if they fail).

- That all base classes are also strict `eqx.Module`s.
- That concrete classes are final.
- The `__init__` method and all fields are all defined on a single class.
- That abstract classes have names beginning with `"Abstract"`.
- That no concrete method is overridden. For example, this will raise an error:
```python
class Foo(eqx.Module):
def f(self): ...

class Bar(Foo, strict=True):
def f(self): ...
```
but this is allowed:
```python
class Abstract(eqx.Module):
@abc.abstractmethod
def f(self): ...

class Concrete(Abstract, strict=True):
def f(self): ...
```

Just the strict `Module` is checked. Subclasses will not become strict unless they also opt-in. This makes it possible to safely enable strict modules in a library, without affecting any downstream users.-->
3 changes: 3 additions & 0 deletions docs/api/nn/inference.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Training/Inference

::: equinox.nn.inference_mode
2 changes: 1 addition & 1 deletion docs/api/nn/sequential.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ These are useful when building fairly straightforward models. But for anything n
::: equinox.nn.StatefulLayer
selection:
members:
- __call__
- is_stateful
7 changes: 7 additions & 0 deletions docs/api/nn/shared.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Sharing layers

::: equinox.nn.Shared
selection:
members:
- __init__
- __call__
106 changes: 101 additions & 5 deletions docs/api/nn/stateful.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,113 @@
# Stateful operations

These are the tools that underly stateful operations, like [`equinox.nn.BatchNorm`][] or [`equinox.nn.SpectralNorm`][].
These are the tools that underlie stateful operations, like [`equinox.nn.BatchNorm`][] or [`equinox.nn.SpectralNorm`][]. These are fairly unusual layers, so most users will not need this part of the API.

See the [stateful example](../../examples/stateful.ipynb) for an example of working with stateful operations.
!!! Example

::: equinox.nn.State
The [stateful example](../../examples/stateful.ipynb) is a good reference for the typical workflow for stateful layers.

---

::: equinox.nn.make_with_state

## Extra features

Let's explain how this works under the hood. First of all, all stateful layers (`BatchNorm` etc.) include an "index". This is basically just a unique hashable value (used later as a dictionary key), and an initial value for the state:

::: equinox.nn.StateIndex
selection:
members:
- __init__

---

::: equinox.nn.StateIndex
This `State` object that's being passed around is essentially just a dictionary, mapping from `StateIndex`s to PyTrees-of-arrays. Correspondingly this has `.get` and `.set` methods to read and write values to it.

::: equinox.nn.State
selection:
members:
- __init__
- get
- set
- substate
- update

## Custom stateful layers

Let's use [`equinox.nn.StateIndex`][] to create a custom stateful layer.

```python
import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array

class Counter(eqx.Module):
index: eqx.nn.StateIndex

def __init__(self):
init_state = jnp.array(0)
self.index = eqx.nn.StateIndex(init_state)

def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]:
value = state.get(self.index)
new_x = x + value
new_state = state.set(self.index, value + 1)
return new_x, new_state

counter, state = eqx.nn.make_with_state(Counter)()
x = jnp.array(2.3)

num_calls = state.get(counter.index)
print(f"Called {num_calls} times.") # 0

_, state = counter(x, state)
num_calls = state.get(counter.index)
print(f"Called {num_calls} times.") # 1

_, state = counter(x, state)
num_calls = state.get(counter.index)
print(f"Called {num_calls} times.") # 2
```

## Vmap'd stateful layers

This is an advanced thing to do! Here we'll build on [the ensembling guide](../../../tricks/#ensembling), and see how how we can create vmap'd stateful layers.

This follows on from the previous example, in which we define `Counter`.
```python
import jax.random as jr

class Model(eqx.Module):
linear: eqx.nn.Linear
counter: Counter
v_counter: Counter

def __init__(self, key):
# Not-stateful layer
self.linear = eqx.nn.Linear(2, 2, key=key)
# Stateful layer.
self.counter = Counter()
# Vmap'd stateful layer. (Whose initial state will include a batch dimension.)
self.v_counter = eqx.filter_vmap(Counter, axis_size=2)()

def __call__(self, x: Array, state: eqx.nn.State) -> tuple[Array, eqx.nn.State]:
# This bit happens as normal.
assert x.shape == (2,)
x = self.linear(x)
x, state = self.counter(x, state)

# For the vmap, we have to restrict our state to just those states we want to
# vmap, and then update the overall state again afterwards.
#
# After all, the state for `self.counter` isn't expecting to be batched, so we
# have to remove that.
substate = state.substate(self.v_counter)
x, substate = eqx.filter_vmap(self.v_counter)(x, substate)
state = state.update(substate)

return x, state

key = jr.PRNGKey(0)
model, state = eqx.nn.make_with_state(Model)(key)
x = jnp.array([5.0, -1.0])
model(x, state)
```
94 changes: 86 additions & 8 deletions docs/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,13 @@ class Module(eqx.Module):
self.linear1 = shared_linear
self.linear2 = shared_linear
```
in which the same object is saved multiple times in the model.
in which the same object is saved multiple times in the model. However, after making some gradient updates you'll find that `self.linear1` and `self.linear2` are now different.

Don't do this!
This is intended. In Equinox+JAX, models are Py*Trees*, not DAGs. (Directed acyclic graphs.) This is basically just an arbitrary choice JAX made a long time ago in its design, but it does generally make reasoning about your code fairly easy. (You never need to track if an object is used in multiple places.)

After making some gradient updates you'll find that `self.linear1` and `self.linear2` are now different.
That said, it can sometimes happen that you really do want to tie together multiple nodes in your PyTree. If this is the case, then use [`equinox.nn.Shared`][], which provides this behaviour. (It stores things as a tree, and then inserts a reference to each node into the right place whenever you need it.)

Recall that in Equinox, models are PyTrees. Meanwhile, JAX treats all PyTrees as *trees*: that is, the same object does not appear more in the tree than once. (If it did, then it would be a *directed acyclic graph* instead.) If JAX ever encounters the same object multiple times then it will unwittingly make independent copies of the object whenever it transforms the overall PyTree.

The resolution is simple: just don't store the same object in multiple places in the PyTree.

You can check for whether you have duplicate nodes by using the [`equinox.tree_check`][] function.
You can also check for whether you have duplicate nodes by using the [`equinox.tree_check`][] function.

## How do I input higher-order tensors (e.g. with batch dimensions) into my model?

Expand Down Expand Up @@ -160,3 +156,85 @@ class Model(eqx.Module):
def __call__(self, x):
return self.param * x + jax.lax.stop_gradient(self.buffer)
```

## I think my function is being recompiled each time it is run.

You can check each time your function is compiled by adding a print statement:
```python
@eqx.filter_jit
def your_function(x, y, z):
print("Compiling!")
... # rest of your code here
```
JAX calls your function each time it needs to compile it. Afterwards, it never actually calls it -- indeed it doesn't use Python at all! Instead, it uses its compiled copy of your function, which only performs array operations. Thus, a print statement is an easy way to check each time JAX is compiling your function.

A function will be recompiled every time the shape or dtype of its arrays changes, or if any of its static (non-array) inputs change (as measured by `__eq__`).

If you want to check which argument is causing an undesired recompilation, then this can be done by checking each argument in turn:
```python
@eqx.filter_jit
def check_arg(arg, name):
print(f"Argument {name} is triggering a compile.")


for step, (x, y, z) in enumerate(...): # e.g. a training loop
print(f"Step is {step}")
check_arg(x, "x")
check_arg(y, "y")
check_arg(z, "z")
your_function(x, y, z)
```
for which you'll often see output like
```
Step is 0
Argument x is triggering a compile.
Argument y is triggering a compile.
Argument z is triggering a compile.
Step is 1
Argument y is triggering a compile.
Step is 2
Argument y is triggering a compile.
...
```
On the very first step, none of the arguments have been seen before, so they all trigger a compile. On later steps, just the problematic argument will trigger a recompilation of `check_arg` -- this will be the one that is triggering a recompilation of `your_function` as well!

## How does Equinox compare to...?

#### ...PyTorch?

JAX+Equinox is usually faster than PyTorch (a stronger JIT compiler), and more featureful (e.g. supporting jit-of-vmap, forward-mode autolinearisation, and autoparallelism).

For those doing scientific computing or scientific ML, then JAX+Equinox also has a much stronger ecosystem. For example, PyTorch no longer has a library for solving differential equations (torchdiffeq is unmaintained). Meanwhile, JAX has [Diffrax](https://github.com/patrick-kidger/diffrax).

Both JAX+Equinox and PyTorch are roughly equally easy to use. PyTorch tends to be a easier for new users (e.g. it's closer to being "Python as normal", and there's less functional programming), whilst JAX+Equinox generally supports advanced use-cases more cleanly (e.g. PyTorch has multiple JIT compilers each with their own quirks -- `torch.{fx, jit.script, jit.trace, compile, _dynamo, ...}` -- whilst JAX+Equinox just has the one).

PyTorch is older, and enjoys broader adoption -- it's generally easier to find developers for PyTorch, or off-the-shelf model implementations using it.

#### ...Keras?

These are two very different libraries, with very different target audiences. Keras is great for plug-and-play building of models -- it's often compared to using Lego. This makes it a convenient framework for standing up neural networks quickly. Equinox is much lower level: it tries to support more general use-cases (e.g. its downstream scientific ecosystem), but usually requires greater proficiency with numerical computing / software development / machine learning.

#### ...Flax?

- Flax introduces multiple new abstractions (`flax.linen.Module`, `flax.linen.Variable`, `Module.setup` vs `flax.linen.compact`, `flax.struct.dataclass`, etc.). Equinox tries to avoid adding new abstractions to core JAX; everything is always just a PyTree.
- Flax is a DSL: it is generally incompatible with non-Flax code, and requires using wrapped `flax.linen.{vmap, scan, ...}` rather than the native `jax.{vmap, ...}`. In contrast, Equinox allows you to use native JAX operations and aims to be compatible with arbitrary JAX code.
- Bound methods of `eqx.Module` are just PyTrees. In Flax this isn't the case -- passing around bound methods will either result in errors or recompilations, depending what you do. Likewise, `eqx.Module` handles inheritance correctly, including propagating metadata like docstrings. The equivalent `flax.struct.dataclass` silently misbehaves. Overall Equinox seems to have fewer footguns.
- Equinox offers several advanced features (like [runtime errors](../api/errors/) or [PyTree manipulation](../api/manipulation/#equinox.tree_at)) not found in other libraries.

See also the [Equinox paper](https://arxiv.org/abs/2111.00254).

#### ...Julia?

The Julia ecosystem has [historically been buggy](https://kidger.site/thoughts/jax-vs-julia/).

At time of writing, Julia does not yet have a robust autodifferentiation system. For example, it has multiple competing implementations -- both Diffractor.jl and ForwardDiff.jl for forward-mode autodifferentiation, and all of Tracker.jl, Zygote.jl, Enzyme.jl, ReverseDiff.jl for reverse-mode autodifferentiation. It does not yet support higher-order autodifferentiation robustly. In contrast, JAX+Equinox use a single strong autodifferentiation system.

However, note that JAX+Equinox don't try to offer a completely general programming model: they are optimised for arrays and linear algebra. (Essentially, the sorts of things you use NumPy for.) They're not designed for e.g. a branch-and-bound combinatorial optimisation algorithm, and for these purposes Julia will be superior.

Julia is often a small amount faster on microbenchmarks on CPUs. JAX+Equinox supports running on TPUs, whilst Julia generally does not.

**You're obviously biased! Are the above comparisons fair?**

Seriously, we think they're fair! Nonetheless all of the above approaches have their adherents, so it seems like all of these approaches are doing something right. So if you're already happily using one of them for your current project... then keep using them. (Don't rewrite things for no reason.) But conversely, we'd invite you to try Equinox for your next project. :)

For what it's worth, if you have the time to learn (e.g. you're a grad student), then we'd strongly recommend trying all of the above. All of these libraries have made substantial innovations, and have all made substantially moved the numerical computing space forward. Equinox deliberately takes inspiration from them. For example Julia has an excellent type system, and this has strongly informed [this Equinox design pattern](../pattern/).
Loading