Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_welford_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/epilogue_cshuffle_v3_reduce_wmma.hpp"

namespace ck {

Expand Down Expand Up @@ -605,6 +606,29 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
BlockwiseGemmPipe,
BlockSize>;

template <typename ReduceTrait>
using EpilogueReduceCShuffle = EpilogueReduceCShuffle<
DsDataType,
EDataType,
AccDataType,
CShuffleDataType,
MPerBlock,
NPerBlock,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
CDEElementwiseOperation,
ThisThreadBlock,
BlockwiseGemmPipe,
GemmSpec,
BlockSize,
ReduceTrait>;

template <typename DEGridDesc>
__device__ static constexpr auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DEGridDesc& de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_gemm_reduce_instance
device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp
device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp

device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp
device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp
device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp
device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp
)

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

55 changes: 49 additions & 6 deletions profiler/include/profiler/profile_gemm_reduce_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ using ReduceOutElementOps = ck::Tuple<Div, Div>;
using DeviceGemmReduceNoOpPtr =
ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>;

#ifdef CK_USE_XDL
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);

Expand All @@ -45,6 +46,20 @@ void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances(

void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
#endif
#ifdef CK_USE_WMMA
void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);

void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);

void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);

void add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&);
#endif

} // namespace instance
} // namespace device
Expand Down Expand Up @@ -211,33 +226,61 @@ bool profile_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#ifdef CK_USE_XDL
ck::tensor_operation::device::instance::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances(
gemm_ptrs);
#endif
#ifdef CK_USE_WMMA
ck::tensor_operation::device::instance::
add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_kn_mn_instances(
gemm_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#ifdef CK_USE_XDL
ck::tensor_operation::device::instance::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances(
gemm_ptrs);
#endif
#ifdef CK_USE_WMMA
ck::tensor_operation::device::instance::
add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_mk_nk_mn_instances(
gemm_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#ifdef CK_USE_XDL
ck::tensor_operation::device::instance::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances(
gemm_ptrs);
#endif
#ifdef CK_USE_WMMA
ck::tensor_operation::device::instance::
add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_kn_mn_instances(
gemm_ptrs);
#endif
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
#ifdef CK_USE_XDL
ck::tensor_operation::device::instance::
add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances(
gemm_ptrs);
#endif
#ifdef CK_USE_WMMA
ck::tensor_operation::device::instance::
add_device_gemm_reduce_wmma_cshuffle_v3_f16_f16_f16_f32_f32_km_nk_mn_instances(
gemm_ptrs);
#endif
}
}

Expand Down Expand Up @@ -274,6 +317,8 @@ bool profile_gemm_reduce_impl(int do_verification,

auto invoker_ptr = gemm_ptr->MakeInvokerPointer();

std::string gemm_name = gemm_ptr->GetTypeString();

if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
Expand All @@ -289,8 +334,6 @@ bool profile_gemm_reduce_impl(int do_verification,
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});

std::string gemm_name = gemm_ptr->GetTypeString();

std::size_t flop = std::size_t(2) * M * N * K;

std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
Expand All @@ -317,9 +360,9 @@ bool profile_gemm_reduce_impl(int do_verification,
reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data());
reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data());

ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result);
ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result);
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
pass = pass & ck::utils::check_err(reduce0_m_device_result, reduce0_m_host_result);
pass = pass & ck::utils::check_err(reduce1_m_device_result, reduce1_m_host_result);

if(do_log)
{
Expand All @@ -346,7 +389,7 @@ bool profile_gemm_reduce_impl(int do_verification,
}
else
{
std::cout << "does not support this GEMM problem" << std::endl;
std::cout << gemm_name << ": does not support this GEMM problem" << std::endl;
}
}

Expand Down
10 changes: 6 additions & 4 deletions test/gemm_reduce/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_reduce_fp16 PRIVATE utility device_gemm_reduce_instance)
endif()
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_reduce_fp16 PRIVATE utility device_gemm_reduce_instance)
endif()
endif()
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>

Expand Down