-
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implemented jax.lax.while primitive (#16)
* Implemented jax.lax.while primitive * extended unitful example * added test cases for while implementation * test grad over closure * minor changes to while primitive
- Loading branch information
Showing
4 changed files
with
226 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from ._core import ( | ||
Dimension as Dimension, | ||
kilograms as kilograms, | ||
meters as meters, | ||
seconds as seconds, | ||
Unitful as Unitful, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from typing import Union | ||
|
||
import equinox as eqx # https://github.com/patrick-kidger/equinox | ||
import jax | ||
import jax.core as core | ||
import jax.numpy as jnp | ||
from jaxtyping import ArrayLike # https://github.com/patrick-kidger/jaxtyping | ||
|
||
import quax | ||
|
||
|
||
class Dimension: | ||
def __init__(self, name): | ||
self.name = name | ||
|
||
def __repr__(self): | ||
return self.name | ||
|
||
|
||
kilograms = Dimension("kg") | ||
meters = Dimension("m") | ||
seconds = Dimension("s") | ||
|
||
|
||
def _dim_to_unit(x: Union[Dimension, dict[Dimension, int]]) -> dict[Dimension, int]: | ||
if isinstance(x, Dimension): | ||
return {x: 1} | ||
else: | ||
return x | ||
|
||
|
||
class Unitful(quax.ArrayValue): | ||
array: ArrayLike | ||
units: dict[Dimension, int] = eqx.field(static=True, converter=_dim_to_unit) | ||
|
||
def aval(self): | ||
shape = jnp.shape(self.array) | ||
dtype = jnp.result_type(self.array) | ||
return core.ShapedArray(shape, dtype) | ||
|
||
def materialise(self): | ||
raise ValueError("Refusing to materialise Unitful array.") | ||
|
||
|
||
@quax.register(jax.lax.add_p) | ||
def _(x: Unitful, y: Unitful): # function name doesn't matter | ||
if x.units == y.units: | ||
return Unitful(x.array + y.array, x.units) | ||
else: | ||
raise ValueError(f"Cannot add two arrays with units {x.units} and {y.units}.") | ||
|
||
|
||
@quax.register(jax.lax.mul_p) | ||
def _(x: Unitful, y: Unitful): | ||
units = x.units.copy() | ||
for k, v in y.units.items(): | ||
if k in units: | ||
units[k] += v | ||
else: | ||
units[k] = v | ||
return Unitful(x.array * y.array, units) | ||
|
||
|
||
@quax.register(jax.lax.mul_p) | ||
def _(x: ArrayLike, y: Unitful): | ||
return Unitful(x * y.array, y.units) | ||
|
||
|
||
@quax.register(jax.lax.mul_p) | ||
def _(x: Unitful, y: ArrayLike): | ||
return Unitful(x.array * y, x.units) | ||
|
||
|
||
@quax.register(jax.lax.integer_pow_p) | ||
def _(x: Unitful, *, y: int): | ||
units = {k: v * y for k, v in x.units.items()} | ||
return Unitful(x.array, units) | ||
|
||
|
||
@quax.register(jax.lax.lt_p) | ||
def _(x: Unitful, y: Unitful, **kwargs): | ||
if x.units == y.units: | ||
return jax.lax.lt(x.array, y.array, **kwargs) | ||
else: | ||
raise ValueError( | ||
f"Cannot compare two arrays with units {x.units} and {y.units}." | ||
) | ||
|
||
|
||
@quax.register(jax.lax.broadcast_in_dim_p) | ||
def _(operand: Unitful, **kwargs): | ||
new_arr = jax.lax.broadcast_in_dim(operand.array, **kwargs) | ||
return Unitful(new_arr, operand.units) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
import pytest | ||
|
||
import quax | ||
from quax.examples.unitful import kilograms, meters, Unitful | ||
|
||
|
||
def _outer_fn(a: jax.Array, b: jax.Array, c: jax.Array): | ||
# body has b as static argument | ||
def _body_fn(a: jax.Array): | ||
return a + b | ||
|
||
# cond has c as static argument | ||
def _cond_fn(a: jax.Array): | ||
return (a < c).squeeze() | ||
|
||
res = jax.lax.while_loop( | ||
body_fun=_body_fn, | ||
cond_fun=_cond_fn, | ||
init_val=a, | ||
) | ||
return res | ||
|
||
|
||
def test_while_basic(): | ||
a = Unitful(jnp.asarray(1.0), {meters: 1}) | ||
b = Unitful(jnp.asarray(2.0), {meters: 1}) | ||
c = Unitful(jnp.asarray(10.0), {meters: 1}) | ||
res = quax.quaxify(_outer_fn)(a, b, c) | ||
assert res.array == 11 | ||
assert res.units == {meters: 1} | ||
|
||
|
||
def test_while_different_units(): | ||
a = Unitful(jnp.asarray([1.0]), {meters: 1}) | ||
b = Unitful(jnp.asarray([2.0]), {meters: 1}) | ||
c = Unitful(jnp.asarray([10.0]), {kilograms: 1}) | ||
with pytest.raises(Exception): | ||
quax.quaxify(_outer_fn)(a, b, c) | ||
|
||
|
||
def test_while_jit(): | ||
a = Unitful(jnp.asarray(1.0), {meters: 1}) | ||
b = Unitful(jnp.asarray(2.0), {meters: 1}) | ||
c = Unitful(jnp.asarray(10.0), {meters: 1}) | ||
res = quax.quaxify(jax.jit(_outer_fn))(a, b, c) | ||
assert res.array == 11 | ||
assert res.units == {meters: 1} | ||
|
||
|
||
def test_while_vmap(): | ||
a = Unitful(jnp.arange(1), {meters: 1}) | ||
b = Unitful(jnp.asarray(2), {meters: 1}) | ||
c = Unitful(jnp.arange(2, 13, 2), {meters: 1}) | ||
vmap_fn = jax.vmap(_outer_fn, in_axes=(None, None, 0)) | ||
res = quax.quaxify(vmap_fn)(a, b, c) | ||
for i in range(len(c.array)): # type: ignore | ||
assert res.array[i] == c.array[i] # type: ignore | ||
assert res.units == {meters: 1} | ||
|
||
|
||
def test_while_grad_closure(): | ||
x = Unitful(jnp.asarray(2.0), {meters: 1}) | ||
c = Unitful(jnp.asarray(10.0), {meters: 1}) | ||
dummy = Unitful(jnp.asarray(1.0), {meters: 1}) | ||
|
||
def outer_fn(outer_var: jax.Array, c: jax.Array, dummy: jax.Array): | ||
def _body_fn_grad(a: jax.Array): | ||
return a + outer_var | ||
|
||
def _cond_fn_grad(a: jax.Array): | ||
return (a < c).squeeze() | ||
|
||
def _outer_fn_grad(a: jax.Array): | ||
return jax.lax.while_loop( | ||
body_fun=_body_fn_grad, | ||
cond_fun=_cond_fn_grad, | ||
init_val=a, | ||
) | ||
|
||
primals = (outer_var,) | ||
tangents = (dummy,) | ||
p_out, t_out = jax.jvp(_outer_fn_grad, primals, tangents) | ||
return p_out, t_out | ||
|
||
p, t = quax.quaxify(outer_fn)(x, c, dummy) | ||
assert p.array == 10 | ||
assert p.units == {meters: 1} | ||
assert t.array == 1 | ||
assert t.units == {meters: 1} |