@@ -2163,7 +2163,7 @@ def test_local_mul_div_switch_sink_cast(self, op, rewrite):
2163
2163
# The zero branch upcasts the output, so we can't ignore its dtype
2164
2164
zero_branch = constant (np .array (0 , dtype = "float64" ), name = "zero_branch" )
2165
2165
other_branch = scalar ("other_branch" , dtype = "float32" )
2166
- outer_var = scalar ("mul_var " , dtype = "bool" )
2166
+ outer_var = scalar ("outer_var " , dtype = "bool" )
2167
2167
2168
2168
out = op (switch (cond , zero_branch , other_branch ), outer_var )
2169
2169
fgraph = FunctionGraph (outputs = [out ], clone = False )
@@ -2173,6 +2173,27 @@ def test_local_mul_div_switch_sink_cast(self, op, rewrite):
2173
2173
expected_out = switch (cond , zero_branch , op (other_branch , outer_var ))
2174
2174
assert equal_computations ([new_out ], [expected_out ])
2175
2175
2176
+ @pytest .mark .parametrize (
2177
+ "op, rewrite" , [(mul , local_mul_switch_sink ), (true_div , local_div_switch_sink )]
2178
+ )
2179
+ def test_local_mul_div_switch_sink_branch_order (self , op , rewrite ):
2180
+ cond = scalar ("cond" , dtype = "bool" )
2181
+ zero_branch = constant (np .array (0.0 , dtype = "float64" ), "zero_branch" )
2182
+ other_branch = scalar ("other_branch" , dtype = "float64" )
2183
+ outer_var = scalar ("outer_var" , dtype = "float64" )
2184
+
2185
+ left = op (switch (cond , zero_branch , other_branch ), outer_var )
2186
+ right = op (switch (cond , other_branch , zero_branch ), outer_var )
2187
+ fgraph = FunctionGraph (outputs = [left , right ], clone = False )
2188
+ [new_left ] = rewrite .transform (fgraph , left .owner )
2189
+ [new_right ] = rewrite .transform (fgraph , right .owner )
2190
+
2191
+ expected_left = switch (cond , zero_branch , op (other_branch , outer_var ))
2192
+ expected_right = switch (cond , op (other_branch , outer_var ), zero_branch )
2193
+ assert equal_computations (
2194
+ [new_left , new_right ], [expected_left , expected_right ]
2195
+ )
2196
+
2176
2197
2177
2198
@pytest .mark .skipif (
2178
2199
config .cxx == "" ,
0 commit comments