diff --git a/functorch/csrc/BatchRulesScatterOps.cpp b/functorch/csrc/BatchRulesScatterOps.cpp index 1fde93293..472d2519e 100644 --- a/functorch/csrc/BatchRulesScatterOps.cpp +++ b/functorch/csrc/BatchRulesScatterOps.cpp @@ -433,6 +433,15 @@ Tensor index_copy_decomp( return at::scatter(self, dim, index_, source); ; } +Tensor slice_scatter_decomp(const Tensor &self, const Tensor &src, + int64_t dim, c10::optional start, + c10::optional end, int64_t step) +{ + auto idx = at::arange(start.value_or(0), end.value_or(self.size(dim)), step, self.options().dtype(kLong)); + idx = get_expanded_index(idx, self.sizes(), dim); + return at::scatter(self, dim, idx, src); +} + Tensor select_scatter_decomp( const Tensor &self, const Tensor &source, int64_t dim, int64_t index) @@ -447,6 +456,7 @@ Tensor select_scatter_decomp( TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { m.impl("index.Tensor", index_plumbing); m.impl("index_put_", index_put__plumbing); + m.impl("slice_scatter", slice_scatter_decomp); m.impl("select_scatter", select_scatter_decomp); m.impl("index_copy", index_copy_decomp); m.impl("index_select", index_select_decomp); diff --git a/test/test_ops.py b/test/test_ops.py index 28ddc97b9..aa7294f52 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -665,7 +665,6 @@ def test_vmapjvp(self, device, dtype, op): xfail('nn.functional.huber_loss'), xfail('nn.functional.instance_norm'), xfail('nn.functional.poisson_nll_loss'), - xfail('slice_scatter'), })) def test_vmapvjp_has_batch_rule(self, device, dtype, op): # These are too annoying to put into the list above diff --git a/test/test_vmap.py b/test/test_vmap.py index 66619e56a..4973d98f4 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3191,7 +3191,6 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail('randint_like'), xfail('searchsorted'), xfail('short', 'channels_last'), - xfail('slice_scatter'), xfail('unique_consecutive'), xfail('unique'), xfail('nn.functional.conv1d'),