Skip to content

Commit

Permalink
register lax.cond_p (#34)
Browse files Browse the repository at this point in the history
* register lax.cond_p

* update test_cond.py

* fix union syntax for Python<3.10

* fix union syntax for Python<3.10

* add select_n and stop_gradient primitives in order to handle vmapping over the pred of lax.cond

* gruff format

* more concise impl of select and stop_gradient

* remove print left behind

* handles unprescribed number of branches (i.e. switch). enforce exact same output type. add tests for those. remove stop_gradient and select primitives 'defaults'

* simplify eval of quaxed jaxpr and out tree of the branches

* compat with older jax versions
  • Loading branch information
vadmbertr authored Nov 11, 2024
1 parent 166266f commit a6fba20
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 2 deletions.
43 changes: 41 additions & 2 deletions quax/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,6 @@ def default(primitive, values, params):
(Using the [Equinox](https://github.com/patrick-kidger/equinox) library that
underlies much of the JAX ecosystem.)
"""

arrays: list[ArrayLike] = []
for x in values:
if _is_value(x):
Expand Down Expand Up @@ -583,4 +582,44 @@ def _(
return result


# TODO: also register higher-order primitives like `lax.cond_p` etc.
_sentinel = object()


@register(jax.lax.cond_p)
def _(
index: ArrayLike,
*args: Union[ArrayValue, ArrayLike],
branches: tuple,
linear=_sentinel,
):
flat_args, in_tree = jtu.tree_flatten(args)

out_trees = []
quax_branches = []
for jaxpr in branches:

def flat_quax_call(flat_args):
args = jtu.tree_unflatten(in_tree, flat_args)
out = quaxify(core.jaxpr_as_fun(jaxpr))(*args)
flat_out, out_tree = jtu.tree_flatten(out)
out_trees.append(out_tree)
return flat_out

quax_jaxpr = jax.make_jaxpr(flat_quax_call)(flat_args)
quax_branches.append(quax_jaxpr)

if any(tree_outs_i != out_trees[0] for tree_outs_i in out_trees[1:]):
raise TypeError("all branches output must have the same pytree.")

if linear is _sentinel:
maybe_linear = {}
else:
maybe_linear = dict(linear=linear)
out_val = jax.lax.cond_p.bind(
index, *flat_args, branches=tuple(quax_branches), **maybe_linear
)
result = jtu.tree_unflatten(out_trees[0], out_val)
return result


# TODO: also register higher-order primitives like `lax.scan_p` etc.
159 changes: 159 additions & 0 deletions tests/test_cond.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from typing import Union

import jax
import jax.numpy as jnp
import pytest

import quax
from quax.examples.unitful import kilograms, meters, Unitful


def _outer_fn(a: jax.Array, b: jax.Array, c: jax.Array, pred: Union[bool, jax.Array]):
def _true_fn(a: jax.Array):
return a + b

def _false_fn(a: jax.Array):
return a + c

res = jax.lax.cond(pred, _true_fn, _false_fn, a)
return res


def test_cond_basic():
a = Unitful(jnp.asarray(1.0), {meters: 1})
b = Unitful(jnp.asarray(2.0), {meters: 1})
c = Unitful(jnp.asarray(10.0), {meters: 1})

res = quax.quaxify(_outer_fn)(a, b, c, False)
assert res.array == 11
assert res.units == {meters: 1}

res = quax.quaxify(_outer_fn)(a, b, c, True)
assert res.array == 3
assert res.units == {meters: 1}


def test_cond_different_units():
a = Unitful(jnp.asarray([1.0]), {meters: 1})
b = Unitful(jnp.asarray([2.0]), {meters: 1})
c = Unitful(jnp.asarray([10.0]), {kilograms: 1})

with pytest.raises(Exception):
quax.quaxify(_outer_fn)(a, b, c, False)


def test_cond_different_out_trees():
def _outer_fn(
a: jax.Array, b: jax.Array, c: jax.Array, pred: Union[bool, jax.Array]
):
def _true_fn(a: jax.Array):
return a + b

def _false_fn(a: jax.Array):
return a * c

res = jax.lax.cond(pred, _true_fn, _false_fn, a)
return res

a = Unitful(jnp.asarray([1.0]), {meters: 1})
b = Unitful(jnp.asarray([2.0]), {meters: 1})
c = Unitful(jnp.asarray([10.0]), {meters: 1})

with pytest.raises(Exception):
quax.quaxify(_outer_fn)(a, b, c, False)


def test_cond_switch():
def _outer_fn(index: int, a: jax.Array, b: jax.Array, c: jax.Array):
def _fn0(a: jax.Array):
return a + b

def _fn1(a: jax.Array):
return a + c

def _fn2(a: jax.Array):
return a + b + c

res = jax.lax.switch(index, (_fn0, _fn1, _fn2), a)
return res

a = Unitful(jnp.asarray([1.0]), {meters: 1})
b = Unitful(jnp.asarray([2.0]), {meters: 1})
c = Unitful(jnp.asarray([10.0]), {meters: 1})

res = quax.quaxify(_outer_fn)(0, a, b, c)
assert res.array == 3
assert res.units == {meters: 1}

res = quax.quaxify(_outer_fn)(1, a, b, c)
assert res.array == 11
assert res.units == {meters: 1}

res = quax.quaxify(_outer_fn)(2, a, b, c)
assert res.array == 13
assert res.units == {meters: 1}


def test_cond_jit():
a = Unitful(jnp.asarray(1.0), {meters: 1})
b = Unitful(jnp.asarray(2.0), {meters: 1})
c = Unitful(jnp.asarray(10.0), {meters: 1})

res = quax.quaxify(jax.jit(_outer_fn))(a, b, c, False)
assert res.array == 11
assert res.units == {meters: 1}

res = quax.quaxify(jax.jit(_outer_fn))(a, b, c, True)
assert res.array == 3
assert res.units == {meters: 1}


def test_cond_vmap():
a = Unitful(jnp.arange(1), {meters: 1})
b = Unitful(jnp.asarray(2), {meters: 1})
c = Unitful(jnp.arange(2, 13, 2), {meters: 1})
vmap_fn = jax.vmap(_outer_fn, in_axes=(None, None, 0, None))

res = quax.quaxify(vmap_fn)(a, b, c, True)
assert (res.array == a.array + b.array).all()
assert res.units == {meters: 1}

res = quax.quaxify(vmap_fn)(a, b, c, False)
assert (res.array.ravel() == a.array.ravel() + c.array.ravel()).all() # type: ignore
assert res.units == {meters: 1}


def test_cond_grad_closure():
x = Unitful(jnp.asarray(2.0), {meters: 1})
dummy = Unitful(jnp.asarray(1.0), {meters: 1})

def outer_fn(
outer_var: jax.Array,
dummy: jax.Array,
pred: Union[bool, jax.Array],
):
def _true_fn_grad(a: jax.Array):
return a + outer_var

def _false_fn_grad(a: jax.Array):
return a + outer_var * 2

def _outer_fn_grad(a: jax.Array):
return jax.lax.cond(pred, _true_fn_grad, _false_fn_grad, a)

primals = (outer_var,)
tangents = (dummy,)
p_out, t_out = jax.jvp(_outer_fn_grad, primals, tangents)
return p_out, t_out

p, t = quax.quaxify(outer_fn)(x, dummy, True)
assert p.array == 4
assert p.units == {meters: 1}
assert t.array == 1
assert t.units == {meters: 1}

p, t = quax.quaxify(outer_fn)(x, dummy, False)
assert p.array == 6
assert p.units == {meters: 1}
assert t.array == 1
assert t.units == {meters: 1}

0 comments on commit a6fba20

Please sign in to comment.