-
-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
register lax.cond_p #34
register lax.cond_p #34
Conversation
For a first go at JAX internals this looks pretty good to me! I've left two comments but I think this is basically doing the right thing. As for |
Hi! I'm afraid I don't see the comments. About from typing import Union
import jax
import jax.numpy as jnp
import pytest
import quax
from quax.examples.unitful import kilograms, meters, Unitful
def _outer_fn(a: jax.Array, b: jax.Array, c: jax.Array, pred: Union[bool, jax.Array]):
def _true_fn(a: jax.Array):
return a + b
def _false_fn(a: jax.Array):
return a + c
res = jax.lax.cond(pred, _true_fn, _false_fn, a)
return res
a = Unitful(jnp.asarray(1.0), {meters: 1})
b = Unitful(jnp.asarray(2.0), {meters: 1})
c = Unitful(jnp.asarray(10.0), {meters: 1})
vmap_fn = jax.vmap(_outer_fn, in_axes=(None, None, None, 0))
res = quax.quaxify(vmap_fn)(a, b, c, jnp.asarray([True, False]))
res.array # Array([ 3., 11.], dtype=float32, weak_type=True) (no test for this yet) I also noticed something wrong with my implementation: using the |
quax/_core.py
Outdated
out_val = jax.lax.cond_p.bind( | ||
index, *args_leaves, branches=(quax_false_jaxpr, quax_true_jaxpr) | ||
) | ||
result = jtu.tree_unflatten(out_treedef, out_val) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is necessary -- the inputs and outputs of cond should already be flattened. (The PyTree stuff has already been handled by the time binding happens.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure to understand. Does it mean that we should not return custom objects (ArrayValue
), but only ArrayLike
?
Oh, sorry -- looks like I didn't submit those properly. Just added the two comments! As for vmap -- right, the I agree that we should require that the output have the exact same type. |
…same output type. add tests for those. remove stop_gradient and select primitives 'defaults'
Got it. But don't you think that the
Done! |
quax/_core.py
Outdated
def quaxed_jaxpr_out_tree(jaxpr): | ||
quax_fn = quaxify(core.jaxpr_as_fun(jaxpr)) | ||
wrapped_fn, out_tree = api_util.flatten_fun_nokwargs( # pyright: ignore | ||
lu.wrap_init(quax_fn), in_tree | ||
) | ||
in_avals = tuple([core.raise_to_shaped(core.get_aval(x)) for x in args_leaves]) | ||
quax_jaxpr = pe.trace_to_jaxpr_dynamic(wrapped_fn, in_avals)[0] | ||
return core.ClosedJaxpr(quax_jaxpr, ()), out_tree() | ||
|
||
args_leaves, in_tree = jtu.tree_flatten(args) | ||
|
||
quax_branches_jaxpr: List[Optional[core.ClosedJaxpr]] = [None] * len(branches) | ||
quax_jaxpr0, out_tree0 = quaxed_jaxpr_out_tree(branches[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this complexity is necessary. In particular I'd like to avoid reaching so heavily into JAX internals like this.
I think something like the following would work:
flat_args, tree_args = jtu.tree_flatten(args)
tree_outs = []
new_branches = []
for jaxpr in branches:
def flat_call(flat_args):
args = jtu.tree_unflatten(tree_args, flat_args)
out = quaxify(jax.core.jaxpr_as_fun(jaxpr))(*args)
flat_out, tree_out = jtu.tree_flatten(out)
tree_outs.append(tree_out)
return flat_out
new_jaxpr = jax.make_jaxpr(flat_call)(flat_args)
new_branches.append(new_jaxpr)
if any(tree_outs_i != tree_outs[0] for tree_outs_i in tree_outs[1:]):
raise ValueError(...)
flat_out = jax.lax.cond_p.bind(index, *flat_args, branches=tuple(new_branches)
return jtu.tree_unflatten(tree_outs[0], flat_out)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not get we could just unflatten and flatten while creating the quaxed jaxpr. Much simpler indeed!
Committing with those changes right now
Okay, this LGTM! It looks like the test failure is because older versions of JAX used to have a To be compatible with both old and new JAX, then I think it should be enough to tweak things like so: _sentinel = object()
@register(jax.lax.cond_p)
def _(index: ArrayLike, *args: Union[ArrayValue, ArrayLike], branches: tuple, linear=_sentinel):
...
if linear is _sentinel:
maybe_linear = {}
else:
maybe_linear = dict(linear=linear)
out_val = jax.lax.cond_p.bind(index, *flat_args, branches=tuple(quax_branches), **maybe_linear)
... If you can make that tweak then I'd be happy to merge this :) |
Tests are passing for both 3.9 and 3.11 python version now! |
Wonderful stuff -- and this now all looks good to me, so merged! Thank you for putting all of this together. :D |
Great! |
Hi!
I'm interested in having additional primitives supported so I started with the
lax.cond_p
one.This implementation does not support
vmap
ping over the predictorpred
. In this case, as mentionned in JAX documentation, it falls back toselect
. And it is unclear to me how to support this as there is no primitive for it AFAIU.It's the first time I'm digging inside JAX internals, so do not hesite to point out any misunderstanding! (For instance, I'm not 100% sure about how I retrieved the tree definition of the outputs).
I will be happy to add support for the scan primitive if / when you're good with that one.
Vadim