55#include < ATen/native/Pool.h>
66
77#include < ATen/native/xpu/sycl/AveragePool2dKernels.h>
8+ #include < ATen/native/xpu/sycl/KernelUtils.h>
89#include < comm/Runtime.h>
910#include < comm/SYCLContext.h>
1011#include < comm/SYCLHelpers.h>
@@ -25,9 +26,7 @@ inline int max(int a, int b) {
2526template <typename scalar_t , typename accscalar_t , typename index_t >
2627struct AvgPool2dKernelFunctor {
2728 void operator ()(sycl::nd_item<1 > item) const {
28- index_t index = item.get_global_linear_id ();
29-
30- if (index < total_elements_) {
29+ XPU_KERNEL_LOOP (item, index, total_elements_) {
3130 const int pw = index % pooled_width_;
3231 const int ph = (index / pooled_width_) % pooled_height_;
3332 const int c = (index / pooled_width_ / pooled_height_) % channels_;
@@ -73,19 +72,19 @@ struct AvgPool2dKernelFunctor {
7372 AvgPool2dKernelFunctor (
7473 scalar_t * top_data,
7574 const scalar_t * bottom_data,
76- index_t total_elements,
77- index_t channels,
78- index_t height,
79- index_t width,
80- int pooled_height,
81- int pooled_width,
82- int kernel_h,
83- int kernel_w,
84- int stride_h,
85- int stride_w,
86- int pad_h,
87- int pad_w,
88- int divisor_override,
75+ const int total_elements,
76+ const int64_t channels,
77+ const int64_t height,
78+ const int64_t width,
79+ const int64_t pooled_height,
80+ const int pooled_width,
81+ const int kernel_h,
82+ const int kernel_w,
83+ const int stride_h,
84+ const int stride_w,
85+ const int pad_h,
86+ const int pad_w,
87+ const int divisor_override,
8988 bool count_include_pad,
9089 bool use_divisor)
9190 : top_data_(top_data),
@@ -109,29 +108,27 @@ struct AvgPool2dKernelFunctor {
109108 private:
110109 scalar_t * top_data_;
111110 const scalar_t * bottom_data_;
112- index_t total_elements_;
113- index_t channels_;
114- index_t height_;
115- index_t width_;
116- int pooled_height_;
117- int pooled_width_;
118- int kernel_h_;
119- int kernel_w_;
120- int stride_h_;
121- int stride_w_;
122- int pad_h_;
123- int pad_w_;
124- int divisor_override_;
111+ const int total_elements_;
112+ const int64_t channels_;
113+ const int64_t height_;
114+ const int64_t width_;
115+ const int64_t pooled_height_;
116+ const int pooled_width_;
117+ const int kernel_h_;
118+ const int kernel_w_;
119+ const int stride_h_;
120+ const int stride_w_;
121+ const int pad_h_;
122+ const int pad_w_;
123+ const int divisor_override_;
125124 bool count_include_pad_;
126125 bool use_divisor_;
127126};
128127
129128template <typename scalar_t , typename accscalar_t , typename index_t >
130129struct AvgPool2dChannelsLastKernelFunctor {
131130 void operator ()(sycl::nd_item<1 > item) const {
132- index_t index = item.get_global_linear_id ();
133-
134- if (index < total_elements_) {
131+ XPU_KERNEL_LOOP (item, index, total_elements_) {
135132 const int c = index % channels_;
136133 const int pw = (index / channels_) % pooled_width_;
137134 const int ph = (index / channels_ / pooled_width_) % pooled_height_;
@@ -327,8 +324,7 @@ void launch_avg_pool2d_kernel(
327324template <typename scalar_t , typename accscalar_t , typename index_t >
328325struct AvgPool2dChannelsLastBackwardKernelFunctor {
329326 void operator ()(sycl::nd_item<1 > item) const {
330- index_t index = item.get_global_linear_id ();
331- if (index < total_elements_) {
327+ XPU_KERNEL_LOOP_TYPE (item, index, total_elements_, index_t ) {
332328 const int c = index % channels_;
333329 const int w = (index / channels_) % width_ + pad_w_;
334330 const int h = (index / channels_ / width_) % height_ + pad_h_;
@@ -431,8 +427,7 @@ struct AvgPool2dChannelsLastBackwardKernelFunctor {
431427template <typename scalar_t , typename accscalar_t , typename index_t >
432428struct AvgPool2dBackwarKernelFunctor {
433429 void operator ()(sycl::nd_item<1 > item) const {
434- index_t index = item.get_global_linear_id ();
435- if (index < total_elements_) {
430+ XPU_KERNEL_LOOP_TYPE (item, index, total_elements_, index_t ) {
436431 // find out the local index
437432 // find out the local offset
438433 const int w = index % width_ + pad_w_;
0 commit comments