1+ /* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+ Licensed under the Apache License, Version 2.0 (the "License");
4+ you may not use this file except in compliance with the License.
5+ You may obtain a copy of the License at
6+
7+ https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+ Unless required by applicable law or agreed to in writing, software
10+ distributed under the License is distributed on an "AS IS" BASIS,
11+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ See the License for the specific language governing permissions and
13+ limitations under the License.
14+ ==============================================================================*/
15+
16+ #include < flashinfer/page.cuh>
17+
18+ #include " cuda_ops_api.h"
19+
20+ using namespace flashinfer ;
21+
22+ using tvm::ffi::Tensor;
23+
24+ namespace xllm ::kernel::cuda {
25+
26+ void append_paged_kv_cache (TensorView append_key,
27+ TensorView append_value,
28+ TensorView batch_indices,
29+ TensorView positions,
30+ TensorView paged_k_cache,
31+ TensorView paged_v_cache,
32+ TensorView kv_indices,
33+ TensorView kv_indptr,
34+ TensorView kv_last_page_len,
35+ int64_t layout) {
36+ CHECK_LAST_DIM_CONTIGUOUS (append_key);
37+ CHECK_LAST_DIM_CONTIGUOUS (append_value);
38+ CHECK_INPUT (batch_indices);
39+ CHECK_INPUT (positions);
40+ // NOTE(Zihao): doesn't have to be contiguous
41+ CHECK_LAST_DIM_CONTIGUOUS_INPUT (paged_k_cache);
42+ CHECK_LAST_DIM_CONTIGUOUS_INPUT (paged_v_cache);
43+ CHECK_INPUT (kv_indices);
44+ CHECK_INPUT (kv_indptr);
45+ CHECK_INPUT (kv_last_page_len);
46+ CHECK_DIM (3 , append_key);
47+ CHECK_DIM (3 , append_value);
48+ CHECK_DIM (1 , batch_indices);
49+ CHECK_DIM (1 , positions);
50+ CHECK_DIM (4 , paged_k_cache);
51+ CHECK_DIM (4 , paged_v_cache);
52+ CHECK_DIM (1 , kv_indices);
53+ CHECK_DIM (1 , kv_indptr);
54+ CHECK_DIM (1 , kv_last_page_len);
55+ unsigned int nnz = append_key->shape [0 ];
56+ unsigned int batch_size = kv_last_page_len->shape [0 ];
57+ TVM_FFI_ICHECK_EQ (kv_indptr->shape [0 ], batch_size + 1 );
58+ TVM_FFI_ICHECK_EQ (batch_indices->shape [0 ], nnz);
59+ TVM_FFI_ICHECK_EQ (positions->shape [0 ], nnz);
60+ CHECK_DEVICE (append_key, append_key);
61+ CHECK_DEVICE (append_value, append_key);
62+ CHECK_DEVICE (paged_k_cache, append_key);
63+ CHECK_DEVICE (paged_v_cache, append_key);
64+ CHECK_DEVICE (kv_indices, append_key);
65+ CHECK_DEVICE (kv_indptr, append_key);
66+ CHECK_DEVICE (kv_last_page_len, append_key);
67+
68+ QKVLayout kv_layout = QKVLayout (layout);
69+
70+ unsigned int num_heads, page_size, head_dim;
71+ head_dim = paged_k_cache->shape [3 ];
72+ if (kv_layout == QKVLayout::kHND ) {
73+ num_heads = paged_k_cache->shape [1 ];
74+ page_size = paged_k_cache->shape [2 ];
75+ } else {
76+ page_size = paged_k_cache->shape [1 ];
77+ num_heads = paged_k_cache->shape [2 ];
78+ }
79+
80+ // get kv_cache_strides
81+ auto k_strides = paged_k_cache->strides ;
82+ auto v_strides = paged_v_cache->strides ;
83+ auto k_dim = paged_k_cache->ndim ;
84+ TVM_FFI_ICHECK (std::equal (k_strides, k_strides + k_dim, v_strides))
85+ << " k/v strides must be identical" ;
86+
87+ auto append_k_strides = append_key->strides ;
88+ auto append_k_stride_n = append_k_strides[0 ];
89+ auto append_k_stride_h = append_k_strides[1 ];
90+ auto append_v_strides = append_value->strides ;
91+ auto append_v_stride_n = append_v_strides[0 ];
92+ auto append_v_stride_h = append_v_strides[1 ];
93+
94+ TVM_FFI_ICHECK_EQ (append_key->shape [1 ], num_heads);
95+ TVM_FFI_ICHECK_EQ (append_key->shape [2 ], head_dim);
96+ TVM_FFI_ICHECK_EQ (append_value->shape [1 ], num_heads);
97+ TVM_FFI_ICHECK_EQ (append_value->shape [2 ], head_dim);
98+
99+ cudaSetDevice (append_key->device .device_id );
100+ const cudaStream_t stream = get_stream (append_key->device );
101+ bool success =
102+ DISPATCH_DLPACK_DTYPE_TO_CTYPE (paged_k_cache->dtype , c_type, [&] {
103+ paged_kv_t <c_type, int32_t > paged_kv (
104+ num_heads,
105+ page_size,
106+ head_dim,
107+ batch_size,
108+ kv_layout,
109+ static_cast <c_type*>(paged_k_cache->data ),
110+ static_cast <c_type*>(paged_v_cache->data ),
111+ k_strides,
112+ static_cast <int32_t *>(kv_indices->data ),
113+ static_cast <int32_t *>(kv_indptr->data ),
114+ static_cast <int32_t *>(kv_last_page_len->data ));
115+ cudaError_t status =
116+ AppendPagedKVCache (paged_kv,
117+ static_cast <c_type*>(append_key->data ),
118+ static_cast <c_type*>(append_value->data ),
119+ static_cast <int32_t *>(batch_indices->data ),
120+ static_cast <int32_t *>(positions->data ),
121+ nnz,
122+ append_k_stride_n,
123+ append_k_stride_h,
124+ append_v_stride_n,
125+ append_v_stride_h,
126+ stream);
127+ TVM_FFI_ICHECK (status == cudaSuccess)
128+ << " AppendPagedKVCache failed with error: "
129+ << cudaGetErrorString (status);
130+ return true ;
131+ });
132+
133+ TVM_FFI_ICHECK (success) << " AppendPagedKVCache failed to dispatch with dtype "
134+ << paged_k_cache->dtype ;
135+ }
136+
137+ //
138+
139+ //
140+
141+ } // namespace xllm::kernel::cuda
0 commit comments