Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grouped conv backward data GKCYX support #2029

Merged
merged 9 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
### Added

* Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data
* Added support for GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW).
* Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW).
* Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced).
* Added support for Stream-K version of mixed fp8/bf16 GEMM

### Optimized

None
Expand All @@ -22,6 +25,8 @@ None
* Removed support for gfx940 and gfx941 targets (#1944)
* Replaced the raw buffer load/store intrinsics with Clang20 built-ins (#1876)
* DL and DPP kernels are now enabled by default.
* Number of instances in instance factory for grouped convolution forward NGCHW/GKYXC/NGKHW has been reduced.
* Number of instances in instance factory for grouped convolution backward data NGCHW/GKYXC/NGKHW has been reduced.

### Known issues

Expand Down

Large diffs are not rendered by default.

113 changes: 113 additions & 0 deletions include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,119 @@ __global__ void
}
}

template <typename GridwiseElementwiseFunctorA,
typename GridwiseElementwiseFunctorB,
typename InAGridDescTuple,
typename InBGridDescTuple,
typename OutAGridDescTuple,
typename OutBGridDescTuple,
typename InADataTypePointerTuple,
typename InBDataTypePointerTuple,
typename OutADataTypePointerTuple,
typename OutBDataTypePointerTuple,
typename Block2TileMapA,
typename Block2TileMapB,
typename ElementwiseOperation,
index_t NumInputsA,
index_t NumInputsB,
index_t NumOutputsA,
index_t NumOutputsB>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_elementwise_batched_dual(
const InAGridDescTuple in_grid_desc_tuple_a,
const InBGridDescTuple in_grid_desc_tuple_b,
const OutAGridDescTuple out_grid_desc_tuple_a,
const OutBGridDescTuple out_grid_desc_tuple_b,
const InADataTypePointerTuple p_in_global_tuple_a,
const InBDataTypePointerTuple p_in_global_tuple_b,
const OutADataTypePointerTuple p_out_global_tuple_a,
const OutBDataTypePointerTuple p_out_global_tuple_b,
const Block2TileMapA block_2_tile_map_a,
const Block2TileMapB block_2_tile_map_b,
const ElementwiseOperation elementwise_op,
const index_t a_grid_size,
const index_t batch_count_a,
const index_t batch_count_b,
const std::array<index_t, NumInputsA> input_batch_strides_a,
const std::array<index_t, NumInputsB> input_batch_strides_b,
const std::array<index_t, NumOutputsA> output_batch_strides_a,
const std::array<index_t, NumOutputsB> output_batch_strides_b)
{
static_assert(InAGridDescTuple::Size() == NumInputsA &&
InADataTypePointerTuple::Size() == NumInputsA);
static_assert(OutAGridDescTuple::Size() == NumOutputsA &&
OutADataTypePointerTuple::Size() == NumOutputsA);
static_assert(InBGridDescTuple::Size() == NumInputsB &&
InBDataTypePointerTuple::Size() == NumInputsB);
static_assert(OutBGridDescTuple::Size() == NumOutputsB &&
OutBDataTypePointerTuple::Size() == NumOutputsB);

const index_t block_id = __builtin_amdgcn_readfirstlane(get_block_1d_id());

if(block_id < a_grid_size)
{
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(a_grid_size / batch_count_a);
const index_t g_idx = __builtin_amdgcn_readfirstlane(block_id / num_blocks_per_batch);

InADataTypePointerTuple p_in_global_with_offset_tuple;
OutADataTypePointerTuple p_out_global_with_offset_tuple;

static_for<0, InADataTypePointerTuple::Size(), 1>{}([&](auto i) {
p_in_global_with_offset_tuple(i) =
p_in_global_tuple_a.At(i) +
type_convert<long_index_t>(input_batch_strides_a[i]) * g_idx;
});

static_for<0, OutADataTypePointerTuple::Size(), 1>{}([&](auto i) {
p_out_global_with_offset_tuple(i) =
p_out_global_tuple_a.At(i) +
type_convert<long_index_t>(output_batch_strides_a[i]) * g_idx;
});

GridwiseElementwiseFunctorA::Run(in_grid_desc_tuple_a,
out_grid_desc_tuple_a,
p_in_global_with_offset_tuple,
p_out_global_with_offset_tuple,
block_2_tile_map_a,
elementwise_op,
block_id);
}
else
{
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane((get_grid_size() - a_grid_size) / batch_count_b);
const index_t g_idx =
__builtin_amdgcn_readfirstlane((block_id - a_grid_size) / num_blocks_per_batch);

InBDataTypePointerTuple p_in_global_with_offset_tuple;
OutBDataTypePointerTuple p_out_global_with_offset_tuple;

static_for<0, InBDataTypePointerTuple::Size(), 1>{}([&](auto i) {
p_in_global_with_offset_tuple(i) =
p_in_global_tuple_b.At(i) +
type_convert<long_index_t>(input_batch_strides_b[i]) * g_idx;
});

static_for<0, OutBDataTypePointerTuple::Size(), 1>{}([&](auto i) {
p_out_global_with_offset_tuple(i) =
p_out_global_tuple_b.At(i) +
type_convert<long_index_t>(output_batch_strides_b[i]) * g_idx;
});

GridwiseElementwiseFunctorB::Run(in_grid_desc_tuple_b,
out_grid_desc_tuple_b,
p_in_global_with_offset_tuple,
p_out_global_with_offset_tuple,
block_2_tile_map_b,
elementwise_op,
block_id - a_grid_size);
}
}

template <typename GridwiseElementwiseFunctor,
typename InGridDescTuple,
typename OutGridDescTuple,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand Down Expand Up @@ -36,6 +36,24 @@ static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0;

// f16_f16_f32_f16
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_f16_generic_instances =
std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>
// clang-format on
>;

template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
Expand Down Expand Up @@ -73,6 +91,23 @@ using device_grouped_conv_bwd_data_xdl_f16_instances =
>;

// bf16_bf16_f32_bf16
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_bf16_generic_instances = std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>
// clang-format on
>;

template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
Expand Down Expand Up @@ -109,6 +144,24 @@ using device_grouped_conv_bwd_data_xdl_bf16_instances = std::tuple<
>;

// f32_f32_f32_f32
template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_xdl_f32_generic_instances =
std::tuple<
// clang-format off
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1>
// clang-format on
>;

template <index_t NDimSpatial,
typename ALayout,
typename BLayout,
Expand Down
Loading