Skip to content

feat(glm52): MLA decode brick — projection, absorb, cache-pack, FlashMLA sparse decode#477

Merged
xiaguan merged 5 commits into
mainfrom
feat/glm52-mla-decode-brick
Jun 30, 2026
Merged

feat(glm52): MLA decode brick — projection, absorb, cache-pack, FlashMLA sparse decode#477
xiaguan merged 5 commits into
mainfrom
feat/glm52-mla-decode-brick

Conversation

@xiaguan

@xiaguan xiaguan commented Jun 30, 2026

Copy link
Copy Markdown
Collaborator

Summary

First forward brick for GLM5.2 on top of the load-weight scaffold (#476): single-layer MLA decode forward (hidden[6144] -> o[6144], bs=1, full top-k). At context ≤ 2048 the full top-k is DSA-equivalent, so this is the attention correctness foundation for the DP1/EP8 decode path.

The runner/coordinator remains fail-closed — no server wiring, no scheduler, no weight-loading changes. New modules are #[allow(dead_code)] until a later PR wires them into the executor.

See docs/models/glm52/dp1-ep8-decode-plan.md for the full 5-PR roadmap (PR1 MLA brick → PR2 DSA indexer → PR3 EP1 forward → PR4 DeepEP EP8 → PR5 scheduler).

What lands

CUDA kernels (4 files):

  • glm52_mla_assembly.cu (142 lines, hand-written) — query concat + interleave RoPE + fp8 cache pack
  • glm52_moe_quant.cu (180 lines, hand-written) — per-128-group amax → e4m3 FP8 quant
  • glm52_trtllm_grouped_fp8.cu (276 lines, vendored TRTLLM CUTLASS) — blockscale GEMM, m=1 linear + grouped MoE
  • csrc/shared/linear.cu (+35 lines) — cublasGemmStridedBatchedEx wrapper

Rust ops + FFI (8 files):

  • ops/glm52/{mla_assembly,moe_quant,trtllm_linear}.rs — launch wrappers with contract validation
  • ops/linear.rsgemm_strided_batched_bf16 (cuBLAS strided batched)
  • FFI externs for all new CUDA symbols

Model crate (3 files):

  • fp8.rsProjWeight, fp8_linear (quant → relay scale → TRTLLM launch), dequant_kv_b (host fp8→bf16 absorb factors)
  • mla_decode.rsGlm52MlaLayerWeights + glm52_mla_decode_forward
  • lib.rs — registers new modules

Docs (3 files):

  • dp1-ep8-decode-plan.md — 5-PR roadmap with hand-written kernel perf-debt flags
  • mla-decode-brick.md — PR1 dev doc (build, kernel inventory)
  • index.md — glm52 routing section

Hand-written CUDA perf debt

glm52_mla_assembly.cu and glm52_moe_quant.cu are hand-written, not vendored. Both are memory-bound elementwise/reduce kernels (not GEMM/attention compute). They are correct but not tuned: single-issue-per-element, no vectorized load/store, no occupancy targeting. First ncu candidates when decode TPOT is measured. If a fused C-ABI alternative appears in vendored FlashInfer/TRTLLM, replace rather than optimize in place.

Oracle gate — deferred

This PR does not include an oracle test. The prototype had a fixture pipeline (HF forward dump → layer0.npz → probe bins → Rust test), but the dump script was never in the repo, making the chain non-reproducible. The oracle gate is deferred to a follow-up that designs a self-contained fixture pipeline.

Build

Requires SM90a (H200), CUDA 12.6+ (cuLibraryLoadData for DeepGEMM JIT), NCCL 2.30.4+ (DeepEP submodule via moe feature).

export OPENINFER_NCCL_ROOT=/path/to/nvidia/nccl
git submodule update --init --recursive
cargo check --release -p openinfer-glm52

Verified compiling on H200 (sm_90a, CUDA 12.8, NCCL 2.30.7).

Checklist

  • Compiles on H200
  • cargo fmt --all --check clean
  • Runner still fail-closed (no server/scheduler changes)
  • Oracle gate (deferred)
  • ncu profiling of hand-written kernels (deferred to PR5)

xiaguan and others added 5 commits June 30, 2026 21:42
Cherry-pick oracle-validated MLA decode kernels from the pp8 branch and
wire them for EP1/DP1 (no PP):

CUDA kernels:
- glm52_mla_assembly.cu: query assemble + cache pack (hand-written, 142 lines)
- glm52_moe_quant.cu: per-token-group FP8 quant (hand-written, 180 lines)
- glm52_trtllm_grouped_fp8.cu: TRTLLM CUTLASS blockscale GEMM (vendored)
- csrc/shared/linear.cu: gemm_strided_batched_bf16_cuda (cuBLAS)

Rust ops + FFI:
- ops/glm52: mla_assembly, moe_quant, trtllm_linear
- ops/linear: gemm_strided_batched_bf16
- ffi/glm52, ffi/shared: new extern declarations
- build.rs: TRTLLM CUTLASS includes for glm52_trtllm_grouped_fp8.cu

Model crate:
- fp8.rs: ProjWeight, fp8_linear, dequant_kv_b (shared by MLA/dense/MoE)
- mla_decode.rs: single-layer MLA decode forward (hidden->o, bs=1)
- tests/mla_decode_oracle.rs: end-to-end MLA forward vs HF oracle

Hand-written CUDA perf debt: glm52_mla_assembly.cu, glm52_moe_quant.cu —
correct, not tuned, first ncu candidates (see dp1-ep8-decode-plan.md).
…c to private

The prototype oracle fixture chain (HF forward dump -> layer0.npz ->
probe bins -> Rust test) was not self-contained: the dump script that
generates layer0.npz was never in the repo, so the test could not be
reproduced by anyone else. Remove the test + prep scripts + inline test
module; defer the oracle gate to a follow-up that designs a reproducible
fixture pipeline.

Also move tokenspeed-kernel-gap.md to docs/private/ (gitignored) — it
references third-party internal names and is a local research note, not
PR material.

Add mla-decode-brick.md as the PR1 dev doc (build, kernel inventory,
hand-written perf-debt flags).
@xiaguan xiaguan merged commit 90909d5 into main Jun 30, 2026
1 check passed

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: b88b4f7d1b

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

log = { workspace = true }
memmap2 = { workspace = true }
openinfer-core = { workspace = true }
openinfer-kernels = { workspace = true, features = ["glm52"] }

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep GLM5.2 kernels behind the model feature

Because openinfer-glm52 is a workspace member, this non-optional dependency makes any cargo test --release --workspace --lib / cargo check --workspace build openinfer-kernels with glm52 enabled; that feature also enables moe, and the kernels build script then requires the DeepEP/DeepGEMM/FlashMLA submodules plus the GLM5.2/NCCL/H200 toolchain even when the server glm52 feature was not selected. This regresses the documented default Qwen3-only workspace test/build path, so the kernels dependency should be optional and activated only by an explicit GLM5.2 crate/server feature.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant