diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index 586a119c75..ef252986d0 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -38,8 +38,8 @@ concept GridwiseXdlGemmDescriptor = requires(T t) { // Concept for parameter that describe block GEMM problem. template concept BlockGemmDescriptor = requires(T t) { - { t.pipeline_version } -> std::convertible_to; - { t.scheduler } -> std::convertible_to; + { t.pipeline_version } -> std::convertible_to; + { t.scheduler } -> std::convertible_to; }; // Concept for parameters that describe a gridwise WMMA GEMM problem. @@ -50,7 +50,7 @@ concept GridwiseWmmaGemmDescriptor = requires(T t) { { t.n_per_wmma } -> std::convertible_to; { t.m_wmma_per_wave } -> std::convertible_to; { t.n_wmma_per_wave } -> std::convertible_to; - { t.pipeline_version } -> std::convertible_to; + { t.pipeline_version } -> std::convertible_to; }; // Concept for vectorized data transfer for convolution input tensors. @@ -154,8 +154,8 @@ concept SpecifiesSourceAccessOrder = requires(T t) { // Concept to check if struct specifies block GEMM. template concept SpecifiesBlockGemm = requires { - { T::block_gemm.pipeline_version } -> std::convertible_to; - { T::block_gemm.scheduler } -> std::convertible_to; + { T::block_gemm.pipeline_version } -> std::convertible_to; + { T::block_gemm.scheduler } -> std::convertible_to; }; template @@ -180,7 +180,7 @@ concept SpecifiesNumGroupsToMerge = requires { template concept SpecifiesLoopScheduler = requires { - { T::loop_scheduler } -> std::convertible_to; + { T::loop_scheduler } -> std::convertible_to; }; } // namespace ck_tile::builder diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index 8ea3e18d65..83baeadbb0 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -297,42 +297,42 @@ constexpr BlockGemmSpec SetBlockGemm() ck::BlockGemmPipelineScheduler scheduler; ck::BlockGemmPipelineVersion version; - if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTRAWAVE) + if constexpr(BG.scheduler == PipelineScheduler::INTRAWAVE) { scheduler = ck::BlockGemmPipelineScheduler::Intrawave; } - else if constexpr(BG.scheduler == BlockGemmPipelineScheduler::INTERWAVE) + else if constexpr(BG.scheduler == PipelineScheduler::INTERWAVE) { scheduler = ck::BlockGemmPipelineScheduler::Interwave; } else { - static_assert(false, "Unknown BlockGemmPipelineScheduler"); + static_assert(false, "Unknown PipelineScheduler"); } - if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V1) + if constexpr(BG.pipeline_version == PipelineVersion::V1) { version = ck::BlockGemmPipelineVersion::v1; } - else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V2) + else if constexpr(BG.pipeline_version == PipelineVersion::V2) { version = ck::BlockGemmPipelineVersion::v2; } - else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V3) + else if constexpr(BG.pipeline_version == PipelineVersion::V3) { version = ck::BlockGemmPipelineVersion::v3; } - else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V4) + else if constexpr(BG.pipeline_version == PipelineVersion::V4) { version = ck::BlockGemmPipelineVersion::v4; } - else if constexpr(BG.pipeline_version == BlockGemmPipelineVersion::V5) + else if constexpr(BG.pipeline_version == PipelineVersion::V5) { version = ck::BlockGemmPipelineVersion::v5; } else { - static_assert(false, "Unknown BlockGemmPipelineVersion"); + static_assert(false, "Unknown PipelineVersion"); } return BlockGemmSpec{.pipeline_version = version, .scheduler = scheduler}; @@ -442,17 +442,17 @@ consteval ck::LoopScheduler SetLoopScheduler() { constexpr auto loop_scheduler = ALGORITHM.loop_scheduler; - if constexpr(loop_scheduler == LoopScheduler::DEFAULT) + if constexpr(loop_scheduler == PipelineScheduler::DEFAULT) { return ck::LoopScheduler::Default; } - else if constexpr(loop_scheduler == LoopScheduler::INTERWAVE) + else if constexpr(loop_scheduler == PipelineScheduler::INTERWAVE) { return ck::LoopScheduler::Interwave; } else { - static_assert(false, "Unknown LoopScheduler"); + static_assert(false, "Unknown PipelineScheduler"); } } @@ -460,29 +460,29 @@ template consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion() { constexpr auto pipeline_version = ALGORITHM.gridwise_gemm.pipeline_version; - if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V1) + if constexpr(pipeline_version == PipelineVersion::V1) { return ck::PipelineVersion::v1; } - else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V2) + else if constexpr(pipeline_version == PipelineVersion::V2) { return ck::PipelineVersion::v2; } - else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V3) + else if constexpr(pipeline_version == PipelineVersion::V3) { static_assert(false, "V3 is used only for stream-K."); } - else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::V4) + else if constexpr(pipeline_version == PipelineVersion::V4) { return ck::PipelineVersion::v4; } - else if constexpr(pipeline_version == GridwiseGemmPipelineVersion::WEIGHT_ONLY) + else if constexpr(pipeline_version == PipelineVersion::WEIGHT_ONLY) { return ck::PipelineVersion::weight_only; } else { - static_assert(false, "Unknown GridwiseGemmPipelineVersion"); + static_assert(false, "Unknown PipelineVersion"); } } @@ -566,29 +566,29 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion() { constexpr auto version = ALGORITHM.pipeline_version; - if constexpr(version == BlockGemmPipelineVersion::V1) + if constexpr(version == PipelineVersion::V1) { return ck::BlockGemmPipelineVersion::v1; } - else if constexpr(version == BlockGemmPipelineVersion::V2) + else if constexpr(version == PipelineVersion::V2) { return ck::BlockGemmPipelineVersion::v2; } - else if constexpr(version == BlockGemmPipelineVersion::V3) + else if constexpr(version == PipelineVersion::V3) { return ck::BlockGemmPipelineVersion::v3; } - else if constexpr(version == BlockGemmPipelineVersion::V4) + else if constexpr(version == PipelineVersion::V4) { return ck::BlockGemmPipelineVersion::v4; } - else if constexpr(version == BlockGemmPipelineVersion::V5) + else if constexpr(version == PipelineVersion::V5) { return ck::BlockGemmPipelineVersion::v5; } else { - static_assert(false, "Unknown BlockGemmPipelineVersion"); + static_assert(false, "Unknown PipelineVersion"); } } diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp new file mode 100644 index 0000000000..a74d77d155 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits.hpp @@ -0,0 +1,719 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ck_tile::reflect::conv { + +// Helper metafunctions to convert from ck enums to builder enums + +/// @brief Converts a CK BlockGemmPipelineVersion enum to a builder PipelineVersion enum. +/// @tparam ck_ver The CK BlockGemmPipelineVersion enum value to convert. +/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V3, V4, or V5). +/// @details This function maps CK's block GEMM pipeline version identifiers to the +/// builder framework's standardized pipeline version enum. The pipeline version +/// determines the strategy used for data movement and computation overlap in the +/// GEMM kernel's main loop. +template +constexpr auto convert_pipeline_version() +{ + using enum ck::BlockGemmPipelineVersion; + using enum builder::PipelineVersion; + if constexpr(ck_ver == v1) + return V1; + else if constexpr(ck_ver == v2) + return V2; + else if constexpr(ck_ver == v3) + return V3; + else if constexpr(ck_ver == v4) + return V4; + else if constexpr(ck_ver == v5) + return V5; +} + +/// @brief Converts a CK PipelineVersion enum to a builder PipelineVersion enum. +/// @tparam ck_ver The CK PipelineVersion enum value to convert. +/// @return The corresponding builder::PipelineVersion enum value (V1, V2, V4, or WEIGHT_ONLY). +/// @details This function maps CK's general pipeline version identifiers to the +/// builder framework's standardized pipeline version enum. Note that this overload +/// handles a different set of pipeline versions compared to the BlockGemmPipelineVersion +/// variant, including support for specialized weight-only pipelines. +template +constexpr auto convert_pipeline_version() +{ + using enum ck::PipelineVersion; + using enum builder::PipelineVersion; + if constexpr(ck_ver == v1) + return V1; + else if constexpr(ck_ver == v2) + return V2; + else if constexpr(ck_ver == v4) + return V4; + else if constexpr(ck_ver == weight_only) + return WEIGHT_ONLY; +} + +/// @brief Converts a CK BlockGemmPipelineScheduler enum to a builder PipelineScheduler enum. +/// @tparam ck_sched The CK BlockGemmPipelineScheduler enum value to convert. +/// @return The corresponding builder::PipelineScheduler enum value (INTRAWAVE or INTERWAVE). +/// @details This function maps CK's block GEMM pipeline scheduler identifiers to the +/// builder framework's standardized scheduler enum. The scheduler determines how work +/// is distributed and synchronized within and across wavefronts during pipeline execution. +/// INTRAWAVE scheduling operates within a single wavefront, while INTERWAVE coordinates +/// across multiple wavefronts. +template +constexpr auto convert_pipeline_scheduler() +{ + using enum ck::BlockGemmPipelineScheduler; + using enum builder::PipelineScheduler; + if constexpr(ck_sched == Intrawave) + return INTRAWAVE; + else if constexpr(ck_sched == Interwave) + return INTERWAVE; +} + +/// @brief Converts a CK LoopScheduler enum to a builder PipelineScheduler enum. +/// @tparam ck_sched The CK LoopScheduler enum value to convert. +/// @return The corresponding builder::PipelineScheduler enum value (DEFAULT or INTERWAVE). +/// @details This function maps CK's loop scheduler identifiers to the builder framework's +/// standardized pipeline scheduler enum. The loop scheduler controls how iterations of +/// the main computational loop are scheduled across threads. DEFAULT uses the standard +/// scheduling strategy, while INTERWAVE enables cross-wavefront coordination for improved +/// performance in certain scenarios. +template +constexpr auto convert_pipeline_scheduler() +{ + using enum ck::LoopScheduler; + using enum builder::PipelineScheduler; + if constexpr(ck_sched == Default) + return DEFAULT; + else if constexpr(ck_sched == Interwave) + return INTERWAVE; +} + +/// @brief Helper structures for organizing trait data with domain-specific naming + +/// @brief Data tile dimensions processed by a workgroup. +/// @details This struct defines the M, N, and K dimensions of the data tile +/// that a single workgroup (thread block) is responsible for processing in the +/// underlying GEMM computation. +struct DataTileInfo +{ + int m; ///< M dimension of the tile processed by the workgroup (MPerBlock). + int n; ///< N dimension of the tile processed by the workgroup (NPerBlock). + int k; ///< K dimension of the tile processed by the workgroup (KPerBlock). +}; + +/// @brief Dimensions for an input data tile transfer. +/// @details Defines the shape of the input tile (A or B matrix) as it is +/// transferred from global memory to LDS. The tile is conceptually divided +/// into k0 and k1 dimensions. +struct InputTileTransferDimensions +{ + int k0; ///< The outer dimension of K, where K = k0 * k1. + int m_or_n; ///< The M dimension for the A matrix transfer, or the N dimension for the B matrix. + int k1; ///< The inner dimension of K, often corresponding to the vector load size from global + ///< memory. +}; + +/// @brief Parameters governing the transfer of an input tile. +/// @details This struct holds configuration details for how an input tile is +/// loaded from global memory into LDS, including thread clustering, memory +/// access patterns, and vectorization settings. +struct InputTileTransferParams +{ + int k1; ///< The inner K dimension size, often matching the vectorization width. + std::array + thread_cluster_dims; ///< Spatial thread distribution over the input data tile; defines how + ///< many threads are arranged on each axis. + std::array thread_cluster_order; ///< The order of thread spatial distribution over the + ///< input tensor dimensions. + std::array src_access_order; ///< The order of accessing input tensor axes (e.g., which + ///< dimension to read first). + int src_vector_dim; ///< The index of the axis on which vectorized memory access is performed + ///< (the contiguous dimension). + int src_scalar_per_vector; ///< The size of the vector access instruction; the number of + ///< elements accessed per thread per instruction. + int dst_scalar_per_vector_k1; ///< The size of the vectorized store into LDS memory along the K1 + ///< dimension. + bool lds_padding; ///< Flag indicating if padding is used for the LDS tensor to prevent bank + ///< conflicts. +}; + +/// @brief Complete information for an input tile transfer. +/// @details Combines the dimensional information and transfer parameters for +/// a full description of an input tile's journey from global memory to LDS. +struct InputTileTransferInfo +{ + InputTileTransferDimensions tile_dimensions; ///< The shape and layout of the tile. + InputTileTransferParams transfer_params; ///< The parameters for the memory transfer operation. +}; + +/// @brief Parameters for the warp-level GEMM computation. +/// @details Defines the configuration of the GEMM operation performed by each +/// warp using hardware MFMA (Matrix Fused Multiply-Add) instructions. +struct WarpGemmParams +{ + int gemm_m; ///< The M dimension of a single MFMA instruction (MPerXdl). + int gemm_n; ///< The N dimension of a single MFMA instruction (NPerXdl). + int m_iter; ///< The number of MFMA iterations along the M dimension of the output tile per + ///< wavefront (MXdlPerWave). + int n_iter; ///< The number of MFMA iterations along the N dimension of the output tile per + ///< wavefront (NXdlPerWave). +}; + +/// @brief Parameters for shuffling data between warps (CShuffle optimization). +/// @details Configures how many MFMA instruction results are processed per +/// wave in each iteration of the CShuffle routine. +struct WarpShuffleParams +{ + int m_gemms_per_shuffle; ///< Number of MFMA results along the M dimension to process per wave + ///< per shuffle iteration. + int n_gemms_per_shuffle; ///< Number of MFMA results along the N dimension to process per wave + ///< per shuffle iteration. +}; + +/// @brief Information for the output tile transfer (CShuffle). +/// @details Describes how the final computed tile (C matrix) is written out from +/// LDS to global memory, including shuffling, thread clustering, and vectorization. +struct OutputTileTransferInfo +{ + WarpShuffleParams shuffle_params; ///< Configuration for cross-warp data shuffling. + // m_block, m_wave_per_xdl, n_block, n_wave_per_xdl + std::array thread_cluster_dims; ///< The spatial thread distribution used for storing + ///< data into the output tensor. + int scalar_per_vector; ///< The size of the vectorized memory access when storing data to the + ///< output tensor. +}; + +// Helper metafunctions to derive signature information from Instance types + +/// @brief Derives the convolution direction from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return A `builder::ConvDirection` enum value (FORWARD, BACKWARD_DATA, or BACKWARD_WEIGHT). +template +constexpr builder::ConvDirection conv_direction() +{ + using InstTraits = InstanceTraits; + + if constexpr(requires { &InstTraits::kConvForwardSpecialization; }) + { + return builder::ConvDirection::FORWARD; + } + else if constexpr(requires { &InstTraits::kConvBwdDataSpecialization; }) + { + return builder::ConvDirection::BACKWARD_DATA; + } + else if constexpr(requires { &InstTraits::kConvBwdWeightSpecialization; }) + { + return builder::ConvDirection::BACKWARD_WEIGHT; + } + else + { + return builder::ConvDirection::FORWARD; // Default fallback + } +} + +/// @brief Derives the convolution-specific specialization from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return A `builder::ConvFwdSpecialization`, `builder::ConvBwdDataSpecialization`, or +/// `builder::ConvBwdWeightSpecialization` enum value. +template +constexpr auto conv_spec() +{ + using InstTraits = InstanceTraits; + + if constexpr(requires { InstTraits::kConvForwardSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionForwardSpecialization; + + if constexpr(InstTraits::kConvForwardSpecialization == Default) + { + return builder::ConvFwdSpecialization::DEFAULT; + } + else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Pad0) + { + return builder::ConvFwdSpecialization::FILTER_1X1_PAD0; + } + else if constexpr(InstTraits::kConvForwardSpecialization == Filter1x1Stride1Pad0) + { + return builder::ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0; + } + else if constexpr(InstTraits::kConvForwardSpecialization == Filter3x3) + { + return builder::ConvFwdSpecialization::FILTER_3x3; + } + } + else if constexpr(requires { InstTraits::kConvBwdDataSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionBackwardDataSpecialization; + + if constexpr(InstTraits::kConvBwdDataSpecialization == Default) + { + return builder::ConvBwdDataSpecialization::DEFAULT; + } + else if constexpr(InstTraits::kConvBwdDataSpecialization == Filter1x1Stride1Pad0) + { + return builder::ConvBwdDataSpecialization::FILTER_1X1_STRIDE1_PAD0; + } + } + else if constexpr(requires { InstTraits::kConvBwdWeightSpecialization; }) + { + using enum ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization; + + if constexpr(InstTraits::kConvBwdWeightSpecialization == Default) + { + return builder::ConvBwdWeightSpecialization::DEFAULT; + } + else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Stride1Pad0) + { + return builder::ConvBwdWeightSpecialization::FILTER_1X1_STRIDE1_PAD0; + } + else if constexpr(InstTraits::kConvBwdWeightSpecialization == Filter1x1Pad0) + { + return builder::ConvBwdWeightSpecialization::FILTER_1X1_PAD0; + } + else if constexpr(InstTraits::kConvBwdWeightSpecialization == OddC) + { + return builder::ConvBwdWeightSpecialization::ODD_C; + } + } +} + +/// @brief Derives the grouped convolution layout from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return A `builder::GroupConvLayout{1D|2D|3D}` enum value corresponding to the tensor layouts. +template +constexpr auto conv_layout() +{ + using InstTraits = InstanceTraits; + using ALayout = typename InstTraits::ALayout; + using BLayout = typename InstTraits::BLayout; + using ELayout = typename InstTraits::ELayout; + + namespace ctc = ck::tensor_layout::convolution; + + if constexpr(InstTraits::kSpatialDim == 1) + { + if constexpr(std::is_same_v && std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout1D::GNWC_GKXC_GNWK; + } + else if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + return builder::GroupConvLayout1D::NWGC_GKXC_NWGK; + } + else if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + return builder::GroupConvLayout1D::NGCW_GKXC_NGKW; + } + else if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) + { + return builder::GroupConvLayout1D::NGCW_GKCX_NGKW; + } + } + else if constexpr(InstTraits::kSpatialDim == 2) + { + if constexpr(std::is_same_v && std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout2D::NHWGC_GKYXC_NHWGK; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout2D::NGCHW_GKYXC_NGKHW; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout2D::NGCHW_GKCYX_NGKHW; + } + } + else if constexpr(InstTraits::kSpatialDim == 3) + { + if constexpr(std::is_same_v && std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return builder::GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW; + } + } +} + +/// @brief Derives the data type from a device kernel `Instance` type. +/// @tparam Instance The device kernel instance type. +/// @return A `builder::DataType` enum value (e.g., FP16, BF16, FP32). +template +constexpr builder::DataType conv_data_type() +{ + using InstTraits = InstanceTraits; + using ADataType = typename InstTraits::ADataType; + + if constexpr(std::is_same_v) + { + return builder::DataType::FP16; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::BF16; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::FP32; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::FP8; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::I8; + } + else if constexpr(std::is_same_v) + { + return builder::DataType::U8; + } + else + { + // Default fallback + return builder::DataType::FP32; + } +} + +/// @brief Derives the elementwise operation from op type. +/// @tparam ElementwiseOp Elementwise operation functor type. +/// @return A `builder::ElementwiseOperation` enum value corresponding to elementwise operation. +template +constexpr builder::ElementwiseOperation elementwise_op() +{ + constexpr std::string_view name = detail::elementwise_op_name(); + if constexpr(detail::case_insensitive_equal(name, "Bias")) + { + return builder::ElementwiseOperation::BIAS; + } + else if constexpr(detail::case_insensitive_equal(name, "BiasClamp")) + { + return builder::ElementwiseOperation::BIAS_CLAMP; + } + else if constexpr(detail::case_insensitive_equal(name, "BiasBnormClamp")) + { + return builder::ElementwiseOperation::BIAS_BNORM_CLAMP; + } + else if constexpr(detail::case_insensitive_equal(name, "Bilinear")) + { + return builder::ElementwiseOperation::BILINEAR; + } + else if constexpr(detail::case_insensitive_equal(name, "Clamp")) + { + return builder::ElementwiseOperation::CLAMP; + } + else if constexpr(detail::case_insensitive_equal(name, "Scale")) + { + return builder::ElementwiseOperation::SCALE; + } + else if constexpr(detail::case_insensitive_equal(name, "PassThrough")) + { + return builder::ElementwiseOperation::PASS_THROUGH; + } +} + +/// @brief Derives a gemm padding from a kernel instance type. +/// @tparam Instance - A Device Kernel object type. +/// @return A `builder::GemmPadding` enum value corresponding to kernel padding. +template +constexpr builder::GemmPadding gemm_spec() +{ + using InstTraits = InstanceTraits; + using enum builder::GemmPadding; + using enum ck::tensor_operation::device::GemmSpecialization; + + constexpr auto gemm_spec = InstTraits::kGemmSpecialization; + + if constexpr(gemm_spec == Default) + { + return DEFAULT; + } + else if constexpr(gemm_spec == MPadding) + { + return M_PADDING; + } + else if constexpr(gemm_spec == NPadding) + { + return N_PADDING; + } + else if constexpr(gemm_spec == KPadding) + { + return K_PADDING; + } + else if constexpr(gemm_spec == MNPadding) + { + return MN_PADDING; + } + else if constexpr(gemm_spec == MKPadding) + { + return MK_PADDING; + } + else if constexpr(gemm_spec == NKPadding) + { + return NK_PADDING; + } + else if constexpr(gemm_spec == MNKPadding) + { + return MNK_PADDING; + } + else if constexpr(gemm_spec == OPadding) + { + return O_PADDING; + } + else if constexpr(gemm_spec == MOPadding) + { + return MO_PADDING; + } + else if constexpr(gemm_spec == NOPadding) + { + return NO_PADDING; + } + else if constexpr(gemm_spec == KOPadding) + { + return KO_PADDING; + } + else if constexpr(gemm_spec == MNOPadding) + { + return MNO_PADDING; + } + else if constexpr(gemm_spec == MKOPadding) + { + return MKO_PADDING; + } + else if constexpr(gemm_spec == NKOPadding) + { + return NKO_PADDING; + } + else if constexpr(gemm_spec == MNKOPadding) + { + return MNKO_PADDING; + } +} + +/// @brief Primary template for extracting convolution traits. +/// @details This struct is the main entry point for reflecting on a convolution +/// kernel's properties. It is specialized to handle different kinds of input types. +template +struct ConvTraits; + +/// @brief Specialization of `ConvTraits` for a direct device kernel `Instance`. +/// @details This is the primary specialization used to extract a comprehensive +/// set of traits directly from a fully-formed device kernel `Instance` type. +/// It uses `InstanceTraits` to access the kernel's template parameters. +template + requires requires { typename InstanceTraits; } +struct ConvTraits +{ + using InstTraits = InstanceTraits; + + // --- Signature Information --- + /// @brief The number of spatial dimensions in the convolution (1, 2, or 3). + static constexpr int spatial_dim = InstTraits::kSpatialDim; + /// @brief The direction of the convolution (Forward, Backward Data, or Backward Weight). + static constexpr builder::ConvDirection direction = conv_direction(); + /// @brief The memory layout of the convolution tensors (e.g., GNHWC_GKYXC_GNHWK). + static constexpr auto layout = conv_layout(); + /// @brief The primary data type used in the computation (e.g., FP16, FP32). + static constexpr builder::DataType data_type = conv_data_type(); + + static constexpr builder::ElementwiseOperation input_element_op = + elementwise_op(); + static constexpr builder::ElementwiseOperation weight_element_op = + elementwise_op(); + static constexpr builder::ElementwiseOperation output_element_op = + elementwise_op(); + + /// @brief The GEMM specialization used by the kernel - padding + static constexpr auto gemm_padding = gemm_spec(); + /// @brief The convolution-specific specialization (e.g., Default, 1x1). + static constexpr auto conv_specialization = conv_spec(); + + // --- Algorithm Information --- + /// @brief The total number of threads in a thread block (workgroup). + static constexpr int thread_block_size = InstTraits::kBlockSize; + /// @brief The dimensions of the data tile processed by the thread block. + static constexpr DataTileInfo tile_dims = { + .m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock}; + + /// @brief Configuration for the A-matrix (input) tile transfer. + static constexpr InputTileTransferInfo a_tile_transfer = { + .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, + .m_or_n = InstTraits::kMPerBlock, + .k1 = InstTraits::kAK1}, + .transfer_params = {.k1 = InstTraits::kAK1, + .thread_cluster_dims = InstTraits::kAThreadClusterLengths, + .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, + .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, + .src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kABlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}; + + /// @brief Configuration for the B-matrix (weights) tile transfer. + static constexpr InputTileTransferInfo b_tile_transfer = { + .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, + .m_or_n = InstTraits::kNPerBlock, + .k1 = InstTraits::kBK1}, + .transfer_params = {.k1 = InstTraits::kBK1, + .thread_cluster_dims = InstTraits::kBThreadClusterLengths, + .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, + .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, + .src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kBBlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}; + + /// @brief Parameters for the warp-level GEMM computation. + static constexpr WarpGemmParams warp_gemm = {.gemm_m = InstTraits::kMPerXDL, + .gemm_n = InstTraits::kNPerXDL, + .m_iter = InstTraits::kMXdlPerWave, + .n_iter = InstTraits::kNXdlPerWave}; + + /// @brief Configuration for the C-matrix (output) tile transfer. + static constexpr OutputTileTransferInfo c_tile_transfer = { + .shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}; + + /// @brief Helper to safely get the pipeline version. + /// @details This is only available for some convolutions (e.g., forward). + /// If not present in `InstanceTraits`, it returns a default value. + template + static constexpr auto get_pipeline_version() + { + if constexpr(requires { T::kPipelineVersion; }) + { + return convert_pipeline_version(); + } + else + { + // Return a default or indicate not available + return builder::PipelineVersion::V1; + } + } + + /// @brief The block GEMM pipeline version used by the kernel. + static constexpr auto pipeline_version = get_pipeline_version(); + + /// @brief Helper to safely get the pipeline scheduler. + /// @details This is only available for some convolutions. If not present + /// in `InstanceTraits`, it returns a default value. + template + static constexpr auto get_pipeline_scheduler() + { + if constexpr(requires { T::kPipelineScheduler; }) + { + return convert_pipeline_scheduler(); + } + else if constexpr(requires { T::kLoopScheduler; }) + { + return convert_pipeline_scheduler(); + } + else + { + // Return a default or indicate not available + return builder::PipelineScheduler::DEFAULT; + } + } + + /// @brief The pipeline scheduler used by the kernel. + static constexpr auto pipeline_scheduler = get_pipeline_scheduler(); +}; + +/// @brief Specialization of `ConvTraits` for a `ConvBuilder` type. +/// @details This specialization provides backward compatibility for reflecting +/// on kernels defined via the `ConvBuilder` interface. It works by first +/// creating the `Instance` via the builder's factory, and then delegating +/// all trait extraction to the `ConvTraits` specialization. +template +struct ConvTraits> +{ + using Factory = builder::ConvFactory; + using Instance = typename Factory::Instance; + + // Delegate to Instance-based ConvTraits + using InstanceConvTraits = ConvTraits; + + // Forward all members from Instance-based traits + static constexpr int spatial_dim = InstanceConvTraits::spatial_dim; + static constexpr builder::ConvDirection direction = InstanceConvTraits::direction; + static constexpr auto layout = InstanceConvTraits::layout; + static constexpr builder::DataType data_type = InstanceConvTraits::data_type; + + static constexpr builder::ElementwiseOperation input_element_op = + InstanceConvTraits::input_element_op; + static constexpr builder::ElementwiseOperation weight_element_op = + InstanceConvTraits::weight_element_op; + static constexpr builder::ElementwiseOperation output_element_op = + InstanceConvTraits::output_element_op; + + static constexpr auto gemm_padding = InstanceConvTraits::gemm_padding; + static constexpr auto conv_specialization = InstanceConvTraits::conv_specialization; + + static constexpr int thread_block_size = InstanceConvTraits::thread_block_size; + static constexpr DataTileInfo tile_dims = InstanceConvTraits::tile_dims; + static constexpr InputTileTransferInfo a_tile_transfer = InstanceConvTraits::a_tile_transfer; + static constexpr InputTileTransferInfo b_tile_transfer = InstanceConvTraits::b_tile_transfer; + static constexpr WarpGemmParams warp_gemm = InstanceConvTraits::warp_gemm; + static constexpr OutputTileTransferInfo c_tile_transfer = InstanceConvTraits::c_tile_transfer; + static constexpr auto pipeline_version = InstanceConvTraits::pipeline_version; + static constexpr auto pipeline_scheduler = InstanceConvTraits::pipeline_scheduler; +}; + +} // namespace ck_tile::reflect::conv diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp index a47ad0ef57..29c687f491 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits.hpp @@ -14,18 +14,9 @@ #pragma once -#include #include -#include #include -#include -#include -#include -#include -#include -#include -#include -#include "instance_traits_util.hpp" +#include namespace ck_tile::reflect { diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 5784938fc6..c3d7367ee2 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -15,6 +15,7 @@ #pragma once #include "instance_traits.hpp" +#include "instance_traits_util.hpp" // Forward declaration to avoid circular dependency. // This file will be included by the device implementation header, so we cannot include diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 9f64c7d6e8..22b4f06772 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -14,6 +14,7 @@ #pragma once #include "instance_traits.hpp" +#include "instance_traits_util.hpp" // Forward declaration to avoid circular dependency. // This file will be included by the device implementation header, so we cannot include diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp index b13675a7b9..14aeb5074a 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_util.hpp @@ -9,9 +9,11 @@ #include #include +#include #include #include #include +#include #include #include #include @@ -356,4 +358,30 @@ constexpr std::string type_or_type_tuple_name() } } +/// @brief Makes a case insensitive comparison of two string views. +/// @param a First string view +/// @param b Second string view +/// @return Whether two string views a equal case insensitive +constexpr bool case_insensitive_equal(std::string_view a, std::string_view b) +{ + if(a.size() != b.size()) + return false; + + for(size_t i = 0; i < a.size(); ++i) + { + char c1 = a[i]; + char c2 = b[i]; + + // Convert to lowercase for comparison + if(c1 >= 'A' && c1 <= 'Z') + c1 += 32; + if(c2 >= 'A' && c2 <= 'Z') + c2 += 32; + + if(c1 != c2) + return false; + } + return true; +} + } // namespace ck_tile::reflect::detail diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index a2ef89da2e..875a9c8816 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -128,29 +128,14 @@ enum class ElementwiseOperation PASS_THROUGH }; -// Enums for the current block GEMM pipeline versions. -enum class BlockGemmPipelineVersion +// Enums for pipeline versions & schedulers +enum class PipelineVersion { V1, V2, V3, V4, - V5 -}; - -enum struct BlockGemmPipelineScheduler -{ - INTRAWAVE, - INTERWAVE, -}; - -// Enums for the gridwise GEMM pipeline versions. -enum class GridwiseGemmPipelineVersion -{ - V1, - V2, - V3, // Only used in stream-K implementation - V4, + V5, WEIGHT_ONLY }; @@ -186,9 +171,47 @@ enum class ConvFwdSpecialization FILTER_3x3 }; -enum class LoopScheduler +// Enums for the backward data convolution specialization. +enum class ConvBwdDataSpecialization +{ + DEFAULT, + FILTER_1X1_STRIDE1_PAD0, +}; + +// Enums for the backward weight convolution specialization. +enum class ConvBwdWeightSpecialization +{ + DEFAULT, + FILTER_1X1_STRIDE1_PAD0, + FILTER_1X1_PAD0, + ODD_C, +}; + +// Enums for the Gemm padding. +enum class GemmPadding { DEFAULT, + M_PADDING, + N_PADDING, + K_PADDING, + MN_PADDING, + MK_PADDING, + NK_PADDING, + MNK_PADDING, + O_PADDING, + MO_PADDING, + NO_PADDING, + KO_PADDING, + MNO_PADDING, + MKO_PADDING, + NKO_PADDING, + MNKO_PADDING, +}; + +enum class PipelineScheduler +{ + DEFAULT, + INTRAWAVE, INTERWAVE }; diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 26a666a805..c7bfac3c16 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -62,6 +62,9 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_bias_bnorm_clam add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp) add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp) +add_ck_builder_test(test_conv_traits + conv/test_conv_traits.cpp) + # Function to add all test_ckb targets to a list function(collect_test_ckb_targets result_var) # Get all targets in current directory diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp index 472c43438d..388edee873 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_1d_bf16.cpp @@ -24,7 +24,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< FwdConvSignature, FwdThreadBlock, - BlockGemmPipelineVersion::V2, + PipelineVersion::V2, ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp index b9969f7e95..6420da951d 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp @@ -22,7 +22,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); } @@ -44,7 +44,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp index cd5186cc10..fd8e8718c4 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp @@ -22,7 +22,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< FwdConvSignature, FwdThreadBlock, - BlockGemmPipelineVersion::V3, + PipelineVersion::V3, ConvFwdSpecialization::FILTER_1X1_PAD0>(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp index 584e0ab182..a4e1c1de38 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp @@ -22,7 +22,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< FwdConvSignature, FwdThreadBlock, - BlockGemmPipelineVersion::V4, + PipelineVersion::V4, ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp index 17caf98457..9bfbdbb838 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp @@ -22,7 +22,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp index ec4649a6ff..b2bff4cde9 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< FwdConvSignature, FwdThreadBlock, - BlockGemmPipelineVersion::V4, + PipelineVersion::V4, ConvFwdSpecialization::FILTER_1X1_PAD0>(); } diff --git a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp index 393ea9206d..df237d6a0c 100644 --- a/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp +++ b/experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp @@ -23,7 +23,7 @@ TEST(FwdConvInstances, run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< FwdConvSignature, FwdThreadBlock, - BlockGemmPipelineVersion::V1, + PipelineVersion::V1, ConvFwdSpecialization::FILTER_1X1_PAD0>(); } diff --git a/experimental/builder/test/conv/test_conv_traits.cpp b/experimental/builder/test/conv/test_conv_traits.cpp new file mode 100644 index 0000000000..ca453d2ad4 --- /dev/null +++ b/experimental/builder/test/conv/test_conv_traits.cpp @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include +#include +#include +#include + +namespace { + +using ::testing::ElementsAre; + +// Test fixture for ConvTraits tests +class ConvTraitsTest : public ::testing::Test +{ +}; + +// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 +TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched + ck::PipelineVersion::v1, // BlkGemmPipelineVer + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + false>; // DirectLoad + + // Use ConvTraits to extract compile-time information + using Traits = ck_tile::reflect::conv::ConvTraits; + + // Verify signature information + EXPECT_EQ(Traits::spatial_dim, 2); + EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK); + EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(Traits::thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(Traits::tile_dims.m, 128); + EXPECT_EQ(Traits::tile_dims.n, 128); + EXPECT_EQ(Traits::tile_dims.k, 16); + + // Verify A tile transfer info + EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(Traits::a_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(Traits::a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(Traits::a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(Traits::a_tile_transfer.transfer_params.lds_padding); + + // Verify B tile transfer info + EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k0, 2); + EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.m_or_n, 128); + EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k1, 8); + EXPECT_EQ(Traits::b_tile_transfer.transfer_params.k1, 8); + EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1)); + EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2)); + EXPECT_THAT(Traits::b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2)); + EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_vector_dim, 2); + EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_scalar_per_vector, 8); + EXPECT_EQ(Traits::b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8); + EXPECT_TRUE(Traits::b_tile_transfer.transfer_params.lds_padding); + + // Verify warp GEMM params + EXPECT_EQ(Traits::warp_gemm.gemm_m, 32); + EXPECT_EQ(Traits::warp_gemm.gemm_n, 32); + EXPECT_EQ(Traits::warp_gemm.m_iter, 4); + EXPECT_EQ(Traits::warp_gemm.n_iter, 4); + + // Verify output tile transfer info + EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1); + EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1); + EXPECT_THAT(Traits::c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8)); + EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8); + + // Verify pipeline configuration + EXPECT_EQ(Traits::pipeline_scheduler, ck_tile::builder::PipelineScheduler::INTRAWAVE); + EXPECT_EQ(Traits::pipeline_version, ck_tile::builder::PipelineVersion::V1); +} + +// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle +TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + ck::LoopScheduler::Default, // LoopSched + 1>; // NumGroupsToMerge + + // Use ConvTraits to extract compile-time information + using Traits = ck_tile::reflect::conv::ConvTraits; + + // Verify signature information + EXPECT_EQ(Traits::spatial_dim, 2); + EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK); + EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(Traits::thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(Traits::tile_dims.m, 128); + EXPECT_EQ(Traits::tile_dims.n, 128); + EXPECT_EQ(Traits::tile_dims.k, 16); +} +// Test ConvTraits with DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor +TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction) +{ + // Define a concrete instance type with specific template parameters + using DeviceInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor< + 2, // NDimSpatial + ck::tensor_layout::convolution::GNHWC, // ALayout + ck::tensor_layout::convolution::GKYXC, // BLayout + ck::Tuple<>, // DsLayout + ck::tensor_layout::convolution::GNHWK, // ELayout + ck::half_t, // ADataType + ck::half_t, // BDataType + float, // AccDataType + ck::half_t, // CShuffleDataType + ck::Tuple<>, // DsDataType + ck::half_t, // EDataType + ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation + ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation + ck::tensor_operation::device::ConvolutionForwardSpecialization:: + Default, // ConvForwardSpecialization + ck::tensor_operation::device::GemmSpecialization::Default, // GemmSpec + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + ck::Sequence<1, + 32, + 1, + 8>, // CDEBlockTransferClusterLengths + 8, // CDEBlockTransferScalarPerVector_NPerBlock + ck::half_t, // AComputeDataType + ck::half_t, // BComputeDataType + ck::LoopScheduler::Default>; // LoopSched + + // Use ConvTraits to extract compile-time information + using Traits = ck_tile::reflect::conv::ConvTraits; + + // Verify signature information + EXPECT_EQ(Traits::spatial_dim, 2); + EXPECT_EQ(Traits::direction, ck_tile::builder::ConvDirection::FORWARD); + EXPECT_EQ(Traits::layout, ck_tile::builder::GroupConvLayout2D::GNHWC_GKYXC_GNHWK); + EXPECT_EQ(Traits::data_type, ck_tile::builder::DataType::FP16); + EXPECT_EQ(Traits::input_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::weight_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + EXPECT_EQ(Traits::output_element_op, ck_tile::builder::ElementwiseOperation::PASS_THROUGH); + + // Verify specializations + EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT); + EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvFwdSpecialization::DEFAULT); + + // Verify algorithm information + EXPECT_EQ(Traits::thread_block_size, 256); + + // Verify tile dimensions + EXPECT_EQ(Traits::tile_dims.m, 128); + EXPECT_EQ(Traits::tile_dims.n, 128); + EXPECT_EQ(Traits::tile_dims.k, 16); +} +} // anonymous namespace diff --git a/experimental/builder/test/impl/conv_algorithm_types.hpp b/experimental/builder/test/impl/conv_algorithm_types.hpp index 9c5ca9b97b..d22c40e4c2 100644 --- a/experimental/builder/test/impl/conv_algorithm_types.hpp +++ b/experimental/builder/test/impl/conv_algorithm_types.hpp @@ -49,14 +49,14 @@ struct GridwiseWmmaGemm size_t n_per_wmma = 0; size_t m_wmma_per_wave = 0; size_t n_wmma_per_wave = 0; - GridwiseGemmPipelineVersion pipeline_version; + PipelineVersion pipeline_version; }; static_assert(ckb::GridwiseWmmaGemmDescriptor); struct BlockGemm { - BlockGemmPipelineVersion pipeline_version; - BlockGemmPipelineScheduler scheduler; + PipelineVersion pipeline_version; + PipelineScheduler scheduler; }; static_assert(ckb::BlockGemmDescriptor); @@ -156,7 +156,7 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle GemmSpecialization gemm_specialization; size_t num_gemm_k_prefetch_stages; size_t num_groups_to_merge; - LoopScheduler loop_scheduler; + PipelineScheduler loop_scheduler; }; static_assert( ckb::ConvAlgorithmDescriptor); @@ -191,7 +191,7 @@ struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ConvFwdSpecialization fwd_specialization; GemmSpecialization gemm_specialization; size_t num_gemm_k_prefetch_stages; - LoopScheduler loop_scheduler; + PipelineScheduler loop_scheduler; }; static_assert( ckb::ConvAlgorithmDescriptor); diff --git a/experimental/builder/test/utils/ckb_conv_test_common.hpp b/experimental/builder/test/utils/ckb_conv_test_common.hpp index d18a008015..5813ab565a 100644 --- a/experimental/builder/test/utils/ckb_conv_test_common.hpp +++ b/experimental/builder/test/utils/ckb_conv_test_common.hpp @@ -16,7 +16,7 @@ using namespace test; // Common test implementation template constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3() { @@ -52,7 +52,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3() .src_access_order_b = {1, 0, 2}}; constexpr BlockGemm BlockGemmDesc = {.pipeline_version = FwdPipelineVersion, - .scheduler = BlockGemmPipelineScheduler::INTRAWAVE}; + .scheduler = PipelineScheduler::INTRAWAVE}; constexpr ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 FwdConvAlgorithm{ .thread_block = FwdThreadBlock, @@ -73,13 +73,13 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3() EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3")); // Verify pipeline version is correct - if(FwdPipelineVersion == BlockGemmPipelineVersion::V1) + if(FwdPipelineVersion == PipelineVersion::V1) EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v1") != std::string::npos); - else if(FwdPipelineVersion == BlockGemmPipelineVersion::V3) + else if(FwdPipelineVersion == PipelineVersion::V3) EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v3") != std::string::npos); - else if(FwdPipelineVersion == BlockGemmPipelineVersion::V4) + else if(FwdPipelineVersion == PipelineVersion::V4) EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v4") != std::string::npos); - else if(FwdPipelineVersion == BlockGemmPipelineVersion::V5) + else if(FwdPipelineVersion == PipelineVersion::V5) EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v5") != std::string::npos); // Verify specialization is correct @@ -140,7 +140,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle() .gemm_specialization = GemmSpecialization::MNKPadding, .num_gemm_k_prefetch_stages = 1, .num_groups_to_merge = 2, - .loop_scheduler = LoopScheduler::DEFAULT}; + .loop_scheduler = PipelineScheduler::DEFAULT}; using Builder = ConvBuilder; @@ -176,7 +176,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle() .n_per_wmma = 32, .m_wmma_per_wave = 2, .n_wmma_per_wave = 1, - .pipeline_version = GridwiseGemmPipelineVersion::V1}; + .pipeline_version = PipelineVersion::V1}; constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 32, .k1 = 1}, .block_transfer_b = {.k0 = 4, .m_n = 32, .k1 = 1}, @@ -209,7 +209,7 @@ constexpr void run_test_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle() .fwd_specialization = FwdConvSpecialization, .gemm_specialization = GemmSpecialization::MNKPadding, .num_gemm_k_prefetch_stages = 1, - .loop_scheduler = LoopScheduler::DEFAULT}; + .loop_scheduler = PipelineScheduler::DEFAULT}; using Builder = ConvBuilder; diff --git a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp index 01bb806789..219206c5ce 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp @@ -3,6 +3,8 @@ #pragma once +#include + namespace ck { namespace tensor_operation { namespace device {