diff --git a/server/src/common/layer_split_backend.cpp b/server/src/common/layer_split_backend.cpp index 11e75e0a..431b2ec7 100644 --- a/server/src/common/layer_split_backend.cpp +++ b/server/src/common/layer_split_backend.cpp @@ -57,7 +57,8 @@ GenerateResult LayerSplitBackend::run_from_state(const GenerateRequest & req, result.error = "context"; return result; } - if (req.do_sample && req.sampler.temp > 0.0f) { + if (req.do_sample && req.sampler.needs_logit_processing() && + !adapter_->supports_cpu_sampling()) { result.error = "sampling_unsupported"; return result; } diff --git a/server/src/common/layer_split_backend.h b/server/src/common/layer_split_backend.h index 4aeda23f..0386936a 100644 --- a/server/src/common/layer_split_backend.h +++ b/server/src/common/layer_split_backend.h @@ -30,6 +30,7 @@ class LayerSplitAdapter { virtual bool decode_ar(int last_tok, int committed, int n_gen, std::vector & out_tokens, const DaemonIO & io) = 0; + virtual bool supports_cpu_sampling() const { return false; } virtual bool can_dflash_decode() const { return false; } virtual bool decode_dflash(const std::vector & prompt, diff --git a/server/src/gemma4/gemma4_graph.cpp b/server/src/gemma4/gemma4_graph.cpp index bf8f8ce7..bc227655 100644 --- a/server/src/gemma4/gemma4_graph.cpp +++ b/server/src/gemma4/gemma4_graph.cpp @@ -515,13 +515,14 @@ bool build_gemma4_layer_step( return ggml_gallocr_alloc_graph(sg.alloc, sg.gf); } -bool compute_gemma4_split_argmax( +bool compute_gemma4_split_projection( ggml_backend_t backend, const Gemma4Weights & w, ggml_tensor * act, int token_offset, int n_tokens, - std::vector & out_argmax) { + std::vector * out_argmax, + std::vector * out_logits) { ggml_init_params ip{}; ip.mem_size = ggml_tensor_overhead() * 64 + ggml_graph_overhead() + 1024 * 1024; ip.no_alloc = true; @@ -539,9 +540,17 @@ bool compute_gemma4_split_argmax( cur = ggml_tanh(ctx, cur); cur = ggml_scale(ctx, cur, w.final_logit_softcap); } - cur = ggml_argmax(ctx, cur); - ggml_set_output(cur); - ggml_build_forward_expand(gf, cur); + ggml_tensor * logits = cur; + ggml_tensor * argmax = nullptr; + if (out_logits) { + ggml_set_output(logits); + ggml_build_forward_expand(gf, logits); + } + if (out_argmax) { + argmax = ggml_argmax(ctx, logits); + ggml_set_output(argmax); + ggml_build_forward_expand(gf, argmax); + } ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); if (!alloc || !ggml_gallocr_alloc_graph(alloc, gf)) { @@ -554,14 +563,32 @@ bool compute_gemma4_split_argmax( ggml_free(ctx); return false; } - out_argmax.resize((size_t)n_tokens); - ggml_backend_tensor_get(cur, out_argmax.data(), 0, - sizeof(int32_t) * (size_t)n_tokens); + if (out_argmax) { + out_argmax->resize((size_t)n_tokens); + ggml_backend_tensor_get(argmax, out_argmax->data(), 0, + sizeof(int32_t) * (size_t)n_tokens); + } + if (out_logits) { + out_logits->resize((size_t)w.n_vocab * (size_t)n_tokens); + ggml_backend_tensor_get(logits, out_logits->data(), 0, + sizeof(float) * (size_t)w.n_vocab * (size_t)n_tokens); + } ggml_gallocr_free(alloc); ggml_free(ctx); return true; } +bool compute_gemma4_split_argmax( + ggml_backend_t backend, + const Gemma4Weights & w, + ggml_tensor * act, + int token_offset, + int n_tokens, + std::vector & out_argmax) { + return compute_gemma4_split_projection( + backend, w, act, token_offset, n_tokens, &out_argmax, nullptr); +} + bool gemma4_step( ggml_backend_t backend, const Gemma4Weights & w, diff --git a/server/src/gemma4/gemma4_internal.h b/server/src/gemma4/gemma4_internal.h index 5e643060..80f81cd1 100644 --- a/server/src/gemma4/gemma4_internal.h +++ b/server/src/gemma4/gemma4_internal.h @@ -286,6 +286,15 @@ bool compute_gemma4_split_argmax( int n_tokens, std::vector & out_argmax); +bool compute_gemma4_split_projection( + ggml_backend_t backend, + const Gemma4Weights & w, + ggml_tensor * act, + int token_offset, + int n_tokens, + std::vector * out_argmax, + std::vector * out_logits); + // BSA sparse-FA prefill: process the full prompt at once using block-sparse // attention for SWA layers (flash_prefill_forward_bf16). Full-attention layers // use dense FA. Returns logits for the last token. Populates the KV cache diff --git a/server/src/gemma4/gemma4_layer_split_adapter.cpp b/server/src/gemma4/gemma4_layer_split_adapter.cpp index 2562faec..5a6be860 100644 --- a/server/src/gemma4/gemma4_layer_split_adapter.cpp +++ b/server/src/gemma4/gemma4_layer_split_adapter.cpp @@ -140,7 +140,10 @@ bool Gemma4LayerSplitAdapter::init() { } void Gemma4LayerSplitAdapter::begin_request(const GenerateRequest & req) { - (void)req; + sampler_ = req.sampler; + if (req.do_sample && sampler_.seed != 0) { + sampler_rng_.seed(sampler_.seed); + } } void Gemma4LayerSplitAdapter::reset_request_state() { @@ -148,12 +151,14 @@ void Gemma4LayerSplitAdapter::reset_request_state() { shard.cache.cur_pos = 0; shard.cache.last_tok = -1; } + prefill_last_logits_.clear(); } bool Gemma4LayerSplitAdapter::run_forward( const std::vector & tokens, int base_pos, - int & last_tok) { + int & last_tok, + std::vector * logits_out) { if (shards_.empty() || tokens.empty()) return false; const Gemma4Weights & ref = shards_.front().weights; const int hidden = ref.n_embd; @@ -336,9 +341,9 @@ bool Gemma4LayerSplitAdapter::run_forward( std::vector argmax; Gemma4LayerSplitShard & last = shards_.back(); - const bool ok = compute_gemma4_split_argmax( + const bool ok = compute_gemma4_split_projection( last.backend, last.weights, act_in, - n_tokens_total - 1, 1, argmax); + n_tokens_total - 1, 1, &argmax, logits_out); activation_buffer_free(orig); activation_pair_free(acts); if (!ok || argmax.empty()) return false; @@ -353,7 +358,7 @@ bool Gemma4LayerSplitAdapter::run_forward( bool Gemma4LayerSplitAdapter::prefill(const std::vector & prompt, int base_pos, int & last_tok) { - return run_forward(prompt, base_pos, last_tok); + return run_forward(prompt, base_pos, last_tok, &prefill_last_logits_); } bool Gemma4LayerSplitAdapter::decode_ar( @@ -366,6 +371,13 @@ bool Gemma4LayerSplitAdapter::decode_ar( if (shards_.empty()) return false; const auto & w = shards_.front().weights; + const int vocab = w.n_vocab; + std::vector logits_buf; + if (sampler_.needs_logit_processing()) { + if ((int)prefill_last_logits_.size() != vocab) return false; + last_tok = sample_logits(prefill_last_logits_.data(), vocab, sampler_, + out_tokens, sampler_rng_); + } out_tokens.push_back(last_tok); io.emit(last_tok); if (io.cancelled) { @@ -381,7 +393,16 @@ bool Gemma4LayerSplitAdapter::decode_ar( for (int i = 1; i < n_gen; ++i) { std::vector one(1, last_tok); int next_tok = -1; - if (!run_forward(one, committed - 1, next_tok)) return false; + logits_buf.clear(); + if (!run_forward(one, committed - 1, next_tok, + sampler_.needs_logit_processing() ? &logits_buf : nullptr)) { + return false; + } + if (sampler_.needs_logit_processing()) { + if ((int)logits_buf.size() != vocab) return false; + next_tok = sample_logits(logits_buf.data(), vocab, sampler_, + out_tokens, sampler_rng_); + } last_tok = next_tok; out_tokens.push_back(last_tok); io.emit(last_tok); @@ -455,6 +476,7 @@ bool Gemma4LayerSplitAdapter::snapshot_save(int slot) { } snap.cur_pos = snap_pos; snap.last_tok = shards_.front().cache.last_tok; + snap.prefill_last_logits = prefill_last_logits_; return true; } @@ -466,6 +488,7 @@ void Gemma4LayerSplitAdapter::snapshot_free(int slot) { } snap.cur_pos = 0; snap.last_tok = -1; + snap.prefill_last_logits.clear(); if (snap.shards.size() != shards_.size()) snap.shards.resize(shards_.size()); } @@ -476,6 +499,7 @@ bool Gemma4LayerSplitAdapter::snapshot_used(int slot) const { } const auto & snap = snapshots_[(size_t)slot]; if (snap.cur_pos <= 0 || snap.shards.size() != shards_.size()) return false; + if (snap.prefill_last_logits.empty()) return false; for (const auto & ss : snap.shards) { if (!ss.ctx) return false; } @@ -515,6 +539,7 @@ bool Gemma4LayerSplitAdapter::snapshot_restore(int slot) { shards_[i].cache.cur_pos = snap.cur_pos; shards_[i].cache.last_tok = snap.last_tok; } + prefill_last_logits_ = snap.prefill_last_logits; return true; } diff --git a/server/src/gemma4/gemma4_layer_split_adapter.h b/server/src/gemma4/gemma4_layer_split_adapter.h index 430918b5..17f2f21c 100644 --- a/server/src/gemma4/gemma4_layer_split_adapter.h +++ b/server/src/gemma4/gemma4_layer_split_adapter.h @@ -30,6 +30,7 @@ struct Gemma4LayerSplitSnapshot { int cur_pos = 0; int32_t last_tok = -1; std::vector shards; + std::vector prefill_last_logits; }; class Gemma4LayerSplitAdapter : public LayerSplitAdapter { @@ -51,6 +52,7 @@ class Gemma4LayerSplitAdapter : public LayerSplitAdapter { bool decode_ar(int last_tok, int committed, int n_gen, std::vector & out_tokens, const DaemonIO & io) override; + bool supports_cpu_sampling() const override { return true; } bool snapshot_save(int slot) override; void snapshot_free(int slot) override; @@ -65,13 +67,17 @@ class Gemma4LayerSplitAdapter : public LayerSplitAdapter { private: bool run_forward(const std::vector & tokens, int base_pos, - int & last_tok); + int & last_tok, + std::vector * logits_out = nullptr); Gemma4LayerSplitAdapterConfig cfg_; std::vector shards_; std::vector snapshot_backends_; std::vector snapshots_; static constexpr int PREFIX_SLOTS = ModelBackend::kMaxSlots; + SamplerCfg sampler_; + std::mt19937_64 sampler_rng_{std::random_device{}()}; + std::vector prefill_last_logits_; }; void free_gemma4_layer_split_shards(std::vector & shards); diff --git a/server/src/qwen35/layer_split_forward.cpp b/server/src/qwen35/layer_split_forward.cpp index d1ab6658..5fd774cc 100644 --- a/server/src/qwen35/layer_split_forward.cpp +++ b/server/src/qwen35/layer_split_forward.cpp @@ -17,7 +17,7 @@ namespace dflash::common { -bool compute_target_split_argmax( +bool compute_target_split_projection( StepGraph & sg, const TargetWeights & w, ggml_backend_t backend, @@ -26,7 +26,8 @@ bool compute_target_split_argmax( int n_tokens, int hidden, int vocab, - std::vector & argmax_out) { + std::vector * argmax_out, + std::vector * logits_out) { step_graph_free(sg); ggml_init_params ip{}; ip.mem_size = 256 * 1024 * 1024; @@ -43,24 +44,51 @@ bool compute_target_split_argmax( ggml_tensor * logits = ggml_mul_mat(sg.ctx, w.output, normed); ggml_set_name(logits, "target_split_logits"); sg.logits = logits; - sg.argmax_tokens = ggml_argmax(sg.ctx, logits); - ggml_set_name(sg.argmax_tokens, "target_split_argmax"); - ggml_set_output(sg.argmax_tokens); + if (argmax_out) { + sg.argmax_tokens = ggml_argmax(sg.ctx, logits); + ggml_set_name(sg.argmax_tokens, "target_split_argmax"); + ggml_set_output(sg.argmax_tokens); + } + if (logits_out) { + ggml_set_output(sg.logits); + } sg.gf = ggml_new_graph_custom(sg.ctx, 1024, false); - ggml_build_forward_expand(sg.gf, sg.argmax_tokens); + if (argmax_out) ggml_build_forward_expand(sg.gf, sg.argmax_tokens); + if (logits_out) ggml_build_forward_expand(sg.gf, sg.logits); if (!sg.alloc) { sg.alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); } if (!ggml_gallocr_alloc_graph(sg.alloc, sg.gf)) return false; auto st = ggml_backend_graph_compute(backend, sg.gf); if (st != GGML_STATUS_SUCCESS) return false; - (void)vocab; - argmax_out.assign((size_t)n_tokens, 0); - ggml_backend_tensor_get(sg.argmax_tokens, argmax_out.data(), 0, - sizeof(int32_t) * (size_t)n_tokens); + if (argmax_out) { + argmax_out->assign((size_t)n_tokens, 0); + ggml_backend_tensor_get(sg.argmax_tokens, argmax_out->data(), 0, + sizeof(int32_t) * (size_t)n_tokens); + } + if (logits_out) { + logits_out->assign((size_t)vocab * (size_t)n_tokens, 0.0f); + ggml_backend_tensor_get(sg.logits, logits_out->data(), 0, + sizeof(float) * (size_t)vocab * (size_t)n_tokens); + } return true; } +bool compute_target_split_argmax( + StepGraph & sg, + const TargetWeights & w, + ggml_backend_t backend, + ggml_tensor * act, + int token_offset, + int n_tokens, + int hidden, + int vocab, + std::vector & argmax_out) { + return compute_target_split_projection( + sg, w, backend, act, token_offset, n_tokens, hidden, vocab, + &argmax_out, nullptr); +} + bool run_qwen35_layer_split_forward( std::vector & shards, const TargetWeights & embed_source, @@ -208,9 +236,10 @@ bool run_qwen35_layer_split_forward( const bool need_all_argmax = argmax_out != nullptr; const int argmax_offset = need_all_argmax ? 0 : (n_tokens_total - 1); const int argmax_count = need_all_argmax ? n_tokens_total : 1; - const bool ok = compute_target_split_argmax( + const bool ok = compute_target_split_projection( final_sg, last_shard.weights, last_shard.backend, act_in, - argmax_offset, argmax_count, hidden, vocab, argmax_tokens); + argmax_offset, argmax_count, hidden, vocab, + &argmax_tokens, logits_out); step_graph_destroy(final_sg); activation_pair_free(acts); if (!ok) return false; @@ -220,7 +249,6 @@ bool run_qwen35_layer_split_forward( shard.cache.last_tok = last_tok; } if (argmax_out) *argmax_out = std::move(argmax_tokens); - if (logits_out) logits_out->clear(); return true; } diff --git a/server/src/qwen35/layer_split_forward.h b/server/src/qwen35/layer_split_forward.h index c04680fe..bb01bff0 100644 --- a/server/src/qwen35/layer_split_forward.h +++ b/server/src/qwen35/layer_split_forward.h @@ -32,6 +32,18 @@ bool compute_target_split_argmax( int vocab, std::vector & argmax_out); +bool compute_target_split_projection( + StepGraph & sg, + const TargetWeights & w, + ggml_backend_t backend, + ggml_tensor * act, + int token_offset, + int n_tokens, + int hidden, + int vocab, + std::vector * argmax_out, + std::vector * logits_out); + // Run a full forward pass through all shards, writing K/V into each shard's // cache. Returns the argmax of the last token in `last_tok`. // Optionally captures features into `feature_ring` / remote draft. diff --git a/server/src/qwen35/qwen35_layer_split_adapter.cpp b/server/src/qwen35/qwen35_layer_split_adapter.cpp index a9ece2c3..98f8b025 100644 --- a/server/src/qwen35/qwen35_layer_split_adapter.cpp +++ b/server/src/qwen35/qwen35_layer_split_adapter.cpp @@ -86,6 +86,7 @@ bool Qwen35LayerSplitAdapter::init() { for (auto & slot : prefix_snapshots_) { slot.resize(shards_.size()); } + snapshot_prefill_logits_.resize(PREFIX_SLOTS); draft_feature_snapshots_.resize(PREFIX_SLOTS); return true; @@ -171,6 +172,7 @@ void Qwen35LayerSplitAdapter::begin_request(const GenerateRequest & req) { void Qwen35LayerSplitAdapter::reset_request_state() { for (auto & shard : shards_) reset_target_cache(shard.cache); + prefill_last_logits_.clear(); } bool Qwen35LayerSplitAdapter::prefill(const std::vector & prompt, @@ -190,7 +192,8 @@ bool Qwen35LayerSplitAdapter::prefill(const std::vector & prompt, shards_, shards_.front().weights, prompt, base_pos, ubatch, last_tok, cfg_.kq_stride_pad, /*fa_window=*/0, (cfg_.run_dflash && !remote_draft_.active()) ? &feature_ring_ : nullptr, - /*argmax_out=*/nullptr, /*logits_out=*/nullptr, + /*argmax_out=*/nullptr, + &prefill_last_logits_, cfg_.run_dflash ? &remote_draft_ : nullptr); } @@ -215,6 +218,8 @@ bool Qwen35LayerSplitAdapter::snapshot_save(int slot) { return false; } } + if (snapshot_prefill_logits_.size() != (size_t)PREFIX_SLOTS) return false; + snapshot_prefill_logits_[(size_t)slot] = prefill_last_logits_; if (!snapshot_draft_features(slot)) { snapshot_free(slot); return false; @@ -227,6 +232,9 @@ void Qwen35LayerSplitAdapter::snapshot_free(int slot) { for (auto & snap : prefix_snapshots_[(size_t)slot]) { free_prefix_snapshot(snap); } + if (snapshot_prefill_logits_.size() == (size_t)PREFIX_SLOTS) { + snapshot_prefill_logits_[(size_t)slot].clear(); + } free_draft_feature_snapshot(slot); } @@ -237,6 +245,10 @@ bool Qwen35LayerSplitAdapter::snapshot_used(int slot) const { for (const auto & snap : snaps) { if (!snap.ctx) return false; } + if (snapshot_prefill_logits_.size() != (size_t)PREFIX_SLOTS || + snapshot_prefill_logits_[(size_t)slot].empty()) { + return false; + } if (cfg_.run_dflash && cfg_.draft_path) { if (draft_feature_snapshots_.size() != (size_t)PREFIX_SLOTS) return false; const auto & draft_snap = draft_feature_snapshots_[(size_t)slot]; @@ -261,6 +273,8 @@ bool Qwen35LayerSplitAdapter::snapshot_restore(int slot) { return false; } } + if (snapshot_prefill_logits_.size() != (size_t)PREFIX_SLOTS) return false; + prefill_last_logits_ = snapshot_prefill_logits_[(size_t)slot]; if (!restore_draft_features(slot)) return false; return true; } @@ -377,14 +391,22 @@ bool Qwen35LayerSplitAdapter::decode_ar( std::vector & out_tokens, const DaemonIO & io) { if (n_gen <= 0) return true; + const auto & w = shards_.front().weights; + const int vocab = w.n_vocab; + std::vector logits_buf; + if (sampler_.needs_logit_processing()) { + if ((int)prefill_last_logits_.size() != vocab) return false; + last_tok = sample_logits(prefill_last_logits_.data(), vocab, sampler_, + out_tokens, sampler_rng_); + } out_tokens.push_back(last_tok); io.emit(last_tok); if (io.cancelled) { io.emit(-1); return true; } - if (is_eos_tok(last_tok, shards_.front().weights)) { + if (is_eos_tok(last_tok, w)) { io.emit(-1); return true; } @@ -393,16 +415,24 @@ bool Qwen35LayerSplitAdapter::decode_ar( for (int i = 1; i < n_gen; ++i) { std::vector one(1, last_tok); int next_tok = -1; + logits_buf.clear(); if (!run_qwen35_layer_split_forward( shards_, shards_.front().weights, one, committed, 1, next_tok, cfg_.kq_stride_pad, cfg_.fa_window, - cfg_.run_dflash ? &feature_ring_ : nullptr)) { + cfg_.run_dflash ? &feature_ring_ : nullptr, + /*argmax_out=*/nullptr, + sampler_.needs_logit_processing() ? &logits_buf : nullptr)) { return false; } + if (sampler_.needs_logit_processing()) { + if ((int)logits_buf.size() != vocab) return false; + next_tok = sample_logits(logits_buf.data(), vocab, sampler_, + out_tokens, sampler_rng_); + } out_tokens.push_back(next_tok); io.emit(next_tok); if (io.cancelled) break; - if (is_eos_tok(next_tok, shards_.front().weights)) break; + if (is_eos_tok(next_tok, w)) break; last_tok = next_tok; ++committed; } @@ -411,7 +441,7 @@ bool Qwen35LayerSplitAdapter::decode_ar( } bool Qwen35LayerSplitAdapter::can_dflash_decode() const { - return cfg_.run_dflash && cfg_.draft_path && sampler_.temp == 0.0f; + return cfg_.run_dflash && cfg_.draft_path && !sampler_.needs_logit_processing(); } bool Qwen35LayerSplitAdapter::decode_dflash( @@ -499,6 +529,7 @@ void Qwen35LayerSplitAdapter::shutdown() { for (auto & snap : slot) free_prefix_snapshot(snap); } prefix_snapshots_.clear(); + snapshot_prefill_logits_.clear(); draft_feature_snapshots_.clear(); auto shard_metas = layer_split_shard_metas(shards_); free_layer_split_snapshot_backends(shard_metas, snapshot_backends_); diff --git a/server/src/qwen35/qwen35_layer_split_adapter.h b/server/src/qwen35/qwen35_layer_split_adapter.h index 4011d837..68ce39fd 100644 --- a/server/src/qwen35/qwen35_layer_split_adapter.h +++ b/server/src/qwen35/qwen35_layer_split_adapter.h @@ -55,6 +55,7 @@ class Qwen35LayerSplitAdapter : public LayerSplitAdapter { bool decode_ar(int last_tok, int committed, int n_gen, std::vector & out_tokens, const DaemonIO & io) override; + bool supports_cpu_sampling() const override { return true; } bool can_dflash_decode() const override; bool decode_dflash(const std::vector & prompt, int base_pos, @@ -99,6 +100,7 @@ class Qwen35LayerSplitAdapter : public LayerSplitAdapter { bool pflash_drafter_loaded_ = false; static constexpr int PREFIX_SLOTS = ModelBackend::kMaxSlots; std::vector> prefix_snapshots_; + std::vector> snapshot_prefill_logits_; std::vector snapshot_backends_; struct DraftFeatureSnapshot { int cur_pos = 0; @@ -114,6 +116,7 @@ class Qwen35LayerSplitAdapter : public LayerSplitAdapter { SamplerCfg sampler_; std::mt19937_64 sampler_rng_{std::random_device{}()}; std::unique_ptr dflash_target_; + std::vector prefill_last_logits_; }; } // namespace dflash::common diff --git a/server/test/test_server_unit.cpp b/server/test/test_server_unit.cpp index 1415aab3..26dd8169 100644 --- a/server/test/test_server_unit.cpp +++ b/server/test/test_server_unit.cpp @@ -1246,6 +1246,7 @@ struct MockLayerSplitAdapter : LayerSplitAdapter { std::vector emitted_tokens; bool dflash_enabled = false; bool dflash_called = false; + bool sampling_enabled = false; int shutdown_calls = 0; ModelBackend::CompressRequest last_compress_req; @@ -1280,6 +1281,7 @@ struct MockLayerSplitAdapter : LayerSplitAdapter { return true; } bool can_dflash_decode() const override { return dflash_enabled; } + bool supports_cpu_sampling() const override { return sampling_enabled; } bool decode_dflash(const std::vector & prompt, int base_pos, int last_tok, int n_gen, std::vector & out_tokens, const DaemonIO & io) override { @@ -1374,6 +1376,42 @@ static void test_layer_split_backend_inline_snapshot_and_restore_delta() { TEST_ASSERT(raw->dflash_last == 99); } +static void test_layer_split_backend_sampling_capability_gate() { + { + auto * raw = new MockLayerSplitAdapter(); + LayerSplitBackend backend{std::unique_ptr(raw)}; + + GenerateRequest req; + req.prompt = {10, 11}; + req.n_gen = 1; + req.do_sample = true; + req.sampler.temp = 0.8f; + DaemonIO io; + GenerateResult result = backend.generate(req, io); + + TEST_ASSERT(!result.ok); + TEST_ASSERT(result.error == "sampling_unsupported"); + } + + { + auto * raw = new MockLayerSplitAdapter(); + raw->sampling_enabled = true; + LayerSplitBackend backend{std::unique_ptr(raw)}; + + GenerateRequest req; + req.prompt = {10, 11}; + req.n_gen = 1; + req.do_sample = true; + req.sampler.temp = 0.8f; + DaemonIO io; + GenerateResult result = backend.generate(req, io); + + TEST_ASSERT(result.ok); + TEST_ASSERT(result.tokens.size() == 1); + TEST_ASSERT(result.tokens[0] == 12); + } +} + static void test_layer_split_compress_nopark_uses_default_drafter_path() { const std::string ids_path = "/tmp/dflash_test_layer_split_compress_ids.bin"; unlink(ids_path.c_str()); @@ -2548,6 +2586,7 @@ int main() { RUN_TEST(test_parse_target_device_list_single_gpu_is_not_layer_split); RUN_TEST(test_validate_layer_split_weights_shape); RUN_TEST(test_layer_split_backend_inline_snapshot_and_restore_delta); + RUN_TEST(test_layer_split_backend_sampling_capability_gate); RUN_TEST(test_layer_split_compress_nopark_uses_default_drafter_path); RUN_TEST(test_layer_split_compress_rejects_bad_keep_ratio); RUN_TEST(test_layer_split_backend_shutdown_is_idempotent);