diff --git a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp index 4caa4a7ead..cf856d6177 100644 --- a/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp +++ b/src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -25,9 +26,7 @@ inline int max(int a, int b) { template struct AvgPool2dKernelFunctor { void operator()(sycl::nd_item<1> item) const { - index_t index = item.get_global_linear_id(); - - if (index < total_elements_) { + XPU_KERNEL_LOOP(item, index, total_elements_) { const int pw = index % pooled_width_; const int ph = (index / pooled_width_) % pooled_height_; const int c = (index / pooled_width_ / pooled_height_) % channels_; @@ -73,19 +72,19 @@ struct AvgPool2dKernelFunctor { AvgPool2dKernelFunctor( scalar_t* top_data, const scalar_t* bottom_data, - index_t total_elements, - index_t channels, - index_t height, - index_t width, - int pooled_height, - int pooled_width, - int kernel_h, - int kernel_w, - int stride_h, - int stride_w, - int pad_h, - int pad_w, - int divisor_override, + const int total_elements, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t pooled_height, + const int pooled_width, + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int divisor_override, bool count_include_pad, bool use_divisor) : top_data_(top_data), @@ -109,19 +108,19 @@ struct AvgPool2dKernelFunctor { private: scalar_t* top_data_; const scalar_t* bottom_data_; - index_t total_elements_; - index_t channels_; - index_t height_; - index_t width_; - int pooled_height_; - int pooled_width_; - int kernel_h_; - int kernel_w_; - int stride_h_; - int stride_w_; - int pad_h_; - int pad_w_; - int divisor_override_; + const int total_elements_; + const int64_t channels_; + const int64_t height_; + const int64_t width_; + const int64_t pooled_height_; + const int pooled_width_; + const int kernel_h_; + const int kernel_w_; + const int stride_h_; + const int stride_w_; + const int pad_h_; + const int pad_w_; + const int divisor_override_; bool count_include_pad_; bool use_divisor_; }; @@ -129,9 +128,7 @@ struct AvgPool2dKernelFunctor { template struct AvgPool2dChannelsLastKernelFunctor { void operator()(sycl::nd_item<1> item) const { - index_t index = item.get_global_linear_id(); - - if (index < total_elements_) { + XPU_KERNEL_LOOP(item, index, total_elements_) { const int c = index % channels_; const int pw = (index / channels_) % pooled_width_; const int ph = (index / channels_ / pooled_width_) % pooled_height_; @@ -327,8 +324,7 @@ void launch_avg_pool2d_kernel( template struct AvgPool2dChannelsLastBackwardKernelFunctor { void operator()(sycl::nd_item<1> item) const { - index_t index = item.get_global_linear_id(); - if (index < total_elements_) { + XPU_KERNEL_LOOP_TYPE(item, index, total_elements_, index_t) { const int c = index % channels_; const int w = (index / channels_) % width_ + pad_w_; const int h = (index / channels_ / width_) % height_ + pad_h_; @@ -431,8 +427,7 @@ struct AvgPool2dChannelsLastBackwardKernelFunctor { template struct AvgPool2dBackwarKernelFunctor { void operator()(sycl::nd_item<1> item) const { - index_t index = item.get_global_linear_id(); - if (index < total_elements_) { + XPU_KERNEL_LOOP_TYPE(item, index, total_elements_, index_t) { // find out the local index // find out the local offset const int w = index % width_ + pad_w_;