Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0.11.0 #536

Merged
merged 52 commits into from
Sep 29, 2023
Merged

v0.11.0 #536

merged 52 commits into from
Sep 29, 2023

Conversation

patrick-kidger
Copy link
Owner

@patrick-kidger patrick-kidger commented Sep 29, 2023

Better errors

Equinox now includes several additional checks to guard against various bugs! If you have a new error, then this is probably an indication that your code always had a silent bug, and should be updated.

  • eqx.nn.LayerNorm now correctly validates that the shape of its input. This was a common cause of silent bugs. (Thanks @dlwh for pointing this one out!)
  • Equinox now prints out a warning if you supply both __init__ and __post_init__ -- the former actually overwrites the latter. (This is normal Python dataclass behaviour, but probably unexpected.)
  • Equinox now prevents you from assigning Module attributes with a bound method of your current instance, e.g.
    class Model(eqx.Module):
        foo: Callable
    
        def __init__(self):
            self.foo = self.bar
    
        def bar(self):
            ...
    Otherwise, you end up with two different copies of your model! One at self, the other at self.foo.__self__. (The latter being in the bound method.)
  • eqx.tree_at now gives a better error message if you use it try to and update something that isn't a PyTree leaf. (Thanks @LouisDesdoigts!)

API changes

These should all be very minor.

  • Breaking change: eqx.nn.StateIndex now takes the initial value, rather than a function that returns the initial value.
  • Breaking change: If using eqx.field(conveter=...), then conversion now happens before __post_init__, rather than after it.
  • Prefer eqx.nn.make_with_state over eqx.nn.State. The latter will continue to work, but the former is more memory-efficient. (It deletes the original copy of the initial state.)
  • Prefer eqx.nn.inference_mode over eqx.tree_inference. The latter will continue to exist for backward compatibility. These are the same function, this is really just a matter of moving it into the eqx.nn namespace where it always belonged.

Sharing layers

Equinox now supports sharing a layer between multiple parts of your model! This has probably been our longest-requested feature -- in large part because of how intractable it seemed. Equinox models are PyTrees, not PyDAGs, so how exactly are we supposed to have two different parts of our model point at the same layer?

The answer turned out to be the following -- in this example, we're reusing the embedding weight matrix between the initial embedding layer, and the final readout layer, of a language model.

class LanguageModel(eqx.Module):
    shared: eqx.nn.Shared

    def __init__(self):
        embedding = eqx.nn.Embedding(...)
        linear = eqx.nn.Linear(...)
        # These two weights will now be tied together.
        where = lambda embed_and_lin: embed_and_lin[1].weight
        get = lambda embed_and_lin: embed_and_lin[0].weight
        self.shared = eqx.nn.Shared((embedding, linear), where, get)

    def __call__(self, tokens):
        # Expand back out so we can evaluate these layers.
        embedding, linear = self.shared()
        assert embedding.weight is linear.weight  # same parameter!
        # Now go ahead and evaluate your language model.
        ...

here, eqx.nn.Shared(...) simply removes all of the nodes at where, so that we don't have two separate copies. Then when it is called at self.shared(), it puts them back again. Note that this isn't a copy and doesn't incur any additional memory overhead; this all happens at the Python level, not the XLA level.

(The curious may like to take a look at the implementation in equinox/nn/_shared.py, which turned out to be very simple.)

On a meta level, I'd like to comment that I'm quite proud of having gotten this one in! It means that Equinox now supports both stateful layers and shared layers, which have always been the two pieces that seemed out of reach when using something as simple as PyTrees to represent models. But it turns out that PyTrees really are all you need. :D

Other changes

Documentation

  • Many documentation fixes courtesy of @colehaus and @Artur-Galstyan!
  • Added two new examples to the documentation. Thank you to @ahmed-alllam for both of them!
    • Deep convolutional GAN
    • Vision Transformer
  • Added an FAQ entry on comparisons between Equinox and PyTorch/Keras/Julia/Flax. It's a common enough question that should probably have had an answer before now.
  • Added an FAQ entry on debugging recompilation.

Features

  • Added eqx.filter_checkpoint, which as you might expect is a filtered version of jax.checkpoint. (Thanks @dlwh!)
  • Added eqx.Module.__check_init__. This is run in a similar fashion to __post_init__; see the documentation. This can be used to check that invariants of your module hold after initialisation.
  • Added support for vmap'ing stateful layers, by adding eqx.nn.State.{substate, update}. This offers a way to subset or update a State object, that so only the parts of it that need to be vmap'd are passed in. See the stateful documentation for an example of how to do this.
  • Runtime error should now produce much more readable results, without any of the terrifying INTERNAL: Generated function failed: CpuCallback error stuff! This clean-up of the runtime error message is done by eqx.filter_jit, so that will need to be your top-level way of JIT'ing your computation.
  • Added eqx.nn.StatefulLayer -- this is (only!) with eqx.nn.Sequential, to indicate that the layer should be called with x, state, and not just x. If you would like a custom stateful layer to be compatible with Sequential then go ahead and subclass this, and potentially implement the is_stateful method. (Thanks @paganpasta!)
  • The forward pass of each eqx.nn.* layer is now wrapped in a jax.named_scope, for better debugging experience. (Thanks @ahmed-alllam!)
  • eqx.module_update_wrapper no longer requires a second argument; it will look at the __wrapped__ attribute of its first argument.
  • Added eqx.internal.closure_to_pytree, for... you guessed it, turning function closures into PyTrees. The closed-over variables are treated as the subnodes in the PyTree. This will operate recursively so that closed-over closures will themselves become PyTrees, etc. Note that closed-over global variables are not included.

