File tree Expand file tree Collapse file tree 1 file changed +17
-0
lines changed
aten/src/ATen/native/cuda Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Original file line number Diff line number Diff line change @@ -306,13 +306,30 @@ __global__ void batch_norm_collect_statistics_kernel(
306306 stat_accscalar_t var_n = 0 ;
307307 int n = 0 ;
308308 for (int batch = threadIdx .y ; batch < input.size (0 ); batch += blockDim .y ) {
309+ #if defined(USE_ROCM)
310+ constexpr int UNRL = 4 ;
311+ stat_accscalar_t v_[UNRL];
312+ for (int x = threadIdx .x ; x < input.size (2 ); x += blockDim .x *UNRL) {
313+ for (int u = 0 ; u < UNRL; u++)
314+ v_[u] = input[batch][plane][min (x+u*blockDim .x , input.size (2 )-1 )];
315+ for (int u = 0 ; u < UNRL; u++) {
316+ if (x+u*blockDim .x < input.size (2 )) {
317+ stat_accscalar_t d1 = v_[u] - avg;
318+ n++;
319+ avg += d1 / n;
320+ var_n += d1 * (v_[u] - avg);
321+ }
322+ }
323+ }
324+ #else
309325 for (int x = threadIdx .x ; x < input.size (2 ); x += blockDim .x ) {
310326 stat_accscalar_t v = input[batch][plane][x];
311327 stat_accscalar_t d1 = v - avg;
312328 n++;
313329 avg += d1 / n;
314330 var_n += d1 * (v - avg);
315331 }
332+ #endif
316333 }
317334
318335 // first warpSum to get one value per thread to
You can’t perform that action at this time.
0 commit comments