From 50853a8161f26f2c230a62991542221d30c23598 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 8 Nov 2021 17:33:17 +0000 Subject: [PATCH 1/2] Added batch rule for (adaptive_)avg_poolNd Description: - Added batch rule for adaptive_avg_pool{1d,2d,3d} and avg_pool{1d,2d,3d} - Updated tests --- functorch/csrc/BatchRulesDecompositions.cpp | 3 +++ functorch/csrc/BatchRulesPooling.cpp | 5 +++++ test/test_ops.py | 4 ---- test/test_vmap.py | 4 ---- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/functorch/csrc/BatchRulesDecompositions.cpp b/functorch/csrc/BatchRulesDecompositions.cpp index 1528e914e..eb230315e 100644 --- a/functorch/csrc/BatchRulesDecompositions.cpp +++ b/functorch/csrc/BatchRulesDecompositions.cpp @@ -18,7 +18,10 @@ namespace at { namespace functorch { TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { OP_DECOMPOSE(absolute); + OP_DECOMPOSE(avg_pool1d); + OP_DECOMPOSE(adaptive_avg_pool1d); OP_DECOMPOSE(adaptive_avg_pool2d); + OP_DECOMPOSE(adaptive_avg_pool3d); OP_DECOMPOSE(arccos); OP_DECOMPOSE(arccosh); OP_DECOMPOSE(arcsin); diff --git a/functorch/csrc/BatchRulesPooling.cpp b/functorch/csrc/BatchRulesPooling.cpp index 071df32ef..df144bf70 100644 --- a/functorch/csrc/BatchRulesPooling.cpp +++ b/functorch/csrc/BatchRulesPooling.cpp @@ -84,8 +84,13 @@ max_pool2d_with_indices_batch_rule( TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { EXISTING_BDIM(_adaptive_avg_pool2d); + EXISTING_BDIM_ALL_BOXED(_adaptive_avg_pool2d_backward); + EXISTING_BDIM(_adaptive_avg_pool3d); + EXISTING_BDIM_ALL_BOXED(_adaptive_avg_pool3d_backward); EXISTING_BDIM(avg_pool2d); + EXISTING_BDIM(avg_pool3d); EXISTING_BDIM_ALL_BOXED(avg_pool2d_backward); + EXISTING_BDIM_ALL_BOXED(avg_pool3d_backward); VMAP_SUPPORT("max_pool2d_with_indices", max_pool2d_with_indices_batch_rule); VMAP_SUPPORT("max_pool2d_with_indices_backward", max_pool2d_with_indices_backward_batch_rule); } diff --git a/test/test_ops.py b/test/test_ops.py index d55603579..9e0e697ef 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -461,7 +461,6 @@ def test_vmapvjp(self, device, dtype, op): xfail('msort'), xfail('nanmedian'), xfail('nanquantile'), - xfail('nn.functional.adaptive_avg_pool2d'), xfail('nn.functional.conv_transpose2d'), xfail('nn.functional.gelu'), xfail('nn.functional.grid_sample'), @@ -505,9 +504,6 @@ def test_vmapvjp(self, device, dtype, op): xfail('fft.ihfft2'), xfail('fft.ihfftn'), xfail('fft.rfft2'), - xfail('nn.functional.adaptive_avg_pool1d'), - xfail('nn.functional.adaptive_avg_pool3d'), - xfail('nn.functional.avg_pool3d'), xfail('nn.functional.embedding'), })) def test_vmapvjp_has_batch_rule(self, device, dtype, op): diff --git a/test/test_vmap.py b/test/test_vmap.py index 8cffb0674..d9f7ab481 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3143,10 +3143,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('fft.rfft2'), xfail('isinf'), xfail('isreal'), - xfail('nn.functional.adaptive_avg_pool1d'), - xfail('nn.functional.adaptive_avg_pool3d'), - xfail('nn.functional.avg_pool1d'), - xfail('nn.functional.avg_pool3d'), xfail('nn.functional.pixel_shuffle'), xfail('nn.functional.pixel_unshuffle'), })) From 72d49fec8f4ff6b1cc801d87cf70be96c9a8c93b Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 8 Nov 2021 17:47:05 +0000 Subject: [PATCH 2/2] Enabled nn.functional.interpolate mode=area --- test/test_ops.py | 1 - test/test_vmap.py | 1 - 2 files changed, 2 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 9e0e697ef..1704e062c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -464,7 +464,6 @@ def test_vmapvjp(self, device, dtype, op): xfail('nn.functional.conv_transpose2d'), xfail('nn.functional.gelu'), xfail('nn.functional.grid_sample'), - xfail('nn.functional.interpolate', 'area'), xfail('nn.functional.pad', 'circular'), xfail('nn.functional.pad', 'reflect'), xfail('nn.functional.pad', 'replicate'), diff --git a/test/test_vmap.py b/test/test_vmap.py index d9f7ab481..5d47d65d4 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3094,7 +3094,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('nn.functional.cross_entropy', 'mean'), xfail('nn.functional.cross_entropy', 'none'), xfail('nn.functional.cross_entropy', 'sum'), - xfail('nn.functional.interpolate', 'area'), xfail('nn.functional.pad', 'circular'), xfail('nn.functional.unfold'), xfail('norm', 'fro'),