Skip to content

Basic docs for universal gemm & ck-tile gemm. #2014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Apr 2, 2025
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8b37f94
Basic docs for universal gemm & ck-tile gemm.
aosewski Mar 25, 2025
a12fa13
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
b985b09
Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffl…
aosewski Mar 26, 2025
9cc3431
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
b945cd9
Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffl…
aosewski Mar 26, 2025
010aad1
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
3d6aea2
Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffl…
aosewski Mar 26, 2025
5efbe86
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
ad1b09f
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
c671eae
Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffl…
aosewski Mar 26, 2025
2534f77
Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffl…
aosewski Mar 26, 2025
e8a237c
Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffl…
aosewski Mar 26, 2025
5b4d443
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
5832414
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
da89b40
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
bb95119
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
782cba4
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
b57ebf6
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
5eef1e9
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
1b66dc3
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
eca5168
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
faff800
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
66821a2
Reviewers suggestions.
aosewski Mar 26, 2025
cbff8c2
Align tparam names in doc with class tparams.
aosewski Mar 26, 2025
dc01b65
More reviewers fine tuning ;)
aosewski Mar 28, 2025
50828b1
Merge remote-tracking branch 'origin/develop' into aosewski/ck_tparam…
aosewski Mar 28, 2025
510870e
Merge branch 'develop' into aosewski/ck_tparam_doc
aosewski Apr 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

@@ -21,6 +21,105 @@ namespace ck {
namespace tensor_operation {
namespace device {

/// @brief \"Universal\" GEMM operation with SplitK support.
///
/// @par Overview
/// This GEMM operation implements the following mathematical equation:
/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N}))
/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are
/// elementwise operations applied to the A, B, and C tensors, respectively.
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
/// scenarios. That's why it's called \"universal\". It's universal through it's design
/// and versatilty.
///
/// @note This Kernel implementation supports SplitK algorithm. It can be configured
/// to split the dot product accumulated over the K dimension into multiple working groups.
/// The partial products of different workgroups are then reduced using the AtomicAdd
/// operation.
///
/// @tparam ALayout A tensor data layout.
/// @tparam BLayout B tensor data layout.
/// @tparam CLayout C tensor data layout.
/// @tparam ADataType A tensor data type.
/// @tparam BDataType B tensor data type.
/// @tparam CDataType C tensor data type.
/// @tparam GemmAccDataType The accumulation data type related to the hardware
/// matrix-multiplication instruction.
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
/// LDS memory during \"CShuffle\" data layout optimization.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor
/// (after GEMM).
/// @tparam GemmSpec Determines used "padding" version.
/// @tparam BlockSize The number of threads within workgroup.
/// @tparam MPerBlock The input/output data tile size in the M dimension.
/// @tparam NPerBlock The input/output data tile size in the N dimension.
/// @tparam KPerBlock The input data tile size in the K dimension.
/// @tparam AK1 The vector load size from global memory for A tensor.
/// @tparam BK1 The vector load size from global memory for B tensor.
/// @tparam MPerXDL M size of matrix-fused-multiply-add instruction.
/// @tparam NPerXDL N size of matrix-fused-multiply-add instruction.
/// @tparam MXdlPerWave The number of iterations in the M dimension over output tile per wavefront.
/// @tparam NXdlPerWave The number of iterations in the N dimension over output tile per wavefront.
/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question, "How many threads can be
/// arranged on each input data axis?"
/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory.
/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question: "How many threads to
/// arrange on each input data axis?"
/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory.
/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam CShuffleMXdlPerWavePerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in M dimension.
/// @tparam CShuffleNXdlPerWavePerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in N dimension.
/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// thread distribution used for storing data into output
/// tensor across output data layout dimensions.
/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access.
/// Used when storing data to output tensor.
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
/// intrawave).
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication
/// instructions.
/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication
/// instructions.
/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout
/// in global memory. Currently not supported!
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
/// in global memory (pre-shuffled).
template <typename ALayout,
typename BLayout,
typename CLayout,
@@ -130,9 +229,22 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,

using Argument = typename GridwiseGemm::Argument;

// Invoker
/// @brief Helper structure responsible for kernel invocation.
///
/// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU
/// kernel function. It usually determines the launched grid size prepares kernel
/// arguments as well as perform specific kernel configuration selection based on
/// runtime arguments.
///
/// @note If appropriately configured it may measure kernel execution time.
///
struct Invoker : public BaseInvoker
{
/// @brief This function issues GPU kernel execution.
/// @param arg The GPU kernel arguments.
/// @param stream_config The HIP stream configuration helper structure.
/// @return The kernel's average execution time (if time measurement is
/// enabled).
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
103 changes: 103 additions & 0 deletions include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
100755 → 100644
Original file line number Diff line number Diff line change
@@ -82,6 +82,109 @@ __global__ void
#endif // end of if (defined(__gfx9__))
}

