Skip to content

Commit 57a4ddd

Browse files
committed
Test wasn't actually covering rewrite
1 parent 428bfdd commit 57a4ddd

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,11 @@ def local_0_dot_x(fgraph, node):
147147
x, y = node.inputs
148148
if (
149149
get_underlying_scalar_constant_value(
150-
x, only_process_constants=True, raise_not_constant=False
150+
x, only_process_constants=False, raise_not_constant=False
151151
)
152152
== 0
153153
or get_underlying_scalar_constant_value(
154-
y, only_process_constants=True, raise_not_constant=False
154+
y, only_process_constants=False, raise_not_constant=False
155155
)
156156
== 0
157157
):

tests/tensor/rewriting/test_subtensor.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1448,11 +1448,14 @@ def test_dot_allocs_0(self):
14481448
not isinstance(n.op, Dot) for n in f.maker.fgraph.toposort()
14491449
)
14501450

1451-
# test that we don't remove shape errors
1451+
# test that we don't remove shape errors if we exclude shape_unsafe
1452+
f_safe = f = function(
1453+
[_e1[0], _e2[0]], o, mode=self.mode.excluding("shape_unsafe")
1454+
)
14521455
with pytest.raises((ValueError, AssertionError)):
1453-
f(_e1[1], _e2[2])
1456+
f_safe(_e1[1], _e2[2])
14541457
with pytest.raises((ValueError, AssertionError)):
1455-
f(_e1[2], _e2[1])
1458+
f_safe(_e1[2], _e2[1])
14561459

14571460

14581461
def test_local_IncSubtensor_serialize():

0 commit comments

Comments
 (0)