Skip to content

Commit e6f291d

Browse files
authored
server : fix context shift (ggml-org#5195)
* server : fix context shift + simplify self-extend * server : take system_tokens into account * server : more n_past fixes * server : rever n_past_se changes
1 parent 4003be0 commit e6f291d

File tree

2 files changed

+60
-50
lines changed

2 files changed

+60
-50
lines changed

examples/server/chat.sh

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ chat_completion() {
4848
top_p: 0.9,
4949
n_keep: $n_keep,
5050
n_predict: 256,
51+
cache_prompt: true,
5152
stop: ["\n### Human:"],
5253
stream: true
5354
}')"

examples/server/server.cpp

+59-50
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ struct llama_client_slot
185185
llama_sampling_context *ctx_sampling = nullptr;
186186

187187
int32_t ga_i = 0; // group-attention state
188-
int32_t ga_n = 1;// group-attention factor
188+
int32_t ga_n = 1; // group-attention factor
189189
int32_t ga_w = 512; // group-attention width
190190

191191
int32_t n_past_se = 0; // self-extend
@@ -219,7 +219,8 @@ struct llama_client_slot
219219
sent_token_probs_index = 0;
220220
infill = false;
221221
ga_i = 0;
222-
n_past_se = 0;
222+
n_past_se = 0;
223+
223224
generated_token_probs.clear();
224225

225226
for (slot_image & img : images)
@@ -1227,7 +1228,7 @@ struct llama_server_context
12271228
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
12281229
for (int i = 0; i < (int) append_tokens.size(); ++i)
12291230
{
1230-
llama_batch_add(batch, append_tokens[i], slot.n_past, { slot.id }, true);
1231+
llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true);
12311232
slot.n_past += 1;
12321233
}
12331234
}
@@ -1295,6 +1296,8 @@ struct llama_server_context
12951296
for (llama_client_slot &slot : slots)
12961297
{
12971298
slot.cache_tokens.clear();
1299+
slot.n_past = 0;
1300+
slot.n_past_se = 0;
12981301
}
12991302
}
13001303

@@ -1364,26 +1367,26 @@ struct llama_server_context
13641367
kv_cache_clear();
13651368
}
13661369
return true;
1367-
} else {
1368-
task_server task;
1369-
task.type = TASK_TYPE_NEXT_RESPONSE;
1370-
task.target_id = -1;
1371-
queue_tasks.post(task);
13721370
}
13731371

1372+
task_server task;
1373+
task.type = TASK_TYPE_NEXT_RESPONSE;
1374+
task.target_id = -1;
1375+
queue_tasks.post(task);
1376+
13741377
for (llama_client_slot &slot : slots)
13751378
{
13761379
if (slot.ga_n == 1)
13771380
{
1378-
if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx)
1381+
if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx)
13791382
{
13801383
// Shift context
1381-
const int n_left = slot.n_past - slot.params.n_keep - 1;
1384+
const int n_left = system_tokens.size() + slot.n_past - slot.params.n_keep - 1;
13821385
const int n_discard = n_left / 2;
13831386

13841387
LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard);
13851388
llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1);
1386-
llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
1389+
llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, system_tokens.size() + slot.n_past, -n_discard);
13871390

13881391
for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
13891392
{
@@ -1429,8 +1432,10 @@ struct llama_server_context
14291432
slot.i_batch = batch.n_tokens;
14301433

14311434
const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
1432-
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
14331435

1436+
// TODO: we always have to take into account the "system_tokens"
1437+
// this is not great and needs to be improved somehow
1438+
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
14341439
slot.n_past += 1;
14351440
}
14361441

@@ -1481,8 +1486,8 @@ struct llama_server_context
14811486

14821487
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
14831488
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
1484-
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
1485-
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
1489+
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
1490+
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
14861491
prefix_tokens.push_back(llama_token_middle(model));
14871492
prompt_tokens = prefix_tokens;
14881493
}
@@ -1582,19 +1587,22 @@ struct llama_server_context
15821587
}
15831588

15841589
LOG_VERBOSE("prompt ingested", {
1585-
{"n_past", slot.n_past},
1586-
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
1590+
{"n_past", slot.n_past},
1591+
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},
15871592
{"to_eval", tokens_to_str(ctx, slot.cache_tokens.cbegin() + slot.n_past, slot.cache_tokens.cend())},
15881593
});
15891594

