Skip to content

Commit a5e7747

Browse files
committed
Fix Eye test
1 parent f89e567 commit a5e7747

File tree

2 files changed

+9
-14
lines changed

2 files changed

+9
-14
lines changed

pytensor/tensor/basic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,8 +1453,7 @@ def eye(n, m=None, k=0, dtype=None):
14531453
dtype = config.floatX
14541454
if m is None:
14551455
m = n
1456-
localop = Eye(dtype)
1457-
return localop(n, m, k)
1456+
return Eye(dtype)(n, m, k)
14581457

14591458

14601459
def identity_like(x, dtype: str | np.generic | np.dtype | None = None):

tests/tensor/test_basic.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -923,22 +923,18 @@ def test_infer_static_shape():
923923
class TestEye:
924924
# This is slow for the ('int8', 3) version.
925925
def test_basic(self):
926-
def check(dtype, N, M_=None, k=0):
927-
# PyTensor does not accept None as a tensor.
928-
# So we must use a real value.
929-
M = M_
930-
# Currently DebugMode does not support None as inputs even if this is
931-
# allowed.
932-
if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]:
933-
M = N
926+
def check(dtype, N, M=None, k=0):
934927
N_symb = iscalar()
935928
M_symb = iscalar()
936929
k_symb = iscalar()
930+
test_inputs = [N, k] if M is None else [N, M, k]
931+
inputs = [N_symb, k_symb] if M is None else [N_symb, M_symb, k_symb]
937932
f = function(
938-
[N_symb, M_symb, k_symb], eye(N_symb, M_symb, k_symb, dtype=dtype)
933+
inputs,
934+
eye(N_symb, None if (M is None) else M_symb, k_symb, dtype=dtype),
939935
)
940-
result = f(N, M, k)
941-
assert np.allclose(result, np.eye(N, M_, k, dtype=dtype))
936+
result = f(*test_inputs)
937+
assert np.allclose(result, np.eye(N, M, k, dtype=dtype))
942938
assert result.dtype == np.dtype(dtype)
943939

944940
for dtype in ALL_DTYPES:
@@ -1744,7 +1740,7 @@ def test_join_matrixV_negative_axis(self):
17441740
got = f(-2)
17451741
assert np.allclose(got, want)
17461742

1747-
with pytest.raises(ValueError):
1743+
with pytest.raises((ValueError, IndexError)):
17481744
f(-3)
17491745

17501746
@pytest.mark.parametrize("py_impl", (False, True))

0 commit comments

Comments
 (0)