@@ -740,18 +740,15 @@ def grad(self, inputs, gout):
740740 (gz ,) = gout
741741 axis = self .axis
742742
743- # To sum the gradients that belong to the same repeated x,
744- # We create a repeated eye and dot product it with the gradient.
743+ # Use IncSubtensor to sum the gradients that belong to the repeated entries of x
745744 axis_size = x .shape [axis ]
746- repeated_eye = repeat (
747- ptb .eye (axis_size ), repeats , axis = 0
748- ) # A sparse repeat would be neat
749-
750- # Place gradient axis at end for dot product
751- gx = ptb .moveaxis (gz , axis , - 1 )
752- gx = gx @ repeated_eye
753- # Place gradient back into the correct axis
754- gx = ptb .moveaxis (gx , - 1 , axis )
745+ repeated_arange = repeat (ptb .arange (axis_size ), repeats , axis = 0 )
746+
747+ # Move the axis to repeat to front for easier indexing
748+ x_transpose = ptb .moveaxis (x , axis , 0 )
749+ gz_transpose = ptb .moveaxis (gz , axis , 0 )
750+ gx_transpose = ptb .zeros_like (x_transpose )[repeated_arange ].inc (gz_transpose )
751+ gx = ptb .moveaxis (gx_transpose , 0 , axis )
755752
756753 return [gx , DisconnectedType ()()]
757754
0 commit comments