Skip to content

Commit 2eb67c7

Browse files
committed
feat: add flashinfer as kernel backend for cuda.
1 parent 1028207 commit 2eb67c7

37 files changed

+1434
-249
lines changed

xllm/core/framework/batch/batch_input_builder.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
212212
state_.q_seq_lens.insert(state_.q_seq_lens.end(),
213213
state.q_seq_lens.begin(),
214214
state.q_seq_lens.end());
215-
#elif defined(USE_MLU)
215+
#elif defined(USE_MLU) || defined(USE_CUDA)
216216
int32_t seq_len_offset = state_.seq_lens.back();
217217
// skip the first element which is 0
218218
for (size_t i = 1; i < state.seq_lens.size(); ++i) {
@@ -281,7 +281,7 @@ void BatchInputBuilder::process_single_sequence(
281281
#if defined(USE_NPU)
282282
state.seq_lens.push_back(seq_len);
283283
state.q_seq_lens.push_back(q_seq_len);
284-
#elif defined(USE_MLU)
284+
#elif defined(USE_MLU) || defined(USE_CUDA)
285285
state.seq_lens.push_back(state.seq_lens.back() + seq_len);
286286
state.q_seq_lens.push_back(state.q_seq_lens.back() + q_seq_len);
287287
#endif
@@ -425,7 +425,12 @@ void BatchInputBuilder::setup_kv_cache_info(
425425
block_size = block.size();
426426
block_ids.push_back(block.id());
427427
u_block_ids.emplace_back(block.id());
428+
state.paged_kv_indices.push_back(block.id());
428429
}
430+
state.paged_kv_indptr.push_back(state.paged_kv_indptr.back() + blocks.size());
431+
int32_t last_page_len =
432+
(seq_len % block_size == 0) ? block_size : seq_len % block_size;
433+
state.paged_kv_last_page_len.push_back(last_page_len);
429434

430435
int32_t kv_cache_block_idx = n_kv_cache_tokens / block_size;
431436
for (auto iter = block_ids.begin() + kv_cache_block_idx;
@@ -494,12 +499,15 @@ void BatchInputBuilder::padding_decode_batch_size(
494499
#if defined(USE_NPU)
495500
state_.seq_lens.push_back(num_decoding_tokens);
496501
state_.q_seq_lens.push_back(num_decoding_tokens);
497-
#elif defined(USE_MLU)
502+
#elif defined(USE_MLU) || defined(USE_CUDA)
498503
state_.seq_lens.push_back(state_.seq_lens.back() + num_decoding_tokens);
499504
state_.q_seq_lens.push_back(state_.q_seq_lens.back() +
500505
num_decoding_tokens);
501506
#endif
502507
state_.block_tables_vec.emplace_back();
508+
state_.paged_kv_indices.push_back(0);
509+
state_.paged_kv_indptr.push_back(state_.paged_kv_indptr.back() + 1);
510+
state_.paged_kv_last_page_len.push_back(1);
503511
}
504512
}
505513
}

xllm/core/framework/batch/batch_input_builder.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,11 @@ class BatchInputBuilder {
103103
// for continuous kvcache
104104
std::vector<int64_t> new_cache_slot_offsets; //[n_tokens]
105105
std::vector<int64_t> kv_cache_start_offsets; //[n_seq]
106+
107+
// for flashinfer
108+
std::vector<int32_t> paged_kv_indptr = {0};
109+
std::vector<int32_t> paged_kv_indices;
110+
std::vector<int32_t> paged_kv_last_page_len;
106111
};
107112

108113
// Helper methods for sequence processing
@@ -127,7 +132,6 @@ class BatchInputBuilder {
127132
uint32_t q_seq_len,
128133
BuilderState* state_ptr = nullptr,
129134
std::unordered_set<int32_t>* write_block_ids_ptr = nullptr);
130-
131135
void setup_continuous_kv_cache_info(Sequence* sequence,
132136
uint32_t n_kv_cache_tokens,
133137
uint32_t seq_len,

xllm/core/framework/model/model_input_params.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ struct ModelInputParams {
9292

9393
// Copy graph_buffer to device
9494
params.graph_buffer = safe_to(graph_buffer, device, true);
95+
96+
// params for flashinfer
97+
params.paged_kv_indptr = safe_to(paged_kv_indptr, device);
98+
params.paged_kv_indices = safe_to(paged_kv_indices, device);
99+
params.paged_kv_last_page_len = safe_to(paged_kv_last_page_len, device);
100+
95101
return params;
96102
}
97103

@@ -187,6 +193,21 @@ struct ModelInputParams {
187193
// Graph execution buffer for temporary tensor storage
188194
// Used by ACL Graph Executor to avoid repeated memory allocation
189195
torch::Tensor graph_buffer;
196+
197+
// the indptr of the paged kv-cache
198+
// used in flashinfer
199+
// IntTensor: [n_seq + 1]
200+
torch::Tensor paged_kv_indptr;
201+
202+
// the page indices of the paged kv cache
203+
// used in flashinfer
204+
torch::Tensor paged_kv_indices;
205+
206+
// the number of entries in the last page of each request in
207+
// the paged kv cache
208+
// used in flashinfer
209+
// IntTensor: [n_seq]
210+
torch::Tensor paged_kv_last_page_len;
190211
};
191212

192213
} // namespace xllm

xllm/core/kernels/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ cc_library(
1717
kernels
1818
HDRS
1919
param.h
20-
torch_ops_api.h
20+
ops_api.h
2121
SRCS
22-
torch_ops_api.cpp
22+
ops_api.cpp
2323
DEPS
2424
torch
2525
$<$<BOOL:${USE_NPU}>:npu_kernels>
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
include(cc_library)
2+
3+
file(GLOB_RECURSE CUDA_HEADER_FILES
4+
"${CMAKE_CURRENT_LIST_DIR}/*.h"
5+
)
6+
7+
file(GLOB_RECURSE CUDA_SOURCE_FILES
8+
"${CMAKE_CURRENT_LIST_DIR}/*.cpp"
9+
)
10+
11+
cc_library(
12+
NAME
13+
cuda_kernels
14+
HDRS
15+
${CUDA_HEADER_FILES}
16+
SRCS
17+
${CUDA_SOURCE_FILES}
18+
DEPS
19+
flashinfer
20+
)

xllm/core/kernels/cuda/active.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <cuda_runtime.h>
17+
18+
#include <flashinfer/activation.cuh>
19+
20+
#include "cuda_ops_api.h"
21+
22+
using namespace flashinfer;
23+
24+
namespace xllm::kernel::cuda {
25+
26+
__device__ __forceinline__ float silu(const float& val) {
27+
return val / (1.0f + __expf(-val));
28+
}
29+
30+
__device__ __forceinline__ float gelu(const float& val) {
31+
constexpr float kAlpha = M_SQRT1_2;
32+
return val * 0.5f * (1.0f + ::erf(val * kAlpha));
33+
}
34+
35+
__device__ __forceinline__ float gelu_tanh(const float& val) {
36+
const float cdf =
37+
0.5f * (1.0f + math::tanh((0.7978845608028654f *
38+
(val + 0.044715f * val * val * val))));
39+
return val * cdf;
40+
}
41+
42+
void act_and_mul(TensorView out,
43+
TensorView input,
44+
const std::string& act_mode,
45+
bool enable_pdl) {
46+
int d = input->shape[input->ndim - 1] / 2;
47+
int64_t num_tokens = input.numel() / input->shape[input->ndim - 1];
48+
dim3 grid(num_tokens);
49+
50+
cudaSetDevice(out->device.device_id);
51+
const cudaStream_t stream = get_stream(out->device);
52+
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input->dtype, c_type, [&] {
53+
uint32_t vec_size = 16 / sizeof(c_type);
54+
cudaLaunchConfig_t config;
55+
config.gridDim = num_tokens;
56+
config.blockDim = std::min(d / vec_size, 1024U);
57+
config.dynamicSmemBytes = 0;
58+
config.stream = stream;
59+
cudaLaunchAttribute attrs[1];
60+
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
61+
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
62+
config.numAttrs = 1;
63+
config.attrs = attrs;
64+
65+
auto kernel = activation::act_and_mul_kernel<c_type, act_mode>;
66+
67+
cudaLaunchKernelEx(&config,
68+
kernel,
69+
static_cast<c_type*>(out->data),
70+
static_cast<c_type*>(input->data),
71+
d);
72+
73+
cudaError_t err = cudaGetLastError();
74+
TVM_FFI_ICHECK(err == cudaSuccess)
75+
<< "Failed to launch kernel: " << cudaGetErrorString(err);
76+
77+
return true;
78+
});
79+
}
80+
81+
} // namespace xllm::kernel::cuda
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <flashinfer/page.cuh>
17+
18+
#include "cuda_ops_api.h"
19+
20+
using namespace flashinfer;
21+
22+
using tvm::ffi::Tensor;
23+
24+
namespace xllm::kernel::cuda {
25+
26+
void append_paged_kv_cache(TensorView append_key,
27+
TensorView append_value,
28+
TensorView batch_indices,
29+
TensorView positions,
30+
TensorView paged_k_cache,
31+
TensorView paged_v_cache,
32+
TensorView kv_indices,
33+
TensorView kv_indptr,
34+
TensorView kv_last_page_len,
35+
int64_t layout) {
36+
CHECK_LAST_DIM_CONTIGUOUS(append_key);
37+
CHECK_LAST_DIM_CONTIGUOUS(append_value);
38+
CHECK_INPUT(batch_indices);
39+
CHECK_INPUT(positions);
40+
// NOTE(Zihao): doesn't have to be contiguous
41+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(paged_k_cache);
42+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(paged_v_cache);
43+
CHECK_INPUT(kv_indices);
44+
CHECK_INPUT(kv_indptr);
45+
CHECK_INPUT(kv_last_page_len);
46+
CHECK_DIM(3, append_key);
47+
CHECK_DIM(3, append_value);
48+
CHECK_DIM(1, batch_indices);
49+
CHECK_DIM(1, positions);
50+
CHECK_DIM(4, paged_k_cache);
51+
CHECK_DIM(4, paged_v_cache);
52+
CHECK_DIM(1, kv_indices);
53+
CHECK_DIM(1, kv_indptr);
54+
CHECK_DIM(1, kv_last_page_len);
55+
unsigned int nnz = append_key->shape[0];
56+
unsigned int batch_size = kv_last_page_len->shape[0];
57+
TVM_FFI_ICHECK_EQ(kv_indptr->shape[0], batch_size + 1);
58+
TVM_FFI_ICHECK_EQ(batch_indices->shape[0], nnz);
59+
TVM_FFI_ICHECK_EQ(positions->shape[0], nnz);
60+
CHECK_DEVICE(append_key, append_key);
61+
CHECK_DEVICE(append_value, append_key);
62+
CHECK_DEVICE(paged_k_cache, append_key);
63+
CHECK_DEVICE(paged_v_cache, append_key);
64+
CHECK_DEVICE(kv_indices, append_key);
65+
CHECK_DEVICE(kv_indptr, append_key);
66+
CHECK_DEVICE(kv_last_page_len, append_key);
67+
68+
QKVLayout kv_layout = QKVLayout(layout);
69+
70+
unsigned int num_heads, page_size, head_dim;
71+
head_dim = paged_k_cache->shape[3];
72+
if (kv_layout == QKVLayout::kHND) {
73+
num_heads = paged_k_cache->shape[1];
74+
page_size = paged_k_cache->shape[2];
75+
} else {
76+
page_size = paged_k_cache->shape[1];
77+
num_heads = paged_k_cache->shape[2];
78+
}
79+
80+
// get kv_cache_strides
81+
auto k_strides = paged_k_cache->strides;
82+
auto v_strides = paged_v_cache->strides;
83+
auto k_dim = paged_k_cache->ndim;
84+
TVM_FFI_ICHECK(std::equal(k_strides, k_strides + k_dim, v_strides))
85+
<< "k/v strides must be identical";
86+
87+
auto append_k_strides = append_key->strides;
88+
auto append_k_stride_n = append_k_strides[0];
89+
auto append_k_stride_h = append_k_strides[1];
90+
auto append_v_strides = append_value->strides;
91+
auto append_v_stride_n = append_v_strides[0];
92+
auto append_v_stride_h = append_v_strides[1];
93+
94+
TVM_FFI_ICHECK_EQ(append_key->shape[1], num_heads);
95+
TVM_FFI_ICHECK_EQ(append_key->shape[2], head_dim);
96+
TVM_FFI_ICHECK_EQ(append_value->shape[1], num_heads);
97+
TVM_FFI_ICHECK_EQ(append_value->shape[2], head_dim);
98+
99+
cudaSetDevice(append_key->device.device_id);
100+
const cudaStream_t stream = get_stream(append_key->device);
101+
bool success =
102+
DISPATCH_DLPACK_DTYPE_TO_CTYPE(paged_k_cache->dtype, c_type, [&] {
103+
paged_kv_t<c_type, int32_t> paged_kv(
104+
num_heads,
105+
page_size,
106+
head_dim,
107+
batch_size,
108+
kv_layout,
109+
static_cast<c_type*>(paged_k_cache->data),
110+
static_cast<c_type*>(paged_v_cache->data),
111+
k_strides,
112+
static_cast<int32_t*>(kv_indices->data),
113+
static_cast<int32_t*>(kv_indptr->data),
114+
static_cast<int32_t*>(kv_last_page_len->data));
115+
cudaError_t status =
116+
AppendPagedKVCache(paged_kv,
117+
static_cast<c_type*>(append_key->data),
118+
static_cast<c_type*>(append_value->data),
119+
static_cast<int32_t*>(batch_indices->data),
120+
static_cast<int32_t*>(positions->data),
121+
nnz,
122+
append_k_stride_n,
123+
append_k_stride_h,
124+
append_v_stride_n,
125+
append_v_stride_h,
126+
stream);
127+
TVM_FFI_ICHECK(status == cudaSuccess)
128+
<< "AppendPagedKVCache failed with error: "
129+
<< cudaGetErrorString(status);
130+
return true;
131+
});
132+
133+
TVM_FFI_ICHECK(success) << "AppendPagedKVCache failed to dispatch with dtype "
134+
<< paged_k_cache->dtype;
135+
}
136+
137+
//
138+
139+
//
140+
141+
} // namespace xllm::kernel::cuda

0 commit comments

Comments
 (0)