diff --git a/common/arg.cpp b/common/arg.cpp index 7e2b48aec44..a4e5bc66120 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3243,6 +3243,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.slot_prompt_similarity = std::stof(value); } ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--slot-cache-key-similarity"}, "SIMILARITY", + string_format("how much the prompt of a cache_key request must match the cached slot prompt before reusing it (default: %.2f, 0.0 = disable ratio check)\n", params.slot_cache_key_similarity), + [](common_params & params, const std::string & value) { + params.slot_cache_key_similarity = std::stof(value); + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--slot-cache-key-min-prefix"}, "N", + string_format("minimum common-prefix tokens required before reusing a cache_key slot (default: %d, 0 = disabled)\n", params.slot_cache_key_min_prefix), + [](common_params & params, const std::string & value) { + params.slot_cache_key_min_prefix = std::stoi(value); + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--lora-init-without-apply"}, string_format("load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"), diff --git a/common/common.h b/common/common.h index 1d3d788b2de..46efe749185 100644 --- a/common/common.h +++ b/common/common.h @@ -648,7 +648,9 @@ struct common_params { std::string slot_save_path; std::string media_path; // path to directory for loading media files - float slot_prompt_similarity = 0.1f; + float slot_prompt_similarity = 0.1f; + float slot_cache_key_similarity = 0.5f; + int32_t slot_cache_key_min_prefix = 32; // batched-bench params bool is_pp_shared = false; diff --git a/tools/server/README.md b/tools/server/README.md index 11098af2883..1ecfb9cdd81 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -227,6 +227,8 @@ For the full list of features, please refer to [server's changelog](https://gith | `--skip-chat-parsing, --no-skip-chat-parsing` | force a pure content parser, even if a Jinja template is specified; model will output everything in the content section, including any reasoning and/or tool calls (default: disabled)
(env: LLAMA_ARG_SKIP_CHAT_PARSING) | | `--prefill-assistant, --no-prefill-assistant` | whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)
when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled

