From 99699846305c02d8e741de78d46c353a4082bedb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Fri, 10 Oct 2025 08:52:12 +0000 Subject: [PATCH 01/10] Merge fwd conv groups in CK Tile. --- .../grouped_convolution_forward_kernel.hpp | 86 +++++++++++++------ .../utils/transform_conv_fwd_to_gemm.hpp | 44 ++++++---- 2 files changed, 83 insertions(+), 47 deletions(-) diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 110ec2cb54..d5486c4493 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -66,9 +66,9 @@ struct GroupedConvFwdKernelArgs k_batch = args.k_batch; // GemmM will be set after Split-N calculation - GemmN = args.K_; - GemmK = args.C_ * args.filter_spatial_lengths_[0]; - GemmBatch = args.G_; + // GemmN = args.K_; + // GemmK = args.C_ * args.filter_spatial_lengths_[0]; + // GemmBatch = args.G_; in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; @@ -96,8 +96,9 @@ struct GroupedConvFwdKernelArgs conv_to_gemm_transformer .template MakeCDescriptor_M_N(); - group_stride_a = args.C_; - group_stride_b = args.K_ * args.C_ * + NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge; + group_stride_a = args.C_ * NumGroupsPerBatch; + group_stride_b = args.K_ * args.C_ * NumGroupsPerBatch * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end(), 1, @@ -114,8 +115,21 @@ struct GroupedConvFwdKernelArgs input_batch_stride = args.C_ * args.input_spatial_lengths_[0]; output_batch_stride = args.K_ * args.output_spatial_lengths_[0]; - // Update GemmM to use split N (not original N) - GemmM = n_per_split * args.output_spatial_lengths_[0]; + GemmM = a_grid_desc_m_k.get_length(number<0>{}); + GemmN = b_grid_desc_n_k.get_length(number<0>{}); + GemmK = a_grid_desc_m_k.get_length(number<1>{}); + GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch); + + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK + << ", GemmBatch: " << GemmBatch + << ", N per split: " << n_per_split + << ", number of N splits: " << n_splits + << ", input_batch_stride: " << input_batch_stride + << ", output_batch_stride: " << output_batch_stride + << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl; + } } template < @@ -156,11 +170,6 @@ struct GroupedConvFwdKernelArgs k_batch = args.k_batch; - // Note: GemmM will be set after Split-N calculation - GemmN = args.K_; - GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1]; - GemmBatch = args.G_; - in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; for(index_t d = 0; d < NumDTensor; d++) @@ -187,8 +196,9 @@ struct GroupedConvFwdKernelArgs conv_to_gemm_transformer .template MakeCDescriptor_M_N(); - group_stride_a = args.C_; - group_stride_b = args.K_ * args.C_ * + NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge; + group_stride_a = args.C_ * NumGroupsPerBatch; + group_stride_b = args.K_ * args.C_ * NumGroupsPerBatch * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end(), 1, @@ -207,8 +217,21 @@ struct GroupedConvFwdKernelArgs output_batch_stride = args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; - // Update GemmM to use split N (not original N) - GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; + GemmM = a_grid_desc_m_k.get_length(number<0>{}); + GemmN = b_grid_desc_n_k.get_length(number<0>{}); + GemmK = a_grid_desc_m_k.get_length(number<1>{}); + GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch); + + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK + << ", GemmBatch: " << GemmBatch + << ", N per split: " << n_per_split + << ", number of N splits: " << n_splits + << ", input_batch_stride: " << input_batch_stride + << ", output_batch_stride: " << output_batch_stride + << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl; + } } template < @@ -256,12 +279,6 @@ struct GroupedConvFwdKernelArgs k_batch = args.k_batch; - // Note: GemmM will be set after Split-N calculation - GemmN = args.K_; - GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] * - args.filter_spatial_lengths_[2]; - GemmBatch = args.G_; - in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; for(index_t d = 0; d < NumDTensor; d++) @@ -288,8 +305,9 @@ struct GroupedConvFwdKernelArgs conv_to_gemm_transformer .template MakeCDescriptor_M_N(); - group_stride_a = args.C_; - group_stride_b = args.K_ * args.C_ * + NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge; + group_stride_a = args.C_ * NumGroupsPerBatch; + group_stride_b = args.K_ * args.C_ * NumGroupsPerBatch * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end(), 1, @@ -308,11 +326,22 @@ struct GroupedConvFwdKernelArgs output_batch_stride = args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] * args.output_spatial_lengths_[2]; - // Update GemmM to use split N (not original N) - GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] * - args.output_spatial_lengths_[2]; - } + GemmM = a_grid_desc_m_k.get_length(number<0>{}); + GemmN = b_grid_desc_n_k.get_length(number<0>{}); + GemmK = a_grid_desc_m_k.get_length(number<1>{}); + GemmBatch = integer_divide_ceil(args.G_, NumGroupsPerBatch); + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK + << ", GemmBatch: " << GemmBatch + << ", N per split: " << n_per_split + << ", number of N splits: " << n_splits + << ", input_batch_stride: " << input_batch_stride + << ", output_batch_stride: " << output_batch_stride + << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl; + } + } using AGridDescMK = remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeADescriptor_M_K())>; @@ -860,6 +889,7 @@ struct GroupedConvolutionForwardKernel static_cast(batch_offset) * static_cast(kargs.output_batch_stride); + // Adjust pointers: combine group offset and batch offset const InDataType* a_ptr = static_cast(kargs.in_ptr) + group_offset_a + input_batch_offset; diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp index cbe8fdbdaa..ef54d45d5d 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp @@ -192,13 +192,17 @@ struct TransformConvFwdToGemm std::is_same_v>); static_assert(std::is_same_v> || std::is_same_v>); + + // Store original N + original_N_ = c_g_n_k_wos_lengths[I1]; + if constexpr(SplitN) { N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths); } else { - N_ = c_g_n_k_wos_lengths[I1]; + N_ = original_N_; } } @@ -253,8 +257,7 @@ struct TransformConvFwdToGemm } else { - N_ = c_g_n_k_wos_lengths[I1]; - original_N_ = N_; + N_ = original_N_; } } @@ -438,10 +441,10 @@ struct TransformConvFwdToGemm bool>::type = false> CK_TILE_HOST auto MakeADescriptor_M_K() const { + IndexType NStrideTensorA_ = Wi_ * G_ * C_; IndexType WiStride_ = G_ * C_; - IndexType CStrideTensorA_ = 1; - IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_; IndexType GStrideTensorA_ = C_; + IndexType CStrideTensorA_ = 1; if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) { @@ -669,11 +672,11 @@ struct TransformConvFwdToGemm CK_TILE_HOST auto MakeADescriptor_M_K() const { + IndexType NStrideTensorA_ = Hi_ * Wi_ * G_ * C_; IndexType HiStride_ = Wi_ * G_ * C_; IndexType WiStride_ = G_ * C_; - IndexType CStrideTensorA_ = 1; - IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_; IndexType GStrideTensorA_ = C_; + IndexType CStrideTensorA_ = 1; if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) { @@ -928,12 +931,12 @@ struct TransformConvFwdToGemm CK_TILE_HOST auto MakeADescriptor_M_K() const { + IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_; IndexType DiStride_ = Hi_ * Wi_ * G_ * C_; IndexType HiStride_ = Wi_ * G_ * C_; IndexType WiStride_ = G_ * C_; - IndexType CStrideTensorA_ = 1; - IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_; IndexType GStrideTensorA_ = C_; + IndexType CStrideTensorA_ = 1; if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) { @@ -1257,9 +1260,9 @@ struct TransformConvFwdToGemm bool>::type = false> CK_TILE_HOST auto MakeBDescriptor_N_K() const { - IndexType CStrideTensorB_ = 1; - IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_; IndexType GStrideTensorB_ = K_ * Z_ * Y_ * X_ * C_; + IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_; + IndexType CStrideTensorB_ = 1; if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3) { @@ -1324,10 +1327,10 @@ struct TransformConvFwdToGemm bool>::type = false> CK_TILE_HOST auto MakeCDescriptor_M_N() const { + IndexType NStrideTensorC_ = Wo_ * G_ * K_; IndexType WoStride_ = G_ * K_; - IndexType KStrideTensorC_ = 1; - IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_; IndexType GStrideTensorC_ = K_; + IndexType KStrideTensorC_ = 1; const IndexType NDoHoWo = N_ * Wo_; if constexpr(NumGroupsToMerge == 1) @@ -1372,7 +1375,8 @@ struct TransformConvFwdToGemm unmerged_padded_desc, make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), make_merge_transform(make_tuple(K_, NumGroupsToMerge))), - make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + // TODO: sequence<0,1> or sequence<1,0>? + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); } } @@ -1385,11 +1389,11 @@ struct TransformConvFwdToGemm bool>::type = false> CK_TILE_HOST auto MakeCDescriptor_M_N() const { + IndexType NStrideTensorC_ = Ho_ * Wo_ * G_ * K_; IndexType HoStride_ = Wo_ * G_ * K_; IndexType WoStride_ = G_ * K_; - IndexType KStrideTensorC_ = 1; - IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_; IndexType GStrideTensorC_ = K_; + IndexType KStrideTensorC_ = 1; const IndexType NDoHoWo = N_ * Ho_ * Wo_; if constexpr(NumGroupsToMerge == 1) @@ -1438,7 +1442,8 @@ struct TransformConvFwdToGemm unmerged_padded_desc, make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), make_merge_transform(make_tuple(K_, NumGroupsToMerge))), - make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), + // TODO: sequence<0,1> or sequence<1,0>? + make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); } } @@ -1450,12 +1455,12 @@ struct TransformConvFwdToGemm bool>::type = false> CK_TILE_HOST auto MakeCDescriptor_M_N() const { + IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_; IndexType DoStride_ = Ho_ * Wo_ * G_ * K_; IndexType HoStride_ = Wo_ * G_ * K_; IndexType WoStride_ = G_ * K_; - IndexType KStrideTensorC_ = 1; - IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_; IndexType GStrideTensorC_ = K_; + IndexType KStrideTensorC_ = 1; const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_; if constexpr(NumGroupsToMerge == 1) @@ -1505,6 +1510,7 @@ struct TransformConvFwdToGemm unmerged_padded_desc, make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), make_merge_transform(make_tuple(K_, NumGroupsToMerge))), + // TODO: sequence<0,1> or sequence<1,0>? make_tuple(sequence<0, 1>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); } From 9968bef1071745f058412abe8acfe00121910fde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Fri, 31 Oct 2025 09:59:30 +0000 Subject: [PATCH 02/10] Fix building CK fwd convs. --- .../kernel/grouped_convolution_forward_kernel.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index d5486c4493..ce41c6a99a 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -367,6 +367,7 @@ struct GroupedConvFwdKernelArgs index_t GemmN; index_t GemmK; index_t GemmBatch; + index_t NumGroupsPerBatch; const void* in_ptr; const void* wei_ptr; From 01541edecac9e7411c23698bfd68bfd14e733ac3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Fri, 31 Oct 2025 10:07:10 +0000 Subject: [PATCH 03/10] Add number of merged groups to conv fwd kernel name. --- .../grouped_convolution_forward_kernel.hpp | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index ce41c6a99a..0f42c35df7 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -481,13 +481,25 @@ struct GroupedConvolutionForwardKernel [[nodiscard]] CK_TILE_HOST static const std::string GetName() { + constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge; // clang-format off - return concat('_', "grouped_convolution_forward", - gemm_prec_str(), - "gemm", - GemmPipeline::GetName(), - "epilogue", - EpiloguePipeline::GetName()); + if (NumGroupsToMerge > 1) { + return concat('_', "grouped_convolution_forward", + gemm_prec_str(), + "gemm", + GemmPipeline::GetName(), + "epilogue", + EpiloguePipeline::GetName(), + "merge", + NumGroupsToMerge); + } else { + return concat('_', "grouped_convolution_forward", + gemm_prec_str(), + "gemm", + GemmPipeline::GetName(), + "epilogue", + EpiloguePipeline::GetName()); + } // clang-format on } From 60cb60fa62ee359f9a9cac62e6a9881fd60d9a24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= Date: Fri, 31 Oct 2025 10:11:53 +0000 Subject: [PATCH 04/10] Get number of merged groups from conv config. --- .../grouped_convolution_forward_invoker.hpp | 63 +++++++++---------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp index 7ac6a20d70..1f95cf77dc 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp @@ -7,7 +7,7 @@ struct GroupedConvolutionForwardInvoker { template , - ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence, ck_tile:: - sequence, - GemmConfig::PermuteA, - GemmConfig::PermuteB>; + sequence, + ConvConfig::PermuteA, + ConvConfig::PermuteB>; - constexpr ck_tile::index_t VectorSizeA = 8; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - constexpr ck_tile::index_t NumGroupsToMerge = 1; + constexpr ck_tile::index_t VectorSizeA = ConvConfig::VectorSizeA; + constexpr ck_tile::index_t VectorSizeB = ConvConfig::VectorSizeB; + constexpr ck_tile::index_t VectorSizeC = ConvConfig::VectorSizeC; constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner; + ConvConfig::TileParitionerGroupNum, + ConvConfig::TileParitionerM01>; using GroupedConvTraitsType = ck_tile::GroupedConvTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< - GemmConfig::kPadM, - GemmConfig::kPadN, - GemmConfig::kPadK, - GemmConfig::DoubleSmemBuffer, + ConvConfig::kPadM, + ConvConfig::kPadN, + ConvConfig::kPadK, + ConvConfig::DoubleSmemBuffer, typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::AsLayout, typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::BsLayout, typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::CLayout, - GemmConfig::TransposeC, - GemmConfig::UseStructuredSparsity, + ConvConfig::TransposeC, + ConvConfig::UseStructuredSparsity, false, // Persistent, - GemmConfig::NumWaveGroups, - GemmConfig::Preshuffle>; + ConvConfig::NumWaveGroups, + ConvConfig::Preshuffle>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem< InDataType, @@ -82,7 +81,7 @@ struct GroupedConvolutionForwardInvoker VectorSizeB>; using BaseGemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template UniversalGemmPipeline; + ConvConfig::Pipeline>::template UniversalGemmPipeline; const ck_tile::index_t gemm_k = args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(), @@ -90,8 +89,8 @@ struct GroupedConvolutionForwardInvoker 1, std::multiplies()); - const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; - const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile; + const ck_tile::index_t k_grain = args.k_batch * ConvConfig::K_Tile; + const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * ConvConfig::K_Tile; const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); @@ -101,7 +100,7 @@ struct GroupedConvolutionForwardInvoker [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto scheduler = ConvConfig::Scheduler; constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = @@ -121,7 +120,7 @@ struct GroupedConvolutionForwardInvoker VectorSizeB>; using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + ConvConfig::Pipeline>::template GemmPipeline; using ConvEpilogue = ck_tile::CShuffleEpilogue Date: Fri, 31 Oct 2025 10:21:49 +0000 Subject: [PATCH 05/10] Rename GemmConfig to ConvConfig. --- .../grouped_convolution_forward.cpp | 6 +++--- .../run_grouped_convolution_fwd_example.inc | 16 ++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp index b979d4feb3..bef404b53a 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp @@ -14,7 +14,7 @@ #include "grouped_convolution_forward_invoker.hpp" #include "run_grouped_convolution_fwd_example.inc" -template