Skip to content

Commit 3658645

Browse files
[ROCm] deserialize loads in planer sum portion of stats() of norm
Cherry-pick of #2743 Co-authored-by: Jerry Mannil <[email protected]>
1 parent a35461f commit 3658645

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

aten/src/ATen/native/cuda/Normalization.cuh

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,13 +306,30 @@ __global__ void batch_norm_collect_statistics_kernel(
306306
stat_accscalar_t var_n = 0;
307307
int n = 0;
308308
for (int batch = threadIdx.y; batch < input.size(0); batch += blockDim.y) {
309+
#if defined(USE_ROCM)
310+
constexpr int UNRL = 4;
311+
stat_accscalar_t v_[UNRL];
312+
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x*UNRL) {
313+
for (int u = 0; u < UNRL; u++)
314+
v_[u] = input[batch][plane][min(x+u*blockDim.x, input.size(2)-1)];
315+
for (int u = 0; u < UNRL; u++) {
316+
if (x+u*blockDim.x < input.size(2)) {
317+
stat_accscalar_t d1 = v_[u] - avg;
318+
n++;
319+
avg += d1 / n;
320+
var_n += d1 * (v_[u] - avg);
321+
}
322+
}
323+
}
324+
#else
309325
for (int x = threadIdx.x; x < input.size(2); x += blockDim.x) {
310326
stat_accscalar_t v = input[batch][plane][x];
311327
stat_accscalar_t d1 = v - avg;
312328
n++;
313329
avg += d1 / n;
314330
var_n += d1 * (v - avg);
315331
}
332+
#endif
316333
}
317334

318335
// first warpSum to get one value per thread to

0 commit comments

Comments
 (0)