Skip to content

Commit

Permalink
Implemented jax.lax.while primitive (#16)
Browse files Browse the repository at this point in the history
* Implemented jax.lax.while primitive

* extended unitful example

* added test cases for while implementation

* test grad over closure

* minor changes to while primitive
  • Loading branch information
ymahlau authored Jul 21, 2024
1 parent 9b27fc0 commit dcbadb6
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 0 deletions.
35 changes: 35 additions & 0 deletions quax/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,4 +545,39 @@ def _(*args: Union[ArrayLike, ArrayValue], jaxpr, inline, **kwargs):
return jax.jit(flat_fun)(leaves) # now we can call without Quax.


@register(jax.lax.while_p)
def _(
*args: Union[ArrayValue, ArrayLike],
cond_nconsts: int,
cond_jaxpr,
body_nconsts: int,
body_jaxpr,
):
cond_consts = args[:cond_nconsts]
body_consts = args[cond_nconsts : cond_nconsts + body_nconsts]
init_vals = args[cond_nconsts + body_nconsts :]

# compute jaxpr of quaxified body and condition function
quax_cond_fn = quaxify(core.jaxpr_as_fun(cond_jaxpr))
quax_cond_jaxpr = jax.make_jaxpr(quax_cond_fn)(*cond_consts, *init_vals)
quax_body_fn = quaxify(core.jaxpr_as_fun(body_jaxpr))
quax_body_jaxpr = jax.make_jaxpr(quax_body_fn)(*body_consts, *init_vals)

cond_leaves, _ = jtu.tree_flatten(cond_consts)
body_leaves, _ = jtu.tree_flatten(body_consts)
init_val_leaves, val_treedef = jtu.tree_flatten(init_vals)

out_val = jax.lax.while_p.bind(
*cond_leaves,
*body_leaves,
*init_val_leaves,
cond_nconsts=cond_nconsts,
cond_jaxpr=quax_cond_jaxpr,
body_nconsts=body_nconsts,
body_jaxpr=quax_body_jaxpr,
)
result = jtu.tree_unflatten(val_treedef, out_val)
return result


# TODO: also register higher-order primitives like `lax.cond_p` etc.
7 changes: 7 additions & 0 deletions quax/examples/unitful/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from ._core import (
Dimension as Dimension,
kilograms as kilograms,
meters as meters,
seconds as seconds,
Unitful as Unitful,
)
93 changes: 93 additions & 0 deletions quax/examples/unitful/_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Union

import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.core as core
import jax.numpy as jnp
from jaxtyping import ArrayLike # https://github.com/patrick-kidger/jaxtyping

import quax


class Dimension:
def __init__(self, name):
self.name = name

def __repr__(self):
return self.name


kilograms = Dimension("kg")
meters = Dimension("m")
seconds = Dimension("s")


def _dim_to_unit(x: Union[Dimension, dict[Dimension, int]]) -> dict[Dimension, int]:
if isinstance(x, Dimension):
return {x: 1}
else:
return x


class Unitful(quax.ArrayValue):
array: ArrayLike
units: dict[Dimension, int] = eqx.field(static=True, converter=_dim_to_unit)

def aval(self):
shape = jnp.shape(self.array)
dtype = jnp.result_type(self.array)
return core.ShapedArray(shape, dtype)

def materialise(self):
raise ValueError("Refusing to materialise Unitful array.")


@quax.register(jax.lax.add_p)
def _(x: Unitful, y: Unitful): # function name doesn't matter
if x.units == y.units:
return Unitful(x.array + y.array, x.units)
else:
raise ValueError(f"Cannot add two arrays with units {x.units} and {y.units}.")


@quax.register(jax.lax.mul_p)
def _(x: Unitful, y: Unitful):
units = x.units.copy()
for k, v in y.units.items():
if k in units:
units[k] += v
else:
units[k] = v
return Unitful(x.array * y.array, units)


@quax.register(jax.lax.mul_p)
def _(x: ArrayLike, y: Unitful):
return Unitful(x * y.array, y.units)


@quax.register(jax.lax.mul_p)
def _(x: Unitful, y: ArrayLike):
return Unitful(x.array * y, x.units)


@quax.register(jax.lax.integer_pow_p)
def _(x: Unitful, *, y: int):
units = {k: v * y for k, v in x.units.items()}
return Unitful(x.array, units)


@quax.register(jax.lax.lt_p)
def _(x: Unitful, y: Unitful, **kwargs):
if x.units == y.units:
return jax.lax.lt(x.array, y.array, **kwargs)
else:
raise ValueError(
f"Cannot compare two arrays with units {x.units} and {y.units}."
)


@quax.register(jax.lax.broadcast_in_dim_p)
def _(operand: Unitful, **kwargs):
new_arr = jax.lax.broadcast_in_dim(operand.array, **kwargs)
return Unitful(new_arr, operand.units)
91 changes: 91 additions & 0 deletions tests/test_while.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import jax
import jax.numpy as jnp
import pytest

import quax
from quax.examples.unitful import kilograms, meters, Unitful


def _outer_fn(a: jax.Array, b: jax.Array, c: jax.Array):
# body has b as static argument
def _body_fn(a: jax.Array):
return a + b

# cond has c as static argument
def _cond_fn(a: jax.Array):
return (a < c).squeeze()

res = jax.lax.while_loop(
body_fun=_body_fn,
cond_fun=_cond_fn,
init_val=a,
)
return res


def test_while_basic():
a = Unitful(jnp.asarray(1.0), {meters: 1})
b = Unitful(jnp.asarray(2.0), {meters: 1})
c = Unitful(jnp.asarray(10.0), {meters: 1})
res = quax.quaxify(_outer_fn)(a, b, c)
assert res.array == 11
assert res.units == {meters: 1}


def test_while_different_units():
a = Unitful(jnp.asarray([1.0]), {meters: 1})
b = Unitful(jnp.asarray([2.0]), {meters: 1})
c = Unitful(jnp.asarray([10.0]), {kilograms: 1})
with pytest.raises(Exception):
quax.quaxify(_outer_fn)(a, b, c)


def test_while_jit():
a = Unitful(jnp.asarray(1.0), {meters: 1})
b = Unitful(jnp.asarray(2.0), {meters: 1})
c = Unitful(jnp.asarray(10.0), {meters: 1})
res = quax.quaxify(jax.jit(_outer_fn))(a, b, c)
assert res.array == 11
assert res.units == {meters: 1}


def test_while_vmap():
a = Unitful(jnp.arange(1), {meters: 1})
b = Unitful(jnp.asarray(2), {meters: 1})
c = Unitful(jnp.arange(2, 13, 2), {meters: 1})
vmap_fn = jax.vmap(_outer_fn, in_axes=(None, None, 0))
res = quax.quaxify(vmap_fn)(a, b, c)
for i in range(len(c.array)): # type: ignore
assert res.array[i] == c.array[i] # type: ignore
assert res.units == {meters: 1}


def test_while_grad_closure():
x = Unitful(jnp.asarray(2.0), {meters: 1})
c = Unitful(jnp.asarray(10.0), {meters: 1})
dummy = Unitful(jnp.asarray(1.0), {meters: 1})

def outer_fn(outer_var: jax.Array, c: jax.Array, dummy: jax.Array):
def _body_fn_grad(a: jax.Array):
return a + outer_var

def _cond_fn_grad(a: jax.Array):
return (a < c).squeeze()

def _outer_fn_grad(a: jax.Array):
return jax.lax.while_loop(
body_fun=_body_fn_grad,
cond_fun=_cond_fn_grad,
init_val=a,
)

primals = (outer_var,)
tangents = (dummy,)
p_out, t_out = jax.jvp(_outer_fn_grad, primals, tangents)
return p_out, t_out

p, t = quax.quaxify(outer_fn)(x, c, dummy)
assert p.array == 10
assert p.units == {meters: 1}
assert t.array == 1
assert t.units == {meters: 1}

0 comments on commit dcbadb6

Please sign in to comment.