Releases: patrick-kidger/quax
Quax v0.0.5
Features
- Now supports
jax.lax.while
(Thanks @ymahlau!, #16) - Fix for JAX 0.4.31 by adding sharding parameter (#25)
- Supporting passing precedence to plum dispatcher (Thanks @nstarman! #24)
- Updated ecosystem docs, in particular added Quaxed & unxt. (Thanks @nstarman! #27)
New Contributors
Full Changelog: v0.0.4...v0.0.5
quax v0.0.4
Added py.typed marker by @nstarman in #20 !
Full Changelog: v0.0.3...v0.0.4
Quax v0.0.3
This release is an attempt to start bringing some stability to the wild west that is Quax! :D
Documentation
We now have some shiny new documentation available at https://docs.kidger.site/quax ! Go check it out, including the tutorials for writing your own custom rules using Quax.
Design
There have been a few outstanding design questions with Quax, that this release aims to resolve. This is really the highlight of this release.
Pass Value
s across quaxify
boundaries
Where we used to be a bit laid-back about this, we are now quite careful: the pattern is always that you should create your custom value, and that you should then pass it across a quaxify
boundary. For example, this is okay:
import quax
import quax.examples.lora as lora
x = lora.LoraArray(...)
quax.quaxify(some_function)(x)
but this is not:
@quax.quaxify
def some_function(...):
x = lora.LoraArray(...)
Unless you're about to pass that x
into a nested quaxify
, that is.
Previously, if a quaxified argument encountered another Value
during runtime, then it would automatically quaxify that value for us. The problem with this was that this the autoquaxification only happens on primitive binds, not before running any other traced code. For example, jnp.where(..., MyValue(), ...)
would not work, as this argument must be an array, and MyArray()
hasn't yet interact with a quaxified argument and been wrapped in a Quax tracer.
If the previous paragraph is all technical gobblydegook to you, then the summary is that this removes a major "gotcha".
Value
s no longer support __add__
etc.
In line with the previous change: there is no longer a reason to want to write something like MyValue() + 1
, as this is an operation without being wrapped in a quax.quaxfied
! So all such dunder methods have been removed for safety.
No more DenseArrayValue
This is just a case of simplifying the API a bit. Instead of writing rules that look like:
@quax.register(...)
def _(x: quax.DenseArrayValue, y: SomeValue):
...
you shoud write:
from jaxtyping import ArrayLike
@quax.register(...)
def _(x: ArrayLike, y: SomeValue):
...
In particular this means you no longer need to carefully think whether a normal JAX array has been wrapped into a DenseArrayValue
-- instead, just use them like normal.
Previously, we had to think separately about being "in normal JAX code" (with ArrayLike
s and quax.Value
s), and "currently writing a Quax rule" (with DenseArrayValue
s and quax.Value
s). This unifies these two things, so as to simplify the mental reasoning a bit. It also means we can remove the quax.quaxify(..., unwrap_builtin_values=...)
argument we had to use to toggle between these two regimes.
Disabled dynamic tracing
(As per the discussion in #2.)
Quax will no longer perform dynamic tracing -- that is, it will only run on a primtive bind if one of the arguments to that bind are downstream of an input passed to a quax.quaxify
. It will no longer run on all primitive binds that happen to occur inside the quax.quaxify
-wrapped function.
This removes a spooky-action-at-a-distance. Previously, it was possible to quax.register
a primitive acting only on ArrayLike
s, and thus change the behaviour of that primitive even in normal non-Quax usage. For example, our very own zero
library did this, as a way to create symbolic zeros from operations like "broadcast(0, shape)
". This caused all kinds of havoc, with random Zero
s showing up in places that weren't expected, and which the author of some other type did not know to expect!
Features
- All examples are now built-in to the library: check out
quax.examples.{lora, zeros, named, ...}
! - Added
quax.quaxify(..., filter_spec)
: for the advanced user who only wants to quaxify a few arguments, when working with nestedquax.quaxify
s. (Which is realistically probably just me at this point!) It is now possible to easily specify which arguments should be quaxified. This is just a nice to have; you could previously work around this with tree un/flattening, or capturing values via closure. - Better debugging by giving names to overloaded primitive rules. (Thanks @nstarman! #1)
New Contributors
Full Changelog: v0.0.2...v0.0.3
quax v0.0.2
Autogenerated release notes as follows:
Full Changelog: https://github.com/patrick-kidger/quax/commits/v0.0.2