Skip to content

Commit 38f4863

Browse files
committed
Abstract into GGML
1 parent a3d48e4 commit 38f4863

File tree

3 files changed

+22
-35
lines changed

3 files changed

+22
-35
lines changed

Diff for: ggml/include/ggml-backend.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,7 @@ extern "C" {
233233
GGML_API void ggml_backend_view_init(struct ggml_tensor * tensor);
234234

235235
// Copy K and V cache pointers to backend
236-
GGML_API void ggml_backend_copy_k_cache_ptrs(const char ** host_cache_ptrs, size_t size);
237-
GGML_API void ggml_backend_copy_v_cache_ptrs(const char ** host_cache_ptrs, size_t size);
236+
GGML_API void ggml_backend_copy_kv_cache_ptrs(const int64_t n_layer, const int64_t kv_self_head, struct ggml_tensor ** kv_kl, struct ggml_tensor ** kv_vl, const int64_t n_embd_k_gqa, const int64_t n_embd_v_gqa, const bool flash_attn);
238237

239238
#ifdef __cplusplus
240239
}

Diff for: ggml/src/ggml-cuda/cpy.cu

+20-6
Original file line numberDiff line numberDiff line change
@@ -302,12 +302,26 @@ static void ggml_backend_copy_cache_ptrs(char **& backend_cache_ptrs, const char
302302
cudaMemcpy(backend_cache_ptrs, host_cache_ptrs, size*sizeof(char *), cudaMemcpyHostToDevice);
303303
}
304304

305-
void ggml_backend_copy_k_cache_ptrs(const char ** host_cache_ptrs, size_t size) {
306-
ggml_backend_copy_cache_ptrs(k_cache_ptrs, host_cache_ptrs, size);
307-
}
308-
309-
void ggml_backend_copy_v_cache_ptrs(const char ** host_cache_ptrs, size_t size) {
310-
ggml_backend_copy_cache_ptrs(v_cache_ptrs, host_cache_ptrs, size);
305+
void ggml_backend_copy_kv_cache_ptrs(const int64_t n_layer, const int64_t kv_head, struct ggml_tensor ** kv_kl, struct ggml_tensor ** kv_vl, const int64_t n_embd_k_gqa,const int64_t n_embd_v_gqa, const bool flash_attn) {
306+
307+
std::vector<const char *> host_k_cache_ptrs;
308+
std::vector<const char *> host_v_cache_ptrs;
309+
for (int il = 0; il < n_layer; ++il) {
310+
// K cache pointer for this layer
311+
ggml_tensor * tmp_tensor = kv_kl[il];
312+
size_t tmp_offset = (ggml_row_size(kv_kl[il]->type, n_embd_k_gqa))*kv_head;
313+
host_k_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
314+
// V cache pointer for this layer
315+
tmp_tensor = kv_vl[il];
316+
if (flash_attn) {
317+
tmp_offset = (kv_head)*ggml_row_size(kv_vl[il]->type, n_embd_v_gqa);
318+
} else {
319+
tmp_offset = (kv_head)*ggml_element_size(kv_vl[il]);
320+
}
321+
host_v_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
322+
}
323+
ggml_backend_copy_cache_ptrs(k_cache_ptrs, host_k_cache_ptrs.data(), host_k_cache_ptrs.size());
324+
ggml_backend_copy_cache_ptrs(v_cache_ptrs, host_v_cache_ptrs.data(), host_v_cache_ptrs.size());
311325
}
312326

313327
static void ggml_cpy_f16_f32_cuda(

Diff for: src/llama.cpp

+1-27
Original file line numberDiff line numberDiff line change
@@ -14735,33 +14735,7 @@ static int llama_decode_internal(
1473514735
ggml_backend_sched_alloc_graph(lctx.sched, gf);
1473614736

1473714737
#ifdef GGML_USE_CUDA
14738-
// Copy K and V cache pointers to backend
14739-
14740-
// Stage pointers for each layer in host vectors
14741-
std::vector<const char *> k_cache_ptrs;
14742-
std::vector<const char *> v_cache_ptrs;
14743-
const int64_t n_layer = model.hparams.n_layer;
14744-
const int64_t kv_head = kv_self.head;
14745-
for (int il = 0; il < n_layer; ++il) {
14746-
// K cache pointer for this layer
14747-
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
14748-
ggml_tensor * tmp_tensor = kv_self.k_l[il];
14749-
size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
14750-
k_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14751-
// V cache pointer for this layer
14752-
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
14753-
tmp_tensor = kv_self.v_l[il];
14754-
if (cparams.flash_attn) {
14755-
tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
14756-
} else {
14757-
tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
14758-
}
14759-
v_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
14760-
}
14761-
14762-
// copy host vector data to backend
14763-
ggml_backend_copy_k_cache_ptrs(k_cache_ptrs.data(), k_cache_ptrs.size());
14764-
ggml_backend_copy_v_cache_ptrs(v_cache_ptrs.data(), v_cache_ptrs.size());
14738+
ggml_backend_copy_kv_cache_ptrs(model.hparams.n_layer, kv_self.head, kv_self.k_l.data(), kv_self.v_l.data(), hparams.n_embd_k_gqa(), hparams.n_embd_v_gqa(), cparams.flash_attn);
1476514739
#endif
1476614740

1476714741
llama_set_inputs(lctx, u_batch);

0 commit comments

Comments
 (0)