(env: LLAMA_ARG_PREFILL_ASSISTANT) | | `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.10, 0.0 = disabled) | +| `--slot-cache-key-similarity SIMILARITY` | how much the prompt of a cache_key request must match the cached slot prompt before reusing it (default: 0.50, 0.0 = disable ratio check) | +| `--slot-cache-key-min-prefix N` | minimum common-prefix tokens required before reusing a cache_key slot (default: 32, 0 = disabled) | | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) | | `--sleep-idle-seconds SECONDS` | number of seconds of idleness after which the server will sleep (default: -1; -1 = disabled) | | `--spec-draft-hf, -hfd, -hfrd, --hf-repo-draft /[:quant]` | Same as --hf-repo, but for the draft model (default: unused)
(env: LLAMA_ARG_SPEC_DRAFT_HF_REPO) | diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 6096dd6b728..3dcc03d0290 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -22,6 +22,7 @@ #include #include #include +#include // fix problem with std::min and std::max #if defined(_WIN32) @@ -668,6 +669,7 @@ struct server_context_impl { // slots / clients std::vector slots; + std::unordered_map cache_key_slots; int trace = 0; int slots_debug = 0; @@ -682,6 +684,8 @@ struct server_context_impl { // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; + float slot_cache_key_similarity = 0.0f; + size_t slot_cache_key_min_prefix = 0; std::string model_name; // name of the loaded model, to be used by API std::set model_aliases; // additional names for the model @@ -873,6 +877,8 @@ struct server_context_impl { // Necessary similarity of prompt for slot selection slot_prompt_similarity = params_base.slot_prompt_similarity; + slot_cache_key_similarity = params_base.slot_cache_key_similarity; + slot_cache_key_min_prefix = std::max(0, params_base.slot_cache_key_min_prefix); // setup slots SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); @@ -1106,13 +1112,82 @@ struct server_context_impl { return nullptr; } - server_slot * get_available_slot(const server_task & task) { + server_slot * get_slot_by_cache_key(const std::string & cache_key) { + if (cache_key.empty()) { + return nullptr; + } + + auto it = cache_key_slots.find(cache_key); + if (it == cache_key_slots.end()) { + return nullptr; + } + + server_slot * slot = get_slot_by_id(it->second); + if (slot == nullptr) { + cache_key_slots.erase(it); + return nullptr; + } + + if (slot->prompt.tokens.empty()) { + SLT_INF(*slot, "ignoring cache_key slot with empty prompt, key = %s\n", cache_key.c_str()); + cache_key_slots.erase(it); + return nullptr; + } + + if (slot->is_processing()) { + SLT_INF(*slot, "ignoring busy cache_key slot, key = %s\n", cache_key.c_str()); + return nullptr; + } + + return slot; + } + + bool cache_key_slot_has_enough_similarity(const server_slot & slot, const server_task & task) const { + if (slot.prompt.tokens.empty() || task.tokens.empty()) { + SLT_INF(slot, "ignoring cache_key slot with empty prompt or task, key = %s\n", task.cache_key.c_str()); + return false; + } + + const size_t n_common = slot.prompt.tokens.get_common_prefix(task.tokens); + const float sim_cur = float(n_common) / task.tokens.size(); + const bool enough_prefix = n_common >= slot_cache_key_min_prefix; + const bool enough_similarity = slot_cache_key_similarity <= 0.0f || sim_cur >= slot_cache_key_similarity; + if (enough_prefix && enough_similarity) { + SLT_INF(slot, "selected slot by cache_key, sim = %.3f (>= %.3f thold), common = %zu (>= %zu), key = %s\n", + sim_cur, slot_cache_key_similarity, n_common, slot_cache_key_min_prefix, task.cache_key.c_str()); + return true; + } + + SLT_INF(slot, "ignoring cache_key slot, sim = %.3f (< %.3f thold) or common = %zu (< %zu), key = %s\n", + sim_cur, slot_cache_key_similarity, n_common, slot_cache_key_min_prefix, task.cache_key.c_str()); + return false; + } + + void clear_cache_keys_for_slot(int id_slot) { + for (auto it = cache_key_slots.begin(); it != cache_key_slots.end(); ) { + if (it->second == id_slot) { + it = cache_key_slots.erase(it); + } else { + ++it; + } + } + } + + void bind_cache_key_to_slot(const std::string & cache_key, int id_slot) { + clear_cache_keys_for_slot(id_slot); + + if (!cache_key.empty()) { + cache_key_slots[cache_key] = id_slot; + } + } + + server_slot * get_available_slot(const server_task & task, bool allow_prompt_similarity = true) { server_slot * ret = nullptr; bool update_cache = false; // find the slot that has at least n% prompt similarity - if (ret == nullptr && slot_prompt_similarity != 0.0f) { + if (allow_prompt_similarity && ret == nullptr && slot_prompt_similarity != 0.0f) { float sim_best = 0; for (server_slot & slot : slots) { @@ -1366,6 +1441,8 @@ struct server_context_impl { ? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt : SLOT_STATE_STARTED; + bind_cache_key_to_slot(slot.task->cache_key, slot.id); + // reset server kill-switch counter n_empty_consecutive = 0; @@ -1898,7 +1975,18 @@ struct server_context_impl { const int id_slot = task.id_slot; const int id_task = task.id; - server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); + server_slot * slot = nullptr; + if (id_slot != -1) { + slot = get_slot_by_id(id_slot); + } else if (!task.cache_key.empty()) { + server_slot * slot_cache_key = get_slot_by_cache_key(task.cache_key); + if (slot_cache_key != nullptr && cache_key_slot_has_enough_similarity(*slot_cache_key, task)) { + slot = slot_cache_key; + } + } + if (slot == nullptr) { + slot = get_available_slot(task, task.cache_key.empty()); + } // // slot scheduling logic @@ -3456,6 +3544,7 @@ std::unique_ptr server_routes::handle_completions_impl( task.params.res_type = res_type; task.params.oaicompat_cmpl_id = completion_id; task.params.oaicompat_model = meta->model_name; + task.cache_key = json_value(data, "cache_key", json_value(data, "session_id", std::string())); // prepare child tasks if (task.params.n_cmpl > 1) { diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 64bdecd794f..0fba3a92f47 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -135,6 +135,7 @@ struct server_task { // used by SERVER_TASK_TYPE_CANCEL int id_target = -1; int id_slot = -1; + std::string cache_key; // used by parallel sampling (multiple completions from same prompt) int id_parent = -1; @@ -234,6 +235,7 @@ struct server_task { copy.type = type; copy.tokens = tokens.clone(); copy.id_slot = -1; // child tasks cannot specify slot + copy.cache_key.clear(); // use different sampling seed for each child // note: https://github.com/ggml-org/llama.cpp/pull/18700#discussion_r2675115723