-
Notifications
You must be signed in to change notification settings - Fork 11
Closed
Labels
Description
Problem
we currently hard-codes CHUNK_SIZE = 16 in shader.rs and metal_msm.rs, which forces the Metal kernels to create more columns than necessary. Benchmarks show the fixed size favors large inputs but penalizes small MSM workloads (e.g., < 65 536 points). The ZPrize 2023 WebGPU reference implementation switches to CHUNK_SIZE = 4 for inputs < 2^16, giving a lighter MSM configuration and markedly better performance for smaller datasets.
Details / Proposed Solution
-
Expose
CHUNK_SIZEas a compile-time or runtime parameter- Use
4whennum_points < 2**16; keep16otherwise.
- Use
-
Update dispatch logic in host code (Metal & WebGPU)
- Recalculate
num_columns,threads_per_threadgroup, andthreadgroup_countbased on the selectedCHUNK_SIZE.
- Recalculate
-
Audit shader code
- Replace hard-coded literals with constants passed via the push-constants / uniform buffer.
- Validate loop bounds to prevent infinite loops or out-of-bounds memory writes when
CHUNK_SIZE = 4.
Acceptance Criteria
-
cargo test --release test_metal_msm_pipeline -- --nocapturepasses for bothCHUNK_SIZE = 4and16. - End-to-end MSM benchmarks finish without runtime errors for point counts: 2¹⁰, 2¹⁵, 2¹⁶, 2²⁰.
- For inputs ≤ 2¹⁵,
CHUNK_SIZE = 4shows ≥ 10 % speed-up vs the current master on an M-series GPU.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
Type
Projects
Status
Done