Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions server/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ add_library(dflash_common STATIC
src/laguna/laguna_target_graph.cpp
src/laguna/laguna_daemon.cpp
src/laguna/laguna_backend.cpp
src/laguna/laguna_layer_split_adapter.cpp
src/common/backend_ipc.cpp
src/common/dflash_feature_ring.cpp
src/common/dflash_capture.cpp
Expand All @@ -246,6 +247,7 @@ add_library(dflash_common STATIC
src/common/dflash_draft_graph.cpp
src/common/dflash_spec_decode.cpp
src/common/layer_split_backend.cpp
src/common/layer_split_runtime.cpp
src/qwen35/graph_builders.cpp
src/qwen35moe/qwen35moe_ffn.cpp
src/qwen35moe/qwen35moe_backend.cpp
Expand Down
17 changes: 17 additions & 0 deletions server/src/common/backend_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "qwen35_backend.h"
#include "qwen35moe_backend.h"
#include "laguna_backend.h"
#include "laguna_layer_split_adapter.h"
#include "qwen3_backend.h"
#include "gemma4_backend.h"
#include "gemma4_layer_split_adapter.h"
Expand Down Expand Up @@ -124,8 +125,24 @@ std::unique_ptr<ModelBackend> create_backend(const BackendArgs & args) {
return backend;

} else if (arch == "laguna") {
if (args.device.is_layer_split()) {
LagunaLayerSplitAdapterConfig cfg;
cfg.target_path = args.model_path;
cfg.device = args.device;
cfg.chunk = args.chunk;

auto adapter = std::make_unique<LagunaLayerSplitAdapter>(cfg);
auto backend = std::make_unique<LayerSplitBackend>(std::move(adapter));
if (!backend->init()) {
std::fprintf(stderr, "[backend_factory] LayerSplitBackend(laguna) init failed\n");
return nullptr;
}
return backend;
}

LagunaBackendArgs lcfg;
lcfg.target_path = args.model_path;
lcfg.device = args.device;
lcfg.max_ctx = args.device.max_ctx;
lcfg.chunk = args.chunk;
// kv_type defaults to Q8_0 in LagunaBackendArgs
Expand Down
3 changes: 2 additions & 1 deletion server/src/common/layer_split_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ GenerateResult LayerSplitBackend::run_from_state(const GenerateRequest & req,
result.error = "context";
return result;
}
if (req.do_sample && req.sampler.temp > 0.0f) {
if (req.do_sample && req.sampler.needs_logit_processing() &&
!adapter_->supports_cpu_sampling()) {
result.error = "sampling_unsupported";
return result;
}
Expand Down
1 change: 1 addition & 0 deletions server/src/common/layer_split_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class LayerSplitAdapter {
virtual bool decode_ar(int last_tok, int committed, int n_gen,
std::vector<int32_t> & out_tokens,
const DaemonIO & io) = 0;
virtual bool supports_cpu_sampling() const { return false; }

virtual bool can_dflash_decode() const { return false; }
virtual bool decode_dflash(const std::vector<int32_t> & prompt,
Expand Down
64 changes: 64 additions & 0 deletions server/src/common/layer_split_runtime.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include "layer_split_runtime.h"

namespace dflash::common {

bool run_layer_split_ar_decode(
int last_tok,
int committed,
int n_gen,
int vocab,
const std::vector<float> & prefill_last_logits,
const SamplerCfg & sampler,
std::mt19937_64 & rng,
const LayerSplitForwardStep & forward_one,
const std::function<bool(int)> & is_eos,
std::vector<int32_t> & out_tokens,
const DaemonIO & io) {
if (n_gen <= 0) return true;

if (sampler.needs_logit_processing()) {
if ((int)prefill_last_logits.size() != vocab) return false;
last_tok = sample_logits(prefill_last_logits.data(), vocab, sampler,
out_tokens, rng);
}

out_tokens.push_back(last_tok);
io.emit(last_tok);
if (io.cancelled) {
io.emit(-1);
return true;
}
if (is_eos(last_tok)) {
io.emit(-1);
return true;
}
++committed;

std::vector<float> logits_buf;
for (int i = 1; i < n_gen; ++i) {
std::vector<int32_t> one(1, last_tok);
int next_tok = -1;
logits_buf.clear();
if (!forward_one(one, committed, next_tok,
sampler.needs_logit_processing() ? &logits_buf : nullptr)) {
return false;
}
if (sampler.needs_logit_processing()) {
if ((int)logits_buf.size() != vocab) return false;
next_tok = sample_logits(logits_buf.data(), vocab, sampler,
out_tokens, rng);
}

last_tok = next_tok;
out_tokens.push_back(last_tok);
io.emit(last_tok);
++committed;
if (io.cancelled) break;
if (is_eos(last_tok)) break;
}

io.emit(-1);
return true;
}

} // namespace dflash::common
91 changes: 91 additions & 0 deletions server/src/common/layer_split_runtime.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Shared target layer-split runtime helpers.
//
// Model adapters still own their partial loaders, graph builders, caches, and
// snapshot payloads. This file keeps shared adapter runtime flow in one place
// so new adapters do not copy the same shell.

#pragma once

#include "gguf_inspect.h"
#include "layer_split_utils.h"
#include "model_backend.h"
#include "placement/placement_config.h"
#include "sampler.h"

#include <cstdio>
#include <functional>
#include <random>
#include <vector>

namespace dflash::common {

struct LayerSplitRuntimeInit {
const char * target_path = nullptr;
const DevicePlacement * device = nullptr;
const char * log_prefix = "target-split";
};

template <typename Shard>
bool init_layer_split_runtime(const LayerSplitRuntimeInit & cfg,
std::vector<Shard> & shards,
std::vector<ggml_backend_t> & snapshot_backends) {
if (!cfg.target_path || !cfg.device ||
cfg.device->layer_split_gpus.size() < 2) {
std::fprintf(stderr, "[%s] invalid layer-split config\n", cfg.log_prefix);
return false;
}

const auto info = inspect_gguf_model_info(cfg.target_path);
const int n_layer = info.n_layer;
if (n_layer <= 0) {
std::fprintf(stderr, "[%s] failed to inspect target layer count\n",
cfg.log_prefix);
return false;
}

const auto ranges = compute_layer_ranges(
n_layer,
(int)cfg.device->layer_split_gpus.size(),
cfg.device->layer_split_weights);
if (ranges.size() != cfg.device->layer_split_gpus.size()) {
std::fprintf(stderr,
"[%s] bad layer split for %zu GPUs and %d layers\n",
cfg.log_prefix, cfg.device->layer_split_gpus.size(), n_layer);
return false;
}

shards.resize(cfg.device->layer_split_gpus.size());
auto shard_metas = layer_split_shard_metas(shards);
if (!init_layer_split_shard_metas(
shard_metas, cfg.device->layer_split_gpus, ranges,
cfg.log_prefix)) {
return false;
}

(void)enable_layer_split_peer_access(
cfg.device->layer_split_gpus, cfg.device->peer_access);

return init_layer_split_snapshot_backends(
shard_metas, snapshot_backends, cfg.log_prefix);
}

using LayerSplitForwardStep = std::function<bool(
const std::vector<int32_t> & tokens,
int committed,
int & next_tok,
std::vector<float> * logits_out)>;

bool run_layer_split_ar_decode(
int last_tok,
int committed,
int n_gen,
int vocab,
const std::vector<float> & prefill_last_logits,
const SamplerCfg & sampler,
std::mt19937_64 & rng,
const LayerSplitForwardStep & forward_one,
const std::function<bool(int)> & is_eos,
std::vector<int32_t> & out_tokens,
const DaemonIO & io);

} // namespace dflash::common
43 changes: 35 additions & 8 deletions server/src/gemma4/gemma4_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,13 +515,14 @@ bool build_gemma4_layer_step(
return ggml_gallocr_alloc_graph(sg.alloc, sg.gf);
}

bool compute_gemma4_split_argmax(
bool compute_gemma4_split_projection(
ggml_backend_t backend,
const Gemma4Weights & w,
ggml_tensor * act,
int token_offset,
int n_tokens,
std::vector<int32_t> & out_argmax) {
std::vector<int32_t> * out_argmax,
std::vector<float> * out_logits) {
ggml_init_params ip{};
ip.mem_size = ggml_tensor_overhead() * 64 + ggml_graph_overhead() + 1024 * 1024;
ip.no_alloc = true;
Expand All @@ -539,9 +540,17 @@ bool compute_gemma4_split_argmax(
cur = ggml_tanh(ctx, cur);
cur = ggml_scale(ctx, cur, w.final_logit_softcap);
}
cur = ggml_argmax(ctx, cur);
ggml_set_output(cur);
ggml_build_forward_expand(gf, cur);
ggml_tensor * logits = cur;
ggml_tensor * argmax = nullptr;
if (out_logits) {
ggml_set_output(logits);
ggml_build_forward_expand(gf, logits);
}
if (out_argmax) {
argmax = ggml_argmax(ctx, logits);
ggml_set_output(argmax);
ggml_build_forward_expand(gf, argmax);
}

ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend));
if (!alloc || !ggml_gallocr_alloc_graph(alloc, gf)) {
Expand All @@ -554,14 +563,32 @@ bool compute_gemma4_split_argmax(
ggml_free(ctx);
return false;
}
out_argmax.resize((size_t)n_tokens);
ggml_backend_tensor_get(cur, out_argmax.data(), 0,
sizeof(int32_t) * (size_t)n_tokens);
if (out_argmax) {
out_argmax->resize((size_t)n_tokens);
ggml_backend_tensor_get(argmax, out_argmax->data(), 0,
sizeof(int32_t) * (size_t)n_tokens);
}
if (out_logits) {
out_logits->resize((size_t)w.n_vocab * (size_t)n_tokens);
ggml_backend_tensor_get(logits, out_logits->data(), 0,
sizeof(float) * (size_t)w.n_vocab * (size_t)n_tokens);
}
ggml_gallocr_free(alloc);
ggml_free(ctx);
return true;
}

bool compute_gemma4_split_argmax(
ggml_backend_t backend,
const Gemma4Weights & w,
ggml_tensor * act,
int token_offset,
int n_tokens,
std::vector<int32_t> & out_argmax) {
return compute_gemma4_split_projection(
backend, w, act, token_offset, n_tokens, &out_argmax, nullptr);
}

bool gemma4_step(
ggml_backend_t backend,
const Gemma4Weights & w,
Expand Down
9 changes: 9 additions & 0 deletions server/src/gemma4/gemma4_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,15 @@ bool compute_gemma4_split_argmax(
int n_tokens,
std::vector<int32_t> & out_argmax);

bool compute_gemma4_split_projection(
ggml_backend_t backend,
const Gemma4Weights & w,
ggml_tensor * act,
int token_offset,
int n_tokens,
std::vector<int32_t> * out_argmax,
std::vector<float> * out_logits);

// BSA sparse-FA prefill: process the full prompt at once using block-sparse
// attention for SWA layers (flash_prefill_forward_bf16). Full-attention layers
// use dense FA. Returns logits for the last token. Populates the KV cache
Expand Down
Loading
Loading