Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions server/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ add_library(dflash_common STATIC
src/laguna/laguna_target_graph.cpp
src/laguna/laguna_daemon.cpp
src/laguna/laguna_backend.cpp
src/laguna/laguna_layer_split_adapter.cpp
src/common/backend_ipc.cpp
src/common/dflash_feature_ring.cpp
src/common/dflash_capture.cpp
Expand Down
17 changes: 17 additions & 0 deletions server/src/common/backend_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "qwen35_backend.h"
#include "qwen35moe_backend.h"
#include "laguna_backend.h"
#include "laguna_layer_split_adapter.h"
#include "qwen3_backend.h"
#include "gemma4_backend.h"
#include "gemma4_layer_split_adapter.h"
Expand Down Expand Up @@ -124,8 +125,24 @@ std::unique_ptr<ModelBackend> create_backend(const BackendArgs & args) {
return backend;

} else if (arch == "laguna") {
if (args.device.is_layer_split()) {
LagunaLayerSplitAdapterConfig cfg;
cfg.target_path = args.model_path;
cfg.device = args.device;
cfg.chunk = args.chunk;

auto adapter = std::make_unique<LagunaLayerSplitAdapter>(cfg);
auto backend = std::make_unique<LayerSplitBackend>(std::move(adapter));
if (!backend->init()) {
std::fprintf(stderr, "[backend_factory] LayerSplitBackend(laguna) init failed\n");
return nullptr;
}
return backend;
}

LagunaBackendArgs lcfg;
lcfg.target_path = args.model_path;
lcfg.device = args.device;
lcfg.max_ctx = args.device.max_ctx;
lcfg.chunk = args.chunk;
// kv_type defaults to Q8_0 in LagunaBackendArgs
Expand Down
3 changes: 2 additions & 1 deletion server/src/common/layer_split_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ GenerateResult LayerSplitBackend::run_from_state(const GenerateRequest & req,
result.error = "context";
return result;
}
if (req.do_sample && req.sampler.temp > 0.0f) {
if (req.do_sample && req.sampler.needs_logit_processing() &&
!adapter_->supports_cpu_sampling()) {
result.error = "sampling_unsupported";
return result;
}
Expand Down
1 change: 1 addition & 0 deletions server/src/common/layer_split_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class LayerSplitAdapter {
virtual bool decode_ar(int last_tok, int committed, int n_gen,
std::vector<int32_t> & out_tokens,
const DaemonIO & io) = 0;
virtual bool supports_cpu_sampling() const { return false; }

virtual bool can_dflash_decode() const { return false; }
virtual bool decode_dflash(const std::vector<int32_t> & prompt,
Expand Down
43 changes: 35 additions & 8 deletions server/src/gemma4/gemma4_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,13 +515,14 @@ bool build_gemma4_layer_step(
return ggml_gallocr_alloc_graph(sg.alloc, sg.gf);
}

bool compute_gemma4_split_argmax(
bool compute_gemma4_split_projection(
ggml_backend_t backend,
const Gemma4Weights & w,
ggml_tensor * act,
int token_offset,
int n_tokens,
std::vector<int32_t> & out_argmax) {
std::vector<int32_t> * out_argmax,
std::vector<float> * out_logits) {
ggml_init_params ip{};
ip.mem_size = ggml_tensor_overhead() * 64 + ggml_graph_overhead() + 1024 * 1024;
ip.no_alloc = true;
Expand All @@ -539,9 +540,17 @@ bool compute_gemma4_split_argmax(
cur = ggml_tanh(ctx, cur);
cur = ggml_scale(ctx, cur, w.final_logit_softcap);
}
cur = ggml_argmax(ctx, cur);
ggml_set_output(cur);
ggml_build_forward_expand(gf, cur);
ggml_tensor * logits = cur;
ggml_tensor * argmax = nullptr;
if (out_logits) {
ggml_set_output(logits);
ggml_build_forward_expand(gf, logits);
}
if (out_argmax) {
argmax = ggml_argmax(ctx, logits);
ggml_set_output(argmax);
ggml_build_forward_expand(gf, argmax);
}

ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
if (!alloc || !ggml_gallocr_alloc_graph(alloc, gf)) {
Expand All @@ -554,14 +563,32 @@ bool compute_gemma4_split_argmax(
ggml_free(ctx);
return false;
}
out_argmax.resize((size_t)n_tokens);
ggml_backend_tensor_get(cur, out_argmax.data(), 0,
sizeof(int32_t) * (size_t)n_tokens);
if (out_argmax) {
out_argmax->resize((size_t)n_tokens);
ggml_backend_tensor_get(argmax, out_argmax->data(), 0,
sizeof(int32_t) * (size_t)n_tokens);
}
if (out_logits) {
out_logits->resize((size_t)w.n_vocab * (size_t)n_tokens);
ggml_backend_tensor_get(logits, out_logits->data(), 0,
sizeof(float) * (size_t)w.n_vocab * (size_t)n_tokens);
}
ggml_gallocr_free(alloc);
ggml_free(ctx);
return true;
}

bool compute_gemma4_split_argmax(
ggml_backend_t backend,
const Gemma4Weights & w,
ggml_tensor * act,
int token_offset,
int n_tokens,
std::vector<int32_t> & out_argmax) {
return compute_gemma4_split_projection(
backend, w, act, token_offset, n_tokens, &out_argmax, nullptr);
}

bool gemma4_step(
ggml_backend_t backend,
const Gemma4Weights & w,
Expand Down
9 changes: 9 additions & 0 deletions server/src/gemma4/gemma4_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,15 @@ bool compute_gemma4_split_argmax(
int n_tokens,
std::vector<int32_t> & out_argmax);

bool compute_gemma4_split_projection(
ggml_backend_t backend,
const Gemma4Weights & w,
ggml_tensor * act,
int token_offset,
int n_tokens,
std::vector<int32_t> * out_argmax,
std::vector<float> * out_logits);

// BSA sparse-FA prefill: process the full prompt at once using block-sparse
// attention for SWA layers (flash_prefill_forward_bf16). Full-attention layers
// use dense FA. Returns logits for the last token. Populates the KV cache
Expand Down
37 changes: 31 additions & 6 deletions server/src/gemma4/gemma4_layer_split_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,20 +140,25 @@ bool Gemma4LayerSplitAdapter::init() {
}

void Gemma4LayerSplitAdapter::begin_request(const GenerateRequest & req) {
(void)req;
sampler_ = req.sampler;
if (req.do_sample && sampler_.seed != 0) {
sampler_rng_.seed(sampler_.seed);
}
}

void Gemma4LayerSplitAdapter::reset_request_state() {
for (auto & shard : shards_) {
shard.cache.cur_pos = 0;
shard.cache.last_tok = -1;
}
prefill_last_logits_.clear();
}

bool Gemma4LayerSplitAdapter::run_forward(
const std::vector<int32_t> & tokens,
int base_pos,
int & last_tok) {
int & last_tok,
std::vector<float> * logits_out) {
if (shards_.empty() || tokens.empty()) return false;
const Gemma4Weights & ref = shards_.front().weights;
const int hidden = ref.n_embd;
Expand Down Expand Up @@ -336,9 +341,9 @@ bool Gemma4LayerSplitAdapter::run_forward(

std::vector<int32_t> argmax;
Gemma4LayerSplitShard & last = shards_.back();
const bool ok = compute_gemma4_split_argmax(
const bool ok = compute_gemma4_split_projection(
last.backend, last.weights, act_in,
n_tokens_total - 1, 1, argmax);
n_tokens_total - 1, 1, &argmax, logits_out);
activation_buffer_free(orig);
activation_pair_free(acts);
if (!ok || argmax.empty()) return false;
Expand All @@ -353,7 +358,7 @@ bool Gemma4LayerSplitAdapter::run_forward(
bool Gemma4LayerSplitAdapter::prefill(const std::vector<int32_t> & prompt,
int base_pos,
int & last_tok) {
return run_forward(prompt, base_pos, last_tok);
return run_forward(prompt, base_pos, last_tok, &prefill_last_logits_);
}

bool Gemma4LayerSplitAdapter::decode_ar(
Expand All @@ -366,6 +371,13 @@ bool Gemma4LayerSplitAdapter::decode_ar(
if (shards_.empty()) return false;

const auto & w = shards_.front().weights;
const int vocab = w.n_vocab;
std::vector<float> logits_buf;
if (sampler_.needs_logit_processing()) {
if ((int)prefill_last_logits_.size() != vocab) return false;
last_tok = sample_logits(prefill_last_logits_.data(), vocab, sampler_,
out_tokens, sampler_rng_);
}
out_tokens.push_back(last_tok);
io.emit(last_tok);
if (io.cancelled) {
Expand All @@ -381,7 +393,16 @@ bool Gemma4LayerSplitAdapter::decode_ar(
for (int i = 1; i < n_gen; ++i) {
std::vector<int32_t> one(1, last_tok);
int next_tok = -1;
if (!run_forward(one, committed - 1, next_tok)) return false;
logits_buf.clear();
if (!run_forward(one, committed - 1, next_tok,
sampler_.needs_logit_processing() ? &logits_buf : nullptr)) {
return false;
}
if (sampler_.needs_logit_processing()) {
if ((int)logits_buf.size() != vocab) return false;
next_tok = sample_logits(logits_buf.data(), vocab, sampler_,
out_tokens, sampler_rng_);
}
last_tok = next_tok;
out_tokens.push_back(last_tok);
io.emit(last_tok);
Expand Down Expand Up @@ -455,6 +476,7 @@ bool Gemma4LayerSplitAdapter::snapshot_save(int slot) {
}
snap.cur_pos = snap_pos;
snap.last_tok = shards_.front().cache.last_tok;
snap.prefill_last_logits = prefill_last_logits_;
return true;
}

Expand All @@ -466,6 +488,7 @@ void Gemma4LayerSplitAdapter::snapshot_free(int slot) {
}
snap.cur_pos = 0;
snap.last_tok = -1;
snap.prefill_last_logits.clear();
if (snap.shards.size() != shards_.size()) snap.shards.resize(shards_.size());
}

Expand All @@ -476,6 +499,7 @@ bool Gemma4LayerSplitAdapter::snapshot_used(int slot) const {
}
const auto & snap = snapshots_[(size_t)slot];
if (snap.cur_pos <= 0 || snap.shards.size() != shards_.size()) return false;
if (snap.prefill_last_logits.empty()) return false;
for (const auto & ss : snap.shards) {
if (!ss.ctx) return false;
}
Expand Down Expand Up @@ -515,6 +539,7 @@ bool Gemma4LayerSplitAdapter::snapshot_restore(int slot) {
shards_[i].cache.cur_pos = snap.cur_pos;
shards_[i].cache.last_tok = snap.last_tok;
}
prefill_last_logits_ = snap.prefill_last_logits;
return true;
}

Expand Down
8 changes: 7 additions & 1 deletion server/src/gemma4/gemma4_layer_split_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ struct Gemma4LayerSplitSnapshot {
int cur_pos = 0;
int32_t last_tok = -1;
std::vector<Gemma4Snapshot> shards;
std::vector<float> prefill_last_logits;
};

class Gemma4LayerSplitAdapter : public LayerSplitAdapter {
Expand All @@ -51,6 +52,7 @@ class Gemma4LayerSplitAdapter : public LayerSplitAdapter {
bool decode_ar(int last_tok, int committed, int n_gen,
std::vector<int32_t> & out_tokens,
const DaemonIO & io) override;
bool supports_cpu_sampling() const override { return true; }

bool snapshot_save(int slot) override;
void snapshot_free(int slot) override;
Expand All @@ -65,13 +67,17 @@ class Gemma4LayerSplitAdapter : public LayerSplitAdapter {
private:
bool run_forward(const std::vector<int32_t> & tokens,
int base_pos,
int & last_tok);
int & last_tok,
std::vector<float> * logits_out = nullptr);

Gemma4LayerSplitAdapterConfig cfg_;
std::vector<Gemma4LayerSplitShard> shards_;
std::vector<ggml_backend_t> snapshot_backends_;
std::vector<Gemma4LayerSplitSnapshot> snapshots_;
static constexpr int PREFIX_SLOTS = ModelBackend::kMaxSlots;
SamplerCfg sampler_;
std::mt19937_64 sampler_rng_{std::random_device{}()};
std::vector<float> prefill_last_logits_;
};

void free_gemma4_layer_split_shards(std::vector<Gemma4LayerSplitShard> & shards);
Expand Down
4 changes: 2 additions & 2 deletions server/src/laguna/laguna_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ LagunaBackend::LagunaBackend(const LagunaBackendArgs & args)
LagunaBackend::~LagunaBackend() { shutdown(); }

bool LagunaBackend::init() {
backend_ = ggml_backend_cuda_init(0);
backend_ = ggml_backend_cuda_init(args_.device.gpu);
if (!backend_) {
std::fprintf(stderr, "cuda init failed\n");
std::fprintf(stderr, "cuda init failed gpu=%d\n", args_.device.gpu);
return false;
}

Expand Down
2 changes: 2 additions & 0 deletions server/src/laguna/laguna_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "model_backend.h"
#include "laguna_internal.h"
#include "placement/placement_config.h"
#include "qwen3_drafter.h"

#include "ggml.h"
Expand All @@ -22,6 +23,7 @@ namespace dflash::common {

struct LagunaBackendArgs {
std::string target_path;
DevicePlacement device;
int max_ctx = 16384;
int chunk = 2048;
ggml_type kv_type = GGML_TYPE_Q8_0;
Expand Down
1 change: 1 addition & 0 deletions server/src/laguna/laguna_daemon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace dflash::common {
int run_laguna_daemon(const LagunaDaemonArgs & args) {
LagunaBackendArgs bargs;
bargs.target_path = args.target_path;
bargs.device = args.device;
bargs.max_ctx = args.device.max_ctx;
bargs.chunk = args.chunk;
bargs.kv_type = args.kv_type;
Expand Down
Loading
Loading