From 6f3f26780371026c8b6200c2f946037ca0360374 Mon Sep 17 00:00:00 2001 From: Neo Zhang Jianyu Date: Fri, 19 Sep 2025 10:59:43 +0800 Subject: [PATCH] refactor for CAT by sycl free func --- src/ATen/native/xpu/sycl/Shape.cpp | 92 +++++++++++------------------- 1 file changed, 33 insertions(+), 59 deletions(-) diff --git a/src/ATen/native/xpu/sycl/Shape.cpp b/src/ATen/native/xpu/sycl/Shape.cpp index 12bd0ba66..a251a315f 100644 --- a/src/ATen/native/xpu/sycl/Shape.cpp +++ b/src/ATen/native/xpu/sycl/Shape.cpp @@ -20,8 +20,7 @@ namespace at::native::xpu { constexpr int CAT_ARRAY_BATCH_SIZE = 1024; // Maximum parallel dimension to supporte -constexpr int CAT_ARRAY_MAX_INPUT_DIMS = 5; - +#define CAT_ARRAY_MAX_INPUT_DIMS 5 // Similar to any other IndexToOffset calculation for copying along a given // dimension. template @@ -66,57 +65,37 @@ struct OutputTensorSizeStride { IndexType outputStride[MaxDims]; }; -template < - typename Tout, - typename underlying_out_t, - typename Tin, - typename underlying_in_t, - typename IndexType, - int Dims> -struct CatArrayBatchedCopyKernelFunctor { - void operator()(sycl::nd_item<2> item) const { - IndexType tid = item.get_global_id(1); - IndexType in = item.get_group(0); - IndexType nElements = inputs[in].nElements; +template +SYCL_EXT_ONEAPI_FUNCTION_PROPERTY((syclexp::nd_range_kernel<1>)) +void cat_array_batched_copy_kernel( + Tout *output, CatArrInputTensor *inputs, + OutputTensorSizeStride os, + const int concatDim, int dimStride) { + auto item = syclext::this_work_item::get_nd_item<2>(); + IndexType tid = item.get_global_id(1); + IndexType in = item.get_group(0); - if (tid >= nElements) - return; + IndexType nElements = inputs[in].nElements; - Tin* data = inputs[in].input; - IndexType offset = inputs[in].offset; - IndexType dimSize = inputs[in].dimSize; - IndexType dataOffset = offset * dimStride; + if (tid >= nElements) + return; - IndexType stride = item.get_global_range(1); + Tin *data = inputs[in].input; + IndexType offset = inputs[in].offset; + IndexType dimSize = inputs[in].dimSize; + IndexType dataOffset = offset * dimStride; - while (tid < nElements) { - IndexType elementOffset = CatArrIndexToOffset::compute( - os.outputSize, os.outputStride, dimSize, concatDim, tid); - output[dataOffset + elementOffset] = data[tid]; - tid += stride; - } - } + IndexType stride = item.get_global_range(1); - CatArrayBatchedCopyKernelFunctor( - Tout* output_, - CatArrInputTensor* inputs_, - OutputTensorSizeStride os_, - const int concatDim_, - IndexType dimStride_) - : output(output_), - inputs(inputs_), - os(os_), - concatDim(concatDim_), - dimStride(dimStride_) {} - - private: - Tout* output; - CatArrInputTensor* inputs; - OutputTensorSizeStride os; - const int concatDim; - IndexType dimStride; -}; + while (tid < nElements) { + IndexType elementOffset = CatArrIndexToOffset::compute( + os.outputSize, os.outputStride, dimSize, concatDim, tid); + output[dataOffset + elementOffset] = data[tid]; + tid += stride; + } +} /** * Kernel used to concatenated grimDim.y tensors into an output tensor. Uses a @@ -146,19 +125,11 @@ void CatArrayBatchedCopy( const int concatDim, IndexType dimStride, int batchCounter) { - CatArrayBatchedCopyKernelFunctor< - Tout, - underlying_out_t, - Tin, - underlying_in_t, - IndexType, - Dims> - kfn(output, inputs, os, concatDim, dimStride); - // Get grid where x dim fills half gpu and y dim is number of tensors. // This will have cating two tensors fill the entire grid, but prevent // many threads from needlessly load meta data if their sizes is small. - int64_t numWI = syclMaxWorkGroupSize(kfn); + int64_t numWI = syclMaxWorkGroupSize>(); // We set limited numWG to prevent over schedule. // numWG = 512 EUs * 8 threads * SIMD lanes 32 / max_compute_units @@ -175,8 +146,11 @@ void CatArrayBatchedCopy( sycl::range<2> global_range(batchCounter, numWG * numWI); sycl::range<2> local_range(1, numWI); auto& q = getCurrentSYCLQueue(); - - sycl_kernel_submit(global_range, local_range, q, kfn); + sycl_kernel_submit< + cat_array_batched_copy_kernel, + 2>(global_range, local_range, q, 0, output, inputs, os, concatDim, + dimStride); } template <