-
-
Notifications
You must be signed in to change notification settings - Fork 147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
v0.11.0 #536
Merged
Merged
v0.11.0 #536
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…n eqx.nn thing, not an eqx thing.
…ates its input as being a scalar index.
…pointed out that otherwise this can silently normalise over sequence dimensions in transformer models if you miss the vmap)
…eaving it on permanently by accident.
Previously, this was erroneously skipped when the function lacked any closed-over variables. In addition, this commit adds eqxi.cached_filter_eval_shape, as that is needed for the above.
… __init__. This emerged as a prerequisite to having #491 give the desired behaviour.
… an extra argument. It no longer mutates its input.
…ntary pattern but I find myself needing a reference to this quite frequently...
This is actually a minor breaking change: the API has changed from StateIndex(callable_returning_value) to just StateIndex(value). This change also introduces substantially updated documentation on stateful layers, which should help a lot.
…that they are compatible with the original model.
The implementation really isn't that tricky, so this serves to highlight what each piece is doing.
…rst argument now raises a better error message.
Drive-by: improved pretty-printing of dataclasses with unintialised fields.
* Added an example for a vision transformer (vit) * Changed dataset to CIFAR10, added reference to eqxvision's ViT module * Refactored the Vision Transformer example for improved code structure and readability. * Fixed a small issue in positional embeddings
…are typically mutually exclusive
… superclass with property in subclass.
- Updated the text to be easier to read. - Loosened the requirements on the pattern a little bit -- in particular to allow abstract classes to have __init__ methods and attributes, as long as they're the only class in the hierarchy to have them. - Now explicitly checking that you don't subclass a concrete class, as otherwise you could still do so, just without overriding any methods / only adding new methods.
…now -- let's test the strict=True implementation in Diffrax/Lineax etc. first
…(that this example was hitting).
When subclassing, it used to be the case that `cls.__abstractvars__` basically never got smaller (only if an element was overriden by a class-level attribute or method): ```python class Foo(eqx.Module): x: AbstractVar[bool] class Bar(eqx.Module): x: bool Bar.__abstractvars__ == frozenset({"x"}) ``` This was intended -- the idea is that the all abstractvars would get checked during initialisation, i.e. validity wrt this condition being a property of the instance, rather than being a property of just the class object. With this change, the above example will remove `x` from `__abstractvars__`. This is because it's useful and typical to reason about whether a class is abstract or not -- it's much more annoying to have to reason about whether each individual instance is abstract. Indeed the recent changes to `eqx.Module`, in strict mode, are a use case in which want to be able to reason about things in this way. Thus, any element of either `subcls.__dict__` or `subcls.__annotations__` can be used to concretise any abstract variable, rather than just doing `hasattr(self, var)` during initialisation.
This is necessary to allow downstream libraries, like jaxtyping, to monkey-patch in their own checks.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Better errors
Equinox now includes several additional checks to guard against various bugs! If you have a new error, then this is probably an indication that your code always had a silent bug, and should be updated.
eqx.nn.LayerNorm
now correctly validates that the shape of its input. This was a common cause of silent bugs. (Thanks @dlwh for pointing this one out!)__init__
and__post_init__
-- the former actually overwrites the latter. (This is normal Python dataclass behaviour, but probably unexpected.)self
, the other atself.foo.__self__
. (The latter being in the bound method.)eqx.tree_at
now gives a better error message if you use it try to and update something that isn't a PyTree leaf. (Thanks @LouisDesdoigts!)API changes
These should all be very minor.
eqx.nn.StateIndex
now takes the initial value, rather than a function that returns the initial value.eqx.field(conveter=...)
, then conversion now happens before__post_init__
, rather than after it.eqx.nn.make_with_state
overeqx.nn.State
. The latter will continue to work, but the former is more memory-efficient. (It deletes the original copy of the initial state.)eqx.nn.inference_mode
overeqx.tree_inference
. The latter will continue to exist for backward compatibility. These are the same function, this is really just a matter of moving it into theeqx.nn
namespace where it always belonged.Sharing layers
Equinox now supports sharing a layer between multiple parts of your model! This has probably been our longest-requested feature -- in large part because of how intractable it seemed. Equinox models are PyTrees, not PyDAGs, so how exactly are we supposed to have two different parts of our model point at the same layer?
The answer turned out to be the following -- in this example, we're reusing the embedding weight matrix between the initial embedding layer, and the final readout layer, of a language model.
here,
eqx.nn.Shared(...)
simply removes all of the nodes atwhere
, so that we don't have two separate copies. Then when it is called atself.shared()
, it puts them back again. Note that this isn't a copy and doesn't incur any additional memory overhead; this all happens at the Python level, not the XLA level.(The curious may like to take a look at the implementation in
equinox/nn/_shared.py
, which turned out to be very simple.)On a meta level, I'd like to comment that I'm quite proud of having gotten this one in! It means that Equinox now supports both stateful layers and shared layers, which have always been the two pieces that seemed out of reach when using something as simple as PyTrees to represent models. But it turns out that PyTrees really are all you need. :D
Other changes
Documentation
Features
eqx.filter_checkpoint
, which as you might expect is a filtered version ofjax.checkpoint
. (Thanks @dlwh!)eqx.Module.__check_init__
. This is run in a similar fashion to__post_init__
; see the documentation. This can be used to check that invariants of your module hold after initialisation.eqx.nn.State.{substate, update}
. This offers a way to subset or update aState
object, that so only the parts of it that need to be vmap'd are passed in. See the stateful documentation for an example of how to do this.INTERNAL: Generated function failed: CpuCallback error
stuff! This clean-up of the runtime error message is done byeqx.filter_jit
, so that will need to be your top-level way of JIT'ing your computation.eqx.nn.StatefulLayer
-- this is (only!) witheqx.nn.Sequential
, to indicate that the layer should be called withx, state
, and not justx
. If you would like a custom stateful layer to be compatible withSequential
then go ahead and subclass this, and potentially implement theis_stateful
method. (Thanks @paganpasta!)eqx.nn.*
layer is now wrapped in ajax.named_scope
, for better debugging experience. (Thanks @ahmed-alllam!)eqx.module_update_wrapper
no longer requires a second argument; it will look at the__wrapped__
attribute of its first argument.eqx.internal.closure_to_pytree
, for... you guessed it, turning function closures into PyTrees. The closed-over variables are treated as the subnodes in the PyTree. This will operate recursively so that closed-over closures will themselves become PyTrees, etc. Note that closed-over global variables are not included.Bugfixes
eqx.tree_{serialise,deserialise}_leaves
now correctly handle unusual NumPy scalars, likebfloat16
. (Thanks @colehaus!)eqx.field(metadata=...)
arguments no longer results in thestatic
/converter
arguments being ignored. (Thanks @mjo22!)eqx.filter_custom_vjp
now supports residuals that are not arrays. (The residuals are the pytree that is passed between the forward and backward pass.)eqx.{AbstractVar,AbstractClassVar}
should now support overriden generics in subclasses. That is, something like this:eqx.internal.while_loop
now supports using custom (non-Equinox) pytrees in the state.eqx.tree_check
no longer raises some false positives.__init_subclass__
with additional class creation kwargs. (Thanks @ASEM000, @Roger-luo!)