Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions server/src/common/model_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "ggml.h"
#include "ggml-backend.h"
#include "sampler.h"
#include "placement/draft_residency.h"

namespace dflash::common {

Expand Down Expand Up @@ -250,6 +251,7 @@ struct ModelBackend {
std::string drafter_path; // GGUF path (for lazy-load)
int drafter_gpu = 0; // backend-local GPU for PFlash drafter
bool skip_park = false; // true on >=32GB GPUs
DraftResidencyAction residency_action = DraftResidencyAction::KeepLoaded;
};

struct CompressResult {
Expand Down
246 changes: 142 additions & 104 deletions server/src/gemma4/gemma4_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,89 +55,9 @@ bool Gemma4Backend::init() {
}
cache_.fa_window = cfg_.fa_window;

// Load draft model for speculative decode
if (cfg_.draft_path) {
const int draft_gpu = (cfg_.draft_gpu >= 0) ? cfg_.draft_gpu : cfg_.device.gpu;
draft_backend_ = ggml_backend_cuda_init(draft_gpu);
if (!draft_backend_) {
std::fprintf(stderr, "[gemma4] draft CUDA init failed (gpu=%d)\n", draft_gpu);
} else {
// Load draft GGUF — pass nullptr for target (Gemma4 != TargetWeights)
if (!load_draft_gguf(cfg_.draft_path, draft_backend_, dw_, nullptr)) {
std::fprintf(stderr, "[gemma4] draft load failed: %s\n", dflash27b_last_error());
ggml_backend_free(draft_backend_); draft_backend_ = nullptr;
} else {
// Override mask_token_id for Gemma4 (token 4 per model card)
dw_.mask_token_id = 4;

// Fix draft dimensions from actual tensor shapes (GGUF metadata is wrong)
// fc.weight: [fc_in, draft_hidden]
const int draft_hidden = (int)dw_.fc->ne[1];
const int fc_in = (int)dw_.fc->ne[0];
const int n_capture = fc_in / w_.n_embd;

if (draft_hidden != dw_.n_embd) {
std::printf("[gemma4] draft: overriding n_embd %d -> %d (from fc weight)\n",
dw_.n_embd, draft_hidden);
dw_.n_embd = draft_hidden;
}
// Infer n_head from wq shape: wq.ne[1] = n_head * head_dim
if (dw_.n_layer > 0 && dw_.layers[0].wq) {
const int q_dim = (int)dw_.layers[0].wq->ne[1];
const int inferred_n_head = q_dim / dw_.head_dim;
if (inferred_n_head != dw_.n_head) {
std::printf("[gemma4] draft: overriding n_head %d -> %d\n",
dw_.n_head, inferred_n_head);
dw_.n_head = inferred_n_head;
}
}
// Infer n_ff from ffn_gate shape
if (dw_.n_layer > 0 && dw_.layers[0].w_gate) {
const int inferred_ff = (int)dw_.layers[0].w_gate->ne[1];
if (inferred_ff != dw_.n_ff) {
std::printf("[gemma4] draft: overriding n_ff %d -> %d\n",
dw_.n_ff, inferred_ff);
dw_.n_ff = inferred_ff;
}
}
// Override n_target_layers from fc shape
dw_.n_target_layers = n_capture;

// Gemma4 DFlash draft: layers 0-3 are SWA (causal), layer 4 is full (non-causal)
// (from model card: layer_types = [sliding*4, full_attention])
dw_.swa_window = 2048;
for (int i = 0; i < dw_.n_layer - 1 && i < (int)dw_.layers.size(); i++)
dw_.layers[i].is_swa = true;

std::printf("[gemma4] draft loaded: fc_in=%d target_hidden=%d "
"draft_hidden=%d n_capture_layers=%d swa=%d\n",
fc_in, w_.n_embd, draft_hidden, n_capture, dw_.swa_window);

// Allocate target_feat ring buffer
constexpr int TARGET_FEAT_CAP = 4096;
const int feat_cap = std::min(cfg_.device.max_ctx, TARGET_FEAT_CAP);
if (!create_gemma4_target_feat(backend_, cache_, n_capture, w_.n_embd, feat_cap)) {
std::fprintf(stderr, "[gemma4] target_feat alloc failed\n");
} else {
// Init feature mirror on draft GPU
const int mirror_cap = std::min(cfg_.draft_ctx_max, feat_cap);
if (!draft_feature_mirror_init(feature_mirror_, draft_backend_,
draft_gpu, cfg_.device.gpu, mirror_cap,
n_capture, w_.n_embd)) {
std::fprintf(stderr, "[gemma4] feature mirror init failed\n");
} else {
// Create DFlash target adapter
dflash_target_ = new Gemma4DFlashTarget(w_, cache_, backend_);
std::printf("[gemma4] spec-decode ready: capture_layers=%d mirror_cap=%d\n",
n_capture, mirror_cap);
std::printf("[gemma4] capture_layer_ids:");
for (int k = 0; k < (int)cache_.capture_layer_ids.size(); k++)
std::printf(" %d", cache_.capture_layer_ids[k]);
std::printf("\n");
}
}
}
}
// Load draft model for speculative decode.
if (cfg_.draft_path && !load_decode_draft()) {
std::fprintf(stderr, "[gemma4] draft unavailable; speculative decode disabled\n");
}

std::printf("[gemma4] init ok: %d layers, embd=%d, vocab=%d, max_ctx=%d\n",
Expand All @@ -157,8 +77,15 @@ void Gemma4Backend::print_ready_banner() const {
// ── Park / Unpark ──────────────────────────────────────────────────────

bool Gemma4Backend::park(const std::string & what) {
(void)what;
if (parked_) return true;
const bool want_draft = (what.empty() || what == "all" || what == "draft");
const bool want_target = (what.empty() || what == "all" || what == "target");

if (want_draft && !draft_parked_) {
free_decode_draft();
draft_parked_ = true;
std::printf("[gemma4] draft released\n"); std::fflush(stdout);
}
if (!want_target || parked_) return true;

// Free snapshots first (they reference the snap_backend buffer)
for (int i = 0; i < PREFIX_SLOTS; ++i) {
Expand All @@ -177,24 +104,37 @@ bool Gemma4Backend::park(const std::string & what) {
}

bool Gemma4Backend::unpark(const std::string & what) {
(void)what;
if (!parked_) return true;
const bool want_draft = (what.empty() || what == "all" || what == "draft");
const bool want_target = (what.empty() || what == "all" || what == "target");

if (want_target && !parked_) {
// target already resident
} else if (want_target && parked_) {
// Reload weights from disk
if (!load_gemma4_gguf(cfg_.model_path, backend_, w_)) {
std::fprintf(stderr, "[gemma4] unpark: failed to reload weights\n");
return false;
}

// Reload weights from disk
if (!load_gemma4_gguf(cfg_.model_path, backend_, w_)) {
std::fprintf(stderr, "[gemma4] unpark: failed to reload weights\n");
return false;
}
// Recreate KV cache
if (!create_gemma4_cache(backend_, w_, cfg_.device.max_ctx, cache_)) {
std::fprintf(stderr, "[gemma4] unpark: failed to recreate cache\n");
free_gemma4_weights(w_);
return false;
}
cache_.fa_window = cfg_.fa_window;

// Recreate KV cache
if (!create_gemma4_cache(backend_, w_, cfg_.device.max_ctx, cache_)) {
std::fprintf(stderr, "[gemma4] unpark: failed to recreate cache\n");
free_gemma4_weights(w_);
return false;
parked_ = false;
std::printf("[gemma4] unparked (VRAM restored)\n"); std::fflush(stdout);
if (cfg_.draft_path && !draft_parked_ && draft_backend_) {
delete dflash_target_;
dflash_target_ = new Gemma4DFlashTarget(w_, cache_, backend_);
}
}

parked_ = false;
std::printf("[gemma4] unparked (VRAM restored)\n"); std::fflush(stdout);
if (want_draft && draft_parked_ && cfg_.draft_path) {
if (!load_decode_draft()) return false;
}
return true;
}

Expand Down Expand Up @@ -1104,16 +1044,114 @@ bool Gemma4Backend::try_handle_command(const std::string & line,
return false; // no arch-specific commands
}

bool Gemma4Backend::load_decode_draft() {
if (!cfg_.draft_path) return false;
if (draft_backend_ && feature_mirror_.target_feat) {
draft_parked_ = false;
return true;
}

const int draft_gpu = (cfg_.draft_gpu >= 0) ? cfg_.draft_gpu : cfg_.device.gpu;
draft_backend_ = ggml_backend_cuda_init(draft_gpu);
if (!draft_backend_) {
std::fprintf(stderr, "[gemma4] draft CUDA init failed (gpu=%d)\n", draft_gpu);
return false;
}
if (!load_draft_gguf(cfg_.draft_path, draft_backend_, dw_, nullptr)) {
std::fprintf(stderr, "[gemma4] draft load failed: %s\n", dflash27b_last_error());
ggml_backend_free(draft_backend_);
draft_backend_ = nullptr;
return false;
}

dw_.mask_token_id = 4;
const int draft_hidden = (int)dw_.fc->ne[1];
const int fc_in = (int)dw_.fc->ne[0];
const int n_capture = fc_in / w_.n_embd;

if (draft_hidden != dw_.n_embd) {
std::printf("[gemma4] draft: overriding n_embd %d -> %d (from fc weight)\n",
dw_.n_embd, draft_hidden);
dw_.n_embd = draft_hidden;
}
if (dw_.n_layer > 0 && dw_.layers[0].wq) {
const int q_dim = (int)dw_.layers[0].wq->ne[1];
const int inferred_n_head = q_dim / dw_.head_dim;
if (inferred_n_head != dw_.n_head) {
std::printf("[gemma4] draft: overriding n_head %d -> %d\n",
dw_.n_head, inferred_n_head);
dw_.n_head = inferred_n_head;
}
}
if (dw_.n_layer > 0 && dw_.layers[0].w_gate) {
const int inferred_ff = (int)dw_.layers[0].w_gate->ne[1];
if (inferred_ff != dw_.n_ff) {
std::printf("[gemma4] draft: overriding n_ff %d -> %d\n",
dw_.n_ff, inferred_ff);
dw_.n_ff = inferred_ff;
}
}
dw_.n_target_layers = n_capture;
dw_.swa_window = 2048;
for (int i = 0; i < dw_.n_layer - 1 && i < (int)dw_.layers.size(); i++) {
dw_.layers[i].is_swa = true;
}

std::printf("[gemma4] draft loaded: fc_in=%d target_hidden=%d "
"draft_hidden=%d n_capture_layers=%d swa=%d\n",
fc_in, w_.n_embd, draft_hidden, n_capture, dw_.swa_window);

constexpr int TARGET_FEAT_CAP = 4096;
const int feat_cap = std::min(cfg_.device.max_ctx, TARGET_FEAT_CAP);
if (!cache_.target_feat &&
!create_gemma4_target_feat(backend_, cache_, n_capture, w_.n_embd, feat_cap)) {
std::fprintf(stderr, "[gemma4] target_feat alloc failed\n");
free_decode_draft();
return false;
}

const int mirror_cap = std::min(cfg_.draft_ctx_max, feat_cap);
if (!draft_feature_mirror_init(feature_mirror_, draft_backend_,
draft_gpu, cfg_.device.gpu, mirror_cap,
n_capture, w_.n_embd)) {
std::fprintf(stderr, "[gemma4] feature mirror init failed\n");
free_decode_draft();
return false;
}

delete dflash_target_;
dflash_target_ = new Gemma4DFlashTarget(w_, cache_, backend_);
draft_parked_ = false;
std::printf("[gemma4] spec-decode ready: capture_layers=%d mirror_cap=%d\n",
n_capture, mirror_cap);
std::printf("[gemma4] capture_layer_ids:");
for (int k = 0; k < (int)cache_.capture_layer_ids.size(); k++) {
std::printf(" %d", cache_.capture_layer_ids[k]);
}
std::printf("\n");
return true;
}

void Gemma4Backend::free_decode_draft() {
delete dflash_target_;
dflash_target_ = nullptr;
draft_feature_mirror_free(feature_mirror_);
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
free_gemma4_target_feat(cache_);
if (dw_.ctx) {
free_draft_weights(dw_);
}
if (draft_backend_) {
ggml_backend_free(draft_backend_);
draft_backend_ = nullptr;
}
}

// ── Shutdown ───────────────────────────────────────────────────────────

void Gemma4Backend::shutdown() {
for (int i = 0; i < PREFIX_SLOTS; ++i) snapshot_free(i);
free_drafter();
// Clean up DFlash spec-decode resources
delete dflash_target_; dflash_target_ = nullptr;
draft_feature_mirror_free(feature_mirror_);
if (dw_.ctx) { free_draft_weights(dw_); }
if (draft_backend_) { ggml_backend_free(draft_backend_); draft_backend_ = nullptr; }
free_decode_draft();
free_gemma4_cache(cache_);
free_gemma4_weights(w_);
free_snapshot_backend(snap_backend_, backend_);
Expand Down
3 changes: 3 additions & 0 deletions server/src/gemma4/gemma4_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ class Gemma4Backend : public ModelBackend {
const DaemonIO & io,
const BudgetHook * budget_hook = nullptr,
bool * forced_close_out = nullptr);

bool load_decode_draft();
void free_decode_draft();
};

} // namespace dflash::common
1 change: 1 addition & 0 deletions server/src/gemma4/gemma4_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ bool create_gemma4_cache_partial(ggml_backend_t backend,
void free_gemma4_cache(Gemma4Cache & c);

// Allocate target_feat ring buffer (call after draft load determines n_capture_layers).
void free_gemma4_target_feat(Gemma4Cache & c);
bool create_gemma4_target_feat(ggml_backend_t backend, Gemma4Cache & cache,
int n_capture_layers, int hidden_size, int cap);

Expand Down
9 changes: 9 additions & 0 deletions server/src/gemma4/gemma4_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,15 @@ void free_gemma4_cache(Gemma4Cache & c) {
c.cur_pos = 0;
}

void free_gemma4_target_feat(Gemma4Cache & c) {
if (c.feat_buf) { ggml_backend_buffer_free(c.feat_buf); c.feat_buf = nullptr; }
if (c.feat_ctx) { ggml_free(c.feat_ctx); c.feat_ctx = nullptr; }
c.target_feat = nullptr;
c.target_feat_cap = 0;
c.n_capture_layers = 0;
c.capture_layer_ids.clear();
}

bool create_gemma4_target_feat(ggml_backend_t backend, Gemma4Cache & cache,
int n_capture_layers, int hidden_size, int cap) {
if (n_capture_layers <= 0 || hidden_size <= 0 || cap <= 0) return false;
Expand Down
Loading
Loading