Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,12 @@ struct common_speculative_state_draft_mtp : public common_speculative_impl {
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
}

// During prompt sync the draft MTP context only needs its cache/state
// updated. Host-visible pre-norm rows are consumed during draft()
// generation, not while mirroring prompt batches.
llama_set_embeddings_pre_norm(ctx_dft, false);
const int32_t rc = llama_decode(ctx_dft, batch);
llama_set_embeddings_pre_norm(ctx_dft, true);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
return false;
Expand Down
45 changes: 28 additions & 17 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -888,14 +888,23 @@ float * llama_context::get_embeddings_pre_norm() {
}

float * llama_context::get_embeddings_pre_norm_ith(int32_t i) {
output_reorder();

try {
if (embd_pre_norm.data == nullptr) {
throw std::runtime_error("no pre-norm embeddings");
}

const int64_t j = output_resolve_row(i);
int64_t j = i;
if (j < 0) {
j = n_outputs_pre_norm + j;
if (j < 0) {
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs_pre_norm));
}
}

if (j >= n_outputs_pre_norm) {
throw std::runtime_error(format("pre-norm embeddings id out of range [0, %d)", n_outputs_pre_norm));
}

const uint32_t n_embd = model.hparams.n_embd;
return embd_pre_norm.data + j*n_embd;
} catch (const std::exception & err) {
Expand Down Expand Up @@ -1346,6 +1355,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
}

n_outputs = n_tokens;
n_outputs_pre_norm = cparams.embeddings_pre_norm ? n_tokens : 0;

const auto causal_attn_org = cparams.causal_attn;

Expand Down Expand Up @@ -1731,12 +1741,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
}

// reserve output buffer
if (output_reserve(n_outputs_all) < n_outputs_all) {
if (output_reserve(n_outputs_all, cparams.embeddings_pre_norm ? n_tokens_all : n_outputs_all) < n_outputs_all) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
return -2;
};

int64_t n_outputs_prev = 0;
int64_t n_outputs_pre_norm_prev = 0;

do {
const auto & ubatch = mctx->get_ubatch();
Expand Down Expand Up @@ -1882,16 +1893,17 @@ int llama_context::decode(const llama_batch & batch_inp) {

// extract pre-norm embeddings (hidden state before the final output norm)
// only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored.
if (embd_pre_norm.data && t_h_pre_norm && n_outputs > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm);
GGML_ASSERT(backend_h != nullptr);

const uint32_t n_embd = hparams.n_embd;
float * embd_pre_norm_out = embd_pre_norm.data + n_outputs_prev*n_embd;
const int64_t n_outputs_pre_norm_new = t_h_pre_norm->ne[1];
float * embd_pre_norm_out = embd_pre_norm.data + n_outputs_pre_norm_prev*n_embd;

GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_pre_norm.size);
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_outputs*n_embd*sizeof(float));
GGML_ASSERT(n_outputs_pre_norm_prev + n_outputs_pre_norm_new <= (int64_t) embd_pre_norm.size/(int64_t)n_embd);
ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_outputs_pre_norm_new*n_embd*sizeof(float));
n_outputs_pre_norm_prev += n_outputs_pre_norm_new;
}

