Skip to content

Commit 348b4de

Browse files
committed
feat: add flashinfer as kernel backend for cuda device.
1 parent fcec9c9 commit 348b4de

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1992
-56
lines changed

.gitmodules

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,15 @@
2828
[submodule "third_party/Mooncake"]
2929
path = third_party/Mooncake
3030
url = https://gitcode.com/xLLM-AI/Mooncake.git
31+
[submodule "third_party/flashinfer"]
32+
path = third_party/flashinfer
33+
url = https://gitcode.com/xLLM-AI/flashinfer.git
34+
[submodule "third_party/cutlass"]
35+
path = third_party/cutlass
36+
url = https://gitcode.com/xLLM-AI/cutlass.git
37+
[submodule "third_party/tvm-ffi"]
38+
path = third_party/tvm-ffi
39+
url = https://gitcode.com/xLLM-AI/tvm-ffi.git
40+
[submodule "third_party/dlpack"]
41+
path = third_party/dlpack
42+
url = https://gitcode.com/xLLM-AI/dlpack.git

CMakeLists.txt

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ set(CMAKE_CXX_STANDARD 20)
101101
set(CMAKE_CXX_STANDARD_REQUIRED ON)
102102
set(CMAKE_CXX_EXTENSIONS ON)
103103

104-
if(USE_NPU)
104+
if(USE_NPU OR USE_CUDA)
105105
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
106106
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
107107
elseif(USE_MLU)
@@ -178,6 +178,32 @@ if (DEFINED ENV{DEPENDENCES_ROOT})
178178
message(STATUS "Using DEPENDENCES_ROOT: $ENV{DEPENDENCES_ROOT}")
179179
endif()
180180

181+
# set architecture for CUDA
182+
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES AND USE_CUDA)
183+
set(CMAKE_CUDA_ARCHITECTURES 80)
184+
endif()
185+
186+
# Build TORCH_CUDA_ARCH_LIST
187+
if(USE_CUDA)
188+
# Build TORCH_CUDA_ARCH_LIST
189+
set(TORCH_CUDA_ARCH_LIST "")
190+
foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES)
191+
if(CUDA_ARCH MATCHES "^([0-9])([0-9])a$")
192+
set(TORCH_ARCH "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}a")
193+
elseif(CUDA_ARCH MATCHES "^([0-9])([0-9])*$")
194+
set(TORCH_ARCH "${CMAKE_MATCH_1}.${CMAKE_MATCH_2}")
195+
elseif(CUDA_ARCH STREQUAL "native")
196+
set(TORCH_ARCH "Auto")
197+
else()
198+
message(FATAL_ERROR "${CUDA_ARCH} is not supported")
199+
endif()
200+
list(APPEND TORCH_CUDA_ARCH_LIST ${TORCH_ARCH})
201+
endforeach()
202+
203+
message(STATUS "CMAKE_CUDA_ARCHITECTURES: ${CMAKE_CUDA_ARCHITECTURES}")
204+
message(STATUS "TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}")
205+
endif()
206+
181207
# configure vcpkg
182208
# have to set CMAKE_TOOLCHAIN_FILE before first project call.
183209
# if (DEFINED ENV{VCPKG_ROOT} AND NOT DEFINED CMAKE_TOOLCHAIN_FILE)
@@ -217,7 +243,12 @@ endif()
217243
set(CPPREST_EXCLUDE_WEBSOCKETS ON CACHE BOOL "Exclude websockets functionality." FORCE)
218244
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-format-truncation")
219245

220-
project("xllm" LANGUAGES C CXX)
246+
if(USE_CUDA)
247+
project("xllm" LANGUAGES C CXX CUDA)
248+
find_package(CUDAToolkit REQUIRED)
249+
else()
250+
project("xllm" LANGUAGES C CXX)
251+
endif()
221252

222253
# find_package(CUDAToolkit REQUIRED)
223254

@@ -352,6 +383,43 @@ if(USE_MLU)
352383
)
353384
endif()
354385

