Skip to content

Basic docs for universal gemm & ck-tile gemm. #2014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Apr 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8b37f94
Basic docs for universal gemm & ck-tile gemm.
aosewski Mar 25, 2025
a12fa13
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
b985b09
Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffl…
aosewski Mar 26, 2025
9cc3431
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
b945cd9
Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffl…
aosewski Mar 26, 2025
010aad1
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
3d6aea2
Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffl…
aosewski Mar 26, 2025
5efbe86
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
ad1b09f
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
c671eae
Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffl…
aosewski Mar 26, 2025
2534f77
Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffl…
aosewski Mar 26, 2025
e8a237c
Update include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffl…
aosewski Mar 26, 2025
5b4d443
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
5832414
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
da89b40
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
bb95119
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
782cba4
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
b57ebf6
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
5eef1e9
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
1b66dc3
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
eca5168
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
faff800
Update include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cs…
aosewski Mar 26, 2025
66821a2
Reviewers suggestions.
aosewski Mar 26, 2025
cbff8c2
Align tparam names in doc with class tparams.
aosewski Mar 26, 2025
dc01b65
More reviewers fine tuning ;)
aosewski Mar 28, 2025
50828b1
Merge remote-tracking branch 'origin/develop' into aosewski/ck_tparam…
aosewski Mar 28, 2025
510870e
Merge branch 'develop' into aosewski/ck_tparam_doc
aosewski Apr 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

Expand All @@ -21,6 +21,105 @@ namespace ck {
namespace tensor_operation {
namespace device {

/// @brief \"Universal\" GEMM operation with SplitK support.
///
/// @par Overview
/// This GEMM operation implements the following mathematical equation:
/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N}))
/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are
/// elementwise operations applied to the A, B, and C tensors, respectively.
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
/// scenarios. That's why it's called \"universal\". It's universal through it's design
/// and versatilty.
///
/// @note This Kernel implementation supports SplitK algorithm. It can be configured
/// to split the dot product accumulated over the K dimension into multiple working groups.
/// The partial products of different workgroups are then reduced using the AtomicAdd
/// operation.
///
/// @tparam ALayout A tensor data layout.
/// @tparam BLayout B tensor data layout.
/// @tparam CLayout C tensor data layout.
/// @tparam ADataType A tensor data type.
/// @tparam BDataType B tensor data type.
/// @tparam CDataType C tensor data type.
/// @tparam GemmAccDataType The accumulation data type related to the hardware
/// matrix-multiplication instruction.
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
/// LDS memory during \"CShuffle\" data layout optimization.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor
/// (after GEMM).
/// @tparam GemmSpec Determines used "padding" version.
/// @tparam BlockSize The number of threads within workgroup.
/// @tparam MPerBlock The input/output data tile size in the M dimension.
/// @tparam NPerBlock The input/output data tile size in the N dimension.
/// @tparam KPerBlock The input data tile size in the K dimension.
/// @tparam AK1 The vector load size from global memory for A tensor.
/// @tparam BK1 The vector load size from global memory for B tensor.
/// @tparam MPerXDL M size of matrix-fused-multiply-add instruction.
/// @tparam NPerXDL N size of matrix-fused-multiply-add instruction.
/// @tparam MXdlPerWave The number of iterations in the M dimension over output tile per wavefront.
/// @tparam NXdlPerWave The number of iterations in the N dimension over output tile per wavefront.
/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question, "How many threads can be
/// arranged on each input data axis?"
/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory.
/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question: "How many threads to
/// arrange on each input data axis?"
/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory.
/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam CShuffleMXdlPerWavePerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in M dimension.
/// @tparam CShuffleNXdlPerWavePerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in N dimension.
/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// thread distribution used for storing data into output
/// tensor across output data layout dimensions.
/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access.
/// Used when storing data to output tensor.
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
/// intrawave).
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication
/// instructions.
/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication
/// instructions.
/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout
/// in global memory. Currently not supported!
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
/// in global memory (pre-shuffled).
template <typename ALayout,
typename BLayout,
typename CLayout,
Expand Down Expand Up @@ -130,9 +229,22 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,

using Argument = typename GridwiseGemm::Argument;

