diff --git a/src/ATen/native/xpu/sycl/SortingKernels.h b/src/ATen/native/xpu/sycl/SortingKernels.h index bd5972a99..aad93d9eb 100644 --- a/src/ATen/native/xpu/sycl/SortingKernels.h +++ b/src/ATen/native/xpu/sycl/SortingKernels.h @@ -316,6 +316,13 @@ void segmented_radix_sort_pairs_downsweep_kernel( // ======================= large sort ======================= +template +struct ABBufferCopyFunctor { + scalar_t operator()(scalar_t x) const { + return x; + } +}; + template < typename key_t, typename value_t, @@ -409,18 +416,20 @@ void segmented_radix_sort_pairs_kernel( auto input_calc = TrivialOffsetCalculator<2>(); at::detail::Array data; if (keys_out) { - auto q = at::xpu::getCurrentSYCLQueue(); - q.memcpy( - (void*)keys_out, - (void*)keys_temp, - sizeof(key_t) * num_segments * num_elements); + data[0] = (char*)keys_out; + data[1] = (char*)keys_temp; + auto fn = ABBufferCopyFunctor(); + auto vec_size = memory::can_vectorize_up_to(data); + launch_vectorized_kernel( + num_segments * num_elements, fn, data, input_calc, vec_size); } if (values_out) { - auto q = at::xpu::getCurrentSYCLQueue(); - q.memcpy( - (void*)values_out, - (void*)values_temp, - sizeof(value_t) * num_segments * num_elements); + data[0] = (char*)values_out; + data[1] = (char*)values_temp; + auto fn = ABBufferCopyFunctor(); + auto vec_size = memory::can_vectorize_up_to(data); + launch_vectorized_kernel( + num_segments * num_elements, fn, data, input_calc, vec_size); } } }