diff --git a/example/01_gemm/gemm_wmma_fp8_v3.cpp b/example/01_gemm/gemm_wmma_fp8_v3.cpp index 0376820b7b..2f8eac113b 100644 --- a/example/01_gemm/gemm_wmma_fp8_v3.cpp +++ b/example/01_gemm/gemm_wmma_fp8_v3.cpp @@ -13,7 +13,7 @@ using CDataType = ck::bhalf_t; using ComputeTypeA = ck::f8_t; using ComputeTypeB = ck::f8_t; -using ALayout = Row; +using ALayout = Col; using BLayout = Col; using CLayout = Row; @@ -30,13 +30,13 @@ using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuf PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 64, - 8, 8, + 16, 16, // AK1, BK1 16, 16, 4, 2, + S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 4, 16, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, - 2, 8, 8, 0, - S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, - 2, 8, 8, 0, + 2, 16, 16, 0, 1, 1, S<1, 32, 1, 4>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ComputeTypeA, ComputeTypeB>; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp index 8cff087ddb..89952910e6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp @@ -28,6 +28,7 @@ template constexpr auto BlockGemmPipeline_Selector() { @@ -52,6 +53,7 @@ constexpr auto BlockGemmPipeline_Selector() MRepeat, NRepeat, KPack, + KInner, TransposeC>{}; } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) @@ -75,6 +77,7 @@ constexpr auto BlockGemmPipeline_Selector() MRepeat, NRepeat, KPack, + KInner, TransposeC>{}; } else diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp index 265db9166a..abc9720714 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -30,6 +30,7 @@ template struct BlockwiseGemmWmmaops_pipeline_base { @@ -38,6 +39,7 @@ struct BlockwiseGemmWmmaops_pipeline_base static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; using ThisThreadBlock = ThisThreadBlock; @@ -54,15 +56,20 @@ struct BlockwiseGemmWmmaops_pipeline_base static constexpr index_t B_KRow = 1; #endif - static constexpr index_t A_K1 = AWmmaTileDesc{}.GetLength(I5); - static constexpr index_t B_K1 = BWmmaTileDesc{}.GetLength(I5); + static constexpr auto wmma_gemm = WmmaGemm{}; + + static constexpr index_t KPerThread = wmma_gemm.wmma_instr.k_per_blk * KInner; + static constexpr index_t A_K1 = ck::math::min(AWmmaTileDesc{}.GetLength(I6), KPerThread); + static constexpr index_t B_K1 = ck::math::min(BWmmaTileDesc{}.GetLength(I6), KPerThread); static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!"); static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!"); - - static constexpr auto wmma_gemm = - WmmaGemm{}; - static constexpr index_t KRepeat = KPerBlock / KPack; static constexpr auto WmmaK = Number{}; @@ -191,8 +198,7 @@ struct BlockwiseGemmWmmaops_pipeline_base const auto wmma_krow = 0; #endif - // |KRepeat |MRepeat|MWave |KRow |MLane |KPack - return make_tuple(0, 0, waveId_m, wmma_krow, wmma_a_idx, 0); + return make_tuple(0, 0, 0, waveId_m, wmma_krow, wmma_a_idx, 0); } __device__ static auto CalculateBThreadOriginDataIndex() @@ -209,8 +215,7 @@ struct BlockwiseGemmWmmaops_pipeline_base const auto wmma_krow = 0; #endif - // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack - return make_tuple(0, 0, waveId_n, wmma_krow, wmma_b_idx, 0); + return make_tuple(0, 0, 0, waveId_n, wmma_krow, wmma_b_idx, 0); } template @@ -241,7 +246,7 @@ struct BlockwiseGemmWmmaops_pipeline_base return make_tuple(c_thread_m, c_thread_n); } - using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); + using Tuple7 = decltype(CalculateAThreadOriginDataIndex()); /** * @brief Constructor for BlockwiseGemmWmmaops_pipeline_base. @@ -261,8 +266,8 @@ struct BlockwiseGemmWmmaops_pipeline_base * repeat dimensions. */ __host__ __device__ - BlockwiseGemmWmmaops_pipeline_base(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), - Tuple6 b_origin = CalculateBThreadOriginDataIndex()) + BlockwiseGemmWmmaops_pipeline_base(Tuple7 a_origin = CalculateAThreadOriginDataIndex(), + Tuple7 b_origin = CalculateBThreadOriginDataIndex()) : a_thread_copy_(a_origin), b_thread_copy_(b_origin) { static_assert(AWmmaTileDesc::IsKnownAtCompileTime() && @@ -343,12 +348,14 @@ struct BlockwiseGemmWmmaops_pipeline_base Number{}, I1, I1, + I1, Number{}), make_tuple(Number{}, Number{}, Number{}, I0, I0, + I0, I1)); static constexpr auto b_thread_desc_ = @@ -357,12 +364,14 @@ struct BlockwiseGemmWmmaops_pipeline_base Number{}, I1, I1, + I1, Number{}), make_tuple(Number{}, Number{}, Number{}, I0, I0, + I0, I1)); // C[M, N, NumRegWmma] @@ -374,9 +383,9 @@ struct BlockwiseGemmWmmaops_pipeline_base ComputeTypeA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), - Sequence, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, A_K1, A_K1>; @@ -385,9 +394,9 @@ struct BlockwiseGemmWmmaops_pipeline_base ComputeTypeB, decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_thread_desc_), - Sequence, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, B_K1, B_K1>; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp index 5d7c570428..5f731933e2 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp @@ -32,6 +32,7 @@ template struct BlockwiseGemmWmmaops_pipeline_v1 { @@ -55,6 +56,7 @@ template struct BlockwiseGemmWmmaops_pipeline_v1 : BlockwiseGemmWmmaops_pipeline_base { using Base = BlockwiseGemmWmmaops_pipeline_base; using Base::I0; using Base::I1; - using Base::WaveSize; using typename Base::HotLoopInstList; using Base::A_K1; @@ -187,6 +191,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -211,27 +217,23 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0, I0, I0, I0), - a_thread_buf); - + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0, I0, I0, I0), + a_thread_buf); if constexpr(m0 == I0) { if constexpr(ck::is_same::value == true) { static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple( - Number{}, n0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0, I0), - b_thread_buf); + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0, I0), + b_thread_buf); }); } else @@ -239,45 +241,60 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto n0) { b_thread_copy_.Run( b_block_desc_k0_n0_n1_n2_k1, - make_tuple( - Number{}, n0, I0, I0, I0, I0), + make_tuple(I0, n0, k0, I0, I0, I0, I0), b_block_buf, b_scale_struct.b_scale_thread_bufs( I0)[Number{}], b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0, I0), + make_tuple(I0, n0, I0, I0, I0, I0, I0), b_thread_buf); }); } } - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, I0, I0, I0, I0, Number{}))>{}]; - }); - static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, n0, I0, I0, I0, Number{}))>{}]; - }); + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + I0, + I0, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + I0, + I0, + I0, + I0, + Number{}))>{}]; + }); - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); }); }); }); @@ -324,8 +341,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + static_for<0, KInner, 1>{}([&](auto) { + static_for<0, NRepeat, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); }); }); }); @@ -348,20 +367,20 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}, I1, I1, I1, I1, Number{})); + make_tuple(Number{}, I1, I1, I1, I1, I1, Number{})); // B[NRepeat, N1, N2, KPack] - static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, I1, Number{})); + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, I1, I1, I1, I1, Number{})); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, A_K1, A_K1>; @@ -370,9 +389,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, B_K1, B_K1>; @@ -399,6 +418,7 @@ template struct BlockwiseGemmWmmaops_pipeline_v1 : BlockwiseGemmWmmaops_pipeline_base { using Base = BlockwiseGemmWmmaops_pipeline_base; using Base::I0; using Base::I1; @@ -532,6 +555,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -557,33 +582,22 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0_offset) { static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) { static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{}, - m0, - I0, - I0, - I0, - I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, k0_inner, I0, I0, I0), - a_thread_buf); + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0_offset + k0_inner, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0_inner, I0, I0, I0, I0), + a_thread_buf); }); if constexpr(ck::is_same::value == true) { static_for<0, NRepeat, 1>{}([&](auto n0) { b_thread_copy_.Run( b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{}, - n0, - I0, - I0, - I0, - I0), + make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0), b_block_buf, b_thread_desc_, - make_tuple(I0, n0, k0_inner, I0, I0, I0), + make_tuple(I0, n0, k0_inner, I0, I0, I0, I0), b_thread_buf); }); } @@ -592,18 +606,13 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto n0) { b_thread_copy_.Run( b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{}, - n0, - I0, - I0, - I0, - I0), + make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0), b_block_buf, b_scale_struct.b_scale_thread_bufs(I0)[Number< n0 * BScaleStruct::num_scale_k_block + (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}], b_thread_desc_, - make_tuple(I0, n0, k0_inner, I0, I0, I0), + make_tuple(I0, n0, k0_inner, I0, I0, I0, I0), b_thread_buf); }); } @@ -622,62 +631,69 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0_inner) { - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0_inner, - I0, - I0, - Number{}))>{}]; - }); - static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k0_inner, - I0, - I0, - Number{}))>{}]; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0_inner, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k0_inner, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard. + // B) reduce VMEM FIFO congestion by applying small delays to + // different wavefronts. + // It is performed near the end of MAC cluster to minimize lgkmcnt + // penalty + if constexpr(k0_offset + k0_inner == KRepeat - 1 && + m0 == MRepeat - 1 && n0 == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - // The block_sync_lds() here performs double duty: - // A) safeguard against data hazard. - // B) reduce VMEM FIFO congestion by applying small delays to - // different wavefronts. - // It is performed near the end of MAC cluster to minimize lgkmcnt - // penalty - if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 && - n0 == NRepeat - 1) - { - __builtin_amdgcn_sched_barrier(0); - block_sync_lds(); - __builtin_amdgcn_sched_barrier(0); - } - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0) - { - __builtin_amdgcn_sched_barrier(0); - __builtin_amdgcn_s_setprio(1); - __builtin_amdgcn_sched_barrier(0); - } }); }); }); @@ -729,12 +745,14 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}, I1, I1, + I1, Number{}), make_tuple(Number{}, Number{}, Number{}, I0, I0, + I0, I1)); static constexpr auto b_thread_desc_ = @@ -743,12 +761,14 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}, I1, I1, + I1, Number{}), make_tuple(Number{}, Number{}, Number{}, I0, I0, + I0, I1)); using AThreadCopy = @@ -756,9 +776,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, A_K1, A_K1>; @@ -767,9 +787,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1, - Sequence<0, 1, 2, 3, 4, 5>, - 5, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, B_K1, B_K1>; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp index 83dadb2175..cbe13b6e00 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp @@ -32,6 +32,7 @@ template struct BlockwiseGemmWmmaops_pipeline_v3 { @@ -55,6 +56,7 @@ template struct BlockwiseGemmWmmaops_pipeline_v3 : BlockwiseGemmWmmaops_pipeline_base { using Base = BlockwiseGemmWmmaops_pipeline_base; using Base::I0; @@ -290,40 +295,37 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, k0, I0, I0, I0), - a_thread_buf); + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); }); if constexpr(ck::is_same_v) { static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, k0, I0, I0, I0), - b_thread_buf); + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_thread_buf); }); } else { static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0, I0), - b_block_buf, - b_scale_struct.b_scale_thread_bufs( - I0)[Number{}], - b_thread_desc_, - make_tuple(I0, n0, k0, I0, I0, I0), - b_thread_buf); + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + b_scale_struct.b_scale_thread_bufs( + I0)[Number{}], + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_thread_buf); }); } }); @@ -364,6 +366,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -424,41 +429,48 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, - m0, - k0, - I0, - I0, - Number{}))>{}]; + static_for<0, KInner, 1>{}([&](auto k_inner) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); - static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, - n0, - k0, - I0, - I0, - Number{}))>{}]; - }); - - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); @@ -489,31 +501,47 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, m0, k0, I0, I0, Number{}))>{}]; - }); - static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, n0, k0, I0, I0, Number{}))>{}]; - }); + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); }); }); }); @@ -531,31 +559,47 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = - a_thread_buf[Number{}, m0, k0, I0, I0, Number{}))>{}]; - }); - static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = - b_thread_buf[Number{}, n0, k0, I0, I0, Number{}))>{}]; - }); + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k0, + I0, + I0, + I0, + Number{}))>{}]; + }); - using wmma_input_type_a = - typename vector_type::type; - using wmma_input_type_b = - typename vector_type::type; + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); - wmma_gemm.Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); }); }); }); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp index 465952e285..23f16d38e9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp @@ -17,6 +17,9 @@ template {}, KRow)), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + if constexpr(KInner > 1) + { + // KPack = KInner * KPerWmma + // K1 = KInner * KPerWmmaBlk + // Each thread loads multiple tiles with one instruction + // 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1 + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple( + make_unmerge_transform(make_tuple(Number{}, KRow, Number<1>{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{})); + } + else + { + // KPack = KPerWmma (KInner == 1) + if constexpr(ABK1 <= KPerWmmaBlk) + { + // K1 <= single tile (KPerWmmaBlk) + // Each thread will load KPerWmmaBlk for the WMMA instruction + // Since K1 <= single tile, K0 is unmerged first over KPack / KRow / K1 + // (rest of the single WMMA tile for single thread) and then over KRow + // (rest of the single WMMA tile for single wave) + // KPack / KRow / K1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1 + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple( + make_unmerge_transform(make_tuple( + Number{}, KRow, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{})); + } + else + { + // K1 > single tile (KPerWmmaBlk) + // Each thread will load KPerWmmaBlk for the WMMA instruction + // Since K1 > single tile, each thread loads KPerWmmaBlk and the next + // KPerWmmaBlk chunk is loaded by a different thread in the same wave (WMMA layout). + // This layout is needed to support for example AK1 > single tile and + // BK1 <= single tile in the same gemm + // KPack / KPerWmmaBlk / KRow - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - + // K1 + constexpr auto desc1 = transform_tensor_descriptor( + BlockDesc{}, + make_tuple( + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2>{}, Sequence<1, 4, 6>{}, Sequence<3, 0, 5, 7>{})); + + return transform_tensor_descriptor( + desc1, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_merge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2, 3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{})); + } + } } __device__ static constexpr auto GetBlockStep() diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp index 68476ef3bf..a36ccd43ca 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp @@ -313,14 +313,16 @@ struct ABTransferWaveTiles // This is a block descriptor used to read LDS memory into register // It's defined in a way consistent with the existing implementation to // avoid changes in the pipelines - return make_naive_tensor_descriptor(make_tuple(Number{}, + return make_naive_tensor_descriptor(make_tuple(I1, Number{}, + Number{}, Number{}, Number{}, Number{}, Number{}), - make_tuple(Number{}, + make_tuple(I0, Number{}, + Number{}, Number{}, Number{}, Number{}, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp index fa7eb4faaa..38ebdab65e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp @@ -109,9 +109,20 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); - // TODO: I am pretty sure this is always 16 and *should* always be 16. - static constexpr auto KPack = - math::integer_least_multiple(math::integer_least_multiple(AK1Value, BK1Value), 16); + static constexpr index_t KPerWmmaBlk = + WmmaSelector::selected_wmma + .k_per_blk; + + static constexpr index_t KInnerA = ck::math::integer_divide_ceil(AK1Value, KPerWmmaBlk); + + static constexpr index_t KInnerB = ck::math::integer_divide_ceil(BK1Value, KPerWmmaBlk); + + static constexpr index_t KInner = ck::math::min(KInnerA, KInnerB); + + static constexpr index_t KPack = + KInner * + WmmaSelector::selected_wmma + .k_per_wmma; using ThisThreadBlock = ThisThreadBlock; @@ -201,54 +212,115 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 return b1_block_copy_step; } - template - __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) + template + __host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&) { - constexpr auto a_wave_desc = [&]() { - // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 - constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); - constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + // K0_MN_K1 -> K0_MNRepeat_MNWaves_KRow_MNPerWmma_K1 + constexpr auto K0 = BlockDesc{}.GetLength(I0); + constexpr auto K1 = BlockDesc{}.GetLength(I2); #ifdef __gfx12__ - constexpr auto A_KRow = I2; + constexpr auto KRow = I2; #else - constexpr auto A_KRow = I1; + constexpr auto KRow = I1; #endif + + if constexpr(KInner > 1) + { + // KPack = KInner * KPerWmma + // K1 = KInner * KPerWmmaBlk + // Each thread loads multiple tiles with one instruction + // 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1 return transform_tensor_descriptor( - ABlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), + BlockDesc{}, + make_tuple( + make_unmerge_transform(make_tuple(Number{}, KRow, Number<1>{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); - }(); + make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{})); + } + else + { + // KPack = KPerWmma (KInner == 1) + if constexpr(K1 <= KPerWmmaBlk) + { + // K1 <= single tile (KPerWmmaBlk) + // Each thread will load KPerWmmaBlk for the WMMA instruction + // Since K1 <= single tile, K0 is unmerged first over KPack / KRow / K1 + // (rest of the single WMMA tile for single thread) and then over KRow + // (rest of the single WMMA tile for single wave) + // KPack / KRow / K1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1 + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple(make_unmerge_transform(make_tuple( + Number{}, KRow, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{})); + } + else + { + // K1 > single tile (KPerWmmaBlk) + // Each thread will load KPerWmmaBlk for the WMMA instruction + // Since K1 > single tile, each thread loads KPerWmmaBlk and the next + // KPerWmmaBlk chunk is loaded by a different thread in the same wave (WMMA layout). + // This layout is needed to support for example AK1 > single tile and + // BK1 <= single tile in the same gemm + // KPack / KPerWmmaBlk / KRow - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - + // K1 + constexpr auto desc1 = transform_tensor_descriptor( + BlockDesc{}, + make_tuple( + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<2>{}, Sequence<1, 4, 6>{}, Sequence<3, 0, 5, 7>{})); + + return transform_tensor_descriptor( + desc1, + make_tuple(make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_merge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2, 3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{})); + } + } + } - return a_wave_desc; + template + __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) + { + return MakeWmmaTileDescriptor(ABlockDesc_{}); } template __host__ __device__ static constexpr auto MakeB0WaveDescriptor(const B0BlockDesc_&) { - constexpr auto b0_wave_desc = [&]() { - // BK0_L_BK1 -> BK0_LRepeat_Lwaves_BKRow_LPerWmma_BK1 - constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); - constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); -#ifdef __gfx12__ - constexpr auto B_KRow = I2; -#else - constexpr auto B_KRow = I1; -#endif - return transform_tensor_descriptor( - B0BlockDesc_{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); - }(); - - return b0_wave_desc; + return MakeWmmaTileDescriptor(B0BlockDesc_{}); } template @@ -356,6 +428,7 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 MRepeat, LRepeat, KPack, + KInner, true>())>; // TransposeC (must be true to work), C' = B' x A' // block_id to matrix tile idx (m0, n0) mapping is controlled by {M01, N01} diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 7a5e324468..56f09cee96 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -151,10 +151,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base static constexpr auto AK1Number = Number{}; static constexpr auto BK1Number = Number{}; - static constexpr index_t KPack = math::max( - math::lcm(AK1Number, BK1Number), + static constexpr index_t KPerWmmaBlk = WmmaSelector::selected_wmma - .k_per_wmma); + .k_per_blk; + + static constexpr index_t KInnerA = ck::math::integer_divide_ceil(AK1Value, KPerWmmaBlk); + + static constexpr index_t KInnerB = ck::math::integer_divide_ceil(BK1Value, KPerWmmaBlk); + + static constexpr index_t KInner = ck::math::min(KInnerA, KInnerB); + + static constexpr index_t KPack = + KInner * + WmmaSelector::selected_wmma + .k_per_wmma; using ThisThreadBlock = ThisThreadBlock; @@ -218,6 +228,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base KPerBlock, MPerWmma, AK1Value, + KPack, + KInner, + KPerWmmaBlk, UseBlockPaddingA, PermuteA, ABlockTransferThreadClusterLengths_AK0_M_AK1, @@ -251,6 +264,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base KPerBlock, NPerWmma, BK1Value, + KPack, + KInner, + KPerWmmaBlk, UseBlockPaddingB, PermuteB, BBlockTransferThreadClusterLengths_BK0_N_BK1, @@ -563,7 +579,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base NPerWmma, MRepeat, NRepeat, - KPack>())>; + KPack, + KInner>())>; // Used to create obj in global function and pass it to Run method using EpilogueCShuffle = diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index bca68764f9..55ede990af 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -95,6 +95,7 @@ struct wmma_type __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp index 71b5c5e7cf..806b6e684d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp @@ -48,7 +48,9 @@ using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 2, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp index f4489dc45f..4516d06492 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp @@ -50,7 +50,9 @@ using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 2, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp index 423f86365c..5ace0594f0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp @@ -53,7 +53,9 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 2, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp index 2eb28958e6..27deab1c8c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp @@ -56,7 +56,9 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 2, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp index d10b9facd5..bd5c7d8783 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp @@ -48,7 +48,8 @@ using device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 16, 16, 16, 2, 8, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp index d9d16ede65..1956d1a951 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp @@ -49,7 +49,8 @@ using device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 16, 16, 16, 2, 8, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp index 9277e5e901..934c6aa7ef 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp @@ -51,7 +51,8 @@ using device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 16, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp index e97a649c19..9860b81b78 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp @@ -51,7 +51,8 @@ using device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 16, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp index c8f1b85ddb..4d7169565a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp @@ -49,7 +49,8 @@ using device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 8, 16, 16, 2, 8, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp index fc0220a502..3728368bc4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp @@ -51,7 +51,8 @@ using device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 8, 16, 16, 2, 8, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp index b87cf64b0f..3506575f5d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp @@ -51,7 +51,8 @@ using device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp index 31ad66409e..eef0d6de6a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp @@ -50,7 +50,8 @@ using device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_instances = std::tuple DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 8, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp index 4c37c398fe..2418be62b7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp @@ -55,7 +55,8 @@ using device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8> + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 16, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1, F8, F8> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp index 6b5314b701..38f2869303 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp @@ -51,7 +51,8 @@ using device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1, F8, F8>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, - DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8> + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 16, 16, 16, 16, 2, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1, F8, F8> // clang-format on >; } // namespace instance