@@ -1295,12 +1295,28 @@ def local_inplace_setsubtensor(fgraph, node):
12951295
12961296@node_rewriter ([AdvancedIncSubtensor1 ], inplace = True ) 
12971297def  local_inplace_AdvancedIncSubtensor1 (fgraph , node ):
1298-     if  isinstance (node .op , AdvancedIncSubtensor1 ) and  not  node .op .inplace :
1299-         new_op  =  node .op .clone_inplace ()
1300-         new_node  =  new_op (* node .inputs )
1301-         copy_stack_trace (node .outputs , new_node )
1302-         return  [new_node ]
1303-     return  False 
1298+     if  node .op .inplace :
1299+         return 
1300+ 
1301+     x , y , idx  =  node .inputs 
1302+     if  fgraph .has_destroyers ([x ]):
1303+         # In this case we can't operate inplace, but if x is just an alloc of zeros 
1304+         # We're better off duplicating it and then acting on it inplace. 
1305+         if  (
1306+             x .owner  is  not None 
1307+             and  isinstance (x .owner .op , Alloc )
1308+             and  all (x .owner .inputs [0 ].type .broadcastable )
1309+             and  isinstance (x .owner .inputs [0 ], Constant )
1310+             and  x .owner .inputs [0 ].unique_value  ==  0 
1311+         ):
1312+             x  =  x .owner .clone ().outputs [0 ]
1313+         else :
1314+             return  None   # Inplace isn't valid 
1315+ 
1316+     new_op  =  node .op .clone_inplace ()
1317+     new_node  =  new_op (x , y , idx )
1318+     copy_stack_trace (node .outputs , new_node )
1319+     return  [new_node ]
13041320
13051321
13061322compile .optdb .register (
0 commit comments