Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3243,6 +3243,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.slot_prompt_similarity = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--slot-cache-key-similarity"}, "SIMILARITY",
string_format("how much the prompt of a cache_key request must match the cached slot prompt before reusing it (default: %.2f, 0.0 = disable ratio check)\n", params.slot_cache_key_similarity),
[](common_params & params, const std::string & value) {
params.slot_cache_key_similarity = std::stof(value);
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--slot-cache-key-min-prefix"}, "N",
string_format("minimum common-prefix tokens required before reusing a cache_key slot (default: %d, 0 = disabled)\n", params.slot_cache_key_min_prefix),
[](common_params & params, const std::string & value) {
params.slot_cache_key_min_prefix = std::stoi(value);
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--lora-init-without-apply"},
string_format("load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"),
Expand Down
4 changes: 3 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,9 @@ struct common_params {
std::string slot_save_path;
std::string media_path; // path to directory for loading media files

float slot_prompt_similarity = 0.1f;
float slot_prompt_similarity = 0.1f;
float slot_cache_key_similarity = 0.5f;
int32_t slot_cache_key_min_prefix = 32;

// batched-bench params
bool is_pp_shared = false;
Expand Down
2 changes: 2 additions & 0 deletions tools/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ For the full list of features, please refer to [server's changelog](https://gith
| `--skip-chat-parsing, --no-skip-chat-parsing` | force a pure content parser, even if a Jinja template is specified; model will output everything in the content section, including any reasoning and/or tool calls (default: disabled)<br/>(env: LLAMA_ARG_SKIP_CHAT_PARSING) |
| `--prefill-assistant, --no-prefill-assistant` | whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)<br/>when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled<br/><br/>(env: LLAMA_ARG_PREFILL_ASSISTANT) |
| `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.10, 0.0 = disabled) |
| `--slot-cache-key-similarity SIMILARITY` | how much the prompt of a cache_key request must match the cached slot prompt before reusing it (default: 0.50, 0.0 = disable ratio check) |
| `--slot-cache-key-min-prefix N` | minimum common-prefix tokens required before reusing a cache_key slot (default: 32, 0 = disabled) |
| `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) |
| `--sleep-idle-seconds SECONDS` | number of seconds of idleness after which the server will sleep (default: -1; -1 = disabled) |
| `--spec-draft-hf, -hfd, -hfrd, --hf-repo-draft <user>/<model>[:quant]` | Same as --hf-repo, but for the draft model (default: unused)<br/>(env: LLAMA_ARG_SPEC_DRAFT_HF_REPO) |
Expand Down
95 changes: 92 additions & 3 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <memory>
#include <filesystem>
#include <utility>
#include <unordered_map>

// fix problem with std::min and std::max
#if defined(_WIN32)
Expand Down Expand Up @@ -668,6 +669,7 @@ struct server_context_impl {

// slots / clients
std::vector<server_slot> slots;
std::unordered_map<std::string, int> cache_key_slots;

int trace = 0;
int slots_debug = 0;
Expand All @@ -682,6 +684,8 @@ struct server_context_impl {

// Necessary similarity of prompt for slot selection
float slot_prompt_similarity = 0.0f;
float slot_cache_key_similarity = 0.0f;
size_t slot_cache_key_min_prefix = 0;

std::string model_name; // name of the loaded model, to be used by API
std::set<std::string> model_aliases; // additional names for the model
Expand Down Expand Up @@ -873,6 +877,8 @@ struct server_context_impl {

// Necessary similarity of prompt for slot selection
slot_prompt_similarity = params_base.slot_prompt_similarity;
slot_cache_key_similarity = params_base.slot_cache_key_similarity;
slot_cache_key_min_prefix = std::max<int32_t>(0, params_base.slot_cache_key_min_prefix);

// setup slots
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
Expand Down Expand Up @@ -1106,13 +1112,82 @@ struct server_context_impl {
return nullptr;
}

server_slot * get_available_slot(const server_task & task) {
server_slot * get_slot_by_cache_key(const std::string & cache_key) {
if (cache_key.empty()) {
return nullptr;
}

auto it = cache_key_slots.find(cache_key);
if (it == cache_key_slots.end()) {
return nullptr;
}

server_slot * slot = get_slot_by_id(it->second);
if (slot == nullptr) {
cache_key_slots.erase(it);
return nullptr;
}

if (slot->prompt.tokens.empty()) {
SLT_INF(*slot, "ignoring cache_key slot with empty prompt, key = %s\n", cache_key.c_str());
cache_key_slots.erase(it);
return nullptr;
}

if (slot->is_processing()) {
SLT_INF(*slot, "ignoring busy cache_key slot, key = %s\n", cache_key.c_str());
return nullptr;
}

return slot;
}

bool cache_key_slot_has_enough_similarity(const server_slot & slot, const server_task & task) const {
if (slot.prompt.tokens.empty() || task.tokens.empty()) {
SLT_INF(slot, "ignoring cache_key slot with empty prompt or task, key = %s\n", task.cache_key.c_str());
return false;
}

const size_t n_common = slot.prompt.tokens.get_common_prefix(task.tokens);
const float sim_cur = float(n_common) / task.tokens.size();
const bool enough_prefix = n_common >= slot_cache_key_min_prefix;
const bool enough_similarity = slot_cache_key_similarity <= 0.0f || sim_cur >= slot_cache_key_similarity;
if (enough_prefix && enough_similarity) {
SLT_INF(slot, "selected slot by cache_key, sim = %.3f (>= %.3f thold), common = %zu (>= %zu), key = %s\n",
sim_cur, slot_cache_key_similarity, n_common, slot_cache_key_min_prefix, task.cache_key.c_str());
return true;
}

SLT_INF(slot, "ignoring cache_key slot, sim = %.3f (< %.3f thold) or common = %zu (< %zu), key = %s\n",
sim_cur, slot_cache_key_similarity, n_common, slot_cache_key_min_prefix, task.cache_key.c_str());
return false;
}

void clear_cache_keys_for_slot(int id_slot) {
for (auto it = cache_key_slots.begin(); it != cache_key_slots.end(); ) {
if (it->second == id_slot) {
it = cache_key_slots.erase(it);
} else {
++it;
}
}
}

void bind_cache_key_to_slot(const std::string & cache_key, int id_slot) {
clear_cache_keys_for_slot(id_slot);

if (!cache_key.empty()) {
cache_key_slots[cache_key] = id_slot;
}
}

server_slot * get_available_slot(const server_task & task, bool allow_prompt_similarity = true) {
server_slot * ret = nullptr;

bool update_cache = false;

// find the slot that has at least n% prompt similarity
if (ret == nullptr && slot_prompt_similarity != 0.0f) {
if (allow_prompt_similarity && ret == nullptr && slot_prompt_similarity != 0.0f) {
float sim_best = 0;

for (server_slot & slot : slots) {
Expand Down Expand Up @@ -1366,6 +1441,8 @@ struct server_context_impl {
? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt
: SLOT_STATE_STARTED;

bind_cache_key_to_slot(slot.task->cache_key, slot.id);

// reset server kill-switch counter
n_empty_consecutive = 0;

Expand Down Expand Up @@ -1898,7 +1975,18 @@ struct server_context_impl {
const int id_slot = task.id_slot;
const int id_task = task.id;

server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
server_slot * slot = nullptr;
if (id_slot != -1) {
slot = get_slot_by_id(id_slot);
} else if (!task.cache_key.empty()) {
server_slot * slot_cache_key = get_slot_by_cache_key(task.cache_key);
if (slot_cache_key != nullptr && cache_key_slot_has_enough_similarity(*slot_cache_key, task)) {
slot = slot_cache_key;
}
}
if (slot == nullptr) {
slot = get_available_slot(task, task.cache_key.empty());
}

//
// slot scheduling logic
Expand Down Expand Up @@ -3456,6 +3544,7 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
task.params.res_type = res_type;
task.params.oaicompat_cmpl_id = completion_id;
task.params.oaicompat_model = meta->model_name;
task.cache_key = json_value(data, "cache_key", json_value(data, "session_id", std::string()));

// prepare child tasks
if (task.params.n_cmpl > 1) {
Expand Down
2 changes: 2 additions & 0 deletions tools/server/server-task.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ struct server_task {
// used by SERVER_TASK_TYPE_CANCEL
int id_target = -1;
int id_slot = -1;
std::string cache_key;

// used by parallel sampling (multiple completions from same prompt)
int id_parent = -1;
Expand Down Expand Up @@ -234,6 +235,7 @@ struct server_task {
copy.type = type;
copy.tokens = tokens.clone();
copy.id_slot = -1; // child tasks cannot specify slot
copy.cache_key.clear();

// use different sampling seed for each child
// note: https://github.com/ggml-org/llama.cpp/pull/18700#discussion_r2675115723
Expand Down