Equinox v0.11.4
Features
-
Added
eqx.filter_shard
. This lowers tojax.lax.with_sharding_constraint
as a single way to transfer data, or reshard data, both inside and outside of JIT! (No morejax.device_put
.) In addition, the parallelism example has been updated to use this simpler new functionality. (Thanks @homerjed and @dlwh! #688, #691) -
Added
eqx.filter_{jacfwd,jacrev,hessian}
. These do what you expect! (Thanks @lockwo! #677) -
Added
eqx.nn.RotaryPostionalEmbedding
. This is designed to be used in conjunction with the existingeqx.nn.MultiheadAttention
. (Thanks @Artur-Galstyan! #568) -
Added support for
padding='VALID'
,padding='SAME'
,padding='SAME_LOWER'
to the convolutional layers:eqx.nn.{Conv, ...}
. (Thanks @ChenAo-Phys! #658) -
Added support for
padding_mode='ZEROS'
,padding_mode='REFLECT'
,padding_mode='REPLICATE'
,padding_mode='CIRCULAR'
to the convolutional layers:eqx.nn.{Conv, ...}
. (Thanks @ChenAo-Phys! #658) -
Added a
dtype
argument toeqx.nn.{MultiheadAttention, Linear, Conv, ...}
for specifying the dtype of their parameters. In additioneqx.nn.BatchNorm
will now also uses itsdtype
argument to determine the dtype of its weights and bias, not just the dtype of its moving statistics. (Thanks @Artur-Galstyan and @AakashKumarNain! #680, #689)
Compatibility
-
eqx.error_if
is now compatible with JAX 0.4.26, which changed JAX's own reporting of error messages slightly. (Thanks @hawkinsp! #670) -
Added a warning that checks for doing something like:
class MyModule(eqx.Module): fn: Callable def __init__(self, ...): self.fn = jax.vmap(some_fn)
As this is an easy source of bugs. (The vmap'd function is not a PyTree so will not propagate anything in the PyTree stucture of
some_fn
.)
Technical internal stuff
-
eqx.internal.while_loop(..., kind="checkpointed")
will now only propagate forward JVP tracers for those outputs which are perturbed due to the input to the loop being perturbed. (Rather than all of them.) This change just means that later calls to a nondifferentiable operation, likejax.pure_callback
oreqx.internal.nondifferentiable
, will no longer crash at trace time. (See patrick-kidger/diffrax#396.) -
eqx.internal.while_loop(..., kind="bounded")
will now handle certain vmap+grad combinations without crashing. (It seems like JAX is adding some spurious batch tracers.) (See patrick-kidger/optimistix#48 (comment)) -
the transpose rule for
eqx.internal.create_vprim
now understands symbolic zeros, fixing a crash forgrad-of-vmap-of-<lineax.linear_solve that we only use some outputs from>
. (See patrick-kidger/optimistix#48.) -
The type annotation for the input of any converter function used in
eqx.field(converter=...)
will now be used as the type annotation in anydataclass
-autogenerated__init__
functions. In particular this should mean such functions are now compatible with runtime type checkers like beartype. (jaxtyping users, you were already covered: this checks the assigned annotations instead.)
New Contributors
- @ChenAo-Phys made their first contribution in #658
- @hawkinsp made their first contribution in #670
- @AakashKumarNain made their first contribution in #680
- @imilas made their first contribution in #699
Full Changelog: v0.11.3...v0.11.4