diff --git a/functorch/csrc/BatchRulesDecompositions.cpp b/functorch/csrc/BatchRulesDecompositions.cpp index 1528e914e..eb230315e 100644 --- a/functorch/csrc/BatchRulesDecompositions.cpp +++ b/functorch/csrc/BatchRulesDecompositions.cpp @@ -18,7 +18,10 @@ namespace at { namespace functorch { TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { OP_DECOMPOSE(absolute); + OP_DECOMPOSE(avg_pool1d); + OP_DECOMPOSE(adaptive_avg_pool1d); OP_DECOMPOSE(adaptive_avg_pool2d); + OP_DECOMPOSE(adaptive_avg_pool3d); OP_DECOMPOSE(arccos); OP_DECOMPOSE(arccosh); OP_DECOMPOSE(arcsin); diff --git a/functorch/csrc/BatchRulesPooling.cpp b/functorch/csrc/BatchRulesPooling.cpp index 071df32ef..df144bf70 100644 --- a/functorch/csrc/BatchRulesPooling.cpp +++ b/functorch/csrc/BatchRulesPooling.cpp @@ -84,8 +84,13 @@ max_pool2d_with_indices_batch_rule( TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { EXISTING_BDIM(_adaptive_avg_pool2d); + EXISTING_BDIM_ALL_BOXED(_adaptive_avg_pool2d_backward); + EXISTING_BDIM(_adaptive_avg_pool3d); + EXISTING_BDIM_ALL_BOXED(_adaptive_avg_pool3d_backward); EXISTING_BDIM(avg_pool2d); + EXISTING_BDIM(avg_pool3d); EXISTING_BDIM_ALL_BOXED(avg_pool2d_backward); + EXISTING_BDIM_ALL_BOXED(avg_pool3d_backward); VMAP_SUPPORT("max_pool2d_with_indices", max_pool2d_with_indices_batch_rule); VMAP_SUPPORT("max_pool2d_with_indices_backward", max_pool2d_with_indices_backward_batch_rule); } diff --git a/test/test_ops.py b/test/test_ops.py index d55603579..1704e062c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -461,11 +461,9 @@ def test_vmapvjp(self, device, dtype, op): xfail('msort'), xfail('nanmedian'), xfail('nanquantile'), - xfail('nn.functional.adaptive_avg_pool2d'), xfail('nn.functional.conv_transpose2d'), xfail('nn.functional.gelu'), xfail('nn.functional.grid_sample'), - xfail('nn.functional.interpolate', 'area'), xfail('nn.functional.pad', 'circular'), xfail('nn.functional.pad', 'reflect'), xfail('nn.functional.pad', 'replicate'), @@ -505,9 +503,6 @@ def test_vmapvjp(self, device, dtype, op): xfail('fft.ihfft2'), xfail('fft.ihfftn'), xfail('fft.rfft2'), - xfail('nn.functional.adaptive_avg_pool1d'), - xfail('nn.functional.adaptive_avg_pool3d'), - xfail('nn.functional.avg_pool3d'), xfail('nn.functional.embedding'), })) def test_vmapvjp_has_batch_rule(self, device, dtype, op): diff --git a/test/test_vmap.py b/test/test_vmap.py index 8cffb0674..5d47d65d4 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3094,7 +3094,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('nn.functional.cross_entropy', 'mean'), xfail('nn.functional.cross_entropy', 'none'), xfail('nn.functional.cross_entropy', 'sum'), - xfail('nn.functional.interpolate', 'area'), xfail('nn.functional.pad', 'circular'), xfail('nn.functional.unfold'), xfail('norm', 'fro'), @@ -3143,10 +3142,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('fft.rfft2'), xfail('isinf'), xfail('isreal'), - xfail('nn.functional.adaptive_avg_pool1d'), - xfail('nn.functional.adaptive_avg_pool3d'), - xfail('nn.functional.avg_pool1d'), - xfail('nn.functional.avg_pool3d'), xfail('nn.functional.pixel_shuffle'), xfail('nn.functional.pixel_unshuffle'), }))