Skip to content

Commit fdbf3aa

Browse files
committed
Fix bug in local_div_switch_sink rewrite
Introduced in 4f7d709
1 parent c2e88c6 commit fdbf3aa

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

pytensor/tensor/rewriting/math.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,10 @@ def local_div_switch_sink(fgraph, node):
699699
# will point to the new division op.
700700
copy_stack_trace(node.outputs, fdiv)
701701

702-
fct = switch(switch_cond, zero_switch_input, fdiv)
702+
if branch == 0:
703+
fct = switch(switch_cond, zero_switch_input, fdiv)
704+
else:
705+
fct = switch(switch_cond, fdiv, zero_switch_input)
703706

704707
# Tell debug_mode than the output is correct, even if nan disappear
705708
fct.tag.values_eq_approx = values_eq_approx_remove_nan

tests/tensor/rewriting/test_math.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -2163,7 +2163,7 @@ def test_local_mul_div_switch_sink_cast(self, op, rewrite):
21632163
# The zero branch upcasts the output, so we can't ignore its dtype
21642164
zero_branch = constant(np.array(0, dtype="float64"), name="zero_branch")
21652165
other_branch = scalar("other_branch", dtype="float32")
2166-
outer_var = scalar("mul_var", dtype="bool")
2166+
outer_var = scalar("outer_var", dtype="bool")
21672167

21682168
out = op(switch(cond, zero_branch, other_branch), outer_var)
21692169
fgraph = FunctionGraph(outputs=[out], clone=False)
@@ -2173,6 +2173,27 @@ def test_local_mul_div_switch_sink_cast(self, op, rewrite):
21732173
expected_out = switch(cond, zero_branch, op(other_branch, outer_var))
21742174
assert equal_computations([new_out], [expected_out])
21752175

2176+
@pytest.mark.parametrize(
2177+
"op, rewrite", [(mul, local_mul_switch_sink), (true_div, local_div_switch_sink)]
2178+
)
2179+
def test_local_mul_div_switch_sink_branch_order(self, op, rewrite):
2180+
cond = scalar("cond", dtype="bool")
2181+
zero_branch = constant(np.array(0.0, dtype="float64"), "zero_branch")
2182+
other_branch = scalar("other_branch", dtype="float64")
2183+
outer_var = scalar("outer_var", dtype="float64")
2184+
2185+
left = op(switch(cond, zero_branch, other_branch), outer_var)
2186+
right = op(switch(cond, other_branch, zero_branch), outer_var)
2187+
fgraph = FunctionGraph(outputs=[left, right], clone=False)
2188+
[new_left] = rewrite.transform(fgraph, left.owner)
2189+
[new_right] = rewrite.transform(fgraph, right.owner)
2190+
2191+
expected_left = switch(cond, zero_branch, op(other_branch, outer_var))
2192+
expected_right = switch(cond, op(other_branch, outer_var), zero_branch)
2193+
assert equal_computations(
2194+
[new_left, new_right], [expected_left, expected_right]
2195+
)
2196+
21762197

21772198
@pytest.mark.skipif(
21782199
config.cxx == "",

0 commit comments

Comments
 (0)