// Copy backend sampling output if this ubatch produced any sampling tensors.
Expand All @@ -1912,6 +1924,7 @@ int llama_context::decode(const llama_batch & batch_inp) {

// set to total number of outputs in the batch, for use in llama_get_logits_ith
n_outputs = n_outputs_all;
n_outputs_pre_norm = n_outputs_pre_norm_prev;

// set output mappings
if (n_outputs > 0) {
Expand Down Expand Up @@ -1970,11 +1983,14 @@ int llama_context::decode(const llama_batch & batch_inp) {
// output
//

uint32_t llama_context::output_reserve(int32_t n_outputs) {
uint32_t llama_context::output_reserve(int32_t n_outputs, int32_t n_outputs_pre_norm_req) {
const auto & hparams = model.hparams;
const auto & vocab = model.vocab;

const int64_t n_outputs_max = std::max<int64_t>(n_outputs, n_seq_max());
const int64_t n_outputs_pre_norm_max = cparams.embeddings_pre_norm
? std::max<int64_t>(n_outputs_pre_norm_req < 0 ? n_outputs : n_outputs_pre_norm_req, n_seq_max())
: 0;

const auto n_batch = cparams.n_batch;
const auto n_vocab = vocab.n_tokens();
Expand All @@ -1997,7 +2013,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {

logits.size = has_logits ? n_vocab*n_outputs_max : 0;
embd.size = has_embd ? n_embd_out*n_outputs_max : 0;
embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0;
embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_pre_norm_max : 0;

// Allocate backend sampling output buffers if there are backend samplers configured.
const bool has_sampling = !sampling.samplers.empty();
Expand Down Expand Up @@ -2102,6 +2118,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
std::fill(output_ids.begin(), output_ids.end(), -1);

this->n_outputs = 0;
this->n_outputs_pre_norm = 0;

return n_outputs_max;
}
Expand All @@ -2126,12 +2143,6 @@ void llama_context::output_reorder() {
}
}

if (embd_pre_norm.size > 0) {
for (uint64_t k = 0; k < n_embd; k++) {
std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]);
}
}

if (!sampling.samplers.empty()) {
assert(sampling.logits.size > 0);
assert(sampling.probs.size > 0);
Expand Down
5 changes: 3 additions & 2 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ struct llama_context {

// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
uint32_t output_reserve(int32_t n_outputs);
uint32_t output_reserve(int32_t n_outputs, int32_t n_outputs_pre_norm = -1);

void output_reorder();

Expand Down Expand Up @@ -282,10 +282,11 @@ struct llama_context {
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
buffer_view<float> embd = {nullptr, 0};

// hidden state before the final output norm (2-dimensional array: [n_outputs][n_embd])
// hidden state before the final output norm (2-dimensional array: [n_outputs_pre_norm][n_embd])
// populated only when cparams.embeddings_pre_norm is enabled and the model graph
// sets llm_graph_result::t_h_pre_norm
buffer_view<float> embd_pre_norm = {nullptr, 0};
int32_t n_outputs_pre_norm = 0;

struct sampling_info {
// !samplers.empty() to check if any samplers are active
Expand Down
8 changes: 8 additions & 0 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,14 @@ void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
GGML_ASSERT(out_ids);

if (n_outputs == 0) {
return;
}

if (out_ids->buffer == nullptr) {
return;
}

const int64_t n_tokens = ubatch->n_tokens;

GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
Expand Down
29 changes: 25 additions & 4 deletions src/models/qwen35.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
auto * inp = build_inp_mem_hybrid();

ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = build_inp_out_ids();
ggml_tensor * inp_out_ids = (n_outputs > 0 && (!cparams.embeddings_pre_norm || n_outputs < n_tokens)) ? build_inp_out_ids() : nullptr;

// MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass.
const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers;
Expand All @@ -176,7 +176,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
}

if (il == n_transformer_layers - 1 && inp_out_ids) {
if (il == n_transformer_layers - 1 && inp_out_ids && !cparams.embeddings_pre_norm) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
Expand Down Expand Up @@ -211,6 +211,16 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;

if (n_outputs == 0) {
ggml_build_forward_expand(gf, cur);
return;
}

if (inp_out_ids && cparams.embeddings_pre_norm && n_outputs < n_tokens) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
cb(cur, "h_pre_norm_out", -1);
}

// Final norm
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);

Expand Down Expand Up @@ -520,8 +530,9 @@ llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr

res->add_input(std::move(inp));

ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = (n_outputs > 0 && n_outputs < n_tokens) ? build_inp_out_ids() : nullptr;
auto * inp_attn = build_attn_inp_kv();

ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il);
cb(h_norm, "mtp_hnorm", il);
Expand Down Expand Up @@ -610,6 +621,16 @@ llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;

if (n_outputs == 0) {
ggml_build_forward_expand(gf, cur);
return;
}

if (inp_out_ids && n_outputs < n_tokens) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
cb(cur, "mtp_h_pre_norm_out", -1);
}

ggml_tensor * head_norm_w = layer.nextn.shared_head_norm
? layer.nextn.shared_head_norm
: model.output_norm;
Expand Down
29 changes: 25 additions & 4 deletions src/models/qwen35moe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
auto * inp = build_inp_mem_hybrid();

ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = build_inp_out_ids();
ggml_tensor * inp_out_ids = (n_outputs > 0 && (!cparams.embeddings_pre_norm || n_outputs < n_tokens)) ? build_inp_out_ids() : nullptr;

// MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass.
const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers;
Expand All @@ -199,7 +199,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il);
}

if (il == n_transformer_layers - 1 && inp_out_ids) {
if (il == n_transformer_layers - 1 && inp_out_ids && !cparams.embeddings_pre_norm) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
Expand Down Expand Up @@ -234,6 +234,16 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;

if (n_outputs == 0) {
ggml_build_forward_expand(gf, cur);
return;
}

if (inp_out_ids && cparams.embeddings_pre_norm && n_outputs < n_tokens) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
cb(cur, "h_pre_norm_out", -1);
}

// Final norm
cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);

Expand Down Expand Up @@ -584,8 +594,9 @@ llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm

res->add_input(std::move(inp));

ggml_tensor * inp_pos = build_inp_pos();
auto * inp_attn = build_attn_inp_kv();
ggml_tensor * inp_pos = build_inp_pos();
ggml_tensor * inp_out_ids = (n_outputs > 0 && n_outputs < n_tokens) ? build_inp_out_ids() : nullptr;
auto * inp_attn = build_attn_inp_kv();

ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il);
cb(h_norm, "mtp_hnorm", il);
Expand Down Expand Up @@ -706,6 +717,16 @@ llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm
cb(cur, "h_pre_norm", -1);
res->t_h_pre_norm = cur;

if (n_outputs == 0) {
ggml_build_forward_expand(gf, cur);
return;
}

if (inp_out_ids && n_outputs < n_tokens) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
cb(cur, "mtp_h_pre_norm_out", -1);
}

ggml_tensor * head_norm_w = layer.nextn.shared_head_norm
? layer.nextn.shared_head_norm
: model.output_norm;
Expand Down
9 changes: 5 additions & 4 deletions tools/server/server-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,11 @@ struct server_slot {

bool need_embd() const {
GGML_ASSERT(task);
return task->need_embd() || (spec && common_speculative_need_embd(spec));
return task->need_embd();
}

bool need_embd_pre_norm() const {
return spec && common_speculative_need_embd(spec);
}

// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
Expand Down Expand Up @@ -2801,9 +2805,6 @@ struct server_context_impl {
break;
}

// embedding requires all tokens in the batch to be output;
// MTP also wants logits at every prompt position so the
// streaming hook can mirror t_h_pre_norm into ctx_dft.
common_batch_add(batch,
cur_tok,
slot.prompt.tokens.pos_next(),
Expand Down