15901595
const bool has_images = process_images(slot);
15911596

15921597
// process the prefix of first image
15931598
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
1599+
15941600
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
1595-
int ga_i = slot.ga_i;
1601+
1602+
int32_t ga_i = slot.ga_i;
15961603
int32_t ga_n = slot.ga_n;
15971604
int32_t ga_w = slot.ga_w;
1605+
15981606
for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past)
15991607
{
16001608
if (slot.ga_n != 1)
@@ -1606,7 +1614,7 @@ struct llama_server_context
16061614
}
16071615
}
16081616
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
1609-
slot_npast += 1;
1617+
slot_npast++;
16101618
}
16111619

16121620
if (has_images && !ingest_images(slot, n_batch))
@@ -1666,6 +1674,7 @@ struct llama_server_context
16661674
slot.n_past_se += n_tokens;
16671675
}
16681676
}
1677+
16691678
llama_batch batch_view =
16701679
{
16711680
n_tokens,
@@ -1782,51 +1791,51 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
17821791
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
17831792
if (llama_mlock_supported())
17841793
{
1785-
printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
1794+
printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
17861795
}
17871796
if (llama_mmap_supported())
17881797
{
1789-
printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
1798+
printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
17901799
}
1791-
printf(" --numa attempt optimizations that help on some NUMA systems\n");
1800+
printf(" --numa attempt optimizations that help on some NUMA systems\n");
17921801
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
17931802
printf(" -ngl N, --n-gpu-layers N\n");
1794-
printf(" number of layers to store in VRAM\n");
1803+
printf(" number of layers to store in VRAM\n");
17951804
printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
1796-
printf(" how to split the model across multiple GPUs, one of:\n");
1797-
printf(" - none: use one GPU only\n");
1798-
printf(" - layer (default): split layers and KV across GPUs\n");
1799-
printf(" - row: split rows across GPUs\n");
1805+
printf(" how to split the model across multiple GPUs, one of:\n");
1806+
printf(" - none: use one GPU only\n");
1807+
printf(" - layer (default): split layers and KV across GPUs\n");
1808+
printf(" - row: split rows across GPUs\n");
18001809
printf(" -ts SPLIT --tensor-split SPLIT\n");
1801-
printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
1802-
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
1803-
printf(" or for intermediate results and KV (with split-mode = row)\n");
1810+
printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
1811+
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
1812+
printf(" or for intermediate results and KV (with split-mode = row)\n");
18041813
#endif
18051814
printf(" -m FNAME, --model FNAME\n");
1806-
printf(" model path (default: %s)\n", params.model.c_str());
1815+
printf(" model path (default: %s)\n", params.model.c_str());
18071816
printf(" -a ALIAS, --alias ALIAS\n");
1808-
printf(" set an alias for the model, will be added as `model` field in completion response\n");
1809-
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
1810-
printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
1811-
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
1812-
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
1813-
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
1814-
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
1815-
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
1816-
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
1817-
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
1818-
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
1819-
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
1820-
printf(" -spf FNAME, --system-prompt-file FNAME\n");
1821-
printf(" Set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
1822-
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n");
1823-
printf(" --log-disable disables logging to a file.\n");
1817+
printf(" set an alias for the model, will be added as `model` field in completion response\n");
1818+
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
1819+
printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
1820+
printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
1821+
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
1822+
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
1823+
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
1824+
printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
1825+
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
1826+
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
1827+
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
1828+
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
1829+
printf(" -spf FNAME, --system-prompt-file FNAME\n");
1830+
printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
1831+
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n");
1832+
printf(" --log-disable disables logging to a file.\n");
18241833
printf("\n");
18251834
printf(" --override-kv KEY=TYPE:VALUE\n");
1826-
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
1827-
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
1828-
printf(" -gan N, --grp-attn-n N Set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
1829-
printf(" -gaw N, --grp-attn-w N Set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
1835+
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
1836+
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
1837+
printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
1838+
printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
18301839
printf("\n");
18311840
}
18321841

0 commit comments

Comments
 (0)