386+
if(USE_CUDA)
387+
add_definitions(-DUSE_CUDA)
388+
add_compile_definitions(TORCH_CUDA=1)
389+
set(CMAKE_VERBOSE_MAKEFILE ON)
390+
include_directories(
391+
$ENV{PYTHON_INCLUDE_PATH}
392+
$ENV{PYTORCH_INSTALL_PATH}/include
393+
$ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include
394+
)
395+
396+
link_directories(
397+
$ENV{PYTHON_LIB_PATH}
398+
$ENV{PYTORCH_INSTALL_PATH}/lib
399+
$ENV{CUDA_TOOLKIT_ROOT_DIR}/lib64
400+
)
401+
402+
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} -O3)
403+
# The following definitions must be undefined since half-precision operation is required.
404+
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS}
405+
-U__CUDA_NO_HALF_OPERATORS__
406+
-U__CUDA_NO_HALF_CONVERSIONS__
407+
-U__CUDA_NO_HALF2_OPERATORS__
408+
-U__CUDA_NO_BFLOAT16_CONVERSIONS__)
409+
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} --use_fast_math -Xfatbin -compress-all)
410+
message(STATUS "CUDA_NVCC_FLAGS: ${CUDA_NVCC_FLAGS}")
411+
412+
# find_package(NCCL REQUIRED)
413+
414+
# # find cudnn
415+
# execute_process(COMMAND python -c "import nvidia.cudnn; print(nvidia.cudnn.__file__)" OUTPUT_VARIABLE CUDNN_PYTHON_PATH)
416+
# get_filename_component(CUDNN_ROOT_DIR "${CUDNN_PYTHON_PATH}" DIRECTORY)
417+
# link_directories(
418+
# ${CUDNN_ROOT_DIR}/lib64
419+
# ${CUDNN_ROOT_DIR}/lib
420+
# )
421+
endif()
422+
355423
# check if USE_CXX11_ABI is set correctly
356424
# if (DEFINED USE_CXX11_ABI)
357425
# parse_make_options(${TORCH_CXX_FLAGS} "TORCH_CXX_FLAGS")

setup.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,13 @@ def set_mlu_envs():
212212
os.environ["LIBTORCH_ROOT"] = get_torch_root_path()
213213
os.environ["PYTORCH_INSTALL_PATH"] = get_torch_root_path()
214214
os.environ["PYTORCH_MLU_INSTALL_PATH"] = get_torch_mlu_root_path()
215-
215+
216+
def set_cuda_envs():
217+
os.environ["PYTHON_INCLUDE_PATH"] = get_python_include_path()
218+
os.environ["PYTHON_LIB_PATH"] = get_torch_root_path()
219+
os.environ["LIBTORCH_ROOT"] = get_torch_root_path()
220+
os.environ["PYTORCH_INSTALL_PATH"] = get_torch_root_path()
221+
216222
class CMakeExtension(Extension):
217223
def __init__(self, name: str, path: str, sourcedir: str = "") -> None:
218224
super().__init__(name, sources=[])
@@ -223,7 +229,7 @@ def __init__(self, name: str, path: str, sourcedir: str = "") -> None:
223229
class ExtBuild(build_ext):
224230
user_options = build_ext.user_options + [
225231
("base-dir=", None, "base directory of xLLM project"),
226-
("device=", None, "target device type (a3 or a2 or mlu)"),
232+
("device=", None, "target device type (a3 or a2 or mlu or cuda)"),
227233
("arch=", None, "target arch type (x86 or arm)"),
228234
("install-xllm-kernels=", None, "install xllm_kernels RPM package (true/false)"),
229235
]
@@ -302,8 +308,14 @@ def build_extension(self, ext: CMakeExtension):
302308
cmake_args += ["-DUSE_MLU=ON"]
303309
# set mlu environment variables
304310
set_mlu_envs()
311+
elif self.device == "cuda":
312+
cuda_architectures = "80;89;90"
313+
cmake_args += ["-DUSE_CUDA=ON",
314+
f"-DCMAKE_CUDA_ARCHITECTURES={cuda_architectures}"]
315+
# set cuda environment variables
316+
set_cuda_envs()
305317
else:
306-
raise ValueError("Please set --device to a2 or a3 or mlu.")
318+
raise ValueError("Please set --device to a2 or a3 or mlu or cuda.")
307319

308320

