From 37dc139ffe97d707e4075b42febd1f14eb45e330 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Fri, 31 Jan 2025 11:53:33 -0500 Subject: [PATCH] test(lora): add regression test Signed-off-by: Nathaniel Starkman --- tests/test_lora.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_lora.py b/tests/test_lora.py index e20f8d5..d8a7199 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -4,6 +4,8 @@ import jax.numpy as jnp import jax.random as jr import pytest +from jaxtyping import TypeCheckError +from plum import NotFoundLookupError import quax import quax.examples.lora as lora @@ -110,3 +112,18 @@ def test_materialise(): _ = quax.quaxify(jax.nn.relu)(x_true) with pytest.raises(RuntimeError, match="Refusing to materialise"): _ = quax.quaxify(jax.nn.relu)(x_false) + + +def test_regression_38(getkey): + """Regression test for PR 38 (stackless tracers).""" + x = jnp.arange(4.0).reshape(2, 2) + y = lora.LoraArray(x, rank=1, key=getkey()) + + def f(x): + return jax.lax.add_p.bind(x, y) + + func = quax.quaxify(f) + + # Error type depends on whether jaxtyping is on + with pytest.raises((TypeCheckError, NotFoundLookupError)): + _ = func(y)