diff --git a/common/arg.cpp b/common/arg.cpp
index 7e2b48aec44..a4e5bc66120 100644
--- a/common/arg.cpp
+++ b/common/arg.cpp
@@ -3243,6 +3243,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.slot_prompt_similarity = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
+ add_opt(common_arg(
+ {"--slot-cache-key-similarity"}, "SIMILARITY",
+ string_format("how much the prompt of a cache_key request must match the cached slot prompt before reusing it (default: %.2f, 0.0 = disable ratio check)\n", params.slot_cache_key_similarity),
+ [](common_params & params, const std::string & value) {
+ params.slot_cache_key_similarity = std::stof(value);
+ }
+ ).set_examples({LLAMA_EXAMPLE_SERVER}));
+ add_opt(common_arg(
+ {"--slot-cache-key-min-prefix"}, "N",
+ string_format("minimum common-prefix tokens required before reusing a cache_key slot (default: %d, 0 = disabled)\n", params.slot_cache_key_min_prefix),
+ [](common_params & params, const std::string & value) {
+ params.slot_cache_key_min_prefix = std::stoi(value);
+ }
+ ).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--lora-init-without-apply"},
string_format("load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"),
diff --git a/common/common.h b/common/common.h
index 1d3d788b2de..46efe749185 100644
--- a/common/common.h
+++ b/common/common.h
@@ -648,7 +648,9 @@ struct common_params {
std::string slot_save_path;
std::string media_path; // path to directory for loading media files
- float slot_prompt_similarity = 0.1f;
+ float slot_prompt_similarity = 0.1f;
+ float slot_cache_key_similarity = 0.5f;
+ int32_t slot_cache_key_min_prefix = 32;
// batched-bench params
bool is_pp_shared = false;
diff --git a/tools/server/README.md b/tools/server/README.md
index 11098af2883..1ecfb9cdd81 100644
--- a/tools/server/README.md
+++ b/tools/server/README.md
@@ -227,6 +227,8 @@ For the full list of features, please refer to [server's changelog](https://gith
| `--skip-chat-parsing, --no-skip-chat-parsing` | force a pure content parser, even if a Jinja template is specified; model will output everything in the content section, including any reasoning and/or tool calls (default: disabled)
(env: LLAMA_ARG_SKIP_CHAT_PARSING) |
| `--prefill-assistant, --no-prefill-assistant` | whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)
when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled
(env: LLAMA_ARG_PREFILL_ASSISTANT) |
| `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.10, 0.0 = disabled) |
+| `--slot-cache-key-similarity SIMILARITY` | how much the prompt of a cache_key request must match the cached slot prompt before reusing it (default: 0.50, 0.0 = disable ratio check) |
+| `--slot-cache-key-min-prefix N` | minimum common-prefix tokens required before reusing a cache_key slot (default: 32, 0 = disabled) |
| `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) |
| `--sleep-idle-seconds SECONDS` | number of seconds of idleness after which the server will sleep (default: -1; -1 = disabled) |
| `--spec-draft-hf, -hfd, -hfrd, --hf-repo-draft /[:quant]` | Same as --hf-repo, but for the draft model (default: unused)
(env: LLAMA_ARG_SPEC_DRAFT_HF_REPO) |
diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp
index 6096dd6b728..3dcc03d0290 100644
--- a/tools/server/server-context.cpp
+++ b/tools/server/server-context.cpp
@@ -22,6 +22,7 @@
#include
#include
#include
+#include
// fix problem with std::min and std::max
#if defined(_WIN32)
@@ -668,6 +669,7 @@ struct server_context_impl {
// slots / clients
std::vector slots;
+ std::unordered_map cache_key_slots;
int trace = 0;
int slots_debug = 0;
@@ -682,6 +684,8 @@ struct server_context_impl {
// Necessary similarity of prompt for slot selection
float slot_prompt_similarity = 0.0f;
+ float slot_cache_key_similarity = 0.0f;
+ size_t slot_cache_key_min_prefix = 0;
std::string model_name; // name of the loaded model, to be used by API
std::set model_aliases; // additional names for the model
@@ -873,6 +877,8 @@ struct server_context_impl {
// Necessary similarity of prompt for slot selection
slot_prompt_similarity = params_base.slot_prompt_similarity;
+ slot_cache_key_similarity = params_base.slot_cache_key_similarity;
+ slot_cache_key_min_prefix = std::max(0, params_base.slot_cache_key_min_prefix);
// setup slots
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
@@ -1106,13 +1112,82 @@ struct server_context_impl {
return nullptr;
}
- server_slot * get_available_slot(const server_task & task) {
+ server_slot * get_slot_by_cache_key(const std::string & cache_key) {
+ if (cache_key.empty()) {
+ return nullptr;
+ }
+
+ auto it = cache_key_slots.find(cache_key);
+ if (it == cache_key_slots.end()) {
+ return nullptr;
+ }
+
+ server_slot * slot = get_slot_by_id(it->second);
+ if (slot == nullptr) {
+ cache_key_slots.erase(it);
+ return nullptr;
+ }
+
+ if (slot->prompt.tokens.empty()) {
+ SLT_INF(*slot, "ignoring cache_key slot with empty prompt, key = %s\n", cache_key.c_str());
+ cache_key_slots.erase(it);
+ return nullptr;
+ }
+
+ if (slot->is_processing()) {
+ SLT_INF(*slot, "ignoring busy cache_key slot, key = %s\n", cache_key.c_str());
+ return nullptr;
+ }
+
+ return slot;
+ }
+
+ bool cache_key_slot_has_enough_similarity(const server_slot & slot, const server_task & task) const {
+ if (slot.prompt.tokens.empty() || task.tokens.empty()) {
+ SLT_INF(slot, "ignoring cache_key slot with empty prompt or task, key = %s\n", task.cache_key.c_str());
+ return false;
+ }
+
+ const size_t n_common = slot.prompt.tokens.get_common_prefix(task.tokens);
+ const float sim_cur = float(n_common) / task.tokens.size();
+ const bool enough_prefix = n_common >= slot_cache_key_min_prefix;
+ const bool enough_similarity = slot_cache_key_similarity <= 0.0f || sim_cur >= slot_cache_key_similarity;
+ if (enough_prefix && enough_similarity) {
+ SLT_INF(slot, "selected slot by cache_key, sim = %.3f (>= %.3f thold), common = %zu (>= %zu), key = %s\n",
+ sim_cur, slot_cache_key_similarity, n_common, slot_cache_key_min_prefix, task.cache_key.c_str());
+ return true;
+ }
+
+ SLT_INF(slot, "ignoring cache_key slot, sim = %.3f (< %.3f thold) or common = %zu (< %zu), key = %s\n",
+ sim_cur, slot_cache_key_similarity, n_common, slot_cache_key_min_prefix, task.cache_key.c_str());
+ return false;
+ }
+
+ void clear_cache_keys_for_slot(int id_slot) {
+ for (auto it = cache_key_slots.begin(); it != cache_key_slots.end(); ) {
+ if (it->second == id_slot) {
+ it = cache_key_slots.erase(it);
+ } else {
+ ++it;
+ }
+ }
+ }
+
+ void bind_cache_key_to_slot(const std::string & cache_key, int id_slot) {
+ clear_cache_keys_for_slot(id_slot);
+
+ if (!cache_key.empty()) {
+ cache_key_slots[cache_key] = id_slot;
+ }
+ }
+
+ server_slot * get_available_slot(const server_task & task, bool allow_prompt_similarity = true) {
server_slot * ret = nullptr;
bool update_cache = false;
// find the slot that has at least n% prompt similarity
- if (ret == nullptr && slot_prompt_similarity != 0.0f) {
+ if (allow_prompt_similarity && ret == nullptr && slot_prompt_similarity != 0.0f) {
float sim_best = 0;
for (server_slot & slot : slots) {
@@ -1366,6 +1441,8 @@ struct server_context_impl {
? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
: SLOT_STATE_STARTED;
+ bind_cache_key_to_slot(slot.task->cache_key, slot.id);
+
// reset server kill-switch counter
n_empty_consecutive = 0;
@@ -1898,7 +1975,18 @@ struct server_context_impl {
const int id_slot = task.id_slot;
const int id_task = task.id;
- server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
+ server_slot * slot = nullptr;
+ if (id_slot != -1) {
+ slot = get_slot_by_id(id_slot);
+ } else if (!task.cache_key.empty()) {
+ server_slot * slot_cache_key = get_slot_by_cache_key(task.cache_key);
+ if (slot_cache_key != nullptr && cache_key_slot_has_enough_similarity(*slot_cache_key, task)) {
+ slot = slot_cache_key;
+ }
+ }
+ if (slot == nullptr) {
+ slot = get_available_slot(task, task.cache_key.empty());
+ }
//
// slot scheduling logic
@@ -3456,6 +3544,7 @@ std::unique_ptr server_routes::handle_completions_impl(
task.params.res_type = res_type;
task.params.oaicompat_cmpl_id = completion_id;
task.params.oaicompat_model = meta->model_name;
+ task.cache_key = json_value(data, "cache_key", json_value(data, "session_id", std::string()));
// prepare child tasks
if (task.params.n_cmpl > 1) {
diff --git a/tools/server/server-task.h b/tools/server/server-task.h
index 64bdecd794f..0fba3a92f47 100644
--- a/tools/server/server-task.h
+++ b/tools/server/server-task.h
@@ -135,6 +135,7 @@ struct server_task {
// used by SERVER_TASK_TYPE_CANCEL
int id_target = -1;
int id_slot = -1;
+ std::string cache_key;
// used by parallel sampling (multiple completions from same prompt)
int id_parent = -1;
@@ -234,6 +235,7 @@ struct server_task {
copy.type = type;
copy.tokens = tokens.clone();
copy.id_slot = -1; // child tasks cannot specify slot
+ copy.cache_key.clear();
// use different sampling seed for each child
// note: https://github.com/ggml-org/llama.cpp/pull/18700#discussion_r2675115723