Skip to content

Commit 5423fdd

Browse files
authored
Added batch rule for (adaptive_)avg_poolNd (#248)
* Added batch rule for (adaptive_)avg_poolNd Description: - Added batch rule for adaptive_avg_pool{1d,2d,3d} and avg_pool{1d,2d,3d} - Updated tests * Enabled nn.functional.interpolate mode=area
1 parent fedd426 commit 5423fdd

File tree

4 files changed

+8
-10
lines changed

4 files changed

+8
-10
lines changed

functorch/csrc/BatchRulesDecompositions.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ namespace at { namespace functorch {
1818

1919
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
2020
OP_DECOMPOSE(absolute);
21+
OP_DECOMPOSE(avg_pool1d);
22+
OP_DECOMPOSE(adaptive_avg_pool1d);
2123
OP_DECOMPOSE(adaptive_avg_pool2d);
24+
OP_DECOMPOSE(adaptive_avg_pool3d);
2225
OP_DECOMPOSE(arccos);
2326
OP_DECOMPOSE(arccosh);
2427
OP_DECOMPOSE(arcsin);

functorch/csrc/BatchRulesPooling.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,13 @@ max_pool2d_with_indices_batch_rule(
8484

8585
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
8686
EXISTING_BDIM(_adaptive_avg_pool2d);
87+
EXISTING_BDIM_ALL_BOXED(_adaptive_avg_pool2d_backward);
88+
EXISTING_BDIM(_adaptive_avg_pool3d);
89+
EXISTING_BDIM_ALL_BOXED(_adaptive_avg_pool3d_backward);
8790
EXISTING_BDIM(avg_pool2d);
91+
EXISTING_BDIM(avg_pool3d);
8892
EXISTING_BDIM_ALL_BOXED(avg_pool2d_backward);
93+
EXISTING_BDIM_ALL_BOXED(avg_pool3d_backward);
8994
VMAP_SUPPORT("max_pool2d_with_indices", max_pool2d_with_indices_batch_rule);
9095
VMAP_SUPPORT("max_pool2d_with_indices_backward", max_pool2d_with_indices_backward_batch_rule);
9196
}

test/test_ops.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -447,11 +447,9 @@ def test_vmapvjp(self, device, dtype, op):
447447
xfail('msort'),
448448
xfail('nanmedian'),
449449
xfail('nanquantile'),
450-
xfail('nn.functional.adaptive_avg_pool2d'),
451450
xfail('nn.functional.conv_transpose2d'),
452451
xfail('nn.functional.gelu'),
453452
xfail('nn.functional.grid_sample'),
454-
xfail('nn.functional.interpolate', 'area'),
455453
xfail('nn.functional.pad', 'circular'),
456454
xfail('nn.functional.unfold'),
457455
xfail('norm', 'fro'),
@@ -487,9 +485,6 @@ def test_vmapvjp(self, device, dtype, op):
487485
xfail('fft.ihfft2'),
488486
xfail('fft.ihfftn'),
489487
xfail('fft.rfft2'),
490-
xfail('nn.functional.adaptive_avg_pool1d'),
491-
xfail('nn.functional.adaptive_avg_pool3d'),
492-
xfail('nn.functional.avg_pool3d'),
493488
xfail('nn.functional.embedding'),
494489
}))
495490
def test_vmapvjp_has_batch_rule(self, device, dtype, op):

test/test_vmap.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3088,7 +3088,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
30883088
xfail('nn.functional.cross_entropy', 'mean'),
30893089
xfail('nn.functional.cross_entropy', 'none'),
30903090
xfail('nn.functional.cross_entropy', 'sum'),
3091-
xfail('nn.functional.interpolate', 'area'),
30923091
xfail('nn.functional.pad', 'circular'),
30933092
xfail('nn.functional.unfold'),
30943093
xfail('norm', 'fro'),
@@ -3137,10 +3136,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
31373136
xfail('fft.rfft2'),
31383137
xfail('isinf'),
31393138
xfail('isreal'),
3140-
xfail('nn.functional.adaptive_avg_pool1d'),
3141-
xfail('nn.functional.adaptive_avg_pool3d'),
3142-
xfail('nn.functional.avg_pool1d'),
3143-
xfail('nn.functional.avg_pool3d'),
31443139
xfail('nn.functional.pixel_shuffle'),
31453140
xfail('nn.functional.pixel_unshuffle'),
31463141
}))

0 commit comments

Comments
 (0)