Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion server/src/common/layer_split_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ GenerateResult LayerSplitBackend::run_from_state(const GenerateRequest & req,
result.error = "context";
return result;
}
if (req.do_sample && req.sampler.temp > 0.0f) {
if (req.do_sample && req.sampler.needs_logit_processing() &&
!adapter_->supports_cpu_sampling()) {
result.error = "sampling_unsupported";
return result;
}
Expand Down
1 change: 1 addition & 0 deletions server/src/common/layer_split_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class LayerSplitAdapter {
virtual bool decode_ar(int last_tok, int committed, int n_gen,
std::vector<int32_t> & out_tokens,
const DaemonIO & io) = 0;
virtual bool supports_cpu_sampling() const { return false; }

virtual bool can_dflash_decode() const { return false; }
virtual bool decode_dflash(const std::vector<int32_t> & prompt,
Expand Down
43 changes: 35 additions & 8 deletions server/src/gemma4/gemma4_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,13 +515,14 @@ bool build_gemma4_layer_step(
return ggml_gallocr_alloc_graph(sg.alloc, sg.gf);
}

bool compute_gemma4_split_argmax(
bool compute_gemma4_split_projection(
ggml_backend_t backend,
const Gemma4Weights & w,
ggml_tensor * act,
int token_offset,
int n_tokens,
std::vector<int32_t> & out_argmax) {
std::vector<int32_t> * out_argmax,
std::vector<float> * out_logits) {
ggml_init_params ip{};
ip.mem_size = ggml_tensor_overhead() * 64 + ggml_graph_overhead() + 1024 * 1024;
ip.no_alloc = true;
Expand All @@ -539,9 +540,17 @@ bool compute_gemma4_split_argmax(
cur = ggml_tanh(ctx, cur);
cur = ggml_scale(ctx, cur, w.final_logit_softcap);
}
cur = ggml_argmax(ctx, cur);
ggml_set_output(cur);
ggml_build_forward_expand(gf, cur);
ggml_tensor * logits = cur;
ggml_tensor * argmax = nullptr;
if (out_logits) {
ggml_set_output(logits);
ggml_build_forward_expand(gf, logits);
}
if (out_argmax) {
argmax = ggml_argmax(ctx, logits);
ggml_set_output(argmax);
ggml_build_forward_expand(gf, argmax);
}

ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
if (!alloc || !ggml_gallocr_alloc_graph(alloc, gf)) {
Expand All @@ -554,14 +563,32 @@ bool compute_gemma4_split_argmax(
ggml_free(ctx);
return false;
}
out_argmax.resize((size_t)n_tokens);
ggml_backend_tensor_get(cur, out_argmax.data(), 0,
sizeof(int32_t) * (size_t)n_tokens);
if (out_argmax) {
out_argmax->resize((size_t)n_tokens);
ggml_backend_tensor_get(argmax, out_argmax->data(), 0,
sizeof(int32_t) * (size_t)n_tokens);
}
if (out_logits) {
out_logits->resize((size_t)w.n_vocab * (size_t)n_tokens);
ggml_backend_tensor_get(logits, out_logits->data(), 0,
sizeof(float) * (size_t)w.n_vocab * (size_t)n_tokens);
}
ggml_gallocr_free(alloc);
ggml_free(ctx);
return true;
}

bool compute_gemma4_split_argmax(
ggml_backend_t backend,
const Gemma4Weights & w,
ggml_tensor * act,
int token_offset,
int n_tokens,
std::vector<int32_t> & out_argmax) {
return compute_gemma4_split_projection(
backend, w, act, token_offset, n_tokens, &out_argmax, nullptr);
}

bool gemma4_step(
ggml_backend_t backend,
const Gemma4Weights & w,
Expand Down
9 changes: 9 additions & 0 deletions server/src/gemma4/gemma4_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,15 @@ bool compute_gemma4_split_argmax(
int n_tokens,
std::vector<int32_t> & out_argmax);

bool compute_gemma4_split_projection(
ggml_backend_t backend,
const Gemma4Weights & w,
ggml_tensor * act,
int token_offset,
int n_tokens,
std::vector<int32_t> * out_argmax,
std::vector<float> * out_logits);

// BSA sparse-FA prefill: process the full prompt at once using block-sparse
// attention for SWA layers (flash_prefill_forward_bf16). Full-attention layers
// use dense FA. Returns logits for the last token. Populates the KV cache
Expand Down
37 changes: 31 additions & 6 deletions server/src/gemma4/gemma4_layer_split_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,25 @@ bool Gemma4LayerSplitAdapter::init() {
}

void Gemma4LayerSplitAdapter::begin_request(const GenerateRequest & req) {
(void)req;
sampler_ = req.sampler;
if (req.do_sample && sampler_.seed != 0) {
sampler_rng_.seed(sampler_.seed);
}
}

void Gemma4LayerSplitAdapter::reset_request_state() {
for (auto & shard : shards_) {
shard.cache.cur_pos = 0;
shard.cache.last_tok = -1;
}
prefill_last_logits_.clear();
}

bool Gemma4LayerSplitAdapter::run_forward(
const std::vector<int32_t> & tokens,
int base_pos,
int & last_tok) {
int & last_tok,
std::vector<float> * logits_out) {
if (shards_.empty() || tokens.empty()) return false;
const Gemma4Weights & ref = shards_.front().weights;
const int hidden = ref.n_embd;
Expand Down Expand Up @@ -336,9 +341,9 @@ bool Gemma4LayerSplitAdapter::run_forward(

std::vector<int32_t> argmax;
Gemma4LayerSplitShard & last = shards_.back();
const bool ok = compute_gemma4_split_argmax(
const bool ok = compute_gemma4_split_projection(
last.backend, last.weights, act_in,
n_tokens_total - 1, 1, argmax);
n_tokens_total - 1, 1, &argmax, logits_out);
activation_buffer_free(orig);
activation_pair_free(acts);
if (!ok || argmax.empty()) return false;
Expand All @@ -353,7 +358,7 @@ bool Gemma4LayerSplitAdapter::run_forward(
bool Gemma4LayerSplitAdapter::prefill(const std::vector<int32_t> & prompt,
int base_pos,
int & last_tok) {
return run_forward(prompt, base_pos, last_tok);
return run_forward(prompt, base_pos, last_tok, &prefill_last_logits_);
}

bool Gemma4LayerSplitAdapter::decode_ar(
Expand All @@ -366,6 +371,13 @@ bool Gemma4LayerSplitAdapter::decode_ar(
if (shards_.empty()) return false;

const auto & w = shards_.front().weights;
const int vocab = w.n_vocab;
std::vector<float> logits_buf;
if (sampler_.needs_logit_processing()) {
if ((int)prefill_last_logits_.size() != vocab) return false;
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
last_tok = sample_logits(prefill_last_logits_.data(), vocab, sampler_,
out_tokens, sampler_rng_);
}
out_tokens.push_back(last_tok);
io.emit(last_tok);
if (io.cancelled) {
Expand All @@ -381,7 +393,16 @@ bool Gemma4LayerSplitAdapter::decode_ar(
for (int i = 1; i < n_gen; ++i) {
std::vector<int32_t> one(1, last_tok);
int next_tok = -1;
if (!run_forward(one, committed - 1, next_tok)) return false;
logits_buf.clear();
if (!run_forward(one, committed - 1, next_tok,
sampler_.needs_logit_processing() ? &logits_buf : nullptr)) {
return false;
}
if (sampler_.needs_logit_processing()) {
if ((int)logits_buf.size() != vocab) return false;
next_tok = sample_logits(logits_buf.data(), vocab, sampler_,
out_tokens, sampler_rng_);
}
last_tok = next_tok;
out_tokens.push_back(last_tok);
io.emit(last_tok);
Expand Down Expand Up @@ -455,6 +476,7 @@ bool Gemma4LayerSplitAdapter::snapshot_save(int slot) {
}
snap.cur_pos = snap_pos;
snap.last_tok = shards_.front().cache.last_tok;
snap.prefill_last_logits = prefill_last_logits_;
return true;
}

Expand All @@ -466,6 +488,7 @@ void Gemma4LayerSplitAdapter::snapshot_free(int slot) {
}
snap.cur_pos = 0;
snap.last_tok = -1;
snap.prefill_last_logits.clear();
if (snap.shards.size() != shards_.size()) snap.shards.resize(shards_.size());
}

Expand All @@ -476,6 +499,7 @@ bool Gemma4LayerSplitAdapter::snapshot_used(int slot) const {
}
const auto & snap = snapshots_[(size_t)slot];
if (snap.cur_pos <= 0 || snap.shards.size() != shards_.size()) return false;
if (snap.prefill_last_logits.empty()) return false;
for (const auto & ss : snap.shards) {
if (!ss.ctx) return false;
}
Expand Down Expand Up @@ -515,6 +539,7 @@ bool Gemma4LayerSplitAdapter::snapshot_restore(int slot) {
shards_[i].cache.cur_pos = snap.cur_pos;
shards_[i].cache.last_tok = snap.last_tok;
}
prefill_last_logits_ = snap.prefill_last_logits;
return true;
}

Expand Down
8 changes: 7 additions & 1 deletion server/src/gemma4/gemma4_layer_split_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ struct Gemma4LayerSplitSnapshot {
int cur_pos = 0;
int32_t last_tok = -1;
std::vector<Gemma4Snapshot> shards;
std::vector<float> prefill_last_logits;
};

class Gemma4LayerSplitAdapter : public LayerSplitAdapter {
Expand All @@ -51,6 +52,7 @@ class Gemma4LayerSplitAdapter : public LayerSplitAdapter {
bool decode_ar(int last_tok, int committed, int n_gen,
std::vector<int32_t> & out_tokens,
const DaemonIO & io) override;
bool supports_cpu_sampling() const override { return true; }

bool snapshot_save(int slot) override;
void snapshot_free(int slot) override;
Expand All @@ -65,13 +67,17 @@ class Gemma4LayerSplitAdapter : public LayerSplitAdapter {
private:
bool run_forward(const std::vector<int32_t> & tokens,
int base_pos,
int & last_tok);
int & last_tok,
std::vector<float> * logits_out = nullptr);

Gemma4LayerSplitAdapterConfig cfg_;
std::vector<Gemma4LayerSplitShard> shards_;
std::vector<ggml_backend_t> snapshot_backends_;
std::vector<Gemma4LayerSplitSnapshot> snapshots_;
static constexpr int PREFIX_SLOTS = ModelBackend::kMaxSlots;
SamplerCfg sampler_;
std::mt19937_64 sampler_rng_{std::random_device{}()};
std::vector<float> prefill_last_logits_;
};

void free_gemma4_layer_split_shards(std::vector<Gemma4LayerSplitShard> & shards);
Expand Down
54 changes: 41 additions & 13 deletions server/src/qwen35/layer_split_forward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

namespace dflash::common {

bool compute_target_split_argmax(
bool compute_target_split_projection(
StepGraph & sg,
const TargetWeights & w,
ggml_backend_t backend,
Expand All @@ -26,7 +26,8 @@ bool compute_target_split_argmax(
int n_tokens,
int hidden,
int vocab,
std::vector<int32_t> & argmax_out) {
std::vector<int32_t> * argmax_out,
std::vector<float> * logits_out) {
step_graph_free(sg);
ggml_init_params ip{};
ip.mem_size = 256 * 1024 * 1024;
Expand All @@ -43,24 +44,51 @@ bool compute_target_split_argmax(
ggml_tensor * logits = ggml_mul_mat(sg.ctx, w.output, normed);
ggml_set_name(logits, "target_split_logits");
sg.logits = logits;
sg.argmax_tokens = ggml_argmax(sg.ctx, logits);
ggml_set_name(sg.argmax_tokens, "target_split_argmax");
ggml_set_output(sg.argmax_tokens);
if (argmax_out) {
sg.argmax_tokens = ggml_argmax(sg.ctx, logits);
ggml_set_name(sg.argmax_tokens, "target_split_argmax");
ggml_set_output(sg.argmax_tokens);
}
if (logits_out) {
ggml_set_output(sg.logits);
}
sg.gf = ggml_new_graph_custom(sg.ctx, 1024, false);
ggml_build_forward_expand(sg.gf, sg.argmax_tokens);
if (argmax_out) ggml_build_forward_expand(sg.gf, sg.argmax_tokens);
if (logits_out) ggml_build_forward_expand(sg.gf, sg.logits);
if (!sg.alloc) {
sg.alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
}
if (!ggml_gallocr_alloc_graph(sg.alloc, sg.gf)) return false;
auto st = ggml_backend_graph_compute(backend, sg.gf);
if (st != GGML_STATUS_SUCCESS) return false;
(void)vocab;
argmax_out.assign((size_t)n_tokens, 0);
ggml_backend_tensor_get(sg.argmax_tokens, argmax_out.data(), 0,
sizeof(int32_t) * (size_t)n_tokens);
if (argmax_out) {
argmax_out->assign((size_t)n_tokens, 0);
ggml_backend_tensor_get(sg.argmax_tokens, argmax_out->data(), 0,
sizeof(int32_t) * (size_t)n_tokens);
}
if (logits_out) {
logits_out->assign((size_t)vocab * (size_t)n_tokens, 0.0f);
ggml_backend_tensor_get(sg.logits, logits_out->data(), 0,
sizeof(float) * (size_t)vocab * (size_t)n_tokens);
}
return true;
}

bool compute_target_split_argmax(
StepGraph & sg,
const TargetWeights & w,
ggml_backend_t backend,
ggml_tensor * act,
int token_offset,
int n_tokens,
int hidden,
int vocab,
std::vector<int32_t> & argmax_out) {
return compute_target_split_projection(
sg, w, backend, act, token_offset, n_tokens, hidden, vocab,
&argmax_out, nullptr);
}

bool run_qwen35_layer_split_forward(
std::vector<Qwen35LayerSplitShard> & shards,
const TargetWeights & embed_source,
Expand Down Expand Up @@ -208,9 +236,10 @@ bool run_qwen35_layer_split_forward(
const bool need_all_argmax = argmax_out != nullptr;
const int argmax_offset = need_all_argmax ? 0 : (n_tokens_total - 1);
const int argmax_count = need_all_argmax ? n_tokens_total : 1;
const bool ok = compute_target_split_argmax(
const bool ok = compute_target_split_projection(
final_sg, last_shard.weights, last_shard.backend, act_in,
argmax_offset, argmax_count, hidden, vocab, argmax_tokens);
argmax_offset, argmax_count, hidden, vocab,
&argmax_tokens, logits_out);
step_graph_destroy(final_sg);
activation_pair_free(acts);
if (!ok) return false;
Expand All @@ -220,7 +249,6 @@ bool run_qwen35_layer_split_forward(
shard.cache.last_tok = last_tok;
}
if (argmax_out) *argmax_out = std::move(argmax_tokens);
if (logits_out) logits_out->clear();
return true;
}

Expand Down
12 changes: 12 additions & 0 deletions server/src/qwen35/layer_split_forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,18 @@ bool compute_target_split_argmax(
int vocab,
std::vector<int32_t> & argmax_out);

bool compute_target_split_projection(
StepGraph & sg,
const TargetWeights & w,
ggml_backend_t backend,
ggml_tensor * act,
int token_offset,
int n_tokens,
int hidden,
int vocab,
std::vector<int32_t> * argmax_out,
std::vector<float> * logits_out);

// Run a full forward pass through all shards, writing K/V into each shard's
// cache. Returns the argmax of the last token in `last_tok`.
// Optionally captures features into `feature_ring` / remote draft.
Expand Down
Loading
Loading