From 7c8cda49baec61961637da93540a15c7194d9cc6 Mon Sep 17 00:00:00 2001 From: "zhangxu.709" Date: Thu, 16 Oct 2025 15:34:16 +0800 Subject: [PATCH] feat: add flashinfer as kernel backend for cuda device. --- .gitmodules | 12 ++ CMakeLists.txt | 74 ++++++++++- setup.py | 22 +++- third_party/CMakeLists.txt | 38 ++++++ third_party/cutlass | 1 + third_party/dlpack | 1 + third_party/flashinfer | 1 + third_party/tvm-ffi | 1 + xllm/core/common/CMakeLists.txt | 2 + xllm/core/common/flashinfer_workspace.cpp | 46 +++++++ xllm/core/common/flashinfer_workspace.h | 49 ++++++++ xllm/core/common/global_flags.cpp | 8 +- xllm/core/common/global_flags.h | 2 + .../framework/batch/batch_input_builder.cpp | 28 ++++- .../framework/batch/batch_input_builder.h | 7 +- .../core/framework/model/model_input_params.h | 20 +++ xllm/core/kernels/CMakeLists.txt | 4 + xllm/core/kernels/cuda/CMakeLists.txt | 25 ++++ xllm/core/kernels/cuda/activation.cpp | 32 +++++ xllm/core/kernels/cuda/batch_decode.cpp | 89 +++++++++++++ xllm/core/kernels/cuda/batch_prefill.cpp | 104 +++++++++++++++ xllm/core/kernels/cuda/cuda_ops_api.h | 79 ++++++++++++ xllm/core/kernels/cuda/matmul.cpp | 27 ++++ xllm/core/kernels/cuda/norm.cpp | 31 +++++ xllm/core/kernels/cuda/reshape_paged_cache.cu | 106 ++++++++++++++++ xllm/core/kernels/cuda/rope.cpp | 39 ++++++ xllm/core/kernels/cuda/utils.cpp | 119 ++++++++++++++++++ xllm/core/kernels/cuda/utils.h | 57 +++++++++ xllm/core/kernels/mlu/mlu_ops_api.h | 5 - xllm/core/kernels/ops_api.cpp | 62 ++++++++- xllm/core/kernels/ops_api.h | 13 +- xllm/core/kernels/param.h | 6 +- xllm/core/layers/CMakeLists.txt | 1 - xllm/core/layers/common/CMakeLists.txt | 1 + xllm/core/layers/common/attention.cpp | 24 ++++ xllm/core/layers/common/attention.h | 7 ++ xllm/core/layers/common/dense_mlp.h | 2 +- xllm/core/layers/common/fused_moe.h | 2 +- xllm/core/layers/{ => common}/linear.h | 6 +- xllm/core/layers/common/qwen3_attention.h | 2 +- xllm/core/platform/device.cpp | 2 + xllm/core/platform/stream.cpp | 6 +- xllm/core/platform/stream.h | 8 +- xllm/core/platform/vmm_api.cpp | 2 +- xllm/core/runtime/forward_params.h | 4 + xllm/core/runtime/llm_engine.cpp | 2 +- xllm/core/runtime/llm_worker_impl.cpp | 6 +- xllm/core/runtime/params_utils.cpp | 23 ++++ xllm/core/runtime/worker_impl.cpp | 10 +- xllm/models/llm/llm_model_base.h | 1 - 50 files changed, 1166 insertions(+), 53 deletions(-) create mode 160000 third_party/cutlass create mode 160000 third_party/dlpack create mode 160000 third_party/flashinfer create mode 160000 third_party/tvm-ffi create mode 100644 xllm/core/common/flashinfer_workspace.cpp create mode 100644 xllm/core/common/flashinfer_workspace.h create mode 100644 xllm/core/kernels/cuda/CMakeLists.txt create mode 100644 xllm/core/kernels/cuda/activation.cpp create mode 100644 xllm/core/kernels/cuda/batch_decode.cpp create mode 100644 xllm/core/kernels/cuda/batch_prefill.cpp create mode 100644 xllm/core/kernels/cuda/cuda_ops_api.h create mode 100644 xllm/core/kernels/cuda/matmul.cpp create mode 100644 xllm/core/kernels/cuda/norm.cpp create mode 100644 xllm/core/kernels/cuda/reshape_paged_cache.cu create mode 100644 xllm/core/kernels/cuda/rope.cpp create mode 100644 xllm/core/kernels/cuda/utils.cpp create mode 100644 xllm/core/kernels/cuda/utils.h rename xllm/core/layers/{ => common}/linear.h (98%) diff --git a/.gitmodules b/.gitmodules index a4254bed..6f5fc412 100755 --- a/.gitmodules +++ b/.gitmodules @@ -28,3 +28,15 @@ [submodule "third_party/Mooncake"] path = third_party/Mooncake url = https://gitcode.com/xLLM-AI/Mooncake.git +[submodule "third_party/flashinfer"] + path = third_party/flashinfer + url = https://gitcode.com/xLLM-AI/flashinfer.git +[submodule "third_party/cutlass"] + path = third_party/cutlass + url = https://gitcode.com/xLLM-AI/cutlass.git +[submodule "third_party/tvm-ffi"] + path = third_party/tvm-ffi + url = https://gitcode.com/xLLM-AI/tvm-ffi.git +[submodule "third_party/dlpack"] + path = third_party/dlpack + url = https://gitcode.com/xLLM-AI/dlpack.git diff --git a/CMakeLists.txt b/CMakeLists.txt index c7765ee7..924cb297 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,8 +1,10 @@ cmake_minimum_required(VERSION 3.26) set_property(GLOBAL PROPERTY USE_FOLDERS ON) +set(CMAKE_CUDA_COMPILER "/usr/local/cuda/bin/nvcc") option(USE_NPU "Enable NPU support" OFF) option(USE_MLU "Enable MLU support" OFF) +option(USE_CUDA "Enable CUDA support" OFF) if(DEVICE_ARCH STREQUAL "ARM") set(CMAKE_SYSTEM_PROCESSOR aarch64) @@ -101,7 +103,7 @@ set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS ON) -if(USE_NPU) +if(USE_NPU OR USE_CUDA) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0") add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) elseif(USE_MLU) @@ -178,6 +180,32 @@ if (DEFINED ENV{DEPENDENCES_ROOT}) message(STATUS "Using DEPENDENCES_ROOT: $ENV{DEPENDENCES_ROOT}") endif() +# set architecture for CUDA +if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES AND USE_CUDA) + set(CMAKE_CUDA_ARCHITECTURES 80) +endif() + +# Build TORCH_CUDA_ARCH_LIST +if(USE_CUDA) + # Build TORCH_CUDA_ARCH_LIST + set(TORCH_CUDA_ARCH_LIST "") + foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES) + if(CUDA_ARCH MATCHES "^([0-9])([0-9])a$") + set(TORCH_ARCH "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}a") + elseif(CUDA_ARCH MATCHES "^([0-9])([0-9])*$") + set(TORCH_ARCH "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}") + elseif(CUDA_ARCH STREQUAL "native") + set(TORCH_ARCH "Auto") + else() + message(FATAL_ERROR "${CUDA_ARCH} is not supported") + endif() + list(APPEND TORCH_CUDA_ARCH_LIST ${TORCH_ARCH}) + endforeach() + + message(STATUS "CMAKE_CUDA_ARCHITECTURES: ${CMAKE_CUDA_ARCHITECTURES}") + message(STATUS "TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}") +endif() + # configure vcpkg # have to set CMAKE_TOOLCHAIN_FILE before first project call. # if (DEFINED ENV{VCPKG_ROOT} AND NOT DEFINED CMAKE_TOOLCHAIN_FILE) @@ -217,7 +245,12 @@ endif() set(CPPREST_EXCLUDE_WEBSOCKETS ON CACHE BOOL "Exclude websockets functionality." FORCE) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-format-truncation") -project("xllm" LANGUAGES C CXX) +if(USE_CUDA) + project("xllm" LANGUAGES C CXX CUDA) + find_package(CUDAToolkit REQUIRED) +else() + project("xllm" LANGUAGES C CXX) +endif() # find_package(CUDAToolkit REQUIRED) @@ -352,6 +385,43 @@ if(USE_MLU) ) endif() +if(USE_CUDA) + add_definitions(-DUSE_CUDA) + add_compile_definitions(TORCH_CUDA=1) + set(CMAKE_VERBOSE_MAKEFILE ON) + include_directories( + $ENV{PYTHON_INCLUDE_PATH} + $ENV{PYTORCH_INSTALL_PATH}/include + $ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include + ) + + link_directories( + $ENV{PYTHON_LIB_PATH} + $ENV{PYTORCH_INSTALL_PATH}/lib + $ENV{CUDA_TOOLKIT_ROOT_DIR}/lib64 + ) + + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -O3) + # The following definitions must be undefined since half-precision operation is required. + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} + -U__CUDA_NO_HALF_OPERATORS__ + -U__CUDA_NO_HALF_CONVERSIONS__ + -U__CUDA_NO_HALF2_OPERATORS__ + -U__CUDA_NO_BFLOAT16_CONVERSIONS__) + set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} --use_fast_math -Xfatbin -compress-all) + message(STATUS "CUDA_NVCC_FLAGS: ${CUDA_NVCC_FLAGS}") + + # find_package(NCCL REQUIRED) + + # find cudnn + execute_process(COMMAND python -c "import nvidia.cudnn; print(nvidia.cudnn.__file__)" OUTPUT_VARIABLE CUDNN_PYTHON_PATH) + get_filename_component(CUDNN_ROOT_DIR "${CUDNN_PYTHON_PATH}" DIRECTORY) + link_directories( + ${CUDNN_ROOT_DIR}/lib64 + ${CUDNN_ROOT_DIR}/lib + ) +endif() + # check if USE_CXX11_ABI is set correctly # if (DEFINED USE_CXX11_ABI) # parse_make_options(${TORCH_CXX_FLAGS} "TORCH_CXX_FLAGS") diff --git a/setup.py b/setup.py index 62c56345..ec45dbde 100644 --- a/setup.py +++ b/setup.py @@ -212,7 +212,13 @@ def set_mlu_envs(): os.environ["LIBTORCH_ROOT"] = get_torch_root_path() os.environ["PYTORCH_INSTALL_PATH"] = get_torch_root_path() os.environ["PYTORCH_MLU_INSTALL_PATH"] = get_torch_mlu_root_path() - + +def set_cuda_envs(): + os.environ["PYTHON_INCLUDE_PATH"] = get_python_include_path() + os.environ["PYTHON_LIB_PATH"] = get_torch_root_path() + os.environ["LIBTORCH_ROOT"] = get_torch_root_path() + os.environ["PYTORCH_INSTALL_PATH"] = get_torch_root_path() + class CMakeExtension(Extension): def __init__(self, name: str, path: str, sourcedir: str = "") -> None: super().__init__(name, sources=[]) @@ -223,7 +229,7 @@ def __init__(self, name: str, path: str, sourcedir: str = "") -> None: class ExtBuild(build_ext): user_options = build_ext.user_options + [ ("base-dir=", None, "base directory of xLLM project"), - ("device=", None, "target device type (a3 or a2 or mlu)"), + ("device=", None, "target device type (a3 or a2 or mlu or cuda)"), ("arch=", None, "target arch type (x86 or arm)"), ("install-xllm-kernels=", None, "install xllm_kernels RPM package (true/false)"), ] @@ -302,8 +308,14 @@ def build_extension(self, ext: CMakeExtension): cmake_args += ["-DUSE_MLU=ON"] # set mlu environment variables set_mlu_envs() + elif self.device == "cuda": + cuda_architectures = "80;89;90" + cmake_args += ["-DUSE_CUDA=ON", + f"-DCMAKE_CUDA_ARCHITECTURES={cuda_architectures}"] + # set cuda environment variables + set_cuda_envs() else: - raise ValueError("Please set --device to a2 or a3 or mlu.") + raise ValueError("Please set --device to a2 or a3 or mlu or cuda.") # Adding CMake arguments set as environment variable @@ -353,7 +365,7 @@ def build_extension(self, ext: CMakeExtension): class BuildDistWheel(bdist_wheel): user_options = bdist_wheel.user_options + [ - ("device=", None, "target device type (a3 or a2 or mlu)"), + ("device=", None, "target device type (a3 or a2 or mlu or cuda)"), ("arch=", None, "target arch type (x86 or arm)"), ] @@ -530,7 +542,7 @@ def apply_patch(): idx = sys.argv.index('--device') if idx + 1 < len(sys.argv): device = sys.argv[idx+1].lower() - if device not in ('a2', 'a3', 'mlu'): + if device not in ('a2', 'a3', 'mlu', 'cuda'): print("Error: --device must be a2 or a3 or mlu (case-insensitive)") sys.exit(1) # Remove the arguments so setup() doesn't see them diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 7fb8fa93..815ab7e8 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -20,3 +20,41 @@ target_include_directories(mooncake_store PUBLIC ) target_link_libraries(mooncake_store PUBLIC transfer_engine cachelib_memory_allocator) + + +if(USE_CUDA) + cc_library( + NAME + cutlass + INCLUDES + cutlass/include + cutlass/tools/util/include + DEPS + torch # TODO: depends on CUDA instead of torch + ) + cc_library( + NAME + dlpack + INCLUDES + dlpack/include + ) + cc_library( + NAME + tvm-ffi + INCLUDES + tvm-ffi/include + DEPS + dlpack + ) + cc_library( + NAME + flashinfer + INCLUDES + flashinfer/include + flashinfer/csrc + DEPS + cutlass + tvm-ffi + dlpack + ) +endif() \ No newline at end of file diff --git a/third_party/cutlass b/third_party/cutlass new file mode 160000 index 00000000..e6e2cc29 --- /dev/null +++ b/third_party/cutlass @@ -0,0 +1 @@ +Subproject commit e6e2cc29f5e7611dfc6af0ed6409209df0068cf2 diff --git a/third_party/dlpack b/third_party/dlpack new file mode 160000 index 00000000..93c8f2a3 --- /dev/null +++ b/third_party/dlpack @@ -0,0 +1 @@ +Subproject commit 93c8f2a3c774b84af6f652b1992c48164fae60fc diff --git a/third_party/flashinfer b/third_party/flashinfer new file mode 160000 index 00000000..d4a3ff43 --- /dev/null +++ b/third_party/flashinfer @@ -0,0 +1 @@ +Subproject commit d4a3ff4356aeaeaa2e67a8b176d72a749d96a089 diff --git a/third_party/tvm-ffi b/third_party/tvm-ffi new file mode 160000 index 00000000..af898a2c --- /dev/null +++ b/third_party/tvm-ffi @@ -0,0 +1 @@ +Subproject commit af898a2c32f053806064ef7b679682f94b5569c1 diff --git a/xllm/core/common/CMakeLists.txt b/xllm/core/common/CMakeLists.txt index 3410b2e5..f1e49f48 100644 --- a/xllm/core/common/CMakeLists.txt +++ b/xllm/core/common/CMakeLists.txt @@ -15,6 +15,7 @@ cc_library( rate_limiter.h types.h device_monitor.h + flashinfer_workspace.h SRCS etcd_client.cpp global_flags.cpp @@ -23,6 +24,7 @@ cc_library( options.cpp rate_limiter.cpp device_monitor.cpp + flashinfer_workspace.cpp DEPS util absl::random_random diff --git a/xllm/core/common/flashinfer_workspace.cpp b/xllm/core/common/flashinfer_workspace.cpp new file mode 100644 index 00000000..eff340da --- /dev/null +++ b/xllm/core/common/flashinfer_workspace.cpp @@ -0,0 +1,46 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "flashinfer_workspace.h" + +#include "global_flags.h" + +namespace xllm { + +void FlashinferWorkspace::initialize(const torch::Device& device) { + float_workspace_buffer_ = + torch::empty({FLAGS_workspace_buffer_size}, + torch::dtype(torch::kUInt8).device(device)); + int_workspace_buffer_ = + torch::empty({FLAGS_workspace_buffer_size}, + torch::dtype(torch::kUInt8).device(device)); + page_locked_int_workspace_buffer_ = torch::empty( + {FLAGS_workspace_buffer_size}, + torch::dtype(torch::kUInt8).device(torch::kCPU).pinned_memory(true)); +} + +torch::Tensor FlashinferWorkspace::get_float_workspace_buffer() { + return float_workspace_buffer_; +} + +torch::Tensor FlashinferWorkspace::get_int_workspace_buffer() { + return int_workspace_buffer_; +} + +torch::Tensor FlashinferWorkspace::get_page_locked_int_workspace_buffer() { + return page_locked_int_workspace_buffer_; +} + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/common/flashinfer_workspace.h b/xllm/core/common/flashinfer_workspace.h new file mode 100644 index 00000000..bbd875a3 --- /dev/null +++ b/xllm/core/common/flashinfer_workspace.h @@ -0,0 +1,49 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include + +#include "macros.h" + +namespace xllm { + +class FlashinferWorkspace { + public: + static FlashinferWorkspace& get_instance() { + static FlashinferWorkspace instance; + return instance; + }; + + void initialize(const torch::Device& device); + + torch::Tensor get_float_workspace_buffer(); + torch::Tensor get_int_workspace_buffer(); + torch::Tensor get_page_locked_int_workspace_buffer(); + + private: + FlashinferWorkspace() = default; + ~FlashinferWorkspace() = default; + DISALLOW_COPY_AND_ASSIGN(FlashinferWorkspace); + + torch::Tensor float_workspace_buffer_; + torch::Tensor int_workspace_buffer_; + torch::Tensor page_locked_int_workspace_buffer_; +}; + +} // namespace xllm \ No newline at end of file diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index dd95e2ea..9de87658 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -385,4 +385,10 @@ DEFINE_int64(buffer_size_per_seq, // --- beam search config --- DEFINE_bool(enable_beam_search_kernel, false, - "Whether to enable beam search kernel."); \ No newline at end of file + "Whether to enable beam search kernel."); + +// --- flashinfer config --- +DEFINE_int32(workspace_buffer_size, + 512 * 1024 * 1024, + "The user reserved workspace buffer used to store intermediate " + "attention results in split-k algorithm."); diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 10e1d794..e90eaa69 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -202,3 +202,5 @@ DECLARE_int64(buffer_size_per_seq); DECLARE_bool(enable_beam_search_kernel); DECLARE_bool(enable_shm); + +DECLARE_int32(workspace_buffer_size); diff --git a/xllm/core/framework/batch/batch_input_builder.cpp b/xllm/core/framework/batch/batch_input_builder.cpp index 2ff34176..7f2fe880 100644 --- a/xllm/core/framework/batch/batch_input_builder.cpp +++ b/xllm/core/framework/batch/batch_input_builder.cpp @@ -216,7 +216,7 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx, state_.q_seq_lens.insert(state_.q_seq_lens.end(), state.q_seq_lens.begin(), state.q_seq_lens.end()); -#elif defined(USE_MLU) +#elif defined(USE_MLU) || defined(USE_CUDA) int32_t seq_len_offset = state_.seq_lens.back(); // skip the first element which is 0 for (size_t i = 1; i < state.seq_lens.size(); ++i) { @@ -288,7 +288,7 @@ void BatchInputBuilder::process_single_sequence( #if defined(USE_NPU) state.seq_lens.push_back(seq_len); state.q_seq_lens.push_back(q_seq_len); -#elif defined(USE_MLU) +#elif defined(USE_MLU) || defined(USE_CUDA) state.seq_lens.push_back(state.seq_lens.back() + seq_len); state.q_seq_lens.push_back(state.q_seq_lens.back() + q_seq_len); #endif @@ -448,7 +448,12 @@ void BatchInputBuilder::setup_kv_cache_info( block_size = block.size(); block_ids.push_back(block.id()); u_block_ids.emplace_back(block.id()); + state.paged_kv_indices.push_back(block.id()); } + state.paged_kv_indptr.push_back(state.paged_kv_indptr.back() + blocks.size()); + int32_t last_page_len = + (seq_len % block_size == 0) ? block_size : seq_len % block_size; + state.paged_kv_last_page_len.push_back(last_page_len); int32_t kv_cache_block_idx = n_kv_cache_tokens / block_size; for (auto iter = block_ids.begin() + kv_cache_block_idx; @@ -517,12 +522,15 @@ void BatchInputBuilder::padding_decode_batch_size( #if defined(USE_NPU) state_.seq_lens.push_back(num_decoding_tokens); state_.q_seq_lens.push_back(num_decoding_tokens); -#elif defined(USE_MLU) +#elif defined(USE_MLU) || defined(USE_CUDA) state_.seq_lens.push_back(state_.seq_lens.back() + num_decoding_tokens); state_.q_seq_lens.push_back(state_.q_seq_lens.back() + num_decoding_tokens); #endif state_.block_tables_vec.emplace_back(); + state_.paged_kv_indices.push_back(0); + state_.paged_kv_indptr.push_back(state_.paged_kv_indptr.back() + 1); + state_.paged_kv_last_page_len.push_back(1); } } } @@ -560,6 +568,14 @@ ForwardInput BatchInputBuilder::state_to_forward_input() { input_params.decode_seq_range = util::find_ones_indices(input_params.q_seq_lens_vec); + // for flashinfer + input_params.paged_kv_indptr = + torch::tensor(state_.paged_kv_indptr, torch::kInt); + input_params.paged_kv_indices = + torch::tensor(state_.paged_kv_indices, torch::kInt); + input_params.paged_kv_last_page_len = + torch::tensor(state_.paged_kv_last_page_len, torch::kInt); + // Setup multimodal data input_params.mm_data = MMData::batch(mm_data_vec_); @@ -634,6 +650,12 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() { raw_forward_input.transfer_kv_infos = std::move(state_.transfer_kv_infos); raw_forward_input.prefill_seq_len = state_.prefill_seq_len; + // for flashinfer + raw_forward_input.paged_kv_indptr = std::move(state_.paged_kv_indptr); + raw_forward_input.paged_kv_indices = std::move(state_.paged_kv_indices); + raw_forward_input.paged_kv_last_page_len = + std::move(state_.paged_kv_last_page_len); + raw_forward_input.embedding_ids = std::move(state_.embedding_ids); raw_forward_input.extra_token_ids = std::move(state_.extra_token_ids); // beam search kernel input diff --git a/xllm/core/framework/batch/batch_input_builder.h b/xllm/core/framework/batch/batch_input_builder.h index 9b76bfb1..508610fd 100644 --- a/xllm/core/framework/batch/batch_input_builder.h +++ b/xllm/core/framework/batch/batch_input_builder.h @@ -86,7 +86,7 @@ class BatchInputBuilder { #if defined(USE_NPU) std::vector seq_lens; std::vector q_seq_lens; -#elif defined(USE_MLU) +#elif defined(USE_MLU) || defined(USE_CUDA) std::vector seq_lens = {0}; // cu_seq_lens std::vector q_seq_lens = {0}; // q_cu_seq_len #endif @@ -107,6 +107,11 @@ class BatchInputBuilder { // for continuous kvcache std::vector new_cache_slot_offsets; //[n_tokens] std::vector kv_cache_start_offsets; //[n_seq] + + // for flashinfer + std::vector paged_kv_indptr = {0}; + std::vector paged_kv_indices; + std::vector paged_kv_last_page_len; }; // Helper methods for sequence processing diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index aaaae36d..b9da468e 100644 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -94,6 +94,11 @@ struct ModelInputParams { // Copy graph_buffer to device params.graph_buffer = safe_to(graph_buffer, device, true); + // params for flashinfer + params.paged_kv_indptr = safe_to(paged_kv_indptr, device); + params.paged_kv_indices = safe_to(paged_kv_indices, device); + params.paged_kv_last_page_len = safe_to(paged_kv_last_page_len, device); + return params; } @@ -193,6 +198,21 @@ struct ModelInputParams { // Graph execution buffer for temporary tensor storage // Used by ACL Graph Executor to avoid repeated memory allocation torch::Tensor graph_buffer; + + // the indptr of the paged kv-cache + // used in flashinfer + // IntTensor: [n_seq + 1] + torch::Tensor paged_kv_indptr; + + // the page indices of the paged kv cache + // used in flashinfer + torch::Tensor paged_kv_indices; + + // the number of entries in the last page of each request in + // the paged kv cache + // used in flashinfer + // IntTensor: [n_seq] + torch::Tensor paged_kv_last_page_len; }; } // namespace xllm diff --git a/xllm/core/kernels/CMakeLists.txt b/xllm/core/kernels/CMakeLists.txt index 4aa1941b..3bba0e16 100644 --- a/xllm/core/kernels/CMakeLists.txt +++ b/xllm/core/kernels/CMakeLists.txt @@ -8,6 +8,9 @@ if(USE_MLU) add_subdirectory(mlu) endif() +if(USE_CUDA) + add_subdirectory(cuda) +endif() cc_library( NAME @@ -21,4 +24,5 @@ cc_library( torch $<$:npu_kernels> $<$:mlu_kernels> + $<$:cuda_kernels> ) \ No newline at end of file diff --git a/xllm/core/kernels/cuda/CMakeLists.txt b/xllm/core/kernels/cuda/CMakeLists.txt new file mode 100644 index 00000000..83277bd4 --- /dev/null +++ b/xllm/core/kernels/cuda/CMakeLists.txt @@ -0,0 +1,25 @@ +include(cc_library) + +include_directories( + ${CMAKE_SOURCE_DIR}/third_party/tvm-ffi/include +) + +file(GLOB_RECURSE CUDA_HEADER_FILES + "${CMAKE_CURRENT_LIST_DIR}/*.h" +) + +file(GLOB_RECURSE CUDA_SOURCE_FILES + "${CMAKE_CURRENT_LIST_DIR}/*.cpp" + "${CMAKE_CURRENT_LIST_DIR}/*.cu" +) + +cc_library( + NAME + cuda_kernels + HDRS + ${CUDA_HEADER_FILES} + SRCS + ${CUDA_SOURCE_FILES} + DEPS + flashinfer +) diff --git a/xllm/core/kernels/cuda/activation.cpp b/xllm/core/kernels/cuda/activation.cpp new file mode 100644 index 00000000..e892b7ad --- /dev/null +++ b/xllm/core/kernels/cuda/activation.cpp @@ -0,0 +1,32 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cuda_ops_api.h" + +namespace xllm::kernel::cuda { + +void act_and_mul(torch::Tensor& out, + torch::Tensor& input, + const std::string& act_mode) { + if (act_mode != "silu" && act_mode != "gelu" && act_mode != "gelu_tanh") { + throw std::runtime_error("Unsupported act mode: " + act_mode + + ", only support silu, gelu, gelu_tanh"); + } + + std::string uri = act_mode + "_and_mul"; + get_module(uri)->GetFunction(uri).value()( + to_ffi_tensor(out), to_ffi_tensor(input), support_pdl()); +} +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/cuda/batch_decode.cpp b/xllm/core/kernels/cuda/batch_decode.cpp new file mode 100644 index 00000000..c5751f0a --- /dev/null +++ b/xllm/core/kernels/cuda/batch_decode.cpp @@ -0,0 +1,89 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cuda_ops_api.h" + +namespace xllm::kernel::cuda { + +void batch_decode(torch::Tensor& float_workspace_buffer, + torch::Tensor& int_workspace_buffer, + torch::Tensor& page_locked_int_workspace_buffer, + const torch::Tensor& query, + const torch::Tensor& k_cache, + const torch::Tensor& v_cache, + const torch::Tensor& q_cu_seq_lens, + const torch::Tensor& paged_kv_indptr, + const torch::Tensor& paged_kv_indices, + const torch::Tensor& paged_kv_last_page_len, + int64_t window_size_left, + torch::Tensor& output, + std::optional& output_lse, + bool enable_cuda_graph) { + std::string uri = get_batch_decode_uri(query.scalar_type(), + k_cache.scalar_type(), + output.scalar_type(), + paged_kv_indptr.scalar_type(), + query.size(-1), + v_cache.size(-1), + /*pos_encoding_mode=*/0, + /*use_sliding_window=*/false, + /*use_logits_soft_cap=*/false); + + ffi::Array plan_info; + torch::Tensor qo_indptr_host = q_cu_seq_lens.to(torch::kCPU); + const int64_t batch_size = q_cu_seq_lens.size(0) - 1; + + // plan decode + plan_info = get_module(uri)->GetFunction("plan").value()( + to_ffi_tensor(float_workspace_buffer), + to_ffi_tensor(int_workspace_buffer), + to_ffi_tensor(page_locked_int_workspace_buffer), + to_ffi_tensor(qo_indptr_host), + batch_size, + query.size(1), // num_qo_heads + k_cache.size(2), // num_kv_heads + k_cache.size(1), // block_size + enable_cuda_graph, + window_size_left, + /* logits_soft_cap=*/0.0, + query.size(-1), // head_dim_qk + v_cache.size(-1), // head_dim_vo + /*empty_q_data=*/torch::Tensor(), + /*empty_kv_data=*/torch::Tensor()); + + // batch decode + get_module(uri)->GetFunction("run").value()( + to_ffi_tensor(float_workspace_buffer), + to_ffi_tensor(int_workspace_buffer), + plan_info, + to_ffi_tensor(query), + to_ffi_tensor(k_cache), + to_ffi_tensor(v_cache), + to_ffi_tensor(paged_kv_indptr), + to_ffi_tensor(paged_kv_indices), + to_ffi_tensor(paged_kv_last_page_len), + to_ffi_tensor(output), + to_ffi_tensor(output_lse), + /*kv_layout_code=*/0, + window_size_left, + support_pdl(), + /*maybe_alibi_slopes=*/torch::Tensor(), + /*logits_soft_cap=*/0.0, + /*sm_scale=*/1.0, + /*rope_rcp_scale=*/1.0, + /*rope_rcp_theta=*/1.0); +} + +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/cuda/batch_prefill.cpp b/xllm/core/kernels/cuda/batch_prefill.cpp new file mode 100644 index 00000000..e88143c3 --- /dev/null +++ b/xllm/core/kernels/cuda/batch_prefill.cpp @@ -0,0 +1,104 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cuda_ops_api.h" + +namespace xllm::kernel::cuda { + +void batch_prefill(torch::Tensor& float_workspace_buffer, + torch::Tensor& int_workspace_buffer, + torch::Tensor& page_locked_int_workspace_buffer, + const torch::Tensor& query, + const torch::Tensor& key, + const torch::Tensor& value, + const torch::Tensor& q_cu_seq_lens, + const torch::Tensor& kv_cu_seq_lens, + int64_t window_size_left, + torch::Tensor& output, + std::optional& output_lse, + bool enable_cuda_graph) { + std::string uri = get_batch_prefill_uri(/*backend=*/"fa2", + query.scalar_type(), + key.scalar_type(), + output.scalar_type(), + q_cu_seq_lens.scalar_type(), + query.size(-1), + value.size(-1), + /*pos_encoding_mode=*/0, + /*use_sliding_window=*/false, + /*use_logits_soft_cap=*/false, + /*use_fp16_qk_reduction=*/false); + + ffi::Array plan_info; + + torch::Tensor kv_indptr_host = kv_cu_seq_lens.to(torch::kCPU); + torch::Tensor qo_indptr_host = q_cu_seq_lens.to(torch::kCPU); + torch::Tensor kv_len_arr = + kv_indptr_host.slice(0, 1) - kv_indptr_host.slice(0, 0, -1); + const int64_t total_num_rows = qo_indptr_host.size(0); + const int64_t batch_size = q_cu_seq_lens.size(0) - 1; + + // plan prefill + plan_info = get_module(uri)->GetFunction("plan").value()( + to_ffi_tensor(float_workspace_buffer), + to_ffi_tensor(int_workspace_buffer), + to_ffi_tensor(page_locked_int_workspace_buffer), + to_ffi_tensor(qo_indptr_host), + to_ffi_tensor(kv_indptr_host), + to_ffi_tensor(kv_len_arr), + total_num_rows, + batch_size, + query.size(1), // num_qo_heads + key.size(1), // num_kv_heads + /*page_size=*/1, + enable_cuda_graph, + query.size(-1), // head_dim_qk + value.size(-1), // head_dim_vo + /*causal=*/true, // causal + window_size_left, + /*fixed_split_size=*/-1, + /*disable_split_kv=*/false); + + // batch prefill + get_module(uri) + ->GetFunction("ragged_run") + .value()(to_ffi_tensor(float_workspace_buffer), + to_ffi_tensor(int_workspace_buffer), + plan_info, + to_ffi_tensor(query), + to_ffi_tensor(key), + to_ffi_tensor(value), + to_ffi_tensor(q_cu_seq_lens), + to_ffi_tensor(kv_cu_seq_lens), + to_ffi_tensor(output), + to_ffi_tensor(output_lse), + /*mask_mode_code=CAUSAL*/ 1, + /*layout=*/0, + window_size_left, + support_pdl(), + /*maybe_custom_mask=*/torch::Tensor(), + /*maybe_mask_indptr=*/torch::Tensor(), + /*maybe_alibi_slopes=*/torch::Tensor(), + /*maybe_prefix_len_ptr=*/torch::Tensor(), + /*maybe_token_pos_in_items_ptr=*/torch::Tensor(), + /*maybe_max_item_len_ptr=*/torch::Tensor(), + /*logits_soft_cap=*/0.0, + /*sm_scale=*/1.0, + /*rope_rcp_scale=*/1.0, + /*rope_rcp_theta=*/1.0, + /*token_pos_in_items_len=*/0); +} + +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/cuda/cuda_ops_api.h b/xllm/core/kernels/cuda/cuda_ops_api.h new file mode 100644 index 00000000..0a074ef1 --- /dev/null +++ b/xllm/core/kernels/cuda/cuda_ops_api.h @@ -0,0 +1,79 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "utils.h" + +namespace xllm::kernel::cuda { + +void apply_rope_pos_ids_cos_sin_cache(torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& q_rope, + torch::Tensor& k_rope, + torch::Tensor& cos_sin_cache, + torch::Tensor& pos_ids, + bool interleave); + +// act_mode only support silu, gelu, gelu_tanh +void act_and_mul(torch::Tensor& out, + torch::Tensor& input, + const std::string& act_mode); + +void reshape_paged_cache( + const torch::Tensor& slot_ids, // [n_tokens] + const torch::Tensor& keys, // [n_tokens, n_kv_heads, head_dim] + const torch::Tensor& values, // [n_tokens, n_kv_heads, head_dim] + torch::Tensor& key_cache, // [n_blocks, block_size, n_heads, head_dim] + torch::Tensor& value_cache); + +void batch_prefill(torch::Tensor& float_workspace_buffer, + torch::Tensor& int_workspace_buffer, + torch::Tensor& page_locked_int_workspace_buffer, + const torch::Tensor& query, + const torch::Tensor& key, + const torch::Tensor& value, + const torch::Tensor& q_cu_seq_lens, + const torch::Tensor& kv_cu_seq_lens, + int64_t window_size_left, + torch::Tensor& output, + std::optional& output_lse, + bool enable_cuda_graph); + +void batch_decode(torch::Tensor& float_workspace_buffer, + torch::Tensor& int_workspace_buffer, + torch::Tensor& page_locked_int_workspace_buffer, + const torch::Tensor& query, + const torch::Tensor& k_cache, + const torch::Tensor& v_cache, + const torch::Tensor& q_cu_seq_lens, + const torch::Tensor& paged_kv_indptr, + const torch::Tensor& paged_kv_indices, + const torch::Tensor& paged_kv_last_page_len, + int64_t window_size_left, + torch::Tensor& output, + std::optional& output_lse, + bool enable_cuda_graph); + +void rmsnorm(torch::Tensor& output, + torch::Tensor& input, + torch::Tensor& weight, + double eps); + +torch::Tensor matmul(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& bias); + +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/cuda/matmul.cpp b/xllm/core/kernels/cuda/matmul.cpp new file mode 100644 index 00000000..12aaeb27 --- /dev/null +++ b/xllm/core/kernels/cuda/matmul.cpp @@ -0,0 +1,27 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cuda_ops_api.h" + +namespace xllm::kernel::cuda { + +torch::Tensor matmul(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& bias) { + namespace F = torch::nn::functional; + return F::linear(a, b, bias.value_or(torch::Tensor())); +} + +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/cuda/norm.cpp b/xllm/core/kernels/cuda/norm.cpp new file mode 100644 index 00000000..7c21abdf --- /dev/null +++ b/xllm/core/kernels/cuda/norm.cpp @@ -0,0 +1,31 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cuda_ops_api.h" + +namespace xllm::kernel::cuda { + +void rmsnorm(torch::Tensor& output, + torch::Tensor& input, + torch::Tensor& weight, + double eps) { + get_module("norm")->GetFunction("rmsnorm").value()(to_ffi_tensor(output), + to_ffi_tensor(input), + to_ffi_tensor(weight), + eps, + support_pdl()); +} + +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/cuda/reshape_paged_cache.cu b/xllm/core/kernels/cuda/reshape_paged_cache.cu new file mode 100644 index 00000000..a64b19fc --- /dev/null +++ b/xllm/core/kernels/cuda/reshape_paged_cache.cu @@ -0,0 +1,106 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cuda_ops_api.h" + +namespace { +// NOLINTBEGIN(cppcoreguidelines-macro-usage) +#define DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) +#define DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +// NOLINTEND(cppcoreguidelines-macro-usage) +} // namespace + +namespace xllm::kernel::cuda { + +template +__global__ void reshape_paged_cache_kernel( + const int* __restrict__ slot_ids, // [n_tokens] + const T* __restrict__ keys, // [n_tokens, n_heads, head_dim] + const T* __restrict__ values, // [n_tokens, n_heads, head_dim] + T* __restrict__ key_cache, + T* __restrict__ value_cache, + int64_t k_stride, + int64_t v_stride, + int64_t n_kv_heads, + int64_t head_dim, + int64_t block_size) { + // block/token index + const int64_t bid = blockIdx.x; + // which slot to write to + const int64_t slot_id = slot_ids[bid]; + // block index + const int64_t block_idx = slot_id / block_size; + // offset within block + const int64_t block_offset = slot_id % block_size; + // base index for the block in cache + const int64_t block_base_idx = block_idx * block_size * n_kv_heads * head_dim; + // copy value one by one for the token + for (int64_t i = threadIdx.x; i < n_kv_heads * head_dim; i += blockDim.x) { + const int64_t k_src_idx = bid * k_stride + i; + const int64_t v_src_idx = bid * v_stride + i; + // cache: [n_blocks, block_size, n_heads, head_dim] + const int64_t head_base_idx = + block_base_idx + block_offset * n_kv_heads * head_dim; + // which head to write to + const int head_idx = i / head_dim; + // which dim within head to write to + const int head_offset = i % head_dim; + const int64_t dst_idx = head_base_idx + head_idx * head_dim + head_offset; + key_cache[dst_idx] = keys[k_src_idx]; + value_cache[dst_idx] = values[v_src_idx]; + } +} + +void reshape_paged_cache( + const torch::Tensor& slot_ids, // [n_tokens] + const torch::Tensor& keys, // [n_tokens, n_kv_heads, head_dim] + const torch::Tensor& values, // [n_tokens, n_kv_heads, head_dim] + torch::Tensor& key_cache, // [n_blocks, block_size, n_heads, head_dim] + torch::Tensor& value_cache) { + // keys and values should be continuous at n_kv_heads and head_dim dims + CHECK(keys.stride(-1) == 1 && keys.stride(-2) == keys.size(-1)); + CHECK(values.stride(-1) == 1 && values.stride(-2) == values.size(-1)); + const int64_t n_tokens = keys.size(-3); + const int64_t n_kv_heads = keys.size(-2); + const int64_t head_dim = keys.size(-1); + const int64_t block_size = key_cache.size(-3); + // it is possible that keys and values have different strides + const int64_t k_stride = keys.stride(-3); + const int64_t v_stride = values.stride(-3); + const int64_t n = n_kv_heads * head_dim; + dim3 grid(n_tokens); + dim3 block(std::min(n, 1024)); + DISPATCH_FLOATING_TYPES( + keys.scalar_type(), "reshape_paged_cache_kernel", [&] { + reshape_paged_cache_kernel + <<>>( + slot_ids.data_ptr(), + keys.data_ptr(), + values.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + k_stride, + v_stride, + n_kv_heads, + head_dim, + block_size); + }); +} + +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/cuda/rope.cpp b/xllm/core/kernels/cuda/rope.cpp new file mode 100644 index 00000000..7d96ac90 --- /dev/null +++ b/xllm/core/kernels/cuda/rope.cpp @@ -0,0 +1,39 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "cuda_ops_api.h" + +namespace xllm::kernel::cuda { + +void apply_rope_pos_ids_cos_sin_cache(torch::Tensor& q, + torch::Tensor& k, + torch::Tensor& q_rope, + torch::Tensor& k_rope, + torch::Tensor& cos_sin_cache, + torch::Tensor& pos_ids, + bool interleave) { + get_module("rope") + ->GetFunction("apply_rope_pos_ids_cos_sin_cache") + .value()(to_ffi_tensor(q), + to_ffi_tensor(k), + to_ffi_tensor(q_rope), + to_ffi_tensor(k_rope), + to_ffi_tensor(cos_sin_cache), + to_ffi_tensor(pos_ids), + interleave, + support_pdl()); +} + +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/cuda/utils.cpp b/xllm/core/kernels/cuda/utils.cpp new file mode 100644 index 00000000..dd7e2032 --- /dev/null +++ b/xllm/core/kernels/cuda/utils.cpp @@ -0,0 +1,119 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "utils.h" + +#include + +namespace { +const std::string base_ops_path = + "/root/.cache/flashinfer/0.4.1/80_89_90a/cached_ops"; + +const std::unordered_map + filename_safe_dtype_map = { + {torch::kFloat16, "f16"}, + {torch::kBFloat16, "bf16"}, + {torch::kFloat8_e4m3fn, "e4m3"}, + {torch::kFloat8_e5m2, "e5m2"}, + {torch::kInt8, "i8"}, + {torch::kUInt8, "u8"}, + {torch::kInt32, "i32"}, + {torch::kUInt32, "u32"}, + {torch::kInt64, "i64"}, + {torch::kUInt64, "u64"}, +}; + +std::string map_dtype(torch::ScalarType t) { + auto it = filename_safe_dtype_map.find(t); + if (it == filename_safe_dtype_map.end()) + throw std::invalid_argument("Unsupported dtype"); + return std::string(it->second); +} +} // namespace + +namespace xllm::kernel::cuda { +ffi::Tensor to_ffi_tensor(const torch::Tensor& torch_tensor) { + auto dlpack = at::toDLPackVersioned(torch_tensor); + return ffi::Tensor::FromDLPackVersioned(dlpack); +} + +bool support_pdl() { + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, /*device_id=*/0); + return prop.major >= 9; +} + +std::string get_batch_prefill_uri(const std::string& backend, + torch::ScalarType dtype_q, + torch::ScalarType dtype_kv, + torch::ScalarType dtype_o, + torch::ScalarType dtype_idx, + int64_t head_dim_qk, + int64_t head_dim_vo, + int64_t pos_encoding_mode, + bool use_sliding_window, + bool use_logits_soft_cap, + bool use_fp16_qk_reduction) { + std::ostringstream oss; + oss << "batch_prefill_with_kv_cache_" + << "dtype_q_" << map_dtype(dtype_q) << "_" + << "dtype_kv_" << map_dtype(dtype_kv) << "_" + << "dtype_o_" << map_dtype(dtype_o) << "_" + << "dtype_idx_" << map_dtype(dtype_idx) << "_" + << "head_dim_qk_" << head_dim_qk << "_" + << "head_dim_vo_" << head_dim_vo << "_" + << "posenc_" << pos_encoding_mode << "_" + << "use_swa_" << (use_sliding_window ? "True" : "False") << "_" + << "use_logits_cap_" << (use_logits_soft_cap ? "True" : "False") << "_" + << "f16qk_" << (use_fp16_qk_reduction ? "True" : "False"); + + if (backend == "fa3") oss << "_sm90"; + + return oss.str(); +} + +std::string get_batch_decode_uri(torch::ScalarType dtype_q, + torch::ScalarType dtype_kv, + torch::ScalarType dtype_o, + torch::ScalarType dtype_idx, + int64_t head_dim_qk, + int64_t head_dim_vo, + int64_t pos_encoding_mode, + bool use_sliding_window, + bool use_logits_soft_cap) { + std::ostringstream oss; + oss << "batch_decode_with_kv_cache_" + << "dtype_q_" << map_dtype(dtype_q) << "_" + << "dtype_kv_" << map_dtype(dtype_kv) << "_" + << "dtype_o_" << map_dtype(dtype_o) << "_" + << "dtype_idx_" << map_dtype(dtype_idx) << "_" + << "head_dim_qk_" << head_dim_qk << "_" + << "head_dim_vo_" << head_dim_vo << "_" + << "posenc_" << pos_encoding_mode << "_" + << "use_swa_" << (use_sliding_window ? "True" : "False") << "_" + << "use_logits_cap_" << (use_logits_soft_cap ? "True" : "False"); + + return oss.str(); +} + +std::string path_to_uri(const std::string& uri) { + return base_ops_path + "/" + uri + "/" + uri + ".so"; +} + +ffi::Module get_module(const std::string& uri) { + std::string uri_path = path_to_uri(uri); + return ffi::Module::LoadFromFile(uri_path); +} +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/cuda/utils.h b/xllm/core/kernels/cuda/utils.h new file mode 100644 index 00000000..2105a074 --- /dev/null +++ b/xllm/core/kernels/cuda/utils.h @@ -0,0 +1,57 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#include + +namespace ffi = tvm::ffi; + +namespace xllm::kernel::cuda { + +std::string get_batch_prefill_uri(const std::string& backend, + torch::ScalarType dtype_q, + torch::ScalarType dtype_kv, + torch::ScalarType dtype_o, + torch::ScalarType dtype_idx, + int64_t head_dim_qk, + int64_t head_dim_vo, + int64_t pos_encoding_mode, + bool use_sliding_window, + bool use_logits_soft_cap, + bool use_fp16_qk_reduction); + +std::string get_batch_decode_uri(torch::ScalarType dtype_q, + torch::ScalarType dtype_kv, + torch::ScalarType dtype_o, + torch::ScalarType dtype_idx, + int64_t head_dim_qk, + int64_t head_dim_vo, + int64_t pos_encoding_mode, + bool use_sliding_window, + bool use_logits_soft_cap); + +std::string path_to_uri(const std::string& uri); + +ffi::Module get_module(const std::string& uri); + +ffi::Tensor to_ffi_tensor(const torch::Tensor& torch_tensor); + +bool support_pdl(); +} // namespace xllm::kernel::cuda \ No newline at end of file diff --git a/xllm/core/kernels/mlu/mlu_ops_api.h b/xllm/core/kernels/mlu/mlu_ops_api.h index fce078f7..09d49b9f 100644 --- a/xllm/core/kernels/mlu/mlu_ops_api.h +++ b/xllm/core/kernels/mlu/mlu_ops_api.h @@ -26,11 +26,6 @@ limitations under the License. namespace xllm::kernel::mlu { -static const std::string kActModeSilu = "silu"; -static const std::string kActModeGelu = "gelu"; -static const std::string kActModeQuickGelu = "quick_gelu"; -static const std::string kActModeSwish = "swish"; - void apply_rotary(torch::Tensor& q, torch::Tensor& k, const torch::Tensor& sin, diff --git a/xllm/core/kernels/ops_api.cpp b/xllm/core/kernels/ops_api.cpp index de7cebbe..29c1bd17 100644 --- a/xllm/core/kernels/ops_api.cpp +++ b/xllm/core/kernels/ops_api.cpp @@ -15,8 +15,13 @@ limitations under the License. #include "ops_api.h" -namespace xllm { -namespace kernel { +#if defined(USE_MLU) +#include "mlu/mlu_ops_api.h" +#elif defined(USE_CUDA) +#include "cuda/cuda_ops_api.h" +#endif + +namespace xllm::kernel { void apply_rotary(RotaryParams& params) { #if defined(USE_MLU) @@ -30,6 +35,14 @@ void apply_rotary(RotaryParams& params) { params.discrete, params.dynamic_ntk, params.max_query_len); +#elif defined(USE_CUDA) + cuda::apply_rope_pos_ids_cos_sin_cache(params.q, + params.k, + params.q, + params.k, + params.cos_sin, + params.position_ids, + params.interleaved); #else throw std::runtime_error("apply_rotary not implemented"); #endif @@ -45,6 +58,8 @@ void active(ActivationParams& params) { params.is_gated, params.start_expert_id, params.expert_size); +#elif defined(USE_CUDA) + cuda::act_and_mul(params.output, params.input, params.act_mode); #else throw std::runtime_error("active not implemented"); #endif @@ -58,6 +73,12 @@ void reshape_paged_cache(ReshapePagedCacheParams& params) { params.v_cache, params.slot_mapping, params.direction); +#elif defined(USE_CUDA) + cuda::reshape_paged_cache(params.slot_mapping, + params.key, + params.value, + params.k_cache, + params.v_cache); #else throw std::runtime_error("reshape_paged_cache not implemented"); #endif @@ -87,6 +108,19 @@ void batch_prefill(AttentionParams& params) { params.window_size_right, params.compute_dtype, params.return_lse); +#elif defined(USE_CUDA) + cuda::batch_prefill(params.float_workspace_buffer, + params.int_workspace_buffer, + params.page_locked_int_workspace_buffer, + params.query, + params.key, + params.value, + params.q_cu_seq_lens, + params.kv_cu_seq_lens, + params.window_size_left, + params.output, + params.output_lse, + params.enable_cuda_graph); #else throw std::runtime_error("batch_prefill not implemented"); #endif @@ -114,6 +148,21 @@ void batch_decode(AttentionParams& params) { params.scale, params.return_lse, params.kv_cache_quant_bit_size); +#elif defined(USE_CUDA) + cuda::batch_decode(params.float_workspace_buffer, + params.int_workspace_buffer, + params.page_locked_int_workspace_buffer, + params.query, + params.k_cache, + params.v_cache, + params.q_cu_seq_lens, + params.paged_kv_indptr, + params.paged_kv_indices, + params.paged_kv_last_page_len, + params.window_size_left, + params.output, + params.output_lse, + params.enable_cuda_graph); #else throw std::runtime_error("batch_decode not implemented"); #endif @@ -136,6 +185,8 @@ void fused_layernorm(FusedLayerNormParams& params) { params.store_output_before_norm, params.store_output_after_norm, params.dynamic_quant); +#elif defined(USE_CUDA) + cuda::rmsnorm(params.output, params.input, params.weight, params.eps); #else throw std::runtime_error("fused_layernorm not implemented"); #endif @@ -145,6 +196,8 @@ torch::Tensor matmul(MatmulParams& params) { #if defined(USE_MLU) return mlu::matmul( params.a, params.b, params.bias, params.c, params.alpha, params.beta); +#elif defined(USE_CUDA) + return cuda::matmul(params.a, params.b, params.bias); #else throw std::runtime_error("matmul not implemented"); #endif @@ -179,6 +232,8 @@ torch::Tensor fused_moe(FusedMoEParams& params) { params.world_size, params.shared_expert_num, params.parallel_mode); +#elif defined(USE_CUDA) + throw std::runtime_error("fused_moe for cudanot implemented"); #else throw std::runtime_error("fused_moe not implemented"); #endif @@ -226,5 +281,4 @@ torch::Tensor scaled_matmul(ScaledMatmulParams& params) { throw std::runtime_error("scaled_matmul not implemented"); #endif } -} // namespace kernel -} // namespace xllm \ No newline at end of file +} // namespace xllm::kernel diff --git a/xllm/core/kernels/ops_api.h b/xllm/core/kernels/ops_api.h index 3e11d352..510d47d0 100644 --- a/xllm/core/kernels/ops_api.h +++ b/xllm/core/kernels/ops_api.h @@ -17,12 +17,12 @@ limitations under the License. #include "param.h" -#if defined(USE_MLU) -#include "mlu/mlu_ops_api.h" -#endif +namespace xllm::kernel { -namespace xllm { -namespace kernel { +static const std::string kActModeSilu = "silu"; +static const std::string kActModeGelu = "gelu"; +static const std::string kActModeQuickGelu = "quick_gelu"; +static const std::string kActModeSwish = "swish"; void apply_rotary(RotaryParams& params); @@ -45,5 +45,4 @@ std::tuple scaled_quantize( torch::Tensor scaled_matmul(ScaledMatmulParams& params); -} // namespace kernel -} // namespace xllm +} // namespace xllm::kernel diff --git a/xllm/core/kernels/param.h b/xllm/core/kernels/param.h index 1f920186..75b9f4df 100644 --- a/xllm/core/kernels/param.h +++ b/xllm/core/kernels/param.h @@ -21,8 +21,7 @@ limitations under the License. #include #include -namespace xllm { -namespace kernel { +namespace xllm::kernel { // Note: add default values for optional parameters in the struct definition @@ -206,5 +205,4 @@ struct ScaledMatmulParams { std::optional b_calib = std::nullopt; std::optional output = std::nullopt; }; -} // namespace kernel -} // namespace xllm \ No newline at end of file +} // namespace xllm::kernel diff --git a/xllm/core/layers/CMakeLists.txt b/xllm/core/layers/CMakeLists.txt index 6ad3d0c7..22763761 100644 --- a/xllm/core/layers/CMakeLists.txt +++ b/xllm/core/layers/CMakeLists.txt @@ -61,7 +61,6 @@ cc_library( word_embedding.h lm_head.h block_copy.h - linear.h SRCS multi_head_attention.cpp DEPS diff --git a/xllm/core/layers/common/CMakeLists.txt b/xllm/core/layers/common/CMakeLists.txt index 20b037d4..c13b21d6 100755 --- a/xllm/core/layers/common/CMakeLists.txt +++ b/xllm/core/layers/common/CMakeLists.txt @@ -15,6 +15,7 @@ cc_library( qwen3_decoder_layer.h qwen3_moe_decoder_layer.h linear_impl.h + linear.h word_embedding_impl.h SRCS qwen3_attention.cpp diff --git a/xllm/core/layers/common/attention.cpp b/xllm/core/layers/common/attention.cpp index fa7a7725..7146cb13 100644 --- a/xllm/core/layers/common/attention.cpp +++ b/xllm/core/layers/common/attention.cpp @@ -15,6 +15,7 @@ limitations under the License. #include "attention.h" +#include "common/flashinfer_workspace.h" #include "kernels/ops_api.h" DECLARE_bool(enable_chunked_prefill); @@ -37,6 +38,13 @@ AttentionMetadata AttentionMetadata::build(const ModelInputParams& params, attn_metadata.slot_mapping = params.new_cache_slots; attn_metadata.compute_dtype = compute_dtype; + // for flashinfer + attn_metadata.paged_kv_indptr = params.paged_kv_indptr; + attn_metadata.paged_kv_indices = params.paged_kv_indices; + attn_metadata.paged_kv_last_page_len = params.paged_kv_last_page_len; + attn_metadata.q_cu_seq_lens = params.q_seq_lens; + attn_metadata.kv_cu_seq_lens = params.kv_seq_lens; // cumulative kv seqlens + bool is_start_loc_match = (params.q_seq_lens_vec == params.kv_seq_lens_vec); attn_metadata.is_chunked_prefill = is_prefill && !is_start_loc_match; attn_metadata.is_prefill = is_prefill && !attn_metadata.is_chunked_prefill; @@ -92,6 +100,16 @@ std::tuple> AttentionImpl::forward( attention_params.window_size_left = sliding_window_; attention_params.scale = scale_; attention_params.compute_dtype = attn_metadata.compute_dtype; + // for flashinfer + attention_params.float_workspace_buffer = + FlashinferWorkspace::get_instance().get_float_workspace_buffer(); + attention_params.int_workspace_buffer = + FlashinferWorkspace::get_instance().get_int_workspace_buffer(); + attention_params.page_locked_int_workspace_buffer = + FlashinferWorkspace::get_instance() + .get_page_locked_int_workspace_buffer(); + attention_params.kv_cu_seq_lens = attn_metadata.kv_cu_seq_lens; + attention_params.q_cu_seq_lens = attn_metadata.q_cu_seq_lens; if (attn_metadata.is_prefill) { attention_params.key = key; @@ -123,6 +141,12 @@ std::tuple> AttentionImpl::forward( attention_params.block_table = attn_metadata.block_table; attention_params.kv_seq_lens = attn_metadata.kv_seq_lens; + // for flashinfer + attention_params.paged_kv_indptr = attn_metadata.paged_kv_indptr; + attention_params.paged_kv_indices = attn_metadata.paged_kv_indices; + attention_params.paged_kv_last_page_len = + attn_metadata.paged_kv_last_page_len; + xllm::kernel::batch_decode(attention_params); } diff --git a/xllm/core/layers/common/attention.h b/xllm/core/layers/common/attention.h index 7e210001..60fd92da 100644 --- a/xllm/core/layers/common/attention.h +++ b/xllm/core/layers/common/attention.h @@ -44,6 +44,13 @@ struct AttentionMetadata { std::string compute_dtype; bool is_prefill; bool is_chunked_prefill; + + // for flashinfer + torch::Tensor paged_kv_indptr; + torch::Tensor paged_kv_indices; + torch::Tensor paged_kv_last_page_len; + torch::Tensor q_cu_seq_lens; + torch::Tensor kv_cu_seq_lens; }; class AttentionImpl : public torch::nn::Module { diff --git a/xllm/core/layers/common/dense_mlp.h b/xllm/core/layers/common/dense_mlp.h index 11c7e487..9f02e764 100644 --- a/xllm/core/layers/common/dense_mlp.h +++ b/xllm/core/layers/common/dense_mlp.h @@ -21,7 +21,7 @@ limitations under the License. #include "framework/parallel_state/parallel_args.h" #include "framework/quant_args.h" #include "framework/state_dict/state_dict.h" -#include "layers/linear.h" +#include "linear.h" namespace xllm { namespace layer { diff --git a/xllm/core/layers/common/fused_moe.h b/xllm/core/layers/common/fused_moe.h index e23a2629..c7340a98 100644 --- a/xllm/core/layers/common/fused_moe.h +++ b/xllm/core/layers/common/fused_moe.h @@ -23,7 +23,7 @@ limitations under the License. #include "framework/quant_args.h" #include "framework/state_dict/state_dict.h" #include "framework/state_dict/utils.h" -#include "layers/linear.h" +#include "linear.h" namespace xllm { namespace layer { diff --git a/xllm/core/layers/linear.h b/xllm/core/layers/common/linear.h similarity index 98% rename from xllm/core/layers/linear.h rename to xllm/core/layers/common/linear.h index 7870dbeb..a2b238ab 100644 --- a/xllm/core/layers/linear.h +++ b/xllm/core/layers/common/linear.h @@ -18,14 +18,11 @@ limitations under the License. #include #include -#if defined(USE_MLU) -#include "common/linear_impl.h" -#endif +#include "linear_impl.h" namespace xllm { namespace layer { -#if defined(USE_MLU) class ColumnParallelLinear : public torch::nn::ModuleHolder { public: @@ -123,7 +120,6 @@ class ReplicatedLinear : public torch::nn::ModuleHolder { quant_args, options)) {} }; -#endif } // namespace layer } // namespace xllm diff --git a/xllm/core/layers/common/qwen3_attention.h b/xllm/core/layers/common/qwen3_attention.h index 9d5536ce..6b2bc2ba 100644 --- a/xllm/core/layers/common/qwen3_attention.h +++ b/xllm/core/layers/common/qwen3_attention.h @@ -23,8 +23,8 @@ limitations under the License. #include "framework/parallel_state/parallel_args.h" #include "framework/quant_args.h" #include "framework/state_dict/state_dict.h" -#include "layers/linear.h" #include "layers/rms_norm.h" +#include "linear.h" #include "rotary_embedding.h" namespace xllm { diff --git a/xllm/core/platform/device.cpp b/xllm/core/platform/device.cpp index 6c3763c6..2f415637 100644 --- a/xllm/core/platform/device.cpp +++ b/xllm/core/platform/device.cpp @@ -64,6 +64,8 @@ const std::string Device::type() { return "npu"; #elif defined(USE_MLU) return "mlu"; +#elif defined(USE_CUDA) + return "cuda"; #endif } diff --git a/xllm/core/platform/stream.cpp b/xllm/core/platform/stream.cpp index 5cb15b48..6e69276d 100644 --- a/xllm/core/platform/stream.cpp +++ b/xllm/core/platform/stream.cpp @@ -21,15 +21,13 @@ namespace xllm { Stream::Stream() : stream_(c10_npu::getNPUStreamFromPool()) {} #elif defined(USE_MLU) Stream::Stream() : stream_(torch_mlu::getStreamFromPool()) {} +#elif defined(USE_CUDA) +Stream::Stream() : stream_(c10::cuda::getStreamFromPool()) {} #endif int Stream::synchronize() const { -#if defined(USE_NPU) - return aclrtSynchronizeStream(stream_.stream()); -#elif defined(USE_MLU) stream_.unwrap().synchronize(); return 0; -#endif } c10::StreamGuard Stream::set_stream_guard() const { diff --git a/xllm/core/platform/stream.h b/xllm/core/platform/stream.h index 7cb65913..843105cb 100644 --- a/xllm/core/platform/stream.h +++ b/xllm/core/platform/stream.h @@ -21,13 +21,17 @@ limitations under the License. #endif // clang-format on +#include +#include + #include #if defined(USE_NPU) #include #include #elif defined(USE_MLU) -#include #include +#elif defined(USE_CUDA) +#include #endif namespace xllm { @@ -50,6 +54,8 @@ class Stream { c10_npu::NPUStream stream_; #elif defined(USE_MLU) torch_mlu::MLUStream stream_; +#elif defined(USE_CUDA) + c10::cuda::CUDAStream stream_; #endif }; diff --git a/xllm/core/platform/vmm_api.cpp b/xllm/core/platform/vmm_api.cpp index d129e603..e0342246 100644 --- a/xllm/core/platform/vmm_api.cpp +++ b/xllm/core/platform/vmm_api.cpp @@ -98,7 +98,7 @@ void create_vir_ptr(VirPtr vir_ptr, size_t aligned_size) { #elif defined(USE_MLU) ret = cnMemAddressReserve(&vir_ptr, aligned_size, 0, 0, 0); #elif defined(USE_CUDA) - ret = cuMemAddressReserve(&vir_ptr, aligned_size, 0, nullptr, 0); + ret = cuMemAddressReserve(&vir_ptr, aligned_size, 0, 0, 0); #endif CHECK_EQ(ret, 0) << "Failed to create virtual memory handle"; } diff --git a/xllm/core/runtime/forward_params.h b/xllm/core/runtime/forward_params.h index dd4a3d8f..8f8b63ad 100644 --- a/xllm/core/runtime/forward_params.h +++ b/xllm/core/runtime/forward_params.h @@ -180,6 +180,10 @@ struct RawForwardInput { std::vector kv_cache_start_offsets; //[n_seq] // beam search kernel input std::vector acc_logprob_vec; + // for flashinfer + std::vector paged_kv_indptr; //[n_seq + 1] + std::vector paged_kv_indices; //[num_used_pages] + std::vector paged_kv_last_page_len; //[n_seq] }; struct RawSampleOutput { diff --git a/xllm/core/runtime/llm_engine.cpp b/xllm/core/runtime/llm_engine.cpp index bdb30aa5..442c4454 100644 --- a/xllm/core/runtime/llm_engine.cpp +++ b/xllm/core/runtime/llm_engine.cpp @@ -297,7 +297,7 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) { kv_cache_shape.emplace_back(std::vector{ kv_cache_cap.n_blocks, block_size, 1, args_.qk_rope_head_dim()}); } else { -#if defined(USE_NPU) +#if defined(USE_NPU) || defined(USE_CUDA) kv_cache_shape.emplace_back(std::vector{ kv_cache_cap.n_blocks, block_size, n_local_kv_heads_, head_dim_}); kv_cache_shape.emplace_back(std::vector{ diff --git a/xllm/core/runtime/llm_worker_impl.cpp b/xllm/core/runtime/llm_worker_impl.cpp index 820bb9cc..128ed70d 100644 --- a/xllm/core/runtime/llm_worker_impl.cpp +++ b/xllm/core/runtime/llm_worker_impl.cpp @@ -26,6 +26,7 @@ limitations under the License. #include #include "common/device_monitor.h" +#include "common/flashinfer_workspace.h" #include "common/metrics.h" #include "common/types.h" #include "core/common/global_flags.h" @@ -41,7 +42,10 @@ namespace xllm { LLMWorkerImpl::LLMWorkerImpl(const ParallelArgs& parallel_args, const torch::Device& device, const runtime::Options& options) - : WorkerImpl(parallel_args, device, options) {} + : WorkerImpl(parallel_args, device, options) { + // initialize flashinfer workspace + FlashinferWorkspace::get_instance().initialize(device_); +} bool LLMWorkerImpl::init_model(ModelContext& context) { CHECK(model_ == nullptr) << "Model is already initialized."; diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index 428c0c3e..16bf6fff 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -64,6 +64,16 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, std::vector(pb_forward_input->q_seq_lens().begin(), pb_forward_input->q_seq_lens().end()); // aprint(q_seq_lens, "q_seq_lens", global_rank_); + // for flashinfer + std::vector paged_kv_indptr = + std::vector(pb_forward_input->paged_kv_indptr().begin(), + pb_forward_input->paged_kv_indptr().end()); + std::vector paged_kv_indices = + std::vector(pb_forward_input->paged_kv_indices().begin(), + pb_forward_input->paged_kv_indices().end()); + std::vector paged_kv_last_page_len = + std::vector(pb_forward_input->paged_kv_last_page_len().begin(), + pb_forward_input->paged_kv_last_page_len().end()); std::vector> block_tables_vec; for (size_t i = 0; i < pb_forward_input->block_tables_vec().size(); ++i) { block_tables_vec.emplace_back(std::vector( @@ -213,6 +223,12 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, input_params.kv_seq_lens_vec = std::move(seq_lens); input_params.q_seq_lens_vec = std::move(q_seq_lens); + input_params.paged_kv_indptr = torch::tensor(paged_kv_indptr, tensor_options); + input_params.paged_kv_indices = + torch::tensor(paged_kv_indices, tensor_options); + input_params.paged_kv_last_page_len = + torch::tensor(paged_kv_last_page_len, tensor_options); + input_params.new_cache_slots = torch::tensor(new_token_slot_ids, tensor_options); input_params.decode_seq_range = decode_seq_range; @@ -396,6 +412,13 @@ void forward_input_to_proto(const RawForwardInput& inputs, ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_seq_lens(), inputs.seq_lens); ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_q_seq_lens(), inputs.q_seq_lens); + // for flashinfer + ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_paged_kv_indptr(), + inputs.paged_kv_indptr); + ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_paged_kv_indices(), + inputs.paged_kv_indices); + ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_paged_kv_last_page_len(), + inputs.paged_kv_last_page_len); ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_new_token_slot_ids(), inputs.new_token_slot_ids); pb_forward_input->mutable_block_tables_vec()->Reserve( diff --git a/xllm/core/runtime/worker_impl.cpp b/xllm/core/runtime/worker_impl.cpp index b6a8cf48..422eff49 100644 --- a/xllm/core/runtime/worker_impl.cpp +++ b/xllm/core/runtime/worker_impl.cpp @@ -24,6 +24,8 @@ limitations under the License. #include "kernels/npu/xllm_ops/replace_token.h" #elif defined(USE_MLU) #include +#elif defined(USE_CUDA) +#include #endif #include @@ -92,7 +94,7 @@ bool WorkerImpl::allocate_kv_cache( value_cache = at_npu::native::npu_format_cast( torch::empty(kv_cache_shape[1], torch::dtype(dtype_).device(device_)), 2); -#elif defined(USE_MLU) +#elif defined(USE_MLU) || defined(USE_CUDA) key_cache = torch::empty(kv_cache_shape[0], torch::dtype(dtype_).device(device_)); value_cache = @@ -300,6 +302,8 @@ std::tuple WorkerImpl::estimate_kv_cache_capacity() { device_id, &torch_cache, &torch_largest_block); #elif defined(USE_MLU) torch_mlu::MLUCachingAllocator::emptyCache(); +#elif defined(USE_CUDA) + c10::cuda::CUDACachingAllocator::emptyCache(); #endif const auto available_memory = device_.free_memory(); const auto total_memory = device_.total_memory(); @@ -351,14 +355,14 @@ void WorkerImpl::update_last_step_output( ForwardInput WorkerImpl::update_input_by_last_step_output( ForwardInput& inputs) { -#if defined(USE_A3) || defined(USE_MLU) +#if defined(USE_A3) || defined(USE_MLU) || defined(USE_CUDA) auto& flatten_tokens = inputs.token_ids; auto neg_mask = (flatten_tokens < 0); auto clamped_neg_indices = torch::clamp(-flatten_tokens, 0); auto replacement = last_step_output_.sample_output.next_tokens.index( {clamped_neg_indices - 1}); inputs.token_ids = torch::where(neg_mask, replacement, flatten_tokens); -#else +#elif defined(USE_NPU) xllm_ops::replace_token(inputs.token_ids, last_step_output_.sample_output.next_tokens); #endif diff --git a/xllm/models/llm/llm_model_base.h b/xllm/models/llm/llm_model_base.h index c77d71e1..46cd2787 100644 --- a/xllm/models/llm/llm_model_base.h +++ b/xllm/models/llm/llm_model_base.h @@ -32,7 +32,6 @@ limitations under the License. #include "core/framework/model_context.h" #include "core/layers/attention_mask.h" #include "core/layers/block_copy.h" -#include "core/layers/linear.h" #include "core/layers/lm_head.h" #include "core/layers/pos_embedding.h" #include "core/layers/rms_norm.h"