feat(glm52): MLA decode brick — projection, absorb, cache-pack, FlashMLA sparse decode#477
Conversation
Cherry-pick oracle-validated MLA decode kernels from the pp8 branch and wire them for EP1/DP1 (no PP): CUDA kernels: - glm52_mla_assembly.cu: query assemble + cache pack (hand-written, 142 lines) - glm52_moe_quant.cu: per-token-group FP8 quant (hand-written, 180 lines) - glm52_trtllm_grouped_fp8.cu: TRTLLM CUTLASS blockscale GEMM (vendored) - csrc/shared/linear.cu: gemm_strided_batched_bf16_cuda (cuBLAS) Rust ops + FFI: - ops/glm52: mla_assembly, moe_quant, trtllm_linear - ops/linear: gemm_strided_batched_bf16 - ffi/glm52, ffi/shared: new extern declarations - build.rs: TRTLLM CUTLASS includes for glm52_trtllm_grouped_fp8.cu Model crate: - fp8.rs: ProjWeight, fp8_linear, dequant_kv_b (shared by MLA/dense/MoE) - mla_decode.rs: single-layer MLA decode forward (hidden->o, bs=1) - tests/mla_decode_oracle.rs: end-to-end MLA forward vs HF oracle Hand-written CUDA perf debt: glm52_mla_assembly.cu, glm52_moe_quant.cu — correct, not tuned, first ncu candidates (see dp1-ep8-decode-plan.md).
…c to private The prototype oracle fixture chain (HF forward dump -> layer0.npz -> probe bins -> Rust test) was not self-contained: the dump script that generates layer0.npz was never in the repo, so the test could not be reproduced by anyone else. Remove the test + prep scripts + inline test module; defer the oracle gate to a follow-up that designs a reproducible fixture pipeline. Also move tokenspeed-kernel-gap.md to docs/private/ (gitignored) — it references third-party internal names and is a local research note, not PR material. Add mla-decode-brick.md as the PR1 dev doc (build, kernel inventory, hand-written perf-debt flags).
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: b88b4f7d1b
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| log = { workspace = true } | ||
| memmap2 = { workspace = true } | ||
| openinfer-core = { workspace = true } | ||
| openinfer-kernels = { workspace = true, features = ["glm52"] } |
There was a problem hiding this comment.
Keep GLM5.2 kernels behind the model feature
Because openinfer-glm52 is a workspace member, this non-optional dependency makes any cargo test --release --workspace --lib / cargo check --workspace build openinfer-kernels with glm52 enabled; that feature also enables moe, and the kernels build script then requires the DeepEP/DeepGEMM/FlashMLA submodules plus the GLM5.2/NCCL/H200 toolchain even when the server glm52 feature was not selected. This regresses the documented default Qwen3-only workspace test/build path, so the kernels dependency should be optional and activated only by an explicit GLM5.2 crate/server feature.
Useful? React with 👍 / 👎.
Summary
First forward brick for GLM5.2 on top of the load-weight scaffold (#476): single-layer MLA decode forward (
hidden[6144] -> o[6144], bs=1, full top-k). At context ≤ 2048 the full top-k is DSA-equivalent, so this is the attention correctness foundation for the DP1/EP8 decode path.The runner/coordinator remains fail-closed — no server wiring, no scheduler, no weight-loading changes. New modules are
#[allow(dead_code)]until a later PR wires them into the executor.See
docs/models/glm52/dp1-ep8-decode-plan.mdfor the full 5-PR roadmap (PR1 MLA brick → PR2 DSA indexer → PR3 EP1 forward → PR4 DeepEP EP8 → PR5 scheduler).What lands
CUDA kernels (4 files):
glm52_mla_assembly.cu(142 lines, hand-written) — query concat + interleave RoPE + fp8 cache packglm52_moe_quant.cu(180 lines, hand-written) — per-128-group amax → e4m3 FP8 quantglm52_trtllm_grouped_fp8.cu(276 lines, vendored TRTLLM CUTLASS) — blockscale GEMM, m=1 linear + grouped MoEcsrc/shared/linear.cu(+35 lines) —cublasGemmStridedBatchedExwrapperRust ops + FFI (8 files):
ops/glm52/{mla_assembly,moe_quant,trtllm_linear}.rs— launch wrappers with contract validationops/linear.rs—gemm_strided_batched_bf16(cuBLAS strided batched)Model crate (3 files):
fp8.rs—ProjWeight,fp8_linear(quant → relay scale → TRTLLM launch),dequant_kv_b(host fp8→bf16 absorb factors)mla_decode.rs—Glm52MlaLayerWeights+glm52_mla_decode_forwardlib.rs— registers new modulesDocs (3 files):
dp1-ep8-decode-plan.md— 5-PR roadmap with hand-written kernel perf-debt flagsmla-decode-brick.md— PR1 dev doc (build, kernel inventory)index.md— glm52 routing sectionHand-written CUDA perf debt
glm52_mla_assembly.cuandglm52_moe_quant.cuare hand-written, not vendored. Both are memory-bound elementwise/reduce kernels (not GEMM/attention compute). They are correct but not tuned: single-issue-per-element, no vectorized load/store, no occupancy targeting. First ncu candidates when decode TPOT is measured. If a fused C-ABI alternative appears in vendored FlashInfer/TRTLLM, replace rather than optimize in place.Oracle gate — deferred
This PR does not include an oracle test. The prototype had a fixture pipeline (HF forward dump →
layer0.npz→ probe bins → Rust test), but the dump script was never in the repo, making the chain non-reproducible. The oracle gate is deferred to a follow-up that designs a self-contained fixture pipeline.Build
Requires SM90a (H200), CUDA 12.6+ (
cuLibraryLoadDatafor DeepGEMM JIT), NCCL 2.30.4+ (DeepEP submodule viamoefeature).export OPENINFER_NCCL_ROOT=/path/to/nvidia/nccl git submodule update --init --recursive cargo check --release -p openinfer-glm52Verified compiling on H200 (sm_90a, CUDA 12.8, NCCL 2.30.7).
Checklist
cargo fmt --all --checkclean