diff --git a/server/src/common/model_backend.h b/server/src/common/model_backend.h index de439092d..b808d0c39 100644 --- a/server/src/common/model_backend.h +++ b/server/src/common/model_backend.h @@ -19,6 +19,7 @@ #include "ggml.h" #include "ggml-backend.h" #include "sampler.h" +#include "placement/draft_residency.h" namespace dflash::common { @@ -250,6 +251,7 @@ struct ModelBackend { std::string drafter_path; // GGUF path (for lazy-load) int drafter_gpu = 0; // backend-local GPU for PFlash drafter bool skip_park = false; // true on >=32GB GPUs + DraftResidencyAction residency_action = DraftResidencyAction::KeepLoaded; }; struct CompressResult { diff --git a/server/src/gemma4/gemma4_backend.cpp b/server/src/gemma4/gemma4_backend.cpp index e09ce575c..834a24d1b 100644 --- a/server/src/gemma4/gemma4_backend.cpp +++ b/server/src/gemma4/gemma4_backend.cpp @@ -55,89 +55,9 @@ bool Gemma4Backend::init() { } cache_.fa_window = cfg_.fa_window; - // Load draft model for speculative decode - if (cfg_.draft_path) { - const int draft_gpu = (cfg_.draft_gpu >= 0) ? cfg_.draft_gpu : cfg_.device.gpu; - draft_backend_ = ggml_backend_cuda_init(draft_gpu); - if (!draft_backend_) { - std::fprintf(stderr, "[gemma4] draft CUDA init failed (gpu=%d)\n", draft_gpu); - } else { - // Load draft GGUF — pass nullptr for target (Gemma4 != TargetWeights) - if (!load_draft_gguf(cfg_.draft_path, draft_backend_, dw_, nullptr)) { - std::fprintf(stderr, "[gemma4] draft load failed: %s\n", dflash27b_last_error()); - ggml_backend_free(draft_backend_); draft_backend_ = nullptr; - } else { - // Override mask_token_id for Gemma4 (token 4 per model card) - dw_.mask_token_id = 4; - - // Fix draft dimensions from actual tensor shapes (GGUF metadata is wrong) - // fc.weight: [fc_in, draft_hidden] - const int draft_hidden = (int)dw_.fc->ne[1]; - const int fc_in = (int)dw_.fc->ne[0]; - const int n_capture = fc_in / w_.n_embd; - - if (draft_hidden != dw_.n_embd) { - std::printf("[gemma4] draft: overriding n_embd %d -> %d (from fc weight)\n", - dw_.n_embd, draft_hidden); - dw_.n_embd = draft_hidden; - } - // Infer n_head from wq shape: wq.ne[1] = n_head * head_dim - if (dw_.n_layer > 0 && dw_.layers[0].wq) { - const int q_dim = (int)dw_.layers[0].wq->ne[1]; - const int inferred_n_head = q_dim / dw_.head_dim; - if (inferred_n_head != dw_.n_head) { - std::printf("[gemma4] draft: overriding n_head %d -> %d\n", - dw_.n_head, inferred_n_head); - dw_.n_head = inferred_n_head; - } - } - // Infer n_ff from ffn_gate shape - if (dw_.n_layer > 0 && dw_.layers[0].w_gate) { - const int inferred_ff = (int)dw_.layers[0].w_gate->ne[1]; - if (inferred_ff != dw_.n_ff) { - std::printf("[gemma4] draft: overriding n_ff %d -> %d\n", - dw_.n_ff, inferred_ff); - dw_.n_ff = inferred_ff; - } - } - // Override n_target_layers from fc shape - dw_.n_target_layers = n_capture; - - // Gemma4 DFlash draft: layers 0-3 are SWA (causal), layer 4 is full (non-causal) - // (from model card: layer_types = [sliding*4, full_attention]) - dw_.swa_window = 2048; - for (int i = 0; i < dw_.n_layer - 1 && i < (int)dw_.layers.size(); i++) - dw_.layers[i].is_swa = true; - - std::printf("[gemma4] draft loaded: fc_in=%d target_hidden=%d " - "draft_hidden=%d n_capture_layers=%d swa=%d\n", - fc_in, w_.n_embd, draft_hidden, n_capture, dw_.swa_window); - - // Allocate target_feat ring buffer - constexpr int TARGET_FEAT_CAP = 4096; - const int feat_cap = std::min(cfg_.device.max_ctx, TARGET_FEAT_CAP); - if (!create_gemma4_target_feat(backend_, cache_, n_capture, w_.n_embd, feat_cap)) { - std::fprintf(stderr, "[gemma4] target_feat alloc failed\n"); - } else { - // Init feature mirror on draft GPU - const int mirror_cap = std::min(cfg_.draft_ctx_max, feat_cap); - if (!draft_feature_mirror_init(feature_mirror_, draft_backend_, - draft_gpu, cfg_.device.gpu, mirror_cap, - n_capture, w_.n_embd)) { - std::fprintf(stderr, "[gemma4] feature mirror init failed\n"); - } else { - // Create DFlash target adapter - dflash_target_ = new Gemma4DFlashTarget(w_, cache_, backend_); - std::printf("[gemma4] spec-decode ready: capture_layers=%d mirror_cap=%d\n", - n_capture, mirror_cap); - std::printf("[gemma4] capture_layer_ids:"); - for (int k = 0; k < (int)cache_.capture_layer_ids.size(); k++) - std::printf(" %d", cache_.capture_layer_ids[k]); - std::printf("\n"); - } - } - } - } + // Load draft model for speculative decode. + if (cfg_.draft_path && !load_decode_draft()) { + std::fprintf(stderr, "[gemma4] draft unavailable; speculative decode disabled\n"); } std::printf("[gemma4] init ok: %d layers, embd=%d, vocab=%d, max_ctx=%d\n", @@ -157,8 +77,15 @@ void Gemma4Backend::print_ready_banner() const { // ── Park / Unpark ────────────────────────────────────────────────────── bool Gemma4Backend::park(const std::string & what) { - (void)what; - if (parked_) return true; + const bool want_draft = (what.empty() || what == "all" || what == "draft"); + const bool want_target = (what.empty() || what == "all" || what == "target"); + + if (want_draft && !draft_parked_) { + free_decode_draft(); + draft_parked_ = true; + std::printf("[gemma4] draft released\n"); std::fflush(stdout); + } + if (!want_target || parked_) return true; // Free snapshots first (they reference the snap_backend buffer) for (int i = 0; i < PREFIX_SLOTS; ++i) { @@ -177,24 +104,37 @@ bool Gemma4Backend::park(const std::string & what) { } bool Gemma4Backend::unpark(const std::string & what) { - (void)what; - if (!parked_) return true; + const bool want_draft = (what.empty() || what == "all" || what == "draft"); + const bool want_target = (what.empty() || what == "all" || what == "target"); + + if (want_target && !parked_) { + // target already resident + } else if (want_target && parked_) { + // Reload weights from disk + if (!load_gemma4_gguf(cfg_.model_path, backend_, w_)) { + std::fprintf(stderr, "[gemma4] unpark: failed to reload weights\n"); + return false; + } - // Reload weights from disk - if (!load_gemma4_gguf(cfg_.model_path, backend_, w_)) { - std::fprintf(stderr, "[gemma4] unpark: failed to reload weights\n"); - return false; - } + // Recreate KV cache + if (!create_gemma4_cache(backend_, w_, cfg_.device.max_ctx, cache_)) { + std::fprintf(stderr, "[gemma4] unpark: failed to recreate cache\n"); + free_gemma4_weights(w_); + return false; + } + cache_.fa_window = cfg_.fa_window; - // Recreate KV cache - if (!create_gemma4_cache(backend_, w_, cfg_.device.max_ctx, cache_)) { - std::fprintf(stderr, "[gemma4] unpark: failed to recreate cache\n"); - free_gemma4_weights(w_); - return false; + parked_ = false; + std::printf("[gemma4] unparked (VRAM restored)\n"); std::fflush(stdout); + if (cfg_.draft_path && !draft_parked_ && draft_backend_) { + delete dflash_target_; + dflash_target_ = new Gemma4DFlashTarget(w_, cache_, backend_); + } } - parked_ = false; - std::printf("[gemma4] unparked (VRAM restored)\n"); std::fflush(stdout); + if (want_draft && draft_parked_ && cfg_.draft_path) { + if (!load_decode_draft()) return false; + } return true; } @@ -1104,16 +1044,114 @@ bool Gemma4Backend::try_handle_command(const std::string & line, return false; // no arch-specific commands } +bool Gemma4Backend::load_decode_draft() { + if (!cfg_.draft_path) return false; + if (draft_backend_ && feature_mirror_.target_feat) { + draft_parked_ = false; + return true; + } + + const int draft_gpu = (cfg_.draft_gpu >= 0) ? cfg_.draft_gpu : cfg_.device.gpu; + draft_backend_ = ggml_backend_cuda_init(draft_gpu); + if (!draft_backend_) { + std::fprintf(stderr, "[gemma4] draft CUDA init failed (gpu=%d)\n", draft_gpu); + return false; + } + if (!load_draft_gguf(cfg_.draft_path, draft_backend_, dw_, nullptr)) { + std::fprintf(stderr, "[gemma4] draft load failed: %s\n", dflash27b_last_error()); + ggml_backend_free(draft_backend_); + draft_backend_ = nullptr; + return false; + } + + dw_.mask_token_id = 4; + const int draft_hidden = (int)dw_.fc->ne[1]; + const int fc_in = (int)dw_.fc->ne[0]; + const int n_capture = fc_in / w_.n_embd; + + if (draft_hidden != dw_.n_embd) { + std::printf("[gemma4] draft: overriding n_embd %d -> %d (from fc weight)\n", + dw_.n_embd, draft_hidden); + dw_.n_embd = draft_hidden; + } + if (dw_.n_layer > 0 && dw_.layers[0].wq) { + const int q_dim = (int)dw_.layers[0].wq->ne[1]; + const int inferred_n_head = q_dim / dw_.head_dim; + if (inferred_n_head != dw_.n_head) { + std::printf("[gemma4] draft: overriding n_head %d -> %d\n", + dw_.n_head, inferred_n_head); + dw_.n_head = inferred_n_head; + } + } + if (dw_.n_layer > 0 && dw_.layers[0].w_gate) { + const int inferred_ff = (int)dw_.layers[0].w_gate->ne[1]; + if (inferred_ff != dw_.n_ff) { + std::printf("[gemma4] draft: overriding n_ff %d -> %d\n", + dw_.n_ff, inferred_ff); + dw_.n_ff = inferred_ff; + } + } + dw_.n_target_layers = n_capture; + dw_.swa_window = 2048; + for (int i = 0; i < dw_.n_layer - 1 && i < (int)dw_.layers.size(); i++) { + dw_.layers[i].is_swa = true; + } + + std::printf("[gemma4] draft loaded: fc_in=%d target_hidden=%d " + "draft_hidden=%d n_capture_layers=%d swa=%d\n", + fc_in, w_.n_embd, draft_hidden, n_capture, dw_.swa_window); + + constexpr int TARGET_FEAT_CAP = 4096; + const int feat_cap = std::min(cfg_.device.max_ctx, TARGET_FEAT_CAP); + if (!cache_.target_feat && + !create_gemma4_target_feat(backend_, cache_, n_capture, w_.n_embd, feat_cap)) { + std::fprintf(stderr, "[gemma4] target_feat alloc failed\n"); + free_decode_draft(); + return false; + } + + const int mirror_cap = std::min(cfg_.draft_ctx_max, feat_cap); + if (!draft_feature_mirror_init(feature_mirror_, draft_backend_, + draft_gpu, cfg_.device.gpu, mirror_cap, + n_capture, w_.n_embd)) { + std::fprintf(stderr, "[gemma4] feature mirror init failed\n"); + free_decode_draft(); + return false; + } + + delete dflash_target_; + dflash_target_ = new Gemma4DFlashTarget(w_, cache_, backend_); + draft_parked_ = false; + std::printf("[gemma4] spec-decode ready: capture_layers=%d mirror_cap=%d\n", + n_capture, mirror_cap); + std::printf("[gemma4] capture_layer_ids:"); + for (int k = 0; k < (int)cache_.capture_layer_ids.size(); k++) { + std::printf(" %d", cache_.capture_layer_ids[k]); + } + std::printf("\n"); + return true; +} + +void Gemma4Backend::free_decode_draft() { + delete dflash_target_; + dflash_target_ = nullptr; + draft_feature_mirror_free(feature_mirror_); + free_gemma4_target_feat(cache_); + if (dw_.ctx) { + free_draft_weights(dw_); + } + if (draft_backend_) { + ggml_backend_free(draft_backend_); + draft_backend_ = nullptr; + } +} + // ── Shutdown ─────────────────────────────────────────────────────────── void Gemma4Backend::shutdown() { for (int i = 0; i < PREFIX_SLOTS; ++i) snapshot_free(i); free_drafter(); - // Clean up DFlash spec-decode resources - delete dflash_target_; dflash_target_ = nullptr; - draft_feature_mirror_free(feature_mirror_); - if (dw_.ctx) { free_draft_weights(dw_); } - if (draft_backend_) { ggml_backend_free(draft_backend_); draft_backend_ = nullptr; } + free_decode_draft(); free_gemma4_cache(cache_); free_gemma4_weights(w_); free_snapshot_backend(snap_backend_, backend_); diff --git a/server/src/gemma4/gemma4_backend.h b/server/src/gemma4/gemma4_backend.h index 4a92a607d..ec14b75f3 100644 --- a/server/src/gemma4/gemma4_backend.h +++ b/server/src/gemma4/gemma4_backend.h @@ -127,6 +127,9 @@ class Gemma4Backend : public ModelBackend { const DaemonIO & io, const BudgetHook * budget_hook = nullptr, bool * forced_close_out = nullptr); + + bool load_decode_draft(); + void free_decode_draft(); }; } // namespace dflash::common diff --git a/server/src/gemma4/gemma4_internal.h b/server/src/gemma4/gemma4_internal.h index 5e643060f..076fb81ce 100644 --- a/server/src/gemma4/gemma4_internal.h +++ b/server/src/gemma4/gemma4_internal.h @@ -199,6 +199,7 @@ bool create_gemma4_cache_partial(ggml_backend_t backend, void free_gemma4_cache(Gemma4Cache & c); // Allocate target_feat ring buffer (call after draft load determines n_capture_layers). +void free_gemma4_target_feat(Gemma4Cache & c); bool create_gemma4_target_feat(ggml_backend_t backend, Gemma4Cache & cache, int n_capture_layers, int hidden_size, int cap); diff --git a/server/src/gemma4/gemma4_loader.cpp b/server/src/gemma4/gemma4_loader.cpp index 84c0544dc..77d4799e1 100644 --- a/server/src/gemma4/gemma4_loader.cpp +++ b/server/src/gemma4/gemma4_loader.cpp @@ -568,6 +568,15 @@ void free_gemma4_cache(Gemma4Cache & c) { c.cur_pos = 0; } +void free_gemma4_target_feat(Gemma4Cache & c) { + if (c.feat_buf) { ggml_backend_buffer_free(c.feat_buf); c.feat_buf = nullptr; } + if (c.feat_ctx) { ggml_free(c.feat_ctx); c.feat_ctx = nullptr; } + c.target_feat = nullptr; + c.target_feat_cap = 0; + c.n_capture_layers = 0; + c.capture_layer_ids.clear(); +} + bool create_gemma4_target_feat(ggml_backend_t backend, Gemma4Cache & cache, int n_capture_layers, int hidden_size, int cap) { if (n_capture_layers <= 0 || hidden_size <= 0 || cap <= 0) return false; diff --git a/server/src/placement/draft_residency.h b/server/src/placement/draft_residency.h new file mode 100644 index 000000000..53bf4baf6 --- /dev/null +++ b/server/src/placement/draft_residency.h @@ -0,0 +1,94 @@ +// Drafter residency policy shared by draft-style runtime paths. +// +// The policy is intentionally scoped by draft use-case. PFlash compression can +// release its drafter immediately after prompt compression, while DFlash decode +// draft may need to stay resident across requests for latency. + +#pragma once + +#include + +namespace dflash::common { + +enum class DraftResidencyPolicy { + Auto, + Persistent, + RequestScoped, +}; + +enum class DraftResidencyUse { + PFlashCompress, + DFlashDecode, + MtpDecode, +}; + +enum class DraftResidencyAction { + KeepLoaded, + ReleaseAfterUse, +}; + +struct DraftResidencyContext { + DraftResidencyUse use = DraftResidencyUse::PFlashCompress; + bool low_vram_hint = false; + bool has_decode_draft = false; +}; + +inline const char * draft_residency_policy_name(DraftResidencyPolicy policy) { + switch (policy) { + case DraftResidencyPolicy::Auto: return "auto"; + case DraftResidencyPolicy::Persistent: return "persistent"; + case DraftResidencyPolicy::RequestScoped: return "request-scoped"; + } + return "auto"; +} + +inline bool parse_draft_residency_policy(const std::string & value, + DraftResidencyPolicy & out) { + if (value == "auto") { + out = DraftResidencyPolicy::Auto; + return true; + } + if (value == "persistent") { + out = DraftResidencyPolicy::Persistent; + return true; + } + if (value == "request-scoped" || value == "request_scoped") { + out = DraftResidencyPolicy::RequestScoped; + return true; + } + return false; +} + +inline DraftResidencyAction resolve_draft_residency_action( + DraftResidencyPolicy policy, + const DraftResidencyContext & ctx) { + if (policy == DraftResidencyPolicy::Persistent) { + return DraftResidencyAction::KeepLoaded; + } + if (policy == DraftResidencyPolicy::RequestScoped) { + return DraftResidencyAction::ReleaseAfterUse; + } + + switch (ctx.use) { + case DraftResidencyUse::PFlashCompress: + // In auto mode, only release the PFlash drafter when the operator gave + // a low-VRAM hint. That preserves the existing fast resident path while + // allowing small-card setups to make room for decode draft/target state. + return ctx.low_vram_hint + ? DraftResidencyAction::ReleaseAfterUse + : DraftResidencyAction::KeepLoaded; + case DraftResidencyUse::DFlashDecode: + // DFlash draft is latency-sensitive; keep it resident unless the + // operator explicitly opted into the low-VRAM/request-scoped path. + return (ctx.low_vram_hint && ctx.has_decode_draft) + ? DraftResidencyAction::ReleaseAfterUse + : DraftResidencyAction::KeepLoaded; + case DraftResidencyUse::MtpDecode: + // Placeholder use-case for future draft-style decode paths. Default to + // persistent until a concrete MTP residency lifecycle is wired. + return DraftResidencyAction::KeepLoaded; + } + return DraftResidencyAction::KeepLoaded; +} + +} // namespace dflash::common diff --git a/server/src/qwen3/qwen3_backend.cpp b/server/src/qwen3/qwen3_backend.cpp index e2adc7f65..253886978 100644 --- a/server/src/qwen3/qwen3_backend.cpp +++ b/server/src/qwen3/qwen3_backend.cpp @@ -955,6 +955,10 @@ ModelBackend::CompressResult Qwen3Backend::compress(const CompressRequest & req) drafter_ctx_, req.input_ids, req.keep_ratio); result.ok = true; + if (req.residency_action == DraftResidencyAction::ReleaseAfterUse) { + free_drafter(); + } + if (!req.skip_park && !was_parked) unpark("target"); return result; } diff --git a/server/src/qwen35/qwen35_backend.cpp b/server/src/qwen35/qwen35_backend.cpp index e3b161d8c..4a3d9674e 100644 --- a/server/src/qwen35/qwen35_backend.cpp +++ b/server/src/qwen35/qwen35_backend.cpp @@ -402,8 +402,9 @@ ModelBackend::CompressResult Qwen35Backend::compress(const CompressRequest & req req.input_ids.size(), result.compressed_ids.size()); } - // Keep drafter loaded (own backend + weights persist), matching test_dflash. - // ~1.4 GB stays resident but avoids reload cost on subsequent compresses. + if (req.residency_action == DraftResidencyAction::ReleaseAfterUse) { + free_drafter(); + } // Restore park state if (!req.skip_park) { diff --git a/server/src/server/http_server.cpp b/server/src/server/http_server.cpp index 362c2f4d8..e751cad70 100644 --- a/server/src/server/http_server.cpp +++ b/server/src/server/http_server.cpp @@ -135,6 +135,7 @@ json build_props_body(const ServerConfig & config, {"bsa_enabled", nullptr}, {"bsa_alpha", nullptr}, {"lm_head_fix", nullptr}, + {"draft_residency", draft_residency_policy_name(config.draft_residency)}, }; } else { const char * bsa_env = std::getenv("DFLASH_FP_USE_BSA"); @@ -160,6 +161,7 @@ json build_props_body(const ServerConfig & config, {"bsa_enabled", (bsa_env != nullptr && *bsa_env && std::strcmp(bsa_env, "0") != 0)}, {"bsa_alpha", bsa_alpha}, {"lm_head_fix", (lmfix_env != nullptr && *lmfix_env && std::strcmp(lmfix_env, "0") != 0)}, + {"draft_residency", draft_residency_policy_name(config.draft_residency)}, }; } @@ -197,6 +199,7 @@ json build_props_body(const ServerConfig & config, {"kv_cache_k", config.kv_cache_k}, {"kv_cache_v", config.kv_cache_v}, {"lazy_draft", config.lazy_draft}, + {"draft_residency", draft_residency_policy_name(config.draft_residency)}, {"target_sharding", config.target_sharding}, // Prefill chunk size (bargs.chunk). Surfaced so snapshot // tooling captures the full config — bench consumers @@ -1216,6 +1219,15 @@ void HttpServer::worker_loop() { creq.drafter_path = config_.pflash_drafter_path; creq.drafter_gpu = config_.pflash_drafter_gpu; creq.skip_park = config_.pflash_skip_park; + const auto pflash_residency = + resolve_draft_residency_action( + config_.draft_residency, + DraftResidencyContext{ + DraftResidencyUse::PFlashCompress, + config_.lazy_draft, + !config_.draft_path.empty(), + }); + creq.residency_action = pflash_residency; ModelBackend::CompressResult cresult; if (config_.pflash_remote_drafter) { @@ -1229,6 +1241,9 @@ void HttpServer::worker_loop() { cresult.ok = pflash_remote_.compress( creq.input_ids, creq.keep_ratio, cresult.compressed_ids); + if (pflash_residency == DraftResidencyAction::ReleaseAfterUse) { + pflash_remote_.close(); + } } } else { cresult = backend_.compress(creq); @@ -1501,9 +1516,20 @@ void HttpServer::worker_loop() { return true; }; + const auto dflash_residency = + resolve_draft_residency_action( + config_.draft_residency, + DraftResidencyContext{ + DraftResidencyUse::DFlashDecode, + config_.lazy_draft, + !config_.draft_path.empty(), + }); + // Run generation (with or without restore). - // Lazy-draft: ensure decode draft is loaded before generate. - if (config_.lazy_draft) { + // Request-scoped draft residency ensures decode draft is loaded only + // around the generation window, leaving room for PFlash/target state. + if (dflash_residency == DraftResidencyAction::ReleaseAfterUse && + !config_.draft_path.empty()) { backend_.free_drafter(); // free pflash drafter (~1.4 GB) if loaded backend_.unpark("draft"); // reload decode draft (~3.3 GB) } @@ -1515,8 +1541,8 @@ void HttpServer::worker_loop() { result = backend_.generate_with_empty_spec_fallback(gen_req, io); } - // Lazy-draft: park decode draft after generate to free VRAM. - if (config_.lazy_draft) { + if (dflash_residency == DraftResidencyAction::ReleaseAfterUse && + !config_.draft_path.empty()) { backend_.park("draft"); } diff --git a/server/src/server/http_server.h b/server/src/server/http_server.h index 999eb5d99..71c544acb 100644 --- a/server/src/server/http_server.h +++ b/server/src/server/http_server.h @@ -18,6 +18,7 @@ #include "prefix_cache.h" #include "disk_prefix_cache.h" #include "api_types.h" +#include "placement/draft_residency.h" #include "placement/remote_draft_config.h" #include "common/pflash_drafter_ipc.h" #include "model_card.h" @@ -149,7 +150,8 @@ struct ServerConfig { bool pflash_remote_drafter = false; // use IPC drafter for mixed backends RemoteDraftConfig pflash_remote; // IPC binary/work-dir for remote PFlash drafter bool pflash_skip_park = false; // skip park/unpark for >=32GB GPUs - bool lazy_draft = false; // park decode draft when idle to save VRAM + bool lazy_draft = false; // legacy alias for request-scoped draft residency + DraftResidencyPolicy draft_residency = DraftResidencyPolicy::Auto; // Disk prefix cache std::string disk_cache_dir; // empty = disabled diff --git a/server/src/server/server_main.cpp b/server/src/server/server_main.cpp index 0f31739ed..5f00d4dfe 100644 --- a/server/src/server/server_main.cpp +++ b/server/src/server/server_main.cpp @@ -19,6 +19,7 @@ #include "common/layer_split_utils.h" #include "common/peer_access.h" #include "placement/pflash_placement.h" +#include "placement/draft_residency.h" #include "gguf.h" @@ -211,7 +212,9 @@ static void print_usage(const char * prog) { " --prefill-keep-ratio Fraction of tokens to keep (default: 0.05)\n" " --prefill-drafter Drafter GGUF for compression (Qwen3-0.6B)\n" " --prefill-skip-park Skip park/unpark (for >=32GB GPUs)\n" - " --lazy-draft Park decode draft when idle to save VRAM\n" + " --draft-residency auto|persistent|request-scoped\n" + " Drafter lifetime policy (default: auto)\n" + " --lazy-draft Legacy alias for --draft-residency=request-scoped\n" "\n" "Disk KV cache:\n" " --kv-cache-dir Directory for ondisk KV cache (enables feature)\n" @@ -389,8 +392,19 @@ int main(int argc, char ** argv) { sconfig.pflash_drafter_path = argv[++i]; } else if (std::strcmp(argv[i], "--prefill-skip-park") == 0) { sconfig.pflash_skip_park = true; + } else if (std::strcmp(argv[i], "--draft-residency") == 0 && i + 1 < argc) { + if (!parse_draft_residency_policy(argv[++i], sconfig.draft_residency)) { + std::fprintf(stderr, + "[server] unknown --draft-residency policy: '%s' " + "(expected: auto, persistent, request-scoped)\n", argv[i]); + print_usage(argv[0]); + return 1; + } + sconfig.lazy_draft = + (sconfig.draft_residency == DraftResidencyPolicy::RequestScoped); } else if (std::strcmp(argv[i], "--lazy-draft") == 0) { sconfig.lazy_draft = true; + sconfig.draft_residency = DraftResidencyPolicy::RequestScoped; } else if (std::strcmp(argv[i], "--chat-template-file") == 0 && i + 1 < argc) { const char * path = argv[++i]; std::FILE * f = std::fopen(path, "rb"); @@ -499,9 +513,12 @@ int main(int argc, char ** argv) { setenv("DFLASH27B_FA_WINDOW", "0", 0); } - // Lazy-draft requires both prefill-drafter AND decode draft to be useful. - if (sconfig.lazy_draft && !(pflash_enabled && bargs.draft_path)) { - std::fprintf(stderr, "[server] --lazy-draft ignored: requires both --prefill-drafter and --draft\n"); + if (sconfig.draft_residency == DraftResidencyPolicy::RequestScoped && + !(pflash_enabled || bargs.draft_path)) { + std::fprintf(stderr, + "[server] --draft-residency=request-scoped ignored: requires " + "--prefill-compression or --draft\n"); + sconfig.draft_residency = DraftResidencyPolicy::Auto; sconfig.lazy_draft = false; } @@ -784,6 +801,8 @@ int main(int argc, char ** argv) { std::fprintf(stderr, "[server] │ fp_use_bsa = %s\n", getenv("DFLASH_FP_USE_BSA") ? "ON" : "off"); std::fprintf(stderr, "[server] │ fp_alpha = %s\n", getenv("DFLASH_FP_ALPHA") ? getenv("DFLASH_FP_ALPHA") : "0.12 (default)"); } + std::fprintf(stderr, "[server] │ draft_residency = %s\n", + draft_residency_policy_name(sconfig.draft_residency)); if (bargs.draft_path) { std::fprintf(stderr, "[server] │ lazy_draft = %s\n", sconfig.lazy_draft ? "ON" : "off"); } diff --git a/server/test/test_server_unit.cpp b/server/test/test_server_unit.cpp index 275ec935b..57f998acd 100644 --- a/server/test/test_server_unit.cpp +++ b/server/test/test_server_unit.cpp @@ -23,6 +23,7 @@ #include "placement/placement_config.h" #include "common/layer_split_backend.h" #include "common/layer_split_utils.h" +#include "placement/draft_residency.h" #include #include @@ -892,6 +893,7 @@ static void test_pflash_config_defaults() { TEST_ASSERT(cfg.pflash_keep_ratio > 0.04f && cfg.pflash_keep_ratio < 0.06f); TEST_ASSERT(cfg.pflash_drafter_path.empty()); TEST_ASSERT(!cfg.pflash_skip_park); + TEST_ASSERT(cfg.draft_residency == DraftResidencyPolicy::Auto); } static void test_pflash_config_modes() { @@ -1038,6 +1040,77 @@ static void test_pflash_placement_usage_gate() { /*pflash_enabled=*/true, /*has_decode_draft=*/true)); } +static void test_draft_residency_parse() { + DraftResidencyPolicy policy = DraftResidencyPolicy::Auto; + TEST_ASSERT(parse_draft_residency_policy("auto", policy)); + TEST_ASSERT(policy == DraftResidencyPolicy::Auto); + TEST_ASSERT(parse_draft_residency_policy("persistent", policy)); + TEST_ASSERT(policy == DraftResidencyPolicy::Persistent); + TEST_ASSERT(parse_draft_residency_policy("request-scoped", policy)); + TEST_ASSERT(policy == DraftResidencyPolicy::RequestScoped); + TEST_ASSERT(parse_draft_residency_policy("request_scoped", policy)); + TEST_ASSERT(policy == DraftResidencyPolicy::RequestScoped); + TEST_ASSERT(!parse_draft_residency_policy("request", policy)); +} + +static void test_draft_residency_pflash_auto() { + auto action = resolve_draft_residency_action( + DraftResidencyPolicy::Auto, + DraftResidencyContext{ + DraftResidencyUse::PFlashCompress, + /*low_vram_hint=*/false, + /*has_decode_draft=*/false, + }); + TEST_ASSERT(action == DraftResidencyAction::KeepLoaded); + + action = resolve_draft_residency_action( + DraftResidencyPolicy::Auto, + DraftResidencyContext{ + DraftResidencyUse::PFlashCompress, + /*low_vram_hint=*/true, + /*has_decode_draft=*/true, + }); + TEST_ASSERT(action == DraftResidencyAction::ReleaseAfterUse); +} + +static void test_draft_residency_dflash_auto_and_request_scoped() { + auto action = resolve_draft_residency_action( + DraftResidencyPolicy::Auto, + DraftResidencyContext{ + DraftResidencyUse::DFlashDecode, + /*low_vram_hint=*/false, + /*has_decode_draft=*/true, + }); + TEST_ASSERT(action == DraftResidencyAction::KeepLoaded); + + action = resolve_draft_residency_action( + DraftResidencyPolicy::Auto, + DraftResidencyContext{ + DraftResidencyUse::DFlashDecode, + /*low_vram_hint=*/true, + /*has_decode_draft=*/true, + }); + TEST_ASSERT(action == DraftResidencyAction::ReleaseAfterUse); + + action = resolve_draft_residency_action( + DraftResidencyPolicy::RequestScoped, + DraftResidencyContext{ + DraftResidencyUse::DFlashDecode, + /*low_vram_hint=*/false, + /*has_decode_draft=*/true, + }); + TEST_ASSERT(action == DraftResidencyAction::ReleaseAfterUse); + + action = resolve_draft_residency_action( + DraftResidencyPolicy::Persistent, + DraftResidencyContext{ + DraftResidencyUse::DFlashDecode, + /*low_vram_hint=*/true, + /*has_decode_draft=*/true, + }); + TEST_ASSERT(action == DraftResidencyAction::KeepLoaded); +} + // ═══════════════════════════════════════════════════════════════════════ // Jinja chat template // ═══════════════════════════════════════════════════════════════════════ @@ -2292,6 +2365,7 @@ static void test_props_runtime_shape() { cfg.kv_cache_k = "tq3_0"; cfg.kv_cache_v = "tq3_0"; cfg.lazy_draft = false; + cfg.draft_residency = DraftResidencyPolicy::Persistent; cfg.target_sharding = false; cfg.chunk = 512; cfg.target_device = "auto:0"; @@ -2309,10 +2383,12 @@ static void test_props_runtime_shape() { TEST_ASSERT(rt["kv_cache_k"].get() == "tq3_0"); TEST_ASSERT(rt["kv_cache_v"].get() == "tq3_0"); TEST_ASSERT(rt["lazy_draft"].get() == false); + TEST_ASSERT(rt["draft_residency"].get() == "persistent"); TEST_ASSERT(rt["target_sharding"].get() == false); TEST_ASSERT(rt["chunk"].get() == 512); TEST_ASSERT(rt["target_device"].get() == "auto:0"); TEST_ASSERT(rt["draft_device"].get() == "auto:0"); + TEST_ASSERT(body["pflash"]["draft_residency"].get() == "persistent"); // draft_device is null when no draft model is loaded. cfg.draft_device.clear(); @@ -2643,6 +2719,9 @@ int main() { RUN_TEST(test_pflash_placement_auto_draft_follows_target); RUN_TEST(test_pflash_placement_disabled_never_remote); RUN_TEST(test_pflash_placement_usage_gate); + RUN_TEST(test_draft_residency_parse); + RUN_TEST(test_draft_residency_pflash_auto); + RUN_TEST(test_draft_residency_dflash_auto_and_request_scoped); std::fprintf(stderr, "\n── Jinja chat template ──\n"); RUN_TEST(test_jinja_render_basic);