/// @brief \"Universal\" GEMM kernel with SplitK support.
///
/// @par Overview
/// This GEMM kernel is carrying out following mathematical equation:
/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N}))
/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are
/// elementwise operations that could be applied on each tensor respectively.
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
/// scenarios. That's why it's called \"universal\". It's universal through it's design
/// and versatilty.
///
/// @note This Kernel implementation supports SplitK algorithm. It can be configured
/// to split the dot product accumulated over the K dimension into multiple working groups.
/// The partial products of different workgroups are then reduced using the AtomicAdd
/// operation.
///
/// @tparam ALayout A tensor data layout.
/// @tparam BLayout B tensor data layout.
/// @tparam CLayout C tensor data layout.
/// @tparam ADataType A tensor data type.
/// @tparam BDataType B tensor data type.
/// @tparam AccDataType The accumulation data type related to the hardware
/// matrix-multiplication instruction.
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
/// LDS memory during \"CShuffle\" data layout optimization.
/// @tparam CDataType C tensor data type.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor
/// (after GEMM).
/// @tparam GemmSpec Determines used "padding" version.
/// @tparam BlockSize The number of threads within workgroup.
/// @tparam MPerBlock The input/output data tile size in the M dimension.
/// @tparam NPerBlock The input/output data tile size in the N dimension.
/// @tparam KPerBlock The input data tile size in the K dimension.
/// @tparam AK1Value The vector load size from global memory for A tensor.
/// @tparam BK1Value The vector load size from global memory for B tensor.
/// @tparam MPerXdl M size of matrix-fused-multiply-add instruction.
/// @tparam NPerXdl N size of matrix-fused-multiply-add instruction.
/// @tparam MXdlPerWave The number of iterations in the M dimension over output tile per wavefront.
/// @tparam NXdlPerWave The number of iterations in the N dimension over output tile per wavefront.
/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question, "How many threads can be
/// arranged on each input data axis?"
/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory.
/// @tparam AThreadTransferSrcResetCoordinateAfterRun Decides whether we reset thread coordinate
/// (return back to the window origin) after all thread finish data copy.
/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question: "How many threads to
/// arrange on each input data axis?"
/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory.
/// @tparam BThreadTransferSrcResetCoordinateAfterRun Decides whether we reset thread coordinate
/// (return back to the window origin) after all thread finish data copy.
/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam CShuffleMXdlPerWavePerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in M dimension.
/// @tparam CShuffleNXdlPerWavePerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in N dimension.
/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// thread distribution used for storing data into output
/// tensor across output data layout dimensions.
/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access.
/// Used when storing data to output tensor.
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
/// intrawave).
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication
/// instructions.
/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication
/// instructions.
/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout
/// in global memory. Currently not supported!
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
/// in global memory (pre-shuffled).
template <typename ALayout,
typename BLayout,
typename CLayout,
60 changes: 60 additions & 0 deletions include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
Original file line number Diff line number Diff line change
@@ -12,6 +12,11 @@

namespace ck_tile {

/// @brief The GEMM problem definition.
///
/// @par Overview
/// This structure defines the GEMM problem configuration by stating all required information
/// like M,N,K sizes and respective strides.
struct GemmProblem
{
CK_TILE_HOST GemmProblem() = default;
@@ -29,6 +34,12 @@ struct GemmProblem
index_t stride_C;
};

/// @brief The GEMM kernel host arguments.
///
/// @par Overview
/// This structure is passed to @ref GemmKernel "GemmKernel" when creating kernel arguments
/// object. It contain all necessary information required to build proper kernel argument
/// and launch kernel on GPU.
struct GemmHostArgs : public GemmProblem
{
CK_TILE_HOST GemmHostArgs() = default;
@@ -56,20 +67,69 @@ struct GemmHostArgs : public GemmProblem
index_t k_batch;
};

/// @brief The GEMM kernel device arguments.
struct GemmKernelArgs
{
/// @brief The A input tensor's pointer to device memory.
const void* a_ptr;
/// @brief The B input tensor's pointer to device memory.
const void* b_ptr;
/// @brief The C output tensor's pointer to device memory.
void* c_ptr;
/// @brief GEMM's M dimension size.
index_t M;
/// @brief GEMM's N dimension size.
index_t N;
/// @brief GEMM's K dimension size.
index_t K;
/// @brief The distance between consecutive elements of non-contiguous dimension
/// (in memory) of A tensor.
index_t stride_A;
/// @brief The distance between consecutive elements of non-contiguous dimension
/// (in memory) of B tensor.
index_t stride_B;
/// @brief The distance between consecutive elements of non-contiguous dimension
/// (in memory) of C tensor.
index_t stride_C;
index_t k_batch;
};

/// @brief The GEMM kernel template.
///
/// @paragraph Overview Overview
/// This class provides the generic matrix multiplication kernel template. By semantic
/// division of GEMM algorithm into following parts we achieve flexible, versatile
/// and robust kernel implementation.
///
/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator()
/// function call operator" which determines the work scope of each workgroup.
/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm.
/// This is the place where each workgroup is loading data from global memory and
/// carrying out dot products.
/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation
/// responsible for storing results to global memory. This is also the place where
/// any additional operator fusion may take place.
///
/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_
/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all
/// internal details of those functional parts. You can think of it like both gemm and
/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover
/// the policy is responsible for definition of all necessary data layouts and thread's
/// work distribution.
///
/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into the
/// output data tile to be calculated. It determines the workgroup to
/// data relationship (or in other words - which data would be
/// processed and calculated by which workgroup).
/// @tparam GemmPipeline_ The type of class which provides the core part of matrix
/// multiplication. This class should provide implementation of data
/// loading from global memory and performing block-wise matrix
/// multiplication. You can think of it as a work done by single
/// workgroup point of view.
/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix
/// multiplication implementation. It is responsible for storing
/// results calculated by @ref GemmPipeline_ "GemmPipeline" to
/// the output C tensor in global memory.
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GemmKernel
{