From 8394dabd2609e814fbab3d64939d2ba5b0a7f47e Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 8 Nov 2021 22:02:24 +0000 Subject: [PATCH 1/3] Fixed nn.functional.pad constant mode Description: - Fixed nn.functional.pad constant mode - Updated tests --- functorch/csrc/BatchRulesModules.cpp | 2 +- test/test_ops.py | 2 -- test/test_vmap.py | 1 - 3 files changed, 1 insertion(+), 4 deletions(-) diff --git a/functorch/csrc/BatchRulesModules.cpp b/functorch/csrc/BatchRulesModules.cpp index aa733fa1c..115fbb8e9 100644 --- a/functorch/csrc/BatchRulesModules.cpp +++ b/functorch/csrc/BatchRulesModules.cpp @@ -401,7 +401,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { VMAP_SUPPORT("cudnn_grid_sampler", GRID_SAMPLE_BATCH_RULE(cudnn_grid_sampler)); VMAP_SUPPORT("cross", cross_batch_rule); - UNARY_POINTWISE(constant_pad_nd); + VMAP_SUPPORT("constant_pad_nd", VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(constant_pad_nd))); EXISTING_BDIM(reflection_pad1d); EXISTING_BDIM(reflection_pad2d); EXISTING_BDIM(reflection_pad3d); diff --git a/test/test_ops.py b/test/test_ops.py index d55603579..1636c44b2 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -310,7 +310,6 @@ def vjp_of_vjp(*args_and_cotangents): xfail('diag_embed'), xfail('eig'), xfail('nn.functional.conv_transpose2d'), - xfail('nn.functional.pad', 'constant'), xfail('view_as_complex'), xfail('fft.fft'), xfail('fft.ifft'), @@ -389,7 +388,6 @@ def test_vmapvjp(self, device, dtype, op): @ops(functorch_lagging_op_db + additional_op_db, allowed_dtypes=(torch.float,)) @skipOps('TestOperators', 'test_vmapvjp_has_batch_rule', vmapvjp_fail.union({ - xfail('nn.functional.pad', 'constant'), xfail('view_as_complex'), xfail('__getitem__'), xfail('__rpow__'), diff --git a/test/test_vmap.py b/test/test_vmap.py index 8cffb0674..e641a120c 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3020,7 +3020,6 @@ class TestVmapOperatorsOpInfo(TestCase): xfail('fft.rfftn'), xfail('nn.functional.batch_norm'), xfail('lu_unpack'), - xfail('nn.functional.pad', 'constant'), xfail('empty_like'), xfail('histogramdd'), xfail('nn.functional.embedding'), From 4c3d2bd68c427a0e7f0e6b5cdfd40b17efc18ce9 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 8 Nov 2021 22:52:46 +0000 Subject: [PATCH 2/3] Fixed issues with unexpected failures for fft tests --- test/test_ops.py | 12 ------------ test/test_vmap.py | 5 ----- 2 files changed, 17 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 1636c44b2..277447b4f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -311,15 +311,11 @@ def vjp_of_vjp(*args_and_cotangents): xfail('eig'), xfail('nn.functional.conv_transpose2d'), xfail('view_as_complex'), - xfail('fft.fft'), - xfail('fft.ifft'), xfail('fft.ihfft'), xfail('fft.ihfft'), xfail('fft.rfft'), xfail('fft.rfft'), - xfail('fft.fftn'), xfail('fft.rfftn'), - xfail('fft.ifftn'), xfail('cdist'), xfail('fmax'), xfail('fmin'), @@ -357,8 +353,6 @@ def vjp_of_vjp(*args_and_cotangents): xfail('nanmean'), xfail('block_diag'), xfail('nn.functional.dropout'), - xfail('fft.fft2'), - xfail('fft.ifft2'), xfail('fft.ihfft2'), xfail('fft.ihfftn'), xfail('fft.rfft2'), @@ -405,10 +399,6 @@ def test_vmapvjp(self, device, dtype, op): xfail('diag'), xfail('diag_embed'), xfail('eig'), - xfail('fft.fft'), - xfail('fft.fftn'), - xfail('fft.ifft'), - xfail('fft.ifftn'), xfail('fft.ihfft'), xfail('fft.rfft'), xfail('fft.rfftn'), @@ -498,8 +488,6 @@ def test_vmapvjp(self, device, dtype, op): xfail('_masked.sum'), xfail('_masked.prod'), xfail('cholesky_solve'), - xfail('fft.fft2'), - xfail('fft.ifft2'), xfail('fft.ihfft2'), xfail('fft.ihfftn'), xfail('fft.rfft2'), diff --git a/test/test_vmap.py b/test/test_vmap.py index e641a120c..d4d5a42a8 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3013,11 +3013,6 @@ class TestVmapOperatorsOpInfo(TestCase): xfail('linalg.svd', device_type='cuda'), xfail('index_put'), xfail('matrix_exp'), - xfail('fft.fft'), - xfail('fft.ifft'), - xfail('fft.ihfft'), - xfail('fft.rfft'), - xfail('fft.rfftn'), xfail('nn.functional.batch_norm'), xfail('lu_unpack'), xfail('empty_like'), From 6c363f6ca1804fd2083c910773bacab437ea3baa Mon Sep 17 00:00:00 2001 From: vfdev Date: Tue, 9 Nov 2021 00:37:15 +0100 Subject: [PATCH 3/3] Update BatchRulesModules.cpp --- functorch/csrc/BatchRulesModules.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/functorch/csrc/BatchRulesModules.cpp b/functorch/csrc/BatchRulesModules.cpp index 115fbb8e9..da746d675 100644 --- a/functorch/csrc/BatchRulesModules.cpp +++ b/functorch/csrc/BatchRulesModules.cpp @@ -401,7 +401,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { VMAP_SUPPORT("cudnn_grid_sampler", GRID_SAMPLE_BATCH_RULE(cudnn_grid_sampler)); VMAP_SUPPORT("cross", cross_batch_rule); - VMAP_SUPPORT("constant_pad_nd", VARIADIC_BDIMS_BATCH_RULE(ATEN_FN(constant_pad_nd))); + VARIADIC_BDIMS(constant_pad_nd); EXISTING_BDIM(reflection_pad1d); EXISTING_BDIM(reflection_pad2d); EXISTING_BDIM(reflection_pad3d);