From a6fba204df3d60f7df4e98041a07e7074b19d2df Mon Sep 17 00:00:00 2001 From: Vadim Bertrand <36510417+vadmbertr@users.noreply.github.com> Date: Mon, 11 Nov 2024 18:52:05 +0100 Subject: [PATCH] register lax.cond_p (#34) * register lax.cond_p * update test_cond.py * fix union syntax for Python<3.10 * fix union syntax for Python<3.10 * add select_n and stop_gradient primitives in order to handle vmapping over the pred of lax.cond * gruff format * more concise impl of select and stop_gradient * remove print left behind * handles unprescribed number of branches (i.e. switch). enforce exact same output type. add tests for those. remove stop_gradient and select primitives 'defaults' * simplify eval of quaxed jaxpr and out tree of the branches * compat with older jax versions --- quax/_core.py | 43 +++++++++++- tests/test_cond.py | 159 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 200 insertions(+), 2 deletions(-) create mode 100644 tests/test_cond.py diff --git a/quax/_core.py b/quax/_core.py index d6781af..132e7e4 100644 --- a/quax/_core.py +++ b/quax/_core.py @@ -438,7 +438,6 @@ def default(primitive, values, params): (Using the [Equinox](https://github.com/patrick-kidger/equinox) library that underlies much of the JAX ecosystem.) """ - arrays: list[ArrayLike] = [] for x in values: if _is_value(x): @@ -583,4 +582,44 @@ def _( return result -# TODO: also register higher-order primitives like `lax.cond_p` etc. +_sentinel = object() + + +@register(jax.lax.cond_p) +def _( + index: ArrayLike, + *args: Union[ArrayValue, ArrayLike], + branches: tuple, + linear=_sentinel, +): + flat_args, in_tree = jtu.tree_flatten(args) + + out_trees = [] + quax_branches = [] + for jaxpr in branches: + + def flat_quax_call(flat_args): + args = jtu.tree_unflatten(in_tree, flat_args) + out = quaxify(core.jaxpr_as_fun(jaxpr))(*args) + flat_out, out_tree = jtu.tree_flatten(out) + out_trees.append(out_tree) + return flat_out + + quax_jaxpr = jax.make_jaxpr(flat_quax_call)(flat_args) + quax_branches.append(quax_jaxpr) + + if any(tree_outs_i != out_trees[0] for tree_outs_i in out_trees[1:]): + raise TypeError("all branches output must have the same pytree.") + + if linear is _sentinel: + maybe_linear = {} + else: + maybe_linear = dict(linear=linear) + out_val = jax.lax.cond_p.bind( + index, *flat_args, branches=tuple(quax_branches), **maybe_linear + ) + result = jtu.tree_unflatten(out_trees[0], out_val) + return result + + +# TODO: also register higher-order primitives like `lax.scan_p` etc. diff --git a/tests/test_cond.py b/tests/test_cond.py new file mode 100644 index 0000000..c940fb2 --- /dev/null +++ b/tests/test_cond.py @@ -0,0 +1,159 @@ +from typing import Union + +import jax +import jax.numpy as jnp +import pytest + +import quax +from quax.examples.unitful import kilograms, meters, Unitful + + +def _outer_fn(a: jax.Array, b: jax.Array, c: jax.Array, pred: Union[bool, jax.Array]): + def _true_fn(a: jax.Array): + return a + b + + def _false_fn(a: jax.Array): + return a + c + + res = jax.lax.cond(pred, _true_fn, _false_fn, a) + return res + + +def test_cond_basic(): + a = Unitful(jnp.asarray(1.0), {meters: 1}) + b = Unitful(jnp.asarray(2.0), {meters: 1}) + c = Unitful(jnp.asarray(10.0), {meters: 1}) + + res = quax.quaxify(_outer_fn)(a, b, c, False) + assert res.array == 11 + assert res.units == {meters: 1} + + res = quax.quaxify(_outer_fn)(a, b, c, True) + assert res.array == 3 + assert res.units == {meters: 1} + + +def test_cond_different_units(): + a = Unitful(jnp.asarray([1.0]), {meters: 1}) + b = Unitful(jnp.asarray([2.0]), {meters: 1}) + c = Unitful(jnp.asarray([10.0]), {kilograms: 1}) + + with pytest.raises(Exception): + quax.quaxify(_outer_fn)(a, b, c, False) + + +def test_cond_different_out_trees(): + def _outer_fn( + a: jax.Array, b: jax.Array, c: jax.Array, pred: Union[bool, jax.Array] + ): + def _true_fn(a: jax.Array): + return a + b + + def _false_fn(a: jax.Array): + return a * c + + res = jax.lax.cond(pred, _true_fn, _false_fn, a) + return res + + a = Unitful(jnp.asarray([1.0]), {meters: 1}) + b = Unitful(jnp.asarray([2.0]), {meters: 1}) + c = Unitful(jnp.asarray([10.0]), {meters: 1}) + + with pytest.raises(Exception): + quax.quaxify(_outer_fn)(a, b, c, False) + + +def test_cond_switch(): + def _outer_fn(index: int, a: jax.Array, b: jax.Array, c: jax.Array): + def _fn0(a: jax.Array): + return a + b + + def _fn1(a: jax.Array): + return a + c + + def _fn2(a: jax.Array): + return a + b + c + + res = jax.lax.switch(index, (_fn0, _fn1, _fn2), a) + return res + + a = Unitful(jnp.asarray([1.0]), {meters: 1}) + b = Unitful(jnp.asarray([2.0]), {meters: 1}) + c = Unitful(jnp.asarray([10.0]), {meters: 1}) + + res = quax.quaxify(_outer_fn)(0, a, b, c) + assert res.array == 3 + assert res.units == {meters: 1} + + res = quax.quaxify(_outer_fn)(1, a, b, c) + assert res.array == 11 + assert res.units == {meters: 1} + + res = quax.quaxify(_outer_fn)(2, a, b, c) + assert res.array == 13 + assert res.units == {meters: 1} + + +def test_cond_jit(): + a = Unitful(jnp.asarray(1.0), {meters: 1}) + b = Unitful(jnp.asarray(2.0), {meters: 1}) + c = Unitful(jnp.asarray(10.0), {meters: 1}) + + res = quax.quaxify(jax.jit(_outer_fn))(a, b, c, False) + assert res.array == 11 + assert res.units == {meters: 1} + + res = quax.quaxify(jax.jit(_outer_fn))(a, b, c, True) + assert res.array == 3 + assert res.units == {meters: 1} + + +def test_cond_vmap(): + a = Unitful(jnp.arange(1), {meters: 1}) + b = Unitful(jnp.asarray(2), {meters: 1}) + c = Unitful(jnp.arange(2, 13, 2), {meters: 1}) + vmap_fn = jax.vmap(_outer_fn, in_axes=(None, None, 0, None)) + + res = quax.quaxify(vmap_fn)(a, b, c, True) + assert (res.array == a.array + b.array).all() + assert res.units == {meters: 1} + + res = quax.quaxify(vmap_fn)(a, b, c, False) + assert (res.array.ravel() == a.array.ravel() + c.array.ravel()).all() # type: ignore + assert res.units == {meters: 1} + + +def test_cond_grad_closure(): + x = Unitful(jnp.asarray(2.0), {meters: 1}) + dummy = Unitful(jnp.asarray(1.0), {meters: 1}) + + def outer_fn( + outer_var: jax.Array, + dummy: jax.Array, + pred: Union[bool, jax.Array], + ): + def _true_fn_grad(a: jax.Array): + return a + outer_var + + def _false_fn_grad(a: jax.Array): + return a + outer_var * 2 + + def _outer_fn_grad(a: jax.Array): + return jax.lax.cond(pred, _true_fn_grad, _false_fn_grad, a) + + primals = (outer_var,) + tangents = (dummy,) + p_out, t_out = jax.jvp(_outer_fn_grad, primals, tangents) + return p_out, t_out + + p, t = quax.quaxify(outer_fn)(x, dummy, True) + assert p.array == 4 + assert p.units == {meters: 1} + assert t.array == 1 + assert t.units == {meters: 1} + + p, t = quax.quaxify(outer_fn)(x, dummy, False) + assert p.array == 6 + assert p.units == {meters: 1} + assert t.array == 1 + assert t.units == {meters: 1}