Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,15 @@
[submodule "third_party/Mooncake"]
path = third_party/Mooncake
url = https://gitcode.com/xLLM-AI/Mooncake.git
[submodule "third_party/flashinfer"]
path = third_party/flashinfer
url = https://gitcode.com/xLLM-AI/flashinfer.git
[submodule "third_party/cutlass"]
path = third_party/cutlass
url = https://gitcode.com/xLLM-AI/cutlass.git
[submodule "third_party/tvm-ffi"]
path = third_party/tvm-ffi
url = https://gitcode.com/xLLM-AI/tvm-ffi.git
[submodule "third_party/dlpack"]
path = third_party/dlpack
url = https://gitcode.com/xLLM-AI/dlpack.git
74 changes: 72 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
cmake_minimum_required(VERSION 3.26)
set_property(GLOBAL PROPERTY USE_FOLDERS ON)
set(CMAKE_CUDA_COMPILER "/usr/local/cuda/bin/nvcc")

option(USE_NPU "Enable NPU support" OFF)
option(USE_MLU "Enable MLU support" OFF)
option(USE_CUDA "Enable CUDA support" OFF)

if(DEVICE_ARCH STREQUAL "ARM")
set(CMAKE_SYSTEM_PROCESSOR aarch64)
Expand Down Expand Up @@ -101,7 +103,7 @@ set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS ON)

if(USE_NPU)
if(USE_NPU OR USE_CUDA)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
elseif(USE_MLU)
Expand Down Expand Up @@ -178,6 +180,32 @@ if (DEFINED ENV{DEPENDENCES_ROOT})
message(STATUS "Using DEPENDENCES_ROOT: $ENV{DEPENDENCES_ROOT}")
endif()

# set architecture for CUDA
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES AND USE_CUDA)
set(CMAKE_CUDA_ARCHITECTURES 80)
endif()

# Build TORCH_CUDA_ARCH_LIST
if(USE_CUDA)
# Build TORCH_CUDA_ARCH_LIST
set(TORCH_CUDA_ARCH_LIST "")
foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES)
if(CUDA_ARCH MATCHES "^([0-9])([0-9])a$")
set(TORCH_ARCH "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}a")
elseif(CUDA_ARCH MATCHES "^([0-9])([0-9])*$")
set(TORCH_ARCH "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}")
elseif(CUDA_ARCH STREQUAL "native")
set(TORCH_ARCH "Auto")
else()
message(FATAL_ERROR "${CUDA_ARCH} is not supported")
endif()
list(APPEND TORCH_CUDA_ARCH_LIST ${TORCH_ARCH})
endforeach()

message(STATUS "CMAKE_CUDA_ARCHITECTURES: ${CMAKE_CUDA_ARCHITECTURES}")
message(STATUS "TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}")
endif()

# configure vcpkg
# have to set CMAKE_TOOLCHAIN_FILE before first project call.
# if (DEFINED ENV{VCPKG_ROOT} AND NOT DEFINED CMAKE_TOOLCHAIN_FILE)
Expand Down Expand Up @@ -217,7 +245,12 @@ endif()
set(CPPREST_EXCLUDE_WEBSOCKETS ON CACHE BOOL "Exclude websockets functionality." FORCE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-format-truncation")

project("xllm" LANGUAGES C CXX)
if(USE_CUDA)
project("xllm" LANGUAGES C CXX CUDA)
find_package(CUDAToolkit REQUIRED)
else()
project("xllm" LANGUAGES C CXX)
endif()

# find_package(CUDAToolkit REQUIRED)

Expand Down Expand Up @@ -352,6 +385,43 @@ if(USE_MLU)
)
endif()

if(USE_CUDA)
add_definitions(-DUSE_CUDA)
add_compile_definitions(TORCH_CUDA=1)
set(CMAKE_VERBOSE_MAKEFILE ON)
include_directories(
$ENV{PYTHON_INCLUDE_PATH}
$ENV{PYTORCH_INSTALL_PATH}/include
$ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include
)

link_directories(
$ENV{PYTHON_LIB_PATH}
$ENV{PYTORCH_INSTALL_PATH}/lib
$ENV{CUDA_TOOLKIT_ROOT_DIR}/lib64
)

