From a9aedf7d905d1b2810a29749ae4d66951acd900e Mon Sep 17 00:00:00 2001 From: weicj Date: Fri, 29 May 2026 01:53:50 +0800 Subject: [PATCH 1/3] fix(server): enable sampling for target layer split --- server/src/common/layer_split_backend.cpp | 3 +- server/src/common/layer_split_backend.h | 1 + server/src/gemma4/gemma4_graph.cpp | 43 ++++++++++++--- server/src/gemma4/gemma4_internal.h | 9 ++++ .../src/gemma4/gemma4_layer_split_adapter.cpp | 37 ++++++++++--- .../src/gemma4/gemma4_layer_split_adapter.h | 8 ++- server/src/qwen35/layer_split_forward.cpp | 54 ++++++++++++++----- server/src/qwen35/layer_split_forward.h | 12 +++++ .../src/qwen35/qwen35_layer_split_adapter.cpp | 41 ++++++++++++-- .../src/qwen35/qwen35_layer_split_adapter.h | 3 ++ server/test/test_server_unit.cpp | 39 ++++++++++++++ 11 files changed, 216 insertions(+), 34 deletions(-) diff --git a/server/src/common/layer_split_backend.cpp b/server/src/common/layer_split_backend.cpp index 11e75e0a7..431b2ec79 100644 --- a/server/src/common/layer_split_backend.cpp +++ b/server/src/common/layer_split_backend.cpp @@ -57,7 +57,8 @@ GenerateResult LayerSplitBackend::run_from_state(const GenerateRequest & req, result.error = "context"; return result; } - if (req.do_sample && req.sampler.temp > 0.0f) { + if (req.do_sample && req.sampler.needs_logit_processing() && + !adapter_->supports_cpu_sampling()) { result.error = "sampling_unsupported"; return result; } diff --git a/server/src/common/layer_split_backend.h b/server/src/common/layer_split_backend.h index 4aeda23f6..0386936aa 100644 --- a/server/src/common/layer_split_backend.h +++ b/server/src/common/layer_split_backend.h @@ -30,6 +30,7 @@ class LayerSplitAdapter { virtual bool decode_ar(int last_tok, int committed, int n_gen, std::vector & out_tokens, const DaemonIO & io) = 0; + virtual bool supports_cpu_sampling() const { return false; } virtual bool can_dflash_decode() const { return false; } virtual bool decode_dflash(const std::vector & prompt, diff --git a/server/src/gemma4/gemma4_graph.cpp b/server/src/gemma4/gemma4_graph.cpp index bf8f8ce7c..bc2276555 100644 --- a/server/src/gemma4/gemma4_graph.cpp +++ b/server/src/gemma4/gemma4_graph.cpp @@ -515,13 +515,14 @@ bool build_gemma4_layer_step( return ggml_gallocr_alloc_graph(sg.alloc, sg.gf); } -bool compute_gemma4_split_argmax( +bool compute_gemma4_split_projection( ggml_backend_t backend, const Gemma4Weights & w, ggml_tensor * act, int token_offset, int n_tokens, - std::vector & out_argmax) { + std::vector * out_argmax, + std::vector * out_logits) { ggml_init_params ip{}; ip.mem_size = ggml_tensor_overhead() * 64 + ggml_graph_overhead() + 1024 * 1024; ip.no_alloc = true; @@ -539,9 +540,17 @@ bool compute_gemma4_split_argmax( cur = ggml_tanh(ctx, cur); cur = ggml_scale(ctx, cur, w.final_logit_softcap); } - cur = ggml_argmax(ctx, cur); - ggml_set_output(cur); - ggml_build_forward_expand(gf, cur); + ggml_tensor * logits = cur; + ggml_tensor * argmax = nullptr; + if (out_logits) { + ggml_set_output(logits); + ggml_build_forward_expand(gf, logits); + } + if (out_argmax) { + argmax = ggml_argmax(ctx, logits); + ggml_set_output(argmax); + ggml_build_forward_expand(gf, argmax); + } ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); if (!alloc || !ggml_gallocr_alloc_graph(alloc, gf)) { @@ -554,14 +563,32 @@ bool compute_gemma4_split_argmax( ggml_free(ctx); return false; } - out_argmax.resize((size_t)n_tokens); - ggml_backend_tensor_get(cur, out_argmax.data(), 0, - sizeof(int32_t) * (size_t)n_tokens); + if (out_argmax) { + out_argmax->resize((size_t)n_tokens); + ggml_backend_tensor_get(argmax, out_argmax->data(), 0, + sizeof(int32_t) * (size_t)n_tokens); + } + if (out_logits) { + out_logits->resize((size_t)w.n_vocab * (size_t)n_tokens); + ggml_backend_tensor_get(logits, out_logits->data(), 0, + sizeof(float) * (size_t)w.n_vocab * (size_t)n_tokens); + } ggml_gallocr_free(alloc); ggml_free(ctx); return true; } +bool compute_gemma4_split_argmax( + ggml_backend_t backend, + const Gemma4Weights & w, + ggml_tensor * act, + int token_offset, + int n_tokens, + std::vector & out_argmax) { + return compute_gemma4_split_projection( + backend, w, act, token_offset, n_tokens, &out_argmax, nullptr); +} + bool gemma4_step( ggml_backend_t backend, const Gemma4Weights & w, diff --git a/server/src/gemma4/gemma4_internal.h b/server/src/gemma4/gemma4_internal.h index 5e643060f..80f81cd11 100644 --- a/server/src/gemma4/gemma4_internal.h +++ b/server/src/gemma4/gemma4_internal.h @@ -286,6 +286,15 @@ bool compute_gemma4_split_argmax( int n_tokens, std::vector & out_argmax); +bool compute_gemma4_split_projection( + ggml_backend_t backend, + const Gemma4Weights & w, + ggml_tensor * act, + int token_offset, + int n_tokens, + std::vector * out_argmax, + std::vector * out_logits); + // BSA sparse-FA prefill: process the full prompt at once using block-sparse // attention for SWA layers (flash_prefill_forward_bf16). Full-attention layers // use dense FA. Returns logits for the last token. Populates the KV cache diff --git a/server/src/gemma4/gemma4_layer_split_adapter.cpp b/server/src/gemma4/gemma4_layer_split_adapter.cpp index 2562faec8..5a6be860a 100644 --- a/server/src/gemma4/gemma4_layer_split_adapter.cpp +++ b/server/src/gemma4/gemma4_layer_split_adapter.cpp @@ -140,7 +140,10 @@ bool Gemma4LayerSplitAdapter::init() { } void Gemma4LayerSplitAdapter::begin_request(const GenerateRequest & req) { - (void)req; + sampler_ = req.sampler; + if (req.do_sample && sampler_.seed != 0) { + sampler_rng_.seed(sampler_.seed); + } } void Gemma4LayerSplitAdapter::reset_request_state() { @@ -148,12 +151,14 @@ void Gemma4LayerSplitAdapter::reset_request_state() { shard.cache.cur_pos = 0; shard.cache.last_tok = -1; } + prefill_last_logits_.clear(); } bool Gemma4LayerSplitAdapter::run_forward( const std::vector & tokens, int base_pos, - int & last_tok) { + int & last_tok, + std::vector * logits_out) { if (shards_.empty() || tokens.empty()) return false; const Gemma4Weights & ref = shards_.front().weights; const int hidden = ref.n_embd; @@ -336,9 +341,9 @@ bool Gemma4LayerSplitAdapter::run_forward( std::vector argmax; Gemma4LayerSplitShard & last = shards_.back(); - const bool ok = compute_gemma4_split_argmax( + const bool ok = compute_gemma4_split_projection( last.backend, last.weights, act_in, - n_tokens_total - 1, 1, argmax); + n_tokens_total - 1, 1, &argmax, logits_out); activation_buffer_free(orig); activation_pair_free(acts); if (!ok || argmax.empty()) return false; @@ -353,7 +358,7 @@ bool Gemma4LayerSplitAdapter::run_forward( bool Gemma4LayerSplitAdapter::prefill(const std::vector & prompt, int base_pos, int & last_tok) { - return run_forward(prompt, base_pos, last_tok); + return run_forward(prompt, base_pos, last_tok, &prefill_last_logits_); } bool Gemma4LayerSplitAdapter::decode_ar( @@ -366,6 +371,13 @@ bool Gemma4LayerSplitAdapter::decode_ar( if (shards_.empty()) return false; const auto & w = shards_.front().weights; + const int vocab = w.n_vocab; + std::vector logits_buf; + if (sampler_.needs_logit_processing()) { + if ((int)prefill_last_logits_.size() != vocab) return false; + last_tok = sample_logits(prefill_last_logits_.data(), vocab, sampler_, + out_tokens, sampler_rng_); + } out_tokens.push_back(last_tok); io.emit(last_tok); if (io.cancelled) { @@ -381,7 +393,16 @@ bool Gemma4LayerSplitAdapter::decode_ar( for (int i = 1; i < n_gen; ++i) { std::vector one(1, last_tok); int next_tok = -1; - if (!run_forward(one, committed - 1, next_tok)) return false; + logits_buf.clear(); + if (!run_forward(one, committed - 1, next_tok, + sampler_.needs_logit_processing() ? &logits_buf : nullptr)) { + return false; + } + if (sampler_.needs_logit_processing()) { + if ((int)logits_buf.size() != vocab) return false; + next_tok = sample_logits(logits_buf.data(), vocab, sampler_, + out_tokens, sampler_rng_); + } last_tok = next_tok; out_tokens.push_back(last_tok); io.emit(last_tok); @@ -455,6 +476,7 @@ bool Gemma4LayerSplitAdapter::snapshot_save(int slot) { } snap.cur_pos = snap_pos; snap.last_tok = shards_.front().cache.last_tok; + snap.prefill_last_logits = prefill_last_logits_; return true; } @@ -466,6 +488,7 @@ void Gemma4LayerSplitAdapter::snapshot_free(int slot) { } snap.cur_pos = 0; snap.last_tok = -1; + snap.prefill_last_logits.clear(); if (snap.shards.size() != shards_.size()) snap.shards.resize(shards_.size()); } @@ -476,6 +499,7 @@ bool Gemma4LayerSplitAdapter::snapshot_used(int slot) const { } const auto & snap = snapshots_[(size_t)slot]; if (snap.cur_pos <= 0 || snap.shards.size() != shards_.size()) return false; + if (snap.prefill_last_logits.empty()) return false; for (const auto & ss : snap.shards) { if (!ss.ctx) return false; } @@ -515,6 +539,7 @@ bool Gemma4LayerSplitAdapter::snapshot_restore(int slot) { shards_[i].cache.cur_pos = snap.cur_pos; shards_[i].cache.last_tok = snap.last_tok; } + prefill_last_logits_ = snap.prefill_last_logits; return true; } diff --git a/server/src/gemma4/gemma4_layer_split_adapter.h b/server/src/gemma4/gemma4_layer_split_adapter.h index 430918b57..17f2f21c9 100644 --- a/server/src/gemma4/gemma4_layer_split_adapter.h +++ b/server/src/gemma4/gemma4_layer_split_adapter.h @@ -30,6 +30,7 @@ struct Gemma4LayerSplitSnapshot { int cur_pos = 0; int32_t last_tok = -1; std::vector shards; + std::vector prefill_last_logits; }; class Gemma4LayerSplitAdapter : public LayerSplitAdapter { @@ -51,6 +52,7 @@ class Gemma4LayerSplitAdapter : public LayerSplitAdapter { bool decode_ar(int last_tok, int committed, int n_gen, std::vector & out_tokens, const DaemonIO & io) override; + bool supports_cpu_sampling() const override { return true; } bool snapshot_save(int slot) override; void snapshot_free(int slot) override; @@ -65,13 +67,17 @@ class Gemma4LayerSplitAdapter : public LayerSplitAdapter { private: bool run_forward(const std::vector & tokens, int base_pos, - int & last_tok); + int & last_tok, + std::vector * logits_out = nullptr); Gemma4LayerSplitAdapterConfig cfg_; std::vector shards_; std::vector snapshot_backends_; std::vector snapshots_; static constexpr int PREFIX_SLOTS = ModelBackend::kMaxSlots; + SamplerCfg sampler_; + std::mt19937_64 sampler_rng_{std::random_device{}()}; + std::vector prefill_last_logits_; }; void free_gemma4_layer_split_shards(std::vector & shards); diff --git a/server/src/qwen35/layer_split_forward.cpp b/server/src/qwen35/layer_split_forward.cpp index d1ab66587..5fd774cc0 100644 --- a/server/src/qwen35/layer_split_forward.cpp +++ b/server/src/qwen35/layer_split_forward.cpp @@ -17,7 +17,7 @@ namespace dflash::common { -bool compute_target_split_argmax( +bool compute_target_split_projection( StepGraph & sg, const TargetWeights & w, ggml_backend_t backend, @@ -26,7 +26,8 @@ bool compute_target_split_argmax( int n_tokens, int hidden, int vocab, - std::vector & argmax_out) { + std::vector * argmax_out, + std::vector * logits_out) { step_graph_free(sg); ggml_init_params ip{}; ip.mem_size = 256 * 1024 * 1024; @@ -43,24 +44,51 @@ bool compute_target_split_argmax( ggml_tensor * logits = ggml_mul_mat(sg.ctx, w.output, normed); ggml_set_name(logits, "target_split_logits"); sg.logits = logits; - sg.argmax_tokens = ggml_argmax(sg.ctx, logits); - ggml_set_name(sg.argmax_tokens, "target_split_argmax"); - ggml_set_output(sg.argmax_tokens); + if (argmax_out) { + sg.argmax_tokens = ggml_argmax(sg.ctx, logits); + ggml_set_name(sg.argmax_tokens, "target_split_argmax"); + ggml_set_output(sg.argmax_tokens); + } + if (logits_out) { + ggml_set_output(sg.logits); + } sg.gf = ggml_new_graph_custom(sg.ctx, 1024, false); - ggml_build_forward_expand(sg.gf, sg.argmax_tokens); + if (argmax_out) ggml_build_forward_expand(sg.gf, sg.argmax_tokens); + if (logits_out) ggml_build_forward_expand(sg.gf, sg.logits); if (!sg.alloc) { sg.alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); } if (!ggml_gallocr_alloc_graph(sg.alloc, sg.gf)) return false; auto st = ggml_backend_graph_compute(backend, sg.gf); if (st != GGML_STATUS_SUCCESS) return false; - (void)vocab; - argmax_out.assign((size_t)n_tokens, 0); - ggml_backend_tensor_get(sg.argmax_tokens, argmax_out.data(), 0, - sizeof(int32_t) * (size_t)n_tokens); + if (argmax_out) { + argmax_out->assign((size_t)n_tokens, 0); + ggml_backend_tensor_get(sg.argmax_tokens, argmax_out->data(), 0, + sizeof(int32_t) * (size_t)n_tokens); + } + if (logits_out) { + logits_out->assign((size_t)vocab * (size_t)n_tokens, 0.0f); + ggml_backend_tensor_get(sg.logits, logits_out->data(), 0, + sizeof(float) * (size_t)vocab * (size_t)n_tokens); + } return true; } +bool compute_target_split_argmax( + StepGraph & sg, + const TargetWeights & w, + ggml_backend_t backend, + ggml_tensor * act, + int token_offset, + int n_tokens, + int hidden, + int vocab, + std::vector & argmax_out) { + return compute_target_split_projection( + sg, w, backend, act, token_offset, n_tokens, hidden, vocab, + &argmax_out, nullptr); +} + bool run_qwen35_layer_split_forward( std::vector & shards, const TargetWeights & embed_source, @@ -208,9 +236,10 @@ bool run_qwen35_layer_split_forward( const bool need_all_argmax = argmax_out != nullptr; const int argmax_offset = need_all_argmax ? 0 : (n_tokens_total - 1); const int argmax_count = need_all_argmax ? n_tokens_total : 1; - const bool ok = compute_target_split_argmax( + const bool ok = compute_target_split_projection( final_sg, last_shard.weights, last_shard.backend, act_in, - argmax_offset, argmax_count, hidden, vocab, argmax_tokens); + argmax_offset, argmax_count, hidden, vocab, + &argmax_tokens, logits_out); step_graph_destroy(final_sg); activation_pair_free(acts); if (!ok) return false; @@ -220,7 +249,6 @@ bool run_qwen35_layer_split_forward( shard.cache.last_tok = last_tok; } if (argmax_out) *argmax_out = std::move(argmax_tokens); - if (logits_out) logits_out->clear(); return true; } diff --git a/server/src/qwen35/layer_split_forward.h b/server/src/qwen35/layer_split_forward.h index c04680fe4..bb01bff09 100644 --- a/server/src/qwen35/layer_split_forward.h +++ b/server/src/qwen35/layer_split_forward.h @@ -32,6 +32,18 @@ bool compute_target_split_argmax( int vocab, std::vector & argmax_out); +bool compute_target_split_projection( + StepGraph & sg, + const TargetWeights & w, + ggml_backend_t backend, + ggml_tensor * act, + int token_offset, + int n_tokens, + int hidden, + int vocab, + std::vector * argmax_out, + std::vector * logits_out); + // Run a full forward pass through all shards, writing K/V into each shard's // cache. Returns the argmax of the last token in `last_tok`. // Optionally captures features into `feature_ring` / remote draft. diff --git a/server/src/qwen35/qwen35_layer_split_adapter.cpp b/server/src/qwen35/qwen35_layer_split_adapter.cpp index a9ece2c37..98f8b0259 100644 --- a/server/src/qwen35/qwen35_layer_split_adapter.cpp +++ b/server/src/qwen35/qwen35_layer_split_adapter.cpp @@ -86,6 +86,7 @@ bool Qwen35LayerSplitAdapter::init() { for (auto & slot : prefix_snapshots_) { slot.resize(shards_.size()); } + snapshot_prefill_logits_.resize(PREFIX_SLOTS); draft_feature_snapshots_.resize(PREFIX_SLOTS); return true; @@ -171,6 +172,7 @@ void Qwen35LayerSplitAdapter::begin_request(const GenerateRequest & req) { void Qwen35LayerSplitAdapter::reset_request_state() { for (auto & shard : shards_) reset_target_cache(shard.cache); + prefill_last_logits_.clear(); } bool Qwen35LayerSplitAdapter::prefill(const std::vector & prompt, @@ -190,7 +192,8 @@ bool Qwen35LayerSplitAdapter::prefill(const std::vector & prompt, shards_, shards_.front().weights, prompt, base_pos, ubatch, last_tok, cfg_.kq_stride_pad, /*fa_window=*/0, (cfg_.run_dflash && !remote_draft_.active()) ? &feature_ring_ : nullptr, - /*argmax_out=*/nullptr, /*logits_out=*/nullptr, + /*argmax_out=*/nullptr, + &prefill_last_logits_, cfg_.run_dflash ? &remote_draft_ : nullptr); } @@ -215,6 +218,8 @@ bool Qwen35LayerSplitAdapter::snapshot_save(int slot) { return false; } } + if (snapshot_prefill_logits_.size() != (size_t)PREFIX_SLOTS) return false; + snapshot_prefill_logits_[(size_t)slot] = prefill_last_logits_; if (!snapshot_draft_features(slot)) { snapshot_free(slot); return false; @@ -227,6 +232,9 @@ void Qwen35LayerSplitAdapter::snapshot_free(int slot) { for (auto & snap : prefix_snapshots_[(size_t)slot]) { free_prefix_snapshot(snap); } + if (snapshot_prefill_logits_.size() == (size_t)PREFIX_SLOTS) { + snapshot_prefill_logits_[(size_t)slot].clear(); + } free_draft_feature_snapshot(slot); } @@ -237,6 +245,10 @@ bool Qwen35LayerSplitAdapter::snapshot_used(int slot) const { for (const auto & snap : snaps) { if (!snap.ctx) return false; } + if (snapshot_prefill_logits_.size() != (size_t)PREFIX_SLOTS || + snapshot_prefill_logits_[(size_t)slot].empty()) { + return false; + } if (cfg_.run_dflash && cfg_.draft_path) { if (draft_feature_snapshots_.size() != (size_t)PREFIX_SLOTS) return false; const auto & draft_snap = draft_feature_snapshots_[(size_t)slot]; @@ -261,6 +273,8 @@ bool Qwen35LayerSplitAdapter::snapshot_restore(int slot) { return false; } } + if (snapshot_prefill_logits_.size() != (size_t)PREFIX_SLOTS) return false; + prefill_last_logits_ = snapshot_prefill_logits_[(size_t)slot]; if (!restore_draft_features(slot)) return false; return true; } @@ -377,14 +391,22 @@ bool Qwen35LayerSplitAdapter::decode_ar( std::vector & out_tokens, const DaemonIO & io) { if (n_gen <= 0) return true; + const auto & w = shards_.front().weights; + const int vocab = w.n_vocab; + std::vector logits_buf; + if (sampler_.needs_logit_processing()) { + if ((int)prefill_last_logits_.size() != vocab) return false; + last_tok = sample_logits(prefill_last_logits_.data(), vocab, sampler_, + out_tokens, sampler_rng_); + } out_tokens.push_back(last_tok); io.emit(last_tok); if (io.cancelled) { io.emit(-1); return true; } - if (is_eos_tok(last_tok, shards_.front().weights)) { + if (is_eos_tok(last_tok, w)) { io.emit(-1); return true; } @@ -393,16 +415,24 @@ bool Qwen35LayerSplitAdapter::decode_ar( for (int i = 1; i < n_gen; ++i) { std::vector one(1, last_tok); int next_tok = -1; + logits_buf.clear(); if (!run_qwen35_layer_split_forward( shards_, shards_.front().weights, one, committed, 1, next_tok, cfg_.kq_stride_pad, cfg_.fa_window, - cfg_.run_dflash ? &feature_ring_ : nullptr)) { + cfg_.run_dflash ? &feature_ring_ : nullptr, + /*argmax_out=*/nullptr, + sampler_.needs_logit_processing() ? &logits_buf : nullptr)) { return false; } + if (sampler_.needs_logit_processing()) { + if ((int)logits_buf.size() != vocab) return false; + next_tok = sample_logits(logits_buf.data(), vocab, sampler_, + out_tokens, sampler_rng_); + } out_tokens.push_back(next_tok); io.emit(next_tok); if (io.cancelled) break; - if (is_eos_tok(next_tok, shards_.front().weights)) break; + if (is_eos_tok(next_tok, w)) break; last_tok = next_tok; ++committed; } @@ -411,7 +441,7 @@ bool Qwen35LayerSplitAdapter::decode_ar( } bool Qwen35LayerSplitAdapter::can_dflash_decode() const { - return cfg_.run_dflash && cfg_.draft_path && sampler_.temp == 0.0f; + return cfg_.run_dflash && cfg_.draft_path && !sampler_.needs_logit_processing(); } bool Qwen35LayerSplitAdapter::decode_dflash( @@ -499,6 +529,7 @@ void Qwen35LayerSplitAdapter::shutdown() { for (auto & snap : slot) free_prefix_snapshot(snap); } prefix_snapshots_.clear(); + snapshot_prefill_logits_.clear(); draft_feature_snapshots_.clear(); auto shard_metas = layer_split_shard_metas(shards_); free_layer_split_snapshot_backends(shard_metas, snapshot_backends_); diff --git a/server/src/qwen35/qwen35_layer_split_adapter.h b/server/src/qwen35/qwen35_layer_split_adapter.h index 4011d837a..68ce39fd7 100644 --- a/server/src/qwen35/qwen35_layer_split_adapter.h +++ b/server/src/qwen35/qwen35_layer_split_adapter.h @@ -55,6 +55,7 @@ class Qwen35LayerSplitAdapter : public LayerSplitAdapter { bool decode_ar(int last_tok, int committed, int n_gen, std::vector & out_tokens, const DaemonIO & io) override; + bool supports_cpu_sampling() const override { return true; } bool can_dflash_decode() const override; bool decode_dflash(const std::vector & prompt, int base_pos, @@ -99,6 +100,7 @@ class Qwen35LayerSplitAdapter : public LayerSplitAdapter { bool pflash_drafter_loaded_ = false; static constexpr int PREFIX_SLOTS = ModelBackend::kMaxSlots; std::vector> prefix_snapshots_; + std::vector> snapshot_prefill_logits_; std::vector snapshot_backends_; struct DraftFeatureSnapshot { int cur_pos = 0; @@ -114,6 +116,7 @@ class Qwen35LayerSplitAdapter : public LayerSplitAdapter { SamplerCfg sampler_; std::mt19937_64 sampler_rng_{std::random_device{}()}; std::unique_ptr dflash_target_; + std::vector prefill_last_logits_; }; } // namespace dflash::common diff --git a/server/test/test_server_unit.cpp b/server/test/test_server_unit.cpp index 1415aab30..26dd81696 100644 --- a/server/test/test_server_unit.cpp +++ b/server/test/test_server_unit.cpp @@ -1246,6 +1246,7 @@ struct MockLayerSplitAdapter : LayerSplitAdapter { std::vector emitted_tokens; bool dflash_enabled = false; bool dflash_called = false; + bool sampling_enabled = false; int shutdown_calls = 0; ModelBackend::CompressRequest last_compress_req; @@ -1280,6 +1281,7 @@ struct MockLayerSplitAdapter : LayerSplitAdapter { return true; } bool can_dflash_decode() const override { return dflash_enabled; } + bool supports_cpu_sampling() const override { return sampling_enabled; } bool decode_dflash(const std::vector & prompt, int base_pos, int last_tok, int n_gen, std::vector & out_tokens, const DaemonIO & io) override { @@ -1374,6 +1376,42 @@ static void test_layer_split_backend_inline_snapshot_and_restore_delta() { TEST_ASSERT(raw->dflash_last == 99); } +static void test_layer_split_backend_sampling_capability_gate() { + { + auto * raw = new MockLayerSplitAdapter(); + LayerSplitBackend backend{std::unique_ptr(raw)}; + + GenerateRequest req; + req.prompt = {10, 11}; + req.n_gen = 1; + req.do_sample = true; + req.sampler.temp = 0.8f; + DaemonIO io; + GenerateResult result = backend.generate(req, io); + + TEST_ASSERT(!result.ok); + TEST_ASSERT(result.error == "sampling_unsupported"); + } + + { + auto * raw = new MockLayerSplitAdapter(); + raw->sampling_enabled = true; + LayerSplitBackend backend{std::unique_ptr(raw)}; + + GenerateRequest req; + req.prompt = {10, 11}; + req.n_gen = 1; + req.do_sample = true; + req.sampler.temp = 0.8f; + DaemonIO io; + GenerateResult result = backend.generate(req, io); + + TEST_ASSERT(result.ok); + TEST_ASSERT(result.tokens.size() == 1); + TEST_ASSERT(result.tokens[0] == 12); + } +} + static void test_layer_split_compress_nopark_uses_default_drafter_path() { const std::string ids_path = "/tmp/dflash_test_layer_split_compress_ids.bin"; unlink(ids_path.c_str()); @@ -2548,6 +2586,7 @@ int main() { RUN_TEST(test_parse_target_device_list_single_gpu_is_not_layer_split); RUN_TEST(test_validate_layer_split_weights_shape); RUN_TEST(test_layer_split_backend_inline_snapshot_and_restore_delta); + RUN_TEST(test_layer_split_backend_sampling_capability_gate); RUN_TEST(test_layer_split_compress_nopark_uses_default_drafter_path); RUN_TEST(test_layer_split_compress_rejects_bad_keep_ratio); RUN_TEST(test_layer_split_backend_shutdown_is_idempotent); From 53dd1686a685a054ac5714a1e8a41e759d894ebc Mon Sep 17 00:00:00 2001 From: weicj Date: Fri, 29 May 2026 03:20:25 +0800 Subject: [PATCH 2/3] feat(server): add Laguna target-layer-split adapter --- server/CMakeLists.txt | 1 + server/src/common/backend_factory.cpp | 17 + server/src/laguna/laguna_backend.cpp | 4 +- server/src/laguna/laguna_backend.h | 2 + server/src/laguna/laguna_daemon.cpp | 1 + server/src/laguna/laguna_internal.h | 54 +++ .../src/laguna/laguna_layer_split_adapter.cpp | 419 ++++++++++++++++++ .../src/laguna/laguna_layer_split_adapter.h | 85 ++++ server/src/laguna/laguna_target_graph.cpp | 166 +++++++ server/src/laguna/laguna_target_loader.cpp | 119 ++++- 10 files changed, 852 insertions(+), 16 deletions(-) create mode 100644 server/src/laguna/laguna_layer_split_adapter.cpp create mode 100644 server/src/laguna/laguna_layer_split_adapter.h diff --git a/server/CMakeLists.txt b/server/CMakeLists.txt index 345ee8aee..a433d512c 100644 --- a/server/CMakeLists.txt +++ b/server/CMakeLists.txt @@ -237,6 +237,7 @@ add_library(dflash_common STATIC src/laguna/laguna_target_graph.cpp src/laguna/laguna_daemon.cpp src/laguna/laguna_backend.cpp + src/laguna/laguna_layer_split_adapter.cpp src/common/backend_ipc.cpp src/common/dflash_feature_ring.cpp src/common/dflash_capture.cpp diff --git a/server/src/common/backend_factory.cpp b/server/src/common/backend_factory.cpp index b57c93c03..9a9fd5a4e 100644 --- a/server/src/common/backend_factory.cpp +++ b/server/src/common/backend_factory.cpp @@ -6,6 +6,7 @@ #include "qwen35_backend.h" #include "qwen35moe_backend.h" #include "laguna_backend.h" +#include "laguna_layer_split_adapter.h" #include "qwen3_backend.h" #include "gemma4_backend.h" #include "gemma4_layer_split_adapter.h" @@ -124,8 +125,24 @@ std::unique_ptr create_backend(const BackendArgs & args) { return backend; } else if (arch == "laguna") { + if (args.device.is_layer_split()) { + LagunaLayerSplitAdapterConfig cfg; + cfg.target_path = args.model_path; + cfg.device = args.device; + cfg.chunk = args.chunk; + + auto adapter = std::make_unique(cfg); + auto backend = std::make_unique(std::move(adapter)); + if (!backend->init()) { + std::fprintf(stderr, "[backend_factory] LayerSplitBackend(laguna) init failed\n"); + return nullptr; + } + return backend; + } + LagunaBackendArgs lcfg; lcfg.target_path = args.model_path; + lcfg.device = args.device; lcfg.max_ctx = args.device.max_ctx; lcfg.chunk = args.chunk; // kv_type defaults to Q8_0 in LagunaBackendArgs diff --git a/server/src/laguna/laguna_backend.cpp b/server/src/laguna/laguna_backend.cpp index d6108e4e0..87723f596 100644 --- a/server/src/laguna/laguna_backend.cpp +++ b/server/src/laguna/laguna_backend.cpp @@ -31,9 +31,9 @@ LagunaBackend::LagunaBackend(const LagunaBackendArgs & args) LagunaBackend::~LagunaBackend() { shutdown(); } bool LagunaBackend::init() { - backend_ = ggml_backend_cuda_init(0); + backend_ = ggml_backend_cuda_init(args_.device.gpu); if (!backend_) { - std::fprintf(stderr, "cuda init failed\n"); + std::fprintf(stderr, "cuda init failed gpu=%d\n", args_.device.gpu); return false; } diff --git a/server/src/laguna/laguna_backend.h b/server/src/laguna/laguna_backend.h index 7e487d558..afdaf8f63 100644 --- a/server/src/laguna/laguna_backend.h +++ b/server/src/laguna/laguna_backend.h @@ -8,6 +8,7 @@ #include "model_backend.h" #include "laguna_internal.h" +#include "placement/placement_config.h" #include "qwen3_drafter.h" #include "ggml.h" @@ -22,6 +23,7 @@ namespace dflash::common { struct LagunaBackendArgs { std::string target_path; + DevicePlacement device; int max_ctx = 16384; int chunk = 2048; ggml_type kv_type = GGML_TYPE_Q8_0; diff --git a/server/src/laguna/laguna_daemon.cpp b/server/src/laguna/laguna_daemon.cpp index 3a64e0b8d..952526581 100644 --- a/server/src/laguna/laguna_daemon.cpp +++ b/server/src/laguna/laguna_daemon.cpp @@ -21,6 +21,7 @@ namespace dflash::common { int run_laguna_daemon(const LagunaDaemonArgs & args) { LagunaBackendArgs bargs; bargs.target_path = args.target_path; + bargs.device = args.device; bargs.max_ctx = args.device.max_ctx; bargs.chunk = args.chunk; bargs.kv_type = args.kv_type; diff --git a/server/src/laguna/laguna_internal.h b/server/src/laguna/laguna_internal.h index 505325042..32c981300 100644 --- a/server/src/laguna/laguna_internal.h +++ b/server/src/laguna/laguna_internal.h @@ -26,8 +26,10 @@ #include #include "ggml.h" +#include "ggml-alloc.h" #include "ggml-backend.h" +#include "common/layer_split_utils.h" #include "internal.h" // for CpuEmbedder namespace dflash::common { @@ -134,6 +136,11 @@ bool load_target_gguf_laguna(const std::string & path, ggml_backend_t backend, LagunaTargetWeights & out); +bool load_target_gguf_laguna_partial(const std::string & path, + ggml_backend_t backend, + const TargetLoadPlan & plan, + LagunaTargetWeights & out); + void free_laguna_target_weights(LagunaTargetWeights & w); // ---- Forward graph (Phase 2; signatures only for now) ------------------- @@ -160,6 +167,12 @@ bool create_laguna_target_cache(const LagunaTargetWeights & w, int max_ctx, ggml_backend_t backend, LagunaTargetCache & out); +bool create_laguna_target_cache_partial(const LagunaTargetWeights & w, + int max_ctx, + ggml_backend_t backend, + int layer_begin, + int layer_end, + LagunaTargetCache & out); void free_laguna_target_cache(LagunaTargetCache & c); void reset_laguna_target_cache(LagunaTargetCache & c); @@ -252,4 +265,45 @@ bool laguna_step( bool no_mask, std::vector & out_logits); +struct LagunaLayerStepGraph { + ggml_context * ctx = nullptr; + ggml_cgraph * gf = nullptr; + ggml_gallocr_t alloc = nullptr; + ggml_tensor * positions = nullptr; + ggml_tensor * attn_mask = nullptr; + ggml_tensor * attn_mask_swa = nullptr; +}; + +void laguna_layer_step_graph_free(LagunaLayerStepGraph & sg); +void laguna_layer_step_graph_destroy(LagunaLayerStepGraph & sg); + +bool build_laguna_layer_step( + LagunaLayerStepGraph & sg, + const LagunaTargetWeights & w, + LagunaTargetCache & cache, + ggml_backend_t backend, + int layer_idx, + ggml_tensor * act_in, + ggml_tensor * act_out, + int chunk_start, + int n_tokens, + int kv_start); + +bool compute_laguna_split_argmax( + ggml_backend_t backend, + const LagunaTargetWeights & w, + ggml_tensor * act, + int token_offset, + int n_tokens, + std::vector & out_argmax); + +bool compute_laguna_split_projection( + ggml_backend_t backend, + const LagunaTargetWeights & w, + ggml_tensor * act, + int token_offset, + int n_tokens, + std::vector * out_argmax, + std::vector * out_logits); + } // namespace dflash::common diff --git a/server/src/laguna/laguna_layer_split_adapter.cpp b/server/src/laguna/laguna_layer_split_adapter.cpp new file mode 100644 index 000000000..1f6482e7e --- /dev/null +++ b/server/src/laguna/laguna_layer_split_adapter.cpp @@ -0,0 +1,419 @@ +// Laguna target layer-split adapter. + +#include "laguna_layer_split_adapter.h" + +#include "common/dflash_layer_split_runtime.h" +#include "common/gguf_inspect.h" +#include "common/layer_split_utils.h" +#include "common/sampler.h" +#include "dflash27b.h" + +#include "ggml-cuda.h" + +#include +#include +#include +#include +#include + +namespace dflash::common { + +namespace { + +static bool tensor_ready(const ggml_tensor * t) { + return t && t->buffer; +} + +} // namespace + +LagunaLayerSplitAdapter::LagunaLayerSplitAdapter( + const LagunaLayerSplitAdapterConfig & cfg) + : cfg_(cfg) {} + +LagunaLayerSplitAdapter::~LagunaLayerSplitAdapter() { shutdown(); } + +bool LagunaLayerSplitAdapter::init() { + if (!cfg_.target_path || cfg_.device.layer_split_gpus.size() < 2) { + std::fprintf(stderr, "[laguna-target-split] invalid layer-split config\n"); + return false; + } + + const auto info = inspect_gguf_model_info(cfg_.target_path); + const int n_layer = info.n_layer; + if (n_layer <= 0) { + std::fprintf(stderr, "[laguna-target-split] failed to inspect layer count\n"); + return false; + } + + const auto ranges = compute_layer_ranges( + n_layer, + (int)cfg_.device.layer_split_gpus.size(), + cfg_.device.layer_split_weights); + if (ranges.size() != cfg_.device.layer_split_gpus.size()) { + std::fprintf(stderr, + "[laguna-target-split] bad layer split for %zu GPUs and %d layers\n", + cfg_.device.layer_split_gpus.size(), n_layer); + return false; + } + + shards_.resize(cfg_.device.layer_split_gpus.size()); + auto shard_metas = layer_split_shard_metas(shards_); + if (!init_layer_split_shard_metas( + shard_metas, cfg_.device.layer_split_gpus, ranges, + "laguna-target-split")) { + return false; + } + + (void)enable_layer_split_peer_access( + cfg_.device.layer_split_gpus, cfg_.device.peer_access); + + if (!init_layer_split_snapshot_backends( + shard_metas, snapshot_backends_, "laguna-target-split")) return false; + + for (size_t i = 0; i < shards_.size(); ++i) { + auto & shard = shards_[i]; + const TargetLoadPlan plan = + make_layer_split_load_plan(shard, i + 1 == shards_.size()); + if (!load_target_gguf_laguna_partial( + cfg_.target_path, shard.backend, plan, shard.weights) || + !create_laguna_target_cache_partial( + shard.weights, cfg_.device.max_ctx, shard.backend, + shard.layer_begin, shard.layer_end, shard.cache)) { + std::fprintf(stderr, + "[laguna-target-split] load/cache gpu=%d: %s\n", + shard.gpu, dflash27b_last_error()); + return false; + } + std::fprintf(stderr, "[laguna-target-split] gpu=%d layers=[%d,%d)\n", + shard.gpu, shard.layer_begin, shard.layer_end); + } + + snapshots_.resize(PREFIX_SLOTS); + for (auto & slot : snapshots_) { + slot.shards.resize(shards_.size()); + } + return true; +} + +void LagunaLayerSplitAdapter::begin_request(const GenerateRequest & req) { + sampler_ = req.sampler; + if (req.do_sample && sampler_.seed != 0) { + sampler_rng_.seed(sampler_.seed); + } +} + +void LagunaLayerSplitAdapter::reset_request_state() { + for (auto & shard : shards_) { + reset_laguna_target_cache(shard.cache); + } + prefill_last_logits_.clear(); +} + +bool LagunaLayerSplitAdapter::run_forward( + const std::vector & tokens, + int base_pos, + int & last_tok, + std::vector * logits_out) { + if (shards_.empty() || tokens.empty()) return false; + const LagunaTargetWeights & ref = shards_.front().weights; + const int hidden = ref.n_embd; + const int n_tokens_total = (int)tokens.size(); + int ubatch = cfg_.chunk > 0 ? cfg_.chunk : 2048; + if (const char * e = std::getenv("DFLASH_LAGUNA_LAYER_SPLIT_UBATCH")) { + ubatch = std::max(1, std::atoi(e)); + } + + if (base_pos < 0 || base_pos + n_tokens_total > cfg_.device.max_ctx) { + std::fprintf(stderr, + "[laguna-target-split] range [%d,%d) exceeds max_ctx=%d\n", + base_pos, base_pos + n_tokens_total, cfg_.device.max_ctx); + return false; + } + + ActivationPair acts; + if (!activation_pair_init(acts, shards_.front().backend, hidden, + n_tokens_total)) { + std::fprintf(stderr, "[laguna-target-split] activation alloc failed gpu=%d\n", + shards_.front().gpu); + return false; + } + + { + constexpr int kEmbedBatch = 4096; + std::vector emb((size_t)hidden * std::min(kEmbedBatch, n_tokens_total)); + for (int i = 0; i < n_tokens_total; i += kEmbedBatch) { + const int n = std::min(kEmbedBatch, n_tokens_total - i); + if ((int)emb.size() < hidden * n) emb.resize((size_t)hidden * n); + if (!ref.embedder.embed(tokens.data() + i, n, emb.data())) { + activation_pair_free(acts); + return false; + } + const size_t off = (size_t)i * acts.a->nb[1]; + const size_t bytes = sizeof(float) * (size_t)hidden * n; + ggml_backend_tensor_set(acts.a, emb.data(), off, bytes); + } + } + + ggml_tensor * act_in = acts.a; + ggml_tensor * act_out = acts.b; + LagunaLayerSplitShard * current_shard = &shards_.front(); + for (int il = 0; il < ref.n_layer; ++il) { + LagunaLayerSplitShard * shard = find_layer_split_shard(shards_, il); + if (!shard) { + std::fprintf(stderr, + "[laguna-target-split] missing owner for layer %d\n", il); + activation_pair_free(acts); + return false; + } + if (shard != current_shard) { + ActivationPair next_acts; + if (!activation_pair_init(next_acts, shard->backend, hidden, + n_tokens_total)) { + activation_pair_free(acts); + return false; + } + ggml_backend_synchronize(current_shard->backend); + ggml_backend_tensor_copy(act_in, next_acts.a); + ggml_backend_synchronize(shard->backend); + activation_pair_free(acts); + acts = next_acts; + act_in = acts.a; + act_out = acts.b; + current_shard = shard; + } + + for (int start = 0; start < n_tokens_total;) { + const int n = std::min(ubatch, n_tokens_total - start); + const int kv_start = base_pos + start; + if (!build_laguna_layer_step( + shard->layer_graph, shard->weights, shard->cache, + shard->backend, il, act_in, act_out, start, n, kv_start)) { + std::fprintf(stderr, + "[laguna-target-split] build layer=%d @%d gpu=%d\n", + il, start, shard->gpu); + activation_pair_free(acts); + return false; + } + + std::vector pos((size_t)n); + for (int i = 0; i < n; ++i) pos[(size_t)i] = kv_start + i; + if (!tensor_ready(shard->layer_graph.positions)) { + activation_pair_free(acts); + return false; + } + ggml_backend_tensor_set(shard->layer_graph.positions, pos.data(), 0, + sizeof(int32_t) * pos.size()); + + const int kv_len = kv_start + n; + std::vector mfull((size_t)kv_len * n, -INFINITY); + for (int q = 0; q < n; ++q) { + const int abs_q = kv_start + q; + for (int k = 0; k <= abs_q && k < kv_len; ++k) { + mfull[(size_t)q * kv_len + k] = 0.0f; + } + } + if (tensor_ready(shard->layer_graph.attn_mask)) { + ggml_backend_tensor_set(shard->layer_graph.attn_mask, + mfull.data(), 0, + ggml_nbytes(shard->layer_graph.attn_mask)); + } + + std::vector mswa((size_t)kv_len * n, -INFINITY); + const int W = ref.sliding_window; + for (int q = 0; q < n; ++q) { + const int abs_q = kv_start + q; + const int win_lo = std::max(0, abs_q - W + 1); + for (int k = win_lo; k <= abs_q && k < kv_len; ++k) { + mswa[(size_t)q * kv_len + k] = 0.0f; + } + } + if (tensor_ready(shard->layer_graph.attn_mask_swa)) { + ggml_backend_tensor_set(shard->layer_graph.attn_mask_swa, + mswa.data(), 0, + ggml_nbytes(shard->layer_graph.attn_mask_swa)); + } + + auto st = ggml_backend_graph_compute(shard->backend, + shard->layer_graph.gf); + if (st != GGML_STATUS_SUCCESS) { + std::fprintf(stderr, + "[laguna-target-split] compute layer=%d @%d gpu=%d status=%d\n", + il, start, shard->gpu, (int)st); + activation_pair_free(acts); + return false; + } + start += n; + } + std::swap(act_in, act_out); + } + + std::vector argmax; + LagunaLayerSplitShard & last = shards_.back(); + const bool ok = compute_laguna_split_projection( + last.backend, last.weights, act_in, + n_tokens_total - 1, 1, &argmax, logits_out); + activation_pair_free(acts); + if (!ok || argmax.empty()) return false; + last_tok = argmax.back(); + for (auto & shard : shards_) { + shard.cache.cur_pos = base_pos + n_tokens_total; + shard.cache.last_tok = last_tok; + } + return true; +} + +bool LagunaLayerSplitAdapter::prefill(const std::vector & prompt, + int base_pos, + int & last_tok) { + return run_forward(prompt, base_pos, last_tok, &prefill_last_logits_); +} + +bool LagunaLayerSplitAdapter::decode_ar( + int last_tok, + int committed, + int n_gen, + std::vector & out_tokens, + const DaemonIO & io) { + if (n_gen <= 0) return true; + if (shards_.empty()) return false; + + const auto & w = shards_.front().weights; + const int vocab = (int)w.embedder.n_vocab; + std::vector logits_buf; + if (sampler_.needs_logit_processing()) { + if ((int)prefill_last_logits_.size() != vocab) return false; + last_tok = sample_logits(prefill_last_logits_.data(), vocab, sampler_, + out_tokens, sampler_rng_); + } + out_tokens.push_back(last_tok); + io.emit(last_tok); + if (io.cancelled) { + io.emit(-1); + return true; + } + if (last_tok == w.eos_id || last_tok == w.eos_chat_id) { + io.emit(-1); + return true; + } + ++committed; + + for (int i = 1; i < n_gen; ++i) { + std::vector one(1, last_tok); + int next_tok = -1; + logits_buf.clear(); + if (!run_forward(one, committed - 1, next_tok, + sampler_.needs_logit_processing() ? &logits_buf : nullptr)) { + return false; + } + if (sampler_.needs_logit_processing()) { + if ((int)logits_buf.size() != vocab) return false; + next_tok = sample_logits(logits_buf.data(), vocab, sampler_, + out_tokens, sampler_rng_); + } + last_tok = next_tok; + out_tokens.push_back(last_tok); + io.emit(last_tok); + ++committed; + if (io.cancelled) break; + if (last_tok == w.eos_id || last_tok == w.eos_chat_id) break; + } + io.emit(-1); + return true; +} + +bool LagunaLayerSplitAdapter::snapshot_save(int slot) { + if (slot < 0 || slot >= PREFIX_SLOTS || shards_.empty()) return false; + if (snapshot_backends_.size() != shards_.size()) return false; + auto & snap = snapshots_[(size_t)slot]; + const int snap_pos = shards_.front().cache.cur_pos; + if (snap_pos <= 0) return false; + + snapshot_free(slot); + if (snap.shards.size() != shards_.size()) snap.shards.resize(shards_.size()); + for (size_t i = 0; i < shards_.size(); ++i) { + if (!laguna_snapshot_save(shards_[i].cache, snapshot_backends_[i], + shards_[i].weights.n_layer, + shards_[i].weights.n_head_kv, + shards_[i].weights.head_dim, + snap.shards[i])) { + snapshot_free(slot); + return false; + } + } + snap.cur_pos = snap_pos; + snap.last_tok = shards_.front().cache.last_tok; + snap.prefill_last_logits = prefill_last_logits_; + return true; +} + +void LagunaLayerSplitAdapter::snapshot_free(int slot) { + if (slot < 0 || slot >= PREFIX_SLOTS || snapshots_.empty()) return; + auto & snap = snapshots_[(size_t)slot]; + for (auto & ss : snap.shards) { + laguna_snapshot_free(ss); + } + snap.cur_pos = 0; + snap.last_tok = -1; + snap.prefill_last_logits.clear(); + if (snap.shards.size() != shards_.size()) snap.shards.resize(shards_.size()); +} + +bool LagunaLayerSplitAdapter::snapshot_used(int slot) const { + if (slot < 0 || slot >= PREFIX_SLOTS || + snapshots_.size() != (size_t)PREFIX_SLOTS) { + return false; + } + const auto & snap = snapshots_[(size_t)slot]; + if (snap.cur_pos <= 0 || snap.shards.size() != shards_.size()) return false; + if (snap.prefill_last_logits.empty()) return false; + for (const auto & ss : snap.shards) { + if (!ss.used) return false; + } + return true; +} + +int LagunaLayerSplitAdapter::snapshot_cur_pos(int slot) const { + return snapshot_used(slot) ? snapshots_[(size_t)slot].cur_pos : 0; +} + +bool LagunaLayerSplitAdapter::snapshot_restore(int slot) { + if (!snapshot_used(slot)) return false; + auto & snap = snapshots_[(size_t)slot]; + for (size_t i = 0; i < shards_.size(); ++i) { + if (snap.shards[i].cur_pos != snap.cur_pos) return false; + if (!laguna_snapshot_restore(snap.shards[i], shards_[i].cache)) { + return false; + } + shards_[i].cache.last_tok = snap.last_tok; + } + prefill_last_logits_ = snap.prefill_last_logits; + return true; +} + +int LagunaLayerSplitAdapter::current_last_token() const { + if (shards_.empty()) return -1; + return shards_.front().cache.last_tok; +} + +void LagunaLayerSplitAdapter::shutdown() { + for (int i = 0; i < PREFIX_SLOTS; ++i) snapshot_free(i); + auto shard_metas = layer_split_shard_metas(shards_); + free_layer_split_snapshot_backends(shard_metas, snapshot_backends_); + free_laguna_layer_split_shards(shards_); +} + +void free_laguna_layer_split_shards( + std::vector & shards) { + for (auto & shard : shards) { + laguna_layer_step_graph_destroy(shard.layer_graph); + free_laguna_target_cache(shard.cache); + free_laguna_target_weights(shard.weights); + if (shard.backend) { + ggml_backend_free(shard.backend); + shard.backend = nullptr; + } + } + shards.clear(); +} + +} // namespace dflash::common diff --git a/server/src/laguna/laguna_layer_split_adapter.h b/server/src/laguna/laguna_layer_split_adapter.h new file mode 100644 index 000000000..882432488 --- /dev/null +++ b/server/src/laguna/laguna_layer_split_adapter.h @@ -0,0 +1,85 @@ +// Laguna target layer-split adapter. + +#pragma once + +#include "common/layer_split_backend.h" +#include "common/layer_split_utils.h" +#include "laguna_internal.h" +#include "placement/placement_config.h" + +#include "ggml-backend.h" + +#include +#include + +namespace dflash::common { + +struct LagunaLayerSplitAdapterConfig { + const char * target_path = nullptr; + DevicePlacement device; + int chunk = 2048; +}; + +struct LagunaLayerSplitShard : LayerSplitShardMeta { + LagunaTargetWeights weights; + LagunaTargetCache cache; + LagunaLayerStepGraph layer_graph; +}; + +struct LagunaLayerSplitSnapshot { + int cur_pos = 0; + int32_t last_tok = -1; + std::vector shards; + std::vector prefill_last_logits; +}; + +class LagunaLayerSplitAdapter : public LayerSplitAdapter { +public: + explicit LagunaLayerSplitAdapter(const LagunaLayerSplitAdapterConfig & cfg); + ~LagunaLayerSplitAdapter() override; + + LagunaLayerSplitAdapter(const LagunaLayerSplitAdapter &) = delete; + LagunaLayerSplitAdapter & operator=(const LagunaLayerSplitAdapter &) = delete; + + const char * name() const override { return "laguna"; } + bool init() override; + int max_context() const override { return cfg_.device.max_ctx; } + + void begin_request(const GenerateRequest & req) override; + void reset_request_state() override; + bool prefill(const std::vector & prompt, + int base_pos, int & last_tok) override; + bool decode_ar(int last_tok, int committed, int n_gen, + std::vector & out_tokens, + const DaemonIO & io) override; + bool supports_cpu_sampling() const override { return true; } + + bool snapshot_save(int slot) override; + void snapshot_free(int slot) override; + bool snapshot_used(int slot) const override; + int snapshot_cur_pos(int slot) const override; + bool snapshot_restore(int slot) override; + int current_last_token() const override; + + void free_drafter() override {} + void shutdown() override; + +private: + bool run_forward(const std::vector & tokens, + int base_pos, + int & last_tok, + std::vector * logits_out = nullptr); + + LagunaLayerSplitAdapterConfig cfg_; + std::vector shards_; + std::vector snapshot_backends_; + std::vector snapshots_; + static constexpr int PREFIX_SLOTS = ModelBackend::kMaxSlots; + SamplerCfg sampler_; + std::mt19937_64 sampler_rng_{std::random_device{}()}; + std::vector prefill_last_logits_; +}; + +void free_laguna_layer_split_shards(std::vector & shards); + +} // namespace dflash::common diff --git a/server/src/laguna/laguna_target_graph.cpp b/server/src/laguna/laguna_target_graph.cpp index 8f2e3c638..89d460990 100644 --- a/server/src/laguna/laguna_target_graph.cpp +++ b/server/src/laguna/laguna_target_graph.cpp @@ -43,6 +43,23 @@ bool create_laguna_target_cache(const LagunaTargetWeights & w, int max_ctx, ggml_backend_t backend, LagunaTargetCache & out) { + return create_laguna_target_cache_partial( + w, max_ctx, backend, /*layer_begin=*/0, /*layer_end=*/w.n_layer, out); +} + +bool create_laguna_target_cache_partial(const LagunaTargetWeights & w, + int max_ctx, + ggml_backend_t backend, + int layer_begin, + int layer_end, + LagunaTargetCache & out) { + if (layer_begin < 0) layer_begin = 0; + if (layer_end < 0) layer_end = w.n_layer; + if (layer_begin > layer_end || layer_end > w.n_layer) { + set_last_error("laguna cache: invalid layer range"); + return false; + } + out.backend = backend; out.max_ctx = max_ctx; out.cur_pos = 0; @@ -66,6 +83,7 @@ bool create_laguna_target_cache(const LagunaTargetWeights & w, out.attn_k.resize(w.n_layer, nullptr); out.attn_v.resize(w.n_layer, nullptr); for (int il = 0; il < w.n_layer; ++il) { + if (il < layer_begin || il >= layer_end) continue; char nm[32]; std::snprintf(nm, sizeof(nm), "k_l%d", il); ggml_tensor * k = ggml_new_tensor_3d(out.base_ctx, k_type, w.head_dim, max_ctx, w.n_head_kv); @@ -89,6 +107,7 @@ bool create_laguna_target_cache(const LagunaTargetWeights & w, std::vector zeros(std::min(buf_sz, 64 * 1024 * 1024), 0); for (int il = 0; il < w.n_layer; ++il) { for (auto * t : { out.attn_k[il], out.attn_v[il] }) { + if (!t) continue; const size_t sz = ggml_nbytes(t); for (size_t off = 0; off < sz; off += zeros.size()) { const size_t chunk = std::min(zeros.size(), sz - off); @@ -119,6 +138,7 @@ bool laguna_snapshot_alloc(const LagunaTargetCache & cache, out.attn_k.assign((size_t)n_layer, nullptr); out.attn_v.assign((size_t)n_layer, nullptr); for (int il = 0; il < n_layer; ++il) { + if (!cache.attn_k[il] || !cache.attn_v[il]) continue; char nm[32]; std::snprintf(nm, sizeof(nm), "snap_k_l%d", il); // Right-sized: [head_dim, snap_pos, n_head_kv] @@ -179,6 +199,7 @@ bool laguna_snapshot_save(const LagunaTargetCache & cache, ggml_tensor * dk = snap.attn_k[il]; ggml_tensor * sv = cache.attn_v[il]; ggml_tensor * dv = snap.attn_v[il]; + if (!sk || !dk || !sv || !dv) continue; const size_t k_strip = (size_t)snap_pos * sk->nb[1]; const size_t v_strip = (size_t)snap_pos * sv->nb[1]; for (int kh = 0; kh < n_head_kv; kh++) { @@ -210,6 +231,7 @@ bool laguna_snapshot_restore(const LagunaCacheSnapshot & snap, ggml_tensor * dk = cache.attn_k[il]; ggml_tensor * sv = snap.attn_v[il]; ggml_tensor * dv = cache.attn_v[il]; + if (!sk || !dk || !sv || !dv) continue; const size_t k_strip = (size_t)snap_pos * sk->nb[1]; const size_t v_strip = (size_t)snap_pos * sv->nb[1]; for (int kh = 0; kh < (int)sk->ne[2]; kh++) { @@ -639,6 +661,150 @@ static ggml_tensor * build_laguna_layer( return ggml_add(ctx, cur, ffn_inp); } +void laguna_layer_step_graph_free(LagunaLayerStepGraph & sg) { + if (sg.ctx) { + ggml_free(sg.ctx); + sg.ctx = nullptr; + } + sg.gf = nullptr; + sg.positions = nullptr; + sg.attn_mask = nullptr; + sg.attn_mask_swa = nullptr; +} + +void laguna_layer_step_graph_destroy(LagunaLayerStepGraph & sg) { + if (sg.alloc) { + ggml_gallocr_free(sg.alloc); + sg.alloc = nullptr; + } + laguna_layer_step_graph_free(sg); +} + +bool build_laguna_layer_step( + LagunaLayerStepGraph & sg, + const LagunaTargetWeights & w, + LagunaTargetCache & cache, + ggml_backend_t backend, + int layer_idx, + ggml_tensor * act_in, + ggml_tensor * act_out, + int chunk_start, + int n_tokens, + int kv_start) { + laguna_layer_step_graph_free(sg); + if (layer_idx < 0 || layer_idx >= w.n_layer) return false; + if (!cache.attn_k[layer_idx] || !cache.attn_v[layer_idx]) return false; + + ggml_init_params ip{}; + ip.mem_size = ggml_tensor_overhead() * 16384 + ggml_graph_overhead() + 16 * 1024 * 1024; + ip.no_alloc = true; + sg.ctx = ggml_init(ip); + if (!sg.ctx) return false; + sg.gf = ggml_new_graph_custom(sg.ctx, 16384, false); + + ggml_tensor * inp = ggml_view_2d( + sg.ctx, act_in, w.n_embd, n_tokens, + act_in->nb[1], (size_t)chunk_start * act_in->nb[1]); + ggml_set_input(inp); + + sg.positions = ggml_new_tensor_1d(sg.ctx, GGML_TYPE_I32, n_tokens); + ggml_set_input(sg.positions); + + const int kv_len = kv_start + n_tokens; + sg.attn_mask = ggml_new_tensor_4d(sg.ctx, GGML_TYPE_F32, kv_len, n_tokens, 1, 1); + ggml_set_input(sg.attn_mask); + ggml_tensor * mask_full_f16 = ggml_cast(sg.ctx, sg.attn_mask, GGML_TYPE_F16); + + sg.attn_mask_swa = ggml_new_tensor_4d(sg.ctx, GGML_TYPE_F32, kv_len, n_tokens, 1, 1); + ggml_set_input(sg.attn_mask_swa); + ggml_tensor * mask_swa_f16 = ggml_cast(sg.ctx, sg.attn_mask_swa, GGML_TYPE_F16); + + ggml_tensor * layer_out = build_laguna_layer( + sg.ctx, sg.gf, w, cache, layer_idx, inp, sg.positions, + mask_full_f16, kv_start, n_tokens, mask_swa_f16); + if (!layer_out) return false; + + ggml_tensor * out_view = ggml_view_2d( + sg.ctx, act_out, w.n_embd, n_tokens, + act_out->nb[1], (size_t)chunk_start * act_out->nb[1]); + ggml_build_forward_expand(sg.gf, ggml_cpy(sg.ctx, layer_out, out_view)); + + if (!sg.alloc) { + sg.alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + } + return ggml_gallocr_alloc_graph(sg.alloc, sg.gf); +} + +bool compute_laguna_split_projection( + ggml_backend_t backend, + const LagunaTargetWeights & w, + ggml_tensor * act, + int token_offset, + int n_tokens, + std::vector * out_argmax, + std::vector * out_logits) { + ggml_init_params ip{}; + ip.mem_size = ggml_tensor_overhead() * 64 + ggml_graph_overhead() + 1024 * 1024; + ip.no_alloc = true; + ggml_context * ctx = ggml_init(ip); + if (!ctx) return false; + ggml_cgraph * gf = ggml_new_graph(ctx); + + ggml_tensor * act_view = ggml_view_2d( + ctx, act, w.n_embd, n_tokens, act->nb[1], + (size_t)token_offset * act->nb[1]); + ggml_tensor * cur = laguna_rms_norm_mul(ctx, act_view, w.out_norm); + cur = ggml_mul_mat(ctx, w.output, cur); + ggml_tensor * logits = cur; + ggml_tensor * argmax = nullptr; + if (out_logits) { + ggml_set_output(logits); + ggml_build_forward_expand(gf, logits); + } + if (out_argmax) { + argmax = ggml_argmax(ctx, logits); + ggml_set_output(argmax); + ggml_build_forward_expand(gf, argmax); + } + + ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + if (!alloc || !ggml_gallocr_alloc_graph(alloc, gf)) { + if (alloc) ggml_gallocr_free(alloc); + ggml_free(ctx); + return false; + } + if (ggml_backend_graph_compute(backend, gf) != GGML_STATUS_SUCCESS) { + ggml_gallocr_free(alloc); + ggml_free(ctx); + return false; + } + if (out_argmax) { + out_argmax->resize((size_t)n_tokens); + ggml_backend_tensor_get(argmax, out_argmax->data(), 0, + sizeof(int32_t) * (size_t)n_tokens); + } + if (out_logits) { + const int vocab = (int)w.embedder.n_vocab; + out_logits->resize((size_t)vocab * (size_t)n_tokens); + ggml_backend_tensor_get(logits, out_logits->data(), 0, + sizeof(float) * (size_t)vocab * (size_t)n_tokens); + } + ggml_gallocr_free(alloc); + ggml_free(ctx); + return true; +} + +bool compute_laguna_split_argmax( + ggml_backend_t backend, + const LagunaTargetWeights & w, + ggml_tensor * act, + int token_offset, + int n_tokens, + std::vector & out_argmax) { + return compute_laguna_split_projection( + backend, w, act, token_offset, n_tokens, &out_argmax, nullptr); +} + LagunaGraphOutputs build_laguna_graph( ggml_context * ctx, ggml_cgraph * gf, diff --git a/server/src/laguna/laguna_target_loader.cpp b/server/src/laguna/laguna_target_loader.cpp index 495565f94..cd9b617a2 100644 --- a/server/src/laguna/laguna_target_loader.cpp +++ b/server/src/laguna/laguna_target_loader.cpp @@ -42,10 +42,13 @@ #include "dflash27b.h" #include +#include #include #include +#include #include #include +#include #if !defined(_WIN32) #include @@ -129,11 +132,58 @@ bool get_bool_or(const gguf_context * g, const char * key, bool fallback) { int64_t id = gguf_find_key(g, key); return (id < 0) ? fallback : (bool)gguf_get_val_bool(g, id); } +size_t align_up_size(size_t x, size_t a) { + if (a == 0) return x; + const size_t r = x % a; + return r == 0 ? x : x + (a - r); +} + +bool parse_block_tensor_name(const char * name, int & layer_id) { + const char prefix[] = "blk."; + const size_t prefix_len = sizeof(prefix) - 1; + if (std::strncmp(name, prefix, prefix_len) != 0) return false; + const char * p = name + prefix_len; + if (*p < '0' || *p > '9') return false; + char * end = nullptr; + const long v = std::strtol(p, &end, 10); + if (!end || *end != '.' || v < 0 || v > INT_MAX) return false; + layer_id = (int)v; + return true; +} + +bool should_load_laguna_tensor(const char * name, const TargetLoadPlan & plan) { + if (std::strcmp(name, "token_embd.weight") == 0) return false; + if (std::strcmp(name, "output_norm.weight") == 0 || + std::strcmp(name, "output.weight") == 0) { + return plan.load_output; + } + int layer_id = -1; + if (parse_block_tensor_name(name, layer_id)) { + return layer_id >= plan.layer_begin && layer_id < plan.layer_end; + } + return false; +} + +struct LagunaTensorAlloc { + ggml_tensor * tensor = nullptr; + size_t file_offset = 0; + size_t file_size = 0; + size_t buffer_offset = 0; +}; + } // namespace bool load_target_gguf_laguna(const std::string & path, ggml_backend_t backend, LagunaTargetWeights & out) { + TargetLoadPlan plan; + return load_target_gguf_laguna_partial(path, backend, plan, out); +} + +bool load_target_gguf_laguna_partial(const std::string & path, + ggml_backend_t backend, + const TargetLoadPlan & plan_in, + LagunaTargetWeights & out) { // ── 1. Parse metadata ──────────────────────────────────────────────── ggml_context * meta_ctx = nullptr; @@ -393,28 +443,67 @@ bool load_target_gguf_laguna(const std::string & path, } } - // ── 3. Allocate CUDA buffer for tensors. Pre-pin tok_embd to host memory - // so the allocator skips it (only allocates tensors with data == NULL). - // Saves ~110 MiB VRAM; the embedder reads tok_embd directly from mmap. + TargetLoadPlan plan = plan_in; + if (plan.layer_begin < 0) plan.layer_begin = 0; + if (plan.layer_end < 0) plan.layer_end = (int)n_layer; + if (plan.layer_begin > plan.layer_end || plan.layer_end > (int)n_layer) { + char e[160]; + std::snprintf(e, sizeof(e), + "laguna: invalid layer range [%d,%d) for n_layer=%u", + plan.layer_begin, plan.layer_end, n_layer); + set_last_error(e); + gguf_free(gctx); + return false; + } + + // ── 3. Allocate backend buffer only for selected tensors. Token embedding + // stays CPU-only and is owned by the CpuEmbedder mmap. LagunaMmap mm; std::string err; if (!mm.open_ro(path, err)) { set_last_error(err); gguf_free(gctx); return false; } - const size_t data_start = gguf_get_data_offset(gctx); + const size_t data_start = gguf_get_data_offset(gctx); const int64_t n_tensors = gguf_get_n_tensors(gctx); + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); + const size_t alignment = ggml_backend_buft_get_alignment(buft); + std::vector allocs; + size_t alloc_total = 0; for (int64_t tid = 0; tid < n_tensors; ++tid) { - if (std::strcmp(gguf_get_tensor_name(gctx, tid), "token_embd.weight") == 0) { - out.tok_embd->data = (uint8_t *)mm.addr + - data_start + gguf_get_tensor_offset(gctx, tid); - break; - } + const char * tname = gguf_get_tensor_name(gctx, tid); + ggml_tensor * t = ggml_get_tensor(meta_ctx, tname); + if (!t || !should_load_laguna_tensor(tname, plan)) continue; + alloc_total = align_up_size(alloc_total, alignment); + LagunaTensorAlloc a; + a.tensor = t; + a.file_offset = data_start + gguf_get_tensor_offset(gctx, tid); + a.file_size = gguf_get_tensor_size(gctx, tid); + a.buffer_offset = alloc_total; + alloc_total += ggml_backend_buft_get_alloc_size(buft, t); + allocs.push_back(a); + } + if (allocs.empty()) { + set_last_error("laguna: load plan selected no GPU tensors"); + gguf_free(gctx); + return false; } - out.buf = ggml_backend_alloc_ctx_tensors(meta_ctx, backend); + + out.buf = ggml_backend_alloc_buffer(backend, alloc_total); if (!out.buf) { - set_last_error("ggml_backend_alloc_ctx_tensors failed (laguna target)"); + set_last_error("ggml_backend_alloc_buffer failed (laguna target)"); gguf_free(gctx); return false; } + ggml_backend_buffer_set_usage(out.buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + + char * base = (char *)ggml_backend_buffer_get_base(out.buf); + for (const LagunaTensorAlloc & a : allocs) { + if (ggml_backend_tensor_alloc(out.buf, a.tensor, + base + a.buffer_offset) != GGML_STATUS_SUCCESS) { + set_last_error("ggml_backend_tensor_alloc failed (laguna target)"); + gguf_free(gctx); + return false; + } + } - // ── 4. Copy tensor bytes to GPU; remember tok_embd offset for the embedder ─ + // ── 4. Copy selected tensor bytes to GPU; remember tok_embd for embedder ─ size_t total = 0; size_t tok_embd_off = 0, tok_embd_sz = 0; ggml_type tok_embd_type = GGML_TYPE_COUNT; @@ -434,6 +523,7 @@ bool load_target_gguf_laguna(const std::string & path, tok_embd_type = gguf_get_tensor_type(gctx, tid); continue; } + if (!should_load_laguna_tensor(tname, plan)) continue; ggml_backend_tensor_set(t, (const uint8_t *)mm.addr + off, 0, sz); total += sz; } @@ -463,8 +553,9 @@ bool load_target_gguf_laguna(const std::string & path, char summary[224]; std::snprintf(summary, sizeof(summary), - "laguna target loaded: %" PRId64 " tensors on GPU %.2f GiB, tok_embd %.0f MiB CPU-only (%s)", - n_tensors, total / (1024.0 * 1024.0 * 1024.0), + "laguna target loaded: layers [%d,%d) output=%d tensors=%zu GPU %.2f GiB, tok_embd %.0f MiB CPU-only (%s)", + plan.layer_begin, plan.layer_end, plan.load_output ? 1 : 0, + allocs.size(), total / (1024.0 * 1024.0 * 1024.0), tok_embd_sz / (1024.0 * 1024.0), ggml_type_name(tok_embd_type)); set_last_error(summary); std::printf("[laguna-loader] %s\n", summary); From 988fc93350de2b6dce60d365097f7c68ea63540a Mon Sep 17 00:00:00 2001 From: weicj Date: Fri, 29 May 2026 15:44:46 +0800 Subject: [PATCH 3/3] refactor(server): share target layer-split runtime helpers --- server/CMakeLists.txt | 1 + server/src/common/layer_split_runtime.cpp | 64 +++++++++++++ server/src/common/layer_split_runtime.h | 91 ++++++++++++++++++ .../src/gemma4/gemma4_layer_split_adapter.cpp | 90 ++++-------------- .../src/laguna/laguna_layer_split_adapter.cpp | 90 ++++-------------- .../src/qwen35/qwen35_layer_split_adapter.cpp | 92 ++++--------------- 6 files changed, 206 insertions(+), 222 deletions(-) create mode 100644 server/src/common/layer_split_runtime.cpp create mode 100644 server/src/common/layer_split_runtime.h diff --git a/server/CMakeLists.txt b/server/CMakeLists.txt index a433d512c..519c47de5 100644 --- a/server/CMakeLists.txt +++ b/server/CMakeLists.txt @@ -247,6 +247,7 @@ add_library(dflash_common STATIC src/common/dflash_draft_graph.cpp src/common/dflash_spec_decode.cpp src/common/layer_split_backend.cpp + src/common/layer_split_runtime.cpp src/qwen35/graph_builders.cpp src/qwen35moe/qwen35moe_ffn.cpp src/qwen35moe/qwen35moe_backend.cpp diff --git a/server/src/common/layer_split_runtime.cpp b/server/src/common/layer_split_runtime.cpp new file mode 100644 index 000000000..66941ba36 --- /dev/null +++ b/server/src/common/layer_split_runtime.cpp @@ -0,0 +1,64 @@ +#include "layer_split_runtime.h" + +namespace dflash::common { + +bool run_layer_split_ar_decode( + int last_tok, + int committed, + int n_gen, + int vocab, + const std::vector & prefill_last_logits, + const SamplerCfg & sampler, + std::mt19937_64 & rng, + const LayerSplitForwardStep & forward_one, + const std::function & is_eos, + std::vector & out_tokens, + const DaemonIO & io) { + if (n_gen <= 0) return true; + + if (sampler.needs_logit_processing()) { + if ((int)prefill_last_logits.size() != vocab) return false; + last_tok = sample_logits(prefill_last_logits.data(), vocab, sampler, + out_tokens, rng); + } + + out_tokens.push_back(last_tok); + io.emit(last_tok); + if (io.cancelled) { + io.emit(-1); + return true; + } + if (is_eos(last_tok)) { + io.emit(-1); + return true; + } + ++committed; + + std::vector logits_buf; + for (int i = 1; i < n_gen; ++i) { + std::vector one(1, last_tok); + int next_tok = -1; + logits_buf.clear(); + if (!forward_one(one, committed, next_tok, + sampler.needs_logit_processing() ? &logits_buf : nullptr)) { + return false; + } + if (sampler.needs_logit_processing()) { + if ((int)logits_buf.size() != vocab) return false; + next_tok = sample_logits(logits_buf.data(), vocab, sampler, + out_tokens, rng); + } + + last_tok = next_tok; + out_tokens.push_back(last_tok); + io.emit(last_tok); + ++committed; + if (io.cancelled) break; + if (is_eos(last_tok)) break; + } + + io.emit(-1); + return true; +} + +} // namespace dflash::common diff --git a/server/src/common/layer_split_runtime.h b/server/src/common/layer_split_runtime.h new file mode 100644 index 000000000..79566d2fe --- /dev/null +++ b/server/src/common/layer_split_runtime.h @@ -0,0 +1,91 @@ +// Shared target layer-split runtime helpers. +// +// Model adapters still own their partial loaders, graph builders, caches, and +// snapshot payloads. This file keeps shared adapter runtime flow in one place +// so new adapters do not copy the same shell. + +#pragma once + +#include "gguf_inspect.h" +#include "layer_split_utils.h" +#include "model_backend.h" +#include "placement/placement_config.h" +#include "sampler.h" + +#include +#include +#include +#include + +namespace dflash::common { + +struct LayerSplitRuntimeInit { + const char * target_path = nullptr; + const DevicePlacement * device = nullptr; + const char * log_prefix = "target-split"; +}; + +template +bool init_layer_split_runtime(const LayerSplitRuntimeInit & cfg, + std::vector & shards, + std::vector & snapshot_backends) { + if (!cfg.target_path || !cfg.device || + cfg.device->layer_split_gpus.size() < 2) { + std::fprintf(stderr, "[%s] invalid layer-split config\n", cfg.log_prefix); + return false; + } + + const auto info = inspect_gguf_model_info(cfg.target_path); + const int n_layer = info.n_layer; + if (n_layer <= 0) { + std::fprintf(stderr, "[%s] failed to inspect target layer count\n", + cfg.log_prefix); + return false; + } + + const auto ranges = compute_layer_ranges( + n_layer, + (int)cfg.device->layer_split_gpus.size(), + cfg.device->layer_split_weights); + if (ranges.size() != cfg.device->layer_split_gpus.size()) { + std::fprintf(stderr, + "[%s] bad layer split for %zu GPUs and %d layers\n", + cfg.log_prefix, cfg.device->layer_split_gpus.size(), n_layer); + return false; + } + + shards.resize(cfg.device->layer_split_gpus.size()); + auto shard_metas = layer_split_shard_metas(shards); + if (!init_layer_split_shard_metas( + shard_metas, cfg.device->layer_split_gpus, ranges, + cfg.log_prefix)) { + return false; + } + + (void)enable_layer_split_peer_access( + cfg.device->layer_split_gpus, cfg.device->peer_access); + + return init_layer_split_snapshot_backends( + shard_metas, snapshot_backends, cfg.log_prefix); +} + +using LayerSplitForwardStep = std::function & tokens, + int committed, + int & next_tok, + std::vector * logits_out)>; + +bool run_layer_split_ar_decode( + int last_tok, + int committed, + int n_gen, + int vocab, + const std::vector & prefill_last_logits, + const SamplerCfg & sampler, + std::mt19937_64 & rng, + const LayerSplitForwardStep & forward_one, + const std::function & is_eos, + std::vector & out_tokens, + const DaemonIO & io); + +} // namespace dflash::common diff --git a/server/src/gemma4/gemma4_layer_split_adapter.cpp b/server/src/gemma4/gemma4_layer_split_adapter.cpp index 5a6be860a..dea8aca2f 100644 --- a/server/src/gemma4/gemma4_layer_split_adapter.cpp +++ b/server/src/gemma4/gemma4_layer_split_adapter.cpp @@ -5,6 +5,7 @@ #include "common/dflash_layer_split_runtime.h" #include "common/gguf_inspect.h" #include "common/layer_split_utils.h" +#include "common/layer_split_runtime.h" #include "dflash27b.h" #include "ggml-cuda.h" @@ -74,43 +75,15 @@ Gemma4LayerSplitAdapter::Gemma4LayerSplitAdapter( Gemma4LayerSplitAdapter::~Gemma4LayerSplitAdapter() { shutdown(); } bool Gemma4LayerSplitAdapter::init() { - if (!cfg_.target_path || cfg_.device.layer_split_gpus.size() < 2) { - std::fprintf(stderr, "[gemma4-target-split] invalid layer-split config\n"); + const LayerSplitRuntimeInit runtime_cfg{ + cfg_.target_path, + &cfg_.device, + "gemma4-target-split", + }; + if (!init_layer_split_runtime(runtime_cfg, shards_, snapshot_backends_)) { return false; } - const auto info = inspect_gguf_model_info(cfg_.target_path); - const int n_layer = info.n_layer; - if (n_layer <= 0) { - std::fprintf(stderr, "[gemma4-target-split] failed to inspect layer count\n"); - return false; - } - - const auto ranges = compute_layer_ranges( - n_layer, - (int)cfg_.device.layer_split_gpus.size(), - cfg_.device.layer_split_weights); - if (ranges.size() != cfg_.device.layer_split_gpus.size()) { - std::fprintf(stderr, - "[gemma4-target-split] bad layer split for %zu GPUs and %d layers\n", - cfg_.device.layer_split_gpus.size(), n_layer); - return false; - } - - shards_.resize(cfg_.device.layer_split_gpus.size()); - auto shard_metas = layer_split_shard_metas(shards_); - if (!init_layer_split_shard_metas( - shard_metas, cfg_.device.layer_split_gpus, ranges, - "gemma4-target-split")) { - return false; - } - - (void)enable_layer_split_peer_access( - cfg_.device.layer_split_gpus, cfg_.device.peer_access); - - if (!init_layer_split_snapshot_backends( - shard_metas, snapshot_backends_, "gemma4-target-split")) return false; - for (size_t i = 0; i < shards_.size(); ++i) { auto & shard = shards_[i]; const TargetLoadPlan plan = @@ -372,46 +345,15 @@ bool Gemma4LayerSplitAdapter::decode_ar( const auto & w = shards_.front().weights; const int vocab = w.n_vocab; - std::vector logits_buf; - if (sampler_.needs_logit_processing()) { - if ((int)prefill_last_logits_.size() != vocab) return false; - last_tok = sample_logits(prefill_last_logits_.data(), vocab, sampler_, - out_tokens, sampler_rng_); - } - out_tokens.push_back(last_tok); - io.emit(last_tok); - if (io.cancelled) { - io.emit(-1); - return true; - } - if (last_tok == w.eos_id || last_tok == w.eos_chat_id) { - io.emit(-1); - return true; - } - ++committed; - - for (int i = 1; i < n_gen; ++i) { - std::vector one(1, last_tok); - int next_tok = -1; - logits_buf.clear(); - if (!run_forward(one, committed - 1, next_tok, - sampler_.needs_logit_processing() ? &logits_buf : nullptr)) { - return false; - } - if (sampler_.needs_logit_processing()) { - if ((int)logits_buf.size() != vocab) return false; - next_tok = sample_logits(logits_buf.data(), vocab, sampler_, - out_tokens, sampler_rng_); - } - last_tok = next_tok; - out_tokens.push_back(last_tok); - io.emit(last_tok); - ++committed; - if (io.cancelled) break; - if (last_tok == w.eos_id || last_tok == w.eos_chat_id) break; - } - io.emit(-1); - return true; + return run_layer_split_ar_decode( + last_tok, committed, n_gen, vocab, prefill_last_logits_, sampler_, + sampler_rng_, + [&](const std::vector & one, int pos, int & next_tok, + std::vector * logits_out) { + return run_forward(one, pos - 1, next_tok, logits_out); + }, + [&](int tok) { return tok == w.eos_id || tok == w.eos_chat_id; }, + out_tokens, io); } bool Gemma4LayerSplitAdapter::snapshot_save(int slot) { diff --git a/server/src/laguna/laguna_layer_split_adapter.cpp b/server/src/laguna/laguna_layer_split_adapter.cpp index 1f6482e7e..c06d35900 100644 --- a/server/src/laguna/laguna_layer_split_adapter.cpp +++ b/server/src/laguna/laguna_layer_split_adapter.cpp @@ -6,6 +6,7 @@ #include "common/gguf_inspect.h" #include "common/layer_split_utils.h" #include "common/sampler.h" +#include "common/layer_split_runtime.h" #include "dflash27b.h" #include "ggml-cuda.h" @@ -33,43 +34,15 @@ LagunaLayerSplitAdapter::LagunaLayerSplitAdapter( LagunaLayerSplitAdapter::~LagunaLayerSplitAdapter() { shutdown(); } bool LagunaLayerSplitAdapter::init() { - if (!cfg_.target_path || cfg_.device.layer_split_gpus.size() < 2) { - std::fprintf(stderr, "[laguna-target-split] invalid layer-split config\n"); + const LayerSplitRuntimeInit runtime_cfg{ + cfg_.target_path, + &cfg_.device, + "laguna-target-split", + }; + if (!init_layer_split_runtime(runtime_cfg, shards_, snapshot_backends_)) { return false; } - const auto info = inspect_gguf_model_info(cfg_.target_path); - const int n_layer = info.n_layer; - if (n_layer <= 0) { - std::fprintf(stderr, "[laguna-target-split] failed to inspect layer count\n"); - return false; - } - - const auto ranges = compute_layer_ranges( - n_layer, - (int)cfg_.device.layer_split_gpus.size(), - cfg_.device.layer_split_weights); - if (ranges.size() != cfg_.device.layer_split_gpus.size()) { - std::fprintf(stderr, - "[laguna-target-split] bad layer split for %zu GPUs and %d layers\n", - cfg_.device.layer_split_gpus.size(), n_layer); - return false; - } - - shards_.resize(cfg_.device.layer_split_gpus.size()); - auto shard_metas = layer_split_shard_metas(shards_); - if (!init_layer_split_shard_metas( - shard_metas, cfg_.device.layer_split_gpus, ranges, - "laguna-target-split")) { - return false; - } - - (void)enable_layer_split_peer_access( - cfg_.device.layer_split_gpus, cfg_.device.peer_access); - - if (!init_layer_split_snapshot_backends( - shard_metas, snapshot_backends_, "laguna-target-split")) return false; - for (size_t i = 0; i < shards_.size(); ++i) { auto & shard = shards_[i]; const TargetLoadPlan plan = @@ -279,46 +252,15 @@ bool LagunaLayerSplitAdapter::decode_ar( const auto & w = shards_.front().weights; const int vocab = (int)w.embedder.n_vocab; - std::vector logits_buf; - if (sampler_.needs_logit_processing()) { - if ((int)prefill_last_logits_.size() != vocab) return false; - last_tok = sample_logits(prefill_last_logits_.data(), vocab, sampler_, - out_tokens, sampler_rng_); - } - out_tokens.push_back(last_tok); - io.emit(last_tok); - if (io.cancelled) { - io.emit(-1); - return true; - } - if (last_tok == w.eos_id || last_tok == w.eos_chat_id) { - io.emit(-1); - return true; - } - ++committed; - - for (int i = 1; i < n_gen; ++i) { - std::vector one(1, last_tok); - int next_tok = -1; - logits_buf.clear(); - if (!run_forward(one, committed - 1, next_tok, - sampler_.needs_logit_processing() ? &logits_buf : nullptr)) { - return false; - } - if (sampler_.needs_logit_processing()) { - if ((int)logits_buf.size() != vocab) return false; - next_tok = sample_logits(logits_buf.data(), vocab, sampler_, - out_tokens, sampler_rng_); - } - last_tok = next_tok; - out_tokens.push_back(last_tok); - io.emit(last_tok); - ++committed; - if (io.cancelled) break; - if (last_tok == w.eos_id || last_tok == w.eos_chat_id) break; - } - io.emit(-1); - return true; + return run_layer_split_ar_decode( + last_tok, committed, n_gen, vocab, prefill_last_logits_, sampler_, + sampler_rng_, + [&](const std::vector & one, int pos, int & next_tok, + std::vector * logits_out) { + return run_forward(one, pos - 1, next_tok, logits_out); + }, + [&](int tok) { return tok == w.eos_id || tok == w.eos_chat_id; }, + out_tokens, io); } bool LagunaLayerSplitAdapter::snapshot_save(int slot) { diff --git a/server/src/qwen35/qwen35_layer_split_adapter.cpp b/server/src/qwen35/qwen35_layer_split_adapter.cpp index 98f8b0259..cb095a136 100644 --- a/server/src/qwen35/qwen35_layer_split_adapter.cpp +++ b/server/src/qwen35/qwen35_layer_split_adapter.cpp @@ -6,6 +6,7 @@ #include "common/gguf_inspect.h" #include "common/layer_split_utils.h" #include "common/sampler.h" +#include "common/layer_split_runtime.h" #include "qwen35/layer_split_forward.h" #include "qwen35/qwen35_layer_split_dflash_target.h" #include "qwen3/qwen3_drafter.h" @@ -25,41 +26,15 @@ Qwen35LayerSplitAdapter::Qwen35LayerSplitAdapter( Qwen35LayerSplitAdapter::~Qwen35LayerSplitAdapter() { shutdown(); } bool Qwen35LayerSplitAdapter::init() { - if (!cfg_.target_path || cfg_.device.layer_split_gpus.size() < 2) { - std::fprintf(stderr, "[target-split] invalid layer-split config\n"); + const LayerSplitRuntimeInit runtime_cfg{ + cfg_.target_path, + &cfg_.device, + "target-split", + }; + if (!init_layer_split_runtime(runtime_cfg, shards_, snapshot_backends_)) { return false; } - const auto info = inspect_gguf_model_info(cfg_.target_path); - const int n_layer = info.n_layer; - if (n_layer <= 0) { - std::fprintf(stderr, "[target-split] failed to inspect target layer count\n"); - return false; - } - const auto ranges = compute_layer_ranges( - n_layer, - (int)cfg_.device.layer_split_gpus.size(), - cfg_.device.layer_split_weights); - if (ranges.size() != cfg_.device.layer_split_gpus.size()) { - std::fprintf(stderr, - "[target-split] bad layer split for %zu GPUs and %d layers\n", - cfg_.device.layer_split_gpus.size(), n_layer); - return false; - } - - shards_.resize(cfg_.device.layer_split_gpus.size()); - auto shard_metas = layer_split_shard_metas(shards_); - if (!init_layer_split_shard_metas( - shard_metas, cfg_.device.layer_split_gpus, ranges, "target-split")) { - return false; - } - - (void)enable_layer_split_peer_access( - cfg_.device.layer_split_gpus, cfg_.device.peer_access); - - if (!init_layer_split_snapshot_backends( - shard_metas, snapshot_backends_, "target-split")) return false; - for (auto & shard : shards_) { const TargetLoadPlan plan = make_layer_split_load_plan(shard, &shard == &shards_.back()); @@ -393,51 +368,20 @@ bool Qwen35LayerSplitAdapter::decode_ar( if (n_gen <= 0) return true; const auto & w = shards_.front().weights; const int vocab = w.n_vocab; - std::vector logits_buf; - - if (sampler_.needs_logit_processing()) { - if ((int)prefill_last_logits_.size() != vocab) return false; - last_tok = sample_logits(prefill_last_logits_.data(), vocab, sampler_, - out_tokens, sampler_rng_); - } - out_tokens.push_back(last_tok); - io.emit(last_tok); - if (io.cancelled) { - io.emit(-1); - return true; - } - if (is_eos_tok(last_tok, w)) { - io.emit(-1); - return true; - } - ++committed; - - for (int i = 1; i < n_gen; ++i) { - std::vector one(1, last_tok); - int next_tok = -1; - logits_buf.clear(); - if (!run_qwen35_layer_split_forward( - shards_, shards_.front().weights, one, committed, 1, next_tok, + return run_layer_split_ar_decode( + last_tok, committed, n_gen, vocab, prefill_last_logits_, sampler_, + sampler_rng_, + [&](const std::vector & one, int pos, int & next_tok, + std::vector * logits_out) { + return run_qwen35_layer_split_forward( + shards_, shards_.front().weights, one, pos, 1, next_tok, cfg_.kq_stride_pad, cfg_.fa_window, cfg_.run_dflash ? &feature_ring_ : nullptr, /*argmax_out=*/nullptr, - sampler_.needs_logit_processing() ? &logits_buf : nullptr)) { - return false; - } - if (sampler_.needs_logit_processing()) { - if ((int)logits_buf.size() != vocab) return false; - next_tok = sample_logits(logits_buf.data(), vocab, sampler_, - out_tokens, sampler_rng_); - } - out_tokens.push_back(next_tok); - io.emit(next_tok); - if (io.cancelled) break; - if (is_eos_tok(next_tok, w)) break; - last_tok = next_tok; - ++committed; - } - io.emit(-1); - return true; + logits_out); + }, + [&](int tok) { return is_eos_tok(tok, w); }, + out_tokens, io); } bool Qwen35LayerSplitAdapter::can_dflash_decode() const {