309321
# Adding CMake arguments set as environment variable
@@ -353,7 +365,7 @@ def build_extension(self, ext: CMakeExtension):
353365

354366
class BuildDistWheel(bdist_wheel):
355367
user_options = bdist_wheel.user_options + [
356-
("device=", None, "target device type (a3 or a2 or mlu)"),
368+
("device=", None, "target device type (a3 or a2 or mlu or cuda)"),
357369
("arch=", None, "target arch type (x86 or arm)"),
358370
]
359371

@@ -530,7 +542,7 @@ def apply_patch():
530542
idx = sys.argv.index('--device')
531543
if idx + 1 < len(sys.argv):
532544
device = sys.argv[idx+1].lower()
533-
if device not in ('a2', 'a3', 'mlu'):
545+
if device not in ('a2', 'a3', 'mlu', 'cuda'):
534546
print("Error: --device must be a2 or a3 or mlu (case-insensitive)")
535547
sys.exit(1)
536548
# Remove the arguments so setup() doesn't see them

third_party/CMakeLists.txt

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,41 @@ target_include_directories(mooncake_store PUBLIC
2020
)
2121

2222
target_link_libraries(mooncake_store PUBLIC transfer_engine cachelib_memory_allocator)
23+
24+
25+
if(USE_CUDA)
26+
cc_library(
27+
NAME
28+
cutlass
29+
INCLUDES
30+
cutlass/include
31+
cutlass/tools/util/include
32+
DEPS
33+
torch # TODO: depends on CUDA instead of torch
34+
)
35+
cc_library(
36+
NAME
37+
dlpack
38+
INCLUDES
39+
dlpack/include
40+
)
41+
cc_library(
42+
NAME
43+
tvm-ffi
44+
INCLUDES
45+
tvm-ffi/include
46+
DEPS
47+
dlpack
48+
)
49+
cc_library(
50+
NAME
51+
flashinfer
52+
INCLUDES
53+
flashinfer/include
54+
flashinfer/csrc
55+
DEPS
56+
cutlass
57+
tvm-ffi
58+
dlpack
59+
)
60+
endif()

third_party/cutlass

Submodule cutlass added at e6e2cc2

third_party/dlpack

Submodule dlpack added at 93c8f2a

third_party/flashinfer

Submodule flashinfer added at d4a3ff4

third_party/tvm-ffi

Submodule tvm-ffi added at af898a2

xllm/core/common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ cc_library(
1515
rate_limiter.h
1616
types.h
1717
device_monitor.h
18+
flashinfer_workspace.h
1819
SRCS
1920
etcd_client.cpp
2021
global_flags.cpp
@@ -23,6 +24,7 @@ cc_library(
2324
options.cpp
2425
rate_limiter.cpp
2526
device_monitor.cpp
27+
flashinfer_workspace.cpp
2628
DEPS
2729
util
2830
absl::random_random
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "flashinfer_workspace.h"
17+
18+
#include "global_flags.h"
19+
20+
namespace xllm {
21+
22+
void FlashinferWorkspace::initialize(const torch::Device& device) {
23+
float_workspace_buffer_ =
24+
torch::empty({FLAGS_workspace_buffer_size},
25+
torch::dtype(torch::kUInt8).device(device));
26+
int_workspace_buffer_ =
27+
torch::empty({FLAGS_workspace_buffer_size},
28+
torch::dtype(torch::kUInt8).device(device));
29+
page_locked_int_workspace_buffer_ = torch::empty(
30+
{FLAGS_workspace_buffer_size},
31+
torch::dtype(torch::kUInt8).device(torch::kCPU).pinned_memory(true));
32+
}
33+
34+
torch::Tensor FlashinferWorkspace::get_float_workspace_buffer() {
35+
return float_workspace_buffer_;
36+
}
37+
38+
torch::Tensor FlashinferWorkspace::get_int_workspace_buffer() {
39+
return int_workspace_buffer_;
40+
}
41+
42+
torch::Tensor FlashinferWorkspace::get_page_locked_int_workspace_buffer() {
43+
return page_locked_int_workspace_buffer_;
44+
}
45+
46+
} // namespace xllm

0 commit comments

Comments
 (0)