Skip to content

Commit 05ec481

Browse files
vfdev-5zou3519
authored andcommitted
[functorch] Added backward batch rule for pad replicate/reflect modes (pytorch/functorch#251)
Description: - Added backward batch rule for pad replicate/reflect modes - Updated tests Related to pytorch/functorch#240
1 parent 7a42ec1 commit 05ec481

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

Diff for: functorch/functorch/csrc/BatchRulesModules.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,14 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
409409
EXISTING_BDIM(replication_pad2d);
410410
EXISTING_BDIM(replication_pad3d);
411411

412+
EXISTING_BDIM_ALL_BOXED(replication_pad1d_backward);
413+
EXISTING_BDIM_ALL_BOXED(replication_pad2d_backward);
414+
EXISTING_BDIM_ALL_BOXED(replication_pad3d_backward);
415+
416+
EXISTING_BDIM_ALL_BOXED(reflection_pad1d_backward);
417+
EXISTING_BDIM_ALL_BOXED(reflection_pad2d_backward);
418+
EXISTING_BDIM_ALL_BOXED(reflection_pad3d_backward);
419+
412420
UPSAMPLE_BATCH(upsample_bicubic2d);
413421
UPSAMPLE_BATCH(upsample_bilinear2d);
414422
UPSAMPLE_BATCH(upsample_linear1d);

Diff for: functorch/test/test_ops.py

-2
Original file line numberDiff line numberDiff line change
@@ -465,8 +465,6 @@ def test_vmapvjp(self, device, dtype, op):
465465
xfail('nn.functional.grid_sample'),
466466
xfail('nn.functional.interpolate', 'area'),
467467
xfail('nn.functional.pad', 'circular'),
468-
xfail('nn.functional.pad', 'reflect'),
469-
xfail('nn.functional.pad', 'replicate'),
470468
xfail('nn.functional.unfold'),
471469
xfail('norm', 'fro'),
472470
xfail('norm', 'inf'),

0 commit comments

Comments
 (0)