Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

register lax.cond_p #34

Merged
merged 11 commits into from
Nov 11, 2024
41 changes: 38 additions & 3 deletions quax/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
import functools as ft
import itertools as it
from collections.abc import Callable, Sequence
from typing import Any, cast, Generic, TypeVar, Union
from typing import Any, cast, Generic, List, Optional, TypeVar, Union
vadmbertr marked this conversation as resolved.
Show resolved Hide resolved
from typing_extensions import TypeGuard

import equinox as eqx
import jax
import jax._src
import jax.api_util as api_util
import jax.core as core
import jax.extend.linear_util as lu
import jax.interpreters.partial_eval as pe
import jax.numpy as jnp
import jax.tree_util as jtu
import plum
Expand Down Expand Up @@ -438,7 +440,6 @@ def default(primitive, values, params):
(Using the [Equinox](https://github.com/patrick-kidger/equinox) library that
underlies much of the JAX ecosystem.)
"""

arrays: list[ArrayLike] = []
for x in values:
if _is_value(x):
Expand Down Expand Up @@ -583,4 +584,38 @@ def _(
return result


# TODO: also register higher-order primitives like `lax.cond_p` etc.
@register(jax.lax.cond_p)
def _(index: ArrayLike, *args: Union[ArrayValue, ArrayLike], branches: tuple):
def quaxed_jaxpr_out_tree(jaxpr):
quax_fn = quaxify(core.jaxpr_as_fun(jaxpr))
wrapped_fn, out_tree = api_util.flatten_fun_nokwargs( # pyright: ignore
lu.wrap_init(quax_fn), in_tree
)
in_avals = tuple([core.raise_to_shaped(core.get_aval(x)) for x in args_leaves])
quax_jaxpr = pe.trace_to_jaxpr_dynamic(wrapped_fn, in_avals)[0]
return core.ClosedJaxpr(quax_jaxpr, ()), out_tree()

args_leaves, in_tree = jtu.tree_flatten(args)

quax_branches_jaxpr: List[Optional[core.ClosedJaxpr]] = [None] * len(branches)
quax_jaxpr0, out_tree0 = quaxed_jaxpr_out_tree(branches[0])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this complexity is necessary. In particular I'd like to avoid reaching so heavily into JAX internals like this.

I think something like the following would work:

flat_args, tree_args = jtu.tree_flatten(args)
tree_outs = []

new_branches = []
for jaxpr in branches:
    def flat_call(flat_args):
        args = jtu.tree_unflatten(tree_args, flat_args)
        out = quaxify(jax.core.jaxpr_as_fun(jaxpr))(*args)
        flat_out, tree_out = jtu.tree_flatten(out)
        tree_outs.append(tree_out)
        return flat_out

    new_jaxpr = jax.make_jaxpr(flat_call)(flat_args)
    new_branches.append(new_jaxpr)

if any(tree_outs_i != tree_outs[0] for tree_outs_i in tree_outs[1:]):
    raise ValueError(...)

flat_out = jax.lax.cond_p.bind(index, *flat_args, branches=tuple(new_branches)
return jtu.tree_unflatten(tree_outs[0], flat_out)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not get we could just unflatten and flatten while creating the quaxed jaxpr. Much simpler indeed!
Committing with those changes right now

quax_branches_jaxpr[0] = quax_jaxpr0
for i in range(1, len(branches)):
quax_jaxpr, out_tree = quaxed_jaxpr_out_tree(branches[i])
jax._src.lax.control_flow.common._check_tree_and_avals( # pyright: ignore
f"branch 0 and {i + 1} outputs",
out_tree0,
quax_jaxpr0.out_avals,
out_tree,
quax_jaxpr.out_avals,
)
quax_branches_jaxpr[i] = quax_jaxpr

out_val = jax.lax.cond_p.bind(
index, *args_leaves, branches=tuple(quax_branches_jaxpr)
)
result = jtu.tree_unflatten(out_tree0, out_val)
return result


# TODO: also register higher-order primitives like `lax.scan_p` etc.
159 changes: 159 additions & 0 deletions tests/test_cond.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from typing import Union

import jax
import jax.numpy as jnp
import pytest

import quax
from quax.examples.unitful import kilograms, meters, Unitful


def _outer_fn(a: jax.Array, b: jax.Array, c: jax.Array, pred: Union[bool, jax.Array]):
def _true_fn(a: jax.Array):
return a + b

def _false_fn(a: jax.Array):
return a + c

res = jax.lax.cond(pred, _true_fn, _false_fn, a)
return res


def test_cond_basic():
a = Unitful(jnp.asarray(1.0), {meters: 1})
b = Unitful(jnp.asarray(2.0), {meters: 1})
c = Unitful(jnp.asarray(10.0), {meters: 1})

res = quax.quaxify(_outer_fn)(a, b, c, False)
assert res.array == 11
assert res.units == {meters: 1}

res = quax.quaxify(_outer_fn)(a, b, c, True)
assert res.array == 3
assert res.units == {meters: 1}


def test_cond_different_units():
a = Unitful(jnp.asarray([1.0]), {meters: 1})
b = Unitful(jnp.asarray([2.0]), {meters: 1})
c = Unitful(jnp.asarray([10.0]), {kilograms: 1})

with pytest.raises(Exception):
quax.quaxify(_outer_fn)(a, b, c, False)


def test_cond_different_out_trees():
def _outer_fn(
a: jax.Array, b: jax.Array, c: jax.Array, pred: Union[bool, jax.Array]
):
def _true_fn(a: jax.Array):
return a + b

def _false_fn(a: jax.Array):
return a * c

res = jax.lax.cond(pred, _true_fn, _false_fn, a)
return res

a = Unitful(jnp.asarray([1.0]), {meters: 1})
b = Unitful(jnp.asarray([2.0]), {meters: 1})
c = Unitful(jnp.asarray([10.0]), {meters: 1})

with pytest.raises(Exception):
quax.quaxify(_outer_fn)(a, b, c, False)


def test_cond_switch():
def _outer_fn(index: int, a: jax.Array, b: jax.Array, c: jax.Array):
def _fn0(a: jax.Array):
return a + b

def _fn1(a: jax.Array):
return a + c

def _fn2(a: jax.Array):
return a + b + c

res = jax.lax.switch(index, (_fn0, _fn1, _fn2), a)
return res

a = Unitful(jnp.asarray([1.0]), {meters: 1})
b = Unitful(jnp.asarray([2.0]), {meters: 1})
c = Unitful(jnp.asarray([10.0]), {meters: 1})

res = quax.quaxify(_outer_fn)(0, a, b, c)
assert res.array == 3
assert res.units == {meters: 1}

res = quax.quaxify(_outer_fn)(1, a, b, c)
assert res.array == 11
assert res.units == {meters: 1}

res = quax.quaxify(_outer_fn)(2, a, b, c)
assert res.array == 13
assert res.units == {meters: 1}


def test_cond_jit():
a = Unitful(jnp.asarray(1.0), {meters: 1})
b = Unitful(jnp.asarray(2.0), {meters: 1})
c = Unitful(jnp.asarray(10.0), {meters: 1})

res = quax.quaxify(jax.jit(_outer_fn))(a, b, c, False)
assert res.array == 11
assert res.units == {meters: 1}

res = quax.quaxify(jax.jit(_outer_fn))(a, b, c, True)
assert res.array == 3
assert res.units == {meters: 1}


def test_cond_vmap():
a = Unitful(jnp.arange(1), {meters: 1})
b = Unitful(jnp.asarray(2), {meters: 1})
c = Unitful(jnp.arange(2, 13, 2), {meters: 1})
vmap_fn = jax.vmap(_outer_fn, in_axes=(None, None, 0, None))

res = quax.quaxify(vmap_fn)(a, b, c, True)
assert (res.array == a.array + b.array).all()
assert res.units == {meters: 1}

res = quax.quaxify(vmap_fn)(a, b, c, False)
assert (res.array.ravel() == a.array.ravel() + c.array.ravel()).all() # type: ignore
assert res.units == {meters: 1}


def test_cond_grad_closure():
x = Unitful(jnp.asarray(2.0), {meters: 1})
dummy = Unitful(jnp.asarray(1.0), {meters: 1})

def outer_fn(
outer_var: jax.Array,
dummy: jax.Array,
pred: Union[bool, jax.Array],
):
def _true_fn_grad(a: jax.Array):
return a + outer_var

def _false_fn_grad(a: jax.Array):
return a + outer_var * 2

def _outer_fn_grad(a: jax.Array):
return jax.lax.cond(pred, _true_fn_grad, _false_fn_grad, a)

primals = (outer_var,)
tangents = (dummy,)
p_out, t_out = jax.jvp(_outer_fn_grad, primals, tangents)
return p_out, t_out

p, t = quax.quaxify(outer_fn)(x, dummy, True)
assert p.array == 4
assert p.units == {meters: 1}
assert t.array == 1
assert t.units == {meters: 1}

p, t = quax.quaxify(outer_fn)(x, dummy, False)
assert p.array == 6
assert p.units == {meters: 1}
assert t.array == 1
assert t.units == {meters: 1}
Loading