Skip to content

Commit

Permalink
register lax.cond_p
Browse files Browse the repository at this point in the history
  • Loading branch information
vadmbertr committed Oct 24, 2024
1 parent 166266f commit 7cbbd96
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 1 deletion.
30 changes: 29 additions & 1 deletion quax/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import equinox as eqx
import jax
import jax._src
import jax.api_util as api_util
import jax.core as core
import jax.extend.linear_util as lu
import jax.interpreters.partial_eval as pe
import jax.numpy as jnp
import jax.tree_util as jtu
import plum
Expand Down Expand Up @@ -583,4 +585,30 @@ def _(
return result


# TODO: also register higher-order primitives like `lax.cond_p` etc.
@register(jax.lax.cond_p)
def _(index: ArrayLike, *args: Union[ArrayValue, ArrayLike], branches: tuple):
false_jaxpr, true_jaxpr = branches

# compute jaxpr of quaxified false and true functions
quax_false_fn = quaxify(core.jaxpr_as_fun(false_jaxpr))
quax_false_jaxpr = jax.make_jaxpr(quax_false_fn)(*args)
quax_true_fn = quaxify(core.jaxpr_as_fun(true_jaxpr))
quax_true_jaxpr = jax.make_jaxpr(quax_true_fn)(*args)

# infer the output treedef
args_leaves, in_treedef = jtu.tree_flatten(args)
wrapped_fn, out_treedef = api_util.flatten_fun_nokwargs( # pyright: ignore
lu.wrap_init(quax_false_fn), in_treedef
)
in_avals = tuple([core.raise_to_shaped(core.get_aval(x)) for x in args_leaves])
_ = pe.trace_to_jaxpr_dynamic(wrapped_fn, in_avals)
out_treedef = out_treedef()

out_val = jax.lax.cond_p.bind(
index, *args_leaves, branches=(quax_false_jaxpr, quax_true_jaxpr)
)
result = jtu.tree_unflatten(out_treedef, out_val)
return result


# TODO: also register higher-order primitives like `lax.scan_p` etc.
109 changes: 109 additions & 0 deletions tests/test_cond.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import jax
import jax.numpy as jnp
import pytest

import quax
from quax.examples.unitful import kilograms, meters, Unitful


def _outer_fn(a: jax.Array, b: jax.Array, c: jax.Array, pred: bool | jax.Array):
def _true_fn(a: jax.Array):
return a + b

def _false_fn(a: jax.Array):
return a + c

res = jax.lax.cond(pred, _true_fn, _false_fn, a)
return res


def test_cond_basic():
a = Unitful(jnp.asarray(1.0), {meters: 1})
b = Unitful(jnp.asarray(2.0), {meters: 1})
c = Unitful(jnp.asarray(10.0), {meters: 1})

res = quax.quaxify(_outer_fn)(a, b, c, False)
assert res.array == 11
assert res.units == {meters: 1}

res = quax.quaxify(_outer_fn)(a, b, c, True)
assert res.array == 3
assert res.units == {meters: 1}


def test_cond_different_units():
a = Unitful(jnp.asarray([1.0]), {meters: 1})
b = Unitful(jnp.asarray([2.0]), {meters: 1})
c = Unitful(jnp.asarray([10.0]), {kilograms: 1})

with pytest.raises(Exception):
quax.quaxify(_outer_fn)(a, b, c, False)


def test_cond_jit():
a = Unitful(jnp.asarray(1.0), {meters: 1})
b = Unitful(jnp.asarray(2.0), {meters: 1})
c = Unitful(jnp.asarray(10.0), {meters: 1})

res = quax.quaxify(jax.jit(_outer_fn))(a, b, c, False)
assert res.array == 11
assert res.units == {meters: 1}

res = quax.quaxify(jax.jit(_outer_fn))(a, b, c, True)
assert res.array == 3
assert res.units == {meters: 1}


def test_cond_vmap():
a = Unitful(jnp.arange(1), {meters: 1})
b = Unitful(jnp.asarray(2), {meters: 1})
c = Unitful(jnp.arange(2, 13, 2), {meters: 1})
vmap_fn = jax.vmap(_outer_fn, in_axes=(None, None, 0, None))

res = quax.quaxify(vmap_fn)(a, b, c, True)
assert (res.array == a.array + b.array).all()
assert res.units == {meters: 1}

res = quax.quaxify(vmap_fn)(a, b, c, False)
assert (res.array.ravel() == a.array.ravel() + c.array.ravel()).all() # type: ignore
assert res.units == {meters: 1}


def test_cond_grad_closure():
x = Unitful(jnp.asarray(2.0), {meters: 1})
b = Unitful(jnp.asarray(2.0), {meters: 1})
c = Unitful(jnp.asarray(10.0), {meters: 1})
dummy = Unitful(jnp.asarray(1.0), {meters: 1})

def outer_fn(
outer_var: jax.Array,
dummy: jax.Array,
b: jax.Array,
c: jax.Array,
pred: bool | jax.Array,
):
def _true_fn_grad(outer_var: jax.Array):
return outer_var + b

def _false_fn_grad(outer_var: jax.Array):
return outer_var + c

def _outer_fn_grad(a: jax.Array):
return jax.lax.cond(pred, _true_fn_grad, _false_fn_grad, a)

primals = (outer_var,)
tangents = (dummy,)
p_out, t_out = jax.jvp(_outer_fn_grad, primals, tangents)
return p_out, t_out

p, t = quax.quaxify(outer_fn)(x, dummy, b, c, True)
assert p.array == 4
assert p.units == {meters: 1}
assert t.array == 1
assert t.units == {meters: 1}

p, t = quax.quaxify(outer_fn)(x, dummy, b, c, False)
assert p.array == 12
assert p.units == {meters: 1}
assert t.array == 1
assert t.units == {meters: 1}

0 comments on commit 7cbbd96

Please sign in to comment.