diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 03bde86421..a9ae0b2a6a 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -105,7 +105,7 @@ foreach(gpu IN LISTS GPU_TARGETS) endif() endforeach() -list(APPEND gpu_list_tf32 gfx942) +list(APPEND gpu_list_tf32 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0) diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index e482953e46..32110759f4 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -356,11 +356,18 @@ inline __host__ __device__ constexpr double get_rtol() } template -inline __host__ __device__ constexpr double get_atol() +inline __host__ __device__ constexpr double get_atol(size_t K = 0) { if constexpr(std::is_same_v && std::is_same_v) { - return 1e-3; + if(K == 0) + { + throw std::runtime_error("K is 0"); + } + // tf32 has 10 mantissa bits, so epsilon = 2^(-10) = 1/1024 + constexpr double epsilon_tf32 = 1.0 / 1024.0; // 2^(-10) + constexpr double epsilon_fp32 = std::numeric_limits::epsilon(); + return (epsilon_tf32 - epsilon_fp32) * K; } else if constexpr(std::is_same_v) { diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 7fb0c1e812..cdabcc9fa8 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -212,7 +212,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) c_m_n_host_result, "Error: Incorrect results!", get_rtol(), - get_atol()); + get_atol(K)); #endif } diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 4f174bfcbb..791d81e264 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -21,7 +21,7 @@ foreach(gpu IN LISTS GPU_TARGETS) endif() endforeach() -list(APPEND gpu_list_tf32 gfx942) +list(APPEND gpu_list_tf32 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0) diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index 20cbc5fdca..20d9bab7e1 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -33,3 +33,13 @@ if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) endif() + +list(APPEND gpu_list_tf32 gfx942 gfx950) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0) + add_example_executable(example_grouped_gemm_xdl_fp32_tf32 grouped_gemm_xdl_fp32_tf32.cpp) + add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32_tf32) + set(target 1) + endif() +endforeach() diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp new file mode 100644 index 0000000000..78eb90e311 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#define EXAMPLE_WITH_COMPUTE_DATATYPE + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; +using ComputeDataType = ck::tf32_t; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, ck::LoopScheduler::Default, ComputeDataType>; +// clang-format on + +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } + +#undef EXAMPLE_WITH_COMPUTE_DATATYPE diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 87ccebc3c4..13698f3394 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -3,6 +3,11 @@ #pragma once +// use macro to minimize code change +#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE +using ComputeDataType = AccDataType; +#endif + struct ProblemSize final { std::vector Ms; @@ -231,7 +236,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co AccDataType, AElementOp, BElementOp, - CDEElementOp>; + CDEElementOp, + ComputeDataType, + ComputeDataType>; for(std::size_t i = 0; i < gemm_descs.size(); i++) { @@ -253,7 +260,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co pass &= ck::utils::check_err(c_device_result_converted, c_host_tensors[i]); #else - pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + pass &= ck::utils::check_err, Tensor, ComputeDataType>( + c_device_tensors[i], c_host_tensors[i]); #endif } } diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 0c4f056a46..53f4c27399 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -129,7 +129,10 @@ inline bool is_wmma_supported() return is_gfx103_supported() || is_gfx11_supported() || is_gfx12_supported(); } -inline bool is_tf32_supported() { return (ck::get_device_name() == "gfx942") ? true : false; } +inline bool is_tf32_supported() +{ + return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"; +} } // namespace ck #endif diff --git a/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp index 3637053e14..f34f91acfc 100644 --- a/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -19,6 +19,7 @@ #include "ck/host_utility/io.hpp" #include "ck/library/utility/ranges.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace utils { @@ -171,6 +172,21 @@ check_err(const Range& out, double rtol = 1e-5, double atol = 3e-5) { +#ifndef __HIPCC_RTC__ + if(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950") + { + rtol = 1e-2; + atol = 1e-2; + } +#else +// In RTC mode, use preprocessor macros to check device architecture +#if defined(__gfx942__) || defined(__gfx950__) + { + rtol = 1e-2; + atol = 1e-2; + } +#endif +#endif // __HIPCC_RTC__ if(out.size() != ref.size()) { std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 55015dd30f..7648e9a92d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -619,6 +619,7 @@ template constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() { + if constexpr(LoopSched == LoopScheduler::Default) { return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 + typename CElementwiseOperation, + typename ComputeDataType = ADataType> struct DeviceGroupedGemm : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index 7a1944cc68..0ae1aa321a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -134,7 +134,8 @@ template + LoopScheduler LoopSched = make_default_loop_scheduler(), + typename ComputeDataType = ADataType> struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm + CDEElementwiseOperation, + ComputeDataType> { using DeviceOp = DeviceGroupedGemm_Xdl; GET_NXDL_PER_WAVE_IMPL @@ -233,8 +235,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); - using ComputeDataType = ADataType; - // GridwiseGemm template using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 4a6ed62c0e..c8643a4087 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -279,7 +279,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 Sequence, Sequence, Sequence>; - static_for<0, tuple_element_t::Size(), 1>{}( [&](auto v_idx) { constexpr auto VectorLoadSize = diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index ce2d9299f9..0e1a5aa9ac 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -80,8 +80,10 @@ enum struct MfmaInstr mfma_f32_16x16x128f8f6f4, mfma_scale_f32_32x32x64f8f6f4, mfma_scale_f32_16x16x128f8f6f4, - mfma_f32_16x16x8xf32, // tf32 - mfma_f32_32x32x4xf32, + mfma_f32_16x16x8xf32, // tf32 on gfx942 + mfma_f32_32x32x4xf32, // tf32 on gfx942 + mfma_f32_16x16x32xf32, // bf16x3 simulate tf32 on gfx950 + mfma_f32_32x32x16xf32, // bf16x3 simulate tf32 on gfx950 // gfx11 wmma_f32_16x16x16_f16, wmma_f32_16x16x16_bf16, @@ -994,24 +996,47 @@ struct mfma_type }; template <> -struct mfma_type +struct mfma_type { - static constexpr index_t wave_size = 64; // fixed - static constexpr index_t m_per_blk = 32; // from the instruction - static constexpr index_t n_per_blk = 32; // from the instruction - static constexpr index_t num_threads_per_blk = n_per_blk; // 32 - static constexpr index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size; // 16 - static constexpr index_t num_input_blks = m_per_blk / num_regs_per_blk; // 2 - static constexpr index_t group_size = 4; // corresponding to CD rows mapping + // gfx950 specific: use bf16x3 simulate tf32 + static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; static constexpr index_t num_output_blks = 1; - static constexpr index_t k_per_blk = 2; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16xf32::Run(a, b, reg_c); + } +}; +template <> +struct mfma_type +{ + // gfx950 specific: use bf16x3 simulate tf32 + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; static constexpr bool is_k_reduction = true; - // AB register size: 2, CD register size: 16 + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_32x32x4xf32::Run(a, b, reg_c); + intrin_mfma_f32_16x16x32xf32::Run(a, b, reg_c); } }; @@ -1275,12 +1300,14 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { #if defined(__gfx12__) return MfmaInstr::wmma_unsupport_16x16_gfx12; #elif defined(__gfx11__) return MfmaInstr::wmma_unsupport_16x16_gfx11; +#elif defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x16xf32; #elif defined(__gfx942__) return MfmaInstr::mfma_f32_32x32x4xf32; #else @@ -1289,12 +1316,14 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { #if defined(__gfx12__) return MfmaInstr::wmma_unsupport_16x16_gfx12; #elif defined(__gfx11__) return MfmaInstr::wmma_unsupport_16x16_gfx11; +#elif defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x32xf32; #elif defined(__gfx942__) return MfmaInstr::mfma_f32_16x16x8xf32; #else @@ -2185,6 +2214,10 @@ struct XdlopsGemm (is_same::value && KPack <= 8) || ((is_same::value || is_same::value) && KPack < 32) || is_same::value) +#if defined(__gfx950__) + // tf32 on gfx950 is implemented as bf16x3, so it should be treated as bf16. + || (is_same::value && KPack <= 4) +#endif ? true : false; static constexpr auto mfma = MfmaSelector +__device__ __forceinline__ void +convert_float_to_bf16_pairs(const vector_type& reg_f32, + vector_type& reg_bf16_big, + vector_type& reg_bf16_small) +{ + static_for<0, VecSize, 1>{}([&](auto k) { + using IK = Number; + reg_bf16_big.template AsType()(k) = + type_convert(reg_f32.template AsType()[IK{}]); + reg_bf16_small.template AsType()(k) = type_convert( + reg_f32.template AsType()[IK{}] - + type_convert(reg_bf16_big.template AsType()[IK{}])); + }); +} +/* */ + // fp32 template struct intrin_mfma_f32_32x32x1f32; @@ -1636,7 +1655,7 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16> } }; -/******************* tf32 *************************************/ +/******************* tf32 on gfx942 *************************************/ template struct intrin_mfma_f32_16x16x8xf32; @@ -1646,7 +1665,7 @@ struct intrin_mfma_f32_16x16x8xf32<16, 16> template __device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c) { -#if defined(__gfx94__) +#if defined(__gfx942__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); #else @@ -1666,7 +1685,7 @@ struct intrin_mfma_f32_32x32x4xf32<32, 32> template __device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c) { -#if defined(__gfx94__) +#if defined(__gfx942__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4_xf32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); #else @@ -1677,4 +1696,102 @@ struct intrin_mfma_f32_32x32x4xf32<32, 32> } }; +/******************* tf32/xf32 on gfx950 ********************************/ +/* bf16x3 simulate tf32/xf32: input/output/accumulator are all float; */ +/* step: */ +/* 1. separate one input to 2 bf16 registers: */ +/* in_bf16_big = f32_to_bf16(in_f32) */ +/* in_bf16_small = in_f32 - in_bf16_big */ +/* 2. run 3 xdlops gemm: the accumulator of each gemm is the same. */ +/* out_f32 = A_bf16_big * B_bf16_big */ +/* out_f32 += A_bf16_small * B_bf16_big */ +/* out_f32 += A_bf16_big * B_bf16_small */ +/************************************************************************/ +template +struct intrin_mfma_f32_16x16x32xf32; + +template <> +struct intrin_mfma_f32_16x16x32xf32<16, 16> +{ + template + __device__ static void Run(const float8_t& reg_a, const float8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + using I0 = Number<0>; + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + vector_type v_reg_a_bf16_big; + vector_type v_reg_a_bf16_small; + vector_type v_reg_b_bf16_big; + vector_type v_reg_b_bf16_small; + + convert_float_to_bf16_pairs(reg_a_v, v_reg_a_bf16_big, v_reg_a_bf16_small); + convert_float_to_bf16_pairs(reg_b_v, v_reg_b_bf16_big, v_reg_b_bf16_small); + + // Run 3 times: big*big, small*big, big*small + intrin_mfma_f32_16x16x32bf16<16, 16>::Run( + v_reg_a_bf16_small.template AsType()[I0{}], + v_reg_b_bf16_big.template AsType()[I0{}], + reg_c); + intrin_mfma_f32_16x16x32bf16<16, 16>::Run( + v_reg_a_bf16_big.template AsType()[I0{}], + v_reg_b_bf16_small.template AsType()[I0{}], + reg_c); + intrin_mfma_f32_16x16x32bf16<16, 16>::Run( + v_reg_a_bf16_big.template AsType()[I0{}], + v_reg_b_bf16_big.template AsType()[I0{}], + reg_c); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_f32_32x32x16xf32; + +template <> +struct intrin_mfma_f32_32x32x16xf32<32, 32> +{ + template + __device__ static void Run(const float8_t& reg_a, const float8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + using I0 = Number<0>; + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + vector_type v_reg_a_bf16_big; + vector_type v_reg_a_bf16_small; + vector_type v_reg_b_bf16_big; + vector_type v_reg_b_bf16_small; + + convert_float_to_bf16_pairs(reg_a_v, v_reg_a_bf16_big, v_reg_a_bf16_small); + convert_float_to_bf16_pairs(reg_b_v, v_reg_b_bf16_big, v_reg_b_bf16_small); + + // Run 3 times: big*big, small*big, big*small + intrin_mfma_f32_32x32x16bf16<32, 32>::Run( + v_reg_a_bf16_small.template AsType()[I0{}], + v_reg_b_bf16_big.template AsType()[I0{}], + reg_c); + intrin_mfma_f32_32x32x16bf16<32, 32>::Run( + v_reg_a_bf16_big.template AsType()[I0{}], + v_reg_b_bf16_small.template AsType()[I0{}], + reg_c); + intrin_mfma_f32_32x32x16bf16<32, 32>::Run( + v_reg_a_bf16_big.template AsType()[I0{}], + v_reg_b_bf16_big.template AsType()[I0{}], + reg_c); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +/******************* tf32/xf32 on gfx950 end ************************************/ } // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 8b9b973b2d..660ec64f97 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -142,18 +142,8 @@ struct ReferenceGemm : public device::BaseOperator arg.b_element_op_(v_b, arg.b_k_n_(k, n)); } - if constexpr(is_same_v && - is_same_v) - { // only for tf32 now - v_acc += - ck::type_convert(ck::type_convert(v_a)) * - ck::type_convert(ck::type_convert(v_b)); - } - else - { - v_acc += - ck::type_convert(v_a) * ck::type_convert(v_b); - } + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); } CDataType v_c{0}; diff --git a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp index cf30bc7dda..bcc8b95500 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp @@ -80,16 +80,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) // apply b_element_op b_element_op(v_b, p_b_grid[element_idx_b]); // multiply and accumulate - if constexpr(is_same_v && - is_same_v) - { // only for tf32 now - v_acc += ck::type_convert(ck::type_convert(v_a)) * - ck::type_convert(ck::type_convert(v_b)); - } - else - { - v_acc += type_convert(v_a) * type_convert(v_b); - } + v_acc += type_convert(v_a) * type_convert(v_b); } // apply c_element_op c_element_op(v_c, v_acc); diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index 13f5cd1cda..8400b020f7 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -105,7 +105,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using INT8 = int8_t; using F8 = ck::f8_t; using BF8 = ck::bf8_t; -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) using TF32 = ck::tf32_t; #endif @@ -228,7 +228,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); #endif } @@ -253,7 +253,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); #endif } @@ -280,7 +280,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); #endif } @@ -306,7 +306,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); #endif } @@ -331,7 +331,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); #endif } @@ -352,7 +352,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); #endif } @@ -373,7 +373,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); #endif } @@ -416,7 +416,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); #endif } @@ -439,7 +439,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) } else if(data_type == ConvDataType::F32_F32_F32_TF32) { -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); #endif }