Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

register lax.cond_p #34

Merged
merged 11 commits into from
Nov 11, 2024
40 changes: 39 additions & 1 deletion quax/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import equinox as eqx
import jax
import jax._src
import jax.api_util as api_util
import jax.core as core
import jax.extend.linear_util as lu
import jax.interpreters.partial_eval as pe
import jax.numpy as jnp
import jax.tree_util as jtu
import plum
Expand Down Expand Up @@ -583,4 +585,40 @@ def _(
return result


# TODO: also register higher-order primitives like `lax.cond_p` etc.
@register(jax.lax.cond_p)
def _(index: ArrayLike, *args: Union[ArrayValue, ArrayLike], branches: tuple):
false_jaxpr, true_jaxpr = branches
vadmbertr marked this conversation as resolved.
Show resolved Hide resolved

# compute jaxpr of quaxified false and true functions
quax_false_fn = quaxify(core.jaxpr_as_fun(false_jaxpr))
quax_false_jaxpr = jax.make_jaxpr(quax_false_fn)(*args)
quax_true_fn = quaxify(core.jaxpr_as_fun(true_jaxpr))
quax_true_jaxpr = jax.make_jaxpr(quax_true_fn)(*args)

# infer the output treedef
args_leaves, in_treedef = jtu.tree_flatten(args)
wrapped_fn, out_treedef = api_util.flatten_fun_nokwargs( # pyright: ignore
lu.wrap_init(quax_false_fn), in_treedef
)
in_avals = tuple([core.raise_to_shaped(core.get_aval(x)) for x in args_leaves])
_ = pe.trace_to_jaxpr_dynamic(wrapped_fn, in_avals)
out_treedef = out_treedef()

out_val = jax.lax.cond_p.bind(
index, *args_leaves, branches=(quax_false_jaxpr, quax_true_jaxpr)
)
result = jtu.tree_unflatten(out_treedef, out_val)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is necessary -- the inputs and outputs of cond should already be flattened. (The PyTree stuff has already been handled by the time binding happens.)

Copy link
Contributor Author

@vadmbertr vadmbertr Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure to understand. Does it mean that we should not return custom objects (ArrayValue), but only ArrayLike?

return result


@register(jax.lax.select_n_p)
def _(which: ArrayLike, *cases: Union[ArrayValue, ArrayLike]):
return jtu.tree_map(ft.partial(jax.lax.select_n_p.bind, which), *cases)


@register(jax.lax.stop_gradient_p)
def _(x: ArrayValue):
return jtu.tree_map(jax.lax.stop_gradient_p.bind, x)


# TODO: also register higher-order primitives like `lax.scan_p` etc.
8 changes: 0 additions & 8 deletions quax/examples/prng/_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
import functools as ft
from collections.abc import Sequence
from typing import Any, TypeVar
from typing_extensions import Self, TYPE_CHECKING, TypeAlias
Expand All @@ -10,7 +9,6 @@
import jax.core
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from jaxtyping import Array, ArrayLike, Float, Integer, UInt, UInt32

Expand Down Expand Up @@ -161,9 +159,3 @@ def split(key: PRNG_T, num: int = 2) -> Sequence[PRNG_T]:
"""

return key.split(num)


# Allows for `jnp.where(pred, key1, key2)`.
@quax.register(lax.select_n_p)
def _(pred, *cases: PRNG) -> PRNG:
return jtu.tree_map(ft.partial(lax.select_n, pred), *cases)
107 changes: 107 additions & 0 deletions tests/test_cond.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import Union

import jax
import jax.numpy as jnp
import pytest

import quax
from quax.examples.unitful import kilograms, meters, Unitful


def _outer_fn(a: jax.Array, b: jax.Array, c: jax.Array, pred: Union[bool, jax.Array]):
def _true_fn(a: jax.Array):
return a + b

def _false_fn(a: jax.Array):
return a + c

res = jax.lax.cond(pred, _true_fn, _false_fn, a)
return res


def test_cond_basic():
a = Unitful(jnp.asarray(1.0), {meters: 1})
b = Unitful(jnp.asarray(2.0), {meters: 1})
c = Unitful(jnp.asarray(10.0), {meters: 1})

res = quax.quaxify(_outer_fn)(a, b, c, False)
assert res.array == 11
assert res.units == {meters: 1}

res = quax.quaxify(_outer_fn)(a, b, c, True)
assert res.array == 3
assert res.units == {meters: 1}


def test_cond_different_units():
a = Unitful(jnp.asarray([1.0]), {meters: 1})
b = Unitful(jnp.asarray([2.0]), {meters: 1})
c = Unitful(jnp.asarray([10.0]), {kilograms: 1})

with pytest.raises(Exception):
quax.quaxify(_outer_fn)(a, b, c, False)


def test_cond_jit():
a = Unitful(jnp.asarray(1.0), {meters: 1})
b = Unitful(jnp.asarray(2.0), {meters: 1})
c = Unitful(jnp.asarray(10.0), {meters: 1})

res = quax.quaxify(jax.jit(_outer_fn))(a, b, c, False)
assert res.array == 11
assert res.units == {meters: 1}

res = quax.quaxify(jax.jit(_outer_fn))(a, b, c, True)
assert res.array == 3
assert res.units == {meters: 1}


def test_cond_vmap():
a = Unitful(jnp.arange(1), {meters: 1})
b = Unitful(jnp.asarray(2), {meters: 1})
c = Unitful(jnp.arange(2, 13, 2), {meters: 1})
vmap_fn = jax.vmap(_outer_fn, in_axes=(None, None, 0, None))

res = quax.quaxify(vmap_fn)(a, b, c, True)
assert (res.array == a.array + b.array).all()
assert res.units == {meters: 1}

res = quax.quaxify(vmap_fn)(a, b, c, False)
assert (res.array.ravel() == a.array.ravel() + c.array.ravel()).all() # type: ignore
assert res.units == {meters: 1}


def test_cond_grad_closure():
x = Unitful(jnp.asarray(2.0), {meters: 1})
dummy = Unitful(jnp.asarray(1.0), {meters: 1})

def outer_fn(
outer_var: jax.Array,
dummy: jax.Array,
pred: Union[bool, jax.Array],
):
def _true_fn_grad(a: jax.Array):
return a + outer_var

def _false_fn_grad(a: jax.Array):
return a + outer_var * 2

def _outer_fn_grad(a: jax.Array):
return jax.lax.cond(pred, _true_fn_grad, _false_fn_grad, a)

primals = (outer_var,)
tangents = (dummy,)
p_out, t_out = jax.jvp(_outer_fn_grad, primals, tangents)
return p_out, t_out

p, t = quax.quaxify(outer_fn)(x, dummy, True)
assert p.array == 4
assert p.units == {meters: 1}
assert t.array == 1
assert t.units == {meters: 1}

p, t = quax.quaxify(outer_fn)(x, dummy, False)
assert p.array == 6
assert p.units == {meters: 1}
assert t.array == 1
assert t.units == {meters: 1}
Loading