Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

register lax.cond_p #34

Merged
merged 11 commits into from
Nov 11, 2024
Merged

Conversation

vadmbertr
Copy link
Contributor

Hi!

I'm interested in having additional primitives supported so I started with the lax.cond_p one.

This implementation does not support vmapping over the predictor pred. In this case, as mentionned in JAX documentation, it falls back to select. And it is unclear to me how to support this as there is no primitive for it AFAIU.

It's the first time I'm digging inside JAX internals, so do not hesite to point out any misunderstanding! (For instance, I'm not 100% sure about how I retrieved the tree definition of the outputs).

I will be happy to add support for the scan primitive if / when you're good with that one.
Vadim

@patrick-kidger
Copy link
Owner

For a first go at JAX internals this looks pretty good to me! I've left two comments but I think this is basically doing the right thing.

As for vmap, I think this should work already without needing any special support from us. Once vmap'd then the cond_p disappears altogether (turns into a select), and in fact this now no longer involves any higher-order primitives.

@vadmbertr
Copy link
Contributor Author

Hi!

I'm afraid I don't see the comments.

About vmap, I double check and realized that in fact we need 2 more primitives: stop_gradient and select_n. It is now possible to do:

from typing import Union

import jax
import jax.numpy as jnp
import pytest

import quax
from quax.examples.unitful import kilograms, meters, Unitful


def _outer_fn(a: jax.Array, b: jax.Array, c: jax.Array, pred: Union[bool, jax.Array]):
    def _true_fn(a: jax.Array):
        return a + b

    def _false_fn(a: jax.Array):
        return a + c

    res = jax.lax.cond(pred, _true_fn, _false_fn, a)
    return res


a = Unitful(jnp.asarray(1.0), {meters: 1})
b = Unitful(jnp.asarray(2.0), {meters: 1})
c = Unitful(jnp.asarray(10.0), {meters: 1})

vmap_fn = jax.vmap(_outer_fn, in_axes=(None, None, None, 0))
res = quax.quaxify(vmap_fn)(a, b, c, jnp.asarray([True, False]))

res.array  # Array([ 3., 11.], dtype=float32, weak_type=True)

(no test for this yet)

I also noticed something wrong with my implementation: using the Unitful example, if the True and False branches return differents units, then currently the returned Unitful will always have the units returned by the False branch.
I don't think it is possible to track which out_treedef should be used but we can check that they are both the same.
I believe this comment apply to the select_n implementation.

quax/_core.py Outdated Show resolved Hide resolved
quax/_core.py Outdated
out_val = jax.lax.cond_p.bind(
index, *args_leaves, branches=(quax_false_jaxpr, quax_true_jaxpr)
)
result = jtu.tree_unflatten(out_treedef, out_val)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is necessary -- the inputs and outputs of cond should already be flattened. (The PyTree stuff has already been handled by the time binding happens.)

Copy link
Contributor Author

@vadmbertr vadmbertr Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure to understand. Does it mean that we should not return custom objects (ArrayValue), but only ArrayLike?

@patrick-kidger
Copy link
Owner

Oh, sorry -- looks like I didn't submit those properly. Just added the two comments!

As for vmap -- right, the stop_gradient and select_n will depend on whether the types that are being have default rules registered for them. To explain that a bit: Quax has some central tracing machinery that is shared throughout the library, and handling higher-order primitives like lax.cond_p is part of this. Then there are the more specialised rules for any individual application (unitful, symbolic zeros, lora, ...), and these are where the normal primitives are handled. So basically, handling cond_p is unrelated to handling stop_gradient and select_n.

I agree that we should require that the output have the exact same type.

…same output type. add tests for those. remove stop_gradient and select primitives 'defaults'
@vadmbertr
Copy link
Contributor Author

As for vmap -- right, the stop_gradient and select_n will depend on whether the types that are being have default rules registered for them. To explain that a bit: Quax has some central tracing machinery that is shared throughout the library, and handling higher-order primitives like lax.cond_p is part of this. Then there are the more specialised rules for any individual application (unitful, symbolic zeros, lora, ...), and these are where the normal primitives are handled. So basically, handling cond_p is unrelated to handling stop_gradient and select_n.

Got it. But don't you think that the ValueError being raised if no default rule is registered (and materialization is not allowed) is a bit vague in that specific case?

I agree that we should require that the output have the exact same type.

Done!

quax/_core.py Outdated Show resolved Hide resolved
quax/_core.py Outdated
Comment on lines 589 to 601
def quaxed_jaxpr_out_tree(jaxpr):
quax_fn = quaxify(core.jaxpr_as_fun(jaxpr))
wrapped_fn, out_tree = api_util.flatten_fun_nokwargs( # pyright: ignore
lu.wrap_init(quax_fn), in_tree
)
in_avals = tuple([core.raise_to_shaped(core.get_aval(x)) for x in args_leaves])
quax_jaxpr = pe.trace_to_jaxpr_dynamic(wrapped_fn, in_avals)[0]
return core.ClosedJaxpr(quax_jaxpr, ()), out_tree()

args_leaves, in_tree = jtu.tree_flatten(args)

quax_branches_jaxpr: List[Optional[core.ClosedJaxpr]] = [None] * len(branches)
quax_jaxpr0, out_tree0 = quaxed_jaxpr_out_tree(branches[0])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this complexity is necessary. In particular I'd like to avoid reaching so heavily into JAX internals like this.

I think something like the following would work:

flat_args, tree_args = jtu.tree_flatten(args)
tree_outs = []

new_branches = []
for jaxpr in branches:
    def flat_call(flat_args):
        args = jtu.tree_unflatten(tree_args, flat_args)
        out = quaxify(jax.core.jaxpr_as_fun(jaxpr))(*args)
        flat_out, tree_out = jtu.tree_flatten(out)
        tree_outs.append(tree_out)
        return flat_out

    new_jaxpr = jax.make_jaxpr(flat_call)(flat_args)
    new_branches.append(new_jaxpr)

if any(tree_outs_i != tree_outs[0] for tree_outs_i in tree_outs[1:]):
    raise ValueError(...)

flat_out = jax.lax.cond_p.bind(index, *flat_args, branches=tuple(new_branches)
return jtu.tree_unflatten(tree_outs[0], flat_out)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not get we could just unflatten and flatten while creating the quaxed jaxpr. Much simpler indeed!
Committing with those changes right now

@patrick-kidger
Copy link
Owner

Okay, this LGTM! It looks like the test failure is because older versions of JAX used to have a linear argument here:

jax-ml/jax@6becf71

To be compatible with both old and new JAX, then I think it should be enough to tweak things like so:

_sentinel = object()

@register(jax.lax.cond_p)
def _(index: ArrayLike, *args: Union[ArrayValue, ArrayLike], branches: tuple, linear=_sentinel):
    ...
    if linear is _sentinel:
        maybe_linear = {}
    else:
        maybe_linear = dict(linear=linear)
    out_val = jax.lax.cond_p.bind(index, *flat_args, branches=tuple(quax_branches), **maybe_linear)
    ...

If you can make that tweak then I'd be happy to merge this :)

@vadmbertr
Copy link
Contributor Author

Tests are passing for both 3.9 and 3.11 python version now!

@patrick-kidger patrick-kidger merged commit a6fba20 into patrick-kidger:main Nov 11, 2024
2 checks passed
@patrick-kidger
Copy link
Owner

Wonderful stuff -- and this now all looks good to me, so merged!

Thank you for putting all of this together. :D

@vadmbertr
Copy link
Contributor Author

Great!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants