From b2ec9569ad051bf69b799ce75ac03a7ab07332ca Mon Sep 17 00:00:00 2001 From: Vadim Bertrand Date: Fri, 25 Oct 2024 10:40:30 +0200 Subject: [PATCH] update test_cond.py --- tests/test_cond.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/tests/test_cond.py b/tests/test_cond.py index 741f34b..7126231 100644 --- a/tests/test_cond.py +++ b/tests/test_cond.py @@ -71,22 +71,18 @@ def test_cond_vmap(): def test_cond_grad_closure(): x = Unitful(jnp.asarray(2.0), {meters: 1}) - b = Unitful(jnp.asarray(2.0), {meters: 1}) - c = Unitful(jnp.asarray(10.0), {meters: 1}) dummy = Unitful(jnp.asarray(1.0), {meters: 1}) def outer_fn( outer_var: jax.Array, dummy: jax.Array, - b: jax.Array, - c: jax.Array, pred: bool | jax.Array, ): - def _true_fn_grad(outer_var: jax.Array): - return outer_var + b + def _true_fn_grad(a: jax.Array): + return a + outer_var - def _false_fn_grad(outer_var: jax.Array): - return outer_var + c + def _false_fn_grad(a: jax.Array): + return a + outer_var * 2 def _outer_fn_grad(a: jax.Array): return jax.lax.cond(pred, _true_fn_grad, _false_fn_grad, a) @@ -96,14 +92,14 @@ def _outer_fn_grad(a: jax.Array): p_out, t_out = jax.jvp(_outer_fn_grad, primals, tangents) return p_out, t_out - p, t = quax.quaxify(outer_fn)(x, dummy, b, c, True) + p, t = quax.quaxify(outer_fn)(x, dummy, True) assert p.array == 4 assert p.units == {meters: 1} assert t.array == 1 assert t.units == {meters: 1} - p, t = quax.quaxify(outer_fn)(x, dummy, b, c, False) - assert p.array == 12 + p, t = quax.quaxify(outer_fn)(x, dummy, False) + assert p.array == 6 assert p.units == {meters: 1} assert t.array == 1 assert t.units == {meters: 1}