set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -O3)
# The following definitions must be undefined since half-precision operation is required.
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS}
-U__CUDA_NO_HALF_OPERATORS__
-U__CUDA_NO_HALF_CONVERSIONS__
-U__CUDA_NO_HALF2_OPERATORS__
-U__CUDA_NO_BFLOAT16_CONVERSIONS__)
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} --use_fast_math -Xfatbin -compress-all)
message(STATUS "CUDA_NVCC_FLAGS: ${CUDA_NVCC_FLAGS}")

# find_package(NCCL REQUIRED)

# find cudnn
execute_process(COMMAND python -c "import nvidia.cudnn; print(nvidia.cudnn.__file__)" OUTPUT_VARIABLE CUDNN_PYTHON_PATH)
get_filename_component(CUDNN_ROOT_DIR "${CUDNN_PYTHON_PATH}" DIRECTORY)
link_directories(
${CUDNN_ROOT_DIR}/lib64
${CUDNN_ROOT_DIR}/lib
)
endif()

# check if USE_CXX11_ABI is set correctly
# if (DEFINED USE_CXX11_ABI)
# parse_make_options(${TORCH_CXX_FLAGS} "TORCH_CXX_FLAGS")
Expand Down
22 changes: 17 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,13 @@ def set_mlu_envs():
os.environ["LIBTORCH_ROOT"] = get_torch_root_path()
os.environ["PYTORCH_INSTALL_PATH"] = get_torch_root_path()
os.environ["PYTORCH_MLU_INSTALL_PATH"] = get_torch_mlu_root_path()


def set_cuda_envs():
os.environ["PYTHON_INCLUDE_PATH"] = get_python_include_path()
os.environ["PYTHON_LIB_PATH"] = get_torch_root_path()
os.environ["LIBTORCH_ROOT"] = get_torch_root_path()
os.environ["PYTORCH_INSTALL_PATH"] = get_torch_root_path()

class CMakeExtension(Extension):
def __init__(self, name: str, path: str, sourcedir: str = "") -> None:
super().__init__(name, sources=[])
Expand All @@ -223,7 +229,7 @@ def __init__(self, name: str, path: str, sourcedir: str = "") -> None:
class ExtBuild(build_ext):
user_options = build_ext.user_options + [
("base-dir=", None, "base directory of xLLM project"),
("device=", None, "target device type (a3 or a2 or mlu)"),
("device=", None, "target device type (a3 or a2 or mlu or cuda)"),
("arch=", None, "target arch type (x86 or arm)"),
("install-xllm-kernels=", None, "install xllm_kernels RPM package (true/false)"),
]
Expand Down Expand Up @@ -302,8 +308,14 @@ def build_extension(self, ext: CMakeExtension):
cmake_args += ["-DUSE_MLU=ON"]
# set mlu environment variables
set_mlu_envs()
elif self.device == "cuda":
cuda_architectures = "80;89;90"
cmake_args += ["-DUSE_CUDA=ON",
f"-DCMAKE_CUDA_ARCHITECTURES={cuda_architectures}"]
# set cuda environment variables
set_cuda_envs()
else:
raise ValueError("Please set --device to a2 or a3 or mlu.")
raise ValueError("Please set --device to a2 or a3 or mlu or cuda.")


# Adding CMake arguments set as environment variable
Expand Down Expand Up @@ -353,7 +365,7 @@ def build_extension(self, ext: CMakeExtension):

class BuildDistWheel(bdist_wheel):
user_options = bdist_wheel.user_options + [
("device=", None, "target device type (a3 or a2 or mlu)"),
("device=", None, "target device type (a3 or a2 or mlu or cuda)"),
("arch=", None, "target arch type (x86 or arm)"),
]

Expand Down Expand Up @@ -530,7 +542,7 @@ def apply_patch():
idx = sys.argv.index('--device')
if idx + 1 < len(sys.argv):
device = sys.argv[idx+1].lower()
if device not in ('a2', 'a3', 'mlu'):
if device not in ('a2', 'a3', 'mlu', 'cuda'):
print("Error: --device must be a2 or a3 or mlu (case-insensitive)")
sys.exit(1)
# Remove the arguments so setup() doesn't see them
Expand Down
38 changes: 38 additions & 0 deletions third_party/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,41 @@ target_include_directories(mooncake_store PUBLIC
)

target_link_libraries(mooncake_store PUBLIC transfer_engine cachelib_memory_allocator)


