From 148760f8d07dab9092600e27a04710a09ccf7935 Mon Sep 17 00:00:00 2001 From: Glen Cao Date: Tue, 21 Oct 2025 03:44:36 -0700 Subject: [PATCH] Optimized BiLiear 2D Up Sampling for AMD MI devices --- .../ATen/native/cuda/UpSampleBilinear2d.cu | 103 +++++++++++++++++- 1 file changed, 101 insertions(+), 2 deletions(-) diff --git a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu index b891750891d58..b46bbaa6500b9 100644 --- a/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu +++ b/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu @@ -127,6 +127,29 @@ __global__ void upsample_bilinear2d_nhwc_out_frame( } } +#ifdef USE_ROCM +// Helper function to compute output pixel range that can contribute to input pixel +template +__device__ __forceinline__ void compute_output_range( + int input_pos, + accscalar_t scale, + int output_size, + bool align_corners, + int& min_output, + int& max_output) { + accscalar_t lo, hi; + if (align_corners) { + lo = static_cast(input_pos - 1) / scale; + hi = static_cast(input_pos + 1) / scale; + } else { + lo = (input_pos - static_cast(0.5)) / scale - static_cast(0.5); + hi = (input_pos + static_cast(1.5)) / scale - static_cast(0.5); + } + min_output = max(0, static_cast(std::ceil(lo))); + max_output = min(output_size - 1, static_cast(std::floor(hi))); +} +#endif + // Backward (adjoint) operation 1 <- 2 (accumulates) template C10_LAUNCH_BOUNDS_1(1024) @@ -141,8 +164,74 @@ __global__ void upsample_bilinear2d_backward_out_frame( const bool align_corners, scalar_t* __restrict__ idata, const scalar_t* __restrict__ odata) { - const size_t o_numel = nc * width2 * height2; + // In C++, integer multiplication, like in standard arithmetic, is generally commutative. const size_t i_numel = nc * width1 * height1; +#ifdef USE_ROCM + for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < i_numel; + index += blockDim.x * gridDim.x) { + // Decode input pixel coordinates + size_t index_temp = index; + const int w1 = index_temp % width1; + index_temp /= width1; + const int h1 = index_temp % height1; + const size_t nc_idx = index_temp / height1; + + accscalar_t grad_sum = 0; + + // Find range of output pixels that could interpolate from this input pixel + int h2_min, h2_max, w2_min, w2_max; + compute_output_range(h1, rheight, height2, align_corners, h2_min, h2_max); + compute_output_range(w1, rwidth, width2, align_corners, w2_min, w2_max); + + // Iterate over potential output pixels + for (int h2 = h2_min; h2 <= h2_max; h2++) { + for (int w2 = w2_min; w2 <= w2_max; w2++) { + // Compute source coordinates for this output pixel + const accscalar_t h1r = area_pixel_compute_source_index( + rheight, h2, align_corners, /*cubic=*/false); + const int h1_base = (int)h1r; + const int h1p = (h1_base < height1 - 1) ? 1 : 0; + const accscalar_t h1lambda = h1r - h1_base; + const accscalar_t h0lambda = static_cast(1) - h1lambda; + + const accscalar_t w1r = area_pixel_compute_source_index( + rwidth, w2, align_corners, /*cubic=*/false); + const int w1_base = (int)w1r; + const int w1p = (w1_base < width1 - 1) ? 1 : 0; + const accscalar_t w1lambda = w1r - w1_base; + const accscalar_t w0lambda = static_cast(1) - w1lambda; + + // Check if our input pixel participates in this interpolation and accumulate all weights + // At boundaries, h1p=0 or w1p=0 causes some sampling positions to collapse + // to the same pixel, so we need to accumulate weights from all matching positions + accscalar_t weight = 0; + + // Check all four interpolation positions and accumulate weights + if (h1 == h1_base && w1 == w1_base) { + weight += h0lambda * w0lambda; // top-left + } + if (h1 == h1_base && w1 == w1_base + w1p) { + weight += h0lambda * w1lambda; // top-right (may be same as top-left if w1p=0) + } + if (h1 == h1_base + h1p && w1 == w1_base) { + weight += h1lambda * w0lambda; // bottom-left (may be same as top-left if h1p=0) + } + if (h1 == h1_base + h1p && w1 == w1_base + w1p) { + weight += h1lambda * w1lambda; // bottom-right (may collapse to other positions) + } + + if (weight > 0) { + const size_t output_idx = nc_idx * height2 * width2 + h2 * width2 + w2; + grad_sum += weight * static_cast(odata[output_idx]); + } + } + } + + // Write accumulated gradient (no atomics needed) + idata[index] = static_cast(grad_sum); + } +#else + const size_t o_numel = nc * width2 * height2; for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; index < o_numel; index += blockDim.x * gridDim.x) { size_t index_temp = index; @@ -191,6 +280,7 @@ __global__ void upsample_bilinear2d_backward_out_frame( static_cast(h1lambda * w1lambda * d2val), true); } +#endif } template @@ -387,7 +477,6 @@ static void upsample_bilinear2d_backward_out_cuda_template( // threads are not covering the whole input tensor. grad_input.zero_(); - const size_t num_kernels = nbatch * channels * output_height * output_width; const int num_threads = std::min( at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -397,6 +486,12 @@ static void upsample_bilinear2d_backward_out_cuda_template( return; } +#ifdef USE_ROCM + constexpr bool use_input = true; +#else + constexpr bool use_input = false; +#endif + AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, grad_output_.scalar_type(), "upsample_bilinear2d_backward_out_frame", [&] { @@ -414,6 +509,8 @@ static void upsample_bilinear2d_backward_out_cuda_template( const accscalar_t rwidth = area_pixel_compute_scale( input_width, output_width, align_corners, scales_w); + const size_t num_kernels = nbatch * channels * output_height * output_width; + upsample_bilinear2d_backward_nhwc_out_frame <<(num_threads)), num_threads, 0, stream>>>( input_height, @@ -444,6 +541,8 @@ static void upsample_bilinear2d_backward_out_cuda_template( const accscalar_t rwidth = area_pixel_compute_scale( input_width, output_width, align_corners, scales_w); + const size_t num_kernels = nbatch * channels * (use_input ? input_height * input_width : output_height * output_width); + upsample_bilinear2d_backward_out_frame <<(num_threads)), num_threads,