Skip to content

Commit a35461f

Browse files
[ROCm] Deserialize loads in planer sum portion of reduce() of norm
Cherry-pick of #2740 Co-authored-by: Jerry Mannil <[email protected]>
1 parent b067d8f commit a35461f

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

aten/src/ATen/native/cuda/Normalization.cuh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,23 @@ __device__ scalar_t reduce(Op op, PTA tensor, int plane) {
115115
// first the reductions each thread does separately
116116
scalar_t sum = static_cast<scalar_t>(0);
117117
for (int batch = threadIdx.y; batch < tensor.size(0); batch += blockDim.y) {
118+
#if defined(USE_ROCM)
119+
constexpr int UNRL = 4; // load deserilize factor
120+
scalar_t tmp[UNRL];
121+
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x*UNRL) {
122+
#pragma unroll
123+
for (int u = 0; u < UNRL; u++)
124+
tmp[u] = op(batch, plane, min((int)tensor.size(2)-1, (int)(x+u*blockDim.x)));
125+
#pragma unroll
126+
for (int u = 0; u < UNRL; u++)
127+
if (x+u*blockDim.x < tensor.size(2))
128+
sum += tmp[u];
129+
}
130+
#else
118131
for (int x = threadIdx.x; x < tensor.size(2); x += blockDim.x) {
119132
sum += op(batch, plane, x);
120133
}
134+
#endif
121135
}
122136
__shared__ scalar_t shared[C10_WARP_SIZE];
123137
SumReduceOp<scalar_t> reduce_op;

0 commit comments

Comments
 (0)