From dcbadb6ef4dccb5155be2313e42701fa004a588e Mon Sep 17 00:00:00 2001 From: Yannik Mahlau <59509701+ymahlau@users.noreply.github.com> Date: Sun, 21 Jul 2024 16:44:38 +0200 Subject: [PATCH] Implemented jax.lax.while primitive (#16) * Implemented jax.lax.while primitive * extended unitful example * added test cases for while implementation * test grad over closure * minor changes to while primitive --- quax/_core.py | 35 ++++++++++++ quax/examples/unitful/__init__.py | 7 +++ quax/examples/unitful/_core.py | 93 +++++++++++++++++++++++++++++++ tests/test_while.py | 91 ++++++++++++++++++++++++++++++ 4 files changed, 226 insertions(+) create mode 100644 quax/examples/unitful/__init__.py create mode 100644 quax/examples/unitful/_core.py create mode 100644 tests/test_while.py diff --git a/quax/_core.py b/quax/_core.py index 35edd5d..102ae3a 100644 --- a/quax/_core.py +++ b/quax/_core.py @@ -545,4 +545,39 @@ def _(*args: Union[ArrayLike, ArrayValue], jaxpr, inline, **kwargs): return jax.jit(flat_fun)(leaves) # now we can call without Quax. +@register(jax.lax.while_p) +def _( + *args: Union[ArrayValue, ArrayLike], + cond_nconsts: int, + cond_jaxpr, + body_nconsts: int, + body_jaxpr, +): + cond_consts = args[:cond_nconsts] + body_consts = args[cond_nconsts : cond_nconsts + body_nconsts] + init_vals = args[cond_nconsts + body_nconsts :] + + # compute jaxpr of quaxified body and condition function + quax_cond_fn = quaxify(core.jaxpr_as_fun(cond_jaxpr)) + quax_cond_jaxpr = jax.make_jaxpr(quax_cond_fn)(*cond_consts, *init_vals) + quax_body_fn = quaxify(core.jaxpr_as_fun(body_jaxpr)) + quax_body_jaxpr = jax.make_jaxpr(quax_body_fn)(*body_consts, *init_vals) + + cond_leaves, _ = jtu.tree_flatten(cond_consts) + body_leaves, _ = jtu.tree_flatten(body_consts) + init_val_leaves, val_treedef = jtu.tree_flatten(init_vals) + + out_val = jax.lax.while_p.bind( + *cond_leaves, + *body_leaves, + *init_val_leaves, + cond_nconsts=cond_nconsts, + cond_jaxpr=quax_cond_jaxpr, + body_nconsts=body_nconsts, + body_jaxpr=quax_body_jaxpr, + ) + result = jtu.tree_unflatten(val_treedef, out_val) + return result + + # TODO: also register higher-order primitives like `lax.cond_p` etc. diff --git a/quax/examples/unitful/__init__.py b/quax/examples/unitful/__init__.py new file mode 100644 index 0000000..9961381 --- /dev/null +++ b/quax/examples/unitful/__init__.py @@ -0,0 +1,7 @@ +from ._core import ( + Dimension as Dimension, + kilograms as kilograms, + meters as meters, + seconds as seconds, + Unitful as Unitful, +) diff --git a/quax/examples/unitful/_core.py b/quax/examples/unitful/_core.py new file mode 100644 index 0000000..e4e5e0f --- /dev/null +++ b/quax/examples/unitful/_core.py @@ -0,0 +1,93 @@ +from typing import Union + +import equinox as eqx # https://github.com/patrick-kidger/equinox +import jax +import jax.core as core +import jax.numpy as jnp +from jaxtyping import ArrayLike # https://github.com/patrick-kidger/jaxtyping + +import quax + + +class Dimension: + def __init__(self, name): + self.name = name + + def __repr__(self): + return self.name + + +kilograms = Dimension("kg") +meters = Dimension("m") +seconds = Dimension("s") + + +def _dim_to_unit(x: Union[Dimension, dict[Dimension, int]]) -> dict[Dimension, int]: + if isinstance(x, Dimension): + return {x: 1} + else: + return x + + +class Unitful(quax.ArrayValue): + array: ArrayLike + units: dict[Dimension, int] = eqx.field(static=True, converter=_dim_to_unit) + + def aval(self): + shape = jnp.shape(self.array) + dtype = jnp.result_type(self.array) + return core.ShapedArray(shape, dtype) + + def materialise(self): + raise ValueError("Refusing to materialise Unitful array.") + + +@quax.register(jax.lax.add_p) +def _(x: Unitful, y: Unitful): # function name doesn't matter + if x.units == y.units: + return Unitful(x.array + y.array, x.units) + else: + raise ValueError(f"Cannot add two arrays with units {x.units} and {y.units}.") + + +@quax.register(jax.lax.mul_p) +def _(x: Unitful, y: Unitful): + units = x.units.copy() + for k, v in y.units.items(): + if k in units: + units[k] += v + else: + units[k] = v + return Unitful(x.array * y.array, units) + + +@quax.register(jax.lax.mul_p) +def _(x: ArrayLike, y: Unitful): + return Unitful(x * y.array, y.units) + + +@quax.register(jax.lax.mul_p) +def _(x: Unitful, y: ArrayLike): + return Unitful(x.array * y, x.units) + + +@quax.register(jax.lax.integer_pow_p) +def _(x: Unitful, *, y: int): + units = {k: v * y for k, v in x.units.items()} + return Unitful(x.array, units) + + +@quax.register(jax.lax.lt_p) +def _(x: Unitful, y: Unitful, **kwargs): + if x.units == y.units: + return jax.lax.lt(x.array, y.array, **kwargs) + else: + raise ValueError( + f"Cannot compare two arrays with units {x.units} and {y.units}." + ) + + +@quax.register(jax.lax.broadcast_in_dim_p) +def _(operand: Unitful, **kwargs): + new_arr = jax.lax.broadcast_in_dim(operand.array, **kwargs) + return Unitful(new_arr, operand.units) diff --git a/tests/test_while.py b/tests/test_while.py new file mode 100644 index 0000000..d31237d --- /dev/null +++ b/tests/test_while.py @@ -0,0 +1,91 @@ +import jax +import jax.numpy as jnp +import pytest + +import quax +from quax.examples.unitful import kilograms, meters, Unitful + + +def _outer_fn(a: jax.Array, b: jax.Array, c: jax.Array): + # body has b as static argument + def _body_fn(a: jax.Array): + return a + b + + # cond has c as static argument + def _cond_fn(a: jax.Array): + return (a < c).squeeze() + + res = jax.lax.while_loop( + body_fun=_body_fn, + cond_fun=_cond_fn, + init_val=a, + ) + return res + + +def test_while_basic(): + a = Unitful(jnp.asarray(1.0), {meters: 1}) + b = Unitful(jnp.asarray(2.0), {meters: 1}) + c = Unitful(jnp.asarray(10.0), {meters: 1}) + res = quax.quaxify(_outer_fn)(a, b, c) + assert res.array == 11 + assert res.units == {meters: 1} + + +def test_while_different_units(): + a = Unitful(jnp.asarray([1.0]), {meters: 1}) + b = Unitful(jnp.asarray([2.0]), {meters: 1}) + c = Unitful(jnp.asarray([10.0]), {kilograms: 1}) + with pytest.raises(Exception): + quax.quaxify(_outer_fn)(a, b, c) + + +def test_while_jit(): + a = Unitful(jnp.asarray(1.0), {meters: 1}) + b = Unitful(jnp.asarray(2.0), {meters: 1}) + c = Unitful(jnp.asarray(10.0), {meters: 1}) + res = quax.quaxify(jax.jit(_outer_fn))(a, b, c) + assert res.array == 11 + assert res.units == {meters: 1} + + +def test_while_vmap(): + a = Unitful(jnp.arange(1), {meters: 1}) + b = Unitful(jnp.asarray(2), {meters: 1}) + c = Unitful(jnp.arange(2, 13, 2), {meters: 1}) + vmap_fn = jax.vmap(_outer_fn, in_axes=(None, None, 0)) + res = quax.quaxify(vmap_fn)(a, b, c) + for i in range(len(c.array)): # type: ignore + assert res.array[i] == c.array[i] # type: ignore + assert res.units == {meters: 1} + + +def test_while_grad_closure(): + x = Unitful(jnp.asarray(2.0), {meters: 1}) + c = Unitful(jnp.asarray(10.0), {meters: 1}) + dummy = Unitful(jnp.asarray(1.0), {meters: 1}) + + def outer_fn(outer_var: jax.Array, c: jax.Array, dummy: jax.Array): + def _body_fn_grad(a: jax.Array): + return a + outer_var + + def _cond_fn_grad(a: jax.Array): + return (a < c).squeeze() + + def _outer_fn_grad(a: jax.Array): + return jax.lax.while_loop( + body_fun=_body_fn_grad, + cond_fun=_cond_fn_grad, + init_val=a, + ) + + primals = (outer_var,) + tangents = (dummy,) + p_out, t_out = jax.jvp(_outer_fn_grad, primals, tangents) + return p_out, t_out + + p, t = quax.quaxify(outer_fn)(x, c, dummy) + assert p.array == 10 + assert p.units == {meters: 1} + assert t.array == 1 + assert t.units == {meters: 1}