From 7cbbd962406365a58d460a098e83349105d78dbc Mon Sep 17 00:00:00 2001 From: Vadim Bertrand Date: Thu, 24 Oct 2024 14:41:33 +0200 Subject: [PATCH 01/11] register lax.cond_p --- quax/_core.py | 30 ++++++++++++- tests/test_cond.py | 109 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 tests/test_cond.py diff --git a/quax/_core.py b/quax/_core.py index d6781af..155e42b 100644 --- a/quax/_core.py +++ b/quax/_core.py @@ -8,8 +8,10 @@ import equinox as eqx import jax import jax._src +import jax.api_util as api_util import jax.core as core import jax.extend.linear_util as lu +import jax.interpreters.partial_eval as pe import jax.numpy as jnp import jax.tree_util as jtu import plum @@ -583,4 +585,30 @@ def _( return result -# TODO: also register higher-order primitives like `lax.cond_p` etc. +@register(jax.lax.cond_p) +def _(index: ArrayLike, *args: Union[ArrayValue, ArrayLike], branches: tuple): + false_jaxpr, true_jaxpr = branches + + # compute jaxpr of quaxified false and true functions + quax_false_fn = quaxify(core.jaxpr_as_fun(false_jaxpr)) + quax_false_jaxpr = jax.make_jaxpr(quax_false_fn)(*args) + quax_true_fn = quaxify(core.jaxpr_as_fun(true_jaxpr)) + quax_true_jaxpr = jax.make_jaxpr(quax_true_fn)(*args) + + # infer the output treedef + args_leaves, in_treedef = jtu.tree_flatten(args) + wrapped_fn, out_treedef = api_util.flatten_fun_nokwargs( # pyright: ignore + lu.wrap_init(quax_false_fn), in_treedef + ) + in_avals = tuple([core.raise_to_shaped(core.get_aval(x)) for x in args_leaves]) + _ = pe.trace_to_jaxpr_dynamic(wrapped_fn, in_avals) + out_treedef = out_treedef() + + out_val = jax.lax.cond_p.bind( + index, *args_leaves, branches=(quax_false_jaxpr, quax_true_jaxpr) + ) + result = jtu.tree_unflatten(out_treedef, out_val) + return result + + +# TODO: also register higher-order primitives like `lax.scan_p` etc. diff --git a/tests/test_cond.py b/tests/test_cond.py new file mode 100644 index 0000000..741f34b --- /dev/null +++ b/tests/test_cond.py @@ -0,0 +1,109 @@ +import jax +import jax.numpy as jnp +import pytest + +import quax +from quax.examples.unitful import kilograms, meters, Unitful + + +def _outer_fn(a: jax.Array, b: jax.Array, c: jax.Array, pred: bool | jax.Array): + def _true_fn(a: jax.Array): + return a + b + + def _false_fn(a: jax.Array): + return a + c + + res = jax.lax.cond(pred, _true_fn, _false_fn, a) + return res + + +def test_cond_basic(): + a = Unitful(jnp.asarray(1.0), {meters: 1}) + b = Unitful(jnp.asarray(2.0), {meters: 1}) + c = Unitful(jnp.asarray(10.0), {meters: 1}) + + res = quax.quaxify(_outer_fn)(a, b, c, False) + assert res.array == 11 + assert res.units == {meters: 1} + + res = quax.quaxify(_outer_fn)(a, b, c, True) + assert res.array == 3 + assert res.units == {meters: 1} + + +def test_cond_different_units(): + a = Unitful(jnp.asarray([1.0]), {meters: 1}) + b = Unitful(jnp.asarray([2.0]), {meters: 1}) + c = Unitful(jnp.asarray([10.0]), {kilograms: 1}) + + with pytest.raises(Exception): + quax.quaxify(_outer_fn)(a, b, c, False) + + +def test_cond_jit(): + a = Unitful(jnp.asarray(1.0), {meters: 1}) + b = Unitful(jnp.asarray(2.0), {meters: 1}) + c = Unitful(jnp.asarray(10.0), {meters: 1}) + + res = quax.quaxify(jax.jit(_outer_fn))(a, b, c, False) + assert res.array == 11 + assert res.units == {meters: 1} + + res = quax.quaxify(jax.jit(_outer_fn))(a, b, c, True) + assert res.array == 3 + assert res.units == {meters: 1} + + +def test_cond_vmap(): + a = Unitful(jnp.arange(1), {meters: 1}) + b = Unitful(jnp.asarray(2), {meters: 1}) + c = Unitful(jnp.arange(2, 13, 2), {meters: 1}) + vmap_fn = jax.vmap(_outer_fn, in_axes=(None, None, 0, None)) + + res = quax.quaxify(vmap_fn)(a, b, c, True) + assert (res.array == a.array + b.array).all() + assert res.units == {meters: 1} + + res = quax.quaxify(vmap_fn)(a, b, c, False) + assert (res.array.ravel() == a.array.ravel() + c.array.ravel()).all() # type: ignore + assert res.units == {meters: 1} + + +def test_cond_grad_closure(): + x = Unitful(jnp.asarray(2.0), {meters: 1}) + b = Unitful(jnp.asarray(2.0), {meters: 1}) + c = Unitful(jnp.asarray(10.0), {meters: 1}) + dummy = Unitful(jnp.asarray(1.0), {meters: 1}) + + def outer_fn( + outer_var: jax.Array, + dummy: jax.Array, + b: jax.Array, + c: jax.Array, + pred: bool | jax.Array, + ): + def _true_fn_grad(outer_var: jax.Array): + return outer_var + b + + def _false_fn_grad(outer_var: jax.Array): + return outer_var + c + + def _outer_fn_grad(a: jax.Array): + return jax.lax.cond(pred, _true_fn_grad, _false_fn_grad, a) + + primals = (outer_var,) + tangents = (dummy,) + p_out, t_out = jax.jvp(_outer_fn_grad, primals, tangents) + return p_out, t_out + + p, t = quax.quaxify(outer_fn)(x, dummy, b, c, True) + assert p.array == 4 + assert p.units == {meters: 1} + assert t.array == 1 + assert t.units == {meters: 1} + + p, t = quax.quaxify(outer_fn)(x, dummy, b, c, False) + assert p.array == 12 + assert p.units == {meters: 1} + assert t.array == 1 + assert t.units == {meters: 1} From b2ec9569ad051bf69b799ce75ac03a7ab07332ca Mon Sep 17 00:00:00 2001 From: Vadim Bertrand Date: Fri, 25 Oct 2024 10:40:30 +0200 Subject: [PATCH 02/11] update test_cond.py --- tests/test_cond.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/tests/test_cond.py b/tests/test_cond.py index 741f34b..7126231 100644 --- a/tests/test_cond.py +++ b/tests/test_cond.py @@ -71,22 +71,18 @@ def test_cond_vmap(): def test_cond_grad_closure(): x = Unitful(jnp.asarray(2.0), {meters: 1}) - b = Unitful(jnp.asarray(2.0), {meters: 1}) - c = Unitful(jnp.asarray(10.0), {meters: 1}) dummy = Unitful(jnp.asarray(1.0), {meters: 1}) def outer_fn( outer_var: jax.Array, dummy: jax.Array, - b: jax.Array, - c: jax.Array, pred: bool | jax.Array, ): - def _true_fn_grad(outer_var: jax.Array): - return outer_var + b + def _true_fn_grad(a: jax.Array): + return a + outer_var - def _false_fn_grad(outer_var: jax.Array): - return outer_var + c + def _false_fn_grad(a: jax.Array): + return a + outer_var * 2 def _outer_fn_grad(a: jax.Array): return jax.lax.cond(pred, _true_fn_grad, _false_fn_grad, a) @@ -96,14 +92,14 @@ def _outer_fn_grad(a: jax.Array): p_out, t_out = jax.jvp(_outer_fn_grad, primals, tangents) return p_out, t_out - p, t = quax.quaxify(outer_fn)(x, dummy, b, c, True) + p, t = quax.quaxify(outer_fn)(x, dummy, True) assert p.array == 4 assert p.units == {meters: 1} assert t.array == 1 assert t.units == {meters: 1} - p, t = quax.quaxify(outer_fn)(x, dummy, b, c, False) - assert p.array == 12 + p, t = quax.quaxify(outer_fn)(x, dummy, False) + assert p.array == 6 assert p.units == {meters: 1} assert t.array == 1 assert t.units == {meters: 1} From c36064669057da53340fafa3154911dfaefcb5b7 Mon Sep 17 00:00:00 2001 From: Vadim Bertrand Date: Fri, 25 Oct 2024 10:42:49 +0200 Subject: [PATCH 03/11] fix union syntax for Python<3.10 --- tests/test_cond.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_cond.py b/tests/test_cond.py index 7126231..8172e4f 100644 --- a/tests/test_cond.py +++ b/tests/test_cond.py @@ -1,3 +1,5 @@ +from typing import Union + import jax import jax.numpy as jnp import pytest @@ -6,7 +8,7 @@ from quax.examples.unitful import kilograms, meters, Unitful -def _outer_fn(a: jax.Array, b: jax.Array, c: jax.Array, pred: bool | jax.Array): +def _outer_fn(a: jax.Array, b: jax.Array, c: jax.Array, pred: Union[bool, jax.Array]): def _true_fn(a: jax.Array): return a + b From 6bd9d15831c242f2260bf0ac40460552df71dcaa Mon Sep 17 00:00:00 2001 From: Vadim Bertrand Date: Fri, 25 Oct 2024 10:43:25 +0200 Subject: [PATCH 04/11] fix union syntax for Python<3.10 --- tests/test_cond.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_cond.py b/tests/test_cond.py index 8172e4f..41c45fd 100644 --- a/tests/test_cond.py +++ b/tests/test_cond.py @@ -78,7 +78,7 @@ def test_cond_grad_closure(): def outer_fn( outer_var: jax.Array, dummy: jax.Array, - pred: bool | jax.Array, + pred: Union[bool, jax.Array], ): def _true_fn_grad(a: jax.Array): return a + outer_var From 78e8a677ce3557984e4c6b9b695b6ea7006cccd0 Mon Sep 17 00:00:00 2001 From: Vadim Bertrand Date: Fri, 25 Oct 2024 12:21:44 +0200 Subject: [PATCH 05/11] add select_n and stop_gradient primitives in order to handle vmapping over the pred of lax.cond --- quax/_core.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/quax/_core.py b/quax/_core.py index 155e42b..af4b0d6 100644 --- a/quax/_core.py +++ b/quax/_core.py @@ -194,6 +194,7 @@ def process_primitive(self, primitive, tracers, params): try: method, _ = rule.resolve_method(values) except plum.NotFoundLookupError: + print(primitive, values, params) out = _default_process(primitive, values, params) else: out = method(*values, **params) @@ -611,4 +612,21 @@ def _(index: ArrayLike, *args: Union[ArrayValue, ArrayLike], branches: tuple): return result +@register(jax.lax.select_n_p) +def _(which: ArrayLike, *cases: Union[ArrayValue, ArrayLike]): + leaves, _ = jtu.tree_flatten(cases) + _, treedef = jtu.tree_flatten(cases[0]) + out_val = jax.lax.select_n_p.bind(which, *leaves) + result = jtu.tree_unflatten(treedef, [out_val]) + return result + + +@register(jax.lax.stop_gradient_p) +def _(x: ArrayValue): + leaves, treedef = jtu.tree_flatten(x) + out_val = jax.lax.stop_gradient_p.bind(*leaves) + result = jtu.tree_unflatten(treedef, [out_val]) + return result + + # TODO: also register higher-order primitives like `lax.scan_p` etc. From cd2217295006f58711ee962260bf09167e149150 Mon Sep 17 00:00:00 2001 From: Vadim Bertrand Date: Fri, 25 Oct 2024 12:47:43 +0200 Subject: [PATCH 06/11] gruff format --- quax/examples/prng/_core.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/quax/examples/prng/_core.py b/quax/examples/prng/_core.py index 046ba75..9a804b8 100644 --- a/quax/examples/prng/_core.py +++ b/quax/examples/prng/_core.py @@ -1,5 +1,4 @@ import abc -import functools as ft from collections.abc import Sequence from typing import Any, TypeVar from typing_extensions import Self, TYPE_CHECKING, TypeAlias @@ -10,7 +9,6 @@ import jax.core import jax.lax as lax import jax.numpy as jnp -import jax.tree_util as jtu import numpy as np from jaxtyping import Array, ArrayLike, Float, Integer, UInt, UInt32 @@ -161,9 +159,3 @@ def split(key: PRNG_T, num: int = 2) -> Sequence[PRNG_T]: """ return key.split(num) - - -# Allows for `jnp.where(pred, key1, key2)`. -@quax.register(lax.select_n_p) -def _(pred, *cases: PRNG) -> PRNG: - return jtu.tree_map(ft.partial(lax.select_n, pred), *cases) From b96b1b9b5ee0a5c74b53698191718f59c545491f Mon Sep 17 00:00:00 2001 From: Vadim Bertrand Date: Fri, 25 Oct 2024 14:32:11 +0200 Subject: [PATCH 07/11] more concise impl of select and stop_gradient --- quax/_core.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/quax/_core.py b/quax/_core.py index af4b0d6..cf37e10 100644 --- a/quax/_core.py +++ b/quax/_core.py @@ -614,19 +614,12 @@ def _(index: ArrayLike, *args: Union[ArrayValue, ArrayLike], branches: tuple): @register(jax.lax.select_n_p) def _(which: ArrayLike, *cases: Union[ArrayValue, ArrayLike]): - leaves, _ = jtu.tree_flatten(cases) - _, treedef = jtu.tree_flatten(cases[0]) - out_val = jax.lax.select_n_p.bind(which, *leaves) - result = jtu.tree_unflatten(treedef, [out_val]) - return result + return jtu.tree_map(ft.partial(jax.lax.select_n_p.bind, which), *cases) @register(jax.lax.stop_gradient_p) def _(x: ArrayValue): - leaves, treedef = jtu.tree_flatten(x) - out_val = jax.lax.stop_gradient_p.bind(*leaves) - result = jtu.tree_unflatten(treedef, [out_val]) - return result + return jtu.tree_map(jax.lax.stop_gradient_p.bind, x) # TODO: also register higher-order primitives like `lax.scan_p` etc. From b68e80eed706cc2d0c0c22774494d8c63164c4ca Mon Sep 17 00:00:00 2001 From: Vadim Bertrand Date: Fri, 25 Oct 2024 17:20:07 +0200 Subject: [PATCH 08/11] remove print left behind --- quax/_core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/quax/_core.py b/quax/_core.py index cf37e10..f502f0b 100644 --- a/quax/_core.py +++ b/quax/_core.py @@ -194,7 +194,6 @@ def process_primitive(self, primitive, tracers, params): try: method, _ = rule.resolve_method(values) except plum.NotFoundLookupError: - print(primitive, values, params) out = _default_process(primitive, values, params) else: out = method(*values, **params) From 0685f83b6bce4a99f04033f5b05ebb861ccd8a5b Mon Sep 17 00:00:00 2001 From: Vadim Bertrand Date: Wed, 30 Oct 2024 15:41:16 +0100 Subject: [PATCH 09/11] handles unprescribed number of branches (i.e. switch). enforce exact same output type. add tests for those. remove stop_gradient and select primitives 'defaults' --- quax/_core.py | 57 ++++++++++++++++++------------------- quax/examples/prng/_core.py | 8 ++++++ tests/test_cond.py | 52 +++++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 30 deletions(-) diff --git a/quax/_core.py b/quax/_core.py index f502f0b..54a46f7 100644 --- a/quax/_core.py +++ b/quax/_core.py @@ -2,7 +2,7 @@ import functools as ft import itertools as it from collections.abc import Callable, Sequence -from typing import Any, cast, Generic, TypeVar, Union +from typing import Any, cast, Generic, List, Optional, TypeVar, Union from typing_extensions import TypeGuard import equinox as eqx @@ -440,7 +440,6 @@ def default(primitive, values, params): (Using the [Equinox](https://github.com/patrick-kidger/equinox) library that underlies much of the JAX ecosystem.) """ - arrays: list[ArrayLike] = [] for x in values: if _is_value(x): @@ -587,38 +586,36 @@ def _( @register(jax.lax.cond_p) def _(index: ArrayLike, *args: Union[ArrayValue, ArrayLike], branches: tuple): - false_jaxpr, true_jaxpr = branches - - # compute jaxpr of quaxified false and true functions - quax_false_fn = quaxify(core.jaxpr_as_fun(false_jaxpr)) - quax_false_jaxpr = jax.make_jaxpr(quax_false_fn)(*args) - quax_true_fn = quaxify(core.jaxpr_as_fun(true_jaxpr)) - quax_true_jaxpr = jax.make_jaxpr(quax_true_fn)(*args) - - # infer the output treedef - args_leaves, in_treedef = jtu.tree_flatten(args) - wrapped_fn, out_treedef = api_util.flatten_fun_nokwargs( # pyright: ignore - lu.wrap_init(quax_false_fn), in_treedef - ) - in_avals = tuple([core.raise_to_shaped(core.get_aval(x)) for x in args_leaves]) - _ = pe.trace_to_jaxpr_dynamic(wrapped_fn, in_avals) - out_treedef = out_treedef() + def quaxed_jaxpr_out_tree(jaxpr): + quax_fn = quaxify(core.jaxpr_as_fun(jaxpr)) + wrapped_fn, out_tree = api_util.flatten_fun_nokwargs( # pyright: ignore + lu.wrap_init(quax_fn), in_tree + ) + in_avals = tuple([core.raise_to_shaped(core.get_aval(x)) for x in args_leaves]) + quax_jaxpr = pe.trace_to_jaxpr_dynamic(wrapped_fn, in_avals)[0] + return core.ClosedJaxpr(quax_jaxpr, ()), out_tree() + + args_leaves, in_tree = jtu.tree_flatten(args) + + quax_branches_jaxpr: List[Optional[core.ClosedJaxpr]] = [None] * len(branches) + quax_jaxpr0, out_tree0 = quaxed_jaxpr_out_tree(branches[0]) + quax_branches_jaxpr[0] = quax_jaxpr0 + for i in range(1, len(branches)): + quax_jaxpr, out_tree = quaxed_jaxpr_out_tree(branches[i]) + jax._src.lax.control_flow.common._check_tree_and_avals( # pyright: ignore + f"branch 0 and {i + 1} outputs", + out_tree0, + quax_jaxpr0.out_avals, + out_tree, + quax_jaxpr.out_avals, + ) + quax_branches_jaxpr[i] = quax_jaxpr out_val = jax.lax.cond_p.bind( - index, *args_leaves, branches=(quax_false_jaxpr, quax_true_jaxpr) + index, *args_leaves, branches=tuple(quax_branches_jaxpr) ) - result = jtu.tree_unflatten(out_treedef, out_val) + result = jtu.tree_unflatten(out_tree0, out_val) return result -@register(jax.lax.select_n_p) -def _(which: ArrayLike, *cases: Union[ArrayValue, ArrayLike]): - return jtu.tree_map(ft.partial(jax.lax.select_n_p.bind, which), *cases) - - -@register(jax.lax.stop_gradient_p) -def _(x: ArrayValue): - return jtu.tree_map(jax.lax.stop_gradient_p.bind, x) - - # TODO: also register higher-order primitives like `lax.scan_p` etc. diff --git a/quax/examples/prng/_core.py b/quax/examples/prng/_core.py index 9a804b8..046ba75 100644 --- a/quax/examples/prng/_core.py +++ b/quax/examples/prng/_core.py @@ -1,4 +1,5 @@ import abc +import functools as ft from collections.abc import Sequence from typing import Any, TypeVar from typing_extensions import Self, TYPE_CHECKING, TypeAlias @@ -9,6 +10,7 @@ import jax.core import jax.lax as lax import jax.numpy as jnp +import jax.tree_util as jtu import numpy as np from jaxtyping import Array, ArrayLike, Float, Integer, UInt, UInt32 @@ -159,3 +161,9 @@ def split(key: PRNG_T, num: int = 2) -> Sequence[PRNG_T]: """ return key.split(num) + + +# Allows for `jnp.where(pred, key1, key2)`. +@quax.register(lax.select_n_p) +def _(pred, *cases: PRNG) -> PRNG: + return jtu.tree_map(ft.partial(lax.select_n, pred), *cases) diff --git a/tests/test_cond.py b/tests/test_cond.py index 41c45fd..c940fb2 100644 --- a/tests/test_cond.py +++ b/tests/test_cond.py @@ -42,6 +42,58 @@ def test_cond_different_units(): quax.quaxify(_outer_fn)(a, b, c, False) +def test_cond_different_out_trees(): + def _outer_fn( + a: jax.Array, b: jax.Array, c: jax.Array, pred: Union[bool, jax.Array] + ): + def _true_fn(a: jax.Array): + return a + b + + def _false_fn(a: jax.Array): + return a * c + + res = jax.lax.cond(pred, _true_fn, _false_fn, a) + return res + + a = Unitful(jnp.asarray([1.0]), {meters: 1}) + b = Unitful(jnp.asarray([2.0]), {meters: 1}) + c = Unitful(jnp.asarray([10.0]), {meters: 1}) + + with pytest.raises(Exception): + quax.quaxify(_outer_fn)(a, b, c, False) + + +def test_cond_switch(): + def _outer_fn(index: int, a: jax.Array, b: jax.Array, c: jax.Array): + def _fn0(a: jax.Array): + return a + b + + def _fn1(a: jax.Array): + return a + c + + def _fn2(a: jax.Array): + return a + b + c + + res = jax.lax.switch(index, (_fn0, _fn1, _fn2), a) + return res + + a = Unitful(jnp.asarray([1.0]), {meters: 1}) + b = Unitful(jnp.asarray([2.0]), {meters: 1}) + c = Unitful(jnp.asarray([10.0]), {meters: 1}) + + res = quax.quaxify(_outer_fn)(0, a, b, c) + assert res.array == 3 + assert res.units == {meters: 1} + + res = quax.quaxify(_outer_fn)(1, a, b, c) + assert res.array == 11 + assert res.units == {meters: 1} + + res = quax.quaxify(_outer_fn)(2, a, b, c) + assert res.array == 13 + assert res.units == {meters: 1} + + def test_cond_jit(): a = Unitful(jnp.asarray(1.0), {meters: 1}) b = Unitful(jnp.asarray(2.0), {meters: 1}) From e130e9b6a738b3e70880bda9bec728f53eeb31c7 Mon Sep 17 00:00:00 2001 From: Vadim Bertrand Date: Wed, 6 Nov 2024 16:46:16 +0100 Subject: [PATCH 10/11] simplify eval of quaxed jaxpr and out tree of the branches --- quax/_core.py | 52 +++++++++++++++++++++------------------------------ 1 file changed, 21 insertions(+), 31 deletions(-) diff --git a/quax/_core.py b/quax/_core.py index 54a46f7..2d0913e 100644 --- a/quax/_core.py +++ b/quax/_core.py @@ -2,16 +2,14 @@ import functools as ft import itertools as it from collections.abc import Callable, Sequence -from typing import Any, cast, Generic, List, Optional, TypeVar, Union +from typing import Any, cast, Generic, TypeVar, Union from typing_extensions import TypeGuard import equinox as eqx import jax import jax._src -import jax.api_util as api_util import jax.core as core import jax.extend.linear_util as lu -import jax.interpreters.partial_eval as pe import jax.numpy as jnp import jax.tree_util as jtu import plum @@ -586,35 +584,27 @@ def _( @register(jax.lax.cond_p) def _(index: ArrayLike, *args: Union[ArrayValue, ArrayLike], branches: tuple): - def quaxed_jaxpr_out_tree(jaxpr): - quax_fn = quaxify(core.jaxpr_as_fun(jaxpr)) - wrapped_fn, out_tree = api_util.flatten_fun_nokwargs( # pyright: ignore - lu.wrap_init(quax_fn), in_tree - ) - in_avals = tuple([core.raise_to_shaped(core.get_aval(x)) for x in args_leaves]) - quax_jaxpr = pe.trace_to_jaxpr_dynamic(wrapped_fn, in_avals)[0] - return core.ClosedJaxpr(quax_jaxpr, ()), out_tree() - - args_leaves, in_tree = jtu.tree_flatten(args) - - quax_branches_jaxpr: List[Optional[core.ClosedJaxpr]] = [None] * len(branches) - quax_jaxpr0, out_tree0 = quaxed_jaxpr_out_tree(branches[0]) - quax_branches_jaxpr[0] = quax_jaxpr0 - for i in range(1, len(branches)): - quax_jaxpr, out_tree = quaxed_jaxpr_out_tree(branches[i]) - jax._src.lax.control_flow.common._check_tree_and_avals( # pyright: ignore - f"branch 0 and {i + 1} outputs", - out_tree0, - quax_jaxpr0.out_avals, - out_tree, - quax_jaxpr.out_avals, - ) - quax_branches_jaxpr[i] = quax_jaxpr + flat_args, in_tree = jtu.tree_flatten(args) - out_val = jax.lax.cond_p.bind( - index, *args_leaves, branches=tuple(quax_branches_jaxpr) - ) - result = jtu.tree_unflatten(out_tree0, out_val) + out_trees = [] + quax_branches = [] + for jaxpr in branches: + + def flat_quax_call(flat_args): + args = jtu.tree_unflatten(in_tree, flat_args) + out = quaxify(core.jaxpr_as_fun(jaxpr))(*args) + flat_out, out_tree = jtu.tree_flatten(out) + out_trees.append(out_tree) + return flat_out + + quax_jaxpr = jax.make_jaxpr(flat_quax_call)(flat_args) + quax_branches.append(quax_jaxpr) + + if any(tree_outs_i != out_trees[0] for tree_outs_i in out_trees[1:]): + raise TypeError("all branches output must have the same pytree.") + + out_val = jax.lax.cond_p.bind(index, *flat_args, branches=tuple(quax_branches)) + result = jtu.tree_unflatten(out_trees[0], out_val) return result From 95cdab771c6835f592159c194f7e4d06fd4a2a88 Mon Sep 17 00:00:00 2001 From: Vadim Bertrand Date: Mon, 11 Nov 2024 16:19:23 +0100 Subject: [PATCH 11/11] compat with older jax versions --- quax/_core.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/quax/_core.py b/quax/_core.py index 2d0913e..132e7e4 100644 --- a/quax/_core.py +++ b/quax/_core.py @@ -582,8 +582,16 @@ def _( return result +_sentinel = object() + + @register(jax.lax.cond_p) -def _(index: ArrayLike, *args: Union[ArrayValue, ArrayLike], branches: tuple): +def _( + index: ArrayLike, + *args: Union[ArrayValue, ArrayLike], + branches: tuple, + linear=_sentinel, +): flat_args, in_tree = jtu.tree_flatten(args) out_trees = [] @@ -603,7 +611,13 @@ def flat_quax_call(flat_args): if any(tree_outs_i != out_trees[0] for tree_outs_i in out_trees[1:]): raise TypeError("all branches output must have the same pytree.") - out_val = jax.lax.cond_p.bind(index, *flat_args, branches=tuple(quax_branches)) + if linear is _sentinel: + maybe_linear = {} + else: + maybe_linear = dict(linear=linear) + out_val = jax.lax.cond_p.bind( + index, *flat_args, branches=tuple(quax_branches), **maybe_linear + ) result = jtu.tree_unflatten(out_trees[0], out_val) return result