Skip to content

Commit

Permalink
test(lora): add regression test
Browse files Browse the repository at this point in the history
Signed-off-by: Nathaniel Starkman <[email protected]>
  • Loading branch information
nstarman authored and patrick-kidger committed Feb 4, 2025
1 parent 7d65e9b commit 37dc139
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import jax.numpy as jnp
import jax.random as jr
import pytest
from jaxtyping import TypeCheckError
from plum import NotFoundLookupError

import quax
import quax.examples.lora as lora
Expand Down Expand Up @@ -110,3 +112,18 @@ def test_materialise():
_ = quax.quaxify(jax.nn.relu)(x_true)
with pytest.raises(RuntimeError, match="Refusing to materialise"):
_ = quax.quaxify(jax.nn.relu)(x_false)


def test_regression_38(getkey):
"""Regression test for PR 38 (stackless tracers)."""
x = jnp.arange(4.0).reshape(2, 2)
y = lora.LoraArray(x, rank=1, key=getkey())

def f(x):
return jax.lax.add_p.bind(x, y)

func = quax.quaxify(f)

# Error type depends on whether jaxtyping is on
with pytest.raises((TypeCheckError, NotFoundLookupError)):
_ = func(y)

0 comments on commit 37dc139

Please sign in to comment.