Bugfixes

  • eqx.tree_{serialise,deserialise}_leaves now correctly handle unusual NumPy scalars, like bfloat16. (Thanks @colehaus!)
  • eqx.field(metadata=...) arguments no longer results in the static/converter arguments being ignored. (Thanks @mjo22!)
  • eqx.filter_custom_vjp now supports residuals that are not arrays. (The residuals are the pytree that is passed between the forward and backward pass.)
  • eqx.{AbstractVar,AbstractClassVar} should now support overriden generics in subclasses. That is, something like this:
    class Foo(eqx.Module):
        x: eqx.AbstractVar[list[str]]
    
    class Bar(Foo):
        x: list[str]
    should no longer raise spurious errors under certain conditions.
  • eqx.internal.while_loop now supports using custom (non-Equinox) pytrees in the state.
  • eqx.tree_check no longer raises some false positives.
  • Equinox modules now support __init_subclass__ with additional class creation kwargs. (Thanks @ASEM000, @Roger-luo!)

patrick-kidger and others added 30 commits September 14, 2023 11:37
…pointed out that otherwise this can silently normalise over sequence dimensions in transformer models if you miss the vmap)
Previously, this was erroneously skipped when the function lacked any closed-over variables.
In addition, this commit adds eqxi.cached_filter_eval_shape, as that is needed for the above.
… __init__. This emerged as a prerequisite to having #491 give the desired behaviour.
… an extra argument. It no longer mutates its input.
…ntary pattern but I find myself needing a reference to this quite frequently...
This is actually a minor breaking change: the API has changed from
StateIndex(callable_returning_value) to just StateIndex(value).

This change also introduces substantially updated documentation on
stateful layers, which should help a lot.
…that they are compatible with the original model.
The implementation really isn't that tricky, so this serves to highlight what each piece is doing.
* add `filter_checkpoint`
…rst argument now raises a better error message.
patrick-kidger and others added 22 commits September 22, 2023 10:17
Drive-by: improved pretty-printing of dataclasses with unintialised fields.
* Added an example for a vision transformer (vit)

* Changed dataset to CIFAR10, added reference to eqxvision's ViT module

* Refactored the Vision Transformer example for improved code structure and readability.

* Fixed a small issue in positional embeddings
- Updated the text to be easier to read.
- Loosened the requirements on the pattern a little bit -- in particular
    to allow abstract classes to have __init__ methods and attributes,
    as long as they're the only class in the hierarchy to have them.
- Now explicitly checking that you don't subclass a concrete class, as
    otherwise you could still do so, just without overriding any
    methods / only adding new methods.
…now -- let's test the strict=True implementation in Diffrax/Lineax etc. first
When subclassing, it used to be the case that `cls.__abstractvars__`
basically never got smaller (only if an element was overriden by a
class-level attribute or method):
```python
class Foo(eqx.Module):
  x: AbstractVar[bool]

class Bar(eqx.Module):
  x: bool

Bar.__abstractvars__ == frozenset({"x"})
```
This was intended -- the idea is that the all abstractvars would get
checked during initialisation, i.e. validity wrt this condition being
a property of the instance, rather than being a property of just the
class object.

With this change, the above example will remove `x` from
`__abstractvars__`.

This is because it's useful and typical to reason about whether a class
is abstract or not -- it's much more annoying to have to reason about
whether each individual instance is abstract. Indeed the recent changes
to `eqx.Module`, in strict mode, are a use case in which want to be
able to reason about things in this way.

Thus, any element of either `subcls.__dict__` or
`subcls.__annotations__` can be used to concretise any abstract
variable, rather than just doing `hasattr(self, var)` during
initialisation.
This is necessary to allow downstream libraries, like jaxtyping, to
monkey-patch in their own checks.
…rors (#511)

* Add poetry and nix configuration

* Add support for `np.generic` to serialisation and deserialisation

* Add path info to deserialisation failures

* Make assorted simplifications and tweaks from review

* Revert "Add poetry and nix configuration"

This reverts commit afaf546.
@patrick-kidger patrick-kidger changed the title [DO NOT MERGE - testing only] v0.11.0 Sep 29, 2023
@patrick-kidger patrick-kidger merged commit 557bf36 into main Sep 29, 2023
@patrick-kidger patrick-kidger deleted the dev branch September 29, 2023 21:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants