Skip to content

Equinox v0.11.2

Compare
Choose a tag to compare
@github-actions github-actions released this 13 Nov 18:28
· 136 commits to main since this release

Features

  • Added eqx.filter_jit(..., donate="all-except-first") and eqx.filter_jit(..., donate="warn-except-first"). This offers a way to donate all arguments except the first one. (If you have multiple such arguments then just pack them together into a tuple in the first argument.) This aims to be a low-overhead easy way to handle buffer donation.
  • Added eqx.debug.{assert_max_traces, get_num_traces}, which aim to provide a friendly way of asserting that a JIT'd function is not recompiled -- and if it is, which argument changed to cause the recompilation.
  • eqx.tree_pprint and eqx.tree_pformat now handle PyTorch tensors and jax.ShapeDtypeStructs.
  • eqx.tree_equal now has new arguments:
    • typematch=True: this will require that every leaf have precisely the same type as each other, i.e. right now the requirement is essentially leaf == leaf2; with this flag it becomes type(leaf) == type(leaf2) and leaf == leaf2.
    • rtol and atol: setting these to nonzero values allows for checking that inexact (floating or complex) arrays are allclose, rather than exactly equal.
    • The expectation is that these will be useful in unit tests, e.g. to write checks of the form assert eqx.tree_equal(output, expected_output, typematch=True, rtol=1e-5, atol=1e-5).

Bugfixes

  • Previously, a learnt activation function for eqx.nn.MLP would use the exact same learnt weights for every neuron in every layer. Now, a separate copy of the activation function is used in each location.
  • Subclasses of eqx.Module should now have their __init__ signatures correctly reported by downstream tooling, e.g. automated doc generators, some IDEs. (Thanks @danielward27! #573)

Typing

  • eqx.filter_value_and_grad now declares that it preserves the return type of its function (Thanks @ConnorBaker! #557)

Documentation

  • Fix missing index argument in docstring example for StateIndex (Thanks @edwardwli! #556)
  • Fixed broken link in eqx.Enumueration docstrings (Thanks @LouisDesdoigts! #579)
  • Fixed missing shape specification by in one of the tricks. (Thanks @homerjed! #582)

Other

  • Improved a few IPython tracebacks with appropriate __tracebackhide__ = True assignments.
  • Subclassedeqx.Enumerations can now override the message associated with their parent Enumeration: this now produces a warning rather than an error.
  • Documented the EQX_ON_ERROR_BREAKPOINT_FRAMES config variable, which is used to work around a JAX bug when setting EQX_ON_ERROR=breakpoint.
  • Can now monkey-patch the methods of an eqx.Module, e.g.
    class Foo(eqx.Module):
        def f(self): ...
    
    Foo.f = some_transform(Foo.f)
    the anticipated use-case for this is to make it easier for typecheckers; see #584.
  • eqx.debug.store_dce now supports non-arrays in its argument.
  • eqx.Enumeration.where(traced_pred, x, x) will now statically return x without tracing. This is occasionally useful to better propagate information at compile time.

Internal features (not officially supported, advanced use only)

  • Added eqx.internal.GetKey. This generates a random JAX PRNG key when called, and crucially has a nice __repr__ reporting what the seed value is. This should not be used in normal JAX code! This is intended as a convenience for tests, so that the random seed appears in the debug printout of a failed test.
  • Added eqx.internal.MaybeBuffer to indicate that an argument of an eqx.internal.{while_loop,scan} might be wrapped in a buffer.
  • Added eqx.internal.buffer_at_set to support buffer.at[...].set(..., pred=...) whilst being agnostic to whether buffer is a JAX array or one of our while loop buffers.

New Contributors

Full Changelog: v0.11.1...v0.11.2