if(USE_CUDA)
cc_library(
NAME
cutlass
INCLUDES
cutlass/include
cutlass/tools/util/include
DEPS
torch # TODO: depends on CUDA instead of torch
)
cc_library(
NAME
dlpack
INCLUDES
dlpack/include
)
cc_library(
NAME
tvm-ffi
INCLUDES
tvm-ffi/include
DEPS
dlpack
)
cc_library(
NAME
flashinfer
INCLUDES
flashinfer/include
flashinfer/csrc
DEPS
cutlass
tvm-ffi
dlpack
)
endif()
1 change: 1 addition & 0 deletions third_party/cutlass
Submodule cutlass added at e6e2cc
1 change: 1 addition & 0 deletions third_party/dlpack
Submodule dlpack added at 93c8f2
1 change: 1 addition & 0 deletions third_party/flashinfer
Submodule flashinfer added at d4a3ff
1 change: 1 addition & 0 deletions third_party/tvm-ffi
Submodule tvm-ffi added at af898a
2 changes: 2 additions & 0 deletions xllm/core/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ cc_library(
rate_limiter.h
types.h
device_monitor.h
flashinfer_workspace.h
SRCS
etcd_client.cpp
global_flags.cpp
Expand All @@ -23,6 +24,7 @@ cc_library(
options.cpp
rate_limiter.cpp
device_monitor.cpp
flashinfer_workspace.cpp
DEPS
util
absl::random_random
Expand Down
46 changes: 46 additions & 0 deletions xllm/core/common/flashinfer_workspace.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "flashinfer_workspace.h"

#include "global_flags.h"

namespace xllm {

void FlashinferWorkspace::initialize(const torch::Device& device) {
float_workspace_buffer_ =
torch::empty({FLAGS_workspace_buffer_size},
torch::dtype(torch::kUInt8).device(device));
int_workspace_buffer_ =
torch::empty({FLAGS_workspace_buffer_size},
torch::dtype(torch::kUInt8).device(device));
page_locked_int_workspace_buffer_ = torch::empty(
{FLAGS_workspace_buffer_size},
torch::dtype(torch::kUInt8).device(torch::kCPU).pinned_memory(true));
}

torch::Tensor FlashinferWorkspace::get_float_workspace_buffer() {
return float_workspace_buffer_;
}

torch::Tensor FlashinferWorkspace::get_int_workspace_buffer() {
return int_workspace_buffer_;
}

torch::Tensor FlashinferWorkspace::get_page_locked_int_workspace_buffer() {
return page_locked_int_workspace_buffer_;
}

} // namespace xllm
49 changes: 49 additions & 0 deletions xllm/core/common/flashinfer_workspace.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once

#include <torch/torch.h>

#include <cstdint>

#include "macros.h"

namespace xllm {

class FlashinferWorkspace {
public:
static FlashinferWorkspace& get_instance() {
static FlashinferWorkspace instance;
return instance;
};

void initialize(const torch::Device& device);

torch::Tensor get_float_workspace_buffer();
torch::Tensor get_int_workspace_buffer();
torch::Tensor get_page_locked_int_workspace_buffer();

private:
FlashinferWorkspace() = default;
~FlashinferWorkspace() = default;
DISALLOW_COPY_AND_ASSIGN(FlashinferWorkspace);

torch::Tensor float_workspace_buffer_;
torch::Tensor int_workspace_buffer_;
torch::Tensor page_locked_int_workspace_buffer_;
};

} // namespace xllm
8 changes: 7 additions & 1 deletion xllm/core/common/global_flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,4 +382,10 @@ DEFINE_int64(buffer_size_per_seq,
// --- beam search config ---
DEFINE_bool(enable_beam_search_kernel,
false,
"Whether to enable beam search kernel.");
"Whether to enable beam search kernel.");

// --- flashinfer config ---
DEFINE_int32(workspace_buffer_size,
512 * 1024 * 1024,
"The user reserved workspace buffer used to store intermediate "
"attention results in split-k algorithm.");
4 changes: 3 additions & 1 deletion xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,6 @@ DECLARE_int64(cache_size_per_token);

DECLARE_int64(buffer_size_per_seq);

DECLARE_bool(enable_beam_search_kernel);
DECLARE_bool(enable_beam_search_kernel);

DECLARE_int32(workspace_buffer_size);
Loading