Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions example/15_grouped_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,13 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
endif()

list(APPEND gpu_list_tf32 gfx942 gfx950)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0)
add_example_executable(example_grouped_gemm_xdl_fp32_tf32 grouped_gemm_xdl_fp32_tf32.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32_tf32)
set(target 1)
endif()
endforeach()
67 changes: 67 additions & 0 deletions example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"

#define EXAMPLE_WITH_COMPUTE_DATATYPE

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using F16 = ck::half_t;
using F32 = float;

using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

using ADataType = F32;
using BDataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F32;
using ComputeDataType = ck::tf32_t;

using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;

using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;

using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, ck::LoopScheduler::Default, ComputeDataType>;
// clang-format on

#include "run_grouped_gemm_example.inc"

int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }

#undef EXAMPLE_WITH_COMPUTE_DATATYPE
12 changes: 10 additions & 2 deletions example/15_grouped_gemm/run_grouped_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@

#pragma once

// use macro to minimize code change
#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE
using ComputeDataType = AccDataType;
#endif

struct ProblemSize final
{
std::vector<ck::index_t> Ms;
Expand Down Expand Up @@ -231,7 +236,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
CDEElementOp,
ComputeDataType,
ComputeDataType>;

for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
Expand All @@ -253,7 +260,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
pass &= ck::utils::check_err(c_device_result_converted, c_host_tensors[i]);

#else
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
pass &= ck::utils::check_err<Tensor<EDataType>, Tensor<EDataType>, ComputeDataType>(
c_device_tensors[i], c_host_tensors[i]);
#endif
}
}
Expand Down
5 changes: 4 additions & 1 deletion include/ck/host_utility/device_prop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ inline bool is_wmma_supported()
return is_gfx103_supported() || is_gfx11_supported() || is_gfx12_supported();
}

inline bool is_tf32_supported() { return (ck::get_device_name() == "gfx942") ? true : false; }
inline bool is_tf32_supported()
{
return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
}

} // namespace ck
#endif
16 changes: 16 additions & 0 deletions include/ck/library/utility/check_err.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "ck/host_utility/io.hpp"

#include "ck/library/utility/ranges.hpp"
#include "ck/host_utility/device_prop.hpp"

namespace ck {
namespace utils {
Expand Down Expand Up @@ -171,6 +172,21 @@ check_err(const Range& out,
double rtol = 1e-5,
double atol = 3e-5)
{
#ifndef __HIPCC_RTC__
if(ck::get_device_name() == "gfx942")
{
rtol = 1e-2;
atol = 1e-2;
}
#else
// In RTC mode, use preprocessor macros to check device architecture
#if defined(__gfx942__)
{
rtol = 1e-2;
atol = 1e-2;
}
#endif
#endif
if(out.size() != ref.size())
{
std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
Expand Down
Loading