diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 129d013ac75f7..47f2b2026121c 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1267,17 +1267,13 @@ struct server_slot { int64_t t_last_used = -1; // generation props - int32_t n_ctx = 0; // context size per slot - int32_t n_past = 0; - int32_t n_decoded = 0; + int32_t n_ctx = 0; // context size per slot + int32_t n_past = 0; // current position (note: it is not affected by context shift) + int32_t n_decoded = 0; // number of tokens generated int32_t n_remaining = -1; int32_t i_batch = -1; int32_t n_predict = -1; // TODO: disambiguate from params.n_predict - // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated - int32_t n_prompt_tokens = 0; - int32_t n_prompt_tokens_processed = 0; - // input prompt tokens server_tokens prompt_tokens; @@ -1307,11 +1303,12 @@ struct server_slot { common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; // stats - size_t n_sent_text = 0; // number of sent text character + size_t n_sent_text = 0; // number of sent text character int64_t t_start_process_prompt; int64_t t_start_generation; + size_t n_prompt_processing = 0; // number of decoded prompt tokens (may be less than prompt_tokens.n_tokens(), in case we are using cache) double t_prompt_processing; // ms double t_token_generation; // ms @@ -1324,7 +1321,6 @@ struct server_slot { void reset() { SLT_DBG(*this, "%s", "\n"); - n_prompt_tokens = 0; last_nl_pos = 0; generated_text = ""; has_new_line = false; @@ -1334,6 +1330,7 @@ struct server_slot { n_past = 0; n_sent_text = 0; task_type = SERVER_TASK_TYPE_COMPLETION; + n_prompt_processing = 0; generated_tokens.clear(); generated_token_probs.clear(); @@ -1384,6 +1381,19 @@ struct server_slot { generated_token_probs.push_back(token); } + int32_t n_prompt_tokens() const { + return prompt_tokens.n_tokens(); + } + + int32_t n_cache_tokens() const { + return cache_tokens.n_tokens(); + } + + // different from n_past if context is shifted + llama_pos curr_pos() const { + return cache_tokens.n_pos(); + } + void release() { if (is_processing()) { SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated); @@ -1397,10 +1407,10 @@ struct server_slot { result_timings get_timings() const { result_timings timings; - timings.prompt_n = n_prompt_tokens_processed; + timings.prompt_n = n_prompt_processing; timings.prompt_ms = t_prompt_processing; - timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; - timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + timings.prompt_per_token_ms = t_prompt_processing / n_prompt_processing; + timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_processing; timings.predicted_n = n_decoded; timings.predicted_ms = t_token_generation; @@ -1446,8 +1456,8 @@ struct server_slot { } void print_timings() const { - const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; - const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + const double t_prompt = t_prompt_processing / n_prompt_processing; + const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_processing; const double t_gen = t_token_generation / n_decoded; const double n_gen_second = 1e3 / t_token_generation * n_decoded; @@ -1457,9 +1467,9 @@ struct server_slot { "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" " total time = %10.2f ms / %5d tokens\n", - t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, + t_prompt_processing, (int)n_prompt_processing, t_prompt, n_prompt_second, t_token_generation, n_decoded, t_gen, n_gen_second, - t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); + t_prompt_processing + t_token_generation, (int)n_prompt_processing + n_decoded); if (n_draft_total > 0) { const float draft_ratio = (float) n_draft_accepted / n_draft_total; @@ -1516,8 +1526,8 @@ struct server_metrics { } void on_prompt_eval(const server_slot & slot) { - n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; - n_prompt_tokens_processed += slot.n_prompt_tokens_processed; + n_prompt_tokens_processed_total += slot.n_prompt_tokens(); + n_prompt_tokens_processed += slot.n_prompt_processing; t_prompt_processing += slot.t_prompt_processing; t_prompt_processing_total += slot.t_prompt_processing; } @@ -2096,7 +2106,7 @@ struct server_context { int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens); // fraction of the common subsequence length compared to the current slot's prompt length - float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); + float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.n_pos()); // select the current slot if the criteria match if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) { @@ -2142,6 +2152,7 @@ struct server_context { slot.task_type = task.type; slot.params = std::move(task.params); slot.prompt_tokens = std::move(task.prompt_tokens); + slot.n_past = 0; if (!are_lora_equal(slot.params.lora, slot.lora)) { // if lora is changed, we cannot reuse cached tokens @@ -2251,12 +2262,14 @@ struct server_context { slot.has_next_token = true; } - // if context shifting is disabled, make sure that we don't run out of context - if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) { + // if context shift is disabled, we stop when it reaches the context limit + if (!params_base.ctx_shift && slot.n_cache_tokens() + 1 >= slot.n_ctx) { + slot.truncated = true; slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - SLT_DBG(slot, "stopped due to running out of context, n_past = %d, n_ctx = %d\n", slot.n_past, slot.n_ctx); + SLT_DBG(slot, "stopped due to running out of context capacity, n_cache_tokens = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", + slot.n_cache_tokens(), slot.n_prompt_tokens(), slot.n_decoded, slot.n_ctx); } // check the limits @@ -2316,16 +2329,6 @@ struct server_context { } } - // if context shift is disabled, we stop when it reaches the context limit - if (slot.n_past >= slot.n_ctx) { - slot.truncated = true; - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", - slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx); - } - if (llama_vocab_is_eog(vocab, result.tok)) { slot.stop = STOP_TYPE_EOS; slot.has_next_token = false; @@ -2335,7 +2338,7 @@ struct server_context { const auto n_ctx_train = llama_model_n_ctx_train(model); - if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { + if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) { slot.truncated = true; slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; // stop prediction @@ -2437,7 +2440,7 @@ struct server_context { res->tokens = { tkn.tok }; res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_prompt_tokens = slot.n_prompt_tokens(); res->post_sampling_probs = slot.params.post_sampling_probs; res->verbose = slot.params.verbose; @@ -2472,8 +2475,8 @@ struct server_context { res->truncated = slot.truncated; res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens; - res->n_tokens_cached = slot.n_past; + res->n_prompt_tokens = slot.n_prompt_tokens(); + res->n_tokens_cached = slot.n_cache_tokens(); res->has_new_line = slot.has_new_line; res->stopping_word = slot.stopping_word; res->stop = slot.stop; @@ -2510,7 +2513,7 @@ struct server_context { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; - res->n_tokens = slot.n_prompt_tokens; + res->n_tokens = slot.n_prompt_tokens(); res->oaicompat = slot.params.oaicompat; const int n_embd = llama_model_n_embd(model); @@ -2553,7 +2556,7 @@ struct server_context { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; - res->n_tokens = slot.n_prompt_tokens; + res->n_tokens = slot.n_prompt_tokens(); for (int i = 0; i < batch.n_tokens; ++i) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { @@ -2794,7 +2797,7 @@ struct server_context { break; } - const size_t token_count = slot->cache_tokens.size(); + const size_t token_count = slot->n_cache_tokens(); const int64_t t_start = ggml_time_us(); std::string filename = task.slot_action.filename; @@ -2880,7 +2883,7 @@ struct server_context { } // Erase token cache - const size_t n_erased = slot->cache_tokens.size(); + const size_t n_erased = slot->n_cache_tokens(); llama_kv_self_seq_rm(ctx, slot->id, -1, -1); slot->cache_tokens.clear(); @@ -2934,7 +2937,7 @@ struct server_context { // apply context-shift if needed // TODO: simplify and improve for (server_slot & slot : slots) { - if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) { + if (slot.is_processing() && slot.n_cache_tokens() + 1 >= slot.n_ctx) { if (!params_base.ctx_shift) { // this check is redundant (for good) // we should never get here, because generation should already stopped in process_token() @@ -2950,14 +2953,15 @@ struct server_context { } // Shift context + const int n_pos_cur = slot.cache_tokens.n_pos(); const int n_keep = slot.params.n_keep + add_bos_token; - const int n_left = slot.n_past - n_keep; + const int n_left = n_pos_cur - n_keep; const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); - llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard); + llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, n_pos_cur, -n_discard); // add generated tokens to cache { @@ -2966,13 +2970,11 @@ struct server_context { new_tokens[i - n_discard] = new_tokens[i]; } - new_tokens.resize(slot.cache_tokens.size() - n_discard); + new_tokens.resize(slot.n_cache_tokens() - n_discard); slot.cache_tokens.clear(); slot.cache_tokens.insert(new_tokens); } - slot.n_past -= n_discard; - slot.truncated = true; } } @@ -3002,13 +3004,13 @@ struct server_context { slot.i_batch = batch.n_tokens; - common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); + common_batch_add(batch, slot.sampled, slot.curr_pos(), { slot.id }, true); slot.n_past += 1; slot.cache_tokens.push_back(slot.sampled); - SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated); + SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_pos = %d, n_cache_tokens = %d, truncated = %d\n", + slot.n_ctx, slot.n_past, slot.cache_tokens.n_pos(), (int) slot.n_cache_tokens(), slot.truncated); } // process in chunks of params.n_batch @@ -3018,6 +3020,8 @@ struct server_context { // next, batch any pending prompts without exceeding n_batch if (params_base.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { + auto & prompt_tokens = slot.prompt_tokens; + // check if we can batch this slot with the previous one if (slot.is_processing()) { if (!slot_batched) { @@ -3029,18 +3033,14 @@ struct server_context { // this slot still has a prompt to be processed if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { - auto & prompt_tokens = slot.prompt_tokens; - // TODO: maybe move branch to outside of this loop in the future if (slot.state == SLOT_STATE_STARTED) { slot.t_start_process_prompt = ggml_time_us(); slot.t_start_generation = 0; - slot.n_past = 0; - slot.n_prompt_tokens = prompt_tokens.size(); slot.state = SLOT_STATE_PROCESSING_PROMPT; - SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); + SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens()); // print prompt tokens (for debugging) /*if (1) { @@ -3066,13 +3066,13 @@ struct server_context { } if (slot.is_non_causal()) { - if (slot.n_prompt_tokens > n_ubatch) { + if (slot.n_prompt_tokens() > n_ubatch) { slot.release(); send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); continue; } - if (slot.n_prompt_tokens > slot.n_ctx) { + if (slot.n_prompt_tokens() > slot.n_ctx) { slot.release(); send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER); continue; @@ -3082,52 +3082,21 @@ struct server_context { // if context shift is disabled, we make sure prompt size is smaller than KV size // TODO: there should be a separate parameter that control prompt truncation // context shift should be applied only during the generation phase - if (slot.n_prompt_tokens >= slot.n_ctx) { + if (slot.n_prompt_tokens() >= slot.n_ctx) { slot.release(); send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST); continue; } } if (slot.params.n_keep < 0) { - slot.params.n_keep = slot.n_prompt_tokens; + slot.params.n_keep = slot.n_prompt_tokens(); } slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - // if input prompt is too big, truncate it - if (slot.n_prompt_tokens >= slot.n_ctx) { - if (mctx) { - // we should never reach this - GGML_ABORT("not supported by multimodal"); - } - const int n_left = slot.n_ctx - slot.params.n_keep; - - const int n_block_size = n_left / 2; - const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; - - const llama_tokens & curr_tokens = slot.prompt_tokens.get_text_tokens(); - llama_tokens new_tokens( - curr_tokens.begin(), - curr_tokens.begin() + slot.params.n_keep); - - new_tokens.insert( - new_tokens.end(), - curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, - curr_tokens.end()); - - prompt_tokens.clear(); - prompt_tokens.insert(new_tokens); - - slot.truncated = true; - slot.n_prompt_tokens = prompt_tokens.size(); - - SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens); - - GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); - } - if (slot.params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); + slot.cache_tokens.keep_first(slot.n_past); // reuse chunks from the cached prompt by shifting their KV cache in the new position if (params_base.n_cache_reuse > 0) { @@ -3141,12 +3110,12 @@ struct server_context { SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); - while (head_c < slot.cache_tokens.size() && - head_p < prompt_tokens.size()) { + while (head_c < (size_t)slot.n_cache_tokens() && + head_p < prompt_tokens.n_tokens()) { size_t n_match = 0; - while (head_c + n_match < slot.cache_tokens.size() && - head_p + n_match < prompt_tokens.size() && + while (head_c + n_match < (size_t)slot.n_cache_tokens() && + head_p + n_match < prompt_tokens.n_tokens() && slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { n_match++; @@ -3176,6 +3145,7 @@ struct server_context { } SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); + slot.cache_tokens.keep_first(slot.n_past); } } else { // if we don't cache the prompt, we have to remove the entire KV cache @@ -3185,44 +3155,45 @@ struct server_context { } } - if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { + if (slot.n_past == slot.prompt_tokens.n_pos() && slot.n_past > 0) { // we have to evaluate at least 1 token to generate logits. - SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens); + SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens()); + slot.cache_tokens.rm_last(1); slot.n_past--; } - slot.n_prompt_tokens_processed = 0; - } - - // non-causal tasks require to fit the entire prompt in the physical batch - if (slot.is_non_causal()) { - // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { - continue; + // non-causal tasks require to fit the entire prompt in the physical batch + if (slot.is_non_causal()) { + // cannot fit the prompt in the current batch - will try next iter + if (batch.n_tokens + slot.n_prompt_tokens() > n_batch) { + continue; + } } - } - // keep only the common part - if (!llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1)) { - // could not partially delete (likely using a non-Transformer model) - llama_kv_self_seq_rm(ctx, slot.id, -1, -1); + // keep only the common part + if (!llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1)) { + // could not partially delete (likely using a non-Transformer model) + llama_kv_self_seq_rm(ctx, slot.id, -1, -1); - // there is no common part left - slot.n_past = 0; - } + // there is no common part left + slot.n_past = 0; + slot.cache_tokens.clear(); + } - SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); + SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); - // remove the non-common part from the cache - slot.cache_tokens.keep_first(slot.n_past); + // remove the non-common part from the cache + slot.cache_tokens.keep_first(slot.n_past); + } // check if we should process the image - if (slot.n_past < slot.n_prompt_tokens + if (slot.n_past < prompt_tokens.n_pos() && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { // process the image int32_t new_n_past; - int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); + size_t n_tok = 0; + int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past, n_tok); int32_t n_pos = new_n_past - slot.n_past; if (res != 0) { @@ -3238,12 +3209,12 @@ struct server_context { slot.cache_tokens.push_back(chunk.get()); // copy } - slot.n_past += n_pos; - slot.n_prompt_tokens_processed += n_pos; + slot.n_past += n_pos; + slot.n_prompt_processing += n_tok; // for stats only } // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + while (slot.n_past < prompt_tokens.n_pos() && batch.n_tokens < n_batch) { // get next token to process llama_token cur_tok = slot.prompt_tokens[slot.n_past]; if (cur_tok == LLAMA_TOKEN_NULL) { @@ -3251,30 +3222,30 @@ struct server_context { } // without pooling, we want to output the embeddings for all the tokens in the batch - const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; + const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING + && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); + common_batch_add(batch, cur_tok, slot.curr_pos(), { slot.id }, need_embd); slot.cache_tokens.push_back(cur_tok); - slot.n_prompt_tokens_processed++; slot.n_past++; + slot.n_prompt_processing++; // for stats only } // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); - SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + SLT_INF(slot, "prompt processing progress, n_past = %d, batch.n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_past / prompt_tokens.n_pos()); // entire prompt has been processed - if (slot.n_past == slot.n_prompt_tokens) { + if (slot.n_past == prompt_tokens.n_pos()) { slot.state = SLOT_STATE_DONE_PROMPT; GGML_ASSERT(batch.n_tokens > 0); - GGML_ASSERT((size_t) slot.n_prompt_tokens == slot.prompt_tokens.size()); common_sampler_reset(slot.smpl); // Process all prompt tokens through sampler system - for (int i = 0; i < slot.n_prompt_tokens; ++i) { + for (int i = 0; i < slot.n_prompt_tokens(); ++i) { llama_token id = slot.prompt_tokens[i]; if (id != LLAMA_TOKEN_NULL) { common_sampler_accept(slot.smpl, id, false); @@ -3287,7 +3258,7 @@ struct server_context { slot.n_decoded = 0; slot.i_batch = batch.n_tokens - 1; - SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); + SLT_INF(slot, "prompt done, n_past = %d, batch.n_tokens = %d\n", slot.n_past, batch.n_tokens); } } @@ -3302,7 +3273,7 @@ struct server_context { return; } - SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); + SRV_DBG("decoding batch, batch.n_tokens = %d\n", batch.n_tokens); if (slot_batched) { // make sure we're in the right embedding mode @@ -3439,9 +3410,9 @@ struct server_context { // determine the max draft that fits the current slot state int n_draft_max = slot.params.speculative.n_max; - // note: n_past is not yet increased for the `id` token sampled above + // note: slot.curr_pos() is not yet increased for the `id` token sampled above // also, need to leave space for 1 extra token to allow context shifts - n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2); + n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.curr_pos() - 2); if (slot.n_remaining > 0) { n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); @@ -3477,10 +3448,10 @@ struct server_context { // construct the speculation batch common_batch_clear(slot.batch_spec); - common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true); + common_batch_add (slot.batch_spec, id, slot.curr_pos(), { slot.id }, true); for (size_t i = 0; i < draft.size(); ++i) { - common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); + common_batch_add(slot.batch_spec, draft[i], slot.curr_pos() + 1 + i, { slot.id }, true); } SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); @@ -3499,7 +3470,7 @@ struct server_context { slot.cache_tokens.push_back(id); slot.cache_tokens.insert({ids.begin(), ids.end() - 1}); - llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1); + llama_kv_self_seq_rm(ctx, slot.id, slot.curr_pos(), -1); for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; diff --git a/tools/server/tests/unit/test_ctx_shift.py b/tools/server/tests/unit/test_ctx_shift.py index 2431ac70882d7..dee774b7f667d 100644 --- a/tools/server/tests/unit/test_ctx_shift.py +++ b/tools/server/tests/unit/test_ctx_shift.py @@ -31,7 +31,7 @@ def test_ctx_shift_enabled(): "prompt": LONG_TEXT, }) assert res.status_code == 200 - assert res.body["timings"]["prompt_n"] == 109 + assert res.body["timings"]["prompt_n"] == 301 assert res.body["timings"]["predicted_n"] == 64 assert res.body["truncated"] is True diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 232eef195437f..f3919f3421eb3 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -1052,6 +1052,11 @@ struct server_tokens { // pos 0 1 2 3 4 5 6 7 8 9 // map_pos_to_image will contain: {5, img0}, {8, img1} + // number of tokens contained in this object + // note that the number of tokens can be larger than the number of positions + // for example, models using m-rope can have multiple tokens that share a position + size_t n_tok = 0; + public: server_tokens() = default; ~server_tokens() = default; @@ -1074,7 +1079,7 @@ struct server_tokens { } } - server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} + server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens), n_tok(tokens.size()) {} // for debugging std::string str() const { @@ -1108,6 +1113,7 @@ struct server_tokens { if (tok == LLAMA_TOKEN_NULL) { throw std::runtime_error("Invalid token"); } + n_tok++; tokens.emplace_back(tok); } @@ -1122,6 +1128,7 @@ struct server_tokens { for (int i = 0; i < n_pos; ++i) { tokens.emplace_back(LLAMA_TOKEN_NULL); } + n_tok += mtmd_image_tokens_get_n_tokens(img_tokens); mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); map_pos_to_image[start_pos] = std::move(new_chunk); } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { @@ -1138,6 +1145,7 @@ struct server_tokens { // for compatibility with context shift and prompt truncation void insert(const llama_tokens & inp_tokens) { GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + n_tok += inp_tokens.size(); tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); } @@ -1153,7 +1161,11 @@ struct server_tokens { tokens[pos] = id; } - size_t size() const { + size_t n_tokens() const { + return n_tok; + } + + llama_pos n_pos() const { return tokens.size(); } @@ -1162,11 +1174,17 @@ struct server_tokens { } void clear() { + n_tok = 0; tokens.clear(); } - void keep_first(size_t n) { - GGML_ASSERT(n <= tokens.size()); + void keep_first(size_t n_pos) { + GGML_ASSERT(n_pos <= tokens.size()); + size_t n_pos_rm = tokens.size() - n_pos; + // num of tokens to remove = n_tok_text + n_tok_img + // = (n_pos_rm - n_pos_img) + n_tok_img + size_t n_pos_img = 0; + size_t n_tok_img = 0; if (has_mtmd) { // we throw an error if we try to remove a token in the middle of an image // for ex. with input of 5 text tokens and 2 images: @@ -1174,24 +1192,32 @@ struct server_tokens { // n 1 2 3 4 5 6 7 8 9 10 // allowed to resize ^ ^ // disallowed to resize ^ ^ ^ - if (n > 0) { - llama_token last_token = tokens[n - 1]; + if (n_pos > 0) { + llama_token last_token = tokens[n_pos - 1]; // make sure we never remove tokens in the middle of an image if (last_token == LLAMA_TOKEN_NULL) { - find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk + find_chunk(n_pos - 1); // will throw an error if the token is not begin-of-chunk } } // remove all image chunks that are not used anymore for (auto it = map_pos_to_image.begin(); it != map_pos_to_image.end(); ) { llama_pos pos = it->first; - if (pos >= (llama_pos)n) { + if (pos >= (llama_pos)n_pos) { + auto img_tokens = mtmd_input_chunk_get_tokens_image(it->second.get()); + n_pos_img += mtmd_image_tokens_get_n_pos(img_tokens); + n_tok_img += mtmd_image_tokens_get_n_tokens(img_tokens); it = map_pos_to_image.erase(it); } else { ++it; } } } - tokens.resize(n); + n_tok -= (n_pos_rm - n_pos_img) + n_tok_img; + tokens.resize(n_pos); + } + + void rm_last(size_t n_pos) { + keep_first(tokens.size() - n_pos); } std::string detokenize(const llama_context * ctx, bool special) const { @@ -1205,7 +1231,8 @@ struct server_tokens { return common_detokenize(ctx, text_tokens, special); } - size_t get_common_prefix(const server_tokens & b) const { + // returns the first position where the tokens differ + llama_pos get_common_prefix(const server_tokens & b) const { size_t max_idx = std::min(tokens.size(), b.tokens.size()); for (size_t i = 0; i < max_idx; ++i) { auto & ai = tokens[i]; @@ -1268,7 +1295,8 @@ struct server_tokens { mtmd_context * mctx, llama_pos n_past, int32_t seq_id, - llama_pos & n_pos_out) { + llama_pos & n_pos_out, + size_t & n_tokens_out) { auto it = map_pos_to_image.find(n_past); if (it == map_pos_to_image.end()) { throw std::runtime_error("Chunk not found"); @@ -1290,6 +1318,8 @@ struct server_tokens { n_pos_out = n_past; return result; } + auto img_tokens = mtmd_input_chunk_get_tokens_image(it->second.get()); + n_tokens_out = mtmd_image_tokens_get_n_tokens(img_tokens); n_pos_out = new_n_past; return 0; }