Equinox v0.11.2
Features
- Added
eqx.filter_jit(..., donate="all-except-first")
andeqx.filter_jit(..., donate="warn-except-first")
. This offers a way to donate all arguments except the first one. (If you have multiple such arguments then just pack them together into a tuple in the first argument.) This aims to be a low-overhead easy way to handle buffer donation. - Added
eqx.debug.{assert_max_traces, get_num_traces}
, which aim to provide a friendly way of asserting that a JIT'd function is not recompiled -- and if it is, which argument changed to cause the recompilation. eqx.tree_pprint
andeqx.tree_pformat
now handle PyTorch tensors andjax.ShapeDtypeStruct
s.eqx.tree_equal
now has new arguments:typematch=True
: this will require that every leaf have precisely the same type as each other, i.e. right now the requirement is essentiallyleaf == leaf2
; with this flag it becomestype(leaf) == type(leaf2) and leaf == leaf2
.rtol
andatol
: setting these to nonzero values allows for checking that inexact (floating or complex) arrays are allclose, rather than exactly equal.- The expectation is that these will be useful in unit tests, e.g. to write checks of the form
assert eqx.tree_equal(output, expected_output, typematch=True, rtol=1e-5, atol=1e-5)
.
Bugfixes
- Previously, a learnt activation function for
eqx.nn.MLP
would use the exact same learnt weights for every neuron in every layer. Now, a separate copy of the activation function is used in each location. - Subclasses of
eqx.Module
should now have their__init__
signatures correctly reported by downstream tooling, e.g. automated doc generators, some IDEs. (Thanks @danielward27! #573)
Typing
eqx.filter_value_and_grad
now declares that it preserves the return type of its function (Thanks @ConnorBaker! #557)
Documentation
- Fix missing index argument in docstring example for
StateIndex
(Thanks @edwardwli! #556) - Fixed broken link in
eqx.Enumueration
docstrings (Thanks @LouisDesdoigts! #579) - Fixed missing shape specification by in one of the tricks. (Thanks @homerjed! #582)
Other
- Improved a few IPython tracebacks with appropriate
__tracebackhide__ = True
assignments. - Subclassed
eqx.Enumeration
s can now override the message associated with their parent Enumeration: this now produces a warning rather than an error. - Documented the
EQX_ON_ERROR_BREAKPOINT_FRAMES
config variable, which is used to work around a JAX bug when settingEQX_ON_ERROR=breakpoint
. - Can now monkey-patch the methods of an
eqx.Module
, e.g.the anticipated use-case for this is to make it easier for typecheckers; see #584.class Foo(eqx.Module): def f(self): ... Foo.f = some_transform(Foo.f)
eqx.debug.store_dce
now supports non-arrays in its argument.eqx.Enumeration.where(traced_pred, x, x)
will now statically returnx
without tracing. This is occasionally useful to better propagate information at compile time.
Internal features (not officially supported, advanced use only)
- Added
eqx.internal.GetKey
. This generates a random JAX PRNG key when called, and crucially has a nice__repr__
reporting what the seed value is. This should not be used in normal JAX code! This is intended as a convenience for tests, so that the random seed appears in the debug printout of a failed test. - Added
eqx.internal.MaybeBuffer
to indicate that an argument of aneqx.internal.{while_loop,scan}
might be wrapped in a buffer. - Added
eqx.internal.buffer_at_set
to supportbuffer.at[...].set(..., pred=...)
whilst being agnostic to whetherbuffer
is a JAX array or one of our while loop buffers.
New Contributors
- @edwardwli made their first contribution in #556
- @ConnorBaker made their first contribution in #557
- @danielward27 made their first contribution in #573
Full Changelog: v0.11.1...v0.11.2