// Invoker
/// @brief Helper structure responsible for kernel invocation.
///
/// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU
/// kernel function. It usually determines the launched grid size prepares kernel
/// arguments as well as perform specific kernel configuration selection based on
/// runtime arguments.
///
/// @note If appropriately configured it may measure kernel execution time.
///
struct Invoker : public BaseInvoker
{
/// @brief This function issues GPU kernel execution.
/// @param arg The GPU kernel arguments.
/// @param stream_config The HIP stream configuration helper structure.
/// @return The kernel's average execution time (if time measurement is
/// enabled).
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
Expand Down
103 changes: 103 additions & 0 deletions include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,109 @@ __global__ void
#endif // end of if (defined(__gfx9__))
}

/// @brief \"Universal\" GEMM kernel with SplitK support.
///
/// @par Overview
/// This GEMM kernel is carrying out following mathematical equation:
/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N}))
/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are
/// elementwise operations that could be applied on each tensor respectively.
/// The \"universal\" gemm comes with multiple pipelines optimized for different usage
/// scenarios. That's why it's called \"universal\". It's universal through it's design
/// and versatilty.
///
/// @note This Kernel implementation supports SplitK algorithm. It can be configured
/// to split the dot product accumulated over the K dimension into multiple working groups.
/// The partial products of different workgroups are then reduced using the AtomicAdd
/// operation.
///
/// @tparam ALayout A tensor data layout.
/// @tparam BLayout B tensor data layout.
/// @tparam CLayout C tensor data layout.
/// @tparam ADataType A tensor data type.
/// @tparam BDataType B tensor data type.
/// @tparam AccDataType The accumulation data type related to the hardware
/// matrix-multiplication instruction.
/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into
/// LDS memory during \"CShuffle\" data layout optimization.
/// @tparam CDataType C tensor data type.
/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements.
/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements.
/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor
/// (after GEMM).
/// @tparam GemmSpec Determines used "padding" version.
/// @tparam BlockSize The number of threads within workgroup.
/// @tparam MPerBlock The input/output data tile size in the M dimension.
/// @tparam NPerBlock The input/output data tile size in the N dimension.
/// @tparam KPerBlock The input data tile size in the K dimension.
/// @tparam AK1Value The vector load size from global memory for A tensor.
/// @tparam BK1Value The vector load size from global memory for B tensor.
/// @tparam MPerXdl M size of matrix-fused-multiply-add instruction.
/// @tparam NPerXdl N size of matrix-fused-multiply-add instruction.
/// @tparam MXdlPerWave The number of iterations in the M dimension over output tile per wavefront.
/// @tparam NXdlPerWave The number of iterations in the N dimension over output tile per wavefront.
/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question, "How many threads can be
/// arranged on each input data axis?"
/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory.
/// @tparam AThreadTransferSrcResetCoordinateAfterRun Decides whether we reset thread coordinate
/// (return back to the window origin) after all thread finish data copy.
/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input
/// data. Can be interpreted as the answer
/// to the question: "How many threads to
/// arrange on each input data axis?"
/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over
/// the input tensor dimension. Can be interpreted
/// as the answer to the question: "In which
/// order to spread threads through tensor axes?".
/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be
/// interpreted as the answer to the question "Which dimension
/// to read first? And which next?" etc.
/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory
/// access - the one with contiguous memory.
/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of
/// elements accessed per thread per instruction.
/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory.
/// @tparam BThreadTransferSrcResetCoordinateAfterRun Decides whether we reset thread coordinate
/// (return back to the window origin) after all thread finish data copy.
/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With
/// universal GEMM there's no need for padding.
/// @tparam CShuffleMXdlPerWavePerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in M dimension.
/// @tparam CShuffleNXdlPerWavePerShuffle The number of matrix-multiplication instructions
/// results to process per wave per iteration of CShuffle
/// in N dimension.
/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial
/// thread distribution used for storing data into output
/// tensor across output data layout dimensions.
/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access.
/// Used when storing data to output tensor.
/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or
/// intrawave).
/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline.
/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication
/// instructions.
/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication
/// instructions.
/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout
/// in global memory. Currently not supported!
/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout
/// in global memory (pre-shuffled).
template <typename ALayout,
typename BLayout,
typename CLayout,
Expand Down
Loading