From 3e263e6d85e662e3e1bde8cb7bde8df36b9de1cf Mon Sep 17 00:00:00 2001 From: xly Date: Thu, 5 Jun 2025 21:55:06 +0100 Subject: [PATCH 1/3] new allocator --- core/common/constant.h | 9 + core/common/generator.h | 83 ++ core/common/types.h | 99 ++ core/kernel/b2b_gemm.h | 982 ++++++++++++++++++ .../kernel/b2b_gemm_grouped_problem_visitor.h | 149 +++ core/kernel/b2b_mma_pipelined.h | 590 +++++++++++ core/kernel/epilogue_utils.h | 133 +++ core/kernel/grouped_threadblock_swizzle.h | 123 +++ core/kernel/mlp_tile_op.h | 299 ++++++ core/kernel/tile_size.h | 405 ++++++++ core/kernel/utils.h | 38 + core/memory/caching_allocator.cpp | 299 ++++++ core/memory/caching_allocator.h | 205 ++-- core/memory/caching_allocator_bk.h | 108 ++ core/memory/shared_memory.cpp | 47 + core/memory/shared_memory.h | 42 + core/memory/torch_caching_allocator.cpp | 6 + core/memory/torch_caching_allocator.h | 45 + core/parallel/expert_dispatcher.cpp | 4 +- core/parallel/expert_module.cpp | 31 +- core/parallel/expert_module.h | 307 ++++-- core/utils/cuda_utils.h | 56 + core/utils/logger.h | 16 + examples/interface_example.py | 10 +- moe_infinity/kernel/__init__.py | 1 + moe_infinity/kernel/router.py | 234 +++++ moe_infinity/models/deepseek.py | 72 +- moe_infinity/runtime/model_offload.py | 16 +- op_builder/prefetch.py | 35 +- tests/cuda/CMakeLists.txt | 153 +++ tests/cuda/test_autosize_tileload.cu | 255 +++++ tests/cuda/test_autosize_tileload_stage.cu | 329 ++++++ tests/cuda/test_autotune_blocksize.cu | 38 + tests/cuda/test_expert_fusion.cu | 343 ++++++ tests/cuda/test_expert_fusion_v2.cu | 191 ++++ tests/cuda/test_fused_mlp.cu | 161 +++ tests/cuda/test_load_tile.cu | 250 +++++ tests/cuda/test_load_tile_templated.cu | 332 ++++++ tests/cuda/test_single_gemm_tiled.cu | 130 +++ tests/cuda/test_tile_size.cu | 254 +++++ tests/cuda/test_uvm_kernel.cu | 216 ++++ 41 files changed, 6873 insertions(+), 223 deletions(-) create mode 100644 core/common/constant.h create mode 100644 core/common/generator.h create mode 100644 core/kernel/b2b_gemm.h create mode 100644 core/kernel/b2b_gemm_grouped_problem_visitor.h create mode 100644 core/kernel/b2b_mma_pipelined.h create mode 100644 core/kernel/epilogue_utils.h create mode 100644 core/kernel/grouped_threadblock_swizzle.h create mode 100644 core/kernel/mlp_tile_op.h create mode 100644 core/kernel/tile_size.h create mode 100644 core/kernel/utils.h create mode 100644 core/memory/caching_allocator.cpp create mode 100644 core/memory/caching_allocator_bk.h create mode 100644 core/memory/shared_memory.cpp create mode 100644 core/memory/shared_memory.h create mode 100644 core/memory/torch_caching_allocator.cpp create mode 100644 core/memory/torch_caching_allocator.h create mode 100644 moe_infinity/kernel/__init__.py create mode 100644 moe_infinity/kernel/router.py create mode 100644 tests/cuda/CMakeLists.txt create mode 100644 tests/cuda/test_autosize_tileload.cu create mode 100644 tests/cuda/test_autosize_tileload_stage.cu create mode 100644 tests/cuda/test_autotune_blocksize.cu create mode 100644 tests/cuda/test_expert_fusion.cu create mode 100644 tests/cuda/test_expert_fusion_v2.cu create mode 100644 tests/cuda/test_fused_mlp.cu create mode 100644 tests/cuda/test_load_tile.cu create mode 100644 tests/cuda/test_load_tile_templated.cu create mode 100644 tests/cuda/test_single_gemm_tiled.cu create mode 100644 tests/cuda/test_tile_size.cu create mode 100644 tests/cuda/test_uvm_kernel.cu diff --git a/core/common/constant.h b/core/common/constant.h new file mode 100644 index 0000000..921156a --- /dev/null +++ b/core/common/constant.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +constexpr int64_t KB = 1024; +constexpr int64_t MB = KB * KB; +constexpr int64_t GB = KB * KB * KB; + +constexpr int kWrapSize = 32; diff --git a/core/common/generator.h b/core/common/generator.h new file mode 100644 index 0000000..ca4ec33 --- /dev/null +++ b/core/common/generator.h @@ -0,0 +1,83 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +class NumGenerator { + public: + // 0是一个特殊的id,必须保证永远不会生成0这个id + static uint32_t ctx_id() { + std::lock_guard g(mutex_); + uint32_t ret = ctx_id_++; + if (ret == 0) ret = ctx_id_++; + return ret; + } + static uint32_t flowno() { + static std::atomic flowno(1024); + return flowno++; + } + + private: + static std::mutex mutex_; + static uint32_t ctx_id_; // Start from 1 to avoid 0 +}; + +// Static member definitions +std::mutex NumGenerator::mutex_; +uint32_t NumGenerator::ctx_id_ = 1; + +inline std::string GenUUID() { + uuid_t uuid; + uuid_generate(uuid); + char uuid_str[37]; + uuid_unparse(uuid, uuid_str); + return std::string(uuid_str); +} + +inline uint64_t GenUUID64() { + static std::random_device rd; + static std::mt19937_64 eng(rd()); + static std::uniform_int_distribution distr; + + std::bitset<64> uuid; + uuid = std::chrono::high_resolution_clock::now().time_since_epoch().count(); + uuid ^= distr(eng); + + return uuid.to_ullong(); +} + +inline std::string CurrentTimeString() { + // Get current time as time_point + auto now = std::chrono::system_clock::now(); + + // Convert time_point to system time for breaking down into components + auto now_c = std::chrono::system_clock::to_time_t(now); + auto now_tm = *std::localtime(&now_c); + + // Get the current time as milliseconds + auto now_ms = std::chrono::duration_cast( + now.time_since_epoch()) % + 1000; + + // Use stringstream to format the time + std::ostringstream oss; + oss << std::put_time(&now_tm, "%Y-%m-%d %H:%M:%S"); + oss << '.' << std::setfill('0') << std::setw(3) << now_ms.count(); + + return oss.str(); +} + +// constexpr microseconds since epoch +inline uint64_t CurrentTimeMicros() { + return std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); +} diff --git a/core/common/types.h b/core/common/types.h index 8dec394..84a9280 100644 --- a/core/common/types.h +++ b/core/common/types.h @@ -6,6 +6,9 @@ #pragma once #include +#include +#include +#include typedef std::uint32_t TensorID; typedef std::size_t HashID; @@ -35,3 +38,99 @@ template struct DoNothingDeleter { void operator()(T* ptr) const {} }; + +// Helper to get the Nth type from a parameter pack +template +struct GetNthType; + +template +struct GetNthType<0, First, Rest...> { + using type = First; +}; + +template +struct GetNthType { + using type = typename GetNthType::type; +}; + +template +using GetNthType_t = typename GetNthType::type; + +// Compile-time integer square root +template +struct ConstexprSqrt { + static constexpr int compute(int low = 1, int high = N) { + if (low == high) return low; + int mid = (low + high + 1) / 2; + return (mid * mid > N) ? compute(low, mid - 1) : compute(mid, high); + } + static constexpr int value = compute(); +}; + +// Round to multiple helper +template +struct RoundToMultiple { + static constexpr int value = ((N + Multiple - 1) / Multiple) * Multiple; +}; + +// A constexpr function to convert any const T* pointer to void* +template +constexpr void* pointer_to_void(const T* ptr) { + return const_cast(reinterpret_cast( + ptr)); // Cast to void* while preserving constness +} + +// Helper macros to generate enum and string mappings +#define ENUM_ENTRY_COMMA(value, EnumType) value, +#define ENUM_CASE(value, EnumType) \ + case EnumType::value: \ + return #value; +#define STRING_CASE(value, EnumType) \ + if (s == #value) return EnumType::value; + +// General enum to string conversion using SFINAE +template +constexpr auto enum_to_string(E e) noexcept + -> std::enable_if_t, const char*> { + // This will be specialized for each enum type + return "Unknown"; +} + +// General string to enum conversion +template +constexpr auto string_to_enum(const std::string& s) noexcept + -> std::enable_if_t, E> { + // This will be specialized for each enum type + return static_cast(0); // Default to first enum value +} + +// Macro to define enum class, enum to string, and string to enum functions +#define DEFINE_ENUM_CLASS(EnumType, ENUM_VALUES) \ + enum class EnumType { ENUM_VALUES(ENUM_ENTRY_COMMA, EnumType) Unknown }; \ + \ + /* Enum to string function */ \ + constexpr const char* EnumType##ToString(EnumType v) { \ + switch (v) { \ + ENUM_VALUES(ENUM_CASE, EnumType) \ + default: \ + return "Unknown"; \ + } \ + } \ + \ + /* String to enum function */ \ + EnumType StringTo##EnumType(const std::string& s) { \ + ENUM_VALUES(STRING_CASE, EnumType) \ + return EnumType::Unknown; \ + } \ + \ + /* Specialize generic template functions for this enum type */ \ + template <> \ + constexpr auto enum_to_string( \ + EnumType e) noexcept -> const char* { \ + return EnumType##ToString(e); \ + } \ + \ + template <> \ + auto string_to_enum(const std::string& s) noexcept -> EnumType { \ + return StringTo##EnumType(s); \ + } diff --git a/core/kernel/b2b_gemm.h b/core/kernel/b2b_gemm.h new file mode 100644 index 0000000..76d0bfc --- /dev/null +++ b/core/kernel/b2b_gemm.h @@ -0,0 +1,982 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or + support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "b2b_gemm_grouped_problem_visitor.h" +#include "grouped_threadblock_swizzle.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +namespace detail { + +/// Utility struct for returning the type of the problem visitor used by the +/// swizzling function, if it is a grouped swizzling function, or a default +/// visitor. This is used only for defining the parameters of the problem +/// visitor used in GroupedParams. +template +struct ProblemVisitorOrDefault; + +/// Return a generic problem visitor for GEMM problems +template +struct ProblemVisitorOrDefault< + B2bMma_, ThreadblockSwizzle_, + typename platform::enable_if< + !cutlass::gemm::threadblock::detail::IsGroupedSwizzle< + ThreadblockSwizzle_>::value>::type> { + using value = B2bGemmGroupedProblemVisitor< + typename B2bMma_::Shape, GroupScheduleMode::kDeviceOnly, 128, 128, + platform::is_same::value>; +}; + +/// Return the problem visitor specified by the swizzling function +template +struct ProblemVisitorOrDefault< + B2bMma_, ThreadblockSwizzle_, + typename platform::enable_if< + cutlass::gemm::threadblock::detail::IsGroupedSwizzle< + ThreadblockSwizzle_>::value>::type> { + using value = typename ThreadblockSwizzle_::ProblemVisitor; +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct B2bGemm { + using B2bMma = B2bMma_; + using Epilogue = Epilogue_; + using OutputOp0 = typename B2bMma::OutputOp; + using OutputOp1 = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA0 = typename B2bMma::IteratorA0::Element; + using LayoutA0 = typename B2bMma::IteratorA0::Layout; + using ElementB0 = typename B2bMma::IteratorB0::Element; + using LayoutB0 = typename B2bMma::IteratorB0::Layout; + using ElementB1 = typename B2bMma::IteratorB1::Element; + using LayoutB1 = typename B2bMma::IteratorB1::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element; + + /// Data types needed for higher-level containers. In some cases, a single + /// type must be exposed despite the B2b GEMM using two GEMMs under the hood. + /// In such cases, we select the values from the second GEMM (other than for + /// ElementA/ElementB) + using ElementA = typename B2bMma::IteratorA0::Element; + using LayoutA = typename B2bMma::IteratorA0::Layout; + using ElementB = typename B2bMma::IteratorB0::Element; + using LayoutB = typename B2bMma::IteratorB0::Layout; + + static ComplexTransform const kTransformA = B2bMma::kTransformA; + static ComplexTransform const kTransformB = B2bMma::kTransformB; + using Operator = typename B2bMma::Operator0; + + using OperatorClass = typename Operator::OperatorClass; + using ThreadblockShape = typename B2bMma::Shape0; + using WarpShape = typename Operator::Shape; + using InstructionShape = typename Operator::InstructionShape; + using ArchTag = typename B2bMma::ArchTag; + + static int const kStages = B2bMma::kStages; + static int const kAlignmentA = B2bMma::IteratorA::AccessType::kElements; + static int const kAlignmentB = B2bMma::IteratorB::AccessType::kElements; + static int const kAlignmentC = + Epilogue::OutputTileIterator::kElementsPerAccess; + + using Mma = B2bMma; + using EpilogueOutputOp = OutputOp1; + + /// Warp count (concept: GemmShape) + using WarpCount0 = typename B2bMma::WarpCount0; + static int const kThreadCount = 32 * WarpCount0::kCount; + + /// Argument structure + struct Arguments { + // + // Data members + // + + GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm; + GemmCoord problem_size_0{0, 0, 0}; + GemmCoord problem_size_1{0, 0, 0}; + typename B2bMma::IteratorA0::TensorRef ref_A0{}; + typename B2bMma::IteratorB0::TensorRef ref_B0{}; + typename Epilogue::OutputTileIterator::TensorRef ref_C0{}; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0{}; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0{}; + typename B2bMma::IteratorB1::TensorRef ref_B1{}; + typename Epilogue::OutputTileIterator::TensorRef ref_C1{}; + typename Epilogue::OutputTileIterator::TensorRef ref_D1{}; + int64_t batch_stride_A0{0}; + int64_t batch_stride_B0{0}; + int64_t batch_stride_B1{0}; + int64_t batch_stride_C1{0}; + int64_t batch_stride_D1{0}; + int64_t batch_stride_Bias0{0}; + int64_t batch_stride_Scale0{0}; + typename OutputOp0::Params epilogue0{}; + typename OutputOp1::Params epilogue1{}; + int batch_count{1}; + + // + // Methods + // + + /// Default ctor + Arguments() = default; + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmUniversalMode mode_, GemmCoord problem_size_0_, + GemmCoord problem_size_1_, + typename B2bMma::IteratorA0::TensorRef ref_A0_, + typename B2bMma::IteratorB0::TensorRef ref_B0_, + typename Epilogue::OutputTileIterator::TensorRef ref_C0_, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0_, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0_, + typename B2bMma::IteratorB1::TensorRef ref_B1_, + typename Epilogue::OutputTileIterator::TensorRef ref_C1_, + typename Epilogue::OutputTileIterator::TensorRef ref_D1_, + int64_t batch_stride_A0_, int64_t batch_stride_B0_, + int64_t batch_stride_B1_, int64_t batch_stride_C1_, + int64_t batch_stride_D1_, int64_t batch_stride_Bias0_, + int64_t batch_stride_Scale0_, + typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(), + typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(), + int batch_count_ = 1) + : mode(mode_), + problem_size_0(problem_size_0_), + problem_size_1(problem_size_1_), + ref_A0(ref_A0_), + ref_B0(ref_B0_), + ref_C0(ref_C0_), + ref_Scale0(ref_Scale0_), + ref_Bias0(ref_Bias0_), + ref_B1(ref_B1_), + ref_C1(ref_C1_), + ref_D1(ref_D1_), + batch_stride_A0(batch_stride_A0_), + batch_stride_B0(batch_stride_B0_), + batch_stride_B1(batch_stride_B1_), + batch_stride_C1(batch_stride_C1_), + batch_stride_D1(batch_stride_D1_), + batch_stride_Bias0(batch_stride_Bias0_), + batch_stride_Scale0(batch_stride_Scale0_), + epilogue0(epilogue0_), + epilogue1(epilogue1_), + batch_count(batch_count_) {} + }; + + // Arguments structure for grouped B2B problems + struct GroupedArguments { + GemmCoord* problem_size_0; + GemmCoord* problem_size_1; + typename B2bMma::IteratorA0::TensorRef* ref_A0; + typename B2bMma::IteratorB0::TensorRef* ref_B0; + typename Epilogue::OutputTileIterator::TensorRef* ref_C0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0; + typename B2bMma::IteratorB1::TensorRef* ref_B1; + typename Epilogue::OutputTileIterator::TensorRef* ref_C1; + typename Epilogue::OutputTileIterator::TensorRef* ref_D1; + + // Epilogue params remain constant across all problems in the group. Thus, + // the parameter here is not a pointer. + typename OutputOp0::Params epilogue0; + typename OutputOp1::Params epilogue1; + + int problem_count; + int threadblock_count; + GemmCoord* host_problem_sizes; + + CUTLASS_HOST_DEVICE + GroupedArguments( + int problem_count, GemmCoord* problem_size_0_, + GemmCoord* problem_size_1_, + typename B2bMma::IteratorA0::TensorRef* ref_A0_, + typename B2bMma::IteratorB0::TensorRef* ref_B0_, + typename Epilogue::OutputTileIterator::TensorRef* ref_C0_, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0_, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0_, + typename B2bMma::IteratorB1::TensorRef* ref_B1_, + typename Epilogue::OutputTileIterator::TensorRef* ref_C1_, + typename Epilogue::OutputTileIterator::TensorRef* ref_D1_, + typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(), + typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(), + int threadblock_count = 0) + : problem_size_0(problem_size_0_), + problem_size_1(problem_size_1_), + ref_A0(ref_A0_), + ref_B0(ref_B0_), + ref_C0(ref_C0_), + ref_Scale0(ref_Scale0_), + ref_Bias0(ref_Bias0_), + ref_B1(ref_B1_), + ref_C1(ref_C1_), + ref_D1(ref_D1_), + epilogue0(epilogue0_), + epilogue1(epilogue1_), + problem_count(problem_count), + threadblock_count(threadblock_count) {} + }; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmUniversalMode mode = + cutlass::gemm::GemmUniversalMode::kGemm; + cutlass::gemm::GemmCoord problem_size_0{}; + cutlass::gemm::GemmCoord problem_size_1{}; + cutlass::gemm::GemmCoord grid_tiled_shape{}; + int swizzle_log_tile{0}; + typename B2bMma::IteratorA0::Params params_A0{}; + typename B2bMma::IteratorA0::TensorRef ref_A0{}; + typename B2bMma::IteratorB0::Params params_B0{}; + typename B2bMma::IteratorB0::TensorRef ref_B0{}; + typename Epilogue::OutputTileIterator::Params params_C0{}; + typename Epilogue::OutputTileIterator::TensorRef ref_C0{}; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0{}; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0{}; + typename B2bMma::IteratorB1::Params params_B1{}; + typename B2bMma::IteratorB1::TensorRef ref_B1{}; + typename Epilogue::OutputTileIterator::Params params_C1{}; + typename Epilogue::OutputTileIterator::TensorRef ref_C1{}; + typename Epilogue::OutputTileIterator::Params params_D1{}; + typename Epilogue::OutputTileIterator::TensorRef ref_D1{}; + typename OutputOp0::Params output_op_0{}; + typename OutputOp1::Params output_op_1{}; + int64_t batch_stride_A0{0}; + int64_t batch_stride_B0{0}; + int64_t batch_stride_B1{0}; + int64_t batch_stride_C1{0}; + int64_t batch_stride_D1{0}; + int64_t batch_stride_Bias0{0}; + int64_t batch_stride_Scale0{0}; + int* semaphore = nullptr; + int gemm_k_iterations_0{0}; + int gemm_k_size_0{0}; + int gemm_k_iterations_1{0}; + int gemm_k_size_1{0}; + + // + // Methods + // + + Params() = default; + + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord const& problem_size_0, + cutlass::gemm::GemmCoord const& problem_size_1, + cutlass::gemm::GemmCoord const& grid_tiled_shape, + typename B2bMma::IteratorA0::TensorRef ref_A0, + typename B2bMma::IteratorB0::TensorRef ref_B0, + typename Epilogue::OutputTileIterator::TensorRef ref_C0, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0, + typename B2bMma::IteratorB1::TensorRef ref_B1, + typename Epilogue::OutputTileIterator::TensorRef ref_C1, + typename Epilogue::OutputTileIterator::TensorRef ref_D1, + int64_t batch_stride_A0, int64_t batch_stride_B0, + int64_t batch_stride_B1, int64_t batch_stride_C1, + int64_t batch_stride_D1, int64_t batch_stride_Bias0, + int64_t batch_stride_Scale0, + typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(), + typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(), + int* workspace = nullptr) + : mode(mode), + problem_size_0(problem_size_0), + problem_size_1(problem_size_1), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle::get_log_tile(grid_tiled_shape)), + params_A0(ref_A0.layout()), + ref_A0(ref_A0), + params_B0(ref_B0.layout()), + ref_B0(ref_B0), + params_C0(ref_C0.layout()), + ref_C0(ref_C0), + ref_Scale0(ref_Scale0), + ref_Bias0(ref_Bias0), + params_B1(ref_B1.layout()), + ref_B1(ref_B1), + params_C1(ref_C1.layout()), + ref_C1(ref_C1), + params_D1(ref_D1.layout()), + ref_D1(ref_D1), + batch_stride_A0(batch_stride_A0), + batch_stride_B0(batch_stride_B0), + batch_stride_B1(batch_stride_B1), + batch_stride_C1(batch_stride_C1), + batch_stride_D1(batch_stride_D1), + batch_stride_Bias0(batch_stride_Bias0), + batch_stride_Scale0(batch_stride_Scale0), + output_op_0(output_op_0), + output_op_1(output_op_1) { + int total_gemm_k_iterations_0 = + (problem_size_0.k() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK; + int gemm_k_iterations_0 = + (total_gemm_k_iterations_0 + grid_tiled_shape.k() - 1) / + grid_tiled_shape.k(); + gemm_k_size_0 = gemm_k_iterations_0 * B2bMma::Shape0::kK; + int total_gemm_k_iterations_1 = + (problem_size_1.k() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK; + int gemm_k_iterations_1 = + (total_gemm_k_iterations_1 + grid_tiled_shape.k() - 1) / + grid_tiled_shape.k(); + gemm_k_size_1 = gemm_k_iterations_1 * B2bMma::Shape1::kK; + + semaphore = workspace; + } + }; + + struct GroupedParams { + cutlass::gemm::GemmCoord* problem_size_0; + cutlass::gemm::GemmCoord* problem_size_1; + cutlass::gemm::GemmCoord* grid_tiled_shape; + typename B2bMma::IteratorA0::TensorRef* ref_A0; + typename B2bMma::IteratorB0::TensorRef* ref_B0; + typename Epilogue::OutputTileIterator::TensorRef* ref_C0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0; + typename B2bMma::IteratorB1::TensorRef* ref_B1; + typename Epilogue::OutputTileIterator::TensorRef* ref_C1; + typename Epilogue::OutputTileIterator::TensorRef* ref_D1; + + // Epilogue params remain constant across all problems in the group. Thus, + // the parameter here is not a pointer. + typename OutputOp0::Params output_op_0; + typename OutputOp1::Params output_op_1; + + using ProblemVisitor = + typename detail::ProblemVisitorOrDefault::value; + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + int* workspace; + + CUTLASS_HOST_DEVICE + GroupedParams() {} + + CUTLASS_HOST_DEVICE + GroupedParams(GroupedArguments const& args, void* workspace = nullptr, + int tile_count = 0) + : problem_size_0(args.problem_size_0), + problem_size_1(args.problem_size_1), + ref_A0(args.ref_A0), + ref_B0(args.ref_B0), + ref_C0(args.ref_C0), + ref_Scale0(args.ref_Scale0), + ref_Bias0(args.ref_Bias0), + ref_B1(args.ref_B1), + ref_C1(args.ref_C1), + ref_D1(args.ref_D1), + output_op_0(args.epilogue0), + output_op_1(args.epilogue1), + problem_visitor(args.problem_size_0, args.problem_size_1, + args.problem_count, workspace, tile_count), + threadblock_count(args.threadblock_count), + workspace(reinterpret_cast(workspace)) {} + + CUTLASS_HOST_DEVICE + void transpose() { + // Only row-major outputs are currently supported, so no transpose is + // performed + } + + /// Returns non-grouped parameters to be used as input to the kernel-level + /// operator for the problem indicated by problem_visitor. + CUTLASS_HOST_DEVICE + Params to_single_params(const ProblemVisitor& problem_visitor) const { + GemmCoord problem_size0 = problem_visitor.problem_size0(); + GemmCoord problem_size1 = problem_visitor.problem_size1(); + int32_t idx = problem_visitor.problem_index(); + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size1); + + return Params(cutlass::gemm::GemmUniversalMode::kGemm, problem_size0, + problem_size1, grid_shape, ref_A0[idx], ref_B0[idx], + ref_C0[idx], ref_Scale0[idx], ref_Bias0[idx], ref_B1[idx], + ref_C1[idx], ref_D1[idx], 0, 0, 0, 0, 0, 0, + 0, // Batched B2B GEMMs within the grouped kernel are + // currently unsupported + output_op_0, output_op_1, workspace); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename B2bMma::B2bMmaSharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + B2bGemm() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const& problem_size_0, + cutlass::gemm::GemmCoord const& problem_size_1, + typename B2bMma::IteratorA0::TensorRef ref_A0, + typename B2bMma::IteratorB0::TensorRef ref_B0, + typename Epilogue::OutputTileIterator::TensorRef ref_C0, + typename B2bMma::IteratorB1::TensorRef ref_B1, + typename Epilogue::OutputTileIterator::TensorRef ref_C1, + typename Epilogue::OutputTileIterator::TensorRef ref_D1) { + static int const kAlignmentA = B2bMma::IteratorA0::AccessType::kElements; + static int const kAlignmentB = B2bMma::IteratorB0::AccessType::kElements; + static int const kAlignmentC = + Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(ref_A0, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B0, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C0, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B1, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C1, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D1, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if ((problem_size_0.m() % kAlignmentA) || + (problem_size_0.k() % kAlignmentA) || + (problem_size_0.n() % kAlignmentB) || + (problem_size_0.k() % kAlignmentB) || + (problem_size_0.m() % kAlignmentC) || + (problem_size_0.n() % kAlignmentC) || + (problem_size_1.m() % kAlignmentA) || + (problem_size_1.k() % kAlignmentA) || + (problem_size_1.n() % kAlignmentB) || + (problem_size_1.k() % kAlignmentB) || + (problem_size_1.m() % kAlignmentC) || + (problem_size_1.n() % kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + // Determine if fusion sizes are valid + if (problem_size_0.m() != problem_size_1.m()) + return Status::kErrorInvalidProblem; + + if (problem_size_0.n() != problem_size_1.k()) + return Status::kErrorInvalidProblem; + + if (problem_size_0.n() > B2bMma::Shape0::kN) + return Status::kErrorInvalidProblem; + + if (problem_size_1.n() > B2bMma::Shape1::kN) + return Status::kErrorInvalidProblem; + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { + ThreadblockSwizzle threadblock_swizzle; + run_with_swizzle(params, shared_storage, threadblock_swizzle); + } + + CUTLASS_DEVICE + void run_with_swizzle_nobias(Params const& params, + SharedStorage& shared_storage, + ThreadblockSwizzle& threadblock_swizzle) { + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } + ElementA0* ptr_A0 = static_cast(params.ref_A0.data()); + ElementB0* ptr_B0 = static_cast(params.ref_B0.data()); + ElementB1* ptr_B1 = static_cast(params.ref_B1.data()); + + int offset_k_0 = 0; + int offset_k_1 = 0; + int problem_size_k_0 = params.problem_size_0.k(); + int problem_size_k_1 = params.problem_size_1.k(); + + if (params.mode == GemmUniversalMode::kGemm) { + // Problem size is a function of threadblock index in the K dimension + problem_size_k_0 = + min(problem_size_k_0, + (threadblock_tile_offset.k() + 1) * params.gemm_k_size_0); + + // Problem size is a function of threadblock index in the K dimension + problem_size_k_1 = + min(problem_size_k_1, + (threadblock_tile_offset.k() + 1) * params.gemm_k_size_1); + + offset_k_0 = threadblock_tile_offset.k() * params.gemm_k_size_0; + offset_k_1 = threadblock_tile_offset.k() * params.gemm_k_size_1; + } else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A0; + ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0; + ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A0{ + threadblock_tile_offset.m() * B2bMma::Shape0::kM, + offset_k_0, + }; + + cutlass::MatrixCoord tb_offset_B0{ + offset_k_0, threadblock_tile_offset.n() * B2bMma::Shape0::kN}; + + cutlass::MatrixCoord tb_offset_B1{ + offset_k_1, threadblock_tile_offset.n() * B2bMma::Shape1::kN}; + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations_0 = + (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / + B2bMma::Shape0::kK; + + // Compute threadblock-scoped matrix multiply-add + // int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + + // B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename B2bMma::IteratorA0 iterator_A0( + params.params_A0, ptr_A0, {params.problem_size_0.m(), problem_size_k_0}, + thread_idx, tb_offset_A0); + + typename B2bMma::IteratorB0 iterator_B0( + params.params_B0, ptr_B0, {problem_size_k_0, params.problem_size_0.n()}, + thread_idx, tb_offset_B0); + + typename B2bMma::IteratorB1 iterator_B1( + params.params_B1, ptr_B1, {problem_size_k_1, params.problem_size_1.n()}, + thread_idx, tb_offset_B1); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + OutputOp0 output_op_0(params.output_op_0); + + if (cutlass::gemm::threadblock::detail::IsGroupedSwizzle< + ThreadblockSwizzle>::value) { + // Wait for all threads to finish their epilogue phases from the previous + // tile. + __syncthreads(); + } + + // Construct thread-scoped matrix multiply + B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, + params.problem_size_0.n()); + + typename B2bMma::FragmentC0 src_accum; + typename B2bMma::FragmentC1 accumulators; + + src_accum.clear(); + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, + iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, + output_op_0); + + // + // Epilogue + // + + OutputOp1 output_op_1(params.output_op_1); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * B2bMma::Shape1::kM, + threadblock_tile_offset.n() * B2bMma::Shape1::kN); + + int block_idx = threadblock_tile_offset.m() + + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC* ptr_C1 = static_cast(params.ref_C1.data()); + ElementC* ptr_D1 = static_cast(params.ref_D1.data()); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + if (params.mode == GemmUniversalMode::kGemm) { + // If performing a reduction via split-K, fetch the initial + // synchronization + + if (params.grid_tiled_shape.k() > 1) { + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is + // currently updating + output_op_1.set_k_partition(threadblock_tile_offset.k(), + params.grid_tiled_shape.k()); + } + } else if (params.mode == GemmUniversalMode::kBatched) { + ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1; + ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D1; + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C1( + params.params_C1, ptr_C1, params.problem_size_1.mn(), thread_idx, + threadblock_offset); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D1( + params.params_D1, ptr_D1, params.problem_size_1.mn(), thread_idx, + threadblock_offset); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator + // construction + if (params.mode == GemmUniversalMode::kGemm && + params.grid_tiled_shape.k() > 1) { + // For subsequent threadblocks, the source matrix is held in the 'D' + // tensor. + if (threadblock_tile_offset.k()) { + iterator_C1 = iterator_D1; + } + + semaphore.wait(threadblock_tile_offset.k()); + + __threadfence(); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op_1, iterator_D1, accumulators, iterator_C1); + + // + // Release the semaphore + // + + if (params.mode == GemmUniversalMode::kGemm && + params.grid_tiled_shape.k() > 1) { + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + __threadfence(); + semaphore.release(lock); + } + } + + /// Executes one GEMM with an externally-provided swizzling function + CUTLASS_DEVICE + void run_with_swizzle(Params const& params, SharedStorage& shared_storage, + ThreadblockSwizzle& threadblock_swizzle) { + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } + + ElementA0* ptr_A0 = static_cast(params.ref_A0.data()); + ElementB0* ptr_B0 = static_cast(params.ref_B0.data()); + ElementB1* ptr_B1 = static_cast(params.ref_B1.data()); + + ScaleBiasData* ptr_Bias0 = + static_cast(params.ref_Bias0.data()); + ScaleBiasData* ptr_Scale0 = + static_cast(params.ref_Scale0.data()); + + int offset_k_0 = 0; + int offset_k_1 = 0; + + int problem_size_k_0 = params.problem_size_0.k(); + int problem_size_k_1 = params.problem_size_1.k(); + + if (params.mode == GemmUniversalMode::kGemm) { + // Problem size is a function of threadblock index in the K dimension + problem_size_k_0 = + min(problem_size_k_0, + (threadblock_tile_offset.k() + 1) * params.gemm_k_size_0); + + // Problem size is a function of threadblock index in the K dimension + problem_size_k_1 = + min(problem_size_k_1, + (threadblock_tile_offset.k() + 1) * params.gemm_k_size_1); + + offset_k_0 = threadblock_tile_offset.k() * params.gemm_k_size_0; + offset_k_1 = threadblock_tile_offset.k() * params.gemm_k_size_1; + } + + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A0; + ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0; + ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1; + ptr_Bias0 += threadblock_tile_offset.k() * params.batch_stride_Bias0; + ptr_Scale0 += threadblock_tile_offset.k() * params.batch_stride_Scale0; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A0{ + threadblock_tile_offset.m() * B2bMma::Shape0::kM, + offset_k_0, + }; + + cutlass::MatrixCoord tb_offset_B0{ + offset_k_0, threadblock_tile_offset.n() * B2bMma::Shape0::kN}; + + cutlass::MatrixCoord tb_offset_B1{ + offset_k_1, threadblock_tile_offset.n() * B2bMma::Shape1::kN}; + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations_0 = + (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / + B2bMma::Shape0::kK; + + // Compute threadblock-scoped matrix multiply-add + // int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + + // B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename B2bMma::IteratorA0 iterator_A0( + params.params_A0, ptr_A0, {params.problem_size_0.m(), problem_size_k_0}, + thread_idx, tb_offset_A0); + + typename B2bMma::IteratorB0 iterator_B0( + params.params_B0, ptr_B0, {problem_size_k_0, params.problem_size_0.n()}, + thread_idx, tb_offset_B0); + + typename B2bMma::IteratorB1 iterator_B1( + params.params_B1, ptr_B1, {problem_size_k_1, params.problem_size_1.n()}, + thread_idx, tb_offset_B1); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // Construct iterators to accumulator scale/bias vector + typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0( + ptr_Scale0, {1, params.problem_size_0.n()}, thread_idx, warp_idx, + MatrixCoord(0, threadblock_tile_offset.n() * B2bMma::Shape0::kN)); + + typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0( + ptr_Bias0, {1, params.problem_size_0.n()}, thread_idx, warp_idx, + MatrixCoord(0, threadblock_tile_offset.n() * B2bMma::Shape0::kN)); + + // + // Main loop + // + + OutputOp0 output_op_0(params.output_op_0); + + if (cutlass::gemm::threadblock::detail::IsGroupedSwizzle< + ThreadblockSwizzle>::value) { + // Wait for all threads to finish their epilogue phases from the previous + // tile. + __syncthreads(); + } + + // Construct thread-scoped matrix multiply + B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, + params.problem_size_0.n()); + + typename B2bMma::FragmentC0 src_accum; + typename B2bMma::FragmentC1 accumulators; + + src_accum.clear(); + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, + iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, + output_op_0); + + // + // Epilogue + // + + OutputOp1 output_op_1(params.output_op_1); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * B2bMma::Shape1::kM, + threadblock_tile_offset.n() * B2bMma::Shape1::kN); + + int block_idx = threadblock_tile_offset.m() + + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC* ptr_C1 = static_cast(params.ref_C1.data()); + ElementC* ptr_D1 = static_cast(params.ref_D1.data()); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + if (params.mode == GemmUniversalMode::kGemm) { + // If performing a reduction via split-K, fetch the initial + // synchronization + + if (params.grid_tiled_shape.k() > 1) { + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is + // currently updating + output_op_1.set_k_partition(threadblock_tile_offset.k(), + params.grid_tiled_shape.k()); + } + } else if (params.mode == GemmUniversalMode::kBatched) { + ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1; + ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D1; + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C1( + params.params_C1, ptr_C1, params.problem_size_1.mn(), thread_idx, + threadblock_offset); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D1( + params.params_D1, ptr_D1, params.problem_size_1.mn(), thread_idx, + threadblock_offset); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator + // construction + if (params.mode == GemmUniversalMode::kGemm && + params.grid_tiled_shape.k() > 1) { + // For subsequent threadblocks, the source matrix is held in the 'D' + // tensor. + if (threadblock_tile_offset.k()) { + iterator_C1 = iterator_D1; + } + + semaphore.wait(threadblock_tile_offset.k()); + + __threadfence(); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op_1, iterator_D1, accumulators, iterator_C1); + + // + // Release the semaphore + // + + if (params.mode == GemmUniversalMode::kGemm && + params.grid_tiled_shape.k() > 1) { + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + __threadfence(); + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/core/kernel/b2b_gemm_grouped_problem_visitor.h b/core/kernel/b2b_gemm_grouped_problem_visitor.h new file mode 100644 index 0000000..9fd82da --- /dev/null +++ b/core/kernel/b2b_gemm_grouped_problem_visitor.h @@ -0,0 +1,149 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Scheduler for grouped B2b GEMMs +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct B2bGemmGroupedProblemVisitor + : public GroupedProblemVisitor< + detail::GemmGroupedProblemSizeHelper, + ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, + ThreadCount> { + using ProblemSizeHelper = + detail::GemmGroupedProblemSizeHelper; + using Base = + GroupedProblemVisitor; + using BaseParams = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + static bool const kTransposed = Transposed; + + cutlass::gemm::GemmCoord const* problem_sizes0; + cutlass::gemm::GemmCoord const* problem_sizes1; + + struct Params { + cutlass::gemm::GemmCoord const* problem_sizes0; + cutlass::gemm::GemmCoord const* problem_sizes1; + int32_t problem_count; + void const* workspace; + int32_t tile_count; + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + Params() + : problem_sizes0(nullptr), + problem_sizes1(nullptr), + problem_count(0), + workspace(nullptr), + tile_count(0) {} + + /// Ctor + CUTLASS_HOST_DEVICE + Params(cutlass::gemm::GemmCoord const* problem_sizes0, + cutlass::gemm::GemmCoord const* problem_sizes1, + int32_t problem_count, void const* workspace = nullptr, + int32_t tile_count = 0) + : problem_sizes0(problem_sizes0), + problem_sizes1(problem_sizes1), + problem_count(problem_count), + workspace(workspace), + tile_count(tile_count) {} + + /// Convert the B2b-GEMM-specific parameters to those used by the base class + CUTLASS_HOST_DEVICE + BaseParams to_base() const { + return BaseParams( // Set problem_sizes as problem_sizes0 because these + // determine shape of the grid used in the non-grouped + // B2b GEMM + problem_sizes0, problem_count, workspace, tile_count); + } + }; + + // + // Methods + // + CUTLASS_DEVICE + B2bGemmGroupedProblemVisitor(Params const& params_, + SharedStorage& shared_storage_, + int32_t block_idx) + : Base(params_.to_base(), shared_storage_, block_idx), + problem_sizes0(params_.problem_sizes0), + problem_sizes1(params_.problem_sizes1) {} + + /// Returns the problem size 0 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size0() const { + GemmCoord problem = problem_sizes0[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + /// Returns the problem size 1 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size1() const { + GemmCoord problem = problem_sizes1[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/core/kernel/b2b_mma_pipelined.h b/core/kernel/b2b_mma_pipelined.h new file mode 100644 index 0000000..97bdbd3 --- /dev/null +++ b/core/kernel/b2b_mma_pipelined.h @@ -0,0 +1,590 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped Back-to-back fused + GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA0_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA0_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile + // (concept::MmaTensorOpFragmentIterator) + typename FragmentIteratorA1_, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: VectorIterator) + typename IteratorAccumulatorScaleBias_, + /// FragmentIterator to load Scale or Bias vector from threadblock fragment + typename FragmentIteratorA1ScaleBias_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Output operator for 1st Gemm(concept: + /// epilogue::thread::LinearCombinationClamp, etc...) + typename OutputOp_, + /// Policy describing tuning details (concept: MmaPipelinedPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPipelinedPolicy) + typename Policy1_, + /// Transformation applied to A0 operand + typename TransformA0_ = NumericArrayConverter< + typename SmemIteratorA0_::Element, typename IteratorA0_::Element, + IteratorA0_::Fragment::kElements>, + /// + /// Transformation applied to B0 operand + typename TransformB0_ = NumericArrayConverter< + typename SmemIteratorB0_::Element, typename IteratorB0_::Element, + IteratorB0_::Fragment::kElements>, + /// + /// Transformation applied to B1 operand + typename TransformB1_ = NumericArrayConverter< + typename SmemIteratorB1_::Element, typename IteratorB1_::Element, + IteratorB1_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool> +class B2bMmaPipelined + : public B2bMmaBase { + public: + ///< Base class + using Base = B2bMmaBase; + + using Shape0 = + Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA0 = + IteratorA0_; ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA0; + using IteratorB0 = + IteratorB0_; ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB0; + using Policy0 = Policy0_; ///< Policy describing tuning details + + using SmemIteratorA0 = SmemIteratorA0_; + using SmemIteratorB0 = SmemIteratorB0_; + + using Shape1 = + Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using FragmentIteratorA1 = + FragmentIteratorA1_; ///< Iterates over intermediate accumulator tile + using IteratorAccumulatorScaleBias = + IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and + ///< bias vectors in global memory + // using FragmentIteratorA1ScaleBias = + // FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or + // Bias + // ///< vector from the threadblock + // fragment + using IteratorB1 = + IteratorB1_; ///< Iterates over tiles of B operand in global memory + using Policy1 = Policy1_; ///< Policy describing tuning details + using Policy = + Policy1; ///< Export Policy1 as the threadblock-level Mma's policy + using Shape = Shape1; + + using SmemIteratorB1 = SmemIteratorB1_; + + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + + using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm + + static const bool PerChannelScale = + (OutputOp::kScale == + epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); + + using TransformA0 = TransformA0_; + using TransformB0 = TransformB0_; + using TransformB1 = TransformB1_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA0 = typename IteratorA0::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB0 = typename IteratorB0::Fragment; + + /// Fragment of accumulator tile + using FragmentC0 = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + + /// Fragment of Scale and Bias loaded from global memory + using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB1 = typename IteratorB1::Fragment; + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy0::Operator::ArchTag; + + /// Complex transform on A0 operand + static ComplexTransform const kTransformA0 = Operator0::kTransformA; + + /// Complex transform on B0 operand + static ComplexTransform const kTransformB0 = Operator0::kTransformB; + + /// Complex transform on B1 operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Complex transform exports needed by higher-level kernels + static ComplexTransform const kTransformA = kTransformA0; + static ComplexTransform const kTransformB = kTransformB0; + + /// staticaly assert kStages for MmaPipelined is two (Double-buffered + /// pipeline) + static_assert((Base::kStages == 2), + "MmaPipelined requires kStages set to value 2"); + + private: + using WarpFragmentA0 = typename Operator0::FragmentA; + using WarpFragmentB0 = typename Operator0::FragmentB; + /// Warp Fragment of operand A1 loaded from accmulator tile + using WarpFragmentA1 = typename FragmentIteratorA1::Fragment; + /// Warp Fragment of operand A1 scale and bias loaded from threadblock + /// fragment + // using WarpFragmentA1ScaleBias = + // typename FragmentIteratorA1ScaleBias::Fragment; + using WarpFragmentB1 = typename Operator1::FragmentB; + + protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA0 smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B0 operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + + /// Iterator to write threadblock-scoped tile of B1 operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + B2bMmaPipelined( + typename Base::B2bMmaSharedStorage& + shared_storage, ///< Shared storage needed for internal use by + ///< threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx, ///< ID of each thread within a warp + int problem_size_0_n ///< GEMM0 N is used for accumulator extent + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.shared_storage0.operand_A_ref(), + thread_idx), + smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), + thread_idx), + smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), + thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + // These should stay the same across different GEMM layers + int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); + int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM; + + // These may change across different GEMM layers + int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k; + int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k_0}); + this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n}); + this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations_0, ///< number of iterations of the mainloop + FragmentC1& accum, ///< destination accumulator tile + IteratorA0 iterator_A, ///< iterator over A operand in global memory + IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory + IteratorAccumulatorScaleBias + iterator_A1_scale, ///< iterator over A1 operand scale vectors in + ///< global memory + IteratorAccumulatorScaleBias + iterator_A1_bias, ///< iterator over A1 operand bias vectors in + ///< global memory + IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory + FragmentC0 const& src_accum, ///< source accumulator tile + OutputOp output_op_0, ///< epilogue operation after 1st Gemm + TransformA0 transform_A0 = + TransformA0(), ///< transformation applied to A0 fragment + TransformB0 transform_B0 = + TransformB0(), ///< transformation applied to B0 fragment + TransformB1 transform_B1 = + TransformB1()) { ///< transformation applied to B1 fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + FragmentC0 accum0 = src_accum; + + FragmentA0 tb_frag_A; + FragmentB0 tb_frag_B0; + + tb_frag_A.clear(); + tb_frag_B0.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B0.load(tb_frag_B0); + + ++iterator_A; + ++iterator_B0; + + this->smem_iterator_A_.store(transform_A0(tb_frag_A)); + this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B0_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA0 warp_frag_A0[2]; + WarpFragmentB0 warp_frag_B0[2]; + + this->warp_tile_iterator_A0_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + + this->warp_tile_iterator_A0_.load(warp_frag_A0[0]); + this->warp_tile_iterator_B0_.load(warp_frag_B0[0]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + Operator0 warp_mma0; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations_0 <= 1); + iterator_B0.clear_mask(gemm_k_iterations_0 <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tightest latency + // requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations0 - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A0(tb_frag_A)); + + this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B0_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A0_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations0, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A0_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations0); + this->warp_tile_iterator_B0_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations0); + + this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B0.load(tb_frag_B0); + ++iterator_A; + ++iterator_B0; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations_0 <= 2); + iterator_B0.clear_mask(gemm_k_iterations_0 <= 2); + } + + warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], + warp_frag_B0[warp_mma_k % 2], accum0); + } + } + + // 2nd Gemm + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate + /// accumulator tile + FragmentIteratorA1 warp_tile_iterator_A1_(accum0); + + // + // Prologue + // + + // FragmentA1ScaleBias tb_frag_A1_scale; + // FragmentA1ScaleBias tb_frag_A1_bias; + // FragmentIteratorA1ScaleBias + // warp_tile_iterator_A1_scale_(tb_frag_A1_scale); + // FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias); + FragmentB1 tb_frag_B1; + + // if (PerChannelScale) tb_frag_A1_scale.clear(); + // tb_frag_A1_bias.clear(); + tb_frag_B1.clear(); + + // The last kblock is loaded in the prolog + // if (PerChannelScale) iterator_A1_scale.load(tb_frag_A1_scale); + // iterator_A1_bias.load(tb_frag_A1_bias); + iterator_B1.load(tb_frag_B1); + + // if (PerChannelScale) ++iterator_A1_scale; + // ++iterator_A1_bias; + ++iterator_B1; + + this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); + + ++this->smem_iterator_B1_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + // WarpFragmentA1ScaleBias warp_frag_A1_scale[2]; + // WarpFragmentA1ScaleBias warp_frag_A1_bias[2]; + WarpFragmentA1 warp_frag_A1[2]; + WarpFragmentB1 warp_frag_B1[2]; + + this->warp_tile_iterator_B1_.set_kgroup_index(0); + + // if (PerChannelScale) + // warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[0]); + // warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[0]); + warp_tile_iterator_A1_.load(warp_frag_A1[0], warp_frag_A1_scale[0], + warp_frag_A1_bias[0], output_op_0); + this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); + + ++warp_tile_iterator_A1_; + // if (PerChannelScale) ++warp_tile_iterator_A1_scale_; + // ++warp_tile_iterator_A1_bias_; + ++this->warp_tile_iterator_B1_; + + Operator1 warp_mma1; + + smem_write_stage_idx = 1; + + int gemm_k_iterations_1 = + FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; + + // Avoid reading out of bounds + iterator_B1.clear_mask(gemm_k_iterations_1 <= 1); + + // + // Mainloop + // + + // Note: The main loop does not support Base::WarpGemmIterations == 2. + CUTLASS_PRAGMA_UNROLL + for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { + // Write fragments to shared memory + this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); + + __syncthreads(); + ++this->smem_iterator_B1_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + } + + smem_write_stage_idx ^= 1; + + if (PerChannelScale) { + tb_frag_A1_scale.clear(); + iterator_A1_scale.load(tb_frag_A1_scale); + ++iterator_A1_scale; + } + tb_frag_A1_bias.clear(); + iterator_A1_bias.load(tb_frag_A1_bias); + ++iterator_A1_bias; + } + + this->warp_tile_iterator_B1_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations1); + + if (PerChannelScale) + warp_tile_iterator_A1_scale_.load( + warp_frag_A1_scale[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_bias_.load( + warp_frag_A1_bias[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], + warp_frag_A1_scale[(warp_mma_k + 1) % 2], + warp_frag_A1_bias[(warp_mma_k + 1) % 2], + output_op_0); + this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); + + if (PerChannelScale) ++warp_tile_iterator_A1_scale_; + ++warp_tile_iterator_A1_bias_; + ++warp_tile_iterator_A1_; + ++this->warp_tile_iterator_B1_; + + if (warp_mma_k == 0) { + iterator_B1.load(tb_frag_B1); + ++iterator_B1; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_B1.clear_mask(gemm_k_iterations_1 <= 2); + } + + warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], + warp_frag_B1[warp_mma_k % 2], accum); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/core/kernel/epilogue_utils.h b/core/kernel/epilogue_utils.h new file mode 100644 index 0000000..c415dda --- /dev/null +++ b/core/kernel/epilogue_utils.h @@ -0,0 +1,133 @@ +#pragma once + +#include +#include +#include +#include +#include + +// Data type +using ElementInput = cutlass::bfloat16_t; +using ElementOutput = cutlass::bfloat16_t; +using ElementAccumulator = float; +using ElementCompute = cutlass::bfloat16_t; + +// Tile sizes +using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; +using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; +using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + +// Layouts +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::RowMajor; + +using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationSiLU< + ElementOutput, // Element type for output + 128 / cutlass::sizeof_bits::value, // Elements per + // vectorized access + ElementAccumulator, // Accumulator (from GEMM) + ElementCompute // Compute type (for scale) + >; + +// Define the GEMM with SiLU fused in epilogue +using FusedGemmSiLU = cutlass::gemm::device::Gemm< + ElementInput, LayoutA, ElementInput, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape, + EpilogueOutputOp // Fused epilogue with SiLU + >; + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Define data types +using ElementInput = cutlass::half_t; +using ElementOutput = cutlass::half_t; +using ElementAccumulator = float; +using ElementCompute = float; + +// Layouts +using LayoutA = cutlass::layout::RowMajor; // X +using LayoutB = cutlass::layout::ColumnMajor; // Weights +using LayoutC = cutlass::layout::RowMajor; // Output + +// Tile sizes (adjust for your GPU architecture) +using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; +using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; +using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + +// Epilogue for GEMM3 (down projection) +using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>; + +// GEMM3 definition (fused output * Wd^T) +using Gemm3 = cutlass::gemm::device::Gemm< + ElementOutput, LayoutA, ElementInput, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp>; + +// Fully fused kernel for GEMM1+SiLU+GEMM2+Mul +__global__ void FusedMoEMLPKernel( + ElementInput const* X, // [B, InputSize] + ElementInput const* Wg, // [HiddenSize, InputSize] + ElementInput const* Wu, // [UpSize, InputSize] + ElementInput const* Wd, // [OutSize, UpSize], optional + ElementOutput* Output, // [B, OutSize] if Wd != nullptr else [B, UpSize] + int B, int InputSize, int HiddenSize, int UpSize, int OutSize, + bool has_Wd) { + int row = blockIdx.x * blockDim.x + threadIdx.x; // Batch + if (row >= B) return; + + // Pointers + const ElementInput* X_row = X + row * InputSize; + const ElementInput* Wg_col = Wg; + const ElementInput* Wu_col = Wu; + + // Accumulators for GEMM1 and GEMM2 + ElementAccumulator acc_g[HiddenSize] = {0}; + ElementAccumulator acc_u[UpSize] = {0}; + + // Compute GEMM1 and GEMM2 in registers + for (int k = 0; k < InputSize; ++k) { + ElementInput x = X_row[k]; + for (int n = 0; n < HiddenSize; ++n) + acc_g[n] += static_cast(x) * + static_cast(Wg_col[n * InputSize + k]); + for (int n = 0; n < UpSize; ++n) + acc_u[n] += static_cast(x) * + static_cast(Wu_col[n * InputSize + k]); + } + + // Apply SiLU to GEMM1 result + for (int n = 0; n < HiddenSize; ++n) { + float x = static_cast(acc_g[n]); + acc_g[n] = x * (1.0f / (1.0f + expf(-x))); // SiLU + } + + // Fused output = SiLU(GEMM1) * GEMM2 + ElementAccumulator fused[UpSize]; + for (int n = 0; n < UpSize; ++n) { + fused[n] = acc_u[n]; + } + + for (int n = 0; n < min(HiddenSize, UpSize); ++n) { + fused[n] *= acc_g[n]; // Elementwise multiply + } + + for (int n = 0; n < OutSize; ++n) { + ElementAccumulator acc_out = 0; + for (int k = 0; k < UpSize; ++k) + acc_out += fused[k] * static_cast(Wd[n * UpSize + k]); + Output[row * OutSize + n] = static_cast(acc_out); + } +} diff --git a/core/kernel/grouped_threadblock_swizzle.h b/core/kernel/grouped_threadblock_swizzle.h new file mode 100644 index 0000000..58f72ad --- /dev/null +++ b/core/kernel/grouped_threadblock_swizzle.h @@ -0,0 +1,123 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements several threadblock-swizzling functions for grouped + kernels +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "kernel/b2b_gemm_grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +struct GroupedThreadblockSwizzleBase {}; + +/// Helper for determining if a swizzling function is specialized for grouped +/// operation +template +struct IsGroupedSwizzle { + static bool const value = + cutlass::platform::is_base_of::value; +}; + +} // namespace detail + +/// Swizzling function for grouped kernels +template +struct GroupedThreadblockSwizzle : detail::GroupedThreadblockSwizzleBase { + using ProblemVisitor = ProblemVisitor_; + ProblemVisitor problem_visitor; + + CUTLASS_HOST_DEVICE + GroupedThreadblockSwizzle( + typename ProblemVisitor::Params& params, + typename ProblemVisitor::SharedStorage& shared_storage, int block_idx) + : problem_visitor(params, shared_storage, block_idx) {} + + /// Obtains the threadblock offset (in units of threadblock-scoped tiles) + CUTLASS_DEVICE + GemmCoord get_tile_offset(int /*log_tile*/) const { + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + return GemmCoord(int(threadblock_idx / grid_shape.n()), + int(threadblock_idx % grid_shape.n()), 0); + } + + /// Dummy method to satisfy API for threadblock swizzling functions + CUTLASS_HOST_DEVICE + static int get_log_tile(GemmCoord /*tiled_shape*/) { return 0; } +}; + +template +struct B2bGemmGroupedThreadblockSwizzle + : GroupedThreadblockSwizzle< + cutlass::gemm::kernel::B2bGemmGroupedProblemVisitor< + ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, + ThreadCount, + platform::is_same::value>> { + using Base = GroupedThreadblockSwizzle< + cutlass::gemm::kernel::B2bGemmGroupedProblemVisitor< + ThreadblockShape, GroupScheduleMode_, PrefetchTileCount, ThreadCount, + platform::is_same::value>>; + + CUTLASS_HOST_DEVICE + B2bGemmGroupedThreadblockSwizzle( + typename Base::ProblemVisitor::Params& params, + typename Base::ProblemVisitor::SharedStorage& shared_storage, + int block_idx) + : Base(params, shared_storage, block_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/core/kernel/mlp_tile_op.h b/core/kernel/mlp_tile_op.h new file mode 100644 index 0000000..5155ae1 --- /dev/null +++ b/core/kernel/mlp_tile_op.h @@ -0,0 +1,299 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "common/types.h" +#include "tile_size.h" + +// template +// struct MmaOperation { +// using ElementA = cutlass::bfloat16_t; +// using ElementB = cutlass::bfloat16_t; +// using ElementC = float; +// using ElementAccumulator = float; + +// // Instruction shape for tensor cores (e.g., 16x8x16 for Ampere) +// using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + +// // Warp-level tile shape +// using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; + +// // MMA operator +// using MmaTensorOp = cutlass::gemm::warp::MmaTensorOp< +// WarpShape, ElementA, cutlass::layout::RowMajor, ElementB, +// cutlass::layout::ColumnMajor, ElementC, cutlass::layout::RowMajor, +// cutlass::arch::OpMultiplyAdd, ArchTag>; + +// // Fragment shapes - A and B have different shapes! +// using FragmentA = typename MmaTensorOp::FragmentA; +// using FragmentB = typename MmaTensorOp::FragmentB; +// using FragmentC = typename MmaTensorOp::FragmentC; + +// // Iterator shapes for A and B are different +// static constexpr int kFragmentARows = WarpShape::kM; +// static constexpr int kFragmentACols = WarpShape::kK; +// static constexpr int kFragmentBRows = WarpShape::kK; +// static constexpr int kFragmentBCols = WarpShape::kN; +// }; + +template +struct TileLoader2D { + using Element = Element_; + using Layout = Layout_; + using ThreadblockShape = ThreadblockShape_; + using ThreadMap = ThreadMap_; + + using GmemIterator = cutlass::transform::threadblock::PredicatedTileIterator< + ThreadblockShape, Element, Layout, 1, ThreadMap>; + + static const int ElementSize = cutlass::sizeof_bits::value; + static const int Crosswise = 64; + + using SmemLayout = std::conditional_t< + std::is_same_v, + cutlass::layout::RowMajorTensorOpMultiplicandCongruous, + cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous>; + + using SmemIterator = cutlass::transform::threadblock::RegularTileIterator< + ThreadblockShape, Element, SmemLayout, 1, ThreadMap, 16>; + + Element* smem_ptr; + + __device__ TileLoader2D(Element* smem_ptr_) : smem_ptr(smem_ptr_) {} + + __device__ void operator()( + cutlass::TensorView const& global_view, + cutlass::MatrixCoord const& tb_offset) { + int thread_idx = threadIdx.x; + + auto extent = global_view.extent(); + + GmemIterator gmem_it(global_view.layout(), global_view.data(), + global_view.extent(), thread_idx, tb_offset); + + typename GmemIterator::Fragment frag; + frag.clear(); + gmem_it.load(frag); + + cutlass::TensorRef smem_ref( + smem_ptr, SmemLayout::packed( + {ThreadblockShape::kRow, ThreadblockShape::kColumn})); + SmemIterator smem_it(smem_ref, thread_idx); + smem_it.store(frag); + + // __syncthreads(); + } +}; + +// Flexible multi-view tile loader class +template +class MultiViewTileLoader { + public: + static constexpr size_t NumViews = sizeof...(ViewConfigs); + + // Extract types from ViewConfigs + template + using Element = typename GetNthType_t::Element; + + template + using Layout = typename GetNthType_t::Layout; + + template + using Shape = typename GetNthType_t::Shape; + + template + using ThreadMap = typename GetNthType_t::ThreadMap; + + template + using GemmCalc = + BusAwareGemmTileCalculator, Element, float>; + + // Calculate total shared memory size + static constexpr size_t calculate_smem_size() { + return calculate_smem_size_impl<0>(); // For pipelining + } + + private: + template + static constexpr size_t calculate_smem_size_impl() { + if constexpr (I >= NumViews) { + return 0; + } else { + using CurrentShape = Shape; + using CurrentElement = Element; + return CurrentShape::kRow * CurrentShape::kColumn * + sizeof(CurrentElement) + + calculate_smem_size_impl(); + } + } + + public: + // Helper to load a specific view + template + __device__ static void load_view( + cutlass::TensorView, Layout> const& view, + char* smem_base, uint32_t* signal, uint32_t stage_idx) { + // // Only assign warps that are within NumViews + // if (warp_id >= NumViews) return; // Add this check + // int lane_idx = threadIdx.x % 32; + if (threadIdx.y % GemmCalc::PipelineConfig::stages != stage_idx) + return; + + // Calculate shared memory offset for this view + size_t smem_offset = calculate_view_offset(); + Element* smem_ptr = + reinterpret_cast*>(smem_base + smem_offset); + + // Calculate threadblock offset based on view configuration + cutlass::MatrixCoord tb_offset = + GetNthType_t::calculate_offset(); + + // Create and use loader + using Loader = TileLoader2D, Layout, + Shape, ThreadMap>; + Loader loader(smem_ptr); + loader(view, tb_offset); + + __syncwarp(); // Ensure all threads have loaded their data + if (signal == nullptr) return; + if (threadIdx.x == 0) { + atomicAdd(&signal[stage_idx], 1); + printf( + "Block(%d,%d) Warp %d loaded view %zu at offset %zu; signal[%d] = " + "%u\n", + blockIdx.x, blockIdx.y, threadIdx.y, ViewIndex, smem_offset, + stage_idx, signal[stage_idx]); + } + } + + // Get the auto-calculated shape for a specific view + template + static constexpr auto get_shape() { + return Shape{}; + } + + static void print_configuration() { + printf("=== Auto-Sized Multi-View Loader Configuration ===\n"); + printf("Bus Width: %d bits (%d bytes/cycle)\n", BusWidthBits, + BusWidthBits / 8); + printf("Number of Views: %zu\n", NumViews); + printf("Total Shared Memory: %zu bytes\n", calculate_smem_size()); + print_view_configs<0>(); + printf("================================================\n"); + } + + private: + private: + template + static void print_view_configs() { + if constexpr (I < NumViews) { + printf("View %zu: %d-bit elements, %dx%d tile (%zu bytes)\n", I, + cutlass::sizeof_bits>::value, Shape::kRow, + Shape::kColumn, + Shape::kRow * Shape::kColumn * sizeof(Element)); + print_view_configs(); + } + } + + template + static constexpr size_t calculate_view_offset() { + if constexpr (ViewIndex == 0) { + return 0; + } else { + return calculate_view_offset() + + Shape::kRow * Shape::kColumn * + sizeof(Element); + } + } +}; + +// View configuration helper +template +struct ViewConfig { + using Element = Element_; + using Layout = Layout_; + using Shape = Shape_; + using ThreadMap = ThreadMap_; + + // Virtual method to be specialized for different matrix positions + __device__ static cutlass::MatrixCoord calculate_offset() { + return cutlass::MatrixCoord(0, 0); + } +}; + +// Specialized view configurations for A and B matrices +template +struct ViewConfigA : ViewConfig { + __device__ static cutlass::MatrixCoord calculate_offset() { + return cutlass::MatrixCoord(int(blockIdx.y * Shape_::kRow), + int(blockIdx.z * Shape_::kColumn)); + } +}; + +template +struct ViewConfigB : ViewConfig { + __device__ static cutlass::MatrixCoord calculate_offset() { + return cutlass::MatrixCoord(int(blockIdx.z * Shape_::kRow), + int(blockIdx.x * Shape_::kColumn)); + } +}; + +// Helper function templates to load each view by index +template +__device__ void load_each(char* smem, uint32_t* signal, uint32_t stage_idx, + FirstView& first, RestViews&... rest) { + if (threadIdx.x == 0) { + printf("Loading view %zu at stage %u\n", Index, stage_idx); + } + MultiLoader::template load_view(first, smem, signal, stage_idx); + if constexpr (sizeof...(RestViews) > 0) { + load_each(smem, signal, stage_idx, rest...); + } +} + +// Base case for recursion +template +__device__ void load_each(char* smem, uint32_t* signal, uint32_t stage_idx) {} + +// // Kernel that uses the flexible loader +// template +// __global__ void flexible_load_kernel(Views... views) { +// extern __shared__ __align__(16) char smem[]; +// load_each(smem, views...); +// } + +// Add this helper function before the kernel +template +__device__ void apply_load_each_impl(char* smem, uint32_t* signal, + uint32_t stage_idx, TupleType&& t, + std::index_sequence) { + load_each(smem, signal, stage_idx, thrust::get(t)...); + // (load_each(smem, signal, stage_idx, thrust::get(t)), + // ...); +} + +template +__device__ void apply_load_each(char* smem, uint32_t* signal, + uint32_t stage_idx, TupleType&& t) { + apply_load_each_impl( + smem, signal, stage_idx, std::forward(t), + std::make_index_sequence>>{}); +} diff --git a/core/kernel/tile_size.h b/core/kernel/tile_size.h new file mode 100644 index 0000000..b34284f --- /dev/null +++ b/core/kernel/tile_size.h @@ -0,0 +1,405 @@ +#pragma once + +#include +#include +#include + +#include "common/constant.h" +#include "common/types.h" + +// Bus width aware GEMM tile calculator +template +struct BusAwareGemmTileCalculator { + // Memory bandwidth parameters + static constexpr int bytes_per_cycle = BusWidthBits / 8; + static constexpr int target_cycles = TargetCycles; + + // Element sizes + static constexpr int size_A = cutlass::sizeof_bits::value / 8; + static constexpr int size_B = cutlass::sizeof_bits::value / 8; + static constexpr int size_C = cutlass::sizeof_bits::value / 8; + + // Total bytes we can move in target cycles + static constexpr int total_bandwidth_bytes = bytes_per_cycle * target_cycles; + + // Scale factor based on bus width (normalized to 256-bit baseline) + static constexpr int bus_scale_factor = BusWidthBits / 256; + static constexpr int bus_sqrt_scale = ConstexprSqrt::value; + + // Base tile dimensions for 256-bit bus + struct BaseTileSizes { + static constexpr int kM = 64; + static constexpr int kN = 64; + static constexpr int kK = 32; + }; + + // Scale tiles based on bus width + struct ScaledGemmTile { + // Scale M and N with square root of bus scaling to maintain aspect ratio + // Scale K less aggressively to maintain data reuse + static constexpr int kM_raw = BaseTileSizes::kM * bus_sqrt_scale; + static constexpr int kN_raw = BaseTileSizes::kN * bus_sqrt_scale; + static constexpr int kK_raw = + BaseTileSizes::kK * (bus_sqrt_scale + 1) / 2; // Scale K by half + + // Round to tensor core friendly sizes + static constexpr int kM = RoundToMultiple::value; + static constexpr int kN = RoundToMultiple::value; + static constexpr int kK = RoundToMultiple::value; + + // Calculate actual memory usage + static constexpr int elements_A = kM * kK; + static constexpr int elements_B = kK * kN; + static constexpr int elements_C = kM * kN; + + static constexpr int bytes_A = elements_A * size_A; + static constexpr int bytes_B = elements_B * size_B; + static constexpr int bytes_C = elements_C * size_C; + + // Total bytes for one tile computation (read A, B and write C) + static constexpr int total_bytes = bytes_A + bytes_B + bytes_C; + + // Cycles needed to transfer this data + static constexpr int cycles_needed = + (total_bytes + bytes_per_cycle - 1) / bytes_per_cycle; + + // Check if we fit within bandwidth budget + static constexpr bool fits_bandwidth = cycles_needed <= target_cycles; + }; + + // Architecture-specific optimized tiles + struct OptimizedTiles { + // For narrow bus (consumer GPUs): prioritize square tiles + struct NarrowBus { + static constexpr bool is_narrow = BusWidthBits <= 384; + static constexpr int kM = is_narrow ? 64 : ScaledGemmTile::kM; + static constexpr int kN = is_narrow ? 64 : ScaledGemmTile::kN; + static constexpr int kK = is_narrow ? 32 : ScaledGemmTile::kK; + }; + + // For wide bus (HBM GPUs): can afford larger tiles + struct WideBus { + static constexpr bool is_wide = BusWidthBits >= 4096; + static constexpr int kM = + is_wide ? RoundToMultiple::value + : ScaledGemmTile::kM; + static constexpr int kN = + is_wide ? RoundToMultiple::value + : ScaledGemmTile::kN; + static constexpr int kK = + is_wide ? RoundToMultiple::value + : ScaledGemmTile::kK; + }; + + // Choose based on bus width + static constexpr int kM = + WideBus::is_wide + ? WideBus::kM + : (NarrowBus::is_narrow ? NarrowBus::kM : ScaledGemmTile::kM); + static constexpr int kN = + WideBus::is_wide + ? WideBus::kN + : (NarrowBus::is_narrow ? NarrowBus::kN : ScaledGemmTile::kN); + static constexpr int kK = + WideBus::is_wide + ? WideBus::kK + : (NarrowBus::is_narrow ? NarrowBus::kK : ScaledGemmTile::kK); + }; + + // Threadblock clusters for CUTLASS 3.x + struct ClusterShape { + // Larger clusters for wider memory interfaces + static constexpr int kM = BusWidthBits >= 4096 ? 2 : 1; + static constexpr int kN = BusWidthBits >= 4096 ? 2 : 1; + static constexpr int kK = 1; // K dimension clustering less beneficial + }; + + // Pipeline stages based on bandwidth + struct PipelineConfig { + // More stages for wider interfaces to hide latency + static constexpr int stages = BusWidthBits >= 4096 ? 4 + : BusWidthBits >= 384 ? 3 + : 2; + }; + + // Warp arrangement + struct WarpArrangement { + // More warps for larger tiles + static constexpr int warps_m = OptimizedTiles::kM / 32; + static constexpr int warps_n = OptimizedTiles::kN / 32; + static constexpr int total_warps = warps_m * warps_n; + + // Ensure we don't exceed SM warp limits + static constexpr int max_warps = 16; // Typical limit + static constexpr bool valid = total_warps <= max_warps; + }; + + static void print_config() { + printf("=== Bus-Aware GEMM Tile Configuration ===\n"); + printf("Memory Bus: %d bits (%d bytes/cycle)\n", BusWidthBits, + bytes_per_cycle); + printf("Element sizes: A=%d, B=%d, C=%d bytes\n", size_A, size_B, size_C); + printf("Bus scale factor: %dx (sqrt: %dx)\n", bus_scale_factor, + bus_sqrt_scale); + printf("\nScaled tile dimensions:\n"); + printf(" Raw: %dx%dx%d\n", ScaledGemmTile::kM_raw, ScaledGemmTile::kN_raw, + ScaledGemmTile::kK_raw); + printf(" Aligned: %dx%dx%d\n", ScaledGemmTile::kM, ScaledGemmTile::kN, + ScaledGemmTile::kK); + printf(" Memory: A=%d, B=%d, C=%d bytes (total: %d)\n", + ScaledGemmTile::bytes_A, ScaledGemmTile::bytes_B, + ScaledGemmTile::bytes_C, ScaledGemmTile::total_bytes); + printf(" Cycles needed: %d (budget: %d)\n", ScaledGemmTile::cycles_needed, + target_cycles); + printf("\nOptimized tile: %dx%dx%d\n", OptimizedTiles::kM, + OptimizedTiles::kN, OptimizedTiles::kK); + printf("Cluster shape: %dx%dx%d\n", ClusterShape::kM, ClusterShape::kN, + ClusterShape::kK); + printf("Pipeline stages: %d\n", PipelineConfig::stages); + printf("Warp arrangement: %dx%d = %d warps\n", WarpArrangement::warps_m, + WarpArrangement::warps_n, WarpArrangement::total_warps); + printf("=========================================\n"); + } +}; + +template +struct ThreadBlockAutoTuner { + using GemmCalc = + BusAwareGemmTileCalculator; + + // Hardware constraints + static constexpr int MAX_THREADS_PER_BLOCK = 1024; + static constexpr int WARP_SIZE = 32; + static constexpr int MAX_WARPS_PER_SM = 32; + static constexpr int MAX_SHARED_MEMORY_PER_BLOCK = 49152; // 48KB typical + + // Memory bandwidth parameters + static constexpr int bytes_per_cycle = BusWidthBits / 8; + static constexpr int element_size = cutlass::sizeof_bits::value / 8; + + struct ThreadConfig { + int threads_x; // First dimension (within warp) + int threads_y; // Second dimension (warp count) + int total_threads; + float score; // Efficiency score + + bool is_valid() const { + return threads_x == WARP_SIZE && // First dim must be warp size + total_threads <= MAX_THREADS_PER_BLOCK && + total_threads % WARP_SIZE == 0; + } + }; + + // Calculate optimal thread configuration + static ThreadConfig autotune(int M, int N, int K, int tile_m, int tile_n, + int tile_k) { + std::vector candidates; + + // First dimension is always 32 (warp size) for coalesced access + const int threads_x = WARP_SIZE; + + // Calculate workload + int tiles_m = (M + tile_m - 1) / tile_m; + int tiles_n = (N + tile_n - 1) / tile_n; + int tiles_k = (K + tile_k - 1) / tile_k; + int total_tiles = tiles_m * tiles_n * tiles_k; + + // Memory requirements per block + size_t smem_per_block = calculate_smem_requirement(tile_m, tile_n, tile_k); + + // Try different warp counts (threads_y) + for (int warps = 1; warps <= 32; warps++) { + ThreadConfig config; + config.threads_x = threads_x; + config.threads_y = warps; + config.total_threads = threads_x * warps; + + if (!config.is_valid()) continue; + + // Calculate efficiency score + config.score = calculate_efficiency_score(config, M, N, K, tile_m, tile_n, + tile_k, tiles_m, tiles_n, + total_tiles, smem_per_block); + + candidates.push_back(config); + } + + // Select best configuration + auto best = + std::max_element(candidates.begin(), candidates.end(), + [](const ThreadConfig& a, const ThreadConfig& b) { + return a.score < b.score; + }); + + return *best; + } + + private: + static size_t calculate_smem_requirement(int tile_m, int tile_n, int tile_k) { + // For typical GEMM: A tile + B tile + optional C tile + size_t size_a = tile_m * tile_k * element_size; + size_t size_b = tile_k * tile_n * element_size; + size_t size_c = tile_m * tile_n * sizeof(float); // Accumulator + + // Add padding for bank conflict avoidance + size_t padding = 256; // Conservative padding + + return size_a + size_b + size_c + padding; + } + + static float calculate_efficiency_score(const ThreadConfig& config, int M, + int N, int K, int tile_m, int tile_n, + int tile_k, int tiles_m, int tiles_n, + int total_tiles, + size_t smem_per_block) { + float score = 0.0f; + + // 1. Occupancy score (more warps = better, up to a point) + float occupancy = calculate_occupancy(config, smem_per_block); + score += occupancy * 30.0f; + + // 2. Memory bandwidth utilization + float bandwidth_efficiency = + calculate_bandwidth_efficiency(config, tile_m, tile_n, tile_k); + score += bandwidth_efficiency * 25.0f; + + // 3. Load balance score + float load_balance = + calculate_load_balance(config, tiles_m, tiles_n, total_tiles); + score += load_balance * 20.0f; + + // 4. Warp efficiency (avoid partial warps) + float warp_efficiency = + (config.total_threads % WARP_SIZE == 0) ? 1.0f : 0.5f; + score += warp_efficiency * 15.0f; + + // 5. Bank conflict avoidance score + float bank_score = calculate_bank_conflict_score(config, tile_m, tile_n); + score += bank_score * 10.0f; + + return score; + } + + static float calculate_occupancy(const ThreadConfig& config, + size_t smem_per_block) { + // Estimate blocks per SM + int blocks_limited_by_threads = MAX_THREADS_PER_BLOCK * MAX_WARPS_PER_SM / + (config.total_threads * WARP_SIZE); + int blocks_limited_by_smem = MAX_SHARED_MEMORY_PER_BLOCK / smem_per_block; + + int blocks_per_sm = + std::min(blocks_limited_by_threads, blocks_limited_by_smem); + int warps_per_sm = blocks_per_sm * config.threads_y; + + return float(warps_per_sm) / MAX_WARPS_PER_SM; + } + + static float calculate_bandwidth_efficiency(const ThreadConfig& config, + int tile_m, int tile_n, + int tile_k) { + // Bytes moved per block + size_t bytes_per_block = (tile_m * tile_k + tile_k * tile_n) * element_size; + + // Threads available for loading + int loading_threads = config.total_threads; + + // Bytes per thread + float bytes_per_thread = float(bytes_per_block) / loading_threads; + + // Ideal: each thread loads 128 bytes (one cache line) + float ideal_bytes = 128.0f; + + // Score based on how close we are to ideal + float ratio = bytes_per_thread / ideal_bytes; + if (ratio > 1.0f) + ratio = 2.0f - ratio; // Penalize too much work per thread + + return std::max(0.0f, ratio); + } + + static float calculate_load_balance(const ThreadConfig& config, int tiles_m, + int tiles_n, int total_tiles) { + // For multi-view loading, we want threads_y to evenly divide work + float balance = 1.0f; + + if (config.threads_y == 2) { + // 2 warps: ideal for A and B loading + balance = 1.0f; + } else if (config.threads_y == 3) { + // 3 warps: ideal for A, B1, B2 (your case) + balance = 1.0f; + } else if (config.threads_y == 4) { + // 4 warps: good for A, B, C prefetch + compute + balance = 0.9f; + } else if (config.threads_y > 4) { + // More warps: may have idle threads + balance = 4.0f / config.threads_y; + } + + return balance; + } + + static float calculate_bank_conflict_score(const ThreadConfig& config, + int tile_m, int tile_n) { + // Estimate bank conflicts based on access pattern + // 32 banks in shared memory + const int BANK_COUNT = 32; + + // For row-major: threads in a warp access consecutive elements + int stride = tile_n; // Assuming row-major + int conflicts = std::gcd(stride, BANK_COUNT); + + // Perfect score if no conflicts + float score = (conflicts == 1) ? 1.0f : (1.0f / conflicts); + + return score; + } + + public: + // Convenience function to get recommended configuration + static dim3 get_optimal_block_dims(int M, int N, int K) { + int tile_m = GemmCalc::OptimizedTiles::kM; + int tile_n = GemmCalc::OptimizedTiles::kN; + int tile_k = GemmCalc::OptimizedTiles::kK; + auto config = autotune(M, N, K, tile_m, tile_n, tile_k); + return dim3(config.threads_x, config.threads_y, 1); + } + + static dim3 get_optimal_grid_dims(int M, int N, int K) { + int tile_m = GemmCalc::OptimizedTiles::kM; + int tile_n = GemmCalc::OptimizedTiles::kN; + // int tile_k = GemmCalc::OptimizedTiles::kK; + int grid_x = (M + tile_m - 1) / tile_m; + int grid_y = (N + tile_n - 1) / tile_n; + return dim3(grid_x, grid_y, 1); + } + + static constexpr dim3 default_block_dims = dim3(32, 9, 1); + + // Get configuration with rationale + static void print_autotuning_result(int M, int N, int K, int tile_m, + int tile_n, int tile_k) { + printf("=== ThreadBlock Autotuning ===\n"); + printf("Problem size: %dx%dx%d\n", M, N, K); + printf("Tile size: %dx%dx%d\n", tile_m, tile_n, tile_k); + printf("Bus width: %d bits\n", BusWidthBits); + + auto config = autotune(M, N, K, tile_m, tile_n, tile_k); + + printf("\nOptimal configuration:\n"); + printf(" threads.x = %d (warp size, for coalescing)\n", config.threads_x); + printf(" threads.y = %d (number of warps)\n", config.threads_y); + printf(" Total threads = %d\n", config.total_threads); + printf(" Score = %.2f\n", config.score); + + // Additional analysis + size_t smem = calculate_smem_requirement(tile_m, tile_n, tile_k); + printf("\nResource usage:\n"); + printf(" Shared memory per block: %zu bytes\n", smem); + printf(" Registers per thread: ~64 (estimated)\n"); + + float occupancy = calculate_occupancy(config, smem); + printf(" Theoretical occupancy: %.1f%%\n", occupancy * 100); + } +}; diff --git a/core/kernel/utils.h b/core/kernel/utils.h new file mode 100644 index 0000000..0792747 --- /dev/null +++ b/core/kernel/utils.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include + +#define KERNEL_LOG_DEBUG(msg, ...) \ + do { \ + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && \ + blockIdx.z == 0) { \ + printf(msg, ##__VA_ARGS__); \ + } \ + } while (0) + +template +using DetectedArchT = ArchTagT(ArchCode >= 900) ? cutlass::arch::Sm90 + : (ArchCode >= 800) ? cutlass::arch::Sm80 + : void > ; + +template +struct DetectedArch { + using SM = void; +}; + +template +struct DetectedArch 0)>> { + using SM = DetectedArchT<__CUDA_ARCH__>; +}; + +template +struct OptimalTileCalculator { + // Assuming we want to saturate memory in 4-8 cycles + static constexpr int bytes_per_cycle = BusWidthBits / 8; + static constexpr int target_cycles = 4; + + static constexpr int optimal_tile_bytes = bytes_per_cycle * target_cycles; + static constexpr int optimal_tile_elements_fp16 = optimal_tile_bytes / 2; + static constexpr int optimal_tile_elements_fp32 = optimal_tile_bytes / 4; +}; diff --git a/core/memory/caching_allocator.cpp b/core/memory/caching_allocator.cpp new file mode 100644 index 0000000..f5d454f --- /dev/null +++ b/core/memory/caching_allocator.cpp @@ -0,0 +1,299 @@ +#include "caching_allocator.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common/generator.h" +#include "utils/logger.h" + +std::unique_ptr kCachingAllocator = nullptr; +std::once_flag kInitCachingAllocatorFlag; + +void* CachingAllocator::Allocate(const size_t bytes) { + std::lock_guard guard(mutex_); + // DLOG_DEBUG("Try Allocate: size: {}, used: {}", bytes, used_bytes_); + const auto& it = available_map_.find(bytes); + void* ptr = nullptr; + if (it == available_map_.end() || it->second.empty()) { + if (bytes == 0) { + ptr = malloc(bytes); + DLOG_WARN("Attempted to allocate 0 bytes, return ", ptr); + } else { + ptr = AllocateAndCache(bytes); + } + } else { + ptr = it->second.back(); + it->second.pop_back(); + } + used_bytes_ += bytes; + allocation_map_[ptr] = bytes; + // DLOG_DEBUG("Allocate: {:p}, size: {}, used: {}", ptr, bytes, used_bytes_); + return ptr; +} + +void CachingAllocator::Free(void* ptr) { + std::lock_guard guard(mutex_); + // DLOG_DEBUG("Try Free: {:p}", ptr); + const auto& it = allocation_map_.find(ptr); + DLOG_FATAL_IF(it == allocation_map_.end(), + "Attempted to free unallocated memory ", ptr); + const size_t alloc_size = it->second; + available_map_[alloc_size].push_back(ptr); + used_bytes_ -= alloc_size; + allocation_map_.erase(it); + // DLOG_DEBUG("Free: {:p}, size: {}, used: {}", ptr, alloc_size, used_bytes_); +} + +void CachingAllocator::InsertShmMeta(ShmMeta meta) { + std::lock_guard guard(mutex_); + shm_id_map_[meta.ptr] = meta; +} + +ShmMeta CachingAllocator::FindShmMetaByRange(void* ptr) { + std::lock_guard guard(mutex_); + { + auto it = shm_id_map_.find(ptr); + if (it != shm_id_map_.end()) { + return it->second; + } + } + for (const auto& it : shm_id_map_) { + if (it.first <= ptr && ptr < (char*)it.first + it.second.size) { + DLOG_WARN_IF(it.first != ptr, "FindShmMetaByRange not exact: expected ", + ptr, ", got ", it.first); + return it.second; + } + } + DLOG_FATAL("Cannot find shm meta by range ", ptr); + return {}; +} + +// int CachingAllocator::GetShmId(void* ptr) { +// std::lock_guard guard(mutex_); +// // const auto& it = shm_id_map_.find(ptr); +// // DLOG_FATAL_IF(it == shm_id_map_.end(), "Cannot find shm id {:p}", ptr); +// auto meta = FindShmMetaByRange(ptr); +// return meta.id; +// // return it->second.id; +// } + +// std::string CachingAllocator::GetShmName(void* ptr) { +// std::lock_guard guard(mutex_); +// auto meta = FindShmMetaByRange(ptr); +// // const auto& it = shm_id_map_.find(ptr); +// // DLOG_FATAL_IF(it == shm_id_map_.end(), "Cannot find shm name {:p}", +// ptr); return meta.name; +// } + +// size_t CachingAllocator::GetShmSize(void* ptr) { +// std::lock_guard guard(mutex_); +// auto meta = FindShmMetaByRange(ptr); +// size_t remain_size = meta.size - (ptr - meta.ptr); +// // const auto& it = shm_id_map_.find(ptr); +// // DLOG_FATAL_IF(it == shm_id_map_.end(), "Cannot find shm size {:p}", +// ptr); return remain_size; +// } + +void CachingAllocator::FreeCached() { + for (const auto& it : available_map_) { + for (const auto& ptr : it.second) { + FreeMemory(ptr); + allocated_bytes_ -= it.first; + allocation_map_.erase(ptr); + used_bytes_ -= it.first; + } + } + available_map_.clear(); +} + +void* CachingAllocator::AllocateAndCache(const size_t bytes) { + // DLOG_DEBUG("AllocateAndCache: size: {}, used: {}", bytes, used_bytes_); + if (allocated_bytes_ + bytes > max_bytes_) { + FreeCached(); + DLOG_FATAL_IF(allocated_bytes_ + bytes > max_bytes_, + "Out of memory; attempted to allocate ", bytes / GB, + "GB, allocated ", allocated_bytes_ / GB, "GB, ", "but only ", + (max_bytes_ - allocated_bytes_) / GB, "GB available"); + } + void* ptr = AllocateMemory(bytes); + return ptr; +} + +void* CachingAllocator::AllocateMemory(size_t bytes) { + switch (type_) { + case MemoryType::SHM: + return AllocShmMemory(bytes); + case MemoryType::PIN: + return AllocPinMemory(bytes); + case MemoryType::CUDA: + return AllocCudaMemory(bytes); + case MemoryType::PIN_SHM: + return AllocPinShmMemory(bytes); + default: + DLOG_FATAL("Unknown memory type"); + return nullptr; + } + DLOG_FATAL("Unknown memory type"); + return nullptr; +} +void* CachingAllocator::AllocCudaMemory(size_t bytes) { + void* ptr; + cudaSetDevice(device_id_); + cudaMalloc(&ptr, bytes); + return ptr; +} +void* CachingAllocator::AllocPinMemory(size_t bytes) { + void* ptr = aligned_alloc(4096, bytes); + // int ret = mlock(ptr, bytes); + // DLOG_FATAL_IF(ret != 0, "mlock failed: errno {}, message {}", errno, + // strerror(errno)); + cudaHostRegister(ptr, bytes, cudaHostRegisterDefault); + // void* ptr; + // cudaHostAlloc(&ptr, bytes, cudaHostAllocDefault); + return ptr; +} +void* CachingAllocator::AllocShmMemory(size_t bytes) { + int shm_id = shmget(IPC_PRIVATE, bytes, IPC_CREAT | 0666); + void* ptr = shmat(shm_id, nullptr, 0); + DLOG_FATAL_IF(ptr == (void*)-1, "shmat failed: errno ", errno, ", message ", + strerror(errno)); + shm_id_map_[ptr] = {shm_id, ptr, bytes, ""}; + return ptr; +} +void* CachingAllocator::AllocPinShmMemory(size_t bytes) { + ShmMeta shm_meta; + shm_meta.name = "/emulator_shm_" + GenUUID(); + + // DLOG_DEBUG("shm_meta name: {}", shm_meta.name); + + int shm_fd = shm_open(shm_meta.name.c_str(), O_CREAT | O_RDWR, 0666); + DLOG_FATAL_IF(shm_fd == -1, "shm_open failed: errno ", errno, ", message ", + strerror(errno)); + DLOG_FATAL_IF(ftruncate(shm_fd, bytes) == -1, "ftruncate failed: errno ", + errno, ", message ", strerror(errno)); + + size_t page_size = sysconf(_SC_PAGESIZE); + size_t aligned_bytes = ((bytes + page_size - 1) / page_size) * page_size; + + // Specify the fixed address + void* buf = AllocPinMemory(aligned_bytes); + + // DLOG_DEBUG("alloc pin shm: bytes: {}, preferred_addr: {:p}, aligned_bytes: + // {}", + // bytes, buf, aligned_bytes); + void* shm_addr = mmap(buf, aligned_bytes, PROT_READ | PROT_WRITE, + MAP_SHARED | MAP_FIXED | MAP_LOCKED, shm_fd, 0); + DLOG_FATAL_IF(shm_addr == MAP_FAILED, "mmap failed: errno ", errno, + ", message ", strerror(errno)); + + DLOG_FATAL_IF(shm_addr != buf, "mmap failed: expected addr: ", buf, ", got ", + shm_addr); + + shm_meta.id = shm_fd; + shm_meta.ptr = shm_addr; + shm_meta.size = aligned_bytes; + shm_id_map_[shm_addr] = shm_meta; + // DLOG_DEBUG("shm_meta: {}", shm_meta); + + return shm_addr; +} + +void CachingAllocator::FreeMemory(void* ptr) { + switch (type_) { + case MemoryType::SHM: + FreeShmMemory(ptr); + break; + case MemoryType::PIN: + FreePinMemory(ptr); + break; + case MemoryType::CUDA: + FreeCudaMemory(ptr); + break; + case MemoryType::PIN_SHM: + FreePinShmMemory(ptr); + break; + default: + DLOG_FATAL("Unknown memory type"); + break; + } +} +void CachingAllocator::FreeCudaMemory(void* ptr) { + cudaSetDevice(device_id_); + cudaFree(ptr); +} +void CachingAllocator::FreePinMemory(void* ptr) { cudaFreeHost(ptr); } +void CachingAllocator::FreeShmMemory(void* ptr) { + shmdt(ptr); + auto shmid = shm_id_map_[ptr].id; + shmctl(shmid, IPC_RMID, nullptr); + shm_id_map_.erase(ptr); +} +void CachingAllocator::FreePinShmMemory(void* ptr) { + auto meta = FindShmMetaByRange(ptr); + auto shm_name = meta.name; + auto size = meta.size; + // auto id = meta.id; + + // DLOG_DEBUG("FreePinShmMemory: addr: {:p}, name: {}, size: {}", ptr, + // shm_name, size); + munmap(ptr, size); + // shm_unlink(shm_id_map_[ptr].name); + close(shm_id_map_[ptr].id); + FreePinMemory(ptr); + shm_id_map_.erase(ptr); +} + +CachingAllocator::CachingAllocator(size_t bytes, MemoryType type, int device_id) + : max_bytes_(bytes), + allocated_bytes_(0), + used_bytes_(0), + type_(type), + device_id_(device_id) {} + +CachingAllocator::~CachingAllocator() { FreeCached(); } + +extern "C" { +void* TorchAllocate(size_t bytes) { + // DLOG_DEBUG("TorchAllocate: size: {}", bytes); + InitCachingAllocator(MemoryType::PIN); + void* ptr = kCachingAllocator->Allocate(bytes); + return ptr; +} + +void TorchFree(void* ptr) { + // DLOG_DEBUG("TorchFree: {:p}", ptr); + InitCachingAllocator(MemoryType::PIN); + kCachingAllocator->Free(ptr); +} + +void TorchFreeCtx(void* ctx) { + InitCachingAllocator(MemoryType::PIN); + TorchCtx* torch_ctx = static_cast(ctx); + kCachingAllocator->Free(torch_ctx->ptr); + delete torch_ctx; +} + +void* TorchAllocateDevice(size_t bytes) { + InitCachingAllocator(MemoryType::CUDA); + void* ptr = kCachingAllocator->Allocate(bytes); + return ptr; +} + +void TorchFreeDevice(void* ptr) { + InitCachingAllocator(MemoryType::CUDA); + kCachingAllocator->Free(ptr); +} + +void TorchFreeCtxDevice(void* ctx) { + InitCachingAllocator(MemoryType::CUDA); + TorchCtx* torch_ctx = static_cast(ctx); + kCachingAllocator->Free(torch_ctx->ptr); + delete torch_ctx; +} +} diff --git a/core/memory/caching_allocator.h b/core/memory/caching_allocator.h index 01251f4..28106ea 100644 --- a/core/memory/caching_allocator.h +++ b/core/memory/caching_allocator.h @@ -1,108 +1,137 @@ #pragma once #include + +#include +#include +#include +#include #include -#include -#include +#include "common/types.h" +#include "shared_memory.h" #include "utils/cuda_utils.h" +#include "utils/logger.h" -// Templated CachingAllocator class -template -class CachingAllocator { - public: - static CachingAllocator* instance(int idx) { - static std::array*, 8> instances; - if (instances[idx] == nullptr) { - instances[idx] = new CachingAllocator(); - } - return instances[idx]; - } - - void* allocate(const size_t bytes) { - const auto& it = available_map_.find(bytes); - if (it == available_map_.end() || it->second.empty()) { - return allocate_and_cache(bytes); - } - void* ptr = it->second.back(); - it->second.pop_back(); - return ptr; - } - - void free(void* ptr) { - const auto& it = allocation_map_.find(ptr); - if (it == allocation_map_.end()) { - Allocator::deallocate(ptr); - return; - } - const size_t alloc_size = it->second; - available_map_[alloc_size].push_back(ptr); - } - - void record_free(void* ptr) { - const auto& it = allocation_map_.find(ptr); - if (it != allocation_map_.end()) { - allocation_map_.erase(it); - } - } +#define MEMORY_TYPE_VALUES(X, EnumType) \ + X(SHM, EnumType) \ + X(PIN, EnumType) \ + X(CUDA, EnumType) \ + X(PIN_SHM, EnumType) - void free_cached() { - for (const auto& it : available_map_) { - for (const auto ptr : it.second) { - Allocator::deallocate(ptr); - allocation_map_.erase(ptr); - } - } - available_map_.clear(); - } +DEFINE_ENUM_CLASS(MemoryType, MEMORY_TYPE_VALUES) - ~CachingAllocator() { free_cached(); } +class CachingAllocator; +extern std::unique_ptr kCachingAllocator; - private: - void* allocate_and_cache(const size_t bytes) { - void* ptr = Allocator::allocate(bytes); - allocation_map_[ptr] = bytes; - return ptr; - } - - std::unordered_map> available_map_; - std::unordered_map allocation_map_; +struct TorchCtx { + void* ptr; + size_t size; }; -// Example Allocator for CUDA -struct CudaDeviceAllocator { - static void* allocate(size_t bytes) { - void* ptr; - CUDA_CHECK(cudaMalloc(&ptr, bytes)); - return ptr; - } +extern "C" { +void* TorchAllocate(size_t bytes); +void TorchFree(void* ptr); +void TorchFreeCtx(void* ctx); +void* TorchAllocateDevice(size_t bytes); +void TorchFreeDevice(void* ptr); +void TorchFreeCtxDevice(void* ctx); +} + +// the caching allocator that supports CPU and CUDA memory +// work as an offset manager for the memory pool +class CachingAllocator : public base::noncopyable { + public: + explicit CachingAllocator(size_t bytes, MemoryType type, int device_id = -1); + virtual ~CachingAllocator(); - static void deallocate(void* ptr) { CUDA_CHECK(cudaFree(ptr)); } -}; + virtual void* Allocate(const size_t bytes); + virtual void Free(void* ptr); -// Example Allocator for Unified Memory -struct CudaUnifiedAllocator { - static void* allocate(size_t bytes) { - void* ptr; - CUDA_CHECK(cudaMallocManaged(&ptr, bytes)); - return ptr; + bool IsAllocated(void* ptr) { + std::lock_guard guard(mutex_); + return allocation_map_.find(ptr) != allocation_map_.end(); } - static void deallocate(void* ptr) { CUDA_CHECK(cudaFree(ptr)); } -}; + ShmMeta FindShmMetaByRange(void* ptr); + void InsertShmMeta(ShmMeta meta); -// Example Allocator for cudaHostAlloc -struct CudaHostAllocator { - static void* allocate(size_t bytes) { - void* ptr; - CUDA_CHECK(cudaHostAlloc(&ptr, bytes, cudaHostAllocDefault)); - return ptr; - } + MemoryType GetType() const { return type_; } + + size_t GetMaxBytes() const { return max_bytes_; } + size_t GetAllocatedBytes() const { return allocated_bytes_; } + size_t GetUsedBytes() const { return used_bytes_; } - static void deallocate(void* ptr) { CUDA_CHECK(cudaFreeHost(ptr)); } + private: + void* AllocateAndCache(const size_t bytes); + void FreeCached(); + + void* AllocateMemory(size_t bytes); + + void* AllocCudaMemory(size_t bytes); + void* AllocPinMemory(size_t bytes); + void* AllocShmMemory(size_t bytes); + void* AllocPinShmMemory(size_t bytes); + + void FreeMemory(void* ptr); + void FreeCudaMemory(void* ptr); + void FreePinMemory(void* ptr); + void FreeShmMemory(void* ptr); + void FreePinShmMemory(void* ptr); + + protected: + int device_id_; + MemoryType type_; + const size_t max_bytes_; + size_t allocated_bytes_; + size_t used_bytes_; + + std::unordered_map> available_map_; + std::unordered_map allocation_map_; + std::mutex mutex_; + std::unordered_map shm_id_map_; }; -// Template specialization for all types of CachingAllocator -typedef CachingAllocator CudaDeviceCachingAllocator; -typedef CachingAllocator CudaUnifiedCachingAllocator; -typedef CachingAllocator CudaHostCachingAllocator; +extern std::once_flag kInitCachingAllocatorFlag; + +static void InitCachingAllocator(MemoryType type, int device_id = -1) { + std::call_once(kInitCachingAllocatorFlag, [&]() { + size_t bytes = 0; + DLOG_DEBUG("InitCachingAllocator: type: {}, device_id: {}", type, + device_id); + if (type == MemoryType::CUDA) { + // DLOG_FATAL_IF(device_id < 0, "Invalid device id"); + if (device_id < 0) { + CUDA_CHECK(cudaGetDevice(&device_id)); + } + // Get environment variable MOEINF_SHM_SIZE + const char* size = std::getenv("MOEINF_GPU_SIZE"); + if (size == nullptr) { + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, device_id)); + bytes = prop.totalGlobalMem; + } else { + DLOG_FATAL_IF(size == nullptr, "MOEINF_GPU_SIZE is not set"); + bytes = std::stoull(size); + } + } else if (type == MemoryType::SHM) { + // Get environment variable MOEINF_SHM_SIZE + const char* size = std::getenv("MOEINF_SHM_SIZE"); + DLOG_FATAL_IF(size == nullptr, "MOEINF_SHM_SIZE is not set"); + bytes = std::stoull(size); + } else if (type == MemoryType::PIN or type == MemoryType::PIN_SHM) { + const char* size = std::getenv("MOEINF_PIN_SIZE"); + DLOG_FATAL_IF(size == nullptr, "MOEINF_PIN_SIZE is not set"); + bytes = std::stoull(size); + } else { + DLOG_FATAL("Unknown memory type"); + } + DLOG_FATAL_IF(kCachingAllocator != nullptr, + "Caching allocator is already initialized"); + + kCachingAllocator = + std::make_unique(bytes, type, device_id); + DLOG_INFO("Caching allocator initialized with {}GB, type: {}", bytes / GB, + type); + }); +} diff --git a/core/memory/caching_allocator_bk.h b/core/memory/caching_allocator_bk.h new file mode 100644 index 0000000..01251f4 --- /dev/null +++ b/core/memory/caching_allocator_bk.h @@ -0,0 +1,108 @@ +#pragma once + +#include +#include +#include +#include + +#include "utils/cuda_utils.h" + +// Templated CachingAllocator class +template +class CachingAllocator { + public: + static CachingAllocator* instance(int idx) { + static std::array*, 8> instances; + if (instances[idx] == nullptr) { + instances[idx] = new CachingAllocator(); + } + return instances[idx]; + } + + void* allocate(const size_t bytes) { + const auto& it = available_map_.find(bytes); + if (it == available_map_.end() || it->second.empty()) { + return allocate_and_cache(bytes); + } + void* ptr = it->second.back(); + it->second.pop_back(); + return ptr; + } + + void free(void* ptr) { + const auto& it = allocation_map_.find(ptr); + if (it == allocation_map_.end()) { + Allocator::deallocate(ptr); + return; + } + const size_t alloc_size = it->second; + available_map_[alloc_size].push_back(ptr); + } + + void record_free(void* ptr) { + const auto& it = allocation_map_.find(ptr); + if (it != allocation_map_.end()) { + allocation_map_.erase(it); + } + } + + void free_cached() { + for (const auto& it : available_map_) { + for (const auto ptr : it.second) { + Allocator::deallocate(ptr); + allocation_map_.erase(ptr); + } + } + available_map_.clear(); + } + + ~CachingAllocator() { free_cached(); } + + private: + void* allocate_and_cache(const size_t bytes) { + void* ptr = Allocator::allocate(bytes); + allocation_map_[ptr] = bytes; + return ptr; + } + + std::unordered_map> available_map_; + std::unordered_map allocation_map_; +}; + +// Example Allocator for CUDA +struct CudaDeviceAllocator { + static void* allocate(size_t bytes) { + void* ptr; + CUDA_CHECK(cudaMalloc(&ptr, bytes)); + return ptr; + } + + static void deallocate(void* ptr) { CUDA_CHECK(cudaFree(ptr)); } +}; + +// Example Allocator for Unified Memory +struct CudaUnifiedAllocator { + static void* allocate(size_t bytes) { + void* ptr; + CUDA_CHECK(cudaMallocManaged(&ptr, bytes)); + return ptr; + } + + static void deallocate(void* ptr) { CUDA_CHECK(cudaFree(ptr)); } +}; + +// Example Allocator for cudaHostAlloc +struct CudaHostAllocator { + static void* allocate(size_t bytes) { + void* ptr; + CUDA_CHECK(cudaHostAlloc(&ptr, bytes, cudaHostAllocDefault)); + return ptr; + } + + static void deallocate(void* ptr) { CUDA_CHECK(cudaFreeHost(ptr)); } +}; + +// Template specialization for all types of CachingAllocator +typedef CachingAllocator CudaDeviceCachingAllocator; +typedef CachingAllocator CudaUnifiedCachingAllocator; +typedef CachingAllocator CudaHostCachingAllocator; diff --git a/core/memory/shared_memory.cpp b/core/memory/shared_memory.cpp new file mode 100644 index 0000000..c34fa3a --- /dev/null +++ b/core/memory/shared_memory.cpp @@ -0,0 +1,47 @@ +#include "shared_memory.h" + +void* OpenSharedMemory(const char* name, size_t size) { + int shm_fd = shm_open(name, O_RDWR, 0666); + LOG_FATAL_IF(shm_fd == -1, + "shm_open failed. name: {}, size: {}; errno: {}, message: {}", + name, size, errno, strerror(errno)); + + void* ptr = + mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); + LOG_FATAL_IF(ptr == MAP_FAILED, + "mmap failed. name: {}, size: {}; errno: {}, message: {}", name, + size, errno, strerror(errno)); + return ptr; +} + +void CloseSharedMemory(void* ptr, size_t size) { + int ret = munmap(ptr, size); + LOG_FATAL_IF(ret == -1, + "munmap failed. ptr: {0:x}, size: {}; errno: {}, message: {}", + ptr, size, errno, strerror(errno)); +} + +std::tuple AttachSharedMemory(const char* name, size_t size) { + int shm_fd = shm_open(name, O_RDWR, 0666); + LOG_FATAL_IF(shm_fd == -1, + "shm_open failed. name: {}, size: {}; errno: {}, message: {}", + name, size, errno, strerror(errno)); + + void* ptr = + mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); + LOG_FATAL_IF(ptr == MAP_FAILED, + "mmap failed. name: {}, size: {}; errno: {}, message: {}", name, + size, errno, strerror(errno)); + return {ptr, shm_fd}; +} + +void DetachSharedMemory(void* ptr, int fd, size_t size) { + int ret = munmap(ptr, size); + LOG_FATAL_IF(ret == -1, + "munmap failed. ptr: {0:x}, size: {}; errno: {}, message: {}", + ptr, size, errno, strerror(errno)); + + ret = close(fd); + LOG_FATAL_IF(ret == -1, "close failed. fd: {}, errno: {}, message: {}", fd, + errno, strerror(errno)); +} diff --git a/core/memory/shared_memory.h b/core/memory/shared_memory.h new file mode 100644 index 0000000..7627bd7 --- /dev/null +++ b/core/memory/shared_memory.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include // for memset +#include +#include + +#include "common/types.h" +#include "utils/logger.h" + +struct ShmMeta { + int id; + void* ptr; + size_t size; + std::string name; + bool is_remote = false; +}; + +void* OpenSharedMemory(const char* name, size_t size); +void CloseSharedMemory(void* ptr, size_t size); +std::tuple AttachSharedMemory(const char* name, size_t size); +void DetachSharedMemory(void* ptr, int fd, size_t size); + +struct ShmDeleter { + void operator()(void* ptr) const { + LOG_DEBUG << "ShmDeleter: ptr: " << ptr; + DetachSharedMemory(ptr, fd, size); + } + size_t size; + int fd; +}; + +template +std::shared_ptr AttachSharedMemoryPtr(const char* name, size_t size) { + auto [ptr, shm_fd] = AttachSharedMemory(name, size); + return std::shared_ptr(static_cast(ptr), ShmDeleter{size, shm_fd}); +} diff --git a/core/memory/torch_caching_allocator.cpp b/core/memory/torch_caching_allocator.cpp new file mode 100644 index 0000000..699188a --- /dev/null +++ b/core/memory/torch_caching_allocator.cpp @@ -0,0 +1,6 @@ +#include "torch_caching_allocator.h" + +// std::unique_ptr kTorchCachingAllocator = +// std::make_unique(); + +ReplaceTorchAllocatorOnLoad kReplaceTorchAllocatorOnLoad; diff --git a/core/memory/torch_caching_allocator.h b/core/memory/torch_caching_allocator.h new file mode 100644 index 0000000..1f62f49 --- /dev/null +++ b/core/memory/torch_caching_allocator.h @@ -0,0 +1,45 @@ +#pragma once + +#include + +#include "caching_allocator.h" + +struct TorchCachingAllocator : public torch::Allocator { + // For Torch Interface + torch::DataPtr allocate(size_t n) override { + void* data = TorchAllocate(n); + return {data, data, &TorchFree, torch::DeviceType::CPU}; + } + + void copy_data(void* dest, const void* src, size_t count) const override { + LOG_DEBUG("Copy data from {:p} to {:p}, size: {}", src, dest, count); + memcpy(dest, src, count); + } + + // // Optional: Handle deallocation (if needed) + // void deallocate(void* ptr) override { + // Free(ptr); // Custom deallocation logic + // } +}; + +// extern std::unique_ptr kTorchCachingAllocator; + +class ReplaceTorchAllocatorOnLoad { + public: + ReplaceTorchAllocatorOnLoad() { + std::call_once(flag_, [&]() { + InitCachingAllocator(MemoryType::PINNED); + torch_caching_allocator_ = new TorchCachingAllocator(); + LOG_INFO("Replace torch allocator with caching allocator"); + torch::SetAllocator(torch::DeviceType::CPU, torch_caching_allocator_); + LOG_INFO("Torch allocator replaced"); + }); + } + + private: + TorchCachingAllocator* torch_caching_allocator_; + std::once_flag flag_; +}; + +// Create a static instance of this class +extern ReplaceTorchAllocatorOnLoad kReplaceTorchAllocatorOnLoad; diff --git a/core/parallel/expert_dispatcher.cpp b/core/parallel/expert_dispatcher.cpp index 8ceceb0..c797840 100644 --- a/core/parallel/expert_dispatcher.cpp +++ b/core/parallel/expert_dispatcher.cpp @@ -189,8 +189,8 @@ void ExpertDispatcher::RegisterExpert( if (cached_node == nullptr) { cached_node = node; experts_[expert_idx][layer_idx]->node = node; - experts_[expert_idx][layer_idx]->jit_module = - new torch::jit::script::Module(torch::jit::load(jit_path)); + // experts_[expert_idx][layer_idx]->jit_module = + // new torch::jit::script::Module(torch::jit::load(jit_path)); } else if (cached_node != node) { DLOG_FATAL("RegisterExpert: tensor_id has multiple nodes", tensor_id); } diff --git a/core/parallel/expert_module.cpp b/core/parallel/expert_module.cpp index 4cd8c78..51ab673 100644 --- a/core/parallel/expert_module.cpp +++ b/core/parallel/expert_module.cpp @@ -4,12 +4,13 @@ // EfficientMoE Team #include "expert_module.h" -#include "aio/archer_tensor_handle.h" -#include "memory/caching_allocator.h" +// #include "memory/caching_allocator.h" +#include "utils/cuda_utils.h" #include "utils/logger.h" static const int64_t kMaxTokens = 128; +/* SwitchTransformersDenseActDense::SwitchTransformersDenseActDense(int dtype) { // auto tensor_dtype = dtype_to_torch(dtype); auto options = torch::TensorOptions().device(torch::kCPU); @@ -226,11 +227,12 @@ void MixtralMoEDenseActDense::SetModuleFromBlob( torch::Tensor MixtralMoEDenseActDense::forward(torch::Tensor hidden_states, cudaStream_t stream) { - /* - current_hidden_states = self.silu(self.w1(hidden_states)) * - self.w3(hidden_states) current_hidden_states = self.w2(current_hidden_states) - return current_hidden_states - */ + + // current_hidden_states = self.silu(self.w1(hidden_states)) * + // self.w3(hidden_states) current_hidden_states = +self.w2(current_hidden_states) + // return current_hidden_states + // int w1_nan = torch::sum(torch::isnan(w1)).item(); // int w2_nan = torch::sum(torch::isnan(w2)).item(); // int w3_nan = torch::sum(torch::isnan(w3)).item(); @@ -303,33 +305,34 @@ torch::Tensor DeepSeekMoEDenseActDense::forward(torch::Tensor hidden_states, torch::matmul(hidden_states, up_proj.transpose(0, 1)), down_proj.transpose(0, 1)); } +*/ void ExpertNode::SetTensorsFromBlob(const torch::Device& device) { - int expert_type = this->expert_type; + auto expert_type = static_cast(this->expert_type); switch (expert_type) { - case SWITCH_TRANSFORMERS_DENSE_ACT_DENSE: + case ExpertType::SwitchTransformersDenseActDense: reinterpret_cast(module) ->SetTensorsFromBlob(node->device_memory_ptr, node->tensor_ids, device); break; - case SWITCH_TRANSFORMERS_DENSE_GATED_ACT_DENSE: + case ExpertType::SwitchTransformersDenseGatedActDense: reinterpret_cast(module) ->SetTensorsFromBlob(node->device_memory_ptr, node->tensor_ids, device); break; - case NLLB_MOE_DENSE_ACT_DENSE: + case ExpertType::NllbMoeDenseActDense: reinterpret_cast(module)->SetTensorsFromBlob( node->device_memory_ptr, node->tensor_ids, device); break; - case FSGPT_MOE_DENSE_ACT_DENSE: + case ExpertType::FSGPTMoeDenseActDense: reinterpret_cast(module)->SetTensorsFromBlob( node->device_memory_ptr, node->tensor_ids, device); break; - case MIXTRAL_MOE_DENSE_ACT_DENSE: + case ExpertType::MixtralMoeDenseActDense: reinterpret_cast(module)->SetTensorsFromBlob( node->device_memory_ptr, node->tensor_ids, device); break; - case DEEPSEEK_MOE_DENSE_ACT_DENSE: + case ExpertType::DeepSeekMoeDenseActDense: reinterpret_cast(module)->SetTensorsFromBlob( node->device_memory_ptr, node->tensor_ids, device); break; diff --git a/core/parallel/expert_module.h b/core/parallel/expert_module.h index 4597ede..663f482 100644 --- a/core/parallel/expert_module.h +++ b/core/parallel/expert_module.h @@ -5,7 +5,216 @@ #include #include + #include "model/model_topology.h" +#include "aio/archer_tensor_handle.h" + +// Expert type enum +enum class ExpertType { + SwitchTransformersDenseActDense = 0, + SwitchTransformersDenseGatedActDense = 1, + NllbMoeDenseActDense = 2, + FSGPTMoeDenseActDense = 3, + MixtralMoeDenseActDense = 4, + DeepSeekMoeDenseActDense = 5 +}; + +// Activation functions enum +enum class ActivationType { ReLU, GELU, SiLU, Identity }; + +// Base traits for expert architectures +template +struct ExpertTraits; + +template <> +struct ExpertTraits { + static constexpr size_t num_weights = 2; + static constexpr size_t num_biases = 0; + static constexpr std::array weight_names = {"wi", "wo"}; + static constexpr std::array bias_names = {}; +}; + +template <> +struct ExpertTraits { + static constexpr size_t num_weights = 3; + static constexpr size_t num_biases = 0; + static constexpr std::array weight_names = {"wi_0", "wi_1", + "wo"}; + static constexpr std::array bias_names = {}; +}; + +template <> +struct ExpertTraits { + static constexpr size_t num_weights = 2; + static constexpr size_t num_biases = 2; + static constexpr std::array weight_names = {"fc1", "fc2"}; + static constexpr std::array bias_names = {"fc1_bias", + "fc2_bias"}; +}; + +template <> +struct ExpertTraits { + static constexpr size_t num_weights = 2; + static constexpr size_t num_biases = 2; + static constexpr std::array weight_names = {"fc1", "fc2"}; + static constexpr std::array bias_names = {"fc1_bias", + "fc2_bias"}; +}; + +template <> +struct ExpertTraits { + static constexpr size_t num_weights = 3; + static constexpr size_t num_biases = 0; + static constexpr std::array weight_names = {"w1", "w2", "w3"}; + static constexpr std::array bias_names = {}; +}; + +template <> +struct ExpertTraits { + static constexpr size_t num_weights = 3; + static constexpr size_t num_biases = 0; + static constexpr std::array weight_names = { + "gate_proj", "up_proj", "down_proj"}; + static constexpr std::array bias_names = {}; +}; + +// Templated expert implementation +template +class Expert : public torch::nn::Module { + private: + using Traits = ExpertTraits; + std::array weights_; + std::array biases_; + + public: + explicit Expert(int dtype) { + auto tensor_dtype = dtype_to_torch(dtype); + auto options = + torch::TensorOptions().dtype(tensor_dtype).device(torch::kCPU); + + // Register weights + for (size_t i = 0; i < Traits::num_weights; ++i) { + weights_[i] = register_parameter(Traits::weight_names[i], + torch::zeros({1}, options)); + } + + // Register biases if any + for (size_t i = 0; i < Traits::num_biases; ++i) { + biases_[i] = + register_parameter(Traits::bias_names[i], torch::zeros({1}, options)); + } + } + + torch::Tensor forward(torch::Tensor hidden_states, + cudaStream_t stream = nullptr); + + void SetTensorsFromBlob(void* ptr, + const std::vector& tensor_ids, + const torch::Device& device) { + size_t idx = 0; + + // Set weights + for (size_t i = 0; i < Traits::num_weights; ++i) { + weights_[i] = kTensorIndex->find(tensor_ids[idx++])->second.tensor; + } + + // Set biases + for (size_t i = 0; i < Traits::num_biases; ++i) { + biases_[i] = kTensorIndex->find(tensor_ids[idx++])->second.tensor; + } + } + + void SetModuleFromBlob(torch::jit::script::Module* ptr) { + for (auto it = ptr->parameters().begin(); it != ptr->parameters().end(); + ++it) { + // Set weights + for (size_t i = 0; i < Traits::num_weights; ++i) { + if ((*it).name() == Traits::weight_names[i]) { + (*it).set_data(weights_[i]); + } + } + + // Set biases + for (size_t i = 0; i < Traits::num_biases; ++i) { + if ((*it).name() == Traits::bias_names[i]) { + (*it).set_data(biases_[i]); + } + } + } + } +}; + +// Forward specializations +template <> +inline torch::Tensor +Expert::forward( + torch::Tensor hidden_states, cudaStream_t stream) { + return torch::matmul( + torch::relu(torch::matmul(hidden_states, weights_[0].transpose(0, 1).to( + hidden_states.dtype()))), + weights_[1].transpose(0, 1).to(hidden_states.dtype())); +} + +template <> +inline torch::Tensor +Expert::forward( + torch::Tensor hidden_states, cudaStream_t stream) { + auto gate = + torch::gelu(torch::matmul(hidden_states, weights_[0].transpose(0, 1))); + auto linear = torch::matmul(hidden_states, weights_[1].transpose(0, 1)); + return torch::matmul(torch::mul(gate, linear), weights_[2].transpose(0, 1)); +} + +template <> +inline torch::Tensor Expert::forward( + torch::Tensor hidden_states, cudaStream_t stream) { + return torch::matmul(torch::relu(torch::matmul(hidden_states, + weights_[0].transpose(0, 1)) + + biases_[0]), + weights_[1].transpose(0, 1)) + + biases_[1]; +} + +template <> +inline torch::Tensor Expert::forward( + torch::Tensor hidden_states, cudaStream_t stream) { + if (hidden_states.dtype() != weights_[0].dtype()) { + hidden_states = hidden_states.to(weights_[0].dtype()); + } + return torch::matmul(torch::relu(torch::matmul(hidden_states, + weights_[0].transpose(0, 1)) + + biases_[0]), + weights_[1].transpose(0, 1)) + + biases_[1]; +} + +template <> +inline torch::Tensor Expert::forward( + torch::Tensor hidden_states, cudaStream_t stream) { + return torch::matmul( + torch::silu(torch::matmul(hidden_states, weights_[0].transpose(0, 1))) * + torch::matmul(hidden_states, weights_[2].transpose(0, 1)), + weights_[1].transpose(0, 1)); +} + +template <> +inline torch::Tensor Expert::forward( + torch::Tensor hidden_states, cudaStream_t stream) { + return torch::matmul( + torch::silu(torch::matmul(hidden_states, weights_[0].transpose(0, 1))) * + torch::matmul(hidden_states, weights_[1].transpose(0, 1)), + weights_[2].transpose(0, 1)); +} + +// Type aliases for compatibility +using SwitchTransformersDenseActDense = + Expert; +using SwitchTransformersDenseGatedActDense = + Expert; +using NllbMoeDenseActDense = Expert; +using FSGPTMoEDenseActDense = Expert; +using MixtralMoEDenseActDense = Expert; +using DeepSeekMoEDenseActDense = Expert; #ifndef EXPERT_TYPE #define EXPERT_TYPE 0 @@ -25,32 +234,6 @@ torch::Tensor launch_fused_moe_ffn(torch::Tensor hidden, // [M, K] torch::Tensor w3, // [K, N] cudaStream_t stream); // CUDA stream -struct ModuleUtils { - virtual void SetTensorsFromBlob(void* ptr, - const std::vector& tensor_ids, - const torch::Device& device) = 0; - virtual void SetModuleFromBlob(torch::jit::script::Module* ptr) = 0; -}; - -#define DECLARE_MODULE(name, ...) \ - struct name : public torch::nn::Module, public ModuleUtils { \ - name(int dtype); \ - torch::Tensor forward(torch::Tensor hidden_states, \ - cudaStream_t stream = nullptr); \ - torch::Tensor __VA_ARGS__; \ - void SetTensorsFromBlob(void* ptr, \ - const std::vector& tensor_ids, \ - const torch::Device& device) override; \ - void SetModuleFromBlob(torch::jit::script::Module* ptr) override; \ - }; - -DECLARE_MODULE(SwitchTransformersDenseActDense, wi, wo) -DECLARE_MODULE(SwitchTransformersDenseGatedActDense, wi_0, wi_1, wo) -DECLARE_MODULE(NllbMoeDenseActDense, fc1, fc2, fc1_bias, fc2_bias) -DECLARE_MODULE(FSGPTMoEDenseActDense, fc1, fc2, fc1_bias, fc2_bias) -DECLARE_MODULE(MixtralMoEDenseActDense, w1, w2, w3) -DECLARE_MODULE(DeepSeekMoEDenseActDense, gate_proj, up_proj, down_proj) - struct MoEMLP : public torch::nn::Module { explicit MoEMLP(int dtype, int expert_type); torch::Tensor forward(torch::Tensor hidden_states, cudaStream_t stream); @@ -84,78 +267,6 @@ struct MoEMLP : public torch::nn::Module { int expert_type_; }; -// struct SwitchTransformersDenseActDense : public torch::nn::Module, -// public ModuleUtils { -// SwitchTransformersDenseActDense(int dtype); -// torch::Tensor forward(torch::Tensor hidden_states); -// torch::Tensor wi, wo; - -// void SetTensorsFromBlob(void* ptr, -// const std::vector& tensor_ids, -// const torch::Device& device) override; -// void SetModuleFromBlob(torch::jit::script::Module* ptr) override; -// }; - -// struct SwitchTransformersDenseGatedActDense : public torch::nn::Module, -// public ModuleUtils { -// SwitchTransformersDenseGatedActDense(int dtype); -// torch::Tensor forward(torch::Tensor hidden_states); -// torch::Tensor wi_0, wi_1, wo; - -// void SetTensorsFromBlob(void* ptr, -// const std::vector& tensor_ids, -// const torch::Device& device) override; -// void SetModuleFromBlob(torch::jit::script::Module* ptr) override; -// }; - -// struct NllbMoeDenseActDense : public torch::nn::Module, public ModuleUtils { -// NllbMoeDenseActDense(int dtype); -// torch::Tensor forward(torch::Tensor hidden_states); -// torch::Tensor fc1, fc2; -// torch::Tensor fc1_bias, fc2_bias; - -// void SetTensorsFromBlob(void* ptr, -// const std::vector& tensor_ids, -// const torch::Device& device) override; -// void SetModuleFromBlob(torch::jit::script::Module* ptr) override; -// }; - -// struct FSGPTMoEDenseActDense : public torch::nn::Module, public ModuleUtils { -// FSGPTMoEDenseActDense(int dtype); -// torch::Tensor forward(torch::Tensor hidden_states); -// torch::Tensor fc1, fc2; -// torch::Tensor fc1_bias, fc2_bias; - -// void SetTensorsFromBlob(void* ptr, -// const std::vector& tensor_ids, -// const torch::Device& device) override; -// void SetModuleFromBlob(torch::jit::script::Module* ptr) override; -// }; - -// struct MixtralMoEDenseActDense : public torch::nn::Module, public ModuleUtils -// { -// MixtralMoEDenseActDense(int dtype); -// torch::Tensor forward(torch::Tensor hidden_states); -// torch::Tensor w1, w2, w3; - -// void SetTensorsFromBlob(void* ptr, -// const std::vector& tensor_ids, -// const torch::Device& device) override; -// void SetModuleFromBlob(torch::jit::script::Module* ptr) override; -// }; - -// struct DeepSeekMoEDenseActDense : public torch::nn::Module, public -// ModuleUtils { -// DeepSeekMoEDenseActDense(int dtype); -// torch::Tensor forward(torch::Tensor hidden_states, cudaStream_t stream); -// torch::Tensor gate_proj, up_proj, down_proj; - -// void SetTensorsFromBlob(void* ptr, -// const std::vector& tensor_ids, -// const torch::Device& device) override; -// void SetModuleFromBlob(torch::jit::script::Module* ptr) override; -// }; - struct ExpertNode { NodePtr node; torch::nn::Module* module; diff --git a/core/utils/cuda_utils.h b/core/utils/cuda_utils.h index 898e014..a3a2f9a 100644 --- a/core/utils/cuda_utils.h +++ b/core/utils/cuda_utils.h @@ -6,9 +6,11 @@ #pragma once #include +#include #include #include #include +#include inline void throwOnCudaError(cudaError_t error, const char* file, int line, const char* function, const char* call) { @@ -21,9 +23,63 @@ inline void throwOnCudaError(cudaError_t error, const char* file, int line, } }; +inline void throwOnCutlassError(cutlass::Status status, const char* file, + int line, const char* function, + const char* call) { + if (status != cutlass::Status::kSuccess) { + std::stringstream ss; + ss << "CUTLASS error " << static_cast(status) << " at " << file << ":" + << line << " in function " << function << ": " + << cutlassGetStatusString(status) << "\nCall: " << call; + throw std::runtime_error(ss.str()); + } +} + #define CUDA_CHECK(call) \ throwOnCudaError(call, __FILE__, __LINE__, __FUNCTION__, #call) +#define CUTLASS_CHECK(call) \ + throwOnCutlassError(call, __FILE__, __LINE__, __FUNCTION__, #call) + +/** + * GPU timer for recording the elapsed time across kernel(s) launched in GPU + * stream + */ +struct GpuTimer { + cudaStream_t _stream_id; + cudaEvent_t _start; + cudaEvent_t _stop; + + /// Constructor + GpuTimer() : _stream_id(0) { + CUDA_CHECK(cudaEventCreate(&_start)); + CUDA_CHECK(cudaEventCreate(&_stop)); + } + + /// Destructor + ~GpuTimer() { + CUDA_CHECK(cudaEventDestroy(_start)); + CUDA_CHECK(cudaEventDestroy(_stop)); + } + + /// Start the timer for a given stream (defaults to the default stream) + void start(cudaStream_t stream_id = 0) { + _stream_id = stream_id; + CUDA_CHECK(cudaEventRecord(_start, _stream_id)); + } + + /// Stop the timer + void stop() { CUDA_CHECK(cudaEventRecord(_stop, _stream_id)); } + + /// Return the elapsed time (in milliseconds) + float elapsed_millis() { + float elapsed = 0.0; + CUDA_CHECK(cudaEventSynchronize(_stop)); + CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop)); + return elapsed; + } +}; + int GetDevice(); bool IsDevicePointer(const void* ptr); int GetDeviceCount(); diff --git a/core/utils/logger.h b/core/utils/logger.h index 00a1e83..18bd1a7 100644 --- a/core/utils/logger.h +++ b/core/utils/logger.h @@ -6,6 +6,7 @@ #pragma once #include +#include "common/types.h" #include "base/logging.h" inline void print(base::LogStream& stream) {} @@ -34,6 +35,14 @@ LogStream& operator<<(LogStream& stream, const std::vector& vec) { return stream; } +// define a custom operator<< for enum classes +template +typename std::enable_if::value, LogStream&>::type operator<<( + LogStream& stream, const T& value) { + // This will call the EnumTypeToString function defined in the macro + return stream << enum_to_string(value); +} + } // namespace base #define DLOG_TRACE(...) \ @@ -72,6 +81,13 @@ LogStream& operator<<(LogStream& stream, const std::vector& vec) { __VA_ARGS__); \ } while (0) +#define DLOG_WARN_IF(condition, ...) \ + do { \ + if (condition) { \ + DLOG_WARN(__VA_ARGS__); \ + } \ + } while (0) + #define DLOG_FATAL(...) \ do { \ if (base::Logger::logLevel() <= base::Logger::FATAL) \ diff --git a/examples/interface_example.py b/examples/interface_example.py index 37c0b5c..a3e4b7c 100644 --- a/examples/interface_example.py +++ b/examples/interface_example.py @@ -66,6 +66,11 @@ def end(self): args = parser.parse_args() model_name = args.model_name_or_path.split("/")[-1] +config = { + "offload_path": os.path.join(args.offload_dir, model_name), + "device_memory_ratio": args.device_memory_ratio, +} +model = MoE(args.model_name_or_path, config) tokenizer = None if "grok" in model_name: @@ -107,11 +112,6 @@ def end(self): # text for dataset in all_inputs for text in dataset["test"]["question"] if "test" in dataset # ] -config = { - "offload_path": os.path.join(args.offload_dir, model_name), - "device_memory_ratio": args.device_memory_ratio, -} -model = MoE(args.model_name_or_path, config) custom_kwargs = {} if "switch" in args.model_name_or_path.lower(): diff --git a/moe_infinity/kernel/__init__.py b/moe_infinity/kernel/__init__.py new file mode 100644 index 0000000..c8f314d --- /dev/null +++ b/moe_infinity/kernel/__init__.py @@ -0,0 +1 @@ +from .router import launch_fused_softmax_topk, launch_fused_softmax_topk_nobias diff --git a/moe_infinity/kernel/router.py b/moe_infinity/kernel/router.py new file mode 100644 index 0000000..bf508b8 --- /dev/null +++ b/moe_infinity/kernel/router.py @@ -0,0 +1,234 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def softmax_kernel( + output_ptr, + input_ptr, + input_row_stride, + output_row_stride, + n_rows, + n_cols, + BLOCK_SIZE: tl.constexpr, + num_stages: tl.constexpr, +): + # starting row of the program + row_start = tl.program_id(0) + row_step = tl.num_programs(0) + for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): + # The stride represents how much we need to increase the pointer to advance 1 row + row_start_ptr = input_ptr + row_idx * input_row_stride + # The block size is the next power of two greater than n_cols, so we can fit each + # row in a single block + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols + mask = col_offsets < n_cols + row = tl.load(input_ptrs, mask=mask, other=-float("inf")) + # Subtract maximum for numerical stability + row_minus_max = row - tl.max(row, axis=0) + # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + # Write back output to DRAM + output_row_start_ptr = output_ptr + row_idx * output_row_stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=mask) + + +@triton.jit +def fused_softmax_topk_kernel( + hidden_ptr, # [B, H] + weight_ptr, # [H, E] + routing_mask_ptr, # [B, E] (bool) + routing_weight_ptr, # [B, E] (float16) + B: tl.constexpr, + H: tl.constexpr, + E: tl.constexpr, + TOPK: tl.constexpr, + BLOCK_E: tl.constexpr, + normalize_topk: tl.constexpr, # New! +): + batch_id = tl.program_id(0) + off_e = tl.arange(0, BLOCK_E) + + # [E] logits init + logits = tl.zeros([BLOCK_E], dtype=tl.float32) + + for h in range(0, H): + hidden_val = tl.load(hidden_ptr + batch_id * H + h) + weight_row = weight_ptr + h * E + off_e + w = tl.load(weight_row, mask=off_e < E, other=0.0) + logits += hidden_val * w + + logits += tl.load(bias_ptr + off_e, mask=off_e < E, other=0.0) + + # Compute softmax + max_logit = tl.max(logits, axis=0) + logits = logits - max_logit + exp_logits = tl.exp(logits) + sum_exp = tl.sum(exp_logits, axis=0) + probs = exp_logits / sum_exp + + # Top-k selection (insertion sort) + top_vals = tl.full([TOPK], -float("inf"), dtype=tl.float32) + top_idxs = tl.full([TOPK], -1, dtype=tl.int32) + + for i in range(BLOCK_E): + p = probs[i] + idx = i + + # insert into sorted list + for j in range(TOPK): + if p > top_vals[j]: + for k in range(TOPK - 1, j, -1): + top_vals[k] = top_vals[k - 1] + top_idxs[k] = top_idxs[k - 1] + top_vals[j] = p + top_idxs[j] = idx + break + + # normalize + if normalize_topk: + sum_top = tl.sum(top_vals) + top_vals = top_vals / sum_top + + for i in range(TOPK): + expert_idx = top_idxs[i] + expert_val = top_vals[i] + if expert_idx >= 0: + tl.store(routing_mask_ptr + batch_id * E + expert_idx, True) + tl.store( + routing_weight_ptr + batch_id * E + expert_idx, + expert_val.to(tl.float16), + ) + + +def launch_fused_softmax_topk(hidden_states, weight, bias, top_k): + B, H = hidden_states.shape + E = weight.shape[1] + dtype = hidden_states.dtype + + routing_mask = torch.zeros( + (B, E), dtype=torch.bool, device=hidden_states.device + ) + routing_weight = torch.zeros( + (B, E), dtype=dtype, device=hidden_states.device + ) + + BLOCK_E = 32 # Must divide E + + fused_softmax_topk_kernel[(B,)]( + hidden_states, + weight, + bias, + routing_mask, + routing_weight, + B=B, + H=H, + E=E, + TOPK=top_k, + BLOCK_E=BLOCK_E, + ) + + return routing_mask, routing_weight + + +@triton.jit +def fused_softmax_topk_kernel_nobias( + hidden_ptr, # [B, H] + weight_ptr, # [E, H] + routing_mask_ptr, # [B, E] + routing_weight_ptr, # [B, E] + B: tl.constexpr, + H: tl.constexpr, + E: tl.constexpr, + TOPK: tl.constexpr, + BLOCK_E: tl.constexpr, + normalize_topk: tl.constexpr, +): + batch_id = tl.program_id(0) + off_e = tl.arange(0, BLOCK_E) + + logits = tl.full([BLOCK_E], -float("inf"), dtype=tl.float32) + + for h in range(H): + h_val = tl.load(hidden_ptr + batch_id * H + h) + w_ptr = weight_ptr + off_e * H + h + valid = off_e < E + w_val = tl.load(w_ptr, mask=valid, other=0.0) + logits = tl.where(valid, logits + h_val * w_val, logits) + + # Softmax + max_logit = tl.max(logits, axis=0) + logits = logits - max_logit + exp_logits = tl.exp(logits) + sum_exp = tl.sum(exp_logits, axis=0) + probs = exp_logits / sum_exp + + # Top-k (insertion sort) + top_vals = tl.full([TOPK], -float("inf"), dtype=tl.float32) + top_idxs = tl.full([TOPK], -1, dtype=tl.int32) + + for i in range(BLOCK_E): + if i < E: + p = tl.load(probs + batch_id * E + i) + idx = i + + for j in range(TOPK): + if p > top_vals[j]: + for k in range(TOPK - 1, j, -1): + top_vals[k] = top_vals[k - 1] + top_idxs[k] = top_idxs[k - 1] + top_vals[j] = p + top_idxs[j] = idx + break + + if normalize_topk: + sum_top = tl.sum(top_vals) + top_vals = top_vals / sum_top + + for i in range(TOPK): + idx = top_idxs[i] + val = top_vals[i] + if idx >= 0: + tl.store(routing_mask_ptr + batch_id * E + idx, True) + tl.store( + routing_weight_ptr + batch_id * E + idx, val.to(tl.float16) + ) + + +def launch_fused_softmax_topk_nobias( + hidden_states, weight, top_k, normalize_topk=True +): + B, H = hidden_states.shape + E = weight.shape[0] + dtype = hidden_states.dtype + + routing_mask = torch.zeros( + (B, E), dtype=torch.bool, device=hidden_states.device + ) + routing_weight = torch.zeros( + (B, E), dtype=dtype, device=hidden_states.device + ) + + BLOCK_E = 128 + assert BLOCK_E >= E, "BLOCK_E must be greater than or equal to E" + + fused_softmax_topk_kernel_nobias[(B,)]( + hidden_states, + weight, + routing_mask, + routing_weight, + B=B, + H=H, + E=E, + TOPK=top_k, + BLOCK_E=BLOCK_E, + normalize_topk=normalize_topk, + ) + + return routing_mask, routing_weight diff --git a/moe_infinity/models/deepseek.py b/moe_infinity/models/deepseek.py index 31bf9d0..0f6146d 100644 --- a/moe_infinity/models/deepseek.py +++ b/moe_infinity/models/deepseek.py @@ -1,10 +1,12 @@ -from typing import Dict, Optional, Tuple +from typing import Dict import nvtx import torch import torch.nn as nn import torch.nn.functional as F +from moe_infinity.kernel.router import launch_fused_softmax_topk_nobias + class DeepseekMoEBlock(nn.Module): """ @@ -37,7 +39,10 @@ def __init__(self, config): ] ) - self.gate = self.gate_cls(config) + # self.gate = self.gate_cls(config) + self.gate = nn.Linear( + config.hidden_size, config.n_routed_experts, bias=False + ) if config.n_shared_experts is not None: intermediate_size = ( config.moe_intermediate_size * config.n_shared_experts @@ -50,10 +55,70 @@ def __init__(self, config): self.archer_engine = None self.expert_tensor_ids: Dict[int, int] = None + @nvtx.annotate("DeepSeekPrepare", color="blue") + def __prepare_expert_route(self, hidden_states): + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk( + routing_weights, self.num_experts_per_tok, dim=-1 + ) + # if self.norm_topk_prob: # only diff with mixtral sparse moe block! + # routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + # print(f"hidden_states shape: {hidden_states.shape}") + # print(f"routing_weights shape: {routing_weights.shape}") + + # Compute sparse mask via scatter + B, E = routing_weights.shape[0], self.num_expert + router_mask = torch.zeros( + B, E, dtype=torch.bool, device=selected_experts.device + ) + + # print("selected_experts", selected_experts.shape) + # print("routing_weights", routing_weights.shape) + # print("router_mask", router_mask.shape) + # print("router_logits", router_logits.shape) + router_mask.scatter_(1, selected_experts, True) + + routing_weights_mask = torch.zeros( + B, E, dtype=routing_weights.dtype, device=routing_weights.device + ) + routing_weights_mask.scatter_add_(1, selected_experts, routing_weights) + + return router_mask, routing_weights_mask + @nvtx.annotate(message="DeepseekMoEBlock", color="blue") def forward(self, hidden_states): identity = hidden_states - orig_shape = hidden_states.shape + batch_size, sequence_length, hidden_dim = identity.shape + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + # routing_mask, routing_weight = launch_fused_softmax_topk_nobias( + # hidden_states, + # self.gate.weight.data, + # self.num_experts_per_tok, + # self.config.norm_topk_prob, + # ) + routing_mask, routing_weight = self.__prepare_expert_route( + hidden_states + ) + + self.expert_executor.dispatch_local( + self.layer_id, hidden_states, routing_mask, routing_weight + ) + final_hidden_states = self.expert_executor.wait_dispatch_local() + + final_hidden_states = final_hidden_states.view( + batch_size, sequence_length, hidden_dim + ).to(hidden_states.dtype) + if self.config.n_shared_experts is not None: + final_hidden_states = final_hidden_states + self.shared_experts( + identity + ) + return final_hidden_states gate_output = self.gate(hidden_states) if len(gate_output) == 3: @@ -91,7 +156,6 @@ def forward(self, hidden_states): ) routing_weights_mask.scatter_add_(1, selected_experts, routing_weights) - batch_size, sequence_length, hidden_dim = orig_shape # router_mask = F.one_hot( # topk_idx, num_classes=self.config.n_routed_experts # ) diff --git a/moe_infinity/runtime/model_offload.py b/moe_infinity/runtime/model_offload.py index 4794345..494043f 100644 --- a/moe_infinity/runtime/model_offload.py +++ b/moe_infinity/runtime/model_offload.py @@ -140,6 +140,12 @@ def init( # print("Distributed init done") self.prefetch_lib = PrefetchBuilder().load() if use_jit else prefetch_op + # new_alloc = torch.cuda.memory.CUDAPluggableAllocator( + # self.prefetch_lib.__file__, "TorchAllocateDevice", "TorchFreeDevice" + # ) + # # Swap the current allocator + # torch.cuda.memory.change_current_allocator(new_alloc) + self.archer_engine = self.prefetch_lib.prefetch_handle( self.checkpoint, _archer_config.device_memory_ratio ) @@ -420,11 +426,11 @@ def archer_from_pretrained(cls, *args, **kwargs): ), ) - script_expert( - self.checkpoint, - self.config.model_type, - self.config, - ) + # script_expert( + # self.checkpoint, + # self.config.model_type, + # self.config, + # ) if self.config.model_type == "deepseek_v3": model = model.to(torch.float8_e4m3fn) diff --git a/op_builder/prefetch.py b/op_builder/prefetch.py index 85ea6b6..3d5623a 100644 --- a/op_builder/prefetch.py +++ b/op_builder/prefetch.py @@ -37,6 +37,7 @@ def sources(self): "core/prefetch/archer_prefetch_handle.cpp", "core/prefetch/task_scheduler.cpp", "core/prefetch/task_thread.cpp", + "core/memory/caching_allocator.cpp", "core/memory/memory_pool.cpp", "core/memory/stream_pool.cpp", "core/memory/host_caching_allocator.cpp", @@ -64,13 +65,30 @@ def sources(self): "core/python/py_archer_prefetch.cpp", ] + def cutlass_dir(self): + CUTLASS_DIR = os.path.expanduser("~") + "/cutlass" + if not os.path.exists(CUTLASS_DIR): + raise FileNotFoundError( + f"Cutlass directory not found: {CUTLASS_DIR}" + ) + else: + print(f"Using Cutlass directory: {CUTLASS_DIR}") + return CUTLASS_DIR + def include_paths(self): - return ["core"] + CUTLASS_DIR = self.cutlass_dir() + + return [ + "core", + f"{CUTLASS_DIR}/include", + f"{CUTLASS_DIR}/tools/util/include", + ] def cxx_args(self): # -O0 for improved debugging, since performance is bound by I/O CPU_ARCH = self.cpu_arch() SIMD_WIDTH = self.simd_width() + CUTLASS_DIR = self.cutlass_dir() return [ "-g", @@ -83,16 +101,19 @@ def cxx_args(self): CPU_ARCH, "-fopenmp", SIMD_WIDTH, - "-I/usr/local/cuda/include", - "-L/usr/local/cuda/lib64", - "-lcuda", - "-lcudart", - "-lcublas", "-lpthread", + "-L/usr/local/cuda/lib64", + f"-L{CUTLASS_DIR}/build/tools/library", + "-lcutlass", ] def extra_ldflags(self): - return [] + return [ + "-luuid", + "-lcublas", + "-lcudart", + "-lcuda", + ] def is_compatible(self, verbose=True): return super().is_compatible(verbose) diff --git a/tests/cuda/CMakeLists.txt b/tests/cuda/CMakeLists.txt new file mode 100644 index 0000000..1a553da --- /dev/null +++ b/tests/cuda/CMakeLists.txt @@ -0,0 +1,153 @@ +cmake_minimum_required(VERSION 3.10) +project(LockFreeQueueTests) + +add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) # needed for torch backward compatibility + +# Set C++ standard +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED True) + +# +# Attempt to find the python package that uses the same python executable as +# `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`. +# +macro(find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS) + file(REAL_PATH ${EXECUTABLE} EXECUTABLE) + set(Python3_EXECUTABLE ${EXECUTABLE}) + find_package(Python3 COMPONENTS Interpreter Development.Module) + + if(NOT Python3_FOUND) + message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") + endif() + + set(_VER "${Python3_VERSION_MAJOR}.${Python3_VERSION_MINOR}") + set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN}) + + if(NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST) + message(FATAL_ERROR + "Python version (${_VER}) is not one of the supported versions: " + "${_SUPPORTED_VERSIONS_LIST}.") + endif() + + message(STATUS "Found python matching: ${EXECUTABLE}.") +endmacro() + + +# +# Run `EXPR` in python. The standard output of python is stored in `OUT` and +# has trailing whitespace stripped. If an error is encountered when running +# python, a fatal message `ERR_MSG` is issued. +# +function(run_python OUT EXPR ERR_MSG) + execute_process( + COMMAND + "${Python3_EXECUTABLE}" "-c" "${EXPR}" + OUTPUT_VARIABLE PYTHON_OUT + RESULT_VARIABLE PYTHON_ERROR_CODE + ERROR_VARIABLE PYTHON_STDERR + OUTPUT_STRIP_TRAILING_WHITESPACE) + + if(NOT PYTHON_ERROR_CODE EQUAL 0) + message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}") + endif() + + set(${OUT} ${PYTHON_OUT} PARENT_SCOPE) +endfunction() + +# Run `EXPR` in python after importing `PKG`. Use the result of this to extend +# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported. +macro(append_cmake_prefix_path PKG EXPR) + run_python(_PREFIX_PATH + "import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path") + list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH}) +endmacro() + +# Add include directories and link for CUTLASS +set(CUTLASS_DIR $ENV{HOME}/cutlass) +message(STATUS "Using CUTLASS from: ${CUTLASS_DIR}") + +if(DEFINED ENV{CONDA_PREFIX}) + set(CONDA_PREFIX_PATH $ENV{CONDA_PREFIX}) + message(STATUS "Conda environment path: ${CONDA_PREFIX_PATH}") +else() + message(WARNING "CONDA_PREFIX is not set. Make sure your Conda environment is activated.") +endif() + +set(CONDA_INCLUDE_DIRS ${CONDA_PREFIX_PATH}/include) +set(CONDA_LINK_DIRS ${CONDA_PREFIX_PATH}/lib) + +find_package(PythonInterp REQUIRED) + +string(REGEX MATCH "([0-9]+)\\.([0-9]+)\\..*" _ ${PYTHON_VERSION_STRING}) +set(Python_VERSION ${CMAKE_MATCH_1}.${CMAKE_MATCH_2}) +message(STATUS "Python version: ${Python_VERSION}") + +# Supported python versions. These versions will be searched in order, the +# first match will be selected. These should be kept in sync with setup.py. +set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11") + +set(PYTHON_EXECUTABLE_PATH ${CONDA_PREFIX_PATH}/bin/python) +message(STATUS "Using Python executable: ${PYTHON_EXECUTABLE_PATH}") +find_python_from_executable(${PYTHON_EXECUTABLE_PATH} "${PYTHON_SUPPORTED_VERSIONS}") +append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") + +# add torch +find_package(Torch REQUIRED) + +find_library(torch_python_LIBRARY torch_python PATHS + "${TORCH_INSTALL_PREFIX}/lib") +message(STATUS "torch_python_LIBRARY: ${torch_python_LIBRARY}") + +include_directories( + ${CUTLASS_DIR}/include + ${CUTLASS_DIR}/tools/util/include + ${CUTLASS_DIR}/examples/13_two_tensor_op_fusion + ${CUTLASS_DIR}/examples/common + ${CONDA_INCLUDE_DIRS} + ${CMAKE_SOURCE_DIR}/../../core +) + +# Link CUTLASS library +link_directories( + ${CUTLASS_DIR}/build/tools/library + ${CONDA_LINK_DIRS} +) + +# Add CUDA kernel compilation +find_package(CUDA REQUIRED) + +# set cuda architecture +set(CUDA_ARCHITECTURES 86) + +# set nvcc flags +set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -g -G -lineinfo -rdynamic -O3 -gencode arch=compute_86,code=sm_86 -Xcompiler -fopenmp") + +set(SRC_LIST + test_uvm_kernel.cu + test_fused_mlp.cu + test_expert_fusion.cu + test_expert_fusion_v2.cu + # test_single_gemm_tiled.cu + test_load_tile.cu + test_load_tile_templated.cu + test_tile_size.cu + test_autosize_tileload.cu + test_autosize_tileload_stage.cu + test_autotune_blocksize.cu +) + +FOREACH(SRC ${SRC_LIST}) + get_filename_component(SRC_NAME ${SRC} NAME_WE) + add_executable(${SRC_NAME} ${SRC}) + target_link_libraries(${SRC_NAME} cutlass ${CUDA_LIBRARIES}) + + # if file is test_expert_fusion or test_expert_fusion_v2, link torch_python + IF(${SRC_NAME} STREQUAL "test_expert_fusion" OR ${SRC_NAME} STREQUAL "test_expert_fusion_v2") + target_link_libraries(${SRC_NAME} ${torch_python_LIBRARY} ${Python3_LIBRARIES} ${TORCH_LIBRARIES}) + target_include_directories(${SRC_NAME} PRIVATE ${CONDA_INCLUDE_DIRS} ${TORCH_INCLUDE_DIRS} ${Python3_INCLUDE_DIRS}) + ENDIF() + + # IF(${SRC_NAME} STREQUAL "test_autosize_tileload") + # target_link_libraries(${SRC_NAME} thrust) + # ENDIF() +ENDFOREACH() diff --git a/tests/cuda/test_autosize_tileload.cu b/tests/cuda/test_autosize_tileload.cu new file mode 100644 index 0000000..109eceb --- /dev/null +++ b/tests/cuda/test_autosize_tileload.cu @@ -0,0 +1,255 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "kernel/tile_size.h" +#include "kernel/mlp_tile_op.h" // Include the header with BusAwareGemmTileCalculator + +// GEMM-style kernel with K iteration inside +template +__global__ void gemm_auto_load_kernel(Views... views) { + extern __shared__ __align__(16) char smem[]; + + // Get problem dimensions from first view (assuming all views have consistent + // dims) + auto& view0 = std::get<0>(std::tie(views...)); + const int K = view0.extent().column(); // Assuming A is MxK + + // Get tile K dimension from loader + using ShapeK = typename MultiLoader::template Shape<0>; + constexpr int TILE_K = ShapeK::kColumn; + + // Iterate over K dimension inside kernel + for (int k_start = 0; k_start < K; k_start += TILE_K) { + // Create sub-views for current K tile + auto make_k_tile_view = [k_start, K](auto& view, int view_idx) { + using ViewElement = typename std::decay_t::Element; + using ViewLayout = typename std::decay_t::Layout; + + int tile_k = min(TILE_K, K - k_start); + return cutlass::TensorView( + view.data() + k_start * tile_k, // Offset in K dimension + view.layout(), {view.extent().row(), tile_k}); + + // if (view_idx == 0) { // A matrix: slice columns + // int tile_k = min(TILE_K, K - k_start); + // return cutlass::TensorView( + // view.data() + k_start, // Offset in K dimension + // view.layout(), {view.extent().row(), tile_k}); + // } else { // B matrices: slice rows + // int tile_k = min(TILE_K, K - k_start); + // return cutlass::TensorView( + // view.data() + + // k_start * view.extent().row(), // Offset in K dimension + // view.layout(), {tile_k, view.extent().row()}); + // } + }; + + // Apply K-tiling to each view + int idx = 0; + auto tiled_views = thrust::make_tuple(make_k_tile_view(views, idx++)...); + + // Load tiles for current K iteration + apply_load_each(smem, nullptr, 0, tiled_views); + + // Synchronize after loading + __syncthreads(); + + // In a real GEMM, computation would happen here + + // Ensure all threads are done before next K iteration + __syncthreads(); + } +} + +// Example usage +int main() { + // Define bus width (384 bits for RTX 4090) + constexpr int BusWidth = 384; + + using ElementA = cutlass::bfloat16_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::bfloat16_t; + using LayoutB = cutlass::layout::ColumnMajor; + + // Calculate auto-sized shapes using GEMM calculator + using GemmCalc = + BusAwareGemmTileCalculator; + + printf("Auto-calculated tile sizes for %d-bit bus:\n", BusWidth); + printf(" GEMM tile (A/B with BF16): %dx%dx%d\n", + GemmCalc::OptimizedTiles::kM, GemmCalc::OptimizedTiles::kN, + GemmCalc::OptimizedTiles::kK); + printf(" Pipeline stages: %d\n", GemmCalc::PipelineConfig::stages); + printf(" Cluster shape: %dx%dx%d\n", GemmCalc::ClusterShape::kM, + GemmCalc::ClusterShape::kN, GemmCalc::ClusterShape::kK); + + // Auto-sized thread maps based on GEMM tile dimensions + using ThreadMapA = cutlass::transform::PitchLinearWarpRakedThreadMap< + cutlass::layout::PitchLinearShape, + 32, cutlass::layout::PitchLinearShape<8, 4>, + 128 / cutlass::sizeof_bits::value>; + + using ThreadMapB = cutlass::transform::PitchLinearWarpRakedThreadMap< + cutlass::layout::PitchLinearShape, + 32, cutlass::layout::PitchLinearShape<8, 4>, + 128 / cutlass::sizeof_bits::value>; + + // Define view configurations + using ViewA = ViewConfigA, + ThreadMapA>; + using ViewB1 = ViewConfigB, + ThreadMapB>; + using ViewB2 = ViewConfigB, + ThreadMapB>; + + // Create auto-sized multi-view loader + using Loader = MultiViewTileLoader; + + // Print configuration + Loader::print_configuration(); + + constexpr int M = 128, N = 768, K = 2048; + cutlass::HostTensor A({M, K}); + cutlass::HostTensor B1({N, K}); + cutlass::HostTensor B2({N, K}); + + cutlass::reference::host::TensorFill(A.host_view(), ElementA(1.0f)); + cutlass::reference::host::TensorFill(B1.host_view(), ElementB(1.0f)); + cutlass::reference::host::TensorFill(B2.host_view(), ElementB(1.0f)); + + A.sync_device(); + B1.sync_device(); + B2.sync_device(); + + // Create views + cutlass::TensorView viewA(A.device_data(), A.layout(), + {M, K}); + cutlass::TensorView viewB1(B1.device_data(), B1.layout(), + {N, K}); + cutlass::TensorView viewB2(B2.device_data(), B2.layout(), + {N, K}); + + // Get auto-sized shapes + using ShapeA = typename Loader::template Shape<0>; + using ShapeB = typename Loader::template Shape<1>; + + size_t smem_size = Loader::calculate_smem_size(); + + // Calculate 2D grid (no Z dimension for K) + dim3 grid = + ThreadBlockAutoTuner::get_optimal_grid_dims(M, N, K); + dim3 block(32, 4, 1); // Fixed block size for simplicity + // Auto-tune thread block configuration + // dim3 block = + // ThreadBlockAutoTuner::get_optimal_block_dims(M, N, + // K); + // block.y = std::max((uint32_t)3, block.y); // Ensure at least one warp in Y + printf("\nKernel launch configuration:\n"); + printf(" Problem size: %dx%dx%d\n", M, N, K); + printf(" Tile size: %dx%dx%d\n", ShapeA::kRow, ShapeB::kColumn, + ShapeA::kColumn); + printf(" Grid: (%d, %d) - 2D grid, K iteration in kernel\n", grid.x, grid.y); + printf(" Block: (%d, %d) - %d warps total\n", block.x, block.y, block.y); + printf(" Shared memory: %zu bytes\n", smem_size); + printf(" K iterations per block: %d\n", + (K + ShapeA::kColumn - 1) / ShapeA::kColumn); + + cudaFuncSetAttribute( + gemm_auto_load_kernel, + cutlass::TensorView, + cutlass::TensorView>, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + cudaStream_t stream; + cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); + + // Timing + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + constexpr int iterations = 1000; + + // Warm-up + gemm_auto_load_kernel + <<>>(viewA, viewB1, viewB2); + cudaStreamSynchronize(stream); + + cudaEventRecord(start, stream); + for (int i = 0; i < iterations; ++i) { + gemm_auto_load_kernel + <<>>(viewA, viewB1, viewB2); + + auto err = cudaGetLastError(); + if (err != cudaSuccess) { + std::cerr << "CUDA error: " << cudaGetErrorString(err) << std::endl; + return -1; + } + } + cudaEventRecord(stop, stream); + + cudaEventSynchronize(stop); + + float milliseconds = 0.0f; + cudaEventElapsedTime(&milliseconds, start, stop); + + // Compute throughput - corrected for actual tile loading + int tiles_m = (M + ShapeA::kRow - 1) / ShapeA::kRow; + int tiles_n = (N + ShapeB::kColumn - 1) / ShapeB::kColumn; + int tiles_k = (K + ShapeA::kColumn - 1) / ShapeA::kColumn; + + // Each threadblock loads one A tile and two B tiles per K iteration + size_t bytes_per_tile_a = ShapeA::kRow * ShapeA::kColumn * sizeof(ElementA); + size_t bytes_per_tile_b = ShapeB::kRow * ShapeB::kColumn * sizeof(ElementB); + + // Total data movement + size_t total_a_tiles = tiles_m * tiles_k; + size_t total_b_tiles = tiles_n * tiles_k * 2; // B1 and B2 + size_t total_bytes = + total_a_tiles * bytes_per_tile_a + total_b_tiles * bytes_per_tile_b; + + float time_sec = (milliseconds / 1000.0f) / iterations; + float gb_transferred = float(total_bytes) / 1e9f; + float throughput = gb_transferred / time_sec; + + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + int memory_clock_khz = prop.memoryClockRate; // in kHz + + // Calculate efficiency + float theoretical_bandwidth = + (BusWidth / 8) * 2.0 * memory_clock_khz / 1e6; // GB/s + float efficiency = (throughput / theoretical_bandwidth) * 100.0f; + + printf("\n=== Performance Results ===\n"); + printf("Time: %.3f ms total, %.3f μs per iteration\n", milliseconds, + milliseconds * 1000.0f / iterations); + printf("Data transferred: %.3f GB\n", gb_transferred); + printf("Throughput: %.2f GB/s\n", throughput); + printf("Theoretical peak: %.2f GB/s\n", theoretical_bandwidth); + printf("Efficiency: %.1f%%\n", efficiency); + + // Cleanup + cudaEventDestroy(start); + cudaEventDestroy(stop); + cudaStreamDestroy(stream); + + return 0; +} diff --git a/tests/cuda/test_autosize_tileload_stage.cu b/tests/cuda/test_autosize_tileload_stage.cu new file mode 100644 index 0000000..a728059 --- /dev/null +++ b/tests/cuda/test_autosize_tileload_stage.cu @@ -0,0 +1,329 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "kernel/tile_size.h" +#include "kernel/mlp_tile_op.h" // Include the header with BusAwareGemmTileCalculator + +// GEMM-style kernel with K iteration inside +template +__global__ void gemm_auto_load_kernel(Views... views) { + extern __shared__ __align__(16) char smem[]; + + // Get problem dimensions from first view (assuming all views have consistent + // dims) + auto& view0 = std::get<0>(std::tie(views...)); + auto& view1 = std::get<1>(std::tie(views...)); + const uint32_t M = view0.extent().row(); // Assuming A is MxK + const uint32_t K = view0.extent().column(); // Assuming A is MxK + const uint32_t N = view1.extent().row(); // Assuming B is NxK + + using ShapeA = typename MultiLoader::template Shape<0>; + using ShapeB = typename MultiLoader::template Shape<1>; + // Bounds check + if (blockIdx.x * ShapeB::kColumn >= N || blockIdx.y * ShapeA::kRow >= M) + return; + + // Get tile K dimension from loader + using ShapeA = typename MultiLoader::template Shape<0>; + constexpr int TILE_K = ShapeA::kColumn; + + using GemmCalc = typename MultiLoader::GemmCalc<0>; + constexpr int NUM_STAGES = GemmCalc::PipelineConfig::stages; + constexpr int WARPS_PER_GROUP = 9; // Assuming 9 warps per group + + __shared__ __align__(16) uint32_t signal[NUM_STAGES]; + + // Multi-stage buffer setup + const size_t stage_size = MultiLoader::calculate_smem_size(); + char* smem_stages[NUM_STAGES]; + for (int i = 0; i < NUM_STAGES; ++i) { + smem_stages[i] = smem + i * stage_size; + } + + // Create pipeline barriers per warp + __shared__ cuda::pipeline_shared_state + pipeline_state; + + auto load_stage = [&](uint32_t stage_idx, int k_offset) { + if (k_offset >= K) return; + // if (threadIdx.y >= MultiLoader::NumViews) return; + + int tile_k = min(TILE_K, K - k_offset); + char* stage_ptr = smem_stages[stage_idx % NUM_STAGES]; + + auto make_view = [&](auto& view, int view_idx) { + using ViewElement = typename std::decay_t::Element; + using ViewLayout = typename std::decay_t::Layout; + + if (view_idx == 0) { // A + size_t offset = blockIdx.y * ShapeA::kRow * K + k_offset; + return cutlass::TensorView( + view.data() + offset, view.layout(), + {min(ShapeA::kRow, M - blockIdx.y * ShapeA::kRow), tile_k}); + } else { // B + size_t offset = k_offset * N + blockIdx.x * ShapeB::kColumn; + return cutlass::TensorView( + view.data() + offset, view.layout(), + {tile_k, min(ShapeB::kColumn, N - blockIdx.x * ShapeB::kColumn)}); + } + }; + + // Apply K-tiling to each view + int idx = 0; + auto tiled_views = thrust::make_tuple(make_view(views, idx++)...); + + if (threadIdx.x == 0) { + printf("Loading stage %d, Wrap %d, K offset %d, tile size %dx%d\n", + stage_idx, threadIdx.y, k_offset, ShapeA::kRow, ShapeB::kColumn); + } + + // Load tiles for current K iteration + apply_load_each(smem, signal, stage_idx % NUM_STAGES, + tiled_views); + }; + + // Pipeline state + int read_stage = 0; + int write_stage = 0; + int compute_stage = 0; + + uint32_t stage_idx = threadIdx.y % NUM_STAGES; + int lane_idx = threadIdx.x % 32; // Assuming 32 threads per warp + + // // Prologue: fill pipeline + // for (; read_stage < NUM_STAGES - 1 && read_stage * TILE_K < K; + // ++read_stage) { + // load_stage(read_stage, read_stage * TILE_K); + // } + load_stage(read_stage++, 0); + + // Wait for pipeline 0 to be ready + while (MultiLoader::NumViews > signal[stage_idx] && 0 == stage_idx) { + // just busy wait + } + + // Main loop for compute stage + for (int k_start = 0; k_start < K; k_start += TILE_K) { + // Start loading future stage + int next_k = k_start + (NUM_STAGES - 1) * TILE_K; + if (next_k < K) { + load_stage(read_stage, next_k); + } + int current_stage = k_start % NUM_STAGES; + + // Wait for current read stage to finish + while (MultiLoader::NumViews > signal[current_stage] && + stage_idx == current_stage) { + // just busy wait + } + + if (lane_idx == 0 && stage_idx == current_stage) { + atomicSub(&signal[current_stage], 1); + } + + // Compute on current stage + // Actual computation would use smem_stages[compute_stage % NUM_STAGES] + + // Advance pipeline + __syncthreads(); + read_stage++; + write_stage++; + compute_stage++; + } +} + +// Example usage +int main() { + // Define bus width (384 bits for RTX 4090) + constexpr int BusWidth = 384; + + using ElementA = cutlass::bfloat16_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::bfloat16_t; + using LayoutB = cutlass::layout::ColumnMajor; + + // Calculate auto-sized shapes using GEMM calculator + using GemmCalc = + BusAwareGemmTileCalculator; + + printf("Auto-calculated tile sizes for %d-bit bus:\n", BusWidth); + printf(" GEMM tile (A/B with BF16): %dx%dx%d\n", + GemmCalc::OptimizedTiles::kM, GemmCalc::OptimizedTiles::kN, + GemmCalc::OptimizedTiles::kK); + printf(" Pipeline stages: %d\n", GemmCalc::PipelineConfig::stages); + printf(" Cluster shape: %dx%dx%d\n", GemmCalc::ClusterShape::kM, + GemmCalc::ClusterShape::kN, GemmCalc::ClusterShape::kK); + + // Auto-sized thread maps based on GEMM tile dimensions + using ThreadMapA = cutlass::transform::PitchLinearWarpRakedThreadMap< + cutlass::layout::PitchLinearShape, + 32, cutlass::layout::PitchLinearShape<8, 4>, + 128 / cutlass::sizeof_bits::value>; + + using ThreadMapB = cutlass::transform::PitchLinearWarpRakedThreadMap< + cutlass::layout::PitchLinearShape, + 32, cutlass::layout::PitchLinearShape<8, 4>, + 128 / cutlass::sizeof_bits::value>; + + // Define view configurations + using ViewA = ViewConfigA, + ThreadMapA>; + using ViewB1 = ViewConfigB, + ThreadMapB>; + using ViewB2 = ViewConfigB, + ThreadMapB>; + + // Create auto-sized multi-view loader + using Loader = MultiViewTileLoader; + + // Print configuration + Loader::print_configuration(); + + constexpr int M = 128, N = 768, K = 2048; + cutlass::HostTensor A({M, K}); + cutlass::HostTensor B1({N, K}); + cutlass::HostTensor B2({N, K}); + + cutlass::reference::host::TensorFill(A.host_view(), ElementA(1.0f)); + cutlass::reference::host::TensorFill(B1.host_view(), ElementB(1.0f)); + cutlass::reference::host::TensorFill(B2.host_view(), ElementB(1.0f)); + + A.sync_device(); + B1.sync_device(); + B2.sync_device(); + + // Create views + cutlass::TensorView viewA(A.device_data(), A.layout(), + {M, K}); + cutlass::TensorView viewB1(B1.device_data(), B1.layout(), + {N, K}); + cutlass::TensorView viewB2(B2.device_data(), B2.layout(), + {N, K}); + + // Get auto-sized shapes + using ShapeA = typename Loader::template Shape<0>; + using ShapeB = typename Loader::template Shape<1>; + + size_t smem_size = + Loader::calculate_smem_size() * GemmCalc::PipelineConfig::stages; + + // Calculate 2D grid (no Z dimension for K) + dim3 grid = + ThreadBlockAutoTuner::get_optimal_grid_dims(M, N, K); + dim3 block(32, 9, 1); // Fixed block size for simplicity + // Auto-tune thread block configuration + // dim3 block = + // ThreadBlockAutoTuner::get_optimal_block_dims(M, N, + // K); + // block.y = std::max((uint32_t)3, block.y); // Ensure at least one warp in Y + printf("\nKernel launch configuration:\n"); + printf(" Problem size: %dx%dx%d\n", M, N, K); + printf(" Tile size: %dx%dx%d\n", ShapeA::kRow, ShapeB::kColumn, + ShapeA::kColumn); + printf(" Grid: (%d, %d) - 2D grid, K iteration in kernel\n", grid.x, grid.y); + printf(" Block: (%d, %d) - %d warps total\n", block.x, block.y, block.y); + printf(" Shared memory: %zu bytes\n", smem_size); + printf(" K iterations per block: %d\n", + (K + ShapeA::kColumn - 1) / ShapeA::kColumn); + + cudaFuncSetAttribute( + gemm_auto_load_kernel, + cutlass::TensorView, + cutlass::TensorView>, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + cudaStream_t stream; + cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); + + // Timing + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + constexpr int iterations = 1000; + + // Warm-up + gemm_auto_load_kernel + <<>>(viewA, viewB1, viewB2); + cudaStreamSynchronize(stream); + + cudaEventRecord(start, stream); + for (int i = 0; i < iterations; ++i) { + gemm_auto_load_kernel + <<>>(viewA, viewB1, viewB2); + + auto err = cudaGetLastError(); + if (err != cudaSuccess) { + std::cerr << "CUDA error: " << cudaGetErrorString(err) << std::endl; + return -1; + } + } + cudaEventRecord(stop, stream); + + cudaEventSynchronize(stop); + + float milliseconds = 0.0f; + cudaEventElapsedTime(&milliseconds, start, stop); + + // Compute throughput - corrected for actual tile loading + int tiles_m = (M + ShapeA::kRow - 1) / ShapeA::kRow; + int tiles_n = (N + ShapeB::kColumn - 1) / ShapeB::kColumn; + int tiles_k = (K + ShapeA::kColumn - 1) / ShapeA::kColumn; + + // Each threadblock loads one A tile and two B tiles per K iteration + size_t bytes_per_tile_a = ShapeA::kRow * ShapeA::kColumn * sizeof(ElementA); + size_t bytes_per_tile_b = ShapeB::kRow * ShapeB::kColumn * sizeof(ElementB); + + // Total data movement + size_t total_a_tiles = tiles_m * tiles_k; + size_t total_b_tiles = tiles_n * tiles_k * 2; // B1 and B2 + size_t total_bytes = + total_a_tiles * bytes_per_tile_a + total_b_tiles * bytes_per_tile_b; + + float time_sec = (milliseconds / 1000.0f) / iterations; + float gb_transferred = float(total_bytes) / 1e9f; + float throughput = gb_transferred / time_sec; + + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + int memory_clock_khz = prop.memoryClockRate; // in kHz + + // Calculate efficiency + float theoretical_bandwidth = + (BusWidth / 8) * 2.0 * memory_clock_khz / 1e6; // GB/s + float efficiency = (throughput / theoretical_bandwidth) * 100.0f; + + printf("\n=== Performance Results ===\n"); + printf("Time: %.3f ms total, %.3f μs per iteration\n", milliseconds, + milliseconds * 1000.0f / iterations); + printf("Data transferred: %.3f GB\n", gb_transferred); + printf("Throughput: %.2f GB/s\n", throughput); + printf("Theoretical peak: %.2f GB/s\n", theoretical_bandwidth); + printf("Efficiency: %.1f%%\n", efficiency); + + // Cleanup + cudaEventDestroy(start); + cudaEventDestroy(stop); + cudaStreamDestroy(stream); + + return 0; +} diff --git a/tests/cuda/test_autotune_blocksize.cu b/tests/cuda/test_autotune_blocksize.cu new file mode 100644 index 0000000..ee41326 --- /dev/null +++ b/tests/cuda/test_autotune_blocksize.cu @@ -0,0 +1,38 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "kernel/tile_size.h" + +// Example usage +int main() { + constexpr int BusWidth = 5120; + constexpr int M = 2048, N = 2048, K = 2048; + + using ElementA = cutlass::bfloat16_t; + using ElementB = cutlass::bfloat16_t; + + // Get GEMM tile sizes + using GemmCalc = + BusAwareGemmTileCalculator; + + // Run autotuner + using Tuner = ThreadBlockAutoTuner; + Tuner::print_autotuning_result(M, N, K, GemmCalc::OptimizedTiles::kM, + GemmCalc::OptimizedTiles::kN, + GemmCalc::OptimizedTiles::kK); + + // Get optimal block dimensions + auto block_dims = Tuner::get_optimal_block_dims(M, N, K); + + printf("\nRecommended kernel launch:\n"); + printf(" <<>>\n", block_dims.x, + block_dims.y); + + return 0; +} diff --git a/tests/cuda/test_expert_fusion.cu b/tests/cuda/test_expert_fusion.cu new file mode 100644 index 0000000..0d93518 --- /dev/null +++ b/tests/cuda/test_expert_fusion.cu @@ -0,0 +1,343 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +int B = 128, K = 2048, Ng = 768, Nu = 768, No = 2048; + +// Type definitions +using ElementInput = cutlass::bfloat16_t; +using ElementOutput = cutlass::bfloat16_t; +using ElementAccumulator = float; +using ElementCompute = float; + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::RowMajor; + +using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; +using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; +using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + +// Epilogues +using EpilogueSiLU = cutlass::epilogue::thread::LinearCombinationSilu< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>; + +using EpilogueLinear = cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>; + +// GEMM Definitions +using Gemm1 = cutlass::gemm::device::Gemm< + ElementInput, LayoutA, ElementInput, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape, EpilogueSiLU>; + +using Gemm2 = cutlass::gemm::device::Gemm< + ElementInput, LayoutA, ElementInput, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape>; + +using Gemm3 = cutlass::gemm::device::Gemm< + ElementOutput, LayoutC, ElementInput, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape>; + +// CUDA kernel for elementwise multiplication +__global__ void ElementwiseMultiply(ElementOutput const* G, + ElementOutput const* U, + ElementOutput* Fused, int total_elements) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total_elements) { + float g = static_cast(G[idx]); + float u = static_cast(U[idx]); + Fused[idx] = static_cast(g * u); + } +} + +/** + * Panic wrapper for unwinding CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \ + << " at: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +void cutlass_mlp(cutlass::HostTensor& X, + cutlass::HostTensor& Wg, + cutlass::HostTensor& Wu, + cutlass::HostTensor& Wd, + cutlass::HostTensor& G, + cutlass::HostTensor& U, + cutlass::HostTensor& F, + cutlass::HostTensor& O) { + // create async stream + cudaStream_t stream; + cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); + + // auto alpha = ElementAccumulator(1.0f); + // int split_k_slices = 1; + + auto start = std::chrono::high_resolution_clock::now(); + Gemm1 gemm1; + Gemm1::Arguments args1({B, Ng, K}, X.device_ref(), Wg.device_ref(), + G.device_ref(), G.device_ref()); + CUTLASS_CHECK(gemm1(args1, stream = stream)); + + Gemm2 gemm2; + Gemm2::Arguments args2({B, Nu, K}, X.device_ref(), Wu.device_ref(), + U.device_ref(), U.device_ref()); + CUTLASS_CHECK(gemm2(args2, stream = stream)); + + // G.sync_device(); + // U.sync_device(); + + int total_elements = B * Nu; + int threads = 256; + int blocks = (total_elements + threads - 1) / threads; + ElementwiseMultiply<<>>( + G.device_data(), U.device_data(), F.device_data(), total_elements); + + // cudaError_t err = cudaGetLastError(); + // if (err != cudaSuccess) { + // std::cerr << "CUDA error before GEMM: " << cudaGetErrorString(err) + // << std::endl; + // } + + // cudaStreamSynchronize(stream); + + // F.sync_host(); + // G.sync_host(); + // U.sync_host(); + + // std::cout << "F output (first 10 elements): "; + // for (int i = 0; i < 10; ++i) { + // std::cout << float(F.host_data()[i]) << " "; + // } + // std::cout << std::endl; + + // std::cout << "G output (first 10 elements): "; + // for (int i = 0; i < 10; ++i) { + // std::cout << float(G.host_data()[i]) << " "; + // } + // std::cout << std::endl; + + // std::cout << "U output (first 10 elements): "; + // for (int i = 0; i < 10; ++i) { + // std::cout << float(U.host_data()[i]) << " "; + // } + // std::cout << std::endl; + + Gemm3 gemm3; + Gemm3::Arguments args3({B, No, Nu}, F.device_ref(), Wd.device_ref(), + O.device_ref(), O.device_ref()); + + CUTLASS_CHECK(gemm3(args3, stream = stream)); + + cudaStreamSynchronize(stream); + std::cout << "cudaStreamSynchronize completed." << std::endl; + + // std::cout << "O sync_host completed." << std::endl; + + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration elapsed = end - start; + std::cout << "Gemm1 and Gemm2 execution time: " << elapsed.count() + << " seconds" << std::endl; + + // Print outputs + O.sync_host(); + std::cout << "O output (first 10 elements): "; + for (int i = 0; i < 10; ++i) std::cout << float(O.host_data()[i]) << " "; + std::cout << std::endl; + + cudaStreamDestroy(stream); +} + +void torch_mlp(torch::Tensor& torch_X, torch::Tensor& torch_Wg, + torch::Tensor& torch_Wu, torch::Tensor& torch_Wd, + torch::Tensor& torch_G, torch::Tensor& torch_U, + torch::Tensor& torch_F, torch::Tensor& torch_O) { + // create async stream + cudaStream_t stream; + cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); + + // stream guard + c10::cuda::CUDAStream torch_stream = + c10::cuda::getStreamFromExternal(stream, 0); + auto start = std::chrono::high_resolution_clock::now(); + { + c10::cuda::CUDAStreamGuard guard(torch_stream); + // gate step + torch::matmul_out(torch_G, torch_X, torch_Wg.transpose(0, 1)); + + // // activation step + // torch::silu_out(torch_F, torch_G); + + // // up step + // torch::matmul_out(torch_U, torch_X, torch_Wu.transpose(0, 1)); + + // // multiplication step, reuse gate_out + // torch::mul_out(torch_G, torch_F, torch_U); + + // // down step + // torch::matmul_out(torch_O, torch_G, torch_Wd.transpose(0, 1)); + } + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration elapsed = end - start; + std::cout << "Torch MLP execution time: " << elapsed.count() << " seconds" + << std::endl; + + // Print outputs + auto flatten_O = torch_O.view({-1}); + std::cout << "Torch O output (first 10 elements): "; + for (int i = 0; i < 10; ++i) { + std::cout << flatten_O[i].item() << " "; + } + std::cout << std::endl; + + cudaStreamDestroy(stream); +} + +int main() { + // int B = 128, K = 256, Ng = 512, Nu = 512, No = 256; + + cutlass::HostTensor X({B, K}); + cutlass::HostTensor Wg({Ng, K}); + cutlass::HostTensor Wu({Nu, K}); + cutlass::HostTensor Wd({No, Nu}); + cutlass::HostTensor G({B, Ng}); + cutlass::HostTensor U({B, Nu}); + cutlass::HostTensor F({B, Nu}); + cutlass::HostTensor O({B, No}); + + cutlass::reference::host::TensorFillRandomUniform(X.host_view(), 1, 1.0f, + -0.5f); + cutlass::reference::host::TensorFillRandomUniform(Wg.host_view(), 1, 1.0f, + -0.5f); + cutlass::reference::host::TensorFillRandomUniform(Wu.host_view(), 1, 1.0f, + -0.5f); + cutlass::reference::host::TensorFillRandomUniform(Wd.host_view(), 1, 1.0f, + -0.5f); + + torch::Tensor torch_X = torch::empty({B, K}, torch::kBFloat16).cuda(); + torch::Tensor torch_Wg = torch::empty({Ng, K}, torch::kBFloat16).cuda(); + torch::Tensor torch_Wu = torch::empty({Nu, K}, torch::kBFloat16).cuda(); + torch::Tensor torch_Wd = torch::empty({No, Nu}, torch::kBFloat16).cuda(); + torch::Tensor torch_G = torch::empty({B, Ng}, torch::kBFloat16).cuda(); + torch::Tensor torch_U = torch::empty({B, Nu}, torch::kBFloat16).cuda(); + torch::Tensor torch_F = torch::empty({B, Nu}, torch::kBFloat16).cuda(); + torch::Tensor torch_O = torch::empty({B, No}, torch::kBFloat16).cuda(); + + X.sync_device(); + Wg.sync_device(); + Wu.sync_device(); + Wd.sync_device(); + F.sync_device(); + G.sync_device(); + U.sync_device(); + O.sync_device(); + + // copy data to torch tensors + cudaMemcpy(torch_X.data_ptr(), X.device_data(), + X.size() * sizeof(cutlass::bfloat16_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(torch_Wg.data_ptr(), Wg.device_data(), + Wg.size() * sizeof(cutlass::bfloat16_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(torch_Wu.data_ptr(), Wu.device_data(), + Wu.size() * sizeof(cutlass::bfloat16_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(torch_Wd.data_ptr(), Wd.device_data(), + Wd.size() * sizeof(cutlass::bfloat16_t), cudaMemcpyDeviceToDevice); + + // print all tensors + std::cout << "Input X (first 10 elements): "; + for (int i = 0; i < 10; ++i) std::cout << float(X.host_data()[i]) << " "; + std::cout << std::endl; + + std::cout << "Weight Wg (first 10 elements): "; + for (int i = 0; i < 10; ++i) std::cout << float(Wg.host_data()[i]) << " "; + std::cout << std::endl; + + std::cout << "Weight Wu (first 10 elements): "; + for (int i = 0; i < 10; ++i) std::cout << float(Wu.host_data()[i]) << " "; + std::cout << std::endl; + + std::cout << "Weight Wd (first 10 elements): "; + for (int i = 0; i < 10; ++i) std::cout << float(Wd.host_data()[i]) << " "; + std::cout << std::endl; + + // warm up torch + auto tmp = torch::matmul(torch_X, torch_Wg.transpose(0, 1)); + + torch_mlp(torch_X, torch_Wg, torch_Wu, torch_Wd, torch_G, torch_U, torch_F, + torch_O); + + cutlass_mlp(X, Wg, Wu, Wd, G, U, F, O); + + // Gemm1 gemm1; + // Gemm1::Arguments args1({B, Ng, K}, {X.device_data(), K}, + // {Wg.device_data(), Ng}, {G.device_data(), Ng}, + // {G.device_data(), Ng}); + // gemm1(args1); + + // Gemm2 gemm2; + // Gemm2::Arguments args2({B, Nu, K}, {X.device_data(), K}, + // {Wu.device_data(), Nu}, {U.device_data(), Nu}, + // {U.device_data(), Nu}); + // gemm2(args2); + + // G.sync_device(); + // U.sync_device(); + + // std::cout << "G output (first 10 elements): "; + // for (int i = 0; i < 10; ++i) std::cout << float(G.host_data()[i]) << " "; + // std::cout << std::endl; + + // std::cout << "U output (first 10 elements): "; + // for (int i = 0; i < 10; ++i) std::cout << float(U.host_data()[i]) << " "; + // std::cout << std::endl; + + // int total_elements = B * Nu; + // int threads = 256; + // int blocks = (total_elements + threads - 1) / threads; + // ElementwiseMultiply<<>>(G.device_data(), + // U.device_data(), + // F.device_data(), + // total_elements); + + // F.sync_device(); + // std::cout << "Fused G * U output (first 10 elements): "; + // for (int i = 0; i < 10; ++i) std::cout << float(F.host_data()[i]) << " "; + // std::cout << std::endl; + + // Gemm3 gemm3; + // Gemm3::Arguments args3({B, No, Nu}, {F.device_data(), Nu}, + // {Wd.device_data(), No}, {O.device_data(), No}, + // {O.device_data(), No}); + // gemm3(args3); + + // O.sync_host(); + // std::cout << "Fused MoEMLP output (first 10 elements): "; + // for (int i = 0; i < 10; ++i) std::cout << float(O.host_data()[i]) << " "; + // std::cout << std::endl; + // return 0; +} diff --git a/tests/cuda/test_expert_fusion_v2.cu b/tests/cuda/test_expert_fusion_v2.cu new file mode 100644 index 0000000..b52af3d --- /dev/null +++ b/tests/cuda/test_expert_fusion_v2.cu @@ -0,0 +1,191 @@ +// Fully fused CUTLASS GEMM kernel where the same A matrix is used with two +// different B matrices, with explicit tiling, iterators, and MMA, optimized for +// tensor cores, and merging C1 and C2 with element-wise multiplication. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +int B = 128, K = 2048, Ng = 768, Nu = 768, No = 2048; + +// Type definitions +using ElementInput = cutlass::bfloat16_t; +using ElementOutput = cutlass::bfloat16_t; +using ElementInputccumulator = float; +using ElementCompute = float; + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::RowMajor; + +using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; +using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; +using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; +using MMAOp = cutlass::arch::OpClassTensorOp; +using Arch = cutlass::arch::Sm80; + +using WarpMatrixShape = cutlass::MatrixShape; + +using Operator = + cutlass::arch::Mma, 32, ElementInput, + LayoutA, ElementInput, LayoutB, ElementInputccumulator, + LayoutC, cutlass::arch::OpMultiplyAdd>; + +using WrapMmaPolicy = + cutlass::gemm::warp::MmaTensorOpPolicy; + +using WarpMma = cutlass::gemm::warp::MmaTensorOp< + WarpShape, ElementInput, LayoutA, ElementInput, LayoutB, + ElementInputccumulator, LayoutC, WrapMmaPolicy, 1, true>; + +__global__ void fused_gemm_kernel(ElementInput const* A, ElementInput const* B1, + ElementInput const* B2, ElementOutput* C, + int M, int N, int K) { + extern __shared__ char shared_mem[]; + ElementInput* smem_A = reinterpret_cast(shared_mem); + ElementInput* smem_B1 = reinterpret_cast( + shared_mem + + ThreadblockShape::kM * ThreadblockShape::kK * sizeof(ElementInput)); + ElementInput* smem_B2 = smem_B1 + ThreadblockShape::kK * ThreadblockShape::kN; + + WarpMma mma; + typename WarpMma::FragmentC accum1; + typename WarpMma::FragmentC accum2; + accum1.clear(); + accum2.clear(); + + using IteratorA = typename WarpMma::IteratorA; + using IteratorB = typename WarpMma::IteratorB; + + for (int tile_k = 0; + tile_k < (K + ThreadblockShape::kK - 1) / ThreadblockShape::kK; + ++tile_k) { + int lane = threadIdx.x; + int num_threads = blockDim.x; + + for (int i = lane; i < ThreadblockShape::kM * ThreadblockShape::kK; + i += num_threads) { + int row = i / ThreadblockShape::kK; + int col = i % ThreadblockShape::kK; + int global_row = blockIdx.y * ThreadblockShape::kM + row; + int global_col = tile_k * ThreadblockShape::kK + col; + smem_A[i] = (global_row < M && global_col < K) + ? A[global_row * K + global_col] + : ElementInput(0); + } + + for (int i = lane; i < ThreadblockShape::kK * ThreadblockShape::kN; + i += num_threads) { + int row = i / ThreadblockShape::kN; + int col = i % ThreadblockShape::kN; + int global_row = tile_k * ThreadblockShape::kK + row; + int global_col = blockIdx.x * ThreadblockShape::kN + col; + smem_B1[i] = (global_row < K && global_col < N) + ? B1[global_row * N + global_col] + : ElementInput(0); + smem_B2[i] = (global_row < K && global_col < N) + ? B2[global_row * N + global_col] + : ElementInput(0); + } + + __syncthreads(); + + typename WarpMma::FragmentA frag_A; + typename WarpMma::FragmentB frag_B1; + typename WarpMma::FragmentB frag_B2; + +#pragma unroll + for (int i = 0; i < frag_A.size(); ++i) frag_A[i] = smem_A[i]; +#pragma unroll + for (int i = 0; i < frag_B1.size(); ++i) frag_B1[i] = smem_B1[i]; +#pragma unroll + for (int i = 0; i < frag_B2.size(); ++i) frag_B2[i] = smem_B2[i]; + + // IteratorA iter_A({smem_A, ThreadblockShape::kK}, lane); + // IteratorB iter_B1({smem_B1, ThreadblockShape::kN}, lane); + // IteratorB iter_B2({smem_B2, ThreadblockShape::kN}, lane); + + // iter_A.load(frag_A); + // iter_B1.load(frag_B1); + // iter_B2.load(frag_B2); + + mma(accum1, frag_A, frag_B1, accum1); + mma(accum2, frag_A, frag_B2, accum2); + __syncthreads(); + } + + int c_row = blockIdx.y * ThreadblockShape::kM + threadIdx.y * WarpShape::kM; + int c_col = blockIdx.x * ThreadblockShape::kN + threadIdx.z * WarpShape::kN; + for (int i = 0; i < WarpShape::kM; ++i) { + for (int j = 0; j < WarpShape::kN; ++j) { + int global_row = c_row + i; + int global_col = c_col + j; + if (global_row < M && global_col < N) { + size_t idx = i * WarpShape::kN + j; + float silu = accum1[idx] / (1.0f + expf(-accum1[idx])); + C[global_row * N + global_col] = ElementOutput(silu * accum2[idx]); + } + } + } +} + +template +void fill_host_tensor(cutlass::HostTensor& tensor, + Element value, float scale = 1.0f, float offset = 0.0f) { + cutlass::reference::host::TensorFillRandomUniform(tensor.host_view(), value, + scale, offset); +} + +int main() { + cutlass::HostTensor X({B, K}); + cutlass::HostTensor Wg({Ng, K}); + cutlass::HostTensor Wu({Nu, K}); + // cutlass::HostTensor Wd({No, Nu}); + cutlass::HostTensor C({B, No}); + + fill_host_tensor(X, ElementInput(1.0f), 1.0f, 0.0f); + fill_host_tensor(Wg, ElementInput(1.0f), 1.0f, 0.0f); + fill_host_tensor(Wu, ElementInput(1.0f), 1.0f, 0.0f); + // fill_host_tensor(Wd, ElementInput(1.0f), 1.0f, 0.0f); + + cudaStream_t stream; + cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); + + dim3 grid((Ng + ThreadblockShape::kN - 1) / ThreadblockShape::kN, + (B + ThreadblockShape::kM - 1) / ThreadblockShape::kM); + dim3 block(32, 4, 1); // 128 threads total + size_t shared_mem_size = + sizeof(ElementInput) * ThreadblockShape::kM * ThreadblockShape::kK + + sizeof(ElementInput) * ThreadblockShape::kK * ThreadblockShape::kN * 2; + fused_gemm_kernel<<>>( + X.device_data(), Wg.device_data(), Wu.device_data(), C.device_data(), B, + No, K); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + std::cerr << "CUDA error: " << cudaGetErrorString(err) << std::endl; + return -1; + } + + cudaStreamSynchronize(stream); + C.sync_host(); + + // Check results + std::cout << "First 10 elements of output C:\n" << std::endl; + for (int i = 0; i < 10; ++i) { + std::cout << C.host_data()[i] << " "; + } + std::cout << std::endl; +} diff --git a/tests/cuda/test_fused_mlp.cu b/tests/cuda/test_fused_mlp.cu new file mode 100644 index 0000000..1e4d6d0 --- /dev/null +++ b/tests/cuda/test_fused_mlp.cu @@ -0,0 +1,161 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights + *reserved. SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + *this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + *ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + *LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + *CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + *SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + *INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + *CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + *ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + *POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "device/b2b_gemm.h" +#include "b2b_gemm_run.h" +#include "test_run.h" + +//////////////////////////////////////////////////////////////////////////////// + +cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_0(128 * 640, 768, 2048); +cutlass::gemm::GemmCoord gemm_f16_sm80_problem_size_1(128 * 640, 2048, 768); + +bool run_nonfused_gemm_f16_sm80() { + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + using ElementCompute = cutlass::bfloat16_t; + + ElementCompute alpha0 = ElementCompute(1); + ElementCompute beta0 = ElementCompute(0); // beta=0 for no-bias + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(0); // beta=0 for no-bias + + using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 32>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using Gemm0 = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + ThreadblockShape0, WarpShape0, InstructionShape, + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute, + cutlass::epilogue::thread::ScaleType::Default>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 3>; + using Gemm1 = cutlass::gemm::device::Gemm< + cutlass::bfloat16_t, cutlass::layout::RowMajor, cutlass::bfloat16_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + ThreadblockShape1, WarpShape1, InstructionShape, + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute, + cutlass::epilogue::thread::ScaleType::Default>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 3>; + + B2bNonFusedGemmRun nonFusedGemm; + + std::cout << "Running Non-fused back-to-back FP16 TN GEMMs...\n"; + bool pass = nonFusedGemm.run(gemm_f16_sm80_problem_size_0, + gemm_f16_sm80_problem_size_1, alpha0, beta0, + alpha1, beta1); + if (pass) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + + return pass; +} + +bool run_fused_gemm_f16_sm80_rf_res() { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementCompute = cutlass::half_t; + + ElementCompute alpha0 = ElementCompute(1); + // Fused kernel has built-in bias, setting beta=0 + ElementCompute beta0 = ElementCompute(0); + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(0); // beta=1 for bias + + using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; + using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; + using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using EpilogueOutputOp0 = cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, InstructionShape::kM * InstructionShape::kN / 32, + ElementAccumulator, ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>; + + using EpilogueOutputOp1 = cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute, + cutlass::epilogue::thread::ScaleType::Default>; + + using B2bGemm = cutlass::gemm::device::B2bGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, 3>; + + B2bFusedGemmRun fusedGemm; + + std::cout + << "Running Fused back-to-back FP16 TN GEMMs with RF residency...\n"; + bool passed = + fusedGemm.run(gemm_f16_sm80_problem_size_0, gemm_f16_sm80_problem_size_1, + alpha0, beta0, alpha1, beta1); + if (passed) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + + return passed; +} + +int main() { + std::vector funcs = {&run_nonfused_gemm_f16_sm80, + &run_fused_gemm_f16_sm80_rf_res}; + + return testRun(80, funcs, "gemm f16 RF residency"); +} + +//////////////////////////////////////////////////////////////////////////////// diff --git a/tests/cuda/test_load_tile.cu b/tests/cuda/test_load_tile.cu new file mode 100644 index 0000000..ba977ce --- /dev/null +++ b/tests/cuda/test_load_tile.cu @@ -0,0 +1,250 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using ShapeA = cutlass::MatrixShape<64, 64>; +using ShapeB = cutlass::MatrixShape<64, 64>; + +template +struct TileLoader2D { + using Element = Element_; + using Layout = Layout_; + using ThreadblockShape = ThreadblockShape_; + using ThreadMap = ThreadMap_; + + using GmemIterator = cutlass::transform::threadblock::PredicatedTileIterator< + ThreadblockShape, Element, Layout, 1, ThreadMap>; + + static const int ElementSize = cutlass::sizeof_bits::value; + static const int Crosswise = 64; // Typical for SM80 + + using SmemLayout = std::conditional_t< + std::is_same_v, + cutlass::layout::RowMajorTensorOpMultiplicandCongruous, + cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous>; + + using SmemIterator = cutlass::transform::threadblock::RegularTileIterator< + ThreadblockShape, Element, SmemLayout, 1, ThreadMap, 16>; + + // Shared memory pointer + Element* smem_ptr; + + // Constructor + __device__ TileLoader2D(Element* smem_ptr_) : smem_ptr(smem_ptr_) {} + + // Load a tile from global to shared memory + __device__ void operator()( + cutlass::TensorView const& global_view, + cutlass::MatrixCoord const& tb_offset) { + int thread_idx = threadIdx.x; + // int thread_idx = threadIdx.x + threadIdx.y * blockDim.x; + + // In loader_A and loader_B + auto extent = global_view.extent(); + + GmemIterator gmem_it(global_view.layout(), global_view.data(), + global_view.extent(), thread_idx, tb_offset); + + typename GmemIterator::Fragment frag; + frag.clear(); + gmem_it.load(frag); + + cutlass::TensorRef smem_ref( + smem_ptr, SmemLayout::packed( + {ThreadblockShape::kRow, ThreadblockShape::kColumn})); + SmemIterator smem_it(smem_ref, thread_idx); + smem_it.store(frag); // This assumes total number of threads is wrap size + + __syncthreads(); // Ensure all threads have completed their loads + } +}; + +template +__global__ void load_3d_kernel(cutlass::TensorView viewA, + cutlass::TensorView viewB1, + cutlass::TensorView viewB2) { + size_t M = viewA.extent().row(); + size_t N = viewB1.extent().column(); + size_t K = viewA.extent().column(); + + // Calculate threadblock tile offsets + cutlass::MatrixCoord tb_offset_A(int(blockIdx.y * ShapeA::kRow), + int(blockIdx.z * ShapeA::kColumn)); + cutlass::MatrixCoord tb_offset_B(int(blockIdx.z * ShapeB::kRow), + int(blockIdx.x * ShapeB::kColumn)); + + if (tb_offset_A.row() >= M || tb_offset_A.column() >= K || + tb_offset_B.row() >= K || tb_offset_B.column() >= N) { + // This assumes matrix size is multiple of tile size in all dimensions + return; // Skip invalid tiles + } + + extern __shared__ __align__(16) char smem[]; + int smem_offset = 0; + + // Shared Memory Allocation + ElementA* smem_A = reinterpret_cast(smem + smem_offset); + constexpr int size_A = ShapeA::kRow * ShapeA::kColumn * sizeof(ElementA); + constexpr int size_B = ShapeB::kRow * ShapeB::kColumn * sizeof(ElementB); + smem_offset += size_A; + ElementB* smem_B1 = reinterpret_cast(smem + smem_offset); + smem_offset += size_B; + ElementB* smem_B2 = reinterpret_cast(smem + smem_offset); + + // KERNEL_LOG_DEBUG("size_A = %d, size_B = %d, smem_A = %p, smem_B = %p\n", + // size_A, size_B, static_cast(smem_A), + // static_cast(smem_B)); + + using ThreadMapA = cutlass::transform::PitchLinearWarpRakedThreadMap< + cutlass::layout::PitchLinearShape, 32, + cutlass::layout::PitchLinearShape<8, 4>, + 128 / cutlass::sizeof_bits::value>; + + using ThreadMapB = cutlass::transform::PitchLinearWarpRakedThreadMap< + cutlass::layout::PitchLinearShape, 32, + cutlass::layout::PitchLinearShape<8, 4>, + 128 / cutlass::sizeof_bits::value>; + + // Instantiate loaders + TileLoader2D loader_A(smem_A); + TileLoader2D loader_B1(smem_B1); + TileLoader2D loader_B2(smem_B2); + + // Distinguish between warp 0 and warp 1 + if (threadIdx.y == 0 && tb_offset_A.row() < M && tb_offset_A.column() < K) { + loader_A(viewA, tb_offset_A); + // int error_cnt_A = 0; + // #pragma unroll + // for (int i = 0; i < size_A / sizeof(ElementA); ++i) { + // if (smem_A[i] != ElementA(1.0f) && threadIdx.x == 0 && blockIdx.x + // == 0 && + // blockIdx.y == 0 && blockIdx.z == 0 && error_cnt_A < 10) { + // printf("smem_A[%d] = %f (expected 1.0)\n", i, float(smem_A[i])); + // error_cnt_A++; + // } + // } + } + if (threadIdx.y == 1 && tb_offset_B.row() < K && tb_offset_B.column() < N) { + loader_B1(viewB1, tb_offset_B); + // int error_cnt_B = 0; + // #pragma unroll + // for (int i = 0; i < size_B / sizeof(ElementB); ++i) { + // if (smem_B1[i] != ElementB(1.0f) && threadIdx.x == 0 && blockIdx.x + // == 0 && + // blockIdx.y == 0 && blockIdx.z == 0 && error_cnt_B < 10) { + // printf("smem_B[%d] = %f (expected 1.0)\n", i, float(smem_B1[i])); + // error_cnt_B++; + // } + // } + } + if (threadIdx.y == 2 && tb_offset_B.row() < K && tb_offset_B.column() < N) { + loader_B2(viewB2, tb_offset_B); + } +} + +int main() { + using ElementA = cutlass::bfloat16_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::bfloat16_t; + using LayoutB = cutlass::layout::ColumnMajor; + + constexpr int M = 128, N = 768, K = 2048; + cutlass::HostTensor A({M, K}); + cutlass::HostTensor B1({K, N}); + cutlass::HostTensor B2({K, N}); + cutlass::reference::host::TensorFill(A.host_view(), ElementA(1.0f)); + cutlass::reference::host::TensorFill(B1.host_view(), ElementB(1.0f)); + cutlass::reference::host::TensorFill(B2.host_view(), ElementB(1.0f)); + + A.sync_device(); + B1.sync_device(); + B2.sync_device(); + + // Create views + cutlass::TensorView viewA(A.device_data(), A.layout(), + {M, K}); + cutlass::TensorView viewB1(B1.device_data(), B1.layout(), + {K, N}); + cutlass::TensorView viewB2(B2.device_data(), B2.layout(), + {K, N}); + + dim3 grid((N + ShapeB::kColumn - 1) / ShapeB::kColumn, + (M + ShapeA::kRow - 1) / ShapeA::kRow, + (K + std::max(ShapeA::kColumn, ShapeB::kRow) - 1) / + std::max(ShapeA::kColumn, ShapeB::kRow)); + dim3 block(32, 3, 1); + size_t smem_size = ShapeA::kRow * ShapeA::kColumn * sizeof(ElementA) + + ShapeB::kRow * ShapeB::kColumn * sizeof(ElementB) * 2; + printf("smem_size = %zu bytes\n", smem_size); + cudaFuncSetAttribute(load_3d_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + cudaStream_t stream; + cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); + + // CUDA events for timing + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + constexpr int iterations = 1000; + // Launch kernel and time + auto start_chrono = std::chrono::high_resolution_clock::now(); + cudaEventRecord(start, stream); + for (int i = 0; i < iterations; ++i) { + load_3d_kernel + <<>>(viewA, viewB1, viewB2); + auto err = cudaGetLastError(); + if (err != cudaSuccess) { + std::cerr << "CUDA error: " << cudaGetErrorString(err) << std::endl; + return -1; + } + } + cudaEventRecord(stop, stream); + + cudaEventSynchronize(stop); + cudaStreamSynchronize(stream); + auto end_chrono = std::chrono::high_resolution_clock::now(); + + cudaEventSynchronize(stop); + + float milliseconds = 0.0f; + cudaEventElapsedTime(&milliseconds, start, stop); + + // milliseconds = + // std::chrono::duration(end_chrono - start_chrono) + // .count(); + + // Compute throughput + size_t total_elements_A = size_t(M) * K; + size_t total_elements_B = size_t(K) * N; + size_t total_bytes_A = total_elements_A * sizeof(ElementA); + size_t total_bytes_B = total_elements_B * sizeof(ElementB) * 2; + size_t total_bytes = total_bytes_A + total_bytes_B; + + float time_sec = (milliseconds / 1000.0f) / iterations; + float gb_transferred = float(total_bytes) / 1e9f; + float throughput = gb_transferred / time_sec; + + std::cout << "3D matrix tile load completed.\n"; + std::cout << "Time: " << milliseconds << " ms\n"; + std::cout << "Data transferred: " << gb_transferred << " GB\n"; + std::cout << "Throughput: " << throughput << " GB/s\n"; + + cudaEventDestroy(start); + cudaEventDestroy(stop); + cudaStreamDestroy(stream); + return 0; +} diff --git a/tests/cuda/test_load_tile_templated.cu b/tests/cuda/test_load_tile_templated.cu new file mode 100644 index 0000000..907e4fe --- /dev/null +++ b/tests/cuda/test_load_tile_templated.cu @@ -0,0 +1,332 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common/types.h" + +// TileLoader2D remains the same +template +struct TileLoader2D { + using Element = Element_; + using Layout = Layout_; + using ThreadblockShape = ThreadblockShape_; + using ThreadMap = ThreadMap_; + + using GmemIterator = cutlass::transform::threadblock::PredicatedTileIterator< + ThreadblockShape, Element, Layout, 1, ThreadMap>; + + static const int ElementSize = cutlass::sizeof_bits::value; + static const int Crosswise = 64; + + using SmemLayout = std::conditional_t< + std::is_same_v, + cutlass::layout::RowMajorTensorOpMultiplicandCongruous, + cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous>; + + using SmemIterator = cutlass::transform::threadblock::RegularTileIterator< + ThreadblockShape, Element, SmemLayout, 1, ThreadMap, 16>; + + Element* smem_ptr; + + __device__ TileLoader2D(Element* smem_ptr_) : smem_ptr(smem_ptr_) {} + + __device__ void operator()( + cutlass::TensorView const& global_view, + cutlass::MatrixCoord const& tb_offset) { + int thread_idx = threadIdx.x; + + auto extent = global_view.extent(); + + GmemIterator gmem_it(global_view.layout(), global_view.data(), + global_view.extent(), thread_idx, tb_offset); + + typename GmemIterator::Fragment frag; + frag.clear(); + gmem_it.load(frag); + + cutlass::TensorRef smem_ref( + smem_ptr, SmemLayout::packed( + {ThreadblockShape::kRow, ThreadblockShape::kColumn})); + SmemIterator smem_it(smem_ref, thread_idx); + smem_it.store(frag); + + __syncthreads(); + } +}; + +// Flexible multi-view tile loader class +template +class MultiViewTileLoader { + public: + static constexpr size_t NumViews = sizeof...(ViewConfigs); + + // Extract types from ViewConfigs + template + using Element = typename GetNthType_t::Element; + + template + using Layout = typename GetNthType_t::Layout; + + template + using Shape = typename GetNthType_t::Shape; + + template + using ThreadMap = typename GetNthType_t::ThreadMap; + + // Calculate total shared memory size + static constexpr size_t calculate_smem_size() { + return calculate_smem_size_impl<0>(); + } + + private: + template + static constexpr size_t calculate_smem_size_impl() { + if constexpr (I >= NumViews) { + return 0; + } else { + using CurrentShape = Shape; + using CurrentElement = Element; + return CurrentShape::kRow * CurrentShape::kColumn * + sizeof(CurrentElement) + + calculate_smem_size_impl(); + } + } + + public: + // Helper to load a specific view + template + __device__ static void load_view( + cutlass::TensorView, Layout> const& view, + char* smem_base, int warp_id) { + if (threadIdx.y != warp_id) return; + + // Calculate shared memory offset for this view + size_t smem_offset = calculate_view_offset(); + Element* smem_ptr = + reinterpret_cast*>(smem_base + smem_offset); + + // Calculate threadblock offset based on view configuration + cutlass::MatrixCoord tb_offset = + GetNthType_t::calculate_offset(); + + // Create and use loader + using Loader = TileLoader2D, Layout, + Shape, ThreadMap>; + Loader loader(smem_ptr); + loader(view, tb_offset); + + // int err_cnt = 0; + // #pragma unroll + // for (int i = 0; i < Shape::kRow * + // Shape::kColumn; ++i) { + // if (smem_ptr[i] != Element(1.0f) && threadIdx.x == 0 && + // blockIdx.x == 0 && + // blockIdx.y == 0 && blockIdx.z == 0 && err_cnt < 10) { + // printf("smem[%d] = %f (expected 1.0)\n", i, float(smem_ptr[i])); + // err_cnt++; + // } + // } + } + + private: + template + static constexpr size_t calculate_view_offset() { + if constexpr (ViewIndex == 0) { + return 0; + } else { + return calculate_view_offset() + + Shape::kRow * Shape::kColumn * + sizeof(Element); + } + } + + template + static constexpr size_t calculate_view_size() { + return Shape::kRow * Shape::kColumn * + sizeof(Element); + } +}; + +// View configuration helper +template +struct ViewConfig { + using Element = Element_; + using Layout = Layout_; + using Shape = Shape_; + using ThreadMap = ThreadMap_; + + // Virtual method to be specialized for different matrix positions + __device__ static cutlass::MatrixCoord calculate_offset() { + return cutlass::MatrixCoord(0, 0); + } +}; + +// Specialized view configurations for A and B matrices +template +struct ViewConfigA : ViewConfig { + __device__ static cutlass::MatrixCoord calculate_offset() { + return cutlass::MatrixCoord(int(blockIdx.y * Shape_::kRow), + int(blockIdx.z * Shape_::kColumn)); + } +}; + +template +struct ViewConfigB : ViewConfig { + __device__ static cutlass::MatrixCoord calculate_offset() { + return cutlass::MatrixCoord(int(blockIdx.z * Shape_::kRow), + int(blockIdx.x * Shape_::kColumn)); + } +}; + +// Helper function templates to load each view by index +template +__device__ void load_each(char* smem, FirstView& first, RestViews&... rest) { + MultiLoader::template load_view(first, smem, Index); + if constexpr (sizeof...(RestViews) > 0) { + load_each(smem, rest...); + } +} + +// Base case for recursion +template +__device__ void load_each(char* smem) {} + +// Kernel that uses the flexible loader +template +__global__ void flexible_load_kernel(Views... views) { + extern __shared__ __align__(16) char smem[]; + load_each(smem, views...); +} + +// Example usage +int main() { + using ElementA = cutlass::bfloat16_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::bfloat16_t; + using LayoutB = cutlass::layout::ColumnMajor; + + using ShapeA = cutlass::MatrixShape<64, 64>; + using ShapeB = cutlass::MatrixShape<64, 64>; + + using ThreadMapA = cutlass::transform::PitchLinearWarpRakedThreadMap< + cutlass::layout::PitchLinearShape, 32, + cutlass::layout::PitchLinearShape<8, 4>, + 128 / cutlass::sizeof_bits::value>; + + using ThreadMapB = cutlass::transform::PitchLinearWarpRakedThreadMap< + cutlass::layout::PitchLinearShape, 32, + cutlass::layout::PitchLinearShape<8, 4>, + 128 / cutlass::sizeof_bits::value>; + + // Define view configurations + using ViewA = ViewConfigA; + using ViewB1 = ViewConfigB; + using ViewB2 = ViewConfigB; + + // Create multi-view loader + using Loader = MultiViewTileLoader; + + constexpr int M = 128, N = 768, K = 2048; + cutlass::HostTensor A({M, K}); + cutlass::HostTensor B1({K, N}); + cutlass::HostTensor B2({K, N}); + + cutlass::reference::host::TensorFill(A.host_view(), ElementA(1.0f)); + cutlass::reference::host::TensorFill(B1.host_view(), ElementB(1.0f)); + cutlass::reference::host::TensorFill(B2.host_view(), ElementB(1.0f)); + + A.sync_device(); + B1.sync_device(); + B2.sync_device(); + + // Create views + cutlass::TensorView viewA(A.device_data(), A.layout(), + {M, K}); + cutlass::TensorView viewB1(B1.device_data(), B1.layout(), + {K, N}); + cutlass::TensorView viewB2(B2.device_data(), B2.layout(), + {K, N}); + + // Calculate grid and block dimensions + dim3 grid((N + ShapeB::kColumn - 1) / ShapeB::kColumn, + (M + ShapeA::kRow - 1) / ShapeA::kRow, + (K + std::max(ShapeA::kColumn, ShapeB::kRow) - 1) / + std::max(ShapeA::kColumn, ShapeB::kRow)); + dim3 block(32, Loader::NumViews, + 1); // Automatically adjust to number of views + + size_t smem_size = Loader::calculate_smem_size(); + printf("Number of views: %zu\n", Loader::NumViews); + printf("Shared memory size: %zu bytes\n", smem_size); + + cudaFuncSetAttribute( + flexible_load_kernel, + cutlass::TensorView, + cutlass::TensorView>, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + cudaStream_t stream; + cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); + + // CUDA events for timing + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + constexpr int iterations = 2000; + + cudaEventRecord(start, stream); + for (int i = 0; i < iterations; ++i) { + flexible_load_kernel + <<>>(viewA, viewB1, viewB2); + + auto err = cudaGetLastError(); + if (err != cudaSuccess) { + std::cerr << "CUDA error: " << cudaGetErrorString(err) << std::endl; + return -1; + } + } + cudaEventRecord(stop, stream); + + cudaEventSynchronize(stop); + cudaStreamSynchronize(stream); + + float milliseconds = 0.0f; + cudaEventElapsedTime(&milliseconds, start, stop); + + // Compute throughput + size_t total_elements_A = size_t(M) * K; + size_t total_elements_B = size_t(K) * N; + size_t total_bytes_A = total_elements_A * sizeof(ElementA); + size_t total_bytes_B = total_elements_B * sizeof(ElementB) * 2; + size_t total_bytes = total_bytes_A + total_bytes_B; + + float time_sec = (milliseconds / 1000.0f) / iterations; + float gb_transferred = float(total_bytes) / 1e9f; + float throughput = gb_transferred / time_sec; + + std::cout << "Flexible multi-view tile load completed.\n"; + std::cout << "Time: " << milliseconds << " ms\n"; + std::cout << "Data transferred: " << gb_transferred << " GB\n"; + std::cout << "Throughput: " << throughput << " GB/s\n"; + + cudaEventDestroy(start); + cudaEventDestroy(stop); + cudaStreamDestroy(stream); + return 0; +} diff --git a/tests/cuda/test_single_gemm_tiled.cu b/tests/cuda/test_single_gemm_tiled.cu new file mode 100644 index 0000000..9a8e0bd --- /dev/null +++ b/tests/cuda/test_single_gemm_tiled.cu @@ -0,0 +1,130 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using ElementA = cutlass::bfloat16_t; +using ElementB = cutlass::bfloat16_t; +using ElementC = cutlass::bfloat16_t; +using ElementAccumulator = float; + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::RowMajor; + +// Tile configurations +using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; +using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; +using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + +// MmaCore definition +using MmaCore = cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, ElementB, + LayoutB, ElementAccumulator, LayoutC, cutlass::arch::OpClassTensorOp, 1, + cutlass::arch::OpMultiplyAdd, false, cutlass::arch::CacheOperation::Global, + cutlass::arch::CacheOperation::Global>; + +using Mma = typename MmaCore::Mma; +using FragmentC = typename MmaCore::FragmentC; +using SharedStorage = typename Mma::SharedStorage; + +// CUDA Kernel +__global__ void mma_kernel(cutlass::TensorRef A, + cutlass::TensorRef B, + cutlass::TensorRef C, int M, + int N, int K) { + extern __shared__ __align__(16) char shared_storage[]; + SharedStorage& smem = *reinterpret_cast(shared_storage); + + Mma mma(smem, threadIdx.x, threadIdx.y, threadIdx.z); + + // Threadblock tile offsets + int block_row = blockIdx.y * ThreadblockShape::kM; + int block_col = blockIdx.x * ThreadblockShape::kN; + + FragmentC accum; + cutlass::arch::mma::clear(accum); // Zero accumulator + + for (int k_tile = 0; k_tile < K; k_tile += ThreadblockShape::kK) { + typename Mma::FragmentA fragA; + typename Mma::FragmentB fragB; + + // Load A tile from global memory + mma.load_a(fragA, A.data() + block_row * K + k_tile, A.stride(0), k_tile); + // Load B tile from global memory + mma.load_b(fragB, B.data() + k_tile * N + block_col, B.stride(0), k_tile); + + // Fused multiply-add + mma(accum, fragA, fragB, accum); + } + + // Write result C to global memory + mma.store_c(C.data() + block_row * N + block_col, C.stride(0), accum); +} + +// Host code +int main() { + constexpr int M = 256, N = 256, K = 256; + + cutlass::HostTensor A({M, K}); + cutlass::HostTensor B({K, N}); + cutlass::HostTensor C({M, N}); + cutlass::HostTensor C_ref({M, N}); + + cutlass::reference::host::TensorFill(A.host_view(), ElementA(1.0f)); + cutlass::reference::host::TensorFill(B.host_view(), ElementB(1.0f)); + + A.sync_device(); + B.sync_device(); + + cutlass::TensorRef A_ref(A.device_data(), A.layout()); + cutlass::TensorRef B_ref(B.device_data(), B.layout()); + cutlass::TensorRef C_ref_device(C.device_data(), + C.layout()); + + // Kernel launch + dim3 grid((N + ThreadblockShape::kN - 1) / ThreadblockShape::kN, + (M + ThreadblockShape::kM - 1) / ThreadblockShape::kM); + dim3 block(32, 4, 1); // Typical for warp-level threading + size_t smem_size = sizeof(SharedStorage); + + std::cout << "Launching kernel with grid = (" << grid.x << ", " << grid.y + << "), block = (" << block.x << ", " << block.y << ")\n"; + mma_kernel<<>>(A_ref, B_ref, C_ref_device, M, N, K); + cudaDeviceSynchronize(); + + // Copy result back + C.sync_host(); + + // Reference computation + cutlass::reference::host::Gemm( + {M, N, K}, ElementAccumulator(1.0f), A.host_ref(), B.host_ref(), + ElementAccumulator(0.0f), C_ref.host_ref()); + + // Validate result + bool passed = true; + for (int i = 0; i < M * N; ++i) { + float diff = + std::abs(float(C.host_data()[i]) - float(C_ref.host_data()[i])); + if (diff > 1e-2) { + std::cout << "Mismatch at " << i << ": " << float(C.host_data()[i]) + << " vs " << float(C_ref.host_data()[i]) << "\n"; + passed = false; + break; + } + } + + if (passed) { + std::cout << "GEMM passed!\n"; + } else { + std::cout << "GEMM failed!\n"; + } + + return 0; +} diff --git a/tests/cuda/test_tile_size.cu b/tests/cuda/test_tile_size.cu new file mode 100644 index 0000000..77a0fcf --- /dev/null +++ b/tests/cuda/test_tile_size.cu @@ -0,0 +1,254 @@ +#include +#include +#include + +#include "common/types.h" + +// Bus width aware GEMM tile calculator +template +struct BusAwareGemmTileCalculator { + // Memory bandwidth parameters + static constexpr int bytes_per_cycle = BusWidthBits / 8; + static constexpr int target_cycles = TargetCycles; + + // Element sizes + static constexpr int size_A = cutlass::sizeof_bits::value / 8; + static constexpr int size_B = cutlass::sizeof_bits::value / 8; + static constexpr int size_C = cutlass::sizeof_bits::value / 8; + + // Total bytes we can move in target cycles + static constexpr int total_bandwidth_bytes = bytes_per_cycle * target_cycles; + + // Scale factor based on bus width (normalized to 256-bit baseline) + static constexpr int bus_scale_factor = BusWidthBits / 256; + static constexpr int bus_sqrt_scale = ConstexprSqrt::value; + + // Base tile dimensions for 256-bit bus + struct BaseTileSizes { + static constexpr int kM = 64; + static constexpr int kN = 64; + static constexpr int kK = 32; + }; + + // Scale tiles based on bus width + struct ScaledGemmTile { + // Scale M and N with square root of bus scaling to maintain aspect ratio + // Scale K less aggressively to maintain data reuse + static constexpr int kM_raw = BaseTileSizes::kM * bus_sqrt_scale; + static constexpr int kN_raw = BaseTileSizes::kN * bus_sqrt_scale; + static constexpr int kK_raw = + BaseTileSizes::kK * (bus_sqrt_scale + 1) / 2; // Scale K by half + + // Round to tensor core friendly sizes + static constexpr int kM = RoundToMultiple::value; + static constexpr int kN = RoundToMultiple::value; + static constexpr int kK = RoundToMultiple::value; + + // Calculate actual memory usage + static constexpr int elements_A = kM * kK; + static constexpr int elements_B = kK * kN; + static constexpr int elements_C = kM * kN; + + static constexpr int bytes_A = elements_A * size_A; + static constexpr int bytes_B = elements_B * size_B; + static constexpr int bytes_C = elements_C * size_C; + + // Total bytes for one tile computation (read A, B and write C) + static constexpr int total_bytes = bytes_A + bytes_B + bytes_C; + + // Cycles needed to transfer this data + static constexpr int cycles_needed = + (total_bytes + bytes_per_cycle - 1) / bytes_per_cycle; + + // Check if we fit within bandwidth budget + static constexpr bool fits_bandwidth = cycles_needed <= target_cycles; + }; + + // Architecture-specific optimized tiles + struct OptimizedTiles { + // For narrow bus (consumer GPUs): prioritize square tiles + struct NarrowBus { + static constexpr bool is_narrow = BusWidthBits <= 384; + static constexpr int kM = is_narrow ? 64 : ScaledGemmTile::kM; + static constexpr int kN = is_narrow ? 64 : ScaledGemmTile::kN; + static constexpr int kK = is_narrow ? 32 : ScaledGemmTile::kK; + }; + + // For wide bus (HBM GPUs): can afford larger tiles + struct WideBus { + static constexpr bool is_wide = BusWidthBits >= 4096; + static constexpr int kM = + is_wide ? RoundToMultiple::value + : ScaledGemmTile::kM; + static constexpr int kN = + is_wide ? RoundToMultiple::value + : ScaledGemmTile::kN; + static constexpr int kK = + is_wide ? RoundToMultiple::value + : ScaledGemmTile::kK; + }; + + // Choose based on bus width + static constexpr int kM = + WideBus::is_wide + ? WideBus::kM + : (NarrowBus::is_narrow ? NarrowBus::kM : ScaledGemmTile::kM); + static constexpr int kN = + WideBus::is_wide + ? WideBus::kN + : (NarrowBus::is_narrow ? NarrowBus::kN : ScaledGemmTile::kN); + static constexpr int kK = + WideBus::is_wide + ? WideBus::kK + : (NarrowBus::is_narrow ? NarrowBus::kK : ScaledGemmTile::kK); + }; + + // Threadblock clusters for CUTLASS 3.x + struct ClusterShape { + // Larger clusters for wider memory interfaces + static constexpr int kM = BusWidthBits >= 4096 ? 2 : 1; + static constexpr int kN = BusWidthBits >= 4096 ? 2 : 1; + static constexpr int kK = 1; // K dimension clustering less beneficial + }; + + // Pipeline stages based on bandwidth + struct PipelineConfig { + // More stages for wider interfaces to hide latency + static constexpr int stages = BusWidthBits >= 4096 ? 4 + : BusWidthBits >= 384 ? 3 + : 2; + }; + + // Warp arrangement + struct WarpArrangement { + // More warps for larger tiles + static constexpr int warps_m = OptimizedTiles::kM / 32; + static constexpr int warps_n = OptimizedTiles::kN / 32; + static constexpr int total_warps = warps_m * warps_n; + + // Ensure we don't exceed SM warp limits + static constexpr int max_warps = 16; // Typical limit + static constexpr bool valid = total_warps <= max_warps; + }; + + static void print_config() { + printf("=== Bus-Aware GEMM Tile Configuration ===\n"); + printf("Memory Bus: %d bits (%d bytes/cycle)\n", BusWidthBits, + bytes_per_cycle); + printf("Element sizes: A=%d, B=%d, C=%d bytes\n", size_A, size_B, size_C); + printf("Bus scale factor: %dx (sqrt: %dx)\n", bus_scale_factor, + bus_sqrt_scale); + printf("\nScaled tile dimensions:\n"); + printf(" Raw: %dx%dx%d\n", ScaledGemmTile::kM_raw, ScaledGemmTile::kN_raw, + ScaledGemmTile::kK_raw); + printf(" Aligned: %dx%dx%d\n", ScaledGemmTile::kM, ScaledGemmTile::kN, + ScaledGemmTile::kK); + printf(" Memory: A=%d, B=%d, C=%d bytes (total: %d)\n", + ScaledGemmTile::bytes_A, ScaledGemmTile::bytes_B, + ScaledGemmTile::bytes_C, ScaledGemmTile::total_bytes); + printf(" Cycles needed: %d (budget: %d)\n", ScaledGemmTile::cycles_needed, + target_cycles); + printf("\nOptimized tile: %dx%dx%d\n", OptimizedTiles::kM, + OptimizedTiles::kN, OptimizedTiles::kK); + printf("Cluster shape: %dx%dx%d\n", ClusterShape::kM, ClusterShape::kN, + ClusterShape::kK); + printf("Pipeline stages: %d\n", PipelineConfig::stages); + printf("Warp arrangement: %dx%d = %d warps\n", WarpArrangement::warps_m, + WarpArrangement::warps_n, WarpArrangement::total_warps); + printf("=========================================\n"); + } +}; + +// Predefined configurations for common GPUs +template +struct GPUOptimalGemmTiles { + // Consumer GPUs + using RTX4090 = BusAwareGemmTileCalculator<384, ElementA, ElementB, ElementC>; + using RTX4080 = BusAwareGemmTileCalculator<256, ElementA, ElementB, ElementC>; + using RTX3090 = BusAwareGemmTileCalculator<384, ElementA, ElementB, ElementC>; + + // Data center GPUs + using H200 = BusAwareGemmTileCalculator<6144, ElementA, ElementB, ElementC>; + using H100 = BusAwareGemmTileCalculator<5120, ElementA, ElementB, ElementC>; + using A100 = BusAwareGemmTileCalculator<5120, ElementA, ElementB, ElementC>; + using V100 = BusAwareGemmTileCalculator<4096, ElementA, ElementB, ElementC>; + + static void compare_all() { + printf("=== GEMM Tile Sizes Across GPUs ===\n"); + printf("Data types: A=%d-bit, B=%d-bit, C=%d-bit\n\n", + cutlass::sizeof_bits::value, + cutlass::sizeof_bits::value, + cutlass::sizeof_bits::value); + + printf("GPU Bus Tile (MxNxK) Cluster Stages\n"); + printf("---------- ------ --------------- -------- ------\n"); + + auto print_gpu = [](const char* name, int bus, auto calc) { + using Calc = decltype(calc); + printf("%-10s %4d %3dx%3dx%2d %dx%dx%d %d\n", name, bus, + Calc::OptimizedTiles::kM, Calc::OptimizedTiles::kN, + Calc::OptimizedTiles::kK, Calc::ClusterShape::kM, + Calc::ClusterShape::kN, Calc::ClusterShape::kK, + Calc::PipelineConfig::stages); + }; + + print_gpu("RTX 4080", 256, RTX4080{}); + print_gpu("RTX 3090", 384, RTX3090{}); + print_gpu("RTX 4090", 384, RTX4090{}); + print_gpu("V100", 4096, V100{}); + print_gpu("A100", 5120, A100{}); + print_gpu("H100", 5120, H100{}); + print_gpu("H200", 6144, H200{}); + } +}; + +// Example usage +int main() { + // Show how tile sizes scale with bus width + printf("=== Scaling Analysis ===\n"); + + // Same data types, different bus widths + using BF16_Gemm_256 = BusAwareGemmTileCalculator<256, cutlass::bfloat16_t, + cutlass::bfloat16_t, float>; + using BF16_Gemm_512 = BusAwareGemmTileCalculator<512, cutlass::bfloat16_t, + cutlass::bfloat16_t, float>; + using BF16_Gemm_1024 = BusAwareGemmTileCalculator<1024, cutlass::bfloat16_t, + cutlass::bfloat16_t, float>; + using BF16_Gemm_4096 = BusAwareGemmTileCalculator<4096, cutlass::bfloat16_t, + cutlass::bfloat16_t, float>; + using BF16_Gemm_5120 = BusAwareGemmTileCalculator<5120, cutlass::bfloat16_t, + cutlass::bfloat16_t, float>; + + printf("\nBF16xBF16->FP32 GEMM tiles vs bus width:\n"); + printf("Bus Width Tile Size Bandwidth Usage\n"); + printf("--------- -------------- ----------------\n"); + + auto print_config = [](int bus, auto calc) { + using Calc = decltype(calc); + printf("%4d-bit %3dx%3dx%2d %d/%d cycles\n", bus, + Calc::OptimizedTiles::kM, Calc::OptimizedTiles::kN, + Calc::OptimizedTiles::kK, Calc::ScaledGemmTile::cycles_needed, + Calc::target_cycles); + }; + + print_config(256, BF16_Gemm_256{}); + print_config(512, BF16_Gemm_512{}); + print_config(1024, BF16_Gemm_1024{}); + print_config(4096, BF16_Gemm_4096{}); + print_config(5120, BF16_Gemm_5120{}); + + // Compare across different GPUs + printf("\n"); + GPUOptimalGemmTiles::compare_all(); + + // Detailed config for H100 + printf("\n"); + BF16_Gemm_5120::print_config(); + + printf("\n"); + BF16_Gemm_256::print_config(); + + return 0; +} diff --git a/tests/cuda/test_uvm_kernel.cu b/tests/cuda/test_uvm_kernel.cu new file mode 100644 index 0000000..a6169fa --- /dev/null +++ b/tests/cuda/test_uvm_kernel.cu @@ -0,0 +1,216 @@ +#include +#include +#include +#include + +__device__ __forceinline__ int load_lu(const int* ptr) { + int val; + asm volatile("ld.global.lu.s32 %0, [%1];" : "=r"(val) : "l"(ptr)); + return val; +} + +__device__ __forceinline__ void store_wb(int* ptr, int val) { + asm volatile("st.global.wb.s32 [%0], %1;" ::"l"(ptr), "r"(val)); +} + +__global__ __launch_bounds__(512) void readEvery64B_Uncached( + const int* __restrict__ data, int N) { + const int stride = 4096 / sizeof(int); // 64-byte stride + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * stride; + + unsigned int local = 0; + if (idx < N) { + local = load_lu(&data[idx]); + } + + if (local == 0xDEADBEEF) { + printf("unreachable: %u\n", local); + } +} + +__global__ __launch_bounds__(512) void copyAllUVM_Uncached( + const int* __restrict__ src, int* __restrict__ dst, int N) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < N) { + dst[idx] = load_lu(&src[idx]); + } +} + +__device__ __forceinline__ void copy_64B_uncached(const void* src, void* dst) { + int4 reg0, reg1, reg2, reg3; + + // Load 64 bytes (4 x int4 = 4 x 16B) from uncached global memory + asm volatile("ld.global.lu.v4.s32 {%0, %1, %2, %3}, [%4];" + : "=r"(reg0.x), "=r"(reg0.y), "=r"(reg0.z), "=r"(reg0.w) + : "l"(src)); + asm volatile("ld.global.lu.v4.s32 {%0, %1, %2, %3}, [%4];" + : "=r"(reg1.x), "=r"(reg1.y), "=r"(reg1.z), "=r"(reg1.w) + : "l"((const char*)src + 16)); + asm volatile("ld.global.lu.v4.s32 {%0, %1, %2, %3}, [%4];" + : "=r"(reg2.x), "=r"(reg2.y), "=r"(reg2.z), "=r"(reg2.w) + : "l"((const char*)src + 32)); + asm volatile("ld.global.lu.v4.s32 {%0, %1, %2, %3}, [%4];" + : "=r"(reg3.x), "=r"(reg3.y), "=r"(reg3.z), "=r"(reg3.w) + : "l"((const char*)src + 48)); + + // Store 64 bytes with write-through to global memory (bypass L1) + asm volatile("st.global.wb.v4.s32 [%0], {%1, %2, %3, %4};" + : + : "l"(dst), "r"(reg0.x), "r"(reg0.y), "r"(reg0.z), "r"(reg0.w)); + asm volatile("st.global.wb.v4.s32 [%0], {%1, %2, %3, %4};" + : + : "l"((char*)dst + 16), "r"(reg1.x), "r"(reg1.y), "r"(reg1.z), + "r"(reg1.w)); + asm volatile("st.global.wb.v4.s32 [%0], {%1, %2, %3, %4};" + : + : "l"((char*)dst + 32), "r"(reg2.x), "r"(reg2.y), "r"(reg2.z), + "r"(reg2.w)); + asm volatile("st.global.wb.v4.s32 [%0], {%1, %2, %3, %4};" + : + : "l"((char*)dst + 48), "r"(reg3.x), "r"(reg3.y), "r"(reg3.z), + "r"(reg3.w)); +} + +__global__ __launch_bounds__(512) void copyEvery4KB_Uncached( + const int* __restrict__ src, int* __restrict__ dst, int N) { + const int stride = 4096 / sizeof(int); // 4KB = 1024 int32 + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * stride; + + if (idx < N) { + dst[idx] = load_lu(&src[idx]); // uncached load + // for (int i = 0; i < stride; ++i) { + // if (idx + i < N) { + // int val = load_lu(&src[idx+i]); + // store_wb(&dst[idx+i], val); + // } + // } + } +} + +int main() { + const int N = 1 << 26; // 64M ints = 256MB + const int strideInts = 4096 / sizeof(int); // 4KB stride + + cudaSetDevice(0); + + // ✅ Allocate Unified Memory + int* uvm_src = nullptr; + int* dev_dst = nullptr; + cudaMallocManaged(&uvm_src, N * sizeof(int)); + cudaMalloc(&dev_dst, N * sizeof(int)); // plain device memory + + // ✅ Touch source on host to keep pages resident in CPU memory initially + for (int i = 0; i < N; ++i) uvm_src[i] = i; + + cudaDeviceSynchronize(); // ensure all CPU writes are visible + + // ⛔ Do NOT prefetch — we want fault-triggered migration + + // Launch read kernel (will trigger page migrations) + auto t_start = std::chrono::high_resolution_clock::now(); + + int threads = 512; + int blocks = (N + threads - 1) / threads; + copyAllUVM_Uncached<<>>(uvm_src, dev_dst, N); + cudaDeviceSynchronize(); + + auto t_end = std::chrono::high_resolution_clock::now(); + std::chrono::duration elapsed_sec = t_end - t_start; + + // ✅ Validate correctness (check values at stride points) + int* host_check = (int*)malloc(N * sizeof(int)); + cudaMemcpy(host_check, dev_dst, N * sizeof(int), cudaMemcpyDeviceToHost); + + int errors = 0; + for (int i = 0; i < N; ++i) { + if (host_check[i] != i) { + if (++errors < 10) { + std::cerr << "Mismatch at index " << i << ": expected " << i << ", got " + << host_check[i] << "\n"; + } + } + } + + if (errors == 0) + std::cout << "✅ UVM migration correctness passed.\n"; + else + std::cout << "❌ UVM migration correctness failed. Errors: " << errors + << "\n"; + + // ✅ Print timing + double gb = N * sizeof(int) / double(1 << 30); + std::cout << "Time: " << elapsed_sec.count() * 1000.0 << " ms\n"; + std::cout << "Throughput: " << gb / elapsed_sec.count() << " GB/s\n"; + + // Cleanup + free(host_check); + cudaFree(uvm_src); + cudaFree(dev_dst); + + return 0; +} + +// int main() { +// const int N = 1 << 26; // 64M ints = 256MB +// const int chunkSize = 1 << 24; // 16M ints = 64MB +// const int numChunks = N / chunkSize; + +// // Allocate host memory with malloc and register as pinned +// int* host_data = (int*)aligned_alloc(4096, N * sizeof(int)); +// if (!host_data) { +// std::cerr << "Host malloc failed.\n"; +// return -1; +// } + +// // Fill with data +// for (int i = 0; i < N; ++i) { +// host_data[i] = i; +// } + +// // Register memory as pinned and mapped +// cudaHostRegister(host_data, N * sizeof(int), cudaHostRegisterMapped); + +// // Get device-accessible pointer +// int* device_ptr = nullptr; +// cudaHostGetDevicePointer(&device_ptr, host_data, 0); + +// // Create compute stream +// cudaStream_t stream; +// cudaStreamCreate(&stream); + +// // Timing setup +// cudaEvent_t start, stop; +// cudaEventCreate(&start); +// cudaEventCreate(&stop); +// cudaEventRecord(start); + +// for (int i = 0; i < numChunks; ++i) { +// int* chunk_ptr = device_ptr + i * chunkSize; + +// int stride = 4096 / sizeof(int); +// int effectiveElems = chunkSize / stride; +// int threads = 512; +// int blocks = (effectiveElems + threads - 1) / threads; + +// readEvery64B_Uncached<<>>(chunk_ptr, +// chunkSize); +// } + +// cudaEventRecord(stop, stream); +// cudaEventSynchronize(stop); + +// float ms = 0; +// cudaEventElapsedTime(&ms, start, stop); +// double gb = N * sizeof(int) / double(1 << 30); +// std::cout << "Pinned Host Mem Read Throughput: " << gb / (ms / 1000.0) << +// " GB/s\n"; + +// // Cleanup +// cudaHostUnregister(host_data); +// free(host_data); +// cudaStreamDestroy(stream); +// cudaEventDestroy(start); +// cudaEventDestroy(stop); + +// return 0; +// } From 54407b11f65729ea55ad4cea235e2357bbbd3bc1 Mon Sep 17 00:00:00 2001 From: xly Date: Sun, 6 Jul 2025 19:53:42 +0100 Subject: [PATCH 2/3] add kernel compilation --- CITATIONS.md | 1 + README.md | 1 + core/common/context.h | 69 +++ core/common/generator.h | 5 + core/common/pytorch.h | 10 + core/kernel/activation_kernels.cu | 215 ++++++++ core/kernel/common_device.h | 119 ++++ core/kernel/dispatch_utils.h | 9 + core/kernel/fused_mlp.cu | 828 ++++++++++++++++++++++++++++ core/kernel/masked_select.h | 494 +++++++++++++++++ core/kernel/ops.h | 36 ++ core/kernel/topk_softmax_kernels.cu | 536 ++++++++++++++++++ core/kernel/torch_bindings.h | 0 core/model/moe.h | 181 ++++++ core/parallel/expert_dispatcher.cpp | 70 +-- core/parallel/expert_module.cpp | 15 +- core/parallel/expert_module.h | 9 - examples/readme_example.py | 2 +- op_builder/builder.py | 12 +- op_builder/prefetch.py | 2 + tests/cuda/CMakeLists.txt | 4 +- tests/cuda/test_fused_mlp_wmma.cu | 405 ++++++++++++++ tests/cuda/tests_masked_select.cu | 358 ++++++++++++ 23 files changed, 3288 insertions(+), 93 deletions(-) create mode 100644 core/common/context.h create mode 100644 core/kernel/activation_kernels.cu create mode 100644 core/kernel/common_device.h create mode 100644 core/kernel/dispatch_utils.h create mode 100644 core/kernel/fused_mlp.cu create mode 100644 core/kernel/masked_select.h create mode 100644 core/kernel/ops.h create mode 100644 core/kernel/topk_softmax_kernels.cu create mode 100644 core/kernel/torch_bindings.h create mode 100644 core/model/moe.h create mode 100644 tests/cuda/test_fused_mlp_wmma.cu create mode 100644 tests/cuda/tests_masked_select.cu diff --git a/CITATIONS.md b/CITATIONS.md index 979e99d..bfdbdd5 100644 --- a/CITATIONS.md +++ b/CITATIONS.md @@ -3,6 +3,7 @@ author = {Leyang Xue and Yao Fu and Zhan Lu and + Chuanhao Sun and Luo Mai and Mahesh Marina}, title = {MoE-Infinity: Efficient MoE Inference on Personal Machines with Sparsity-Aware Expert Cache}, diff --git a/README.md b/README.md index bc29eab..3d3f834 100644 --- a/README.md +++ b/README.md @@ -204,6 +204,7 @@ If you use MoE-Inifity for your research, please cite our [paper](https://arxiv. author = {Leyang Xue and Yao Fu and Zhan Lu and + Chuanhao Sun and Luo Mai and Mahesh Marina}, title = {MoE{-}Infinity: Efficient MoE Inference on Personal Machines with Sparsity-Aware Expert Cache}, diff --git a/core/common/context.h b/core/common/context.h new file mode 100644 index 0000000..0c085d3 --- /dev/null +++ b/core/common/context.h @@ -0,0 +1,69 @@ +// Copyright (c) EfficientMoE. +// SPDX-License-Identifier: Apache-2.0 + +// EfficientMoE Team + +#pragma once + +#include +#include +#include + +enum class DataType { BFLOAT16 = 0, FLOAT32 = 1, FLOAT16 = 2, FP8_E4M3FN = 3 }; + +struct Context { + // Add any necessary member variables or methods here + int64_t max_expert_tokens = 128; // Default maximum expert tokens + int64_t max_tokens = 4096; // Default maximum tokens + int num_experts = 8; // Default number of experts + int topk = 2; // Default top-k value + int64_t hidden_dim = 1024; // Default hidden dimension + int64_t intermediate_dim = + 4096; // Default intermediate dimension for experts + DataType dtype = DataType::FLOAT32; // Default data type + + void SetFromDict(const std::unordered_map& dict) { + if (dict.find("max_expert_tokens") != dict.end()) { + max_expert_tokens = dict.at("max_expert_tokens"); + } + if (dict.find("max_tokens") != dict.end()) { + max_tokens = dict.at("max_tokens"); + } + if (dict.find("num_experts") != dict.end()) { + num_experts = dict.at("num_experts"); + } + if (dict.find("topk") != dict.end()) { + topk = dict.at("topk"); + } + if (dict.find("hidden_dim") != dict.end()) { + hidden_dim = dict.at("hidden_dim"); + } + if (dict.find("intermediate_dim") != dict.end()) { + intermediate_dim = dict.at("intermediate_dim"); + } + if (dict.find("dtype") != dict.end()) { + int dtype_value = dict.at("dtype"); + switch (dtype_value) { + case 0: + dtype = DataType::BFLOAT16; + break; + case 1: + dtype = DataType::FLOAT32; + break; + case 2: + dtype = DataType::FLOAT16; + break; + case 3: + dtype = DataType::FP8_E4M3FN; + break; + default: + throw std::invalid_argument("Invalid dtype value"); + } + } + } +}; + +Context& getContext() { + static Context instance; + return instance; +} diff --git a/core/common/generator.h b/core/common/generator.h index ca4ec33..549d264 100644 --- a/core/common/generator.h +++ b/core/common/generator.h @@ -1,3 +1,8 @@ +// Copyright (c) EfficientMoE. +// SPDX-License-Identifier: Apache-2.0 + +// EfficientMoE Team + #pragma once #include diff --git a/core/common/pytorch.h b/core/common/pytorch.h index 9f40ccf..2dcc269 100644 --- a/core/common/pytorch.h +++ b/core/common/pytorch.h @@ -7,6 +7,7 @@ #include #include "aio/archer_prio_aio_handle.h" +#include "types.h" #include "base/noncopyable.h" #define CPU_DEVICE torch::Device(torch::kCPU) @@ -27,6 +28,15 @@ #define INT64_TENSOR_OPTIONS(target) TENSOR_OPTIONS(torch::kInt64, target) #define BFLOAT16_TENSOR_OPTIONS(target) TENSOR_OPTIONS(torch::kBFloat16, target) +#define TENSOR_FROM_BLOB(blob, shape, dtype, target) \ + torch::from_blob(blob, shape, DoNothingDeleter{}, \ + TENSOR_OPTIONS(dtype, target)) + +// when dtype is a cpp type use type trait to get the torch dtype +#define TENSOR_FROM_BLOB_CPP(blob, shape, dtype, target) \ + torch::from_blob(blob, shape, DoNothingDeleter{}, \ + TENSOR_OPTIONS(torch::ScalarType(dtype), target)) + #define FAKE_TENSOR_SIZES torch::IntArrayRef({1}) inline std::vector list_to_vector(py::list list) { diff --git a/core/kernel/activation_kernels.cu b/core/kernel/activation_kernels.cu new file mode 100644 index 0000000..cac5aef --- /dev/null +++ b/core/kernel/activation_kernels.cu @@ -0,0 +1,215 @@ +// Adapted from +// https://github.dev/vllm-project/vllm/blob/main/csrc/moe/topk_softmax_kernels.cu +// Copyright (c) EfficientMoE. +// SPDX-License-Identifier: Apache-2.0 + +// EfficientMoE Team + +#include +#include +#include + +#include + +#include "ops.h" +#include "dispatch_utils.h" + +template +__device__ __forceinline__ scalar_t compute(const scalar_t& x, + const scalar_t& y) { + return act_first ? ACT_FN(x) * y : x * ACT_FN(y); +} +// Activation and gating kernel template. + +template +__global__ void act_and_mul_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); + const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]); + out[token_idx * d + idx] = compute(x, y); + } +} + +template +__device__ __forceinline__ T silu_kernel(const T& x) { + // x * sigmoid(x) + return (T)(((float)x) / (1.0f + expf((float)-x))); +} + +template +__device__ __forceinline__ T gelu_kernel(const T& x) { + // Equivalent to PyTorch GELU with 'none' approximation. + // Refer to: + // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 + const float f = (float)x; + constexpr float ALPHA = M_SQRT1_2; + return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA))); +} + +template +__device__ __forceinline__ T gelu_tanh_kernel(const T& x) { + // Equivalent to PyTorch GELU with 'tanh' approximation. + // Refer to: + // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 + const float f = (float)x; + constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; + constexpr float KAPPA = 0.044715; + float x_cube = f * f * f; + float inner = BETA * (f + KAPPA * x_cube); + return (T)(0.5f * f * (1.0f + ::tanhf(inner))); +} + +// Launch activation and gating kernel. +// Use ACT_FIRST (bool) indicating whether to apply the activation function +// first. +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + if (num_tokens == 0) { \ + return; \ + } \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + DISPATCH_FLOATING_TYPES(input.scalar_type(), "act_and_mul_kernel", [&] { \ + act_and_mul_kernel, ACT_FIRST> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); + +void silu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + LAUNCH_ACTIVATION_GATE_KERNEL(silu_kernel, true); +} + +void mul_and_silu(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + // The difference between mul_and_silu and silu_and_mul is that mul_and_silu + // applies the silu to the latter half of the input. + LAUNCH_ACTIVATION_GATE_KERNEL(silu_kernel, false); +} + +void gelu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + LAUNCH_ACTIVATION_GATE_KERNEL(gelu_kernel, true); +} + +void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + LAUNCH_ACTIVATION_GATE_KERNEL(gelu_tanh_kernel, true); +} + +template +__device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) { + const float f = (float)x; + return (T)(f > threshold ? f : 0.0f); +} + +template +__global__ void act_and_mul_kernel_with_param( + scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d, + const float param) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); + const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]); + out[token_idx * d + idx] = ACT_FN(x, param) * y; + } +} + +#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \ + act_and_mul_kernel_with_param> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d, \ + PARAM); \ + }); + +void fatrelu_and_mul(torch::Tensor& out, // [..., d], + torch::Tensor& input, // [..., 2 * d] + double threshold) { + LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(fatrelu_kernel, threshold); +} + +// Element-wise activation kernel template. +template +__global__ void activation_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., d] + const int d) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = __ldg(&input[token_idx * d + idx]); + out[token_idx * d + idx] = ACT_FN(x); + } +} + +// Launch element-wise activation kernel. +#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ + int d = input.size(-1); \ + int64_t num_tokens = input.numel() / d; \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \ + activation_kernel><<>>( \ + out.data_ptr(), input.data_ptr(), d); \ + }); + +template +__device__ __forceinline__ T gelu_new_kernel(const T& x) { + const float x3 = (float)(x * x * x); + const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3)))); + return ((T)0.5) * x * (((T)1.0) + t); +} + +template +__device__ __forceinline__ T gelu_fast_kernel(const T& x) { + const float f = (float)x; + const T t = + (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x)); + return ((T)0.5) * x * (((T)1.0) + t); +} + +template +__device__ __forceinline__ T gelu_quick_kernel(const T& x) { + // x * sigmoid(1.702 * x) + return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x))); +} + +void gelu_new(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] +{ + LAUNCH_ACTIVATION_KERNEL(gelu_new_kernel); +} + +void gelu_fast(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] +{ + LAUNCH_ACTIVATION_KERNEL(gelu_fast_kernel); +} + +void gelu_quick(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] +{ + LAUNCH_ACTIVATION_KERNEL(gelu_quick_kernel); +} diff --git a/core/kernel/common_device.h b/core/kernel/common_device.h new file mode 100644 index 0000000..610b929 --- /dev/null +++ b/core/kernel/common_device.h @@ -0,0 +1,119 @@ +#pragma once + +#include +#include +#include + +enum class ActFunc { + SiLU, + ReLU, + GeLU, +}; + +// Device activation function implementations +template +__device__ __forceinline__ T relu(T x) { + return fmaxf(x, T(0.0f)); +} + +template +__device__ __forceinline__ T silu(T x) { + return x / (T(1.0f) + expf(-x)); +} + +template +__device__ __forceinline__ T gelu(T x) { + // Approximation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + const T sqrt_2_over_pi = T(0.7978845608f); + const T coeff = T(0.044715f); + T x_cubed = x * x * x; + T inner = sqrt_2_over_pi * (x + coeff * x_cubed); + return T(0.5f) * x * (T(1.0f) + tanhf(inner)); +} + +// Specializations for __nv_bfloat16 +template <> +__device__ __forceinline__ __nv_bfloat16 relu(__nv_bfloat16 x) { + return __hmax(x, __float2bfloat16(0.0f)); +} + +template <> +__device__ __forceinline__ __nv_bfloat16 silu(__nv_bfloat16 x) { + float x_f = __bfloat162float(x); + float result = x_f / (1.0f + expf(-x_f)); + return __float2bfloat16(result); +} + +template <> +__device__ __forceinline__ __nv_bfloat16 gelu(__nv_bfloat16 x) { + float x_f = __bfloat162float(x); + const float sqrt_2_over_pi = 0.7978845608f; + const float coeff = 0.044715f; + float x_cubed = x_f * x_f * x_f; + float inner = sqrt_2_over_pi * (x_f + coeff * x_cubed); + float result = 0.5f * x_f * (1.0f + tanhf(inner)); + return __float2bfloat16(result); +} + +// Specializations for half precision +#ifdef __CUDA_ARCH__ +template <> +__device__ __forceinline__ half relu(half x) { + return __hmax(x, __float2half(0.0f)); +} + +template <> +__device__ __forceinline__ half silu(half x) { + float x_f = __half2float(x); + float result = x_f / (1.0f + expf(-x_f)); + return __float2half(result); +} + +template <> +__device__ __forceinline__ half gelu(half x) { + float x_f = __half2float(x); + const float sqrt_2_over_pi = 0.7978845608f; + const float coeff = 0.044715f; + float x_cubed = x_f * x_f * x_f; + float inner = sqrt_2_over_pi * (x_f + coeff * x_cubed); + float result = 0.5f * x_f * (1.0f + tanhf(inner)); + return __float2half(result); +} +#endif + +template +__host__ __device__ void warp_activation(ActFunc activation, + const fragment_t& frag, + fragment_t& result) { + switch (activation) { + case ActFunc::ReLU: +#pragma unroll + for (int t = 0; t < result.num_elements; t++) { + result.x[t] = relu(static_cast(frag.x[t])); + } + return; + case ActFunc::SiLU: +#pragma unroll + for (int t = 0; t < result.num_elements; t++) { + result.x[t] = silu(static_cast(frag.x[t])); + } + return; + case ActFunc::GeLU: +#pragma unroll + for (int t = 0; t < result.num_elements; t++) { + result.x[t] = gelu(static_cast(frag.x[t])); + } + return; + default: + // Unsupported activation +#ifdef __CUDA_ARCH__ + // On device, we can't use assert, so we'll set result to zero + for (int t = 0; t < result.num_elements; t++) { + result.x[t] = T(0); + } +#else + assert(false && "Unsupported activation function"); +#endif + return; + } +} diff --git a/core/kernel/dispatch_utils.h b/core/kernel/dispatch_utils.h new file mode 100644 index 0000000..58a7f8a --- /dev/null +++ b/core/kernel/dispatch_utils.h @@ -0,0 +1,9 @@ +#pragma once + +#define DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) diff --git a/core/kernel/fused_mlp.cu b/core/kernel/fused_mlp.cu new file mode 100644 index 0000000..46915a8 --- /dev/null +++ b/core/kernel/fused_mlp.cu @@ -0,0 +1,828 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its + * contributors may be used to endorse or promote products derived from this + * software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/** @file fully_fused_mlp.cu + * @author Thomas Müller and Nikolaus Binder, NVIDIA + * @brief Fully fused CUDA implementation of a multi-layer perceptron. + * Supports online training and simultaneous inference. + */ + +#include +#include + +#include "common_device.h" + +void check_shmem_error(cudaError_t error) { + if (error != cudaSuccess) { + throw std::runtime_error{ + "FullyFusedMLP: insufficient shared memory available on the GPU."}; + } +} + +template +__device__ void threadblock_layer( + ActFunc activation, __nv_bfloat16* __restrict__ act_shmem, + const __nv_bfloat16* __restrict__ weights_this_layer, + OUT_T* __restrict__ out_intermediate_threadblock_this_layer, + const OUT_T* __restrict__ activation_aux = nullptr) { + // act_shmem contains the intermediate activations (shared memory) of the + // thread block's chunk of the batch. + // Can be forward activations or backward activations, depending on + // caller. + // weights_this_layer points to the weight matrix of the current layer. + // out_intermediate_threadblock_this_layer points to the location where + // intermediate activations produced by the thread block should be written to. + // Can be nullptr if nothing should be written. + // activation_aux points to additional arguments that the activation function + // may depend on. Points to the hidden forward activations when computing + // backward activations. + + constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; + constexpr uint32_t N_BLOCKS = WIDTH / 16; + + using namespace nvcuda; + + // If we're performing the backward pass, weights must be loaded in transposed + // form, which is achieved by interpreting the memory in row_major instead of + // col_major order. + using weights_layout_t = + std::conditional_t; + + // Fragments + wmma::fragment + act_frag; + wmma::fragment + weights_frag[N_BLOCKS]; + wmma::fragment result_frag[N_ITERS]; + + // Indices + const uint32_t li = threadIdx.x; // index in warp ("lane index") + const uint32_t wi = threadIdx.y; // index in block ("warp index") + + const uint32_t lane_offset = (8 * li) % WIDTH; + const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH; + + const uint32_t weights_col = 16 * wi; + + __syncthreads(); + +// Load N_BLOCKS chunks of weights from global memory into registers. +#pragma unroll + for (uint32_t i = 0; i < N_BLOCKS; ++i) { + wmma::load_matrix_sync(weights_frag[i], + weights_this_layer + 16 * i + weights_col * WIDTH, + WIDTH); + } + +#pragma unroll + for (int l = 0; l < N_ITERS; ++l) { + wmma::fill_fragment(result_frag[l], 0.0f); + +#pragma unroll + for (uint32_t i = 0; i < N_BLOCKS; ++i) { + // Load a chunk of intermediate activations from shared memory and + // multiply with chunk of weights + wmma::load_matrix_sync(act_frag, + act_shmem + 16 * i + (16 * l) * (WIDTH + SKEW), + WIDTH + SKEW); + wmma::mma_sync(result_frag[l], act_frag, weights_frag[i], result_frag[l]); + } + + // ActFunc + warp_activation<__nv_bfloat16>(activation, result_frag[l], result_frag[l]); + } + + __syncthreads(); + +#pragma unroll + for (int l = 0; l < N_ITERS; ++l) { + wmma::store_matrix_sync(act_shmem + weights_col + l * 16 * (WIDTH + SKEW), + result_frag[l], WIDTH + SKEW, wmma::mem_row_major); + } + + if (out_intermediate_threadblock_this_layer != nullptr) { + __syncthreads(); + +#pragma unroll + for (int l = 0; l < N_ITERS; ++l) { + *(int4*)&out_intermediate_threadblock_this_layer[lane_offset + + (row + 16 * l) * WIDTH] = + *(int4*)&act_shmem[lane_offset + (row + 16 * l) * (WIDTH + SKEW)]; + } + } +} + +template +__device__ void threadblock_load_input_static( + __nv_bfloat16* __restrict__ act_shmem, + const __nv_bfloat16* __restrict__ input_threadblock) { + // act_shmem will be filled by the thread block's chunk of input_threadblock + + constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; + + // Indices + const uint32_t li = threadIdx.x; // index in warp ("lane index") + const uint32_t wi = threadIdx.y; // index in block ("warp index") + + const uint32_t lane_offset = (8 * li) % WIDTH; + const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH; + +#pragma unroll + for (int i = 0; i < N_ITERS; ++i) { + *(int4*)&act_shmem[lane_offset + (row + 16 * i) * (WIDTH + SKEW)] = + *(int4*)&input_threadblock[lane_offset + (row + 16 * i) * WIDTH]; + } +} + +template +__device__ void threadblock_input_layer_forward_dynamic( + ActFunc activation, __nv_bfloat16* __restrict__ act_shmem, + const __nv_bfloat16* __restrict__ input_threadblock, + const __nv_bfloat16* __restrict__ weights_this_layer, + OUT_T* __restrict__ out_intermediate_threadblock_this_layer, + const uint32_t in_width, const uint32_t batch_size) { + // act_shmem contains the intermediate activations (shared memory) of the + // thread block's chunk of the batch input_threadblock points to the thread + // block's chunk of the input batch in global memory weights_this_layer points + // to the weight matrix of the current layer + // out_intermediate_threadblock_this_layer points to the location where + // intermediate activations produced by the thread block should be written to. + // Can be nullptr if nothing should be written. + // in_width is the dynamic width of the input layer + + constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; + constexpr uint32_t INPUT_SKEW = 8; + constexpr uint32_t N_BLOCKS = WIDTH / 16; + + using namespace nvcuda; + + // Fragments + wmma::fragment + act_frag; + wmma::fragment + weights_frag; + wmma::fragment result_frag[N_ITERS]; + + // Indices + const uint32_t li = threadIdx.x; // index in warp ("lane index") + const uint32_t wi = threadIdx.y; // index in block ("warp index") + + const uint32_t lane_offset = (8 * li) % WIDTH; + const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH; + + const uint32_t weights_col = 16 * wi; + + __nv_bfloat16* __restrict__ weights_shmem = + act_shmem + 16 * (in_width + INPUT_SKEW); + + // Load input weight matrix (fits completely into shared memory) + // Each thread can load 8 fp16 elements (16 bytes) at once; we have N_BLOCKS + // warps + const uint32_t n_elems_per_load = N_BLOCKS * 32 * 8; + const uint32_t thread_elem_idx = (li + wi * 32) * 8; + + const uint32_t n_elems_b = WIDTH * in_width; + +#pragma unroll + for (uint32_t idx = thread_elem_idx; idx < n_elems_b; + idx += n_elems_per_load) { + const uint32_t idx_skewed = idx + idx / in_width * INPUT_SKEW; + *(int4*)&weights_shmem[idx_skewed] = *(int4*)&weights_this_layer[idx]; + } + + const uint32_t n_tensor_ops = in_width / 16; + + if (std::is_same::value) { + __syncthreads(); + } + +#pragma unroll + for (int l = 0; l < N_ITERS; ++l) { + if (std::is_same::value) { + // Load chunk of inputs into shmem. + // This is faster than loading it from gmem directly, even though it is + // only used once. (Possibly due to latency hiding through staging.) + const uint32_t n_elems_a = 16 * in_width; + +#pragma unroll + for (uint32_t idx = thread_elem_idx; idx < n_elems_a; + idx += n_elems_per_load) { + const uint32_t idx_skewed = idx + idx / in_width * INPUT_SKEW; + *(int4*)&act_shmem[idx_skewed] = + *(int4*)&input_threadblock[l * n_elems_a + idx]; + } + + __syncthreads(); + } + + wmma::fill_fragment(result_frag[l], 0.0f); +#pragma unroll + for (uint32_t i = 0; i < n_tensor_ops; ++i) { + // Load chunk of inputs and weights from shared memory and multiply them + if (std::is_same::value) { + wmma::load_matrix_sync(act_frag, act_shmem + 16 * i, + in_width + INPUT_SKEW); + } else { + wmma::load_matrix_sync(act_frag, + input_threadblock + 16 * i * batch_size + 16 * l, + batch_size); + } + wmma::load_matrix_sync( + weights_frag, + weights_shmem + 16 * i + weights_col * (in_width + INPUT_SKEW), + in_width + INPUT_SKEW); + wmma::mma_sync(result_frag[l], act_frag, weights_frag, result_frag[l]); + } + + if (std::is_same::value) { + __syncthreads(); + } + + warp_activation<__nv_bfloat16>(activation, result_frag[l], result_frag[l]); + } + + if (std::is_same::value) { + __syncthreads(); + } + +#pragma unroll + for (int l = 0; l < N_ITERS; ++l) { + wmma::store_matrix_sync(act_shmem + weights_col + (16 * l) * (WIDTH + SKEW), + result_frag[l], WIDTH + SKEW, wmma::mem_row_major); + } + + if (out_intermediate_threadblock_this_layer != nullptr) { + __syncthreads(); + +#pragma unroll + for (int i = 0; i < N_ITERS; ++i) { + *(int4*)&out_intermediate_threadblock_this_layer[lane_offset + + (row + 16 * i) * WIDTH] = + *(int4*)&act_shmem[lane_offset + (row + 16 * i) * (WIDTH + SKEW)]; + } + } +} + +template +__device__ void threadblock_last_layer_forward( + ActFunc activation, __nv_bfloat16* __restrict__ act_shmem, + const __nv_bfloat16* __restrict__ weights_this_layer, + OUT_T* __restrict__ out, const uint32_t output_stride, + const nvcuda::wmma::layout_t output_layout) { + // act_shmem contains the intermediate activations (shared memory) of the + // thread block's chunk of the batch weights_this_layer points to the weight + // matrix of the current layer out points to the location where the result + // produced by the thread block should be written to. + // Can be nullptr if nothing should be written. + + constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; + constexpr uint32_t N_BLOCKS = WIDTH / 16; + + using namespace nvcuda; + + // Fragments + wmma::fragment + act_frag; + wmma::fragment + weights_frag[N_BLOCKS]; + wmma::fragment result_frag; + + // Indices + const uint32_t li = threadIdx.x; // index in warp ("lane index") + const uint32_t wi = threadIdx.y; // index in block ("warp index") + + __nv_bfloat16* __restrict__ weights_shmem = + act_shmem + N_ITERS * 16 * (WIDTH + SKEW); + + const uint32_t weights_row = (8 * li) % WIDTH; + const uint32_t weights_col = (8 * li + 8 * 32 * wi) / WIDTH; + + // Load weight matrix into shared memory for the last multiplication. + // Loading into shared memory as opposed to directly into registers is faster + // because unlike in the previous layers, each warp uses the same entries of + // the weight matrix. + *(int4*)&weights_shmem[weights_row + weights_col * (WIDTH + SKEW)] = + *(int4*)&weights_this_layer[weights_row + weights_col * WIDTH]; + + __syncthreads(); + +#pragma unroll + for (uint32_t i = 0; i < N_BLOCKS; ++i) + wmma::load_matrix_sync(weights_frag[i], weights_shmem + 16 * i, + WIDTH + SKEW); + + // Perform last layer by parallelizing over iters + for (uint32_t idx = wi; idx < N_ITERS; idx += N_BLOCKS) { + wmma::fill_fragment(result_frag, 0.0f); +#pragma unroll + for (uint32_t i = 0; i < N_BLOCKS; ++i) { + // Load a chunk of intermediate activations from shared memory and + // multiply with chunk of the weight matrix + wmma::load_matrix_sync(act_frag, + act_shmem + 16 * i + (16 * idx) * (WIDTH + SKEW), + WIDTH + SKEW); + wmma::mma_sync(result_frag, act_frag, weights_frag[i], result_frag); + } + + warp_activation<__nv_bfloat16>(activation, result_frag, result_frag); + + if (output_layout == wmma::mem_row_major) { + wmma::store_matrix_sync(out + idx * 16 * output_stride, result_frag, + output_stride, output_layout); + } else { + wmma::store_matrix_sync(out + idx * 16, result_frag, output_stride, + output_layout); + } + } +} + +template +__device__ void threadblock_write_output_static( + const __nv_bfloat16* __restrict__ act_shmem, + __nv_bfloat16* __restrict__ output_threadblock) { + // output_threadblock will be filled by the thread block's act_shmem + + constexpr uint32_t SKEW = WIDTH % 16 == 0 ? 8 : 0; + + // Indices + const uint32_t li = threadIdx.x; // index in warp ("lane index") + const uint32_t wi = threadIdx.y; // index in block ("warp index") + + const uint32_t lane_offset = (8 * li) % WIDTH; + const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH; + + __syncthreads(); + +#pragma unroll + for (int i = 0; i < N_ITERS; ++i) { + *(int4*)&output_threadblock[lane_offset + (row + 16 * i) * WIDTH] = + *(int4*)&act_shmem[lane_offset + (row + 16 * i) * (WIDTH + SKEW)]; + } +} + +template +__global__ void kernel_mlp_fused( + const ActFunc output_activation, const __nv_bfloat16* __restrict__ input, + const __nv_bfloat16* __restrict__ weights, + OUT_T* __restrict__ out_intermediate, OUT_T* __restrict__ out, + const uint32_t output_stride, const uint32_t batch_size, + const uint32_t in_width, const uint32_t out_width, + const uint32_t n_hidden_matmuls, const nvcuda::wmma::layout_t input_layout, + const nvcuda::wmma::layout_t output_layout) { + // `input` points to the input matrix. Can be any width. + // `weights` points to the weight matrices (contiguous in memory). + // `out_intermediate` points to the memory where intermediate activations + // should be written. When performing inference, a value of nullptr is + // expected (intermediate results are not written). `out` points to the memory + // where the network output should be written. (Output width is assumed to be + // 16 neurons.) + + // Commented out due to isolated strange side-effects on Windows + // if (INFERENCE) { + // assert(out_intermediate == nullptr); + // } else { + // assert(out_intermediate); + // } + + // Shared memory contains the intermediate activations of blockDim.y*16 + // elements. In some cases, it also contains the weight matrix for the first + // and last layer. + extern __shared__ __nv_bfloat16 shmem[]; + __nv_bfloat16* act_shmem = shmem; + + // Each block computes exactly one 16-element chunk of the batch. + const uint32_t elem_idx = 16 * blockIdx.x * N_ITERS; + + // First layer + if (input_layout == nvcuda::wmma::mem_col_major || in_width != WIDTH) { + if (input_layout == nvcuda::wmma::mem_row_major) { + threadblock_input_layer_forward_dynamic( + ACTIVATION, act_shmem, input + elem_idx * in_width, weights, + !INFERENCE ? (out_intermediate + elem_idx * WIDTH) : nullptr, + in_width, batch_size); + } else { + threadblock_input_layer_forward_dynamic( + ACTIVATION, act_shmem, input + elem_idx, weights, + !INFERENCE ? (out_intermediate + elem_idx * WIDTH) : nullptr, + in_width, batch_size); + } + } else { + // If the input has the same width & layout as the hidden layers, we can + // simply use the network's regular layer routine (with static size) instead + // of using the slower dynamic input layer routine. + threadblock_load_input_static(act_shmem, + input + elem_idx * WIDTH); + threadblock_layer( + ACTIVATION, act_shmem, weights, + !INFERENCE ? (out_intermediate + elem_idx * WIDTH) : nullptr); + } + + const uint32_t first_weights_stride = WIDTH * in_width; + const uint32_t weights_stride = WIDTH * WIDTH; + const uint32_t layer_stride = WIDTH * batch_size; + + // Hidden layers + for (uint32_t k = 0; k < n_hidden_matmuls; ++k) { + threadblock_layer( + ACTIVATION, act_shmem, + weights + first_weights_stride + weights_stride * k, + !INFERENCE + ? (out_intermediate + layer_stride * (k + 1) + elem_idx * WIDTH) + : nullptr); + } + + if (out_width > 16) { + // In the forward pass, intermediate activations are already written out. + if (INFERENCE) { + threadblock_write_output_static( + act_shmem, out_intermediate + elem_idx * WIDTH); + } + } else if (out) { + // Last layer + if (output_layout == nvcuda::wmma::mem_row_major) { + threadblock_last_layer_forward( + output_activation, act_shmem, + weights + first_weights_stride + weights_stride * n_hidden_matmuls, + out + elem_idx * output_stride, output_stride, output_layout); + } else { + threadblock_last_layer_forward( + output_activation, act_shmem, + weights + first_weights_stride + weights_stride * n_hidden_matmuls, + out + elem_idx, output_stride, output_layout); + } + } +} + +template +std::enable_if_t::value> mlp_fused_forward( + cudaStream_t stream, ActFunc output_activation, + const GPUMatrix& weights, const GPUMatrixDynamic& input, + GPUMatrix& output_intermediate, GPUMatrixDynamic* output, + const uint32_t n_hidden_layers) { + throw std::runtime_error{ + "The fully fused forward pass only supports __nv_bfloat16 precision."}; +} + +template +std::enable_if_t::value> mlp_fused_forward( + cudaStream_t stream, ActFunc output_activation, + const GPUMatrix& weights, const GPUMatrixDynamic& input, + GPUMatrix& output_intermediate, GPUMatrixDynamic* output, + const uint32_t n_hidden_layers) { + const uint32_t batch_size = input.cols(); + const uint32_t in_width = input.rows(); + + constexpr uint32_t SKEW = + WIDTH % 16 == 0 ? 8 : 0; // <- always going to be 8 as we only support + // multiple-of-16 widths + constexpr uint32_t INPUT_SKEW = 8; // <- likewise with inputs + constexpr uint32_t N_BLOCK_ROWS = WIDTH / 16; + + static_assert(WIDTH % 16 == 0, "Width must be a multiply of 16."); + + CHECK_THROW(in_width % 16 == 0); + CHECK_THROW(weights.rows() == WIDTH); + CHECK_THROW(weights.cols() % 16 == 0); + CHECK_THROW(output_intermediate.cols() == batch_size); + CHECK_THROW(!output || output->cols() == batch_size); + CHECK_THROW(input.layout() == RM || input.stride() == input.m()); + + const int N_ITERS = WIDTH >= 256 ? 2 : 8; + + if (batch_size % (16 * N_ITERS) != 0) { + throw std::runtime_error{ + fmt::format("Batch size must be a multiple of {}.", 16 * N_ITERS)}; + } + + const dim3 threads = { + 32u, N_BLOCK_ROWS, + 1}; // 32 threads = 1 warp, N_BLOCK_ROWS warps per block for 16 rows, up + // to 2x 8 warps can share input (does not help vs. 1) + + uint32_t n_elems_per_block = 16 * N_ITERS; + uint32_t n_blocks = div_round_up(batch_size, n_elems_per_block); + + size_t shmem_size = + sizeof(__nv_bfloat16) * (16 + 16 * N_ITERS) * + (WIDTH + SKEW); // 16*WIDTH rows of weights (for the last layer; others + // are in registers only) + 16*WIDTH*N_ITERS rows of + // intermediate activations + if (in_width != WIDTH || input.layout() == RM) { + // If the input width is dynamic, the input weight matrix as well as part of + // the input will live in extra shared memory + shmem_size = std::max(shmem_size, sizeof(__nv_bfloat16) * (WIDTH + 16) * + (in_width + INPUT_SKEW)); + } + + const dim3 blocks = {n_blocks, 1u, 1u}; + + check_shmem_error(cudaFuncSetAttribute( + kernel_mlp_fused, + cudaFuncAttributeMaxDynamicSharedMemorySize, (int)shmem_size)); + kernel_mlp_fused<<>>( + output_activation, input.data(), weights.data(), + output_intermediate.data(), output ? output->data() : nullptr, + output ? output->stride() : 0, batch_size, in_width, + output ? output->rows() : 0, n_hidden_layers, + // The kernels operate with transposed layouts compared with the MLP code + input.layout() == RM ? nvcuda::wmma::mem_col_major + : nvcuda::wmma::mem_row_major, + output && output->layout() == RM ? nvcuda::wmma::mem_col_major + : nvcuda::wmma::mem_row_major); +} + +template +FullyFusedMLP::FullyFusedMLP(uint32_t input_width, + uint32_t output_width, + uint32_t n_hidden_layers, + ActFunc activation, + ActFunc output_activation) + : m_input_width{input_width}, + m_network_width{WIDTH}, + m_output_width{output_width}, + m_n_hidden_layers{n_hidden_layers}, + m_activation{activation}, + m_output_activation{output_activation} { + if (m_n_hidden_layers <= 0) { + throw std::runtime_error( + "FullyFusedMLP requires at least 1 hidden layer (3 layers in total)."); + } + + m_n_hidden_matmuls = n_hidden_layers - 1; + + m_padded_output_width = next_multiple(m_output_width, REQUIRED_ALIGNMENT()); + + // Create matrices related to weights + m_weight_matrices.emplace_back(nullptr, m_network_width, m_input_width); + m_weight_matrices_inference.emplace_back(nullptr, m_network_width, + m_input_width); + m_gradient_matrices.emplace_back(nullptr, m_network_width, m_input_width); + + for (uint32_t i = 0; i < m_n_hidden_matmuls; ++i) { + m_weight_matrices.emplace_back(nullptr, m_network_width, m_network_width); + m_weight_matrices_inference.emplace_back(nullptr, m_network_width, + m_network_width); + m_gradient_matrices.emplace_back(nullptr, m_network_width, m_network_width); + } + + m_weight_matrices.emplace_back(nullptr, m_padded_output_width, + m_network_width); + m_weight_matrices_inference.emplace_back(nullptr, m_padded_output_width, + m_network_width); + m_gradient_matrices.emplace_back(nullptr, m_padded_output_width, + m_network_width); + + // Determine total number of memory entries and set it + m_total_n_params = 0; + for (const auto& m : m_weight_matrices) { + m_total_n_params += m.n_elements(); + } +} + +template +void FullyFusedMLP::inference_mixed_precision_impl( + cudaStream_t stream, const GPUMatrixDynamic& input, + GPUMatrixDynamic& output, bool use_inference_params) { + // Make sure our temporary buffers have the correct size for the given batch + // size + uint32_t batch_size = input.n(); + + GPUMatrix inference_tmp = + m_output_width > 16 ? GPUMatrix{m_network_width, batch_size, stream} + : GPUMatrix{nullptr, m_network_width, batch_size}; + + // ASSUMPTION: weight matrices are contiguous in memory + switch (m_activation) { + case ActFunc::None: + mlp_fused_forward( + stream, m_output_activation, + input_weight_matrix(use_inference_params), input, inference_tmp, + &output, m_n_hidden_matmuls); + break; + case ActFunc::Exponential: + mlp_fused_forward( + stream, m_output_activation, + input_weight_matrix(use_inference_params), input, inference_tmp, + &output, m_n_hidden_matmuls); + break; + case ActFunc::Sigmoid: + mlp_fused_forward( + stream, m_output_activation, + input_weight_matrix(use_inference_params), input, inference_tmp, + &output, m_n_hidden_matmuls); + break; + case ActFunc::ReLU: + mlp_fused_forward( + stream, m_output_activation, + input_weight_matrix(use_inference_params), input, inference_tmp, + &output, m_n_hidden_matmuls); + break; + case ActFunc::LeakyReLU: + mlp_fused_forward( + stream, m_output_activation, + input_weight_matrix(use_inference_params), input, inference_tmp, + &output, m_n_hidden_matmuls); + break; + case ActFunc::Squareplus: + mlp_fused_forward( + stream, m_output_activation, + input_weight_matrix(use_inference_params), input, inference_tmp, + &output, m_n_hidden_matmuls); + break; + case ActFunc::Softplus: + mlp_fused_forward( + stream, m_output_activation, + input_weight_matrix(use_inference_params), input, inference_tmp, + &output, m_n_hidden_matmuls); + break; + case ActFunc::Tanh: + mlp_fused_forward( + stream, m_output_activation, + input_weight_matrix(use_inference_params), input, inference_tmp, + &output, m_n_hidden_matmuls); + break; + default: + throw std::runtime_error{"Unsupported activation."}; + } + + // If we have more than 16 output dimensions, these will be taken care of by + // CUTLASS rather than the fully fused kernel (which will have written out the + // second-to-last layer activations). + if (m_output_width > 16) { + fc_multiply(stream, output_weight_matrix(use_inference_params), + inference_tmp, output, m_output_activation); + } +} + +template +std::unique_ptr FullyFusedMLP::forward_impl( + cudaStream_t stream, const GPUMatrixDynamic& input, + GPUMatrixDynamic* output, bool use_inference_params, + bool prepare_input_gradients) { + // Make sure our temporary buffers have the correct size for the given batch + // size + uint32_t batch_size = input.n(); + auto forward = allocate_forward_buffers(stream, batch_size); + + // ASSUMPTION: weight matrices & forward_tmp matrices are contiguous in memory + switch (m_activation) { + case ActFunc::None: + mlp_fused_forward( + stream, m_output_activation, + input_weight_matrix(use_inference_params), input, + forward->hidden.at(0), output, m_n_hidden_matmuls); + break; + case ActFunc::Exponential: + mlp_fused_forward( + stream, m_output_activation, + input_weight_matrix(use_inference_params), input, + forward->hidden.at(0), output, m_n_hidden_matmuls); + break; + case ActFunc::Sigmoid: + mlp_fused_forward( + stream, m_output_activation, + input_weight_matrix(use_inference_params), input, + forward->hidden.at(0), output, m_n_hidden_matmuls); + break; + case ActFunc::ReLU: + mlp_fused_forward( + stream, m_output_activation, + input_weight_matrix(use_inference_params), input, + forward->hidden.at(0), output, m_n_hidden_matmuls); + break; + case ActFunc::LeakyReLU: + mlp_fused_forward( + stream, m_output_activation, + input_weight_matrix(use_inference_params), input, + forward->hidden.at(0), output, m_n_hidden_matmuls); + break; + case ActFunc::Squareplus: + mlp_fused_forward( + stream, m_output_activation, + input_weight_matrix(use_inference_params), input, + forward->hidden.at(0), output, m_n_hidden_matmuls); + break; + case ActFunc::Softplus: + mlp_fused_forward( + stream, m_output_activation, + input_weight_matrix(use_inference_params), input, + forward->hidden.at(0), output, m_n_hidden_matmuls); + break; + case ActFunc::Tanh: + mlp_fused_forward( + stream, m_output_activation, + input_weight_matrix(use_inference_params), input, + forward->hidden.at(0), output, m_n_hidden_matmuls); + break; + default: + throw std::runtime_error{"Unsupported activation."}; + } + + // If we have more than 16 output dimensions, these will be taken care of by + // CUTLASS rather than the fully fused kernel (which will have written out the + // second-to-last layer activations). + if (output && m_output_width > 16) { + fc_multiply(stream, output_weight_matrix(use_inference_params), + forward->hidden.back(), *output, + m_output_activation); + } + + return forward; +} + +template +std::unique_ptr::ForwardContext> +FullyFusedMLP::allocate_forward_buffers(cudaStream_t stream, + uint32_t batch_size) { + auto forward = std::make_unique(); + + // Use GPUMatrixBase::allocate_shared_memory to ensure the matrices occupy + // contiguous memory. (Needed in the fully-fused kernels.) + forward->hidden.resize(num_forward_activations()); + for (uint32_t i = 0; i < num_forward_activations(); ++i) { + forward->hidden[i].set_size_unsafe(m_network_width, batch_size); + } + + forward->alloc = + GPUMatrixBase::allocate_shared_memory(stream, forward->hidden); + + return forward; +} + +template +void FullyFusedMLP::set_params_impl(T* params, T* inference_params, + T* gradients) { + size_t current_pos = 0; + for (size_t i = 0; i < m_weight_matrices.size(); ++i) { + m_weight_matrices[i].set_data_unsafe(params + current_pos); + m_weight_matrices_inference[i].set_data_unsafe(inference_params + + current_pos); + m_gradient_matrices[i].set_data_unsafe(gradients + current_pos); + current_pos += m_weight_matrices[i].n_elements(); + } +} + +template +void FullyFusedMLP::initialize_params(pcg32& rnd, + float* params_full_precision, + float scale) { + // Construct weight matrices + std::vector> weight_matrices_full_precision; + weight_matrices_full_precision.emplace_back(params_full_precision, + m_network_width, m_input_width); + params_full_precision += weight_matrices_full_precision.back().n_elements(); + + for (uint32_t i = 0; i < m_n_hidden_matmuls; ++i) { + weight_matrices_full_precision.emplace_back( + params_full_precision, m_network_width, m_network_width); + params_full_precision += weight_matrices_full_precision.back().n_elements(); + } + + weight_matrices_full_precision.emplace_back( + params_full_precision, m_padded_output_width, m_network_width); + + // Initialize matrices + for (size_t i = 0; i < weight_matrices_full_precision.size(); ++i) { + if (m_activation == ActFunc::Sine) { + if (i == 0) { + weight_matrices_full_precision[i].initialize_siren_uniform_first(rnd, + scale); + } else { + weight_matrices_full_precision[i].initialize_siren_uniform(rnd, scale); + } + } else { + weight_matrices_full_precision[i].initialize_xavier_uniform(rnd, scale); + } + } +} + +template class FullyFusedMLP; +template class FullyFusedMLP; +template class FullyFusedMLP; +template class FullyFusedMLP; diff --git a/core/kernel/masked_select.h b/core/kernel/masked_select.h new file mode 100644 index 0000000..6ef4e1b --- /dev/null +++ b/core/kernel/masked_select.h @@ -0,0 +1,494 @@ +#include +#include +#include +#include + +// using namespace nvcuda; +namespace cg = cooperative_groups; + +// Fused kernel: count + collect indices + extract tokens in one pass +__global__ void fused_extract_expert_tokens_bf16( + const __nv_bfloat16* __restrict__ hidden_states, + const bool* __restrict__ router_mask, __nv_bfloat16* __restrict__ output, + int* __restrict__ output_count, const int num_tokens, const int hidden_dim, + const int expert_idx, const int num_experts) { + // Use cooperative groups for better warp-level primitives + auto g = cg::this_thread_block(); + auto warp = cg::tiled_partition<32>(g); + + // Shared memory for warp-level scan + extern __shared__ int shared_data[]; + int* warp_counts = shared_data; + int* block_offset = &shared_data[32]; + + const int warp_id = threadIdx.x / 32; + const int lane_id = threadIdx.x % 32; + const int warps_per_block = blockDim.x / 32; + + // Initialize shared memory + if (threadIdx.x == 0) *block_offset = 0; + if (lane_id == 0) warp_counts[warp_id] = 0; + __syncthreads(); + + // Process tokens in chunks for better memory access + const int tokens_per_thread = + (num_tokens + blockDim.x * gridDim.x - 1) / (blockDim.x * gridDim.x); + const int start_token = + (blockIdx.x * blockDim.x + threadIdx.x) * tokens_per_thread; + const int end_token = min(start_token + tokens_per_thread, num_tokens); + + // Phase 1: Count and mark tokens + int local_count = 0; +#pragma unroll 4 + for (int token_idx = start_token; token_idx < end_token; token_idx++) { + if (router_mask[token_idx * num_experts + expert_idx]) { + local_count++; + } + } + +// Warp-level reduction +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + local_count += __shfl_down_sync(0xffffffff, local_count, offset); + } + + // Store warp count + if (lane_id == 0) { + warp_counts[warp_id] = local_count; + } + __syncthreads(); + + // Block-level scan to get offsets + if (threadIdx.x < warps_per_block) { + int val = warp_counts[threadIdx.x]; +#pragma unroll + for (int i = 1; i < warps_per_block; i <<= 1) { + int n = __shfl_up_sync(0xffffffff, val, i); + if (threadIdx.x >= i) val += n; + } + warp_counts[threadIdx.x] = val; + } + __syncthreads(); + + // Get block's starting position + int block_start = 0; + if (threadIdx.x == warps_per_block - 1) { + block_start = atomicAdd(output_count, warp_counts[warps_per_block - 1]); + *block_offset = block_start; + } + __syncthreads(); + + // Calculate this thread's output offset + int thread_offset = *block_offset; + if (warp_id > 0) { + thread_offset += warp_counts[warp_id - 1]; + } + + // Warp-level scan for lane offsets + int lane_offset = 0; + int lane_count = 0; +#pragma unroll 4 + for (int token_idx = start_token; token_idx < end_token; token_idx++) { + if (router_mask[token_idx * num_experts + expert_idx]) { + lane_count++; + } + } + +#pragma unroll + for (int i = 1; i < 32; i <<= 1) { + int n = __shfl_up_sync(0xffffffff, lane_count, i); + if (lane_id >= i) lane_offset += n; + } + + thread_offset += lane_offset; + + // Phase 2: Extract tokens with coalesced writes + // Use vector loads/stores for better bandwidth utilization + const int vec_size = 8; // Process 8 bf16 elements at once + using vec_t = float4; // 8 bf16s = 4 floats = 128 bits + +#pragma unroll 4 + for (int token_idx = start_token; token_idx < end_token; token_idx++) { + if (router_mask[token_idx * num_experts + expert_idx]) { + const int src_offset = token_idx * hidden_dim; + const int dst_offset = thread_offset * hidden_dim; + + // Vectorized copy + const vec_t* src_vec = + reinterpret_cast(&hidden_states[src_offset]); + vec_t* dst_vec = reinterpret_cast(&output[dst_offset]); + +// Copy in chunks of 8 bf16 elements +#pragma unroll 4 + for (int i = 0; i < hidden_dim / vec_size; i++) { + dst_vec[i] = src_vec[i]; + } + + // Handle remainder + const int remainder = hidden_dim % vec_size; + if (remainder > 0) { + const int base_idx = (hidden_dim / vec_size) * vec_size; +#pragma unroll + for (int i = 0; i < remainder; i++) { + output[dst_offset + base_idx + i] = + hidden_states[src_offset + base_idx + i]; + } + } + + thread_offset++; + } + } +} + +// // Optimized kernel for batch_size > 1 case +// void extract_expert_tokens_fused_cuda(torch::Tensor hidden_states, +// torch::Tensor router_mask, +// torch::Tensor output, +// torch::Tensor output_count, +// int expert_idx, int batch_size) { +// // Skip if batch_size == 1 +// if (batch_size == 1) { +// output.copy_(hidden_states); +// return; +// } + +// const int num_tokens = hidden_states.size(0); +// const int hidden_dim = hidden_states.size(1); +// const int num_experts = router_mask.size(1); + +// // Reset output count +// cudaMemset(output_count.data_ptr(), 0, sizeof(int)); + +// // Configure kernel launch +// const int threads = 256; +// const int warps_per_block = threads / 32; +// const int blocks = min(65535, (num_tokens + threads - 1) / threads); +// const int smem_size = sizeof(int) * (warps_per_block + 1); + +// // Launch fused kernel +// fused_extract_expert_tokens_bf16<<>>( +// reinterpret_cast( +// hidden_states.data_ptr()), +// router_mask.data_ptr(), +// reinterpret_cast<__nv_bfloat16*>(output.data_ptr()), +// output_count.data_ptr(), num_tokens, hidden_dim, expert_idx, +// num_experts); +// } + +// Alternative simpler version without cooperative groups +__global__ void fused_extract_expert_tokens_bf16_simple( + const __nv_bfloat16* __restrict__ hidden_states, + const bool* __restrict__ router_mask, __nv_bfloat16* __restrict__ output, + int* __restrict__ output_count, const int num_tokens, const int hidden_dim, + const int expert_idx, const int num_experts) { + // Shared memory for block-level reduction + extern __shared__ char shared_mem[]; + __shared__ int num_idices = 0; + int* block_idx = reinterpret_cast(shared_mem); + + const int tid = threadIdx.x; + const int global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const int grid_size = blockDim.x * gridDim.x; + + if (global_tid < num_tokens) { + if (router_mask[global_tid * num_experts + expert_idx]) { + // Increment count for this token + int count = atomicAdd(&num_idices, 1); + block_idx[count] = global_tid; // Store index of selected token + } + } + __syncthreads(); + + // Initialize shared memory + if (tid == 0) { + block_count[0] = 0; + } + __syncthreads(); + + // Phase 1: Count selected tokens for this thread + int local_count = 0; + for (int token_idx = global_tid; token_idx < num_tokens; + token_idx += grid_size) { + if (router_mask[token_idx * num_experts + expert_idx]) { + local_count++; + } + } + + // Phase 2: Get block offset using atomic + __shared__ int block_offset; + if (local_count > 0) { + atomicAdd(&block_count[0], local_count); + } + __syncthreads(); + + if (tid == 0 && block_count[0] > 0) { + block_offset = atomicAdd(output_count, block_count[0]); + } + __syncthreads(); + + // Phase 3: Each thread writes its tokens independently + if (local_count > 0) { + // Get thread's offset within block + int thread_offset = atomicAdd(&block_count[0], local_count) - local_count; + int write_idx = block_offset + thread_offset; + + // Write tokens + for (int token_idx = global_tid; token_idx < num_tokens; + token_idx += grid_size) { + if (router_mask[token_idx * num_experts + expert_idx]) { + // Copy token data + for (int i = 0; i < hidden_dim; i++) { + output[write_idx * hidden_dim + i] = + hidden_states[token_idx * hidden_dim + i]; + } + write_idx++; + } + } + } +} + +// Host wrapper function +void extract_expert_tokens_fused_cuda(torch::Tensor hidden_states, + torch::Tensor router_mask, + torch::Tensor output, + torch::Tensor output_count, + int expert_idx, int batch_size) { + // Skip if batch_size == 1 + if (batch_size == 1) { + output.copy_(hidden_states); + return; + } + + const int num_tokens = hidden_states.size(0); + const int hidden_dim = hidden_states.size(1); + const int num_experts = router_mask.size(1); + + // Reset output count + cudaMemset(output_count.data_ptr(), 0, sizeof(int)); + + // Configure kernel launch + const int threads = 256; + const int blocks = min(65535, (num_tokens + threads - 1) / threads); + const int smem_size = sizeof(int) * threads; + + // Launch the simpler kernel (more robust) + fused_extract_expert_tokens_bf16_simple<<>>( + reinterpret_cast( + hidden_states.data_ptr()), + router_mask.data_ptr(), + reinterpret_cast<__nv_bfloat16*>(output.data_ptr()), + output_count.data_ptr(), num_tokens, hidden_dim, expert_idx, + num_experts); +} + +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include + +// using namespace cutlass; + +// // Simplified kernel using CUTLASS iterators and CUB for scan operations +// template +// __global__ void extract_expert_tokens_cutlass( +// TensorRef hidden_states, +// TensorRef router_mask, +// TensorRef output, int* output_count, +// int expert_idx) { +// using BlockScan = cub::BlockScan; +// __shared__ typename BlockScan::TempStorage temp_storage; +// __shared__ int block_offset; + +// const int num_tokens = hidden_states.extent(0); +// const int hidden_dim = hidden_states.extent(1); + +// // Phase 1: Efficient counting with CUB +// int thread_data[kElementsPerThread]; +// int thread_count = 0; + +// #pragma unroll +// for (int i = 0; i < kElementsPerThread; ++i) { +// int token_idx = blockIdx.x * kThreads * kElementsPerThread + +// threadIdx.x * kElementsPerThread + i; + +// bool selected = false; +// if (token_idx < num_tokens) { +// selected = router_mask.at({token_idx, expert_idx}); +// } +// thread_data[i] = selected ? 1 : 0; +// thread_count += thread_data[i]; +// } + +// // Block-wide exclusive scan +// int thread_offset; +// int block_total; +// BlockScan(temp_storage) +// .ExclusiveSum(thread_count, thread_offset, block_total); + +// // Get global offset +// if (threadIdx.x == 0) { +// block_offset = atomicAdd(output_count, block_total); +// } +// __syncthreads(); + +// thread_offset += block_offset; + +// // Phase 2: Extract using CUTLASS iterators for optimal memory access +// using ThreadMap = +// layout::PitchLinearThreadMap layout::PitchLinearShape, +// kThreads, layout::PitchLinearShape<8, 1>, // 8 elements per access +// 1 > ; + +// using Iterator = transform::threadblock::PredicatedTileIterator +// layout::PitchLinearShape, +// bfloat16_t, layout::RowMajor, 1, ThreadMap > ; + +// // Process selected tokens +// int local_offset = 0; +// #pragma unroll +// for (int i = 0; i < kElementsPerThread; ++i) { +// if (thread_data[i]) { +// int token_idx = blockIdx.x * kThreads * kElementsPerThread + +// threadIdx.x * kElementsPerThread + i; + +// // Use CUTLASS iterator for coalesced copy +// Iterator src_iterator(hidden_states.data() + token_idx * hidden_dim, +// {hidden_dim, 1}, threadIdx.x); + +// Iterator dst_iterator( +// output.data() + (thread_offset + local_offset) * hidden_dim, +// {hidden_dim, 1}, threadIdx.x); + +// // Vectorized copy using CUTLASS fragments +// CUTLASS_PRAGMA_UNROLL +// for (int j = 0; j < Iterator::kIterations; ++j) { +// typename Iterator::Fragment fragment; +// src_iterator.load(fragment); +// dst_iterator.store(fragment); +// ++src_iterator; +// ++dst_iterator; +// } + +// local_offset++; +// } +// } +// } + +// // Even simpler version using CUTLASS's DeviceSelect +// void extract_expert_tokens_cutlass_v2(torch::Tensor hidden_states, +// torch::Tensor router_mask, +// torch::Tensor output, +// torch::Tensor output_count, +// int expert_idx, int batch_size) { +// if (batch_size == 1) { +// output.copy_(hidden_states); +// return; +// } + +// const int num_tokens = hidden_states.size(0); +// const int hidden_dim = hidden_states.size(1); + +// // Create CUTLASS tensor refs +// TensorRef hidden_ref( +// reinterpret_cast(hidden_states.data_ptr()), +// layout::RowMajor(hidden_dim)); + +// TensorRef mask_ref( +// router_mask.data_ptr(), layout::RowMajor(router_mask.size(1))); + +// TensorRef output_ref( +// reinterpret_cast(output.data_ptr()), +// layout::RowMajor(hidden_dim)); + +// // Reset count +// cudaMemset(output_count.data_ptr(), 0, sizeof(int)); + +// // Launch optimized kernel +// const int kThreads = 128; +// const int kElementsPerThread = 4; +// const int blocks = (num_tokens + kThreads * kElementsPerThread - 1) / +// (kThreads * kElementsPerThread); + +// extract_expert_tokens_cutlass +// <<>>(hidden_ref, mask_ref, output_ref, +// output_count.data_ptr(), expert_idx); +// } + +// // Alternative: Use CUB DeviceSelect directly for maximum simplicity +// void extract_expert_tokens_cub(torch::Tensor hidden_states, +// torch::Tensor router_mask, torch::Tensor +// output, torch::Tensor output_count, int +// expert_idx, int batch_size) { +// if (batch_size == 1) { +// output.copy_(hidden_states); +// return; +// } + +// const int num_tokens = hidden_states.size(0); +// const int hidden_dim = hidden_states.size(1); + +// // Create index array +// auto indices = torch::arange( +// num_tokens, +// torch::dtype(torch::kInt32).device(hidden_states.device())); + +// // Get mask column for this expert +// auto expert_mask = router_mask.index({"...", expert_idx}); + +// // Allocate temporary storage for CUB +// size_t temp_storage_bytes = 0; +// cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, +// indices.data_ptr(), +// expert_mask.data_ptr(), +// indices.data_ptr(), // reuse for output +// output_count.data_ptr(), num_tokens); + +// auto temp_storage = +// torch::empty(temp_storage_bytes, +// torch::dtype(torch::kUInt8).device(hidden_states.device())); + +// // Select indices +// cub::DeviceSelect::Flagged( +// temp_storage.data_ptr(), temp_storage_bytes, indices.data_ptr(), +// expert_mask.data_ptr(), indices.data_ptr(), +// output_count.data_ptr(), num_tokens); + +// // Get selected count +// int num_selected; +// cudaMemcpy(&num_selected, output_count.data_ptr(), sizeof(int), +// cudaMemcpyDeviceToHost); + +// // Copy selected tokens using CUTLASS batched copy +// if (num_selected > 0) { +// // Simple kernel to copy using selected indices +// auto copy_kernel = [=] __device__(int idx) { +// if (idx < num_selected) { +// int src_idx = indices.data_ptr()[idx]; + +// // Use CUTLASS aligned memory copy +// using CopyOp = +// cutlass::AlignedCopy sizeof(float4), // 128-bit alignment +// layout::RowMajor > ; + +// CopyOp copy_op; +// for (int i = 0; i < hidden_dim; i += 8) { +// copy_op(output.data_ptr() + idx * hidden_dim + i, +// hidden_states.data_ptr() + +// src_idx * hidden_dim + i); +// } +// } +// }; + +// // Launch copy kernel +// int threads = 256; +// int blocks = (num_selected + threads - 1) / threads; +// copy_kernel<<>>(num_selected); +// } +// } diff --git a/core/kernel/ops.h b/core/kernel/ops.h new file mode 100644 index 0000000..424bfcd --- /dev/null +++ b/core/kernel/ops.h @@ -0,0 +1,36 @@ +#pragma once + +#include + +// Activation and gating kernel functions +void silu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input); // [..., 2 * d] + +void mul_and_silu(torch::Tensor& out, // [..., d] + torch::Tensor& input); // [..., 2 * d] + +void gelu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input); // [..., 2 * d] + +void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input); // [..., 2 * d] + +void fatrelu_and_mul(torch::Tensor& out, // [..., d], + torch::Tensor& input, // [..., 2 * d] + double threshold); + +// Element-wise activation kernel functions +void gelu_new(torch::Tensor& out, // [..., d] + torch::Tensor& input); // [..., d] + +void gelu_fast(torch::Tensor& out, // [..., d] + torch::Tensor& input); // [..., d] + +void gelu_quick(torch::Tensor& out, // [..., d] + torch::Tensor& input); // [..., d] + +// TopK softmax kernel functions +void topk_softmax(torch::Tensor& topk_weights, // [num_tokens, topk] + torch::Tensor& topk_indices, // [num_tokens, topk] + torch::Tensor& token_expert_indices, // [num_tokens, topk] + torch::Tensor& gating_output); // [num_tokens, num_experts] diff --git a/core/kernel/topk_softmax_kernels.cu b/core/kernel/topk_softmax_kernels.cu new file mode 100644 index 0000000..31a4b32 --- /dev/null +++ b/core/kernel/topk_softmax_kernels.cu @@ -0,0 +1,536 @@ +/* + * Adapted from + * https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu + * Copyright (c) 2024, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & + * AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "ops.h" + +#include +#include + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +static constexpr int WARP_SIZE = 32; + +/// Aligned array type +template +class alignas(Alignment) AlignedArray { + float data[N]; +}; + +// ====================== Softmax things =============================== +// We have our own implementation of softmax here so we can support transposing +// the output in the softmax kernel when we extend this module to support +// expert-choice routing. +template +__launch_bounds__(TPB) __global__ + void moeSoftmax(const float* input, const bool* finished, float* output, + const int num_cols) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + __shared__ float normalizing_factor; + __shared__ float float_max; + + const int thread_row_offset = blockIdx.x * num_cols; + + cub::Sum sum; + float threadData(-FLT_MAX); + + // Don't touch finished rows. + if ((finished != nullptr) && finished[blockIdx.x]) { + return; + } + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData = max(static_cast(input[idx]), threadData); + } + + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) { + float_max = maxElem; + } + __syncthreads(); + + threadData = 0; + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData += exp((static_cast(input[idx]) - float_max)); + } + + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + + if (threadIdx.x == 0) { + normalizing_factor = 1.f / Z; + } + __syncthreads(); + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + const float val = + exp((static_cast(input[idx]) - float_max)) * normalizing_factor; + output[idx] = val; + } +} + +template +__launch_bounds__(TPB) __global__ + void moeTopK(const float* inputs_after_softmax, const bool* finished, + float* output, IndType* indices, int* source_rows, + const int num_experts, const int k, const int start_expert, + const int end_expert) { + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + const int num_rows = gridDim.x; + const int block_row = blockIdx.x; + + const bool row_is_active = finished ? !finished[block_row] : true; + const int thread_read_offset = blockIdx.x * num_experts; + for (int k_idx = 0; k_idx < k; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = -1.f; // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs_after_softmax[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = + BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + // Ignore experts the node isn't responsible for with expert parallelism + const int expert = result_kvp.key; + const bool node_uses_expert = + expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + const int idx = k * block_row + k_idx; + output[idx] = result_kvp.value; + indices[idx] = should_process_row ? (expert - start_expert) : num_experts; + assert(indices[idx] >= 0); + source_rows[idx] = k_idx * num_rows + block_row; + } + __syncthreads(); + } +} + +// ====================== TopK softmax things =============================== + +/* + A Top-K gating softmax written to exploit when the number of experts in the + MoE layers are a small power of 2. This allows us to cleanly share the rows + among the threads in a single warp and eliminate communication between warps + (so no need to use shared mem). + + It fuses the softmax, max and argmax into a single kernel. + + Limitations: + 1) This implementation is intended for when the number of experts is a small + power of 2. 2) This implementation assumes k is small, but will work for any + k. +*/ + +template +__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ + void topkGatingSoftmax(const float* input, const bool* finished, + float* output, const int num_rows, IndType* indices, + int* source_rows, const int k, + const int start_expert, const int end_expert) { + // We begin by enforcing compile time assertions and setting up compile time + // constants. + static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); + static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), + "NUM_EXPERTS must be power of 2"); + static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), + "BYTES_PER_LDG must be power of 2"); + static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); + + // Number of bytes each thread pulls in per load + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static constexpr int ELTS_PER_ROW = NUM_EXPERTS; + static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; + static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + + // Restrictions based on previous section. + static_assert( + VPT % ELTS_PER_LDG == 0, + "The elements per thread must be a multiple of the elements per ldg"); + static_assert(WARP_SIZE % THREADS_PER_ROW == 0, + "The threads per row must cleanly divide the threads per warp"); + static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), + "THREADS_PER_ROW must be power of 2"); + static_assert(THREADS_PER_ROW <= WARP_SIZE, + "THREADS_PER_ROW can be at most warp size"); + + // We have NUM_EXPERTS elements per row. We specialize for small #experts + static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; + static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; + static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; + + // Restrictions for previous section. + static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, + "The elts per row must cleanly divide the total elt per warp"); + + // ===================== From this point, we finally start computing run-time + // variables. ======================== + + // Compute CTA and warp rows. We pack multiple rows into a single warp, and a + // block contains WARPS_PER_CTA warps. This, each block processes a chunk of + // rows. We start by computing the start row for each block. + const int cta_base_row = blockIdx.x * ROWS_PER_CTA; + + // Now, using the base row per thread block, we compute the base row per warp. + const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; + + // The threads in a warp are split into sub-groups that will work on a row. + // We compute row offset for each thread sub-group + const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; + const int thread_row = warp_base_row + thread_row_in_warp; + + // Threads with indices out of bounds should early exit here. + if (thread_row >= num_rows) { + return; + } + const bool row_is_active = finished ? !finished[thread_row] : true; + + // We finally start setting up the read pointers for each thread. First, each + // thread jumps to the start of the row it will read. + const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + + // Now, we compute the group each thread belong to in order to determine the + // first column to start loads. + const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; + const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; + const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + + // Determine the pointer type to use to read in the data depending on the + // BYTES_PER_LDG template param. In theory, this can support all powers of 2 + // up to 16. NOTE(woosuk): The original implementation uses CUTLASS aligned + // array here. We defined our own aligned array and use it here to avoid the + // dependency on CUTLASS. + using AccessType = AlignedArray; + + // Finally, we pull in the data from global mem + float row_chunk[VPT]; + AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk); + const AccessType* vec_thread_read_ptr = + reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + + // First, we perform a max reduce within the thread. We can do the max in fp16 + // safely (I think) and just convert to float afterwards for the exp + sum + // reduction. + float thread_max = row_chunk[0]; +#pragma unroll + for (int ii = 1; ii < VPT; ++ii) { + thread_max = max(thread_max, row_chunk[ii]); + } + +// Now, we find the max within the thread group and distribute among the +// threads. We use a butterfly reduce. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + thread_max = + fmaxf(thread_max, __shfl_xor_sync(0xffffffff, thread_max, mask)); + } + + // From this point, thread max in all the threads have the max within the row. + // Now, we subtract the max from each element in the thread and take the exp. + // We also compute the thread local sum. + float row_sum = 0; +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = expf(row_chunk[ii] - thread_max); + row_sum += row_chunk[ii]; + } + +// Now, we perform the sum reduce within each thread group. Similar to the max +// reduce, we use a bufferfly pattern. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + row_sum += __shfl_xor_sync(0xffffffff, row_sum, mask); + } + + // From this point, all threads have the max and the sum for their rows in the + // thread_max and thread_sum variables respectively. Finally, we can scale the + // rows for the softmax. Technically, for top-k gating we don't need to + // compute the entire softmax row. We can likely look at the maxes and only + // compute for the top-k values in the row. However, this kernel will likely + // not be a bottle neck and it seems better to closer match torch and find the + // argmax after computing the softmax. + const float reciprocal_row_sum = 1.f / row_sum; + +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; + } + + // Now, softmax_res contains the softmax of the row chunk. Now, I want to find + // the topk elements in each row, along with the max index. + int start_col = first_elt_read_by_thread; + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + + for (int k_idx = 0; k_idx < k; ++k_idx) { + // First, each thread does the local argmax + float max_val = row_chunk[0]; + int expert = start_col; +#pragma unroll + for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; + ++ldg, col += COLS_PER_GROUP_LDG) { +#pragma unroll + for (int ii = 0; ii < ELTS_PER_LDG; ++ii) { + float val = row_chunk[ldg * ELTS_PER_LDG + ii]; + + // No check on the experts here since columns with the smallest index + // are processed first and only updated if > (not >=) + if (val > max_val) { + max_val = val; + expert = col + ii; + } + } + } + +// Now, we perform the argmax reduce. We use the butterfly pattern so threads +// reach consensus about the max. This will be useful for K > 1 so that the +// threads can agree on "who" had the max value. That thread can then blank out +// their max with -inf and the warp can run more iterations... +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + float other_max = + __shfl_xor_sync(0xffffffff, max_val, mask, THREADS_PER_ROW); + int other_expert = + __shfl_xor_sync(0xffffffff, expert, mask, THREADS_PER_ROW); + + // We want lower indices to "win" in every thread so we break ties this + // way + if (other_max > max_val || + (other_max == max_val && other_expert < expert)) { + max_val = other_max; + expert = other_expert; + } + } + + // Write the max for this k iteration to global memory. + if (thread_group_idx == 0) { + // Add a guard to ignore experts not included by this node + const bool node_uses_expert = + expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + // The lead thread from each sub-group will write out the final results to + // global memory. (This will be a single) thread per row of the + // input/output matrices. + const int idx = k * thread_row + k_idx; + output[idx] = max_val; + indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; + source_rows[idx] = k_idx * num_rows + thread_row; + } + + // Finally, we clear the value in the thread with the current max if there + // is another iteration to run. + if (k_idx + 1 < k) { + const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; + const int thread_to_clear_in_group = + (expert / ELTS_PER_LDG) % THREADS_PER_ROW; + + // Only the thread in the group which produced the max will reset the + // "winning" value to -inf. + if (thread_group_idx == thread_to_clear_in_group) { + const int offset_for_expert = expert % ELTS_PER_LDG; + // Safe to set to any negative value since row_chunk values must be + // between 0 and 1. + row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = + -10000.f; + } + } + } +} + +namespace detail { +// Constructs some constants needed to partition the work across threads at +// compile time. +template +struct TopkConstants { + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || + EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, + ""); + static constexpr int VECs_PER_THREAD = + MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; + static constexpr int THREADS_PER_ROW = EXPERTS / VPT; + static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; +}; +} // namespace detail + +template +void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, + float* output, IndType* indices, + int* source_row, const int num_rows, + const int k, const int start_expert, + const int end_expert, + cudaStream_t stream) { + static constexpr std::size_t MAX_BYTES_PER_LDG = 16; + + static constexpr int BYTES_PER_LDG = + MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); + using Constants = detail::TopkConstants; + static constexpr int VPT = Constants::VPT; + static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; + const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; + + dim3 block_dim(WARP_SIZE, WARPS_PER_TB); + topkGatingSoftmax + <<>>(input, finished, output, num_rows, + indices, source_row, k, + start_expert, end_expert); +} + +#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indices, \ + token_expert_indices, num_tokens, topk, 0, num_experts, stream); + +template +void topkGatingSoftmaxKernelLauncher( + const float* gating_output, float* topk_weights, IndType* topk_indices, + int* token_expert_indices, float* softmax_workspace, const int num_tokens, + const int num_experts, const int topk, cudaStream_t stream) { + static constexpr int WARPS_PER_TB = 4; + switch (num_experts) { + case 1: + LAUNCH_SOFTMAX(1, WARPS_PER_TB); + break; + case 2: + LAUNCH_SOFTMAX(2, WARPS_PER_TB); + break; + case 4: + LAUNCH_SOFTMAX(4, WARPS_PER_TB); + break; + case 8: + LAUNCH_SOFTMAX(8, WARPS_PER_TB); + break; + case 16: + LAUNCH_SOFTMAX(16, WARPS_PER_TB); + break; + case 32: + LAUNCH_SOFTMAX(32, WARPS_PER_TB); + break; + case 64: + LAUNCH_SOFTMAX(64, WARPS_PER_TB); + break; + case 128: + LAUNCH_SOFTMAX(128, WARPS_PER_TB); + break; + case 256: + LAUNCH_SOFTMAX(256, WARPS_PER_TB); + break; + default: { + TORCH_CHECK(softmax_workspace != nullptr, + "softmax_workspace must be provided for num_experts that are " + "not a power of 2."); + static constexpr int TPB = 256; + moeSoftmax<<>>( + gating_output, nullptr, softmax_workspace, num_experts); + moeTopK<<>>( + softmax_workspace, nullptr, topk_weights, topk_indices, + token_expert_indices, num_experts, topk, 0, num_experts); + } + } +} + +void topk_softmax(torch::Tensor& topk_weights, // [num_tokens, topk] + torch::Tensor& topk_indices, // [num_tokens, topk] + torch::Tensor& token_expert_indices, // [num_tokens, topk] + torch::Tensor& gating_output) // [num_tokens, num_experts] +{ + const int num_experts = gating_output.size(-1); + const auto num_tokens = gating_output.numel() / num_experts; + const int topk = topk_weights.size(-1); + + const bool is_pow_2 = + (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + const bool needs_workspace = !is_pow_2 || num_experts > 256; + const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + torch::Tensor softmax_workspace = + torch::empty({workspace_size}, gating_output.options()); + + topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), topk_weights.data_ptr(), + topk_indices.data_ptr(), token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), num_tokens, num_experts, topk, + stream); + + // if (topk_indices.scalar_type() == at::ScalarType::Int) { + // topkGatingSoftmaxKernelLauncher( + // gating_output.data_ptr(), topk_weights.data_ptr(), + // topk_indices.data_ptr(), token_expert_indices.data_ptr(), + // softmax_workspace.data_ptr(), num_tokens, num_experts, topk, + // stream); + // } else if (topk_indices.scalar_type() == at::ScalarType::UInt32) { + // topkGatingSoftmaxKernelLauncher( + // gating_output.data_ptr(), topk_weights.data_ptr(), + // topk_indices.data_ptr(), + // token_expert_indices.data_ptr(), + // softmax_workspace.data_ptr(), num_tokens, num_experts, topk, + // stream); + // } else { + // assert(topk_indices.scalar_type() == at::ScalarType::Int64); + // topkGatingSoftmaxKernelLauncher( + // gating_output.data_ptr(), topk_weights.data_ptr(), + // topk_indices.data_ptr(), + // token_expert_indices.data_ptr(), + // softmax_workspace.data_ptr(), num_tokens, num_experts, topk, + // stream); + // } +} diff --git a/core/kernel/torch_bindings.h b/core/kernel/torch_bindings.h new file mode 100644 index 0000000..e69de29 diff --git a/core/model/moe.h b/core/model/moe.h new file mode 100644 index 0000000..454f20c --- /dev/null +++ b/core/model/moe.h @@ -0,0 +1,181 @@ +// Copyright (c) EfficientMoE. +// SPDX-License-Identifier: Apache-2.0 + +// EfficientMoE Team + +#pragma once + +#include +#include + +#include "utils/cuda_utils.h" +#include "common/pytorch.h" +#include "kernel/ops.h" + +#define BUFFER_PTR(buf_type, ptr_type) \ + (buffer_[static_cast(BufferType::buf_type)]) + +#define CUDA_ALLOCATE_BUFFER(type, size) \ + CUDA_CHECK(cudaMalloc( \ + reinterpret_cast(&buffer_[static_cast(BufferType::type)]), \ + size * sizeof(param_t))); + +// The abstraction of MoE (Mixture of Experts) layer with fixed buffers. +template +class MoELayer { + public: + enum class BufferType { + + // MoE buffers + HiddenStates = 0, // Buffer for hidden states + // GatingWeights, // Buffer for gate weights + FinalHiddenStates, // Buffer for final hidden states + GatingOutput, // Buffer for gating output + TopKWeights, // Buffer for top-k weights + TopKIndices, // Buffer for top-k indices + TokenExpertIndices, // Buffer for token expert indices + + // expert buffers + ExpertInput, // Buffer for input to experts + ExpertUpProjOutput, // Buffer for up projection output + ExpertGateProjInput, // Buffer for gate projection input + ExpertDownProjOutput, // Buffer for down projection output + ExpertActMulOutput, // Buffer for gated activation output + + // backward capability + ExpertRouterMask, // Buffer for router mask + ExpertRouterWeight, // Buffer for router weights + + NumBuffers // Total number of buffer types + }; + + explicit MoELayer(int num_experts, int topk, int max_tokens, + int64_t hidden_dim, int64_t intermediate_dim) + : num_experts_(num_experts), + topk_(topk), + max_tokens_(max_tokens), + hidden_dim_(hidden_dim), + intermediate_dim_(intermediate_dim), + buffer_(static_cast(BufferType::NumBuffers)) { + CUDA_ALLOCATE_BUFFER(HiddenStates, max_tokens * hidden_dim); + // CUDA_ALLOCATE_BUFFER(GatingWeights, num_experts * hidden_dim); + CUDA_ALLOCATE_BUFFER(FinalHiddenStates, max_tokens * hidden_dim); + CUDA_ALLOCATE_BUFFER(GatingOutput, max_tokens * num_experts); + CUDA_ALLOCATE_BUFFER(TopKWeights, max_tokens * topk); + CUDA_ALLOCATE_BUFFER(TopKIndices, max_tokens * topk); + CUDA_ALLOCATE_BUFFER(TokenExpertIndices, max_tokens * topk); + CUDA_ALLOCATE_BUFFER(ExpertInput, max_tokens * hidden_dim); + CUDA_ALLOCATE_BUFFER(ExpertUpProjOutput, max_tokens * intermediate_dim); + CUDA_ALLOCATE_BUFFER(ExpertGateProjInput, max_tokens * intermediate_dim); + CUDA_ALLOCATE_BUFFER(ExpertDownProjOutput, max_tokens * hidden_dim); + CUDA_ALLOCATE_BUFFER(ExpertActMulOutput, max_tokens * hidden_dim); + + CUDA_ALLOCATE_BUFFER(ExpertRouterMask, max_tokens * num_experts); + CUDA_ALLOCATE_BUFFER(ExpertRouterWeight, max_tokens * num_experts); + + device_id_ = c10::cuda::current_device(); + cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking); + } + + void ForwardGating() { + // Forward pass for gating mechanism + // This function will use the buffers to compute gating weights and outputs + + // create temperal wrappers as tensor + auto hidden_states = + torch::from_blob(BUFFER_PTR(HiddenStates, void), + {max_tokens_, hidden_dim_}, DoNothingDeleter{}, + torch::TensorOptions() + .dtype(torch::dtype()) + .device(CUDA_DEVICE(device_id_))); + + auto gating_weights = + torch::from_blob(BUFFER_PTR(GatingWeights, void), + {num_experts_, hidden_dim_}, DoNothingDeleter{}, + torch::TensorOptions() + .dtype(torch::dtype()) + .device(CUDA_DEVICE(device_id_))); + + auto gating_output = + torch::from_blob(BUFFER_PTR(GatingOutput, void), + {max_tokens_, num_experts_}, DoNothingDeleter{}, + torch::TensorOptions() + .dtype(torch::dtype()) + .device(CUDA_DEVICE(device_id_))); + + // Perform the gating operation on stream_ + c10::cuda::CUDAStream torch_stream = + c10::cuda::getStreamFromExternal(stream_, device_id_); + c10::cuda::setCurrentCUDAStream(torch_stream); + torch::matmul_out(gating_output, hidden_states, + gating_weights.t()); // [max_tokens, num_experts] + + auto topk_weights = + torch::from_blob(BUFFER_PTR(TopKWeights, void), {max_tokens_, topk_}, + DoNothingDeleter{}, + torch::TensorOptions() + .dtype(torch::kFloat32) + .device(CUDA_DEVICE(device_id_))); + + auto topk_indices = + torch::from_blob(BUFFER_PTR(TopKIndices, void), {max_tokens_, topk_}, + DoNothingDeleter{}, + torch::TensorOptions() + .dtype(torch::kUInt32) + .device(CUDA_DEVICE(device_id_))); + + auto token_expert_indices = + torch::from_blob(BUFFER_PTR(TokenExpertIndices, void), + {max_tokens_, topk_}, DoNothingDeleter{}, + torch::TensorOptions() + .dtype(torch::kUInt32) + .device(CUDA_DEVICE(device_id_))); + + // Perform top-k softmax to get top-k gating weights and indices + topk_softmax(topk_weights, topk_indices, token_expert_indices, + gating_output); // [max_tokens, topk] + + auto router_mask = + torch::from_blob(BUFFER_PTR(ExpertRouterMask, void), + {max_tokens_, num_experts_}, DoNothingDeleter{}, + torch::TensorOptions() + .dtype(torch::kBool) + .device(CUDA_DEVICE(device_id_))); + + router_mask.scatter_(1, token_expert_indices, + true); // Set router mask based on top-k indices + + auto routing_weights_mask = + torch::from_blob(BUFFER_PTR(ExpertRouterWeight, void), + {max_tokens_, num_experts_}, DoNothingDeleter{}, + torch::TensorOptions() + .dtype(torch::dtype()) + .device(CUDA_DEVICE(device_id_))); + + routing_weights_mask.scatter_add_( + 1, token_expert_indices, + topk_weights); // Set routing weights mask + } + + ~MoELayer() { + // Clean up allocated buffers + for (auto* buffer : buffer_) { + if (buffer) { + CUDA_CHECK(cudaFree(buffer)); + } + } + if (stream_) { + CUDA_CHECK(cudaStreamDestroy(stream_)); + } + } + + private: + std::vector buffer_; // Vector of buffers + int num_experts_ = 0; // Number of experts in the MoE layer + int topk_ = 0; // Number of top-k experts to select + int max_tokens_ = 0; // Maximum number of tokens processed in a batch + int64_t hidden_dim_ = 0; // Dimension of hidden states + int64_t intermediate_dim_ = 0; // Dimension of intermediate states + cudaStream_t stream_ = 0; // CUDA stream for asynchronous operations + int device_id_ = 0; // Device ID for the MoE layer +}; diff --git a/core/parallel/expert_dispatcher.cpp b/core/parallel/expert_dispatcher.cpp index c797840..05de7ec 100644 --- a/core/parallel/expert_dispatcher.cpp +++ b/core/parallel/expert_dispatcher.cpp @@ -12,6 +12,7 @@ #include "utils/cuda_utils.h" #include "utils/logger.h" #include "model/model_topology.h" +#include "model/moe.h" #include #include @@ -660,73 +661,4 @@ void ExpertDispatcher::SetInputs(const torch::Tensor& hidden_states, router_mask_ = router_mask; router_weight_ = router_weight; final_hidden_states_ = torch::zeros_like(hidden_states, options); - - // auto stream = exec_streams_[device]; - // int64_t batch_size = hidden_states_.size(0); - // int64_t hidden_dim = hidden_states_.size(1); - - // DLOG_FATAL_IF(num_experts_ != router_weight.size(1), - // "ExpertDispatcher::SetInputs: num_experts ", num_experts_, - // " router_weight.size(1) ", router_weight.size(1)); - - // if (!final_hidden_states_.defined()) { - // // final_hidden_states is float type - // auto options = - // torch::TensorOptions().dtype(torch::kFloat32).device(CUDA_DEVICE(device)); - // auto allocator = c10::DeviceAllocator::get(device); - // void* ptr = allocator->allocate(batch_size * hidden_dim * sizeof(float)); - // final_hidden_states_ = torch::from_blob(ptr, {batch_size, hidden_dim}, - // DoNothingDeleter{}, - // options); - // } - - // if (!router_mask_.defined()) { - // // router mask is boolean type - // auto options = - // torch::TensorOptions().dtype(torch::kBool).device(CUDA_DEVICE(device)); - // auto allocator = c10::DeviceAllocator::get(device); - // void* ptr = allocator->allocate(batch_size * num_experts_ * - // sizeof(bool)); router_mask_ = torch::from_blob(ptr, {batch_size, - // num_experts_}, - // DoNothingDeleter{}, options); - // } - - // if (!router_weight_.defined()) { - // // router weight is float type - // auto options = - // torch::TensorOptions().dtype(torch::kFloat32).device(CUDA_DEVICE(device)); - // auto allocator = c10::DeviceAllocator::get(device); - // void* ptr = allocator->allocate(batch_size * num_experts_ * - // sizeof(float)); router_weight_ = torch::from_blob(ptr, {batch_size, - // num_experts_}, - // DoNothingDeleter{}, options); - // } - - // if (!hidden_states_.defined()) { - // // hidden states is float type - // auto options = - // torch::TensorOptions().dtype(hidden_states.dtype()).device(CUDA_DEVICE(device)); - // auto allocator = c10::DeviceAllocator::get(device); - // void* ptr = allocator->allocate(hidden_states.numel() * sizeof(float)); - // hidden_states_ = torch::from_blob(ptr, {batch_size, hidden_dim}, - // DoNothingDeleter{}, options); - // } - - // cudaMemsetAsync(final_hidden_states_.data_ptr(), 0, - // final_hidden_states_.numel() * sizeof(float), stream); - // cudaMemcpyAsync( - // router_mask_.data_ptr(), router_mask.data_ptr(), - // router_mask.numel() * sizeof(bool), cudaMemcpyDeviceToDevice, stream); - // cudaMemcpyAsync( - // router_weight_.data_ptr(), router_weight.data_ptr(), - // router_weight.numel() * sizeof(float), cudaMemcpyDeviceToDevice, - // stream); - // cudaMemcpyAsync( - // hidden_states_.data_ptr(), hidden_states.data_ptr(), - // hidden_states.numel() * sizeof(float), cudaMemcpyDeviceToDevice, - // stream); - // cudaMemcpyAsync( - // router_mask_.data_ptr(), router_mask.data_ptr(), - // router_mask.numel() * sizeof(bool), cudaMemcpyDeviceToDevice, stream); - // cudaStreamSynchronize(stream); } diff --git a/core/parallel/expert_module.cpp b/core/parallel/expert_module.cpp index 51ab673..5212ec7 100644 --- a/core/parallel/expert_module.cpp +++ b/core/parallel/expert_module.cpp @@ -344,15 +344,6 @@ void ExpertNode::SetTensorsFromBlob(const torch::Device& device) { MoEMLP::MoEMLP(int dtype, int expert_type) { auto tensor_dtype = dtype_to_torch(dtype); auto options = torch::TensorOptions().dtype(tensor_dtype).device(torch::kCPU); - // input_ = register_parameter("input", torch::zeros({1}, options)); - // output_ = register_parameter("output", torch::zeros({1}, options)); - // gate_proj_ = register_parameter("gate_proj", torch::zeros({1}, options)); - // up_proj_ = register_parameter("up_proj", torch::zeros({1}, options)); - // down_proj_ = register_parameter("down_proj", torch::zeros({1}, options)); - - // fc1_bias_ = register_parameter("fc1_bias", torch::zeros({1}, options)); - // fc2_bias_ = register_parameter("fc2_bias", torch::zeros({1}, options)); - // fc3_bias_ = register_parameter("fc3_bias", torch::zeros({1}, options)); expert_type_ = expert_type; dtype_ = dtype; @@ -511,12 +502,12 @@ void MoEMLP::ForwardHelper() { // gate step torch::matmul_out(gate_out, input, gate_proj.transpose(0, 1)); - // activation step - torch::silu_out(gate_act_out, gate_out); - // up step torch::matmul_out(up_out, input, up_proj.transpose(0, 1)); + // activation step + torch::silu_out(gate_act_out, gate_out); + // multiplication step, reuse gate_out torch::mul_out(gate_out, gate_act_out, up_out); diff --git a/core/parallel/expert_module.h b/core/parallel/expert_module.h index 663f482..141a82b 100644 --- a/core/parallel/expert_module.h +++ b/core/parallel/expert_module.h @@ -246,15 +246,6 @@ struct MoEMLP : public torch::nn::Module { private: std::vector buffer_; std::vector param_; - // torch::Tensor input_; - // torch::Tensor output_; - // torch::Tensor gate_proj_; - // torch::Tensor up_proj_; - // torch::Tensor down_proj_; - - // torch::Tensor fc1_bias_; - // torch::Tensor fc2_bias_; - // torch::Tensor fc3_bias_; at::cuda::CUDAGraph graph_; int warmup_count_ = 5; diff --git a/examples/readme_example.py b/examples/readme_example.py index edbe36b..5cdbc1a 100644 --- a/examples/readme_example.py +++ b/examples/readme_example.py @@ -6,7 +6,7 @@ user_home = os.path.expanduser("~") -checkpoint = "deepseek-ai/DeepSeek-V2-Lite-Chat" +checkpoint = "Qwen/Qwen3-30B-A3B" tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote=True) config = { diff --git a/op_builder/builder.py b/op_builder/builder.py index cad6607..c86585b 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -100,7 +100,17 @@ def get_default_compute_capabilities(): "11.7", "11.8", ], - 12: ["12.0", "12.1"], + 12: [ + "12.0", + "12.1", + "12.2", + "12.3", + "12.4", + "12.5", + "12.6", + "12.7", + "12.8", + ], } diff --git a/op_builder/prefetch.py b/op_builder/prefetch.py index 3d5623a..212ab9f 100644 --- a/op_builder/prefetch.py +++ b/op_builder/prefetch.py @@ -34,6 +34,8 @@ def sources(self): "core/utils/cuda_utils.cpp", "core/model/model_topology.cpp", "core/model/fused_mlp.cu", + "core/kernel/activation_kernels.cu", + "core/kernel/topk_softmax_kernels.cu", "core/prefetch/archer_prefetch_handle.cpp", "core/prefetch/task_scheduler.cpp", "core/prefetch/task_thread.cpp", diff --git a/tests/cuda/CMakeLists.txt b/tests/cuda/CMakeLists.txt index 1a553da..8f2b629 100644 --- a/tests/cuda/CMakeLists.txt +++ b/tests/cuda/CMakeLists.txt @@ -125,6 +125,7 @@ set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -g -G -lineinfo -rdynamic -O3 -gencode a set(SRC_LIST test_uvm_kernel.cu test_fused_mlp.cu + test_fused_mlp_wmma.cu test_expert_fusion.cu test_expert_fusion_v2.cu # test_single_gemm_tiled.cu @@ -134,6 +135,7 @@ set(SRC_LIST test_autosize_tileload.cu test_autosize_tileload_stage.cu test_autotune_blocksize.cu + tests_masked_select.cu ) FOREACH(SRC ${SRC_LIST}) @@ -142,7 +144,7 @@ FOREACH(SRC ${SRC_LIST}) target_link_libraries(${SRC_NAME} cutlass ${CUDA_LIBRARIES}) # if file is test_expert_fusion or test_expert_fusion_v2, link torch_python - IF(${SRC_NAME} STREQUAL "test_expert_fusion" OR ${SRC_NAME} STREQUAL "test_expert_fusion_v2") + IF(${SRC_NAME} STREQUAL "test_expert_fusion" OR ${SRC_NAME} STREQUAL "test_expert_fusion_v2" OR ${SRC_NAME} STREQUAL "tests_masked_select" OR ${SRC_NAME} STREQUAL "test_fused_mlp_wmma") target_link_libraries(${SRC_NAME} ${torch_python_LIBRARY} ${Python3_LIBRARIES} ${TORCH_LIBRARIES}) target_include_directories(${SRC_NAME} PRIVATE ${CONDA_INCLUDE_DIRS} ${TORCH_INCLUDE_DIRS} ${Python3_INCLUDE_DIRS}) ENDIF() diff --git a/tests/cuda/test_fused_mlp_wmma.cu b/tests/cuda/test_fused_mlp_wmma.cu new file mode 100644 index 0000000..63b7597 --- /dev/null +++ b/tests/cuda/test_fused_mlp_wmma.cu @@ -0,0 +1,405 @@ +#include +#include +#include +#include +#include +#include +#include + +using namespace nvcuda; + +// Device SiLU activation for bfloat16 +__device__ __forceinline__ __nv_bfloat16 silu(__nv_bfloat16 x) { + float x_f = __bfloat162float(x); + float result = x_f / (1.0f + expf(-x_f)); + return __float2bfloat16(result); +} + +// WMMA activation function for float accumulator, convert to bfloat16 precision +template +__device__ __forceinline__ void warp_silu_activation_bf16( + const fragment_t& frag, fragment_t& result) { +#pragma unroll + for (int t = 0; t < result.num_elements; t++) { + // Convert to bfloat16 precision, apply SiLU, then back to float for + // accumulator + __nv_bfloat16 bf16_val = __float2bfloat16(frag.x[t]); + __nv_bfloat16 silu_result = silu(bf16_val); + result.x[t] = __bfloat162float(silu_result); + } +} + +// Pipelined WMMA kernel for bfloat16 precision +__global__ void wmma_silu_add_kernel_bf16_pipelined( + const __nv_bfloat16* __restrict__ X, const __nv_bfloat16* __restrict__ A, + const __nv_bfloat16* __restrict__ B, __nv_bfloat16* __restrict__ C, int M, + int N, int K, int ldx, int lda, int ldb, int ldc) { + const int WMMA_M = 16; + const int WMMA_N = 16; + const int WMMA_K = 16; + + // Calculate warp indices within the block + int warp_id = threadIdx.x / 32; + int warps_per_block_x = blockDim.x / 32; + int warps_per_block_y = blockDim.y; + + int warp_row = warp_id % warps_per_block_x; + int warp_col = warp_id / warps_per_block_x; + + // Global warp coordinates + int warpM = blockIdx.x * warps_per_block_x + warp_row; + int warpN = blockIdx.y * warps_per_block_y + warp_col; + + if (warpM >= (M + WMMA_M - 1) / WMMA_M || + warpN >= (N + WMMA_N - 1) / WMMA_N) { + return; + } + + // Double-buffered fragments for pipelining + wmma::fragment + frag_x[2]; + wmma::fragment + frag_a[2]; + wmma::fragment + frag_b[2]; + + // Accumulator fragments (float precision for computation) + wmma::fragment frag_xa; + wmma::fragment frag_xb; + wmma::fragment frag_silu_xa; + wmma::fragment frag_result; + + // Shared memory for manual store using float (then convert during global + // store) + extern __shared__ float smem_float[]; + float* warp_smem = smem_float + warp_id * WMMA_M * WMMA_N; + + // Initialize accumulators + wmma::fill_fragment(frag_xa, 0.0f); + wmma::fill_fragment(frag_xb, 0.0f); + + int row = warpM * WMMA_M; + int col = warpN * WMMA_N; + + // Pipeline prologue - load first iteration + int buffer_idx = 0; + if (row < M && col < N && K > 0) { + // Load first tiles + wmma::load_matrix_sync(frag_x[buffer_idx], X + row * ldx, ldx); + wmma::load_matrix_sync(frag_a[buffer_idx], A + col, lda); + wmma::load_matrix_sync(frag_b[buffer_idx], B + col, ldb); + } + + // Pipelined main loop + for (int i = 0; i < K; i += WMMA_K) { + int next_buffer_idx = 1 - buffer_idx; + + // Prefetch next iteration (if not last) + if (i + WMMA_K < K && row < M && col < N) { + wmma::load_matrix_sync(frag_x[next_buffer_idx], + X + row * ldx + (i + WMMA_K), ldx); + wmma::load_matrix_sync(frag_a[next_buffer_idx], + A + (i + WMMA_K) * lda + col, lda); + wmma::load_matrix_sync(frag_b[next_buffer_idx], + B + (i + WMMA_K) * ldb + col, ldb); + } + + // Compute with current buffers + if (row < M && col < N && i < K) { + wmma::mma_sync(frag_xa, frag_x[buffer_idx], frag_a[buffer_idx], frag_xa); + wmma::mma_sync(frag_xb, frag_x[buffer_idx], frag_b[buffer_idx], frag_xb); + + // Convert accumulator results back to bfloat16 precision after each MM + // operation +#pragma unroll + for (int t = 0; t < frag_xa.num_elements; t++) { + frag_xa.x[t] = __bfloat162float(__float2bfloat16(frag_xa.x[t])); + frag_xb.x[t] = __bfloat162float(__float2bfloat16(frag_xb.x[t])); + } + } + + // Swap buffers + buffer_idx = next_buffer_idx; + } + + // Apply SiLU activation to X*A result (with bfloat16 precision behavior) + warp_silu_activation_bf16(frag_xa, frag_silu_xa); + + // Add silu(X*A) + X*B (convert to bfloat16 precision during addition) +#pragma unroll + for (int t = 0; t < frag_result.num_elements; t++) { + // Convert both operands to bfloat16, perform addition, convert back to + // float + __nv_bfloat16 silu_bf16 = __float2bfloat16(frag_silu_xa.x[t]); + __nv_bfloat16 xb_bf16 = __float2bfloat16(frag_xb.x[t]); + __nv_bfloat16 result_bf16 = __hadd(silu_bf16, xb_bf16); + frag_result.x[t] = __bfloat162float(result_bf16); + } + + // Use WMMA store to shared memory with float, then manually copy to global + // with conversion + if (row < M && col < N) { + // Store float result to shared memory using WMMA (this works!) + wmma::store_matrix_sync(warp_smem, frag_result, WMMA_N, + wmma::mem_row_major); + + // Synchronize warp + __syncwarp(); + + // Now cooperatively copy from shared memory to global memory with bfloat16 + // conversion + int lane_id = threadIdx.x % 32; + + // Each thread handles 8 elements (256 total / 32 threads = 8) + for (int i = 0; i < 8; i++) { + int elem_idx = lane_id * 8 + i; + if (elem_idx < WMMA_M * WMMA_N) { + int local_row = elem_idx / WMMA_N; + int local_col = elem_idx % WMMA_N; + int global_row = row + local_row; + int global_col = col + local_col; + + if (global_row < M && global_col < N) { + // Convert float from shared memory to bfloat16 for global memory + float val = warp_smem[local_row * WMMA_N + local_col]; + C[global_row * ldc + global_col] = __float2bfloat16(val); + } + } + } + } +} + +// Host function for PyTorch interface (bfloat16 only) +void wmma_silu_add_cuda(torch::Tensor X, torch::Tensor A, torch::Tensor B, + torch::Tensor C) { + // Check inputs + TORCH_CHECK(X.device().is_cuda(), "X must be a CUDA tensor"); + TORCH_CHECK(A.device().is_cuda(), "A must be a CUDA tensor"); + TORCH_CHECK(B.device().is_cuda(), "B must be a CUDA tensor"); + TORCH_CHECK(C.device().is_cuda(), "C must be a CUDA tensor"); + TORCH_CHECK(X.is_contiguous(), "X must be contiguous"); + TORCH_CHECK(A.is_contiguous(), "A must be contiguous"); + TORCH_CHECK(B.is_contiguous(), "B must be contiguous"); + TORCH_CHECK(C.is_contiguous(), "C must be contiguous"); + TORCH_CHECK(X.dtype() == torch::kBFloat16, "X must be bfloat16"); + TORCH_CHECK(A.dtype() == torch::kBFloat16, "A must be bfloat16"); + TORCH_CHECK(B.dtype() == torch::kBFloat16, "B must be bfloat16"); + TORCH_CHECK(C.dtype() == torch::kBFloat16, "C must be bfloat16"); + + auto M = X.size(0); + auto K = X.size(1); + auto N = A.size(1); + + TORCH_CHECK(A.size(0) == K, "A.size(0) must equal X.size(1)"); + TORCH_CHECK(B.size(0) == K, "B.size(0) must equal X.size(1)"); + TORCH_CHECK(B.size(1) == N, "B.size(1) must equal A.size(1)"); + TORCH_CHECK(C.size(0) == M, "C.size(0) must equal X.size(0)"); + TORCH_CHECK(C.size(1) == N, "C.size(1) must equal A.size(1)"); + + // Calculate grid and block dimensions with multiple warps per block + const int WMMA_M = 16; + const int WMMA_N = 16; + + // Use 4x2 warps per block (8 warps total, 256 threads) + const int WARPS_PER_BLOCK_X = 4; + const int WARPS_PER_BLOCK_Y = 2; + const int THREADS_PER_BLOCK = WARPS_PER_BLOCK_X * WARPS_PER_BLOCK_Y * 32; + + // Calculate shared memory requirements + // Each warp needs one 16x16 float tile for intermediate storage + const int FRAGMENT_SIZE_BYTES = + 16 * 16 * sizeof(float); // 1024 bytes per tile (float, not bfloat16) + const int WARPS_PER_BLOCK = WARPS_PER_BLOCK_X * WARPS_PER_BLOCK_Y; + const int SHARED_MEM_SIZE = + WARPS_PER_BLOCK * FRAGMENT_SIZE_BYTES; // 8 warps × 1024 = 8192 bytes + + dim3 blockDim(THREADS_PER_BLOCK, 1); + + // Calculate grid dimensions based on warps per block + int grid_x = + ((M + WMMA_M - 1) / WMMA_M + WARPS_PER_BLOCK_X - 1) / WARPS_PER_BLOCK_X; + int grid_y = + ((N + WMMA_N - 1) / WMMA_N + WARPS_PER_BLOCK_Y - 1) / WARPS_PER_BLOCK_Y; + + dim3 gridDim(grid_x, grid_y); + + // Get CUDA stream + cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); + + // Launch pipelined bfloat16 kernel with specified shared memory + wmma_silu_add_kernel_bf16_pipelined<<>>( + reinterpret_cast(X.data_ptr()), + reinterpret_cast(A.data_ptr()), + reinterpret_cast(B.data_ptr()), + reinterpret_cast<__nv_bfloat16*>(C.data_ptr()), M, N, K, K, + N, N, N); + + cudaStreamSynchronize(stream); +} + +// Test function to compare CUDA native result vs PyTorch native result +torch::Tensor test_wmma_vs_torch_native(torch::Tensor X, torch::Tensor A, + torch::Tensor B) { + // Validate inputs + TORCH_CHECK(X.device().is_cuda(), "X must be a CUDA tensor"); + TORCH_CHECK(A.device().is_cuda(), "A must be a CUDA tensor"); + TORCH_CHECK(B.device().is_cuda(), "B must be a CUDA tensor"); + TORCH_CHECK(X.dtype() == torch::kBFloat16, "X must be bfloat16"); + TORCH_CHECK(A.dtype() == torch::kBFloat16, "A must be bfloat16"); + TORCH_CHECK(B.dtype() == torch::kBFloat16, "B must be bfloat16"); + + auto M = X.size(0); + auto K = X.size(1); + auto N = A.size(1); + + // Create output tensors + auto C_wmma = + torch::empty({M, N}, torch::dtype(torch::kBFloat16).device(X.device())); + auto C_torch = + torch::empty({M, N}, torch::dtype(torch::kBFloat16).device(X.device())); + + // Test 1: WMMA CUDA kernel result + auto start_wmma = std::chrono::high_resolution_clock::now(); + wmma_silu_add_cuda(X, A, B, C_wmma); + cudaDeviceSynchronize(); + auto end_wmma = std::chrono::high_resolution_clock::now(); + auto duration_wmma = std::chrono::duration_cast( + end_wmma - start_wmma); + + // Test 2: PyTorch native implementation + auto start_torch = std::chrono::high_resolution_clock::now(); + auto XA = torch::mm(X, A); + auto XB = torch::mm(X, B); + auto silu_XA = torch::sigmoid(XA) * XA; // SiLU activation + C_torch = silu_XA + XB; + cudaDeviceSynchronize(); + auto end_torch = std::chrono::high_resolution_clock::now(); + auto duration_torch = std::chrono::duration_cast( + end_torch - start_torch); + + // Calculate differences + auto diff = torch::abs(C_wmma - C_torch); + auto max_diff = torch::max(diff); + auto mean_diff = torch::mean(diff); + auto relative_diff = diff / (torch::abs(C_torch) + 1e-8f); + auto max_relative_diff = torch::max(relative_diff); + auto mean_relative_diff = torch::mean(relative_diff); + + // Convert to CPU for printing + auto max_diff_cpu = max_diff.cpu().item(); + auto mean_diff_cpu = mean_diff.cpu().item(); + auto max_rel_diff_cpu = max_relative_diff.cpu().item(); + auto mean_rel_diff_cpu = mean_relative_diff.cpu().item(); + + // Print test results + printf("\n=== WMMA vs PyTorch Native Comparison ===\n"); + printf("Matrix dimensions: M=%ld, K=%ld, N=%ld\n", M, K, N); + printf("WMMA kernel time: %ld μs\n", duration_wmma.count()); + printf("PyTorch time: %ld μs\n", duration_torch.count()); + printf("Speedup: %.2fx\n", + (float)duration_torch.count() / duration_wmma.count()); + printf("\nAccuracy metrics:\n"); + printf("Max absolute diff: %.8f\n", max_diff_cpu); + printf("Mean absolute diff: %.8f\n", mean_diff_cpu); + printf("Max relative diff: %.8f\n", max_rel_diff_cpu); + printf("Mean relative diff: %.8f\n", mean_rel_diff_cpu); + + // Determine test status + const float abs_tolerance = 1e-3f; // Relaxed for bfloat16 + const float rel_tolerance = 1e-2f; // 1% relative tolerance + + bool test_passed = + (max_diff_cpu < abs_tolerance) && (max_rel_diff_cpu < rel_tolerance); + + printf("\nTest result: %s\n", test_passed ? "PASSED" : "FAILED"); + if (!test_passed) { + printf("Tolerance: abs < %.6f, rel < %.6f\n", abs_tolerance, rel_tolerance); + } + printf("==========================================\n\n"); + + // Return results as a tensor for further analysis + auto results = + torch::tensor({max_diff_cpu, mean_diff_cpu, max_rel_diff_cpu, + mean_rel_diff_cpu, (float)duration_wmma.count(), + (float)duration_torch.count(), test_passed ? 1.0f : 0.0f}, + torch::dtype(torch::kFloat32).device(torch::kCPU)); + + return results; +} + +// Comprehensive benchmark function +torch::Tensor benchmark_wmma_multiple_sizes() { + printf("\n=== Comprehensive WMMA Benchmark ===\n"); + + // Test different matrix sizes + std::vector> test_sizes = { + {512, 512, 512}, // Small + {1024, 1024, 1024}, // Medium + {2048, 2048, 2048}, // Large + {4096, 4096, 4096}, // Very large + {1024, 512, 2048}, // Rectangular 1 + {2048, 1024, 512}, // Rectangular 2 + {8192, 8192, 8192}, // Huge (if memory allows) + }; + + std::vector all_results; + + for (auto& [M, K, N] : test_sizes) { + printf("\nTesting size M=%d, K=%d, N=%d\n", M, K, N); + + try { + // Create test tensors + auto X = torch::randn( + {M, K}, torch::dtype(torch::kBFloat16).device(torch::kCUDA)); + auto A = torch::randn( + {K, N}, torch::dtype(torch::kBFloat16).device(torch::kCUDA)); + auto B = torch::randn( + {K, N}, torch::dtype(torch::kBFloat16).device(torch::kCUDA)); + + // Run test + auto results = test_wmma_vs_torch_native(X, A, B); + auto results_vec = results.accessor(); + + // Store results + for (int i = 0; i < results.size(0); i++) { + all_results.push_back(results_vec[i]); + } + + // Add matrix dimensions + all_results.push_back((float)M); + all_results.push_back((float)K); + all_results.push_back((float)N); + + } catch (const std::exception& e) { + printf("Skipped size %dx%dx%d due to: %s\n", M, K, N, e.what()); + // Add zeros for failed test + for (int i = 0; i < 10; i++) { + all_results.push_back(0.0f); + } + } + } + + printf("=====================================\n"); + + // Return all results as tensor + return torch::tensor(all_results, + torch::dtype(torch::kFloat32).device(torch::kCPU)); +} + +int main() { + // Initialize CUDA + // at::cuda::init(); + + // Run comprehensive benchmark + auto results = benchmark_wmma_multiple_sizes(); + + // // Print final results + // printf("\n=== Final Benchmark Results ===\n"); + // printf("Results tensor shape: %s\n", results.sizes().vec()); + + return 0; +} diff --git a/tests/cuda/tests_masked_select.cu b/tests/cuda/tests_masked_select.cu new file mode 100644 index 0000000..2dd9e8e --- /dev/null +++ b/tests/cuda/tests_masked_select.cu @@ -0,0 +1,358 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Include your kernel headers here +#include "kernel/masked_select.h" + +// // For testing purposes, we'll define the function signatures +// extern void fused_extract_expert_tokens_bf16(const __nv_bfloat16* +// hidden_states, +// const bool* router_mask, +// __nv_bfloat16* output, +// int* output_count, int +// num_tokens, int hidden_dim, int +// expert_idx, int num_experts); + +// extern void extract_expert_tokens_cutlass_v2(const __nv_bfloat16* +// hidden_states, +// const bool* router_mask, +// __nv_bfloat16* output, +// int* output_count, int +// num_tokens, int hidden_dim, int +// expert_idx, int num_experts); + +// extern void extract_expert_tokens_cub(const __nv_bfloat16* hidden_states, +// const bool* router_mask, +// __nv_bfloat16* output, int* +// output_count, int num_tokens, int +// hidden_dim, int expert_idx, int +// num_experts); + +// Utility class for timing +class CudaTimer { + cudaEvent_t start, stop; + + public: + CudaTimer() { + cudaEventCreate(&start); + cudaEventCreate(&stop); + } + + ~CudaTimer() { + cudaEventDestroy(start); + cudaEventDestroy(stop); + } + + void Start() { cudaEventRecord(start); } + + float Stop() { + cudaEventRecord(stop); + cudaEventSynchronize(stop); + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + return milliseconds; + } +}; + +// Reference CPU implementation for correctness checking +void extract_expert_tokens_cpu(const std::vector<__nv_bfloat16>& hidden_states, + const std::vector& router_mask, + std::vector<__nv_bfloat16>& output, + int& output_count, int num_tokens, + int hidden_dim, int expert_idx, + int num_experts) { + output_count = 0; + + // Count selected tokens + for (int i = 0; i < num_tokens; i++) { + if (router_mask[i * num_experts + expert_idx]) { + output_count++; + } + } + + // Extract tokens + int out_idx = 0; + for (int i = 0; i < num_tokens; i++) { + if (router_mask[i * num_experts + expert_idx]) { + for (int j = 0; j < hidden_dim; j++) { + output[out_idx * hidden_dim + j] = hidden_states[i * hidden_dim + j]; + } + out_idx++; + } + } +} + +// Function to generate random test data +void generate_test_data(std::vector<__nv_bfloat16>& hidden_states, + std::vector& router_mask, int num_tokens, + int hidden_dim, int num_experts, + float sparsity = 0.1f) { + std::mt19937 gen(42); // Fixed seed for reproducibility and faster generation + std::uniform_real_distribution<> dis(-1.0, 1.0); + std::bernoulli_distribution mask_dis(sparsity); + + // Generate hidden states + hidden_states.resize(num_tokens * hidden_dim); + for (int i = 0; i < num_tokens * hidden_dim; i++) { + hidden_states[i] = __float2bfloat16(dis(gen)); + } + + // Generate router mask with specified sparsity + router_mask.resize(num_tokens * num_experts); + for (int i = 0; i < num_tokens * num_experts; i++) { + router_mask[i] = mask_dis(gen); + } +} + +// Function to check if two bf16 arrays are approximately equal +bool check_correctness(const std::vector<__nv_bfloat16>& ref, + const thrust::host_vector<__nv_bfloat16>& test, + int count, int hidden_dim, float tolerance = 1e-3f) { + for (int i = 0; i < count * hidden_dim; i++) { + float ref_val = __bfloat162float(ref[i]); + float test_val = __bfloat162float(test[i]); + + if (std::abs(ref_val - test_val) > tolerance) { + std::cout << "Mismatch at index " << i << ": ref=" << ref_val + << " test=" << test_val << std::endl; + return false; + } + } + return true; +} + +// Test configuration structure +struct TestConfig { + int num_tokens; + int hidden_dim; + int num_experts; + int expert_idx; + float sparsity; + const char* name; +}; + +// Main test function +void run_test(const TestConfig& config) { + std::cout << "\n=== Testing configuration: " << config.name + << " ===" << std::endl; + std::cout << "Tokens: " << config.num_tokens + << ", Hidden dim: " << config.hidden_dim + << ", Experts: " << config.num_experts + << ", Sparsity: " << config.sparsity << std::endl; + + // Generate test data + std::vector<__nv_bfloat16> h_hidden_states; + std::vector h_router_mask; + generate_test_data(h_hidden_states, h_router_mask, config.num_tokens, + config.hidden_dim, config.num_experts, config.sparsity); + + // Compute reference result on using torch tensor + std::vector<__nv_bfloat16> cpu_output(config.num_tokens * config.hidden_dim); + int cpu_count = 0; + extract_expert_tokens_cpu(h_hidden_states, h_router_mask, cpu_output, + cpu_count, config.num_tokens, config.hidden_dim, + config.expert_idx, config.num_experts); + + std::cout << "Selected tokens: " << cpu_count << " (" + << (100.0f * cpu_count / config.num_tokens) << "%)" << std::endl; + + // Copy data to GPU + thrust::device_vector<__nv_bfloat16> d_hidden_states = h_hidden_states; + thrust::device_vector d_router_mask = h_router_mask; + + // Prepare output buffers + thrust::device_vector<__nv_bfloat16> d_output_fused(config.num_tokens * + config.hidden_dim); + thrust::device_vector<__nv_bfloat16> d_output_cutlass(config.num_tokens * + config.hidden_dim); + thrust::device_vector<__nv_bfloat16> d_output_cub(config.num_tokens * + config.hidden_dim); + + thrust::device_vector d_count_fused(1); + thrust::device_vector d_count_cutlass(1); + thrust::device_vector d_count_cub(1); + + CudaTimer timer; + const int warmup_iters = 10; + const int bench_iters = 100; + + // Test 1: Fused kernel + std::cout << "\n1. Testing fused kernel..." << std::endl; + + // Warmup + for (int i = 0; i < warmup_iters; i++) { + thrust::fill(d_count_fused.begin(), d_count_fused.end(), 0); + + // Configure kernel launch + const int threads = 256; + const int warps_per_block = threads / 32; + const int blocks = + std::min(65535, (config.num_tokens + threads - 1) / threads); + const int smem_size = sizeof(int) * (warps_per_block + 1); + + fused_extract_expert_tokens_bf16_simple<<>>( + thrust::raw_pointer_cast(d_hidden_states.data()), + thrust::raw_pointer_cast(d_router_mask.data()), + thrust::raw_pointer_cast(d_output_fused.data()), + thrust::raw_pointer_cast(d_count_fused.data()), config.num_tokens, + config.hidden_dim, config.expert_idx, config.num_experts); + } + cudaDeviceSynchronize(); + std::cout << "Warmup complete." << std::endl; + + // Benchmark + timer.Start(); + for (int i = 0; i < bench_iters; i++) { + thrust::fill(d_count_fused.begin(), d_count_fused.end(), 0); + + const int threads = 256; + const int warps_per_block = threads / 32; + const int blocks = + std::min(65535, (config.num_tokens + threads - 1) / threads); + const int smem_size = sizeof(int) * (warps_per_block + 1); + + fused_extract_expert_tokens_bf16_simple<<>>( + thrust::raw_pointer_cast(d_hidden_states.data()), + thrust::raw_pointer_cast(d_router_mask.data()), + thrust::raw_pointer_cast(d_output_fused.data()), + thrust::raw_pointer_cast(d_count_fused.data()), config.num_tokens, + config.hidden_dim, config.expert_idx, config.num_experts); + } + float fused_time = timer.Stop() / bench_iters; + + // Check correctness + int fused_count = d_count_fused[0]; + thrust::host_vector<__nv_bfloat16> h_output_fused = d_output_fused; + + bool fused_correct = (fused_count == cpu_count) && + check_correctness(cpu_output, h_output_fused, cpu_count, + config.hidden_dim); + + std::cout << " Time: " << std::fixed << std::setprecision(3) << fused_time + << " ms" + << " Correctness: " << (fused_correct ? "PASSED" : "FAILED") + << std::endl; + + // Test 1: pytorch call + std::cout << "\n2. Testing PyTorch call..." << std::endl; + torch::Tensor hidden_states_tensor = + torch::from_blob(h_hidden_states.data(), + {config.num_tokens, config.hidden_dim}, torch::kBFloat16) + .clone() + .cuda(); + torch::Tensor router_mask_tensor = + torch::from_blob(h_router_mask.data(), + {config.num_tokens, config.num_experts}, torch::kBool) + .clone() + .cuda(); + torch::Tensor output_tensor = + torch::empty({cpu_count, config.hidden_dim}, torch::kBFloat16).cuda(); + // Warmup + for (int i = 0; i < warmup_iters; i++) { + auto token_mask = + router_mask_tensor.index({torch::indexing::Slice(), config.expert_idx}); + output_tensor.index_put_( + {torch::indexing::Slice(0, cpu_count), torch::indexing::Slice()}, + hidden_states_tensor.index({token_mask})); + } + + // Benchmark + timer.Start(); + for (int i = 0; i < bench_iters; i++) { + auto token_mask = + router_mask_tensor.index({torch::indexing::Slice(), config.expert_idx}); + output_tensor.index_put_( + {torch::indexing::Slice(0, cpu_count), torch::indexing::Slice()}, + hidden_states_tensor.index({token_mask})); + } + float pytorch_time = timer.Stop() / bench_iters; + std::cout << " Time: " << std::fixed << std::setprecision(3) << pytorch_time + << " ms" << std::endl; + // Check correctness + thrust::host_vector<__nv_bfloat16> h_output_pytorch = + output_tensor.cpu().bfloat16(); + bool pytorch_correct = + (cpu_count == h_output_pytorch.size() / config.hidden_dim) && + check_correctness(cpu_output, h_output_pytorch, cpu_count, + config.hidden_dim); + std::cout << " Correctness: " << (pytorch_correct ? "PASSED" : "FAILED") + << std::endl; + + // Performance summary + std::cout << "\nPerformance Summary:" << std::endl; + std::cout << " Fused kernel: " << fused_time << " ms (1.00x)" << std::endl; + std::cout << " PyTorch call: " << pytorch_time << " ms (" << std::fixed + << std::setprecision(2) << fused_time / pytorch_time << "x)" + << std::endl; + + // Calculate bandwidth utilization + size_t bytes_read = + config.num_tokens * config.hidden_dim * sizeof(__nv_bfloat16) + + config.num_tokens * config.num_experts * sizeof(bool); + size_t bytes_written = cpu_count * config.hidden_dim * sizeof(__nv_bfloat16); + float bandwidth_gb = + (bytes_read + bytes_written) / (1024.0f * 1024.0f * 1024.0f); + + std::cout << "\nBandwidth utilization:" << std::endl; + std::cout << " Fused: " << std::fixed << std::setprecision(1) + << bandwidth_gb / (fused_time / 1000.0f) << " GB/s" << std::endl; + std::cout << " PyTorch: " << std::fixed << std::setprecision(1) + << bandwidth_gb / (pytorch_time / 1000.0f) << " GB/s" << std::endl; +} + +int main(int argc, char** argv) { + // Check CUDA device + int device; + cudaGetDevice(&device); + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, device); + + std::cout << "Running on: " << prop.name << std::endl; + std::cout << "Compute capability: " << prop.major << "." << prop.minor + << std::endl; + std::cout << "Memory bandwidth: " + << prop.memoryBusWidth * prop.memoryClockRate * 2 / 8e6 << " GB/s" + << std::endl; + + // Define test configurations + std::vector configs = { + // Small tests + {1024, 768, 8, 0, 0.1f, "Small (BERT-like)"}, + {1024, 768, 8, 0, 0.01f, "Small (sparse)"}, + {1024, 768, 8, 0, 0.5f, "Small (dense)"}, + + // Medium tests + {4096, 1024, 16, 0, 0.1f, "Medium"}, + {4096, 4096, 8, 0, 0.1f, "Medium (wide)"}, + + // Large tests + {16384, 1024, 32, 0, 0.1f, "Large"}, + {16384, 4096, 16, 0, 0.05f, "Large (LLM-like)"}, + + // Stress tests + {65536, 1024, 64, 0, 0.02f, "Stress test"}, + {32768, 8192, 8, 0, 0.125f, "Stress test (very wide)"}, + }; + + // Run all tests + for (const auto& config : configs) { + try { + run_test(config); + } catch (const std::exception& e) { + std::cerr << "Test failed with exception: " << e.what() << std::endl; + } + } + + std::cout << "\n=== All tests completed ===" << std::endl; + + return 0; +} From 7c5918ba52ce20b4868ae80741a2d896cc700e11 Mon Sep 17 00:00:00 2001 From: xly Date: Mon, 7 Jul 2025 13:51:54 +0100 Subject: [PATCH 3/3] stable topk --- core/kernel/fused_mlp.cu | 39 +-- core/kernel/topk_softmax_kernels.cu | 2 +- core/model/moe.cpp | 18 ++ core/model/moe.h | 346 ++++++++++++++++++-------- core/python/py_archer_prefetch.cpp | 6 + examples/interface_example.py | 1 - moe_infinity/models/qwen.py | 54 ++-- moe_infinity/runtime/model_offload.py | 13 + op_builder/prefetch.py | 1 + tests/cuda/CMakeLists.txt | 32 ++- tests/cuda/test_topk_softmax.cu | 252 +++++++++++++++++++ 11 files changed, 616 insertions(+), 148 deletions(-) create mode 100644 core/model/moe.cpp create mode 100644 tests/cuda/test_topk_softmax.cu diff --git a/core/kernel/fused_mlp.cu b/core/kernel/fused_mlp.cu index 46915a8..13e237f 100644 --- a/core/kernel/fused_mlp.cu +++ b/core/kernel/fused_mlp.cu @@ -30,10 +30,12 @@ * Supports online training and simultaneous inference. */ -#include -#include +#if 0 -#include "common_device.h" + #include + #include + + #include "common_device.h" void check_shmem_error(cudaError_t error) { if (error != cudaSuccess) { @@ -90,18 +92,18 @@ __device__ void threadblock_layer( __syncthreads(); // Load N_BLOCKS chunks of weights from global memory into registers. -#pragma unroll + #pragma unroll for (uint32_t i = 0; i < N_BLOCKS; ++i) { wmma::load_matrix_sync(weights_frag[i], weights_this_layer + 16 * i + weights_col * WIDTH, WIDTH); } -#pragma unroll + #pragma unroll for (int l = 0; l < N_ITERS; ++l) { wmma::fill_fragment(result_frag[l], 0.0f); -#pragma unroll + #pragma unroll for (uint32_t i = 0; i < N_BLOCKS; ++i) { // Load a chunk of intermediate activations from shared memory and // multiply with chunk of weights @@ -117,7 +119,7 @@ __device__ void threadblock_layer( __syncthreads(); -#pragma unroll + #pragma unroll for (int l = 0; l < N_ITERS; ++l) { wmma::store_matrix_sync(act_shmem + weights_col + l * 16 * (WIDTH + SKEW), result_frag[l], WIDTH + SKEW, wmma::mem_row_major); @@ -126,7 +128,7 @@ __device__ void threadblock_layer( if (out_intermediate_threadblock_this_layer != nullptr) { __syncthreads(); -#pragma unroll + #pragma unroll for (int l = 0; l < N_ITERS; ++l) { *(int4*)&out_intermediate_threadblock_this_layer[lane_offset + (row + 16 * l) * WIDTH] = @@ -150,7 +152,7 @@ __device__ void threadblock_load_input_static( const uint32_t lane_offset = (8 * li) % WIDTH; const uint32_t row = (8 * li + wi * 8 * 32) / WIDTH; -#pragma unroll + #pragma unroll for (int i = 0; i < N_ITERS; ++i) { *(int4*)&act_shmem[lane_offset + (row + 16 * i) * (WIDTH + SKEW)] = *(int4*)&input_threadblock[lane_offset + (row + 16 * i) * WIDTH]; @@ -206,7 +208,7 @@ __device__ void threadblock_input_layer_forward_dynamic( const uint32_t n_elems_b = WIDTH * in_width; -#pragma unroll + #pragma unroll for (uint32_t idx = thread_elem_idx; idx < n_elems_b; idx += n_elems_per_load) { const uint32_t idx_skewed = idx + idx / in_width * INPUT_SKEW; @@ -219,7 +221,7 @@ __device__ void threadblock_input_layer_forward_dynamic( __syncthreads(); } -#pragma unroll + #pragma unroll for (int l = 0; l < N_ITERS; ++l) { if (std::is_same::value) { // Load chunk of inputs into shmem. @@ -227,7 +229,7 @@ __device__ void threadblock_input_layer_forward_dynamic( // only used once. (Possibly due to latency hiding through staging.) const uint32_t n_elems_a = 16 * in_width; -#pragma unroll + #pragma unroll for (uint32_t idx = thread_elem_idx; idx < n_elems_a; idx += n_elems_per_load) { const uint32_t idx_skewed = idx + idx / in_width * INPUT_SKEW; @@ -239,7 +241,7 @@ __device__ void threadblock_input_layer_forward_dynamic( } wmma::fill_fragment(result_frag[l], 0.0f); -#pragma unroll + #pragma unroll for (uint32_t i = 0; i < n_tensor_ops; ++i) { // Load chunk of inputs and weights from shared memory and multiply them if (std::is_same::value) { @@ -268,7 +270,7 @@ __device__ void threadblock_input_layer_forward_dynamic( __syncthreads(); } -#pragma unroll + #pragma unroll for (int l = 0; l < N_ITERS; ++l) { wmma::store_matrix_sync(act_shmem + weights_col + (16 * l) * (WIDTH + SKEW), result_frag[l], WIDTH + SKEW, wmma::mem_row_major); @@ -277,7 +279,7 @@ __device__ void threadblock_input_layer_forward_dynamic( if (out_intermediate_threadblock_this_layer != nullptr) { __syncthreads(); -#pragma unroll + #pragma unroll for (int i = 0; i < N_ITERS; ++i) { *(int4*)&out_intermediate_threadblock_this_layer[lane_offset + (row + 16 * i) * WIDTH] = @@ -329,7 +331,7 @@ __device__ void threadblock_last_layer_forward( __syncthreads(); -#pragma unroll + #pragma unroll for (uint32_t i = 0; i < N_BLOCKS; ++i) wmma::load_matrix_sync(weights_frag[i], weights_shmem + 16 * i, WIDTH + SKEW); @@ -337,7 +339,7 @@ __device__ void threadblock_last_layer_forward( // Perform last layer by parallelizing over iters for (uint32_t idx = wi; idx < N_ITERS; idx += N_BLOCKS) { wmma::fill_fragment(result_frag, 0.0f); -#pragma unroll + #pragma unroll for (uint32_t i = 0; i < N_BLOCKS; ++i) { // Load a chunk of intermediate activations from shared memory and // multiply with chunk of the weight matrix @@ -376,7 +378,7 @@ __device__ void threadblock_write_output_static( __syncthreads(); -#pragma unroll + #pragma unroll for (int i = 0; i < N_ITERS; ++i) { *(int4*)&output_threadblock[lane_offset + (row + 16 * i) * WIDTH] = *(int4*)&act_shmem[lane_offset + (row + 16 * i) * (WIDTH + SKEW)]; @@ -826,3 +828,4 @@ template class FullyFusedMLP; template class FullyFusedMLP; template class FullyFusedMLP; template class FullyFusedMLP; +#endif diff --git a/core/kernel/topk_softmax_kernels.cu b/core/kernel/topk_softmax_kernels.cu index 31a4b32..8397ee8 100644 --- a/core/kernel/topk_softmax_kernels.cu +++ b/core/kernel/topk_softmax_kernels.cu @@ -507,7 +507,7 @@ void topk_softmax(torch::Tensor& topk_weights, // [num_tokens, topk] topkGatingSoftmaxKernelLauncher( gating_output.data_ptr(), topk_weights.data_ptr(), - topk_indices.data_ptr(), token_expert_indices.data_ptr(), + topk_indices.data_ptr(), token_expert_indices.data_ptr(), softmax_workspace.data_ptr(), num_tokens, num_experts, topk, stream); diff --git a/core/model/moe.cpp b/core/model/moe.cpp new file mode 100644 index 0000000..a61ddc1 --- /dev/null +++ b/core/model/moe.cpp @@ -0,0 +1,18 @@ +#include "moe.h" + +void InitMoELayer(int num_experts, int topk, int max_tokens, int64_t hidden_dim, + int64_t intermediate_dim) { + std::call_once(moe_layer_init_flag, [&]() { + moe_layer_ptr = std::make_unique(num_experts, topk, max_tokens, + hidden_dim, intermediate_dim); + }); +} + +std::tuple TopKSoftmax( + torch::Tensor& gating_outputs) { + if (!moe_layer_ptr) { + throw std::runtime_error( + "MoELayer is not initialized. Call InitMoELayer first."); + } + return moe_layer_ptr->TopKSoftmax(gating_outputs); +} diff --git a/core/model/moe.h b/core/model/moe.h index 454f20c..102aca1 100644 --- a/core/model/moe.h +++ b/core/model/moe.h @@ -7,22 +7,27 @@ #include #include +#include #include "utils/cuda_utils.h" #include "common/pytorch.h" #include "kernel/ops.h" +#include "utils/logger.h" +#include "base/noncopyable.h" #define BUFFER_PTR(buf_type, ptr_type) \ (buffer_[static_cast(BufferType::buf_type)]) -#define CUDA_ALLOCATE_BUFFER(type, size) \ +#define TENSOR_INS(buf_type) tensors_[static_cast(BufferType::buf_type)] + +#define CUDA_ALLOCATE_BUFFER(type, size, dtype) \ CUDA_CHECK(cudaMalloc( \ reinterpret_cast(&buffer_[static_cast(BufferType::type)]), \ - size * sizeof(param_t))); + size * sizeof(dtype))); // always allocate max 4 bytes per element, can + // use less // The abstraction of MoE (Mixture of Experts) layer with fixed buffers. -template -class MoELayer { +class MoELayer : public base::noncopyable { public: enum class BufferType { @@ -50,113 +55,251 @@ class MoELayer { }; explicit MoELayer(int num_experts, int topk, int max_tokens, - int64_t hidden_dim, int64_t intermediate_dim) + int64_t hidden_dim, int64_t intermediate_dim, + bool use_bf16 = false, bool norm_topk_prob = false) : num_experts_(num_experts), topk_(topk), max_tokens_(max_tokens), hidden_dim_(hidden_dim), intermediate_dim_(intermediate_dim), buffer_(static_cast(BufferType::NumBuffers)) { - CUDA_ALLOCATE_BUFFER(HiddenStates, max_tokens * hidden_dim); + CUDA_ALLOCATE_BUFFER(HiddenStates, max_tokens * hidden_dim, float); // CUDA_ALLOCATE_BUFFER(GatingWeights, num_experts * hidden_dim); - CUDA_ALLOCATE_BUFFER(FinalHiddenStates, max_tokens * hidden_dim); - CUDA_ALLOCATE_BUFFER(GatingOutput, max_tokens * num_experts); - CUDA_ALLOCATE_BUFFER(TopKWeights, max_tokens * topk); - CUDA_ALLOCATE_BUFFER(TopKIndices, max_tokens * topk); - CUDA_ALLOCATE_BUFFER(TokenExpertIndices, max_tokens * topk); - CUDA_ALLOCATE_BUFFER(ExpertInput, max_tokens * hidden_dim); - CUDA_ALLOCATE_BUFFER(ExpertUpProjOutput, max_tokens * intermediate_dim); - CUDA_ALLOCATE_BUFFER(ExpertGateProjInput, max_tokens * intermediate_dim); - CUDA_ALLOCATE_BUFFER(ExpertDownProjOutput, max_tokens * hidden_dim); - CUDA_ALLOCATE_BUFFER(ExpertActMulOutput, max_tokens * hidden_dim); - - CUDA_ALLOCATE_BUFFER(ExpertRouterMask, max_tokens * num_experts); - CUDA_ALLOCATE_BUFFER(ExpertRouterWeight, max_tokens * num_experts); - - device_id_ = c10::cuda::current_device(); + CUDA_ALLOCATE_BUFFER(FinalHiddenStates, max_tokens * hidden_dim, float); + CUDA_ALLOCATE_BUFFER(GatingOutput, max_tokens * num_experts, float); + CUDA_ALLOCATE_BUFFER(TopKWeights, max_tokens * topk, float); + CUDA_ALLOCATE_BUFFER(TopKIndices, max_tokens * topk, int64_t); + CUDA_ALLOCATE_BUFFER(TokenExpertIndices, max_tokens * topk, int32_t); + CUDA_ALLOCATE_BUFFER(ExpertInput, max_tokens * hidden_dim, float); + CUDA_ALLOCATE_BUFFER(ExpertUpProjOutput, max_tokens * intermediate_dim, + float); + CUDA_ALLOCATE_BUFFER(ExpertGateProjInput, max_tokens * intermediate_dim, + float); + CUDA_ALLOCATE_BUFFER(ExpertDownProjOutput, max_tokens * hidden_dim, float); + CUDA_ALLOCATE_BUFFER(ExpertActMulOutput, max_tokens * hidden_dim, float); + + CUDA_ALLOCATE_BUFFER(ExpertRouterMask, max_tokens * num_experts, uint32_t); + CUDA_ALLOCATE_BUFFER(ExpertRouterWeight, max_tokens * num_experts, float); + + device_id_ = GetDevice(); cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking); + + DLOG_INFO( + "MoELayer initialized with num_experts:", num_experts_, " topk:", topk, + " max_tokens:", max_tokens, " hidden_dim:", hidden_dim, + " intermediate_dim:", intermediate_dim, " on device:", device_id_); + + scalar_types_ = { + torch::kFloat32, // HiddenStates + // torch::kFloat32, // GatingWeights + torch::kFloat32, // FinalHiddenStates + torch::kFloat32, // GatingOutput + torch::kFloat32, // TopKWeights + torch::kUInt32, // TopKIndices + torch::kInt32, // TokenExpertIndices + torch::kFloat32, // ExpertInput + torch::kFloat32, // ExpertUpProjOutput + torch::kFloat32, // ExpertGateProjInput + torch::kFloat32, // ExpertDownProjOutput + torch::kFloat32, // ExpertActMulOutput + torch::kBool, // ExpertRouterMask + torch::kFloat32 // ExpertRouterWeight + }; + + // Create tensors for each buffer for easy access + for (int i = 0; i < static_cast(BufferType::NumBuffers); ++i) { + tensors_.emplace_back( + torch::zeros({1}, torch::TensorOptions() + .dtype(scalar_types_[i]) + .device(CUDA_DEVICE(device_id_)))); + } + } + + torch::Tensor& _tensor(BufferType type) { + return tensors_[static_cast(type)]; } - void ForwardGating() { - // Forward pass for gating mechanism - // This function will use the buffers to compute gating weights and outputs - - // create temperal wrappers as tensor - auto hidden_states = - torch::from_blob(BUFFER_PTR(HiddenStates, void), - {max_tokens_, hidden_dim_}, DoNothingDeleter{}, - torch::TensorOptions() - .dtype(torch::dtype()) - .device(CUDA_DEVICE(device_id_))); - - auto gating_weights = - torch::from_blob(BUFFER_PTR(GatingWeights, void), - {num_experts_, hidden_dim_}, DoNothingDeleter{}, - torch::TensorOptions() - .dtype(torch::dtype()) - .device(CUDA_DEVICE(device_id_))); - - auto gating_output = - torch::from_blob(BUFFER_PTR(GatingOutput, void), - {max_tokens_, num_experts_}, DoNothingDeleter{}, - torch::TensorOptions() - .dtype(torch::dtype()) - .device(CUDA_DEVICE(device_id_))); + void* _buffer(BufferType type) { return buffer_[static_cast(type)]; } + std::tuple TopKSoftmax( + torch::Tensor& gating_outputs) { // Perform the gating operation on stream_ c10::cuda::CUDAStream torch_stream = c10::cuda::getStreamFromExternal(stream_, device_id_); c10::cuda::setCurrentCUDAStream(torch_stream); - torch::matmul_out(gating_output, hidden_states, - gating_weights.t()); // [max_tokens, num_experts] - - auto topk_weights = - torch::from_blob(BUFFER_PTR(TopKWeights, void), {max_tokens_, topk_}, - DoNothingDeleter{}, - torch::TensorOptions() - .dtype(torch::kFloat32) - .device(CUDA_DEVICE(device_id_))); - - auto topk_indices = - torch::from_blob(BUFFER_PTR(TopKIndices, void), {max_tokens_, topk_}, - DoNothingDeleter{}, - torch::TensorOptions() - .dtype(torch::kUInt32) - .device(CUDA_DEVICE(device_id_))); - - auto token_expert_indices = - torch::from_blob(BUFFER_PTR(TokenExpertIndices, void), - {max_tokens_, topk_}, DoNothingDeleter{}, - torch::TensorOptions() - .dtype(torch::kUInt32) - .device(CUDA_DEVICE(device_id_))); + + int64_t num_tokens = gating_outputs.size(0); + assert(num_tokens <= max_tokens_); + + auto logits = gating_outputs.to(torch::kFloat32); + // if (gating_outputs.dtype() != torch::kFloat32) { + // auto logits = torch::from_blob( + // BUFFER_PTR(GatingOutput, void), {num_tokens, num_experts_}, + // DoNothingDeleter{}, + // torch::TensorOptions() + // .dtype(torch::kFloat32) + // .device(CUDA_DEVICE(device_id_))); + // logits = gating_outputs.to(torch::kFloat32); + // } + + TENSOR_INS(TopKWeights) + .set_data(torch::from_blob(BUFFER_PTR(TopKWeights, void), + {num_tokens, topk_}, + DoNothingDeleter{}, + torch::TensorOptions() + .dtype(torch::kFloat32) + .device(CUDA_DEVICE(device_id_)))); + + TENSOR_INS(TopKIndices) + .set_data(torch::from_blob(BUFFER_PTR(TopKIndices, void), + {num_tokens, topk_}, + DoNothingDeleter{}, + torch::TensorOptions() + .dtype(torch::kInt64) + .device(CUDA_DEVICE(device_id_)))); + TENSOR_INS(TokenExpertIndices) + .set_data(torch::from_blob( + BUFFER_PTR(TokenExpertIndices, void), {num_tokens, topk_}, + DoNothingDeleter{}, + torch::TensorOptions() + .dtype(torch::kInt32) + .device(CUDA_DEVICE(device_id_)))); // Use Int32 for indices // Perform top-k softmax to get top-k gating weights and indices - topk_softmax(topk_weights, topk_indices, token_expert_indices, - gating_output); // [max_tokens, topk] - - auto router_mask = - torch::from_blob(BUFFER_PTR(ExpertRouterMask, void), - {max_tokens_, num_experts_}, DoNothingDeleter{}, - torch::TensorOptions() - .dtype(torch::kBool) - .device(CUDA_DEVICE(device_id_))); - - router_mask.scatter_(1, token_expert_indices, - true); // Set router mask based on top-k indices - - auto routing_weights_mask = - torch::from_blob(BUFFER_PTR(ExpertRouterWeight, void), - {max_tokens_, num_experts_}, DoNothingDeleter{}, - torch::TensorOptions() - .dtype(torch::dtype()) - .device(CUDA_DEVICE(device_id_))); - - routing_weights_mask.scatter_add_( - 1, token_expert_indices, - topk_weights); // Set routing weights mask + topk_softmax(TENSOR_INS(TopKWeights), TENSOR_INS(TopKIndices), + TENSOR_INS(TokenExpertIndices), + logits); // [max_tokens, topk] + // std::cout << "TopKIndices started on device: " + // << TENSOR_INS(TopKIndices) << std::endl; + + TENSOR_INS(TopKWeights) = + TENSOR_INS(TopKWeights) / + TENSOR_INS(TopKWeights).sum(1, true); // Normalize top-k weights + + auto routing_weights = TENSOR_INS(TopKWeights).to(torch::kBFloat16); + + TENSOR_INS(ExpertRouterMask) + .set_data(torch::from_blob(BUFFER_PTR(ExpertRouterMask, void), + {num_tokens, num_experts_}, + DoNothingDeleter{}, + torch::TensorOptions() + .dtype(torch::kBool) + .device(CUDA_DEVICE(device_id_)))); + // std::cout << "TokenExpertIndices started on device: " + // << TENSOR_INS(TokenExpertIndices) << std::endl; + + cudaMemsetAsync(BUFFER_PTR(ExpertRouterMask, void), 0, + num_tokens * num_experts_ * sizeof(uint32_t), + stream_); // Initialize router mask + cudaMemsetAsync(BUFFER_PTR(ExpertRouterWeight, void), 0, + num_tokens * num_experts_ * sizeof(float), + stream_); // Initialize router weights + + TENSOR_INS(ExpertRouterMask) + .scatter_(1, TENSOR_INS(TopKIndices), + true); // Set router mask based on top-k indices + + TENSOR_INS(ExpertRouterWeight) + .set_data(torch::from_blob(BUFFER_PTR(ExpertRouterWeight, void), + {num_tokens, num_experts_}, + DoNothingDeleter{}, + torch::TensorOptions() + .dtype(torch::kBFloat16) + .device(CUDA_DEVICE(device_id_)))); + + cudaStreamSynchronize(stream_); // Ensure all operations are complete + + TENSOR_INS(ExpertRouterWeight) + .scatter_add_(1, TENSOR_INS(TopKIndices), + routing_weights); // Set routing weights mask + // std::cout << "TopKSoftmax completed on device: " << + // TENSOR_INS(ExpertRouterMask) + // << std::endl; + return std::make_tuple(TENSOR_INS(ExpertRouterMask), + TENSOR_INS(ExpertRouterWeight)); } + // void ForwardGating() { + // // Forward pass for gating mechanism + // // This function will use the buffers to compute gating weights and + // outputs + + // // create temperal wrappers as tensor + // auto hidden_states = + // torch::from_blob(BUFFER_PTR(HiddenStates, void), + // {max_tokens_, hidden_dim_}, + // DoNothingDeleter{}, torch::TensorOptions() + // .dtype(torch::dtype()) + // .device(CUDA_DEVICE(device_id_))); + + // auto gating_weights = + // torch::from_blob(BUFFER_PTR(GatingWeights, void), + // {num_experts_, hidden_dim_}, + // DoNothingDeleter{}, torch::TensorOptions() + // .dtype(torch::dtype()) + // .device(CUDA_DEVICE(device_id_))); + + // auto gating_output = + // torch::from_blob(BUFFER_PTR(GatingOutput, void), + // {max_tokens_, num_experts_}, + // DoNothingDeleter{}, torch::TensorOptions() + // .dtype(torch::dtype()) + // .device(CUDA_DEVICE(device_id_))); + + // // Perform the gating operation on stream_ + // c10::cuda::CUDAStream torch_stream = + // c10::cuda::getStreamFromExternal(stream_, device_id_); + // c10::cuda::setCurrentCUDAStream(torch_stream); + // torch::matmul_out(gating_output, hidden_states, + // gating_weights.t()); // [max_tokens, num_experts] + + // auto topk_weights = + // torch::from_blob(BUFFER_PTR(TopKWeights, void), {max_tokens_, topk_}, + // DoNothingDeleter{}, + // torch::TensorOptions() + // .dtype(torch::kFloat32) + // .device(CUDA_DEVICE(device_id_))); + + // auto topk_indices = + // torch::from_blob(BUFFER_PTR(TopKIndices, void), {max_tokens_, topk_}, + // DoNothingDeleter{}, + // torch::TensorOptions() + // .dtype(torch::kUInt32) + // .device(CUDA_DEVICE(device_id_))); + + // auto token_expert_indices = + // torch::from_blob(BUFFER_PTR(TokenExpertIndices, void), + // {max_tokens_, topk_}, DoNothingDeleter{}, + // torch::TensorOptions() + // .dtype(torch::kUInt32) + // .device(CUDA_DEVICE(device_id_))); + + // // Perform top-k softmax to get top-k gating weights and indices + // topk_softmax(topk_weights, topk_indices, token_expert_indices, + // gating_output); // [max_tokens, topk] + + // auto router_mask = + // torch::from_blob(BUFFER_PTR(ExpertRouterMask, void), + // {max_tokens_, num_experts_}, + // DoNothingDeleter{}, torch::TensorOptions() + // .dtype(torch::kBool) + // .device(CUDA_DEVICE(device_id_))); + + // router_mask.scatter_(1, token_expert_indices, + // true); // Set router mask based on top-k indices + + // auto routing_weights_mask = + // torch::from_blob(BUFFER_PTR(ExpertRouterWeight, void), + // {max_tokens_, num_experts_}, + // DoNothingDeleter{}, torch::TensorOptions() + // .dtype(torch::dtype()) + // .device(CUDA_DEVICE(device_id_))); + + // routing_weights_mask.scatter_add_( + // 1, token_expert_indices, + // topk_weights); // Set routing weights mask + // } + ~MoELayer() { // Clean up allocated buffers for (auto* buffer : buffer_) { @@ -170,12 +313,21 @@ class MoELayer { } private: - std::vector buffer_; // Vector of buffers - int num_experts_ = 0; // Number of experts in the MoE layer - int topk_ = 0; // Number of top-k experts to select - int max_tokens_ = 0; // Maximum number of tokens processed in a batch - int64_t hidden_dim_ = 0; // Dimension of hidden states + std::vector buffer_; // Vector of buffers + std::vector tensors_; // Vector of tensors for easy access + std::vector scalar_types_; // Vector of scalar types + int64_t num_experts_ = 0; // Number of experts in the MoE layer + int64_t topk_ = 0; // Number of top-k experts to select + int64_t max_tokens_ = 0; // Maximum number of tokens processed in a batch + int64_t hidden_dim_ = 0; // Dimension of hidden states int64_t intermediate_dim_ = 0; // Dimension of intermediate states cudaStream_t stream_ = 0; // CUDA stream for asynchronous operations int device_id_ = 0; // Device ID for the MoE layer }; + +static std::unique_ptr moe_layer_ptr = nullptr; +static std::once_flag moe_layer_init_flag; +void InitMoELayer(int num_experts, int topk, int max_tokens, int64_t hidden_dim, + int64_t intermediate_dim); +std::tuple TopKSoftmax( + torch::Tensor& gating_outputs); diff --git a/core/python/py_archer_prefetch.cpp b/core/python/py_archer_prefetch.cpp index 8106da9..cb0a203 100644 --- a/core/python/py_archer_prefetch.cpp +++ b/core/python/py_archer_prefetch.cpp @@ -6,8 +6,14 @@ #include #include "parallel/expert_dispatcher.h" #include "prefetch/archer_prefetch_handle.h" +#include "model/moe.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("init_moe_layer", InitMoELayer, + "Initialize the MoE layer with the specified parameters."); + m.def("topk_softmax", TopKSoftmax, + "Perform top-k softmax operation on the MoE layer."); + py::class_(m, "prefetch_handle") .def(py::init()) diff --git a/examples/interface_example.py b/examples/interface_example.py index a3e4b7c..621c189 100644 --- a/examples/interface_example.py +++ b/examples/interface_example.py @@ -112,7 +112,6 @@ def end(self): # text for dataset in all_inputs for text in dataset["test"]["question"] if "test" in dataset # ] - custom_kwargs = {} if "switch" in args.model_name_or_path.lower(): custom_kwargs = {"decoder_start_token_id": 0} diff --git a/moe_infinity/models/qwen.py b/moe_infinity/models/qwen.py index eb8a6e8..a6b4f1b 100644 --- a/moe_infinity/models/qwen.py +++ b/moe_infinity/models/qwen.py @@ -4,6 +4,8 @@ import torch.nn.functional as F from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeMLP +from moe_infinity.ops.op_builder.prefetch import PrefetchBuilder + class Qwen3MoEBlock(nn.Module): def __init__(self, config): @@ -30,30 +32,36 @@ def __prepare_expert_route(self, hidden_states): # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk( - routing_weights, self.top_k, dim=-1 - ) - if self.norm_topk_prob: # only diff with mixtral sparse moe block! - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - # we cast back to the input dtype - routing_weights = routing_weights.to(hidden_states.dtype) - - # print(f"hidden_states shape: {hidden_states.shape}") - # print(f"routing_weights shape: {routing_weights.shape}") - - # Compute sparse mask via scatter - B, E = routing_weights.shape[0], self.num_experts - router_mask = torch.zeros( - B, E, dtype=torch.bool, device=selected_experts.device - ) - router_mask.scatter_(1, selected_experts, True) - - routing_weights_mask = torch.zeros( - B, E, dtype=routing_weights.dtype, device=routing_weights.device - ) - routing_weights_mask.scatter_add_(1, selected_experts, routing_weights) + router_mask, routing_weights_mask = self.lib.topk_softmax(router_logits) + # routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + # routing_weights, selected_experts = torch.topk( + # routing_weights, self.top_k, dim=-1 + # ) + # if self.norm_topk_prob: # only diff with mixtral sparse moe block! + # routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # # we cast back to the input dtype + # routing_weights = routing_weights.to(hidden_states.dtype) + + # # print(f"hidden_states shape: {hidden_states.shape}") + # # print(f"routing_weights shape: {routing_weights.shape}") + + # # Compute sparse mask via scatter + # B, E = routing_weights.shape[0], self.num_experts + # router_mask = torch.zeros( + # B, E, dtype=torch.bool, device=selected_experts.device + # ) + # router_mask.scatter_(1, selected_experts, True) + # routing_weights_mask = torch.zeros( + # B, E, dtype=routing_weights.dtype, device=routing_weights.device + # ) + # routing_weights_mask.scatter_add_(1, selected_experts, routing_weights) + # assert (routing_weights_mask_t == routing_weights_mask).all(), "routing_weights_mask_t and routing_weights_mask should be equal, max diff: {}".format( + # (routing_weights_mask_t - routing_weights_mask).abs().max() + # ) + # assert (router_mask_t == router_mask).all(), "router_mask_t and router_mask should be equal, max diff: {}".format( + # (router_mask_t - router_mask).abs().max() + # ) return router_logits, router_mask, routing_weights_mask @nvtx.annotate("Qwen3MoEBlock", color="blue") diff --git a/moe_infinity/runtime/model_offload.py b/moe_infinity/runtime/model_offload.py index 494043f..8f4e662 100644 --- a/moe_infinity/runtime/model_offload.py +++ b/moe_infinity/runtime/model_offload.py @@ -140,6 +140,7 @@ def init( # print("Distributed init done") self.prefetch_lib = PrefetchBuilder().load() if use_jit else prefetch_op + # new_alloc = torch.cuda.memory.CUDAPluggableAllocator( # self.prefetch_lib.__file__, "TorchAllocateDevice", "TorchFreeDevice" # ) @@ -341,6 +342,7 @@ def archer_from_pretrained(cls, *args, **kwargs): ) self.model_name = model_name = args[0] + # if "arctic" in model_name: # self.config = ArcticConfig.from_pretrained(*args, **kwargs) # else: @@ -349,6 +351,15 @@ def archer_from_pretrained(cls, *args, **kwargs): parse_moe_param(self.config) ) + if "qwen" in model_name.lower(): + self.prefetch_lib.init_moe_layer( + self.num_experts, + self.config.num_experts_per_tok, + 1024, + self.config.hidden_size, + self.config.moe_intermediate_size, + ) + self.dtype = parse_expert_dtype(self.config) self.dtype_cls = self.config.torch_dtype @@ -591,6 +602,8 @@ def archer_from_pretrained(cls, *args, **kwargs): module.expert_predictor = self.expert_predictor module.expert_tensor_map = self.expert_tensor_map + module.lib = self.prefetch_lib + self.expert_layer_modules.append(module) # module_experts = [ diff --git a/op_builder/prefetch.py b/op_builder/prefetch.py index 212ab9f..c322cf9 100644 --- a/op_builder/prefetch.py +++ b/op_builder/prefetch.py @@ -34,6 +34,7 @@ def sources(self): "core/utils/cuda_utils.cpp", "core/model/model_topology.cpp", "core/model/fused_mlp.cu", + "core/model/moe.cpp", "core/kernel/activation_kernels.cu", "core/kernel/topk_softmax_kernels.cu", "core/prefetch/archer_prefetch_handle.cpp", diff --git a/tests/cuda/CMakeLists.txt b/tests/cuda/CMakeLists.txt index 8f2b629..8c97356 100644 --- a/tests/cuda/CMakeLists.txt +++ b/tests/cuda/CMakeLists.txt @@ -125,9 +125,6 @@ set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -g -G -lineinfo -rdynamic -O3 -gencode a set(SRC_LIST test_uvm_kernel.cu test_fused_mlp.cu - test_fused_mlp_wmma.cu - test_expert_fusion.cu - test_expert_fusion_v2.cu # test_single_gemm_tiled.cu test_load_tile.cu test_load_tile_templated.cu @@ -135,19 +132,38 @@ set(SRC_LIST test_autosize_tileload.cu test_autosize_tileload_stage.cu test_autotune_blocksize.cu - tests_masked_select.cu + +) + +set(TORCH_SRC_LIST + test_topk_softmax.cu + test_expert_fusion.cu + test_expert_fusion_v2.cu + # tests_masked_select.cu + test_fused_mlp_wmma.cu ) +file(GLOB KERNEL_SRC "${CMAKE_SOURCE_DIR}/../../core/kernel/*.cu") +message(STATUS "Using kernel source files: ${KERNEL_SRC}, ${CMAKE_SOURCE_DIR}") + FOREACH(SRC ${SRC_LIST}) get_filename_component(SRC_NAME ${SRC} NAME_WE) add_executable(${SRC_NAME} ${SRC}) target_link_libraries(${SRC_NAME} cutlass ${CUDA_LIBRARIES}) +ENDFOREACH() - # if file is test_expert_fusion or test_expert_fusion_v2, link torch_python - IF(${SRC_NAME} STREQUAL "test_expert_fusion" OR ${SRC_NAME} STREQUAL "test_expert_fusion_v2" OR ${SRC_NAME} STREQUAL "tests_masked_select" OR ${SRC_NAME} STREQUAL "test_fused_mlp_wmma") - target_link_libraries(${SRC_NAME} ${torch_python_LIBRARY} ${Python3_LIBRARIES} ${TORCH_LIBRARIES}) - target_include_directories(${SRC_NAME} PRIVATE ${CONDA_INCLUDE_DIRS} ${TORCH_INCLUDE_DIRS} ${Python3_INCLUDE_DIRS}) + +FOREACH(SRC ${TORCH_SRC_LIST}) + get_filename_component(SRC_NAME ${SRC} NAME_WE) + IF(${SRC_NAME} STREQUAL "test_topk_softmax") + add_executable(${SRC_NAME} ${SRC} ${KERNEL_SRC}) + ELSE() + add_executable(${SRC_NAME} ${SRC}) ENDIF() + target_link_libraries(${SRC_NAME} cutlass ${CUDA_LIBRARIES}) + + target_link_libraries(${SRC_NAME} ${torch_python_LIBRARY} ${Python3_LIBRARIES} ${TORCH_LIBRARIES}) + target_include_directories(${SRC_NAME} PRIVATE ${CONDA_INCLUDE_DIRS} ${TORCH_INCLUDE_DIRS} ${Python3_INCLUDE_DIRS}) # IF(${SRC_NAME} STREQUAL "test_autosize_tileload") # target_link_libraries(${SRC_NAME} thrust) diff --git a/tests/cuda/test_topk_softmax.cu b/tests/cuda/test_topk_softmax.cu new file mode 100644 index 0000000..8d4411c --- /dev/null +++ b/tests/cuda/test_topk_softmax.cu @@ -0,0 +1,252 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include "kernel/ops.h" + +// PyTorch native implementation for comparison - same interface as custom +// kernel +void torch_topk_softmax(torch::Tensor& topk_weights, + torch::Tensor& topk_indices, + torch::Tensor& token_expert_indices, + torch::Tensor& gating_output) { + // Convert bf16 to float32 if needed for computation + torch::Tensor input_for_compute = gating_output; + if (gating_output.dtype() == torch::kBFloat16) { + input_for_compute = gating_output.to(torch::kFloat32); + } + + // Apply softmax + torch::Tensor softmax_output = torch::softmax(input_for_compute, -1); + + // Get topk values and indices + auto topk_result = torch::topk(softmax_output, topk_weights.size(-1), -1); + torch::Tensor temp_weights = std::get<0>(topk_result); + torch::Tensor temp_indices = std::get<1>(topk_result); + + // Copy results to pre-allocated tensors + topk_weights.copy_(temp_weights); + topk_indices.copy_(temp_indices.to(torch::kUInt32)); + + // Create token_expert_indices efficiently using PyTorch operations + int num_tokens = gating_output.size(0); + int topk = topk_weights.size(-1); + + // Create base indices for tokens: [0, 1, 2, ..., num_tokens-1] + torch::Tensor token_ids = + torch::arange(num_tokens, torch::TensorOptions() + .dtype(torch::kInt32) + .device(gating_output.device())); + + // Create k indices: [0, 1, 2, ..., topk-1] + torch::Tensor k_ids = + torch::arange(topk, torch::TensorOptions() + .dtype(torch::kInt32) + .device(gating_output.device())); + + // Broadcast and compute: k_ids[:, None] * num_tokens + token_ids[None, :] + // This creates the pattern: j * num_tokens + i for all (i,j) combinations + torch::Tensor k_offset = k_ids.unsqueeze(1) * num_tokens; // [topk, 1] + torch::Tensor token_base = token_ids.unsqueeze(0); // [1, num_tokens] + + // Final result: [topk, num_tokens] then transpose to [num_tokens, topk] + token_expert_indices.copy_((k_offset + token_base).transpose(0, 1)); +} + +class Timer { + private: + std::chrono::high_resolution_clock::time_point start_time; + + public: + void start() { start_time = std::chrono::high_resolution_clock::now(); } + + double elapsed_ms() { + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast( + end_time - start_time); + return duration.count() / 1000.0; // Convert to milliseconds + } +}; + +void benchmark_kernel_vs_torch(int num_tokens, int num_experts, int topk, + int num_iterations = 100) { + std::cout << "\n=== Benchmark: " << num_tokens << " tokens, " << num_experts + << " experts, topk=" << topk << " ===" << std::endl; + + // Setup CUDA device + torch::Device device(torch::kCUDA); + const at::cuda::OptionalCUDAGuard device_guard(device); + + // Create input tensor + torch::Tensor gating_output = torch::randn( + {num_tokens, num_experts}, + torch::TensorOptions().dtype(torch::kFloat32).device(device)); + + // Pre-allocate output tensors for custom kernel + torch::Tensor custom_topk_weights = torch::zeros( + {num_tokens, topk}, + torch::TensorOptions().dtype(torch::kFloat32).device(device)); + torch::Tensor custom_topk_indices = + torch::zeros({num_tokens, topk}, + torch::TensorOptions().dtype(torch::kUInt32).device(device)); + torch::Tensor custom_token_expert_indices = + torch::zeros({num_tokens, topk}, + torch::TensorOptions().dtype(torch::kInt32).device(device)); + + // Pre-allocate output tensors for PyTorch native (same tensors used across + // iterations) + torch::Tensor torch_topk_weights = torch::zeros( + {num_tokens, topk}, + torch::TensorOptions().dtype(torch::kFloat32).device(device)); + torch::Tensor torch_topk_indices = + torch::zeros({num_tokens, topk}, + torch::TensorOptions().dtype(torch::kUInt32).device(device)); + torch::Tensor torch_token_expert_indices = + torch::zeros({num_tokens, topk}, + torch::TensorOptions().dtype(torch::kInt32).device(device)); + + Timer timer; + std::vector custom_times, torch_times; + + // Warmup runs + std::cout << "Warming up..." << std::endl; + for (int i = 0; i < 10; i++) { + // Warmup custom kernel + topk_softmax(custom_topk_weights, custom_topk_indices, + custom_token_expert_indices, gating_output); + + // Warmup PyTorch + torch_topk_softmax(torch_topk_weights, torch_topk_indices, + torch_token_expert_indices, gating_output); + + torch::cuda::synchronize(); + } + + std::cout << "Running benchmark..." << std::endl; + + // Benchmark custom kernel + for (int i = 0; i < num_iterations; i++) { + timer.start(); + topk_softmax(custom_topk_weights, custom_topk_indices, + custom_token_expert_indices, gating_output); + torch::cuda::synchronize(); + custom_times.push_back(timer.elapsed_ms()); + } + + // Benchmark PyTorch native + for (int i = 0; i < num_iterations; i++) { + timer.start(); + torch_topk_softmax(torch_topk_weights, torch_topk_indices, + torch_token_expert_indices, gating_output); + torch::cuda::synchronize(); + torch_times.push_back(timer.elapsed_ms()); + } + + // Calculate statistics + auto calc_stats = [](const std::vector& times) { + double sum = 0, min_time = times[0], max_time = times[0]; + for (double t : times) { + sum += t; + min_time = std::min(min_time, t); + max_time = std::max(max_time, t); + } + double mean = sum / times.size(); + + double variance = 0; + for (double t : times) { + variance += (t - mean) * (t - mean); + } + double stddev = std::sqrt(variance / times.size()); + + return std::make_tuple(mean, min_time, max_time, stddev); + }; + + auto [custom_mean, custom_min, custom_max, custom_std] = + calc_stats(custom_times); + auto [torch_mean, torch_min, torch_max, torch_std] = calc_stats(torch_times); + + // Print results + std::cout << std::fixed << std::setprecision(3); + std::cout << "\nCustom Kernel Results:" << std::endl; + std::cout << " Mean: " << custom_mean << " ms" << std::endl; + std::cout << " Min: " << custom_min << " ms" << std::endl; + std::cout << " Max: " << custom_max << " ms" << std::endl; + std::cout << " Std: " << custom_std << " ms" << std::endl; + + std::cout << "\nPyTorch Native Results:" << std::endl; + std::cout << " Mean: " << torch_mean << " ms" << std::endl; + std::cout << " Min: " << torch_min << " ms" << std::endl; + std::cout << " Max: " << torch_max << " ms" << std::endl; + std::cout << " Std: " << torch_std << " ms" << std::endl; + + double speedup = torch_mean / custom_mean; + std::cout << "\nSpeedup: " << speedup << "x "; + if (speedup > 1.0) { + std::cout << "(Custom kernel is faster)" << std::endl; + } else { + std::cout << "(PyTorch native is faster)" << std::endl; + } + + // Verify correctness (optional) + std::cout << "\nVerifying correctness..." << std::endl; + + // Create temporary tensors for verification + torch::Tensor verify_weights = torch::zeros_like(custom_topk_weights); + torch::Tensor verify_indices = torch::zeros_like(custom_topk_indices); + torch::Tensor verify_token_indices = + torch::zeros_like(custom_token_expert_indices); + + torch_topk_softmax(verify_weights, verify_indices, verify_token_indices, + gating_output); + + // Compare a few values + bool close = torch::allclose(custom_topk_weights, verify_weights, 1e-3, 1e-3); + std::cout << "Results match PyTorch: " << (close ? "YES" : "NO") << std::endl; + + if (!close) { + std::cout << "Max difference: " + << torch::max(torch::abs(custom_topk_weights - verify_weights)) + .item() + << std::endl; + } +} + +int main() { + std::cout << "MoE Kernel vs PyTorch Native Speed Comparison" << std::endl; + std::cout << "=============================================" << std::endl; + + if (!torch::cuda::is_available()) { + std::cerr << "CUDA is not available!" << std::endl; + return -1; + } + + // std::cout << "CUDA Device: " << torch::cuda::get_device_name(0) << + // std::endl; + + // Test different configurations + std::vector> test_configs = { + {32, 128, 8}, {1024, 128, 8}, // Small: 1K tokens, 8 experts, top-2 + {4096, 128, 8}, // Medium: 4K tokens, 16 experts, top-2 + {8192, 128, 8}, // Large: 8K tokens, 32 experts, top-4 + {16384, 128, 8}, // Very large: 16K tokens, 64 experts, top-8 + {32768, 128, 8}, // Extreme: 32K tokens, 128 experts, top-2 + }; + + for (auto [num_tokens, num_experts, topk] : test_configs) { + try { + benchmark_kernel_vs_torch(num_tokens, num_experts, topk, 50); + } catch (const std::exception& e) { + std::cerr << "Error in configuration (" << num_tokens << ", " + << num_experts << ", " << topk << "): " << e.what() + << std::endl; + } + } + + std::cout << "\nBenchmark completed!" << std::endl; + return 0; +}