From ff1413c7fd06ab800550649503abf4e8ebb43204 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 29 Oct 2025 15:32:46 +0000 Subject: [PATCH 01/17] Add DirectLoad tparam & clean up headers. --- .../ck_tile/builder/reflect/instance_traits.hpp | 10 ---------- ...rouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp | 13 ++++++++++--- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp index a47ad0ef57..cab61cdee8 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp @@ -14,18 +14,8 @@ #pragma once -#include #include -#include #include -#include -#include -#include -#include -#include -#include -#include -#include "instance_traits_util.hpp" namespace ck_tile::reflect { diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 21201b8d50..88b884387b 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -15,6 +15,7 @@ #pragma once #include "instance_traits.hpp" +#include "instance_traits_util.hpp" // Forward declaration to avoid circular dependency. // This file will be included by the device implementation header, so we cannot include @@ -69,7 +70,8 @@ template + typename BComputeDataType, + bool DirectLoad> struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; } // namespace ck::tensor_operation::device @@ -124,7 +126,8 @@ template + typename BComputeDataType_, + bool DirectLoad_> struct InstanceTraits> + BComputeDataType_, + DirectLoad_>> { // Spatial dimension static constexpr int kSpatialDim = NDimSpatial; @@ -256,6 +260,8 @@ struct InstanceTraits(); // 47. AComputeDataType oss << "," << detail::type_name(); // 48. BComputeDataType + oss << "," << DirectLoad; // 49. DirectLoad oss << ">"; return oss.str(); From 0b5713ad37068e33539c7573e9a1df5f7211cb96 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 29 Oct 2025 15:35:42 +0000 Subject: [PATCH 02/17] Add convolution traits. --- .../ck_tile/builder/reflect/conv_traits.hpp | 368 ++++++++++++++++++ experimental/builder/test/CMakeLists.txt | 2 + .../builder/test/conv/test_conv_traits.cpp | 144 +++++++ 3 files changed, 514 insertions(+) create mode 100644 experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp create mode 100644 experimental/builder/test/conv/test_conv_traits.cpp diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp new file mode 100644 index 0000000000..738fa2bda3 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -0,0 +1,368 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace ck_tile::reflect::conv { + +/// @brief Helper structures for organizing trait data with domain-specific naming + +/// @brief Data tile dimensions processed by workgroup +struct DataTileInfo +{ + int m; // Processed tile m dimension + int n; // Processed tile n dimension + int k; // Processed tile k dimension +}; + +struct InputTileTransferDimensions +{ + int k0; + int m_or_n; // m for A transfer, n for B transfer + int k1; +}; + +struct InputTileTransferParams +{ + int k1; + std::array thread_cluster_dims; + std::array thread_cluster_order; + std::array src_access_order; + int src_vector_dim; + int src_scalar_per_vector; + int dst_scalar_per_vector_k1; + bool lds_padding; +}; + +struct InputTileTransferInfo +{ + InputTileTransferDimensions tile_dimensions; + InputTileTransferParams transfer_params; +}; + +struct WarpGemmParams +{ + int gemm_m; + int gemm_n; + int num_m_gemms; + int num_n_gemms; +}; + +struct WarpShuffleParams +{ + int m_gemms_per_shuffle; + int n_gemms_per_shuffle; +}; + +struct OutputTileTransferInfo +{ + WarpShuffleParams shuffle_params; + // m_block, m_wave_per_xdl, n_block, n_wave_per_xdl + std::array thread_cluster_dims; + int scalar_per_vector; +}; + +// Helper metafunctions to derive signature information from Instance types + +// Derive ConvDirection from device kernel type +template +constexpr builder::ConvDirection conv_direction() +{ + using InstTraits = InstanceTraits; + + // Check if conv_forward_specialization exists + if constexpr(requires { &InstTraits::kConvForwardSpecialization; }) + { + return builder::ConvDirection::FORWARD; + } + // Check if kConvBwdDataSpecialization exists + else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; }) + { + return builder::ConvDirection::BACKWARD_DATA; + } + else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; }) + { + return builder::ConvDirection::BACKWARD_WEIGHT; + } + else + { + return builder::ConvDirection::FORWARD; // Default fallback + } +} + +// Derive GroupConvLayout from layout types/devel/composable_kernel +template +constexpr auto conv_layout() +{ + using InstTraits = InstanceTraits; + using ALayout = typename InstTraits::ALayout; + using BLayout = typename InstTraits::BLayout; + using ELayout = typename InstTraits::ELayout; + + namespace ctc = ck::tensor_layout::convolution; + + if constexpr(InstTraits::kSpatialDim == 1) + { + if constexpr(std::is_same_v && std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout1D::GNWC_GKXC_GNWK; + } + else if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + return builder::GroupConvLayout1D::NWGC_GKXC_NWGK; + } + else if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + return builder::GroupConvLayout1D::NGCW_GKXC_NGKW; + } + else if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + return builder::GroupConvLayout1D::NGCW_GKCX_NGKW; + } + } + else if constexpr(InstTraits::kSpatialDim == 2) + { + if constexpr(std::is_same_v && std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout2D::NGCHW_GKYXC_NGKHW; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout2D::NGCHW_GKCYX_NGKHW; + } + } + else if constexpr(InstTraits::kSpatialDim == 3) + { + if constexpr(std::is_same_v && std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW; + } + } +} + +// Derive DataType from data type +template +constexpr builder::DataType conv_data_type() +{ + using InstTraits = InstanceTraits; + using ADataType = typename InstTraits::ADataType; + + if constexpr(std::is_same_v) + { + return builder::DataType::FP16; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::BF16; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::FP32; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::FP8; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::I8; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::I8; + } + else + { + // Default fallback + return builder::DataType::FP32; + } +} + +// Helper to extract values from Sequence types at compile time +template +struct SequenceAt; + +template +struct SequenceAt, Idx> +{ + static constexpr int value = ck::Sequence::At(Idx); +}; + +// Primary template for ConvTraits +template +struct ConvTraits; + +// Specialization 1: Direct from Instance (Primary use case) +template + requires requires { typename InstanceTraits; } +struct ConvTraits +{ + using InstTraits = InstanceTraits; + + // Signature information (derived from Instance template parameters) + static constexpr int spatial_dim = InstTraits::kSpatialDim; + static constexpr builder::ConvDirection direction = conv_direction(); + static constexpr auto layout = conv_layout(); + static constexpr builder::DataType data_type = conv_data_type(); + + static constexpr auto gemm_specialization = InstTraits::kGemmSpecialization; + static constexpr auto conv_specialization = InstTraits::kConvForwardSpecialization; + + // Algorithm information (extracted from Instance template parameters) + static constexpr int thread_block_size = InstTraits::kBlockSize; + static constexpr DataTileInfo tile_dims = { + .m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock}; + + static constexpr InputTileTransferInfo a_tile_transfer = { + .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, + .m_or_n = InstTraits::kMPerBlock, + .k1 = InstTraits::kAK1}, + .transfer_params = {.k1 = InstTraits::kAK1, + .thread_cluster_dims = InstTraits::kAThreadClusterLengths, + .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, + .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, + .src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kABlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}; + + static constexpr InputTileTransferInfo b_tile_transfer = { + .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, + .m_or_n = InstTraits::kNPerBlock, + .k1 = InstTraits::kBK1}, + .transfer_params = {.k1 = InstTraits::kBK1, + .thread_cluster_dims = InstTraits::kBThreadClusterLengths, + .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, + .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, + .src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kBBlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}; + + static constexpr WarpGemmParams warp_gemm = {.gemm_m = InstTraits::kMPerXDL, + .gemm_n = InstTraits::kNPerXDL, + .num_m_gemms = InstTraits::kMXdlPerWave, + .num_n_gemms = InstTraits::kNXdlPerWave}; + + static constexpr OutputTileTransferInfo c_tile_transfer = { + .shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}; + + // Pipeline version (only available for forward convolutions) + // For backward data, this member doesn't exist in InstanceTraits + template + static constexpr auto get_pipeline_version() + { + if constexpr(requires { T::kPipelineVersion; }) + { + return T::kPipelineVersion; + } + else + { + // Return a default or indicate not available + return ck::BlockGemmPipelineVersion::v1; + } + } + + static constexpr auto pipeline_version = get_pipeline_version(); + + // Pipeline version (only available for forward convolutions) + // For backward data, this member doesn't exist in InstanceTraits + template + static constexpr auto get_pipeline_scheduler() + { + if constexpr(requires { T::kPipelineScheduler; }) + { + return T::kPipelineScheduler; + } + else + { + // Return a default or indicate not available + return ck::BlockGemmPipelineScheduler::Intrawave; + } + } + + static constexpr auto pipeline_scheduler = get_pipeline_scheduler(); +}; + +// Specialization 2: From Builder (Backward compatibility) +template +struct ConvTraits> +{ + using Factory = builder::ConvFactory; + using Instance = typename Factory::Instance; + + // Delegate to Instance-based ConvTraits + using InstanceConvTraits = ConvTraits; + + // Forward all members from Instance-based traits + static constexpr int spatial_dim = InstanceConvTraits::spatial_dim; + static constexpr builder::ConvDirection direction = InstanceConvTraits::direction; + static constexpr auto layout = InstanceConvTraits::layout; + static constexpr builder::DataType data_type = InstanceConvTraits::data_type; + + static constexpr int thread_block_size = InstanceConvTraits::thread_block_size; + static constexpr DataTileInfo tile_dims = InstanceConvTraits::tile_dims; + static constexpr InputTileTransferInfo a_tile_transfer = InstanceConvTraits::a_tile_transfer; + static constexpr InputTileTransferInfo b_tile_transfer = InstanceConvTraits::b_tile_transfer; + static constexpr WarpGemmParams warp_gemm = InstanceConvTraits::warp_gemm; + static constexpr OutputTileTransferInfo c_tile_transfer = InstanceConvTraits::c_tile_transfer; + static constexpr auto pipeline_version = InstanceConvTraits::pipeline_version; + static constexpr auto pipeline_scheduler = InstanceConvTraits::pipeline_scheduler; +}; + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index b7adbc116a..58a1244776 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -58,3 +58,5 @@ add_ck_factory_test(test_ck_factory_grouped_convolution_forward_bias_bnorm_clamp add_ck_factory_test(test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp) add_ck_factory_test(test_ck_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp) +add_ck_builder_test(test_conv_traits + conv/test_conv_traits.cpp) diff --git a/experimental/builder/test/conv/test_conv_traits.cpp b/experimental/builder/test/conv/test_conv_traits.cpp new file mode 100644 index 0000000000..5a9f3fe854 --- /dev/null +++ b/experimental/builder/test/conv/test_conv_traits.cpp @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +namespace { + +using ::testing::ElementsAre; + +// Test fixture for ConvTraits tests +class ConvTraitsTest : public ::testing::Test +{ +}; + +// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 +TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + false>; // DirectLoad + + // Use ConvTraits to extract compile-time information + using Traits = ck_tile::reflect::conv::ConvTraits; + + // Verify signature information + EXPECT_EQ(Traits::spatial_dim, 2); + EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK); + EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); + + // Verify specializations + EXPECT_EQ(Traits::gemm_specialization, + ck::tensor_operation::device::GemmSpecialization::Default); + EXPECT_EQ(Traits::conv_specialization, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default); + + // Verify algorithm information + EXPECT_EQ(Traits::thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(Traits::tile_dims.m, 128); + EXPECT_EQ(Traits::tile_dims.n, 128); + EXPECT_EQ(Traits::tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(Traits::a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(Traits::a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(Traits::a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(Traits::a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(Traits::b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(Traits::b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(Traits::b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(Traits::b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(Traits::warp_gemm.gemm_m, 32); + EXPECT_EQ(Traits::warp_gemm.gemm_n, 32); + EXPECT_EQ(Traits::warp_gemm.num_m_gemms, 4); + EXPECT_EQ(Traits::warp_gemm.num_n_gemms, 4); + + // Verify output tile transfer info + EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(Traits::c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(Traits::pipeline_scheduler, ck::BlockGemmPipelineScheduler::Intrawave); + EXPECT_EQ(Traits::pipeline_version, ck::BlockGemmPipelineVersion::v1); +} + +} // anonymous namespace From dc469617398ea9c9e63060dc63e3f321599d4e23 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 29 Oct 2025 16:03:19 +0000 Subject: [PATCH 03/17] Update inline documentation. --- .../ck_tile/builder/reflect/conv_traits.hpp | 165 +++++++++++++----- 1 file changed, 118 insertions(+), 47 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 738fa2bda3..f51475ca29 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,75 +14,113 @@ namespace ck_tile::reflect::conv { /// @brief Helper structures for organizing trait data with domain-specific naming -/// @brief Data tile dimensions processed by workgroup +/// @brief Data tile dimensions processed by a workgroup. +/// @details This struct defines the M, N, and K dimensions of the data tile +/// that a single workgroup (thread block) is responsible for processing in the +/// underlying GEMM computation. struct DataTileInfo { - int m; // Processed tile m dimension - int n; // Processed tile n dimension - int k; // Processed tile k dimension + int m; ///< M dimension of the tile processed by the workgroup (MPerBlock). + int n; ///< N dimension of the tile processed by the workgroup (NPerBlock). + int k; ///< K dimension of the tile processed by the workgroup (KPerBlock). }; +/// @brief Dimensions for an input data tile transfer. +/// @details Defines the shape of the input tile (A or B matrix) as it is +/// transferred from global memory to LDS. The tile is conceptually divided +/// into k0 and k1 dimensions. struct InputTileTransferDimensions { - int k0; - int m_or_n; // m for A transfer, n for B transfer - int k1; + int k0; ///< The outer dimension of K, where K = k0 * k1. + int m_or_n; ///< The M dimension for the A matrix transfer, or the N dimension for the B matrix. + int k1; ///< The inner dimension of K, often corresponding to the vector load size from global + ///< memory. }; +/// @brief Parameters governing the transfer of an input tile. +/// @details This struct holds configuration details for how an input tile is +/// loaded from global memory into LDS, including thread clustering, memory +/// access patterns, and vectorization settings. struct InputTileTransferParams { - int k1; - std::array thread_cluster_dims; - std::array thread_cluster_order; - std::array src_access_order; - int src_vector_dim; - int src_scalar_per_vector; - int dst_scalar_per_vector_k1; - bool lds_padding; + int k1; ///< The inner K dimension size, often matching the vectorization width. + std::array + thread_cluster_dims; ///< Spatial thread distribution over the input data tile; defines how + ///< many threads are arranged on each axis. + std::array thread_cluster_order; ///< The order of thread spatial distribution over the + ///< input tensor dimensions. + std::array src_access_order; ///< The order of accessing input tensor axes (e.g., which + ///< dimension to read first). + int src_vector_dim; ///< The index of the axis on which vectorized memory access is performed + ///< (the contiguous dimension). + int src_scalar_per_vector; ///< The size of the vector access instruction; the number of + ///< elements accessed per thread per instruction. + int dst_scalar_per_vector_k1; ///< The size of the vectorized store into LDS memory along the K1 + ///< dimension. + bool lds_padding; ///< Flag indicating if padding is used for the LDS tensor to prevent bank + ///< conflicts. }; +/// @brief Complete information for an input tile transfer. +/// @details Combines the dimensional information and transfer parameters for +/// a full description of an input tile's journey from global memory to LDS. struct InputTileTransferInfo { - InputTileTransferDimensions tile_dimensions; - InputTileTransferParams transfer_params; + InputTileTransferDimensions tile_dimensions; ///< The shape and layout of the tile. + InputTileTransferParams transfer_params; ///< The parameters for the memory transfer operation. }; +/// @brief Parameters for the warp-level GEMM computation. +/// @details Defines the configuration of the GEMM operation performed by each +/// warp using hardware MFMA (Matrix Fused Multiply-Add) instructions. struct WarpGemmParams { - int gemm_m; - int gemm_n; - int num_m_gemms; - int num_n_gemms; + int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl). + int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl). + int num_m_gemms; ///< The number of MFMA iterations along the M dimension of the output tile per + ///< wavefront (MXdlPerWave). + int num_n_gemms; ///< The number of MFMA iterations along the N dimension of the output tile per + ///< wavefront (NXdlPerWave). }; +/// @brief Parameters for shuffling data between warps (CShuffle optimization). +/// @details Configures how many MFMA instruction results are processed per +/// wave in each iteration of the CShuffle routine. struct WarpShuffleParams { - int m_gemms_per_shuffle; - int n_gemms_per_shuffle; + int m_gemms_per_shuffle; ///< Number of MFMA results along the M dimension to process per wave + ///< per shuffle iteration. + int n_gemms_per_shuffle; ///< Number of MFMA results along the N dimension to process per wave + ///< per shuffle iteration. }; +/// @brief Information for the output tile transfer (CShuffle). +/// @details Describes how the final computed tile (C matrix) is written out from +/// LDS to global memory, including shuffling, thread clustering, and vectorization. struct OutputTileTransferInfo { - WarpShuffleParams shuffle_params; + WarpShuffleParams shuffle_params; ///< Configuration for cross-warp data shuffling. // m_block, m_wave_per_xdl, n_block, n_wave_per_xdl - std::array thread_cluster_dims; - int scalar_per_vector; + std::array thread_cluster_dims; ///< The spatial thread distribution used for storing + ///< data into the output tensor. + int scalar_per_vector; ///< The size of the vectorized memory access when storing data to the + ///< output tensor. }; // Helper metafunctions to derive signature information from Instance types -// Derive ConvDirection from device kernel type +/// @brief Derives the convolution direction from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT). template constexpr builder::ConvDirection conv_direction() { using InstTraits = InstanceTraits; - // Check if conv_forward_specialization exists if constexpr(requires { &InstTraits::kConvForwardSpecialization; }) { return builder::ConvDirection::FORWARD; } - // Check if kConvBwdDataSpecialization exists else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; }) { return builder::ConvDirection::BACKWARD_DATA; @@ -97,7 +135,9 @@ constexpr builder::ConvDirection conv_direction() } } -// Derive GroupConvLayout from layout types/devel/composable_kernel +/// @brief Derives the grouped convolution layout from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return A `builder::GroupConvLayout{1D|2D|3D}` enum value corresponding to the tensor layouts. template constexpr auto conv_layout() { @@ -185,7 +225,9 @@ constexpr auto conv_layout() } } -// Derive DataType from data type +/// @brief Derives the data type from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32). template constexpr builder::DataType conv_data_type() { @@ -214,7 +256,7 @@ constexpr builder::DataType conv_data_type() } else if constexpr(std::is_same_v) { - return builder::DataType::I8; + return builder::DataType::U8; } else { @@ -223,41 +265,59 @@ constexpr builder::DataType conv_data_type() } } -// Helper to extract values from Sequence types at compile time +/// @brief Helper to extract a value from a `ck::Sequence` type at a specific index. +/// @tparam Seq The `ck::Sequence` type. +/// @tparam Idx The index of the value to extract. template struct SequenceAt; +/// @brief Specialization of `SequenceAt` for `ck::Sequence`. template struct SequenceAt, Idx> { + /// The integer value at the specified index within the sequence. static constexpr int value = ck::Sequence::At(Idx); }; -// Primary template for ConvTraits +/// @brief Primary template for extracting convolution traits. +/// @details This struct is the main entry point for reflecting on a convolution +/// kernel's properties. It is specialized to handle different kinds of input types. template struct ConvTraits; -// Specialization 1: Direct from Instance (Primary use case) +/// @brief Specialization of `ConvTraits` for a direct device kernel `Instance`. +/// @details This is the primary specialization used to extract a comprehensive +/// set of traits directly from a fully-formed device kernel `Instance` type. +/// It uses `InstanceTraits` to access the kernel's template parameters. template requires requires { typename InstanceTraits; } struct ConvTraits { using InstTraits = InstanceTraits; - // Signature information (derived from Instance template parameters) - static constexpr int spatial_dim = InstTraits::kSpatialDim; + // --- Signature Information --- + /// @brief The number of spatial dimensions in the convolution (1, 2, or 3). + static constexpr int spatial_dim = InstTraits::kSpatialDim; + /// @brief The direction of the convolution (Forward, Backward Data, or Backward Weight). static constexpr builder::ConvDirection direction = conv_direction(); - static constexpr auto layout = conv_layout(); - static constexpr builder::DataType data_type = conv_data_type(); + /// @brief The memory layout of the convolution tensors (e.g., GNHWC_GKYXC_GNHWK). + static constexpr auto layout = conv_layout(); + /// @brief The primary data type used in the computation (e.g., FP16, FP32). + static constexpr builder::DataType data_type = conv_data_type(); + /// @brief The GEMM specialization used by the kernel (e.g., Tiling, Partition). static constexpr auto gemm_specialization = InstTraits::kGemmSpecialization; + /// @brief The convolution-specific specialization (e.g., Default, 1x1). static constexpr auto conv_specialization = InstTraits::kConvForwardSpecialization; - // Algorithm information (extracted from Instance template parameters) - static constexpr int thread_block_size = InstTraits::kBlockSize; + // --- Algorithm Information --- + /// @brief The total number of threads in a thread block (workgroup). + static constexpr int thread_block_size = InstTraits::kBlockSize; + /// @brief The dimensions of the data tile processed by the thread block. static constexpr DataTileInfo tile_dims = { .m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock}; + /// @brief Configuration for the A-matrix (input) tile transfer. static constexpr InputTileTransferInfo a_tile_transfer = { .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, .m_or_n = InstTraits::kMPerBlock, @@ -272,6 +332,7 @@ struct ConvTraits InstTraits::kABlockTransferDstScalarPerVectorK1, .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}; + /// @brief Configuration for the B-matrix (weights) tile transfer. static constexpr InputTileTransferInfo b_tile_transfer = { .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, .m_or_n = InstTraits::kNPerBlock, @@ -286,11 +347,13 @@ struct ConvTraits InstTraits::kBBlockTransferDstScalarPerVectorK1, .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}; + /// @brief Parameters for the warp-level GEMM computation. static constexpr WarpGemmParams warp_gemm = {.gemm_m = InstTraits::kMPerXDL, .gemm_n = InstTraits::kNPerXDL, .num_m_gemms = InstTraits::kMXdlPerWave, .num_n_gemms = InstTraits::kNXdlPerWave}; + /// @brief Configuration for the C-matrix (output) tile transfer. static constexpr OutputTileTransferInfo c_tile_transfer = { .shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, @@ -300,8 +363,9 @@ struct ConvTraits InstTraits::kCThreadClusterLengths[3]}, .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}; - // Pipeline version (only available for forward convolutions) - // For backward data, this member doesn't exist in InstanceTraits + /// @brief Helper to safely get the pipeline version. + /// @details This is only available for some convolutions (e.g., forward). + /// If not present in `InstanceTraits`, it returns a default value. template static constexpr auto get_pipeline_version() { @@ -316,10 +380,12 @@ struct ConvTraits } } + /// @brief The block GEMM pipeline version used by the kernel. static constexpr auto pipeline_version = get_pipeline_version(); - // Pipeline version (only available for forward convolutions) - // For backward data, this member doesn't exist in InstanceTraits + /// @brief Helper to safely get the pipeline scheduler. + /// @details This is only available for some convolutions. If not present + /// in `InstanceTraits`, it returns a default value. template static constexpr auto get_pipeline_scheduler() { @@ -334,10 +400,15 @@ struct ConvTraits } } + /// @brief The pipeline scheduler used by the kernel. static constexpr auto pipeline_scheduler = get_pipeline_scheduler(); }; -// Specialization 2: From Builder (Backward compatibility) +/// @brief Specialization of `ConvTraits` for a `ConvBuilder` type. +/// @details This specialization provides backward compatibility for reflecting +/// on kernels defined via the `ConvBuilder` interface. It works by first +/// creating the `Instance` via the builder's factory, and then delegating +/// all trait extraction to the `ConvTraits` specialization. template From 0afd37456e062f9648bddd9274a1c9929ad6f8f1 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 30 Oct 2025 13:02:19 +0000 Subject: [PATCH 04/17] Add more convolution specialization and gemm padding types. --- .../builder/include/ck_tile/builder/types.hpp | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 7f49e77f81..ea771820f8 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -87,4 +87,41 @@ enum class ConvFwdSpecialization FILTER_3x3 }; +// Enums for the bacward data convolution specialization. +enum class ConvBwdDataSpecialization +{ + DEFAULT, + FILTER_1X1_STRIDE1_PAD0, +}; + +// Enums for the bacward weight convolution specialization. +enum class ConvBwdWeightSpecialization +{ + DEFAULT, + FILTER_1X1_STRIDE1_PAD0, + FILTER_1X1_PAD0, + ODD_C, +}; + +// Enums for the Gemm padding. +enum class GemmPadding +{ + DEFAULT, + M_PADDING, + N_PADDING, + K_PADDING, + MN_PADDING, + MK_PADDING, + NK_PADDING, + MNK_PADDING, + O_PADDING, + MO_PADDING, + NO_PADDING, + KO_PADDING, + MNO_PADDING, + MKO_PADDING, + NKO_PADDING, + MNKO_PADDING, +}; + } // namespace ck_tile::builder From e84d1f22aa2c8b0257038b963aaef15328a8ea7e Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 30 Oct 2025 13:03:01 +0000 Subject: [PATCH 05/17] Add additional helper functions & more tests to conv traits. --- .../ck_tile/builder/reflect/conv_traits.hpp | 235 ++++++++++++++++-- .../builder/test/conv/test_conv_traits.cpp | 182 +++++++++++++- 2 files changed, 387 insertions(+), 30 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index f51475ca29..f66de240db 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -9,6 +9,9 @@ #include #include #include +#include +#include +#include namespace ck_tile::reflect::conv { @@ -75,12 +78,12 @@ struct InputTileTransferInfo /// warp using hardware MFMA (Matrix Fused Multiply-Add) instructions. struct WarpGemmParams { - int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl). - int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl). - int num_m_gemms; ///< The number of MFMA iterations along the M dimension of the output tile per - ///< wavefront (MXdlPerWave). - int num_n_gemms; ///< The number of MFMA iterations along the N dimension of the output tile per - ///< wavefront (NXdlPerWave). + int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl). + int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl). + int m_iter; ///< The number of MFMA iterations along the M dimension of the output tile per + ///< wavefront (MXdlPerWave). + int n_iter; ///< The number of MFMA iterations along the N dimension of the output tile per + ///< wavefront (NXdlPerWave). }; /// @brief Parameters for shuffling data between warps (CShuffle optimization). @@ -135,6 +138,72 @@ constexpr builder::ConvDirection conv_direction() } } +/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return A `builder::ConvFwdSpecialization`, `builder::ConvBwdDataSpecialization`, or +/// `builder::ConvBwdWeightSpecialization` enum value. +template +constexpr auto conv_spec() +{ + using InstTraits = InstanceTraits; + + if constexpr(requires { InstTraits::kConvForwardSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; + + if constexpr(InstTraits::kConvForwardSpecialization == Default) + { + return builder::ConvFwdSpecialization::DEFAULT; + } + else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Pad0) + { + return builder::ConvFwdSpecialization::FILTER_1X1_PAD0; + } + else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Stride1Pad0) + { + return builder::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0; + } + else if constexpr(InstTraits::kConvForwardSpecialization == Filter3x3) + { + return builder::ConvFwdSpecialization::FILTER_3x3; + } + } + else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; + + if constexpr(InstTraits::kConvBwdDataSpecialization == Default) + { + return builder::ConvBwdDataSpecialization::DEFAULT; + } + else if constexpr(InstTraits::kConvBwdDataSpecialization == Filter1x1Stride1Pad0) + { + return builder::ConvBwdDataSpecialization::FILTER_1X1_STRIDE1_PAD0; + } + } + else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; + + if constexpr(InstTraits::kConvBwdWeightSpecialization == Default) + { + return builder::ConvBwdWeightSpecialization::DEFAULT; + } + else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Stride1Pad0) + { + return builder::ConvBwdWeightSpecialization::FILTER_1X1_STRIDE1_PAD0; + } + else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Pad0) + { + return builder::ConvBwdWeightSpecialization::FILTER_1X1_PAD0; + } + else if constexpr(InstTraits::kConvBwdWeightSpecialization == OddC) + { + return builder::ConvBwdWeightSpecialization::ODD_C; + } + } +} + /// @brief Derives the grouped convolution layout from a device kernel `Instance` type. /// @tparam Instance The device kernel instance type. /// @return A `builder::GroupConvLayout{1D|2D|3D}` enum value corresponding to the tensor layouts. @@ -265,19 +334,120 @@ constexpr builder::DataType conv_data_type() } } -/// @brief Helper to extract a value from a `ck::Sequence` type at a specific index. -/// @tparam Seq The `ck::Sequence` type. -/// @tparam Idx The index of the value to extract. -template -struct SequenceAt; +/// @brief Derives the elementwise operation from op type. +/// @tparam ElementwiseOp Elementwise operation functor type. +/// @return A `builder::ElementwiseOperation` enum value corresponding to elementwise operation. +template +constexpr builder::ElementwiseOperation elementwise_op() +{ + constexpr std::string_view name = detail::elementwise_op_name(); + if constexpr(detail::case_insensitive_equal(name, "Bias")) + { + return builder::ElementwiseOperation::BIAS; + } + else if constexpr(detail::case_insensitive_equal(name, "BiasClamp")) + { + return builder::ElementwiseOperation::BIAS_CLAMP; + } + else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp")) + { + return builder::ElementwiseOperation::BIAS_BNORM_CLAMP; + } + else if constexpr(detail::case_insensitive_equal(name, "Bilinear")) + { + return builder::ElementwiseOperation::BILINEAR; + } + else if constexpr(detail::case_insensitive_equal(name, "Clamp")) + { + return builder::ElementwiseOperation::CLAMP; + } + else if constexpr(detail::case_insensitive_equal(name, "Scale")) + { + return builder::ElementwiseOperation::SCALE; + } + else if constexpr(detail::case_insensitive_equal(name, "PassThrough")) + { + return builder::ElementwiseOperation::PASS_THROUGH; + } +} -/// @brief Specialization of `SequenceAt` for `ck::Sequence`. -template -struct SequenceAt, Idx> +/// @brief Derives a gemm padding from a kernel instance type. +/// @tparam Instance - A Device Kernel object type. +/// @return A `builder::GemmPadding` enum value corresponding to kernel padding. +template +constexpr builder::GemmPadding gemm_padding() { - /// The integer value at the specified index within the sequence. - static constexpr int value = ck::Sequence::At(Idx); -}; + using InstTraits = InstanceTraits; + using enum builder::GemmPadding; + using enum ck::tensor_operation::device::GemmSpecialization; + + constexpr auto gemm_spec = InstTraits::kGemmSpecialization; + + if constexpr(gemm_spec == Default) + { + return DEFAULT; + } + else if constexpr(gemm_spec == MPadding) + { + return M_PADDING; + } + else if constexpr(gemm_spec == NPadding) + { + return N_PADDING; + } + else if constexpr(gemm_spec == KPadding) + { + return K_PADDING; + } + else if constexpr(gemm_spec == MNPadding) + { + return MN_PADDING; + } + else if constexpr(gemm_spec == MKPadding) + { + return MK_PADDING; + } + else if constexpr(gemm_spec == NKPadding) + { + return NK_PADDING; + } + else if constexpr(gemm_spec == MNKPadding) + { + return MNK_PADDING; + } + else if constexpr(gemm_spec == OPadding) + { + return O_PADDING; + } + else if constexpr(gemm_spec == MOPadding) + { + return MO_PADDING; + } + else if constexpr(gemm_spec == NOPadding) + { + return NO_PADDING; + } + else if constexpr(gemm_spec == KOPadding) + { + return KO_PADDING; + } + else if constexpr(gemm_spec == MNOPadding) + { + return MNO_PADDING; + } + else if constexpr(gemm_spec == MKOPadding) + { + return MKO_PADDING; + } + else if constexpr(gemm_spec == NKOPadding) + { + return NKO_PADDING; + } + else if constexpr(gemm_spec == MNKOPadding) + { + return MNKO_PADDING; + } +} /// @brief Primary template for extracting convolution traits. /// @details This struct is the main entry point for reflecting on a convolution @@ -305,10 +475,17 @@ struct ConvTraits /// @brief The primary data type used in the computation (e.g., FP16, FP32). static constexpr builder::DataType data_type = conv_data_type(); - /// @brief The GEMM specialization used by the kernel (e.g., Tiling, Partition). - static constexpr auto gemm_specialization = InstTraits::kGemmSpecialization; + static constexpr builder::ElementwiseOperation input_element_op = + elementwise_op(); + static constexpr builder::ElementwiseOperation weight_element_op = + elementwise_op(); + static constexpr builder::ElementwiseOperation output_element_op = + elementwise_op(); + + /// @brief The GEMM specialization used by the kernel - padding + static constexpr auto gemm_padding = gemm_padding(); /// @brief The convolution-specific specialization (e.g., Default, 1x1). - static constexpr auto conv_specialization = InstTraits::kConvForwardSpecialization; + static constexpr auto conv_specialization = conv_spec(); // --- Algorithm Information --- /// @brief The total number of threads in a thread block (workgroup). @@ -348,10 +525,10 @@ struct ConvTraits .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}; /// @brief Parameters for the warp-level GEMM computation. - static constexpr WarpGemmParams warp_gemm = {.gemm_m = InstTraits::kMPerXDL, - .gemm_n = InstTraits::kNPerXDL, - .num_m_gemms = InstTraits::kMXdlPerWave, - .num_n_gemms = InstTraits::kNXdlPerWave}; + static constexpr WarpGemmParams warp_gemm = {.gemm_m = InstTraits::kMPerXDL, + .gemm_n = InstTraits::kNPerXDL, + .m_iter = InstTraits::kMXdlPerWave, + .n_iter = InstTraits::kNXdlPerWave}; /// @brief Configuration for the C-matrix (output) tile transfer. static constexpr OutputTileTransferInfo c_tile_transfer = { @@ -426,6 +603,16 @@ struct ConvTraits> static constexpr auto layout = InstanceConvTraits::layout; static constexpr builder::DataType data_type = InstanceConvTraits::data_type; + static constexpr builder::ElementwiseOperation input_element_op = + InstanceConvTraits::input_element_op; + static constexpr builder::ElementwiseOperation weight_element_op = + InstanceConvTraits::weight_element_op; + static constexpr builder::ElementwiseOperation output_element_op = + InstanceConvTraits::output_element_op; + + static constexpr auto gemm_padding = InstanceConvTraits::gemm_padding; + static constexpr auto conv_specialization = InstanceConvTraits::conv_specialization; + static constexpr int thread_block_size = InstanceConvTraits::thread_block_size; static constexpr DataTileInfo tile_dims = InstanceConvTraits::tile_dims; static constexpr InputTileTransferInfo a_tile_transfer = InstanceConvTraits::a_tile_transfer; diff --git a/experimental/builder/test/conv/test_conv_traits.cpp b/experimental/builder/test/conv/test_conv_traits.cpp index 5a9f3fe854..b48a165550 100644 --- a/experimental/builder/test/conv/test_conv_traits.cpp +++ b/experimental/builder/test/conv/test_conv_traits.cpp @@ -5,6 +5,8 @@ #include #include #include +#include +#include namespace { @@ -83,12 +85,13 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK); EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); // Verify specializations - EXPECT_EQ(Traits::gemm_specialization, - ck::tensor_operation::device::GemmSpecialization::Default); - EXPECT_EQ(Traits::conv_specialization, - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default); + EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); // Verify algorithm information EXPECT_EQ(Traits::thread_block_size, 256); @@ -127,8 +130,8 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) // Verify warp GEMM params EXPECT_EQ(Traits::warp_gemm.gemm_m, 32); EXPECT_EQ(Traits::warp_gemm.gemm_n, 32); - EXPECT_EQ(Traits::warp_gemm.num_m_gemms, 4); - EXPECT_EQ(Traits::warp_gemm.num_n_gemms, 4); + EXPECT_EQ(Traits::warp_gemm.m_iter, 4); + EXPECT_EQ(Traits::warp_gemm.n_iter, 4); // Verify output tile transfer info EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); @@ -141,4 +144,171 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) EXPECT_EQ(Traits::pipeline_version, ck::BlockGemmPipelineVersion::v1); } +// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle +TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + ck::LoopScheduler::Default, // LoopSched + 1>; // NumGroupsToMerge + + // Use ConvTraits to extract compile-time information + using Traits = ck_tile::reflect::conv::ConvTraits; + + // Verify signature information + EXPECT_EQ(Traits::spatial_dim, 2); + EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK); + EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(Traits::thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(Traits::tile_dims.m, 128); + EXPECT_EQ(Traits::tile_dims.n, 128); + EXPECT_EQ(Traits::tile_dims.k, 16); +} +// Test ConvTraits with DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CDEBlockTransferClusterLengths + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + ck::LoopScheduler::Default>; // LoopSched + + // Use ConvTraits to extract compile-time information + using Traits = ck_tile::reflect::conv::ConvTraits; + + // Verify signature information + EXPECT_EQ(Traits::spatial_dim, 2); + EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK); + EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(Traits::thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(Traits::tile_dims.m, 128); + EXPECT_EQ(Traits::tile_dims.n, 128); + EXPECT_EQ(Traits::tile_dims.k, 16); +} } // anonymous namespace From 8f92299a6de51cb7a4060dd446853a7203f3c422 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 30 Oct 2025 13:03:40 +0000 Subject: [PATCH 06/17] Fix tests cmake file. --- experimental/builder/test/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 58a1244776..8d645250bc 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -19,7 +19,7 @@ endfunction() # The test_conv_builder target has all the unit tests (each test should run < 10 ms) add_ck_builder_test(test_conv_builder test_conv_builder.cpp - test_instance_traits.cpp + test_fwd_instance_traits.cpp test_instance_traits_util.cpp) add_ck_builder_test(test_inline_diff test_inline_diff.cpp) From 76dd2151d9b3d99f76f2ab8d2cdc0bbe9b59bbfa Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 30 Oct 2025 13:04:06 +0000 Subject: [PATCH 07/17] Add case insensitive string comparison --- .../builder/reflect/instance_traits_util.hpp | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index 545441fd90..bc5861d312 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -322,4 +322,30 @@ constexpr std::string tuple_name() }(static_cast(nullptr)); } +/// @brief Makes a case insensitive comparison of two string views. +/// @param a First string view +/// @param b Second string view +/// @return Whether two string views a equal case insensitive +constexpr bool case_insensitive_equal(std::string_view a, std::string_view b) +{ + if(a.size() != b.size()) + return false; + + for(size_t i = 0; i < a.size(); ++i) + { + char c1 = a[i]; + char c2 = b[i]; + + // Convert to lowercase for comparison + if(c1 >= 'A' && c1 <= 'Z') + c1 += 32; + if(c2 >= 'A' && c2 <= 'Z') + c2 += 32; + + if(c1 != c2) + return false; + } + return true; +} + } // namespace ck_tile::reflect::detail From 7e2adb19e6d2cdee5da2e4239035752ad5fc7f6b Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 3 Nov 2025 11:47:42 +0000 Subject: [PATCH 08/17] Fix function name overlapping with variable name. --- .../builder/include/ck_tile/builder/reflect/conv_traits.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index f66de240db..798a168dab 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -375,7 +375,7 @@ constexpr builder::ElementwiseOperation elementwise_op() /// @tparam Instance - A Device Kernel object type. /// @return A `builder::GemmPadding` enum value corresponding to kernel padding. template -constexpr builder::GemmPadding gemm_padding() +constexpr builder::GemmPadding gemm_spec() { using InstTraits = InstanceTraits; using enum builder::GemmPadding; @@ -483,7 +483,7 @@ struct ConvTraits elementwise_op(); /// @brief The GEMM specialization used by the kernel - padding - static constexpr auto gemm_padding = gemm_padding(); + static constexpr auto gemm_padding = gemm_spec(); /// @brief The convolution-specific specialization (e.g., Default, 1x1). static constexpr auto conv_specialization = conv_spec(); From b0b2614e3a4a4814ac222e797b6d79d853744192 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 3 Nov 2025 14:31:13 +0000 Subject: [PATCH 09/17] Unify pipeline version and scheduler enums. --- .../ck_tile/builder/reflect/conv_traits.hpp | 101 +++++++++++++++++- .../builder/include/ck_tile/builder/types.hpp | 24 +---- 2 files changed, 102 insertions(+), 23 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp index 798a168dab..a74d77d155 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -12,9 +13,97 @@ #include #include #include +#include +#include namespace ck_tile::reflect::conv { +// Helper metafunctions to convert from ck enums to builder enums + +/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum. +/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert. +/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V3, V4, or V5). +/// @details This function maps CK's block GEMM pipeline version identifiers to the +/// builder framework's standardized pipeline version enum. The pipeline version +/// determines the strategy used for data movement and computation overlap in the +/// GEMM kernel's main loop. +template +constexpr auto convert_pipeline_version() +{ + using enum ck::BlockGemmPipelineVersion; + using enum builder::PipelineVersion; + if constexpr(ck_ver == v1) + return V1; + else if constexpr(ck_ver == v2) + return V2; + else if constexpr(ck_ver == v3) + return V3; + else if constexpr(ck_ver == v4) + return V4; + else if constexpr(ck_ver == v5) + return V5; +} + +/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum. +/// @tparam ck_ver The CK PipelineVersion enum value to convert. +/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V4, or WEIGHT_ONLY). +/// @details This function maps CK's general pipeline version identifiers to the +/// builder framework's standardized pipeline version enum. Note that this overload +/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion +/// variant, including support for specialized weight-only pipelines. +template +constexpr auto convert_pipeline_version() +{ + using enum ck::PipelineVersion; + using enum builder::PipelineVersion; + if constexpr(ck_ver == v1) + return V1; + else if constexpr(ck_ver == v2) + return V2; + else if constexpr(ck_ver == v4) + return V4; + else if constexpr(ck_ver == weight_only) + return WEIGHT_ONLY; +} + +/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum. +/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert. +/// @return The corresponding builder::PipelineScheduler enum value (INTRAWAVE or INTERWAVE). +/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the +/// builder framework's standardized scheduler enum. The scheduler determines how work +/// is distributed and synchronized within and across wavefronts during pipeline execution. +/// INTRAWAVE scheduling operates within a single wavefront, while INTERWAVE coordinates +/// across multiple wavefronts. +template +constexpr auto convert_pipeline_scheduler() +{ + using enum ck::BlockGemmPipelineScheduler; + using enum builder::PipelineScheduler; + if constexpr(ck_sched == Intrawave) + return INTRAWAVE; + else if constexpr(ck_sched == Interwave) + return INTERWAVE; +} + +/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum. +/// @tparam ck_sched The CK LoopScheduler enum value to convert. +/// @return The corresponding builder::PipelineScheduler enum value (DEFAULT or INTERWAVE). +/// @details This function maps CK's loop scheduler identifiers to the builder framework's +/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of +/// the main computational loop are scheduled across threads. DEFAULT uses the standard +/// scheduling strategy, while INTERWAVE enables cross-wavefront coordination for improved +/// performance in certain scenarios. +template +constexpr auto convert_pipeline_scheduler() +{ + using enum ck::LoopScheduler; + using enum builder::PipelineScheduler; + if constexpr(ck_sched == Default) + return DEFAULT; + else if constexpr(ck_sched == Interwave) + return INTERWAVE; +} + /// @brief Helper structures for organizing trait data with domain-specific naming /// @brief Data tile dimensions processed by a workgroup. @@ -548,12 +637,12 @@ struct ConvTraits { if constexpr(requires { T::kPipelineVersion; }) { - return T::kPipelineVersion; + return convert_pipeline_version(); } else { // Return a default or indicate not available - return ck::BlockGemmPipelineVersion::v1; + return builder::PipelineVersion::V1; } } @@ -568,12 +657,16 @@ struct ConvTraits { if constexpr(requires { T::kPipelineScheduler; }) { - return T::kPipelineScheduler; + return convert_pipeline_scheduler(); + } + else if constexpr(requires { T::kLoopScheduler; }) + { + return convert_pipeline_scheduler(); } else { // Return a default or indicate not available - return ck::BlockGemmPipelineScheduler::Intrawave; + return builder::PipelineScheduler::DEFAULT; } } diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 84625a3378..3ee29e564c 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -128,29 +128,14 @@ enum class ElementwiseOperation PASS_THROUGH }; -// Enums for the current block GEMM pipeline versions. -enum class BlockGemmPipelineVersion +// Enums for pipeline versions & schedulers +enum class PipelineVersion { V1, V2, V3, V4, - V5 -}; - -enum struct BlockGemmPipelineScheduler -{ - INTRAWAVE, - INTERWAVE, -}; - -// Enums for the gridwise GEMM pipeline versions. -enum class GridwiseGemmPipelineVersion -{ - V1, - V2, - V3, // Only used in stream-K implementation - V4, + V5, WEIGHT_ONLY }; @@ -223,9 +208,10 @@ enum class GemmPadding MNKO_PADDING, }; -enum class LoopScheduler +enum class PipelineScheduler { DEFAULT, + INTRAWAVE, INTERWAVE }; From 325dbab6397e0b3312d3c3203cf02ecb8960f71f Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 3 Nov 2025 14:33:20 +0000 Subject: [PATCH 10/17] Fix includes. --- .../include/ck_tile/builder/reflect/instance_traits.hpp | 1 + ...aits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 1 + .../include/ck_tile/builder/reflect/instance_traits_util.hpp | 5 +++++ experimental/builder/test/CMakeLists.txt | 1 + experimental/builder/test/conv/test_conv_traits.cpp | 2 ++ .../device/convolution_backward_weight_specialization.hpp | 2 ++ 6 files changed, 12 insertions(+) diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp index cab61cdee8..29c687f491 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp @@ -16,6 +16,7 @@ #include #include +#include namespace ck_tile::reflect { diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 9f64c7d6e8..22b4f06772 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -14,6 +14,7 @@ #pragma once #include "instance_traits.hpp" +#include "instance_traits_util.hpp" // Forward declaration to avoid circular dependency. // This file will be included by the device implementation header, so we cannot include diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index 1ca4d213fd..74e9cc9be4 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -9,9 +9,14 @@ #include #include +#include #include #include #include +#include +#include +#include +#include #include #include #include diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 1b3650081f..8636ef93fd 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -65,6 +65,7 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test add_ck_builder_test(test_conv_traits conv/test_conv_traits.cpp) + # Function to add all test_ckb targets to a list function(collect_test_ckb_targets result_var) # Get all targets in current directory diff --git a/experimental/builder/test/conv/test_conv_traits.cpp b/experimental/builder/test/conv/test_conv_traits.cpp index b48a165550..f114493a0b 100644 --- a/experimental/builder/test/conv/test_conv_traits.cpp +++ b/experimental/builder/test/conv/test_conv_traits.cpp @@ -3,6 +3,8 @@ #include #include +#include + #include #include #include diff --git a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp index 01bb806789..219206c5ce 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp @@ -3,6 +3,8 @@ #pragma once +#include + namespace ck { namespace tensor_operation { namespace device { From 9c7464ed75669072c2570539dfd025bd23391aca Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 3 Nov 2025 14:33:52 +0000 Subject: [PATCH 11/17] Update test conv traits with unified enums. --- experimental/builder/test/conv/test_conv_traits.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/experimental/builder/test/conv/test_conv_traits.cpp b/experimental/builder/test/conv/test_conv_traits.cpp index f114493a0b..ca453d2ad4 100644 --- a/experimental/builder/test/conv/test_conv_traits.cpp +++ b/experimental/builder/test/conv/test_conv_traits.cpp @@ -74,7 +74,7 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) 8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock 8, // CDEBlockTransferScalarPerVector_NPerBlock ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched - ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer + ck::PipelineVersion::v1, // BlkGemmPipelineVer ck::half_t, // AComputeDataType ck::half_t, // BComputeDataType false>; // DirectLoad @@ -142,8 +142,8 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8); // Verify pipeline configuration - EXPECT_EQ(Traits::pipeline_scheduler, ck::BlockGemmPipelineScheduler::Intrawave); - EXPECT_EQ(Traits::pipeline_version, ck::BlockGemmPipelineVersion::v1); + EXPECT_EQ(Traits::pipeline_scheduler, ck_tile::builder::PipelineScheduler::INTRAWAVE); + EXPECT_EQ(Traits::pipeline_version, ck_tile::builder::PipelineVersion::V1); } // Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle From 42580762b182f0d10c6b9c0eef7150eaf192bca6 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 3 Nov 2025 14:34:24 +0000 Subject: [PATCH 12/17] Update concepts etc with update unified enum --- .../builder/conv_algorithm_concepts.hpp | 12 ++--- .../include/ck_tile/builder/conv_factory.hpp | 48 +++++++++---------- .../test/impl/conv_algorithm_types.hpp | 10 ++-- 3 files changed, 35 insertions(+), 35 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 586a119c75..ef252986d0 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -38,8 +38,8 @@ concept GridwiseXdlGemmDescriptor = requires(T t) { // Concept for parameter that describe block GEMM problem. template concept BlockGemmDescriptor = requires(T t) { - { t.pipeline_version } -> std::convertible_to; - { t.scheduler } -> std::convertible_to; + { t.pipeline_version } -> std::convertible_to; + { t.scheduler } -> std::convertible_to; }; // Concept for parameters that describe a gridwise WMMA GEMM problem. @@ -50,7 +50,7 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) { { t.n_per_wmma } -> std::convertible_to; { t.m_wmma_per_wave } -> std::convertible_to; { t.n_wmma_per_wave } -> std::convertible_to; - { t.pipeline_version } -> std::convertible_to; + { t.pipeline_version } -> std::convertible_to; }; // Concept for vectorized data transfer for convolution input tensors. @@ -154,8 +154,8 @@ concept SpecifiesSourceAccessOrder = requires(T t) { // Concept to check if struct specifies block GEMM. template concept SpecifiesBlockGemm = requires { - { T::block_gemm.pipeline_version } -> std::convertible_to; - { T::block_gemm.scheduler } -> std::convertible_to; + { T::block_gemm.pipeline_version } -> std::convertible_to; + { T::block_gemm.scheduler } -> std::convertible_to; }; template @@ -180,7 +180,7 @@ concept SpecifiesNumGroupsToMerge = requires { template concept SpecifiesLoopScheduler = requires { - { T::loop_scheduler } -> std::convertible_to; + { T::loop_scheduler } -> std::convertible_to; }; } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 8ea3e18d65..83baeadbb0 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -297,42 +297,42 @@ constexpr BlockGemmSpec SetBlockGemm() ck::BlockGemmPipelineScheduler scheduler; ck::BlockGemmPipelineVersion version; - if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTRAWAVE) + if constexpr(BG.scheduler == PipelineScheduler::INTRAWAVE) { scheduler = ck::BlockGemmPipelineScheduler::Intrawave; } - else if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTERWAVE) + else if constexpr(BG.scheduler == PipelineScheduler::INTERWAVE) { scheduler = ck::BlockGemmPipelineScheduler::Interwave; } else { - static_assert(false, "Unknown BlockGemmPipelineScheduler"); + static_assert(false, "Unknown PipelineScheduler"); } - if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V1) + if constexpr(BG.pipeline_version == PipelineVersion::V1) { version = ck::BlockGemmPipelineVersion::v1; } - else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V2) + else if constexpr(BG.pipeline_version == PipelineVersion::V2) { version = ck::BlockGemmPipelineVersion::v2; } - else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V3) + else if constexpr(BG.pipeline_version == PipelineVersion::V3) { version = ck::BlockGemmPipelineVersion::v3; } - else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V4) + else if constexpr(BG.pipeline_version == PipelineVersion::V4) { version = ck::BlockGemmPipelineVersion::v4; } - else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V5) + else if constexpr(BG.pipeline_version == PipelineVersion::V5) { version = ck::BlockGemmPipelineVersion::v5; } else { - static_assert(false, "Unknown BlockGemmPipelineVersion"); + static_assert(false, "Unknown PipelineVersion"); } return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler}; @@ -442,17 +442,17 @@ consteval ck::LoopScheduler SetLoopScheduler() { constexpr auto loop_scheduler = ALGORITHM.loop_scheduler; - if constexpr(loop_scheduler == LoopScheduler::DEFAULT) + if constexpr(loop_scheduler == PipelineScheduler::DEFAULT) { return ck::LoopScheduler::Default; } - else if constexpr(loop_scheduler == LoopScheduler::INTERWAVE) + else if constexpr(loop_scheduler == PipelineScheduler::INTERWAVE) { return ck::LoopScheduler::Interwave; } else { - static_assert(false, "Unknown LoopScheduler"); + static_assert(false, "Unknown PipelineScheduler"); } } @@ -460,29 +460,29 @@ template consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion() { constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version; - if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V1) + if constexpr(pipeline_version == PipelineVersion::V1) { return ck::PipelineVersion::v1; } - else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V2) + else if constexpr(pipeline_version == PipelineVersion::V2) { return ck::PipelineVersion::v2; } - else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V3) + else if constexpr(pipeline_version == PipelineVersion::V3) { static_assert(false, "V3 is used only for stream-K."); } - else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V4) + else if constexpr(pipeline_version == PipelineVersion::V4) { return ck::PipelineVersion::v4; } - else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::WEIGHT_ONLY) + else if constexpr(pipeline_version == PipelineVersion::WEIGHT_ONLY) { return ck::PipelineVersion::weight_only; } else { - static_assert(false, "Unknown GridwiseGemmPipelineVersion"); + static_assert(false, "Unknown PipelineVersion"); } } @@ -566,29 +566,29 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion() { constexpr auto version = ALGORITHM.pipeline_version; - if constexpr(version == BlockGemmPipelineVersion::V1) + if constexpr(version == PipelineVersion::V1) { return ck::BlockGemmPipelineVersion::v1; } - else if constexpr(version == BlockGemmPipelineVersion::V2) + else if constexpr(version == PipelineVersion::V2) { return ck::BlockGemmPipelineVersion::v2; } - else if constexpr(version == BlockGemmPipelineVersion::V3) + else if constexpr(version == PipelineVersion::V3) { return ck::BlockGemmPipelineVersion::v3; } - else if constexpr(version == BlockGemmPipelineVersion::V4) + else if constexpr(version == PipelineVersion::V4) { return ck::BlockGemmPipelineVersion::v4; } - else if constexpr(version == BlockGemmPipelineVersion::V5) + else if constexpr(version == PipelineVersion::V5) { return ck::BlockGemmPipelineVersion::v5; } else { - static_assert(false, "Unknown BlockGemmPipelineVersion"); + static_assert(false, "Unknown PipelineVersion"); } } diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 9c5ca9b97b..d22c40e4c2 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -49,14 +49,14 @@ struct GridwiseWmmaGemm size_t n_per_wmma = 0; size_t m_wmma_per_wave = 0; size_t n_wmma_per_wave = 0; - GridwiseGemmPipelineVersion pipeline_version; + PipelineVersion pipeline_version; }; static_assert(ckb::GridwiseWmmaGemmDescriptor); struct BlockGemm { - BlockGemmPipelineVersion pipeline_version; - BlockGemmPipelineScheduler scheduler; + PipelineVersion pipeline_version; + PipelineScheduler scheduler; }; static_assert(ckb::BlockGemmDescriptor); @@ -156,7 +156,7 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle GemmSpecialization gemm_specialization; size_t num_gemm_k_prefetch_stages; size_t num_groups_to_merge; - LoopScheduler loop_scheduler; + PipelineScheduler loop_scheduler; }; static_assert( ckb::ConvAlgorithmDescriptor); @@ -191,7 +191,7 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ConvFwdSpecialization fwd_specialization; GemmSpecialization gemm_specialization; size_t num_gemm_k_prefetch_stages; - LoopScheduler loop_scheduler; + PipelineScheduler loop_scheduler; }; static_assert( ckb::ConvAlgorithmDescriptor); From b28c7cea12dc1bc5fb6e02f61bf682939af86527 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 3 Nov 2025 14:34:55 +0000 Subject: [PATCH 13/17] Fix ckb conv fwd test - unified enum usage. --- .../test/conv/test_ckb_conv_fwd_1d_bf16.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_2d_bf16.cpp | 4 ++-- .../test/conv/test_ckb_conv_fwd_2d_fp16.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_2d_fp32.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_3d_bf16.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_3d_fp16.cpp | 2 +- .../test/conv/test_ckb_conv_fwd_3d_fp32.cpp | 2 +- .../test/utils/ckb_conv_test_common.hpp | 18 +++++++++--------- 8 files changed, 17 insertions(+), 17 deletions(-) diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp index 472c43438d..388edee873 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp @@ -24,7 +24,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< FwdConvSignature, FwdThreadBlock, - BlockGemmPipelineVersion::V2, + PipelineVersion::V2, ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp index b9969f7e95..6420da951d 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp @@ -22,7 +22,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); } @@ -44,7 +44,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp index cd5186cc10..fd8e8718c4 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp @@ -22,7 +22,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< FwdConvSignature, FwdThreadBlock, - BlockGemmPipelineVersion::V3, + PipelineVersion::V3, ConvFwdSpecialization::FILTER_1X1_PAD0>(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp index 584e0ab182..a4e1c1de38 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -22,7 +22,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< FwdConvSignature, FwdThreadBlock, - BlockGemmPipelineVersion::V4, + PipelineVersion::V4, ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp index 17caf98457..9bfbdbb838 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp @@ -22,7 +22,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp index ec4649a6ff..b2bff4cde9 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< FwdConvSignature, FwdThreadBlock, - BlockGemmPipelineVersion::V4, + PipelineVersion::V4, ConvFwdSpecialization::FILTER_1X1_PAD0>(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp index 393ea9206d..df237d6a0c 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< FwdConvSignature, FwdThreadBlock, - BlockGemmPipelineVersion::V1, + PipelineVersion::V1, ConvFwdSpecialization::FILTER_1X1_PAD0>(); } diff --git a/experimental/builder/test/utils/ckb_conv_test_common.hpp b/experimental/builder/test/utils/ckb_conv_test_common.hpp index d18a008015..5813ab565a 100644 --- a/experimental/builder/test/utils/ckb_conv_test_common.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_common.hpp @@ -16,7 +16,7 @@ using namespace test; // Common test implementation template constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3() { @@ -52,7 +52,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3() .src_access_order_b = {1, 0, 2}}; constexpr BlockGemm BlockGemmDesc = {.pipeline_version = FwdPipelineVersion, - .scheduler = BlockGemmPipelineScheduler::INTRAWAVE}; + .scheduler = PipelineScheduler::INTRAWAVE}; constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ .thread_block = FwdThreadBlock, @@ -73,13 +73,13 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3() EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3")); // Verify pipeline version is correct - if(FwdPipelineVersion == BlockGemmPipelineVersion::V1) + if(FwdPipelineVersion == PipelineVersion::V1) EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v1") != std::string::npos); - else if(FwdPipelineVersion == BlockGemmPipelineVersion::V3) + else if(FwdPipelineVersion == PipelineVersion::V3) EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v3") != std::string::npos); - else if(FwdPipelineVersion == BlockGemmPipelineVersion::V4) + else if(FwdPipelineVersion == PipelineVersion::V4) EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v4") != std::string::npos); - else if(FwdPipelineVersion == BlockGemmPipelineVersion::V5) + else if(FwdPipelineVersion == PipelineVersion::V5) EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v5") != std::string::npos); // Verify specialization is correct @@ -140,7 +140,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle() .gemm_specialization = GemmSpecialization::MNKPadding, .num_gemm_k_prefetch_stages = 1, .num_groups_to_merge = 2, - .loop_scheduler = LoopScheduler::DEFAULT}; + .loop_scheduler = PipelineScheduler::DEFAULT}; using Builder = ConvBuilder; @@ -176,7 +176,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle() .n_per_wmma = 32, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1, - .pipeline_version = GridwiseGemmPipelineVersion::V1}; + .pipeline_version = PipelineVersion::V1}; constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1}, .block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1}, @@ -209,7 +209,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle() .fwd_specialization = FwdConvSpecialization, .gemm_specialization = GemmSpecialization::MNKPadding, .num_gemm_k_prefetch_stages = 1, - .loop_scheduler = LoopScheduler::DEFAULT}; + .loop_scheduler = PipelineScheduler::DEFAULT}; using Builder = ConvBuilder; From 3d6678b3fd763b3ba458db20f8a25611b2cd8c6a Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Tue, 4 Nov 2025 10:35:18 +0100 Subject: [PATCH 14/17] Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../include/ck_tile/builder/reflect/instance_traits_util.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index 74e9cc9be4..84011290c2 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -13,7 +13,7 @@ #include #include #include -#include +#include #include #include #include From cef7681ff28102b37f447f8727d73a5482a8bb21 Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Tue, 4 Nov 2025 10:36:38 +0100 Subject: [PATCH 15/17] Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- experimental/builder/include/ck_tile/builder/types.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index 3ee29e564c..875a9c8816 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -171,14 +171,14 @@ enum class ConvFwdSpecialization FILTER_3x3 }; -// Enums for the bacward data convolution specialization. +// Enums for the backward data convolution specialization. enum class ConvBwdDataSpecialization { DEFAULT, FILTER_1X1_STRIDE1_PAD0, }; -// Enums for the bacward weight convolution specialization. +// Enums for the backward weight convolution specialization. enum class ConvBwdWeightSpecialization { DEFAULT, From 5b09fd64fea5a1ff34e8ae5f4cb3885d9da303b8 Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Tue, 4 Nov 2025 10:36:51 +0100 Subject: [PATCH 16/17] Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../include/ck_tile/builder/reflect/instance_traits_util.hpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index 84011290c2..14aeb5074a 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -14,9 +14,6 @@ #include #include #include -#include -#include -#include #include #include #include From 5d5ed41a724d7b7f0ebbb49e1efa9e8be21bd12c Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Tue, 4 Nov 2025 10:37:10 +0100 Subject: [PATCH 17/17] Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- experimental/builder/test/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 8636ef93fd..c7bfac3c16 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -20,7 +20,6 @@ endfunction() add_ck_builder_test(test_ckb_conv_builder test_conv_builder.cpp test_fwd_instance_traits.cpp - test_fwd_instance_traits.cpp test_instance_traits_util.cpp) add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp)