Skip to content

Commit d87d44c

Browse files
committed
Use IncSubtensor in gradient of RepeatOp
1 parent 63c513e commit d87d44c

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -740,18 +740,15 @@ def grad(self, inputs, gout):
740740
(gz,) = gout
741741
axis = self.axis
742742

743-
# To sum the gradients that belong to the same repeated x,
744-
# We create a repeated eye and dot product it with the gradient.
743+
# Use IncSubtensor to sum the gradients that belong to the repeated entries of x
745744
axis_size = x.shape[axis]
746-
repeated_eye = repeat(
747-
ptb.eye(axis_size), repeats, axis=0
748-
) # A sparse repeat would be neat
749-
750-
# Place gradient axis at end for dot product
751-
gx = ptb.moveaxis(gz, axis, -1)
752-
gx = gx @ repeated_eye
753-
# Place gradient back into the correct axis
754-
gx = ptb.moveaxis(gx, -1, axis)
745+
repeated_arange = repeat(ptb.arange(axis_size), repeats, axis=0)
746+
747+
# Move the axis to repeat to front for easier indexing
748+
x_transpose = ptb.moveaxis(x, axis, 0)
749+
gz_transpose = ptb.moveaxis(gz, axis, 0)
750+
gx_transpose = ptb.zeros_like(x_transpose)[repeated_arange].inc(gz_transpose)
751+
gx = ptb.moveaxis(gx_transpose, 0, axis)
755752

756753
return [gx, DisconnectedType()()]
757754

0 commit comments

Comments
 (0)