diff --git a/include/mlx-lm/common/gated_delta.h b/include/mlx-lm/common/gated_delta.h index 054ec928..65bf1c37 100644 --- a/include/mlx-lm/common/gated_delta.h +++ b/include/mlx-lm/common/gated_delta.h @@ -59,6 +59,13 @@ mlx::core::array inplace_write(const mlx::core::array& dst, mlx::core::array kv_inplace_update( const mlx::core::array& cache, const mlx::core::array& new_kv, int offset); +// Device-position variant: the write position is read from a device [1] int32 +// buffer (graph_decode_pos) so the build-once decode graph relaunches correctly. +// Output aliases the cache buffer — the new K/V is written directly into KV[pos]. +mlx::core::array kv_inplace_update_at( + const mlx::core::array& cache, const mlx::core::array& new_kv, + const mlx::core::array& pos); + // FlashQLA-style fused GDN decode step (T=1): folds q/k-RMSNorm + beta/g + // the delta recurrence into ONE kernel (replaces rms_norm(q)+rms_norm(k)+ // compiled beta/g + gated_delta_step). q,k: [B,1,Hk,Dk], v: [B,1,Hv,Dv], @@ -71,7 +78,10 @@ std::pair gdn_fused_decode( const mlx::core::array& b, const mlx::core::array& a_log, const mlx::core::array& dt_bias, const mlx::core::array& q_norm_w, const mlx::core::array& k_norm_w, - const mlx::core::array& state); + const mlx::core::array& state, + // When true, state_out aliases state_in (output written into the same fixed + // buffer) for build-once replay — no scratch slot, no copy. + bool inplace = false); // Fused GDN conv1d decode step: causal depthwise conv (KS taps) + silu + state // shift in one kernel. conv_state [B,KS-1,CD], qkv [B,1,CD], weight [CD,1,KS]. @@ -80,7 +90,9 @@ std::pair gdn_fused_decode( std::pair gdn_conv_step( const mlx::core::array& conv_state, const mlx::core::array& qkv, - const mlx::core::array& weight); + const mlx::core::array& weight, + // When true, new_state aliases conv_state (in-place) for build-once replay. + bool inplace = false); // Fused residual-add + RMSNorm. Returns (sum = a+b, normed = rmsnorm(sum)*weight) // in one kernel — eliminates the standalone residual add and keeps the sum diff --git a/include/mlx-lm/common/generate.h b/include/mlx-lm/common/generate.h index 00690696..633ac000 100644 --- a/include/mlx-lm/common/generate.h +++ b/include/mlx-lm/common/generate.h @@ -334,6 +334,10 @@ class TokenIterator { int pure_graph_state_ = 0; int pure_graph_cap_ = 0; // reserved KV capacity int pure_pos_ = 0; // host mirror of the device decode position + // Logits array recorded during the capture token; the captured exec's + // baked output buffer is overwritten by each replay, so re-reading it + // (with a forced copy) yields the current token's logits. + std::optional pure_logits_; // Run one decode step under the pure-graph path; returns the sampled token. mlx::core::array step_pure_graph(const LMInput::Text& previous); diff --git a/include/mlx-lm/llm/models/qwen35_moe.h b/include/mlx-lm/llm/models/qwen35_moe.h index b7e81a67..29f38ab7 100644 --- a/include/mlx-lm/llm/models/qwen35_moe.h +++ b/include/mlx-lm/llm/models/qwen35_moe.h @@ -157,6 +157,8 @@ class Qwen35MoEGatedDeltaNet { std::optional conv1d_w_dec_; // reshaped+transposed conv weight std::optional q_norm_w_; // full(head_k_dim, inv_scale^2) std::optional k_norm_w_; // full(head_k_dim, inv_scale) + std::optional a_log_f32_; + std::optional dt_bias_f32_; // The four in_proj projections (qkv, z, b, a) all map hidden_size -> their // own output width with identical input width / quantization params, so diff --git a/src/common/gated_delta.cpp b/src/common/gated_delta.cpp index 61f5772c..8be23b3a 100644 --- a/src/common/gated_delta.cpp +++ b/src/common/gated_delta.cpp @@ -291,6 +291,21 @@ static mx::fast::CustomKernelFunction& get_conv_step_kernel() { return kernel; } +// In-place variant: new_state (output 1) aliases conv_state (input 0). The +// kernel reads all taps before writing the shifted state, and each thread owns a +// (batch, channel), so the per-thread read-before-write makes the alias race-free +// — the conv state is updated in its fixed buffer with no copy. +static mx::fast::CustomKernelFunction& get_conv_step_kernel_inplace() { + static auto kernel = mx::fast::hip_kernel( + "gdn_conv_step", + {"conv_state", "qkv", "weight"}, + {"conv_out", "new_state"}, + conv_step_hip_source, + /*header=*/"", /*ensure_row_contiguous=*/true, /*shared_memory=*/0, + /*output_input_aliases=*/{{1, 0}}); + return kernel; +} + // Fused residual-add + RMSNorm: sum = a + b ; normed = rmsnorm(sum) * weight. // Returns BOTH (sum for the next residual, normed for the next matmul) in one // kernel, eliminating the standalone add (binary_vv) and keeping the residual @@ -488,6 +503,22 @@ static mx::fast::CustomKernelFunction& get_gdn_fused_decode_kernel() { return kernel; } +// In-place variant: state_out (output 1) aliases state_in (input 9) so the new +// SSM state is written back into the same fixed buffer — no copy. Used by the +// build-once decode path (the kernel reads the full state before writing, so the +// alias is race-free). Kept separate so the eager path (no sole-ref donation) +// keeps using the non-aliased kernel and never pays a donation-failure copy. +static mx::fast::CustomKernelFunction& get_gdn_fused_decode_kernel_inplace() { + static auto kernel = mx::fast::hip_kernel( + "gdn_fused_decode", + {"q", "k", "v", "b", "a", "a_log", "dt_bias", "q_norm_w", "k_norm_w", "state_in"}, + {"y", "state_out"}, + gdn_fused_decode_hip_source, + /*header=*/"", /*ensure_row_contiguous=*/true, /*shared_memory=*/0, + /*output_input_aliases=*/{{1, 9}}); + return kernel; +} + // --------------------------------------------------------------------------- // gatedDeltaKernel — dispatch the fused HIP kernel // --------------------------------------------------------------------------- @@ -821,7 +852,7 @@ std::pair gdn_fused_decode( const mx::array& a, const mx::array& b, const mx::array& a_log, const mx::array& dt_bias, const mx::array& q_norm_w, const mx::array& k_norm_w, - const mx::array& state) + const mx::array& state, bool inplace) { int B = q.shape(0); int Hk = q.shape(2), Dk = q.shape(3); @@ -832,7 +863,9 @@ std::pair gdn_fused_decode( auto t = q.dtype(); auto al = mx::astype(a_log, mx::float32); auto db = mx::astype(dt_bias, mx::float32); - auto results = get_gdn_fused_decode_kernel()( + auto& kern = inplace ? get_gdn_fused_decode_kernel_inplace() + : get_gdn_fused_decode_kernel(); + auto results = kern( {q, k, v, b, a, al, db, q_norm_w, k_norm_w, state}, {{B, 1, Hv, Dv}, state.shape()}, {t, t}, @@ -852,7 +885,8 @@ std::pair gdn_fused_decode( std::pair gdn_conv_step( const mx::array& conv_state, // [B, KS-1, CD] const mx::array& qkv, // [B, 1, CD] - const mx::array& weight) // [CD, 1, KS] + const mx::array& weight, // [CD, 1, KS] + bool inplace) { int B = qkv.shape(0); int CD = qkv.shape(2); @@ -864,7 +898,9 @@ std::pair gdn_conv_step( static const bool force_mxops = std::getenv("MLX_GDN_CONV_MXOPS") != nullptr; if (!force_mxops) { auto t = qkv.dtype(); - auto results = get_conv_step_kernel()( + auto& kern = inplace ? get_conv_step_kernel_inplace() + : get_conv_step_kernel(); + auto results = kern( {conv_state, qkv, weight}, {{B, 1, CD}, {B, KS - 1, CD}}, {t, t}, @@ -985,6 +1021,32 @@ mx::array kv_inplace_update( #endif } +mx::array kv_inplace_update_at( + const mx::array& cache, const mx::array& new_kv, const mx::array& pos) +{ +#if defined(MLX_BUILD_ROCM) && MLX_BUILD_ROCM + // Same in-place accessor as kv_inplace_update, but the write position comes + // from a device [1] int32 buffer (graph_decode_pos) instead of a host int — + // so the build-once decode graph relaunches correctly as the loop advances + // pos device-side. Output aliases the cache buffer: the math output (roped K + // / V) is written DIRECTLY into KV[pos], no slice_update array op, no copy. + int B = cache.shape(0), H = cache.shape(1); + int ALLOC = cache.shape(2), D = cache.shape(3); + int N = new_kv.shape(2); + long total = (long)B * H * N * D; + auto res = get_kv_inplace_update_kernel()( + {cache, new_kv, pos}, + {cache.shape()}, {cache.dtype()}, + {static_cast(total), 1, 1}, {256, 1, 1}, + {{"B", B}, {"H", H}, {"ALLOC", ALLOC}, {"D", D}, {"N", N}}, + std::nullopt, true, {}); + return res[0]; +#else + (void)new_kv; (void)pos; + return cache; +#endif +} + mx::array inplace_write(const mx::array& dst, const mx::array& src) { #if defined(MLX_BUILD_ROCM) && MLX_BUILD_ROCM int n = static_cast(src.size()); diff --git a/src/common/generate.cpp b/src/common/generate.cpp index f429d14d..6493ca1b 100644 --- a/src/common/generate.cpp +++ b/src/common/generate.cpp @@ -19,16 +19,19 @@ #include namespace mlx::core { void gpu_set_graph_decode_mode(bool v); -// Build-once pure-relaunch decode + deterministic arena (rocm backend bridge). -void decode_pure_record(int slot); -void decode_pure_replay(int slot); -void decode_pure_off(); -size_t decode_pure_chain_len(int slot); +// Deterministic arena (rocm backend bridge). bool decode_arena_begin(size_t capacity, int device, void* stream); void decode_arena_reset(); +void decode_arena_freeze_floor(); +void decode_arena_reset_to_floor(); void decode_arena_end(); bool decode_arena_overflowed(); -void gpu_buffer_copy(array& dst, array& src); +long decode_inline_launch_count(); +// Full decode-step stream capture (build-once / replay). +bool decode_capture_begin(); +bool decode_capture_end_record(int slot); +bool decode_capture_replay(int slot); +void decode_capture_destroy(); } // namespace mlx::core #endif @@ -357,14 +360,14 @@ mx::array TokenIterator::step(const LMInput::Text& previous) { } #if defined(MLX_BUILD_ROCM) -// Build-once pure-relaunch decode step. State machine: -// 0 warmup — engage device-pos; warm mx::compile caches (no record) -// 1 record — record the per-token graph chain once -// 2 replay — relaunch the recorded chain every token -// 9 disabled — fell back to the normal path (arena overflow / mismatch) -// One graph suffices: the GatedDeltaNet recurrent state lives in a single static -// buffer updated in place (no parity ping-pong), and position + input token are -// injected each step via fixed-address device buffers. +// Build-once pure-relaunch decode step. Captures the whole forward into a HIP +// graph once, then relaunches the cached exec every token. State machine: +// 0 warmup -> 1 record -> 2 replay (9 = disabled: arena overflow / capture fail) +// Everything that varies per token lives in FIXED-address buffers so the +// recorded exec's baked pointers stay valid across relaunches: position and input +// token are device buffers injected each step; the GDN recurrent state is updated +// IN PLACE in its cache slots [0]/[1] (the fused kernels alias state-out to +// state-in); KV is written in place at the device position. No scratch, no copy. mx::array TokenIterator::step_pure_graph(const LMInput::Text& previous) { StreamGuard sg(generation_stream()); namespace mc = mlx::core; @@ -393,68 +396,99 @@ mx::array TokenIterator::step_pure_graph(const LMInput::Text& previous) { mlx_lm::advance_graph_decode_pos(1); pure_pos_ += 1; } - // Move GDN scratch next-state [2]/[3] -> read state [0]/[1] (immediate). - static const bool cpdbg = std::getenv("MLX_COPY_DEBUG") != nullptr; - int n_mamba = 0, n_scratch = 0; - if (pure_graph_state_ >= 1) { - for (auto& c : cache_) { - auto* m = c.as_mamba(); - if (!m) continue; - n_mamba++; - if ((*m)[2].has_value()) { - n_scratch++; - mc::gpu_buffer_copy((*m)[0].value(), (*m)[2].value()); - mc::gpu_buffer_copy((*m)[1].value(), (*m)[3].value()); - } + + // GDN recurrent state is updated IN PLACE in cache slots [0]/[1] by the fused + // kernels (state output aliases state input), and KV is written in place at + // the device position — so there is no scratch slot to copy back between + // relaunches. One recorded exec suffices: record once (state 1), replay (2). + const int replay_state = 2; + + auto disable = [&]() { + mc::decode_capture_destroy(); + mc::decode_arena_end(); + mlx_lm::set_graph_external_pos(false); + pure_graph_state_ = 9; + }; + + mx::array token = mx::array(0); + + if (!noreplay && pure_graph_state_ == replay_state && pure_logits_.has_value()) { + // REPLAY: input/pos already set above. Relaunch the recorded exec, then + // read the freshly-overwritten logits buffer (convert_to_token's sample + // kernel reads it at launch time). + mc::decode_arena_reset_to_floor(); // keep recorded buffers; sample above + if (mc::decode_capture_replay(0)) { + token = convert_to_token(*pure_logits_); + } else { + disable(); // capture lost -> rebuild via the eager fallback below } - if (cpdbg) fprintf(stderr, "[cp] mamba=%d scratch=%d\n", n_mamba, n_scratch); } - mc::gpu_set_graph_decode_mode(true); - - // Single static recurrent-state buffer (in-place RMW) -> ONE graph, no - // parity. Record once, then relaunch the same chain every token. - if (!noreplay) { - if (pure_graph_state_ == 1) { - mc::decode_arena_begin(arena_bytes, 0, nullptr); - mc::decode_arena_reset(); - mc::decode_pure_record(0); - } else if (pure_graph_state_ == 2) { - mc::decode_arena_reset(); - mc::decode_pure_replay(0); + + const bool is_record = + !noreplay && pure_graph_state_ >= 1 && pure_graph_state_ < replay_state; + if (pure_graph_state_ != replay_state || pure_graph_state_ == 9) { + // WARMUP (0), RECORD (1..replay_state-1), or fallback: run via call_fn. + if (is_record) { + if (pure_graph_state_ == 1) + mc::decode_arena_begin(arena_bytes, 0, nullptr); + mc::decode_arena_reset(); // record forward allocates from base + mc::decode_capture_begin(); // capture the eager call_fn that follows } - } + auto result = context_.call_fn( + in, cache_.empty() ? nullptr : &cache_, + state_.has_value() ? &state_.value() : nullptr); + state_ = result.state; - auto result = context_.call_fn( - in, cache_.empty() ? nullptr : &cache_, - state_.has_value() ? &state_.value() : nullptr); - state_ = result.state; - auto token = convert_to_token(result.logits); - // Force-eval token + GDN scratch states (the loop reads their raw buffers). - std::vector ev{token}; - for (auto& c : cache_) { - auto* m = c.as_mamba(); - if (m && (*m)[2].has_value()) { ev.push_back((*m)[2].value()); ev.push_back((*m)[3].value()); } + if (is_record) { + // Launch the forward INLINE (async_eval: no blocking sync, which is + // illegal mid-capture) so every kernel records into the capture. The + // in-place GDN state slots [0]/[1] are eval'd so their writing kernels + // are captured. + std::vector outs{result.logits}; + for (auto& c : cache_) { + auto* m = c.as_mamba(); + if (!m) continue; + if ((*m)[0].has_value()) outs.push_back((*m)[0].value()); + if ((*m)[1].has_value()) outs.push_back((*m)[1].value()); + } + mx::async_eval(outs); + if (mc::decode_capture_end_record(0)) { + pure_logits_ = result.logits; // buffer overwritten by each replay + // The captured forward's allocations occupy [0, floor); freeze it + // so replay sampling allocates above the recorded buffers. + mc::decode_arena_freeze_floor(); + } else { + disable(); + } + } + token = convert_to_token(result.logits); + // Force-eval token + in-place GDN state (the next relaunch reads them). + std::vector ev{token}; + for (auto& c : cache_) { + auto* m = c.as_mamba(); + if (!m) continue; + if ((*m)[0].has_value()) ev.push_back((*m)[0].value()); + if ((*m)[1].has_value()) ev.push_back((*m)[1].value()); + } + mx::eval(ev); } - mx::eval(ev); static const bool pure_dbg = std::getenv("MLX_PURE_DEBUG") != nullptr; if (pure_dbg) { - fprintf(stderr, "[pure] state=%d pos=%d in=%d sampled=%d\n", + static long prev_inline = 0; + long now_inline = mc::decode_inline_launch_count(); + fprintf(stderr, "[pure] state=%d pos=%d in=%d sampled=%d inline=%ld(+%ld)\n", pure_graph_state_, pure_pos_, - mlx_lm::graph_decode_input().item(), token.item()); + mlx_lm::graph_decode_input().item(), token.item(), + now_inline, now_inline - prev_inline); + prev_inline = now_inline; } - auto disable = [&]() { - mc::decode_pure_off(); - mc::decode_arena_end(); - mlx_lm::set_graph_external_pos(false); - pure_graph_state_ = 9; - }; if (pure_graph_state_ == 0) { pure_graph_state_ = 1; // next token records - } else if (pure_graph_state_ == 1) { + } else if (pure_graph_state_ >= 1 && pure_graph_state_ < replay_state) { if (mc::decode_arena_overflowed()) disable(); - else pure_graph_state_ = 2; // recorded -> replay + else pure_graph_state_ += 1; // recorded -> replay } return token; } @@ -931,7 +965,7 @@ std::optional TokenIterator::next() { #if defined(MLX_BUILD_ROCM) // Build-once pure-relaunch graph decode (opt-in, qwen35-moe device-pos path). static const bool pure_enabled = - std::getenv("MLX_DECODE_GRAPH_PURE") != nullptr; + std::getenv("MLX_DECODE_GRAPH_PURE_OFF") == nullptr; if (pure_enabled && pure_graph_state_ != 9 && !cache_.empty()) { if (pure_graph_cap_ == 0) { int off = 0; diff --git a/src/common/kv_cache.cpp b/src/common/kv_cache.cpp index 32d0134c..3798f5a3 100644 --- a/src/common/kv_cache.cpp +++ b/src/common/kv_cache.cpp @@ -144,15 +144,14 @@ void KVCacheSimple::set_position(size_t pos) { std::pair KVCacheSimple::update_at_pos( const mlx::core::array& new_keys, const mlx::core::array& new_values, const mlx::core::array& pos) { - // DynamicSliceUpdate at the device-side `pos` (axis 2). The buffer must be - // pre-allocated to capacity; the offset advances device-side so the built - // graph relaunches correctly as the loop advances pos. std::move releases the - // cache's reference so slice_update can donate (update in place) — keeping the - // buffer at a FIXED address, which the build-once graph's nodes bake into. - auto k = std::move(keys_.value()); - auto v = std::move(values_.value()); - keys_ = mx::slice_update(k, new_keys, pos, {2}); - values_ = mx::slice_update(v, new_values, pos, {2}); + // Write the new K/V directly into the pre-allocated cache buffer at the + // device-side `pos` via the in-place accessor (output aliases the cache, so + // the math output lands straight in KV[pos] — no slice_update array op, no + // copy). The buffer stays at a FIXED address the build-once graph bakes, and + // `pos` is a device scalar so the recorded graph relaunches correctly as the + // loop advances it. + keys_ = kv_inplace_update_at(keys_.value(), new_keys, pos); + values_ = kv_inplace_update_at(values_.value(), new_values, pos); offset_ += new_keys.shape(2); return {keys_.value(), values_.value()}; } diff --git a/src/llm/models/qwen35_moe.cpp b/src/llm/models/qwen35_moe.cpp index ae30cd45..2c1cf424 100644 --- a/src/llm/models/qwen35_moe.cpp +++ b/src/llm/models/qwen35_moe.cpp @@ -16,8 +16,10 @@ #include #include #include +#include #include #include +#include namespace mx = mlx::core; @@ -382,15 +384,6 @@ void Qwen35MoEGatedDeltaNet::ensure_in_proj_fused() { in_proj_fused_weight_); } -// In-place overwrite of dst with src via native slice_update donation (start=0). -static mlx::core::array gdn_state_overwrite_(mlx::core::array dst, - const mlx::core::array& src) { - int nd = dst.ndim(); - std::vector axes(nd); - for (int i = 0; i < nd; ++i) axes[i] = i; - return mx::slice_update(std::move(dst), src, - mx::zeros({nd}, mx::int32), axes); -} mx::array Qwen35MoEGatedDeltaNet::operator()( const mx::array& inputs, @@ -425,12 +418,13 @@ mx::array Qwen35MoEGatedDeltaNet::operator()( // Conv1d processing. auto dtype = inputs.dtype(); - // Build-once HIP-graph decode: keep the conv + SSM recurrent state in a SINGLE - // static buffer updated IN PLACE (inplace_write). The recorded graph reads the - // fixed-address state, computes the new state, and writes it back into the same - // buffer — one kernel's read-before-write is safe within a launch, and the - // in-place write accumulates across relaunches. No double buffer / parity is - // needed (only stream capture required ping-pong; build+relaunch does not). + // Build-once HIP-graph decode keeps the conv + SSM recurrent state in + // fixed-address cache slots [0]/[1] and updates them IN PLACE: the fused + // kernels alias their state output to the state input (forced alias, like + // kv_inplace_update), so the new state is written back into the SAME buffer — + // no scratch slot, no copy, and the address the recorded exec bakes stays put + // across relaunches. The kernels read the full state before writing, so the + // in-place update is race-free. bool gdn_inplace = mlx_lm::graph_external_pos() && S == 1 && cache; mx::array conv_state(0.0f); @@ -449,15 +443,8 @@ mx::array Qwen35MoEGatedDeltaNet::operator()( // eliminating their copy_gg/copy_g kernels. if (S == 1 && cache) { auto [conv_out, new_state] = - gdn_conv_step(conv_state, qkv, conv1d_weight_); - if (gdn_inplace && (*cache)[0].has_value()) { - // Read [0], write new state to scratch [2]; loop copies [2]->[0]. - if (!(*cache)[2].has_value()) - (*cache)[2] = mx::zeros_like((*cache)[0].value()); - (*cache)[2] = gdn_state_overwrite_(std::move((*cache)[2].value()), new_state); - } else { - (*cache)[0] = new_state; - } + gdn_conv_step(conv_state, qkv, conv1d_weight_, /*inplace=*/gdn_inplace); + (*cache)[0] = new_state; // Split into q, k, v auto q_out = mx::reshape(mx::slice(conv_out, {0, 0, 0}, {B, 1, key_dim_}), @@ -501,25 +488,24 @@ mx::array Qwen35MoEGatedDeltaNet::operator()( fprintf(stderr, "[st] read_ssm %.6e\n", c.item()); } + if (!a_log_f32_.has_value()) { + a_log_f32_ = mx::astype(a_log_, mx::float32); + dt_bias_f32_ = mx::astype(dt_bias_, mx::float32); + mx::eval(*a_log_f32_, *dt_bias_f32_); + } + if (use_fused_gdn) { mx::array o(0.0f), ns(0.0f); if (use_fused2) { std::tie(o, ns) = gdn_fused_decode( - q_out, k_out, v_out, a_val, b_val, a_log_, dt_bias_, - *q_norm_w_, *k_norm_w_, ssm_state); + q_out, k_out, v_out, a_val, b_val, *a_log_f32_, *dt_bias_f32_, + *q_norm_w_, *k_norm_w_, ssm_state, /*inplace=*/gdn_inplace); } else { std::tie(o, ns) = gated_delta_update( - q_out, k_out, v_out, a_val, b_val, a_log_, dt_bias_, - ssm_state, std::nullopt, /*inplace=*/false); - } - if (gdn_inplace && (*cache)[1].has_value()) { - // Read [1], write new state to scratch [3]; loop copies [3]->[1]. - if (!(*cache)[3].has_value()) - (*cache)[3] = mx::zeros_like((*cache)[1].value()); - (*cache)[3] = gdn_state_overwrite_(std::move((*cache)[3].value()), ns); - } else { - (*cache)[1] = ns; + q_out, k_out, v_out, a_val, b_val, *a_log_f32_, *dt_bias_f32_, + ssm_state, std::nullopt, /*inplace_state=*/gdn_inplace); } + (*cache)[1] = ns; auto normalized = norm_(o, z); return linear_fwd(mx::reshape(normalized, {B, S, -1}), out_proj_weight_); }