Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 31 additions & 36 deletions src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <ATen/native/Pool.h>

#include <ATen/native/xpu/sycl/AveragePool2dKernels.h>
#include <ATen/native/xpu/sycl/KernelUtils.h>
#include <comm/Runtime.h>
#include <comm/SYCLContext.h>
#include <comm/SYCLHelpers.h>
Expand All @@ -25,9 +26,7 @@ inline int max(int a, int b) {
template <typename scalar_t, typename accscalar_t, typename index_t>
struct AvgPool2dKernelFunctor {
void operator()(sycl::nd_item<1> item) const {
index_t index = item.get_global_linear_id();

if (index < total_elements_) {
XPU_KERNEL_LOOP(item, index, total_elements_) {
const int pw = index % pooled_width_;
const int ph = (index / pooled_width_) % pooled_height_;
const int c = (index / pooled_width_ / pooled_height_) % channels_;
Expand Down Expand Up @@ -73,19 +72,19 @@ struct AvgPool2dKernelFunctor {
AvgPool2dKernelFunctor(
scalar_t* top_data,
const scalar_t* bottom_data,
index_t total_elements,
index_t channels,
index_t height,
index_t width,
int pooled_height,
int pooled_width,
int kernel_h,
int kernel_w,
int stride_h,
int stride_w,
int pad_h,
int pad_w,
int divisor_override,
const int total_elements,
const int64_t channels,
const int64_t height,
const int64_t width,
const int64_t pooled_height,
const int pooled_width,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_h,
const int pad_w,
const int divisor_override,
bool count_include_pad,
bool use_divisor)
: top_data_(top_data),
Expand All @@ -109,29 +108,27 @@ struct AvgPool2dKernelFunctor {
private:
scalar_t* top_data_;
const scalar_t* bottom_data_;
index_t total_elements_;
index_t channels_;
index_t height_;
index_t width_;
int pooled_height_;
int pooled_width_;
int kernel_h_;
int kernel_w_;
int stride_h_;
int stride_w_;
int pad_h_;
int pad_w_;
int divisor_override_;
const int total_elements_;
const int64_t channels_;
const int64_t height_;
const int64_t width_;
const int64_t pooled_height_;
const int pooled_width_;
const int kernel_h_;
const int kernel_w_;
const int stride_h_;
const int stride_w_;
const int pad_h_;
const int pad_w_;
const int divisor_override_;
bool count_include_pad_;
bool use_divisor_;
};

template <typename scalar_t, typename accscalar_t, typename index_t>
struct AvgPool2dChannelsLastKernelFunctor {
void operator()(sycl::nd_item<1> item) const {
index_t index = item.get_global_linear_id();

if (index < total_elements_) {
XPU_KERNEL_LOOP(item, index, total_elements_) {
const int c = index % channels_;
const int pw = (index / channels_) % pooled_width_;
const int ph = (index / channels_ / pooled_width_) % pooled_height_;
Expand Down Expand Up @@ -327,8 +324,7 @@ void launch_avg_pool2d_kernel(
template <typename scalar_t, typename accscalar_t, typename index_t>
struct AvgPool2dChannelsLastBackwardKernelFunctor {
void operator()(sycl::nd_item<1> item) const {
index_t index = item.get_global_linear_id();
if (index < total_elements_) {
XPU_KERNEL_LOOP_TYPE(item, index, total_elements_, index_t) {
const int c = index % channels_;
const int w = (index / channels_) % width_ + pad_w_;
const int h = (index / channels_ / width_) % height_ + pad_h_;
Expand Down Expand Up @@ -431,8 +427,7 @@ struct AvgPool2dChannelsLastBackwardKernelFunctor {
template <typename scalar_t, typename accscalar_t, typename index_t>
struct AvgPool2dBackwarKernelFunctor {
void operator()(sycl::nd_item<1> item) const {
index_t index = item.get_global_linear_id();
if (index < total_elements_) {
XPU_KERNEL_LOOP_TYPE(item, index, total_elements_, index_t) {
// find out the local index
// find out the local offset
const int w = index % width_ + pad_w_;
Expand Down