Skip to content

Releases: patrick-kidger/equinox

Equinox v0.11.10

08 Dec 02:44
Compare
Choose a tag to compare

This is a JAX 0.4.36 compatibility release.

With this release, JAX changed how custom primitive rules are called (they are always called, instead of only when the data requires them to be). That requires some updates in Equinox to avoid crashes in the downstream ecosystem. (patrick-kidger/diffrax#532, jax-ml/jax#25289 + links therein.)

Full Changelog: v0.11.9...v0.11.10

Equinox v0.11.9

24 Nov 15:01
Compare
Choose a tag to compare

This is a (important) bugfix release.

  • Fix filter_vmap with out_axes!=0,1 producing outputs with the wrong axis order. (Thanks @remifan! #900, #901)

Full Changelog: v0.11.8...v0.11.9

Equinox v0.11.8

18 Oct 17:19
689c35a
Compare
Choose a tag to compare

The main thing for this release is JAX 0.4.34 compatibility -- JAX introduced breaking changes in this release that we are now compatible with. (#871)

Bugfixes

  • Accessing the concrete implementation of an abstract class attribute within __init_subclass__ should no longer crash. (Plus probably better-behaved __init_subclass__ overall.)

Miscellaneous

  • JAX 0.4.33 introduced a change that broke eqx.error_if's nice displaying of error message. With this release then we are back to having nice error messages again!
  • eqx.nn.StateIndex can now be passed through jax.jit (and not just eqx.filter_jit). (Thanks @NeilGirdhar! #843)
  • Normalization layers now upcast to at least 32-bit precision. (Thanks @AakashKumarNain! #876)
  • Poetry has a bug in its interpretation of ~= version constraints. We now work around that for better compatibility with certain kinds of Poetry installations. (Thanks @norpadon! #878)

Documentation

New Contributors

Full Changelog: v0.11.7...v0.11.8

Equinox v0.11.7

18 Sep 17:10
31b554f
Compare
Choose a tag to compare

Quick release. JAX 0.4.32 / 0.4.33 just introduced a breaking change; this release ensures Equinox is compatible with this. (#856)

Full Changelog: v0.11.6...v0.11.7

Equinox v0.11.6

14 Sep 09:34
Compare
Choose a tag to compare

This is primarily a bug fix release.

  • Runtime error messages (those from eqx.error_if, in particular when wrapped with eqx.filter_jit) should now be compatible with PyCharm's debugger, and with certain multithreaded contexts. (Thanks @adam-hartshorne, @dlwh! #828, #844, #849)

  • Marking a jax.Array or np.ndarray as an eqx.field(static=True) will now raise a warning. This was technically okay as long as you use it in certain very narrow contexts (e.g. to smuggle it into a JIT'd region without being traced), but in practice it was nearly always just a common new-user footgun. (Thanks @lockwo! #800)

  • Using eqx.tree_at for replacing empty tuples is improved. (Thanks @danielward27! #818, #819)

  • eqx.nn.RotaryEmbedding no longer promote input dtypes to at least float32. (Thanks @knyazer! #836)

  • Mypy now understands that eqx.Modules are dataclasses. (Pyright always did, but mypy needed a slightly different approach to appreciate this fact.) (Thanks @NeilGirdhar! #822)

  • Multiple eqx.Modules participating in co-operative multiple inheritance (at least 5 inheriting from each other seem to be necessary?), with some of them overriding the __post_init__s of others, should now follow their expected resolution order. (Thanks @NeilGirdhar! #832, #834)

  • We now have a .editorconfig file, (thanks @NeilGirdhar! #821)

  • Doc improvements. (Thanks @garymm, @ColCarroll! #804, #805)

New Contributors

Full Changelog: v0.11.5...v0.11.6

Equinox v0.11.5

18 Aug 19:11
Compare
Choose a tag to compare

JAX compatibility

Recent versions of JAX (0.4.28+) have made some changes to:

  • Hashing of tracers;
  • Tree-map'ing over Nones;
  • Callbacks;
  • Pretty-printing.

With this update, we should now be compatible with both old and new versions of JAX: this fixes both some new crashes, and some new warnings. (#719, #724, #753, #758, thanks @jakevdp, @hawkinsp!)

Better errors

  • The error messages from eqx.error_if are now substantially more informative: they include traceback information including the stack, and mention the availability of the EQX_ON_ERROR variable. We also do a much better job hiding the large unhelpful printouts that XLA gives by default. (#785, #803)

  • The default value of EQX_ON_ERROR_BREAKPOINT_FRAMES is now 1. (#777) The impact of this is that using eqx.error_if alongside EQX_ON_ERROR=breakpoint will now:

    • reliably always open a debugger, rather than sometimes crashing at trace-time due to upstream JAX bug #16732.
    • however, by default the debugger will no longer include any additional stack frames above it (accessed via u).
    • much of the above is now explained in a printed-out informative message prior to the debugger opening.

Bugfixes

  • eqx.filter_{jacfwd, jacrev} now only apply filtering to their inputs but not their outputs. Previously this was problematic as there was no way to represent static-input-by-static-output in the returned Jacobian, so pieces were silently dropped. (#734, thanks @lockwo!)

  • eqx.tree_at can now be used to replace empty tuples. (#715, #717, #722, thanks @lockwo!)

  • eqx.filter_custom_jvp no longer raises a trace-time crash in some scenarios in which its **kwargs were erroneously counted as having tangents. (#745 (comment), #749)

  • No longer getting a trace-time crash when doing a particular combination of vmap + autodiff + checkpointed while loops. This occurred when using optimistix.BFGS around diffrax.diffeqsolve. (#777)

  • Fixed a trace-time crash when:

    • using a checkpointed while loop...
    • ...with a body function that has a closed-over tracer...
    • ...and that closed-over tracer is differentiated...
    • ...and there are no other closed-over tracers that are differentiated...
    • ...and the dependency on that tracer is only linear.
    • (patrick-kidger/diffrax#387 (comment), #752, thanks @dkweiss31!)
  • Fixed a trace-time crash when composing the grad of vmap of lineax.linear_solve. (patrick-kidger/lineax#101, #795, thanks @rhacking!)

  • eqx.nn.RMSNorm now uses at least 32-bit precision for numerical stability (#723, thanks @AakashKumarNain!)

New features

Other changes

New Contributors

Full Changelog: v0.11.4...v0.11.5

Equinox v0.11.4

14 Apr 13:04
Compare
Choose a tag to compare

Features

  • Added eqx.filter_shard. This lowers to jax.lax.with_sharding_constraint as a single way to transfer data, or reshard data, both inside and outside of JIT! (No more jax.device_put.) In addition, the parallelism example has been updated to use this simpler new functionality. (Thanks @homerjed and @dlwh! #688, #691)

  • Added eqx.filter_{jacfwd,jacrev,hessian}. These do what you expect! (Thanks @lockwo! #677)

  • Added eqx.nn.RotaryPostionalEmbedding. This is designed to be used in conjunction with the existing eqx.nn.MultiheadAttention. (Thanks @Artur-Galstyan! #568)

  • Added support for padding='VALID', padding='SAME', padding='SAME_LOWER' to the convolutional layers: eqx.nn.{Conv, ...}. (Thanks @ChenAo-Phys! #658)

  • Added support for padding_mode='ZEROS', padding_mode='REFLECT', padding_mode='REPLICATE', padding_mode='CIRCULAR' to the convolutional layers: eqx.nn.{Conv, ...}. (Thanks @ChenAo-Phys! #658)

  • Added a dtype argument to eqx.nn.{MultiheadAttention, Linear, Conv, ...} for specifying the dtype of their parameters. In addition eqx.nn.BatchNorm will now also uses its dtype argument to determine the dtype of its weights and bias, not just the dtype of its moving statistics. (Thanks @Artur-Galstyan and @AakashKumarNain! #680, #689)

Compatibility

  • eqx.error_if is now compatible with JAX 0.4.26, which changed JAX's own reporting of error messages slightly. (Thanks @hawkinsp! #670)

  • Added a warning that checks for doing something like:

    class MyModule(eqx.Module):
    	fn: Callable
    
        def __init__(self, ...):
    	    self.fn = jax.vmap(some_fn)

    As this is an easy source of bugs. (The vmap'd function is not a PyTree so will not propagate anything in the PyTree stucture of some_fn.)

Technical internal stuff

  • eqx.internal.while_loop(..., kind="checkpointed") will now only propagate forward JVP tracers for those outputs which are perturbed due to the input to the loop being perturbed. (Rather than all of them.) This change just means that later calls to a nondifferentiable operation, like jax.pure_callback or eqx.internal.nondifferentiable, will no longer crash at trace time. (See patrick-kidger/diffrax#396.)

  • eqx.internal.while_loop(..., kind="bounded") will now handle certain vmap+grad combinations without crashing. (It seems like JAX is adding some spurious batch tracers.) (See patrick-kidger/optimistix#48 (comment))

  • the transpose rule for eqx.internal.create_vprim now understands symbolic zeros, fixing a crash for grad-of-vmap-of-<lineax.linear_solve that we only use some outputs from>. (See patrick-kidger/optimistix#48.)

  • The type annotation for the input of any converter function used in eqx.field(converter=...) will now be used as the type annotation in any dataclass-autogenerated __init__ functions. In particular this should mean such functions are now compatible with runtime type checkers like beartype. (jaxtyping users, you were already covered: this checks the assigned annotations instead.)

New Contributors

Full Changelog: v0.11.3...v0.11.4

Equinox v0.11.3

10 Jan 21:26
Compare
Choose a tag to compare

Features

  • Added equinox.nn.RMSNorm.
  • Added equinox.nn.WeightNorm.
  • equinox.tree_deserialise_leaves now treats jax.ShapeDtypeStructs in the same way as arrays. This makes it possible to avoid instantiating the initial model parameters only to throw them away again, by using equinox.filter_eval_shape:
    model = eqx.filter_eval_shape(Model, ...hyperparameters...)
    model = eqx.tree_deserialise_leaves(load_path, model)
    (#259)

Bugfixes

  • equinox.internal.noinline no longer initialises the JAX backend on use.
  • equinox.filter_jit(...).lower(..., some_kwarg=...) no longer crashes (#625, #627)
  • The state of equionx.nn.BatchNorm now uses the default floating point dtype, rather than always using float32.
  • equinox.nn.MultiheadAttention should now perform the softmax in float32 even when the input is of lower dtype. (This is important for numerical stability.)

Refactor

  • All the layers in equinox.nn.{Linear, MLP, ...} now standardise on accepting extra **kwargs and not calling super().__init__. The intention is that these layers be treated as final, i.e. not subclassable. (Previously things were inconsistent: some did this and some did not.)
  • Should now be compatible with JAX_NUMPY_DTYPE_PROMOTION=strict and JAX_NUMPY_RANK_PROMOTION=raise, and this is checked in tests.
  • Better error message when no kwargs passed to filter_grad (Thanks @knyazer! #589)

Internal features

These are undocumented internal features, that may be changed at any time.

  • Added EQX_GETKEY_SEED for use with equinox.internal.GetKey.
  • equinox.internal.while_loop now has its runtime errors removed. This should help with compatibility with TPUs. (#628)

New Contributors

Full Changelog: v0.11.2...v0.11.3

Equinox v0.11.2

13 Nov 18:28
Compare
Choose a tag to compare

Features

  • Added eqx.filter_jit(..., donate="all-except-first") and eqx.filter_jit(..., donate="warn-except-first"). This offers a way to donate all arguments except the first one. (If you have multiple such arguments then just pack them together into a tuple in the first argument.) This aims to be a low-overhead easy way to handle buffer donation.
  • Added eqx.debug.{assert_max_traces, get_num_traces}, which aim to provide a friendly way of asserting that a JIT'd function is not recompiled -- and if it is, which argument changed to cause the recompilation.
  • eqx.tree_pprint and eqx.tree_pformat now handle PyTorch tensors and jax.ShapeDtypeStructs.
  • eqx.tree_equal now has new arguments:
    • typematch=True: this will require that every leaf have precisely the same type as each other, i.e. right now the requirement is essentially leaf == leaf2; with this flag it becomes type(leaf) == type(leaf2) and leaf == leaf2.
    • rtol and atol: setting these to nonzero values allows for checking that inexact (floating or complex) arrays are allclose, rather than exactly equal.
    • The expectation is that these will be useful in unit tests, e.g. to write checks of the form assert eqx.tree_equal(output, expected_output, typematch=True, rtol=1e-5, atol=1e-5).

Bugfixes

  • Previously, a learnt activation function for eqx.nn.MLP would use the exact same learnt weights for every neuron in every layer. Now, a separate copy of the activation function is used in each location.
  • Subclasses of eqx.Module should now have their __init__ signatures correctly reported by downstream tooling, e.g. automated doc generators, some IDEs. (Thanks @danielward27! #573)

Typing

  • eqx.filter_value_and_grad now declares that it preserves the return type of its function (Thanks @ConnorBaker! #557)

Documentation

  • Fix missing index argument in docstring example for StateIndex (Thanks @edwardwli! #556)
  • Fixed broken link in eqx.Enumueration docstrings (Thanks @LouisDesdoigts! #579)
  • Fixed missing shape specification by in one of the tricks. (Thanks @homerjed! #582)

Other

  • Improved a few IPython tracebacks with appropriate __tracebackhide__ = True assignments.
  • Subclassedeqx.Enumerations can now override the message associated with their parent Enumeration: this now produces a warning rather than an error.
  • Documented the EQX_ON_ERROR_BREAKPOINT_FRAMES config variable, which is used to work around a JAX bug when setting EQX_ON_ERROR=breakpoint.
  • Can now monkey-patch the methods of an eqx.Module, e.g.
    class Foo(eqx.Module):
        def f(self): ...
    
    Foo.f = some_transform(Foo.f)
    the anticipated use-case for this is to make it easier for typecheckers; see #584.
  • eqx.debug.store_dce now supports non-arrays in its argument.
  • eqx.Enumeration.where(traced_pred, x, x) will now statically return x without tracing. This is occasionally useful to better propagate information at compile time.

Internal features (not officially supported, advanced use only)

  • Added eqx.internal.GetKey. This generates a random JAX PRNG key when called, and crucially has a nice __repr__ reporting what the seed value is. This should not be used in normal JAX code! This is intended as a convenience for tests, so that the random seed appears in the debug printout of a failed test.
  • Added eqx.internal.MaybeBuffer to indicate that an argument of an eqx.internal.{while_loop,scan} might be wrapped in a buffer.
  • Added eqx.internal.buffer_at_set to support buffer.at[...].set(..., pred=...) whilst being agnostic to whether buffer is a JAX array or one of our while loop buffers.

New Contributors

Full Changelog: v0.11.1...v0.11.2

Equinox v0.11.1

13 Oct 02:17
Compare
Choose a tag to compare

This is a minor bugfix release.

Bugfixes

  • Checkpointed while loops (eqx.internal.while_loop(..., kind="checkpointed")) now perform a more careful analysis of which arguments need to be differentiated. (#548) This fix is the primary reason for this release -- it unlocks some efficiency improvements when solving SDEs in Diffrax: patrick-kidger/diffrax#320
  • Fixed Abstract{Class,}Var misbehaving around multiple inheritance. (#544)
  • Better compatibility with the beartype library. In a few cases this was throwing some spurious errors to do with forward references. (#543)

Documentation

  • Fix scan-over-layers example in docs. (Thanks @mcbal! #542)

Other

  • Static type checkers should now use Equinox's type hints correctly. (Specfically, we now have the py.typed marker file. Thanks @vidhanio! #547)
  • Added the EQX_ON_ERROR_BREAKPOINT_FRAMES environment variable, to work around JAX bug jax-ml/jax#16732 when using EQX_ON_ERROR=breakpoint. This new variable sets the number of stack frames you can access via the u debugger command, when the on-error debugger is triggered. Set this to a small enough number, e.g. EQX_ON_ERROR_BREAKPOINT_FRAMES=1, and it should fix unusual trace-time errors when using EQX_ON_ERROR=breakpoint.

New Contributors

Full Changelog: v0.11.0...v0.11.1