Skip to content

Commit

Permalink
update test_cond.py
Browse files Browse the repository at this point in the history
  • Loading branch information
vadmbertr committed Oct 25, 2024
1 parent 7cbbd96 commit b2ec956
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions tests/test_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,22 +71,18 @@ def test_cond_vmap():

def test_cond_grad_closure():
x = Unitful(jnp.asarray(2.0), {meters: 1})
b = Unitful(jnp.asarray(2.0), {meters: 1})
c = Unitful(jnp.asarray(10.0), {meters: 1})
dummy = Unitful(jnp.asarray(1.0), {meters: 1})

def outer_fn(
outer_var: jax.Array,
dummy: jax.Array,
b: jax.Array,
c: jax.Array,
pred: bool | jax.Array,
):
def _true_fn_grad(outer_var: jax.Array):
return outer_var + b
def _true_fn_grad(a: jax.Array):
return a + outer_var

def _false_fn_grad(outer_var: jax.Array):
return outer_var + c
def _false_fn_grad(a: jax.Array):
return a + outer_var * 2

def _outer_fn_grad(a: jax.Array):
return jax.lax.cond(pred, _true_fn_grad, _false_fn_grad, a)
Expand All @@ -96,14 +92,14 @@ def _outer_fn_grad(a: jax.Array):
p_out, t_out = jax.jvp(_outer_fn_grad, primals, tangents)
return p_out, t_out

p, t = quax.quaxify(outer_fn)(x, dummy, b, c, True)
p, t = quax.quaxify(outer_fn)(x, dummy, True)
assert p.array == 4
assert p.units == {meters: 1}
assert t.array == 1
assert t.units == {meters: 1}

p, t = quax.quaxify(outer_fn)(x, dummy, b, c, False)
assert p.array == 12
p, t = quax.quaxify(outer_fn)(x, dummy, False)
assert p.array == 6
assert p.units == {meters: 1}
assert t.array == 1
assert t.units == {meters: 1}

0 comments on commit b2ec956

Please sign in to comment.