Skip to content
16 changes: 14 additions & 2 deletions include/mlx-lm/common/gated_delta.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ mlx::core::array inplace_write(const mlx::core::array& dst,
mlx::core::array kv_inplace_update(
const mlx::core::array& cache, const mlx::core::array& new_kv, int offset);

// Device-position variant: the write position is read from a device [1] int32
// buffer (graph_decode_pos) so the build-once decode graph relaunches correctly.
// Output aliases the cache buffer — the new K/V is written directly into KV[pos].
mlx::core::array kv_inplace_update_at(
const mlx::core::array& cache, const mlx::core::array& new_kv,
const mlx::core::array& pos);

// FlashQLA-style fused GDN decode step (T=1): folds q/k-RMSNorm + beta/g +
// the delta recurrence into ONE kernel (replaces rms_norm(q)+rms_norm(k)+
// compiled beta/g + gated_delta_step). q,k: [B,1,Hk,Dk], v: [B,1,Hv,Dv],
Expand All @@ -71,7 +78,10 @@ std::pair<mlx::core::array, mlx::core::array> gdn_fused_decode(
const mlx::core::array& b, const mlx::core::array& a_log,
const mlx::core::array& dt_bias,
const mlx::core::array& q_norm_w, const mlx::core::array& k_norm_w,
const mlx::core::array& state);
const mlx::core::array& state,
// When true, state_out aliases state_in (output written into the same fixed
// buffer) for build-once replay — no scratch slot, no copy.
bool inplace = false);

// Fused GDN conv1d decode step: causal depthwise conv (KS taps) + silu + state
// shift in one kernel. conv_state [B,KS-1,CD], qkv [B,1,CD], weight [CD,1,KS].
Expand All @@ -80,7 +90,9 @@ std::pair<mlx::core::array, mlx::core::array> gdn_fused_decode(
std::pair<mlx::core::array, mlx::core::array> gdn_conv_step(
const mlx::core::array& conv_state,
const mlx::core::array& qkv,
const mlx::core::array& weight);
const mlx::core::array& weight,
// When true, new_state aliases conv_state (in-place) for build-once replay.
bool inplace = false);

// Fused residual-add + RMSNorm. Returns (sum = a+b, normed = rmsnorm(sum)*weight)
// in one kernel — eliminates the standalone residual add and keeps the sum
Expand Down
4 changes: 4 additions & 0 deletions include/mlx-lm/common/generate.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@ class TokenIterator {
int pure_graph_state_ = 0;
int pure_graph_cap_ = 0; // reserved KV capacity
int pure_pos_ = 0; // host mirror of the device decode position
// Logits array recorded during the capture token; the captured exec's
// baked output buffer is overwritten by each replay, so re-reading it
// (with a forced copy) yields the current token's logits.
std::optional<mlx::core::array> pure_logits_;
// Run one decode step under the pure-graph path; returns the sampled token.
mlx::core::array step_pure_graph(const LMInput::Text& previous);

Expand Down
2 changes: 2 additions & 0 deletions include/mlx-lm/llm/models/qwen35_moe.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ class Qwen35MoEGatedDeltaNet {
std::optional<mlx::core::array> conv1d_w_dec_; // reshaped+transposed conv weight
std::optional<mlx::core::array> q_norm_w_; // full(head_k_dim, inv_scale^2)
std::optional<mlx::core::array> k_norm_w_; // full(head_k_dim, inv_scale)
std::optional<mlx::core::array> a_log_f32_;
std::optional<mlx::core::array> dt_bias_f32_;

// The four in_proj projections (qkv, z, b, a) all map hidden_size -> their
// own output width with identical input width / quantization params, so
Expand Down
70 changes: 66 additions & 4 deletions src/common/gated_delta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,21 @@ static mx::fast::CustomKernelFunction& get_conv_step_kernel() {
return kernel;
}

// In-place variant: new_state (output 1) aliases conv_state (input 0). The
// kernel reads all taps before writing the shifted state, and each thread owns a
// (batch, channel), so the per-thread read-before-write makes the alias race-free
// — the conv state is updated in its fixed buffer with no copy.
static mx::fast::CustomKernelFunction& get_conv_step_kernel_inplace() {
static auto kernel = mx::fast::hip_kernel(
"gdn_conv_step",
{"conv_state", "qkv", "weight"},
{"conv_out", "new_state"},
conv_step_hip_source,
/*header=*/"", /*ensure_row_contiguous=*/true, /*shared_memory=*/0,
/*output_input_aliases=*/{{1, 0}});
return kernel;
}

// Fused residual-add + RMSNorm: sum = a + b ; normed = rmsnorm(sum) * weight.
// Returns BOTH (sum for the next residual, normed for the next matmul) in one
// kernel, eliminating the standalone add (binary_vv) and keeping the residual
Expand Down Expand Up @@ -488,6 +503,22 @@ static mx::fast::CustomKernelFunction& get_gdn_fused_decode_kernel() {
return kernel;
}

// In-place variant: state_out (output 1) aliases state_in (input 9) so the new
// SSM state is written back into the same fixed buffer — no copy. Used by the
// build-once decode path (the kernel reads the full state before writing, so the
// alias is race-free). Kept separate so the eager path (no sole-ref donation)
// keeps using the non-aliased kernel and never pays a donation-failure copy.
static mx::fast::CustomKernelFunction& get_gdn_fused_decode_kernel_inplace() {
static auto kernel = mx::fast::hip_kernel(
"gdn_fused_decode",
{"q", "k", "v", "b", "a", "a_log", "dt_bias", "q_norm_w", "k_norm_w", "state_in"},
{"y", "state_out"},
gdn_fused_decode_hip_source,
/*header=*/"", /*ensure_row_contiguous=*/true, /*shared_memory=*/0,
/*output_input_aliases=*/{{1, 9}});
return kernel;
}

// ---------------------------------------------------------------------------
// gatedDeltaKernel — dispatch the fused HIP kernel
// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -821,7 +852,7 @@ std::pair<mx::array, mx::array> gdn_fused_decode(
const mx::array& a, const mx::array& b,
const mx::array& a_log, const mx::array& dt_bias,
const mx::array& q_norm_w, const mx::array& k_norm_w,
const mx::array& state)
const mx::array& state, bool inplace)
{
int B = q.shape(0);
int Hk = q.shape(2), Dk = q.shape(3);
Expand All @@ -832,7 +863,9 @@ std::pair<mx::array, mx::array> gdn_fused_decode(
auto t = q.dtype();
auto al = mx::astype(a_log, mx::float32);
auto db = mx::astype(dt_bias, mx::float32);
auto results = get_gdn_fused_decode_kernel()(
auto& kern = inplace ? get_gdn_fused_decode_kernel_inplace()
: get_gdn_fused_decode_kernel();
auto results = kern(
{q, k, v, b, a, al, db, q_norm_w, k_norm_w, state},
{{B, 1, Hv, Dv}, state.shape()},
{t, t},
Expand All @@ -852,7 +885,8 @@ std::pair<mx::array, mx::array> gdn_fused_decode(
std::pair<mx::array, mx::array> gdn_conv_step(
const mx::array& conv_state, // [B, KS-1, CD]
const mx::array& qkv, // [B, 1, CD]
const mx::array& weight) // [CD, 1, KS]
const mx::array& weight, // [CD, 1, KS]
bool inplace)
{
int B = qkv.shape(0);
int CD = qkv.shape(2);
Expand All @@ -864,7 +898,9 @@ std::pair<mx::array, mx::array> gdn_conv_step(
static const bool force_mxops = std::getenv("MLX_GDN_CONV_MXOPS") != nullptr;
if (!force_mxops) {
auto t = qkv.dtype();
auto results = get_conv_step_kernel()(
auto& kern = inplace ? get_conv_step_kernel_inplace()
: get_conv_step_kernel();
auto results = kern(
{conv_state, qkv, weight},
{{B, 1, CD}, {B, KS - 1, CD}},
{t, t},
Expand Down Expand Up @@ -985,6 +1021,32 @@ mx::array kv_inplace_update(
#endif
}

mx::array kv_inplace_update_at(
const mx::array& cache, const mx::array& new_kv, const mx::array& pos)
{
#if defined(MLX_BUILD_ROCM) && MLX_BUILD_ROCM
// Same in-place accessor as kv_inplace_update, but the write position comes
// from a device [1] int32 buffer (graph_decode_pos) instead of a host int —
// so the build-once decode graph relaunches correctly as the loop advances
// pos device-side. Output aliases the cache buffer: the math output (roped K
// / V) is written DIRECTLY into KV[pos], no slice_update array op, no copy.
int B = cache.shape(0), H = cache.shape(1);
int ALLOC = cache.shape(2), D = cache.shape(3);
int N = new_kv.shape(2);
long total = (long)B * H * N * D;
auto res = get_kv_inplace_update_kernel()(
{cache, new_kv, pos},
{cache.shape()}, {cache.dtype()},
{static_cast<int>(total), 1, 1}, {256, 1, 1},
{{"B", B}, {"H", H}, {"ALLOC", ALLOC}, {"D", D}, {"N", N}},
std::nullopt, true, {});
return res[0];
#else
(void)new_kv; (void)pos;
return cache;
#endif
}

mx::array inplace_write(const mx::array& dst, const mx::array& src) {
#if defined(MLX_BUILD_ROCM) && MLX_BUILD_ROCM
int n = static_cast<int>(src.size());
Expand Down
160 changes: 97 additions & 63 deletions src/common/generate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,19 @@
#include <mlx-lm/common/graph_decode.h>
namespace mlx::core {
void gpu_set_graph_decode_mode(bool v);
// Build-once pure-relaunch decode + deterministic arena (rocm backend bridge).
void decode_pure_record(int slot);
void decode_pure_replay(int slot);
void decode_pure_off();
size_t decode_pure_chain_len(int slot);
// Deterministic arena (rocm backend bridge).
bool decode_arena_begin(size_t capacity, int device, void* stream);
void decode_arena_reset();
void decode_arena_freeze_floor();
void decode_arena_reset_to_floor();
void decode_arena_end();
bool decode_arena_overflowed();
void gpu_buffer_copy(array& dst, array& src);
long decode_inline_launch_count();
// Full decode-step stream capture (build-once / replay).
bool decode_capture_begin();
bool decode_capture_end_record(int slot);
bool decode_capture_replay(int slot);
void decode_capture_destroy();
} // namespace mlx::core
#endif

Expand Down Expand Up @@ -357,14 +360,14 @@ mx::array TokenIterator::step(const LMInput::Text& previous) {
}

#if defined(MLX_BUILD_ROCM)
// Build-once pure-relaunch decode step. State machine:
// 0 warmup — engage device-pos; warm mx::compile caches (no record)
// 1 record — record the per-token graph chain once
// 2 replay — relaunch the recorded chain every token
// 9 disabled — fell back to the normal path (arena overflow / mismatch)
// One graph suffices: the GatedDeltaNet recurrent state lives in a single static
// buffer updated in place (no parity ping-pong), and position + input token are
// injected each step via fixed-address device buffers.
// Build-once pure-relaunch decode step. Captures the whole forward into a HIP
// graph once, then relaunches the cached exec every token. State machine:
// 0 warmup -> 1 record -> 2 replay (9 = disabled: arena overflow / capture fail)
// Everything that varies per token lives in FIXED-address buffers so the
// recorded exec's baked pointers stay valid across relaunches: position and input
// token are device buffers injected each step; the GDN recurrent state is updated
// IN PLACE in its cache slots [0]/[1] (the fused kernels alias state-out to
// state-in); KV is written in place at the device position. No scratch, no copy.
mx::array TokenIterator::step_pure_graph(const LMInput::Text& previous) {
StreamGuard sg(generation_stream());
namespace mc = mlx::core;
Expand Down Expand Up @@ -393,68 +396,99 @@ mx::array TokenIterator::step_pure_graph(const LMInput::Text& previous) {
mlx_lm::advance_graph_decode_pos(1);
pure_pos_ += 1;
}
// Move GDN scratch next-state [2]/[3] -> read state [0]/[1] (immediate).
static const bool cpdbg = std::getenv("MLX_COPY_DEBUG") != nullptr;
int n_mamba = 0, n_scratch = 0;
if (pure_graph_state_ >= 1) {
for (auto& c : cache_) {
auto* m = c.as_mamba();
if (!m) continue;
n_mamba++;
if ((*m)[2].has_value()) {
n_scratch++;
mc::gpu_buffer_copy((*m)[0].value(), (*m)[2].value());
mc::gpu_buffer_copy((*m)[1].value(), (*m)[3].value());
}

// GDN recurrent state is updated IN PLACE in cache slots [0]/[1] by the fused
// kernels (state output aliases state input), and KV is written in place at
// the device position — so there is no scratch slot to copy back between
// relaunches. One recorded exec suffices: record once (state 1), replay (2).
const int replay_state = 2;

auto disable = [&]() {
mc::decode_capture_destroy();
mc::decode_arena_end();
mlx_lm::set_graph_external_pos(false);
pure_graph_state_ = 9;
};

mx::array token = mx::array(0);

if (!noreplay && pure_graph_state_ == replay_state && pure_logits_.has_value()) {
// REPLAY: input/pos already set above. Relaunch the recorded exec, then
// read the freshly-overwritten logits buffer (convert_to_token's sample
// kernel reads it at launch time).
mc::decode_arena_reset_to_floor(); // keep recorded buffers; sample above
if (mc::decode_capture_replay(0)) {
token = convert_to_token(*pure_logits_);
} else {
disable(); // capture lost -> rebuild via the eager fallback below
}
if (cpdbg) fprintf(stderr, "[cp] mamba=%d scratch=%d\n", n_mamba, n_scratch);
}
mc::gpu_set_graph_decode_mode(true);

// Single static recurrent-state buffer (in-place RMW) -> ONE graph, no
// parity. Record once, then relaunch the same chain every token.
if (!noreplay) {
if (pure_graph_state_ == 1) {
mc::decode_arena_begin(arena_bytes, 0, nullptr);
mc::decode_arena_reset();
mc::decode_pure_record(0);
} else if (pure_graph_state_ == 2) {
mc::decode_arena_reset();
mc::decode_pure_replay(0);

const bool is_record =
!noreplay && pure_graph_state_ >= 1 && pure_graph_state_ < replay_state;
if (pure_graph_state_ != replay_state || pure_graph_state_ == 9) {
// WARMUP (0), RECORD (1..replay_state-1), or fallback: run via call_fn.
if (is_record) {
if (pure_graph_state_ == 1)
mc::decode_arena_begin(arena_bytes, 0, nullptr);
mc::decode_arena_reset(); // record forward allocates from base
mc::decode_capture_begin(); // capture the eager call_fn that follows
}
}
auto result = context_.call_fn(
in, cache_.empty() ? nullptr : &cache_,
state_.has_value() ? &state_.value() : nullptr);
state_ = result.state;

auto result = context_.call_fn(
in, cache_.empty() ? nullptr : &cache_,
state_.has_value() ? &state_.value() : nullptr);
state_ = result.state;
auto token = convert_to_token(result.logits);
// Force-eval token + GDN scratch states (the loop reads their raw buffers).
std::vector<mx::array> ev{token};
for (auto& c : cache_) {
auto* m = c.as_mamba();
if (m && (*m)[2].has_value()) { ev.push_back((*m)[2].value()); ev.push_back((*m)[3].value()); }
if (is_record) {
// Launch the forward INLINE (async_eval: no blocking sync, which is
// illegal mid-capture) so every kernel records into the capture. The
// in-place GDN state slots [0]/[1] are eval'd so their writing kernels
// are captured.
std::vector<mx::array> outs{result.logits};
for (auto& c : cache_) {
auto* m = c.as_mamba();
if (!m) continue;
if ((*m)[0].has_value()) outs.push_back((*m)[0].value());
if ((*m)[1].has_value()) outs.push_back((*m)[1].value());
}
mx::async_eval(outs);
if (mc::decode_capture_end_record(0)) {
pure_logits_ = result.logits; // buffer overwritten by each replay
// The captured forward's allocations occupy [0, floor); freeze it
// so replay sampling allocates above the recorded buffers.
mc::decode_arena_freeze_floor();
} else {
disable();
}
}
token = convert_to_token(result.logits);
// Force-eval token + in-place GDN state (the next relaunch reads them).
std::vector<mx::array> ev{token};
for (auto& c : cache_) {
auto* m = c.as_mamba();
if (!m) continue;
if ((*m)[0].has_value()) ev.push_back((*m)[0].value());
if ((*m)[1].has_value()) ev.push_back((*m)[1].value());
}
mx::eval(ev);
}
mx::eval(ev);

static const bool pure_dbg = std::getenv("MLX_PURE_DEBUG") != nullptr;
if (pure_dbg) {
fprintf(stderr, "[pure] state=%d pos=%d in=%d sampled=%d\n",
static long prev_inline = 0;
long now_inline = mc::decode_inline_launch_count();
fprintf(stderr, "[pure] state=%d pos=%d in=%d sampled=%d inline=%ld(+%ld)\n",
pure_graph_state_, pure_pos_,
mlx_lm::graph_decode_input().item<int>(), token.item<int>());
mlx_lm::graph_decode_input().item<int>(), token.item<int>(),
now_inline, now_inline - prev_inline);
prev_inline = now_inline;
}

auto disable = [&]() {
mc::decode_pure_off();
mc::decode_arena_end();
mlx_lm::set_graph_external_pos(false);
pure_graph_state_ = 9;
};
if (pure_graph_state_ == 0) {
pure_graph_state_ = 1; // next token records
} else if (pure_graph_state_ == 1) {
} else if (pure_graph_state_ >= 1 && pure_graph_state_ < replay_state) {
if (mc::decode_arena_overflowed()) disable();
else pure_graph_state_ = 2; // recorded -> replay
else pure_graph_state_ += 1; // recorded -> replay
}
return token;
}
Expand Down Expand Up @@ -931,7 +965,7 @@ std::optional<int> TokenIterator::next() {
#if defined(MLX_BUILD_ROCM)
// Build-once pure-relaunch graph decode (opt-in, qwen35-moe device-pos path).
static const bool pure_enabled =
std::getenv("MLX_DECODE_GRAPH_PURE") != nullptr;
std::getenv("MLX_DECODE_GRAPH_PURE_OFF") == nullptr;
if (pure_enabled && pure_graph_state_ != 9 && !cache_.empty()) {
if (pure_graph_cap_ == 0) {
int off = 0;
Expand Down
Loading
Loading