Skip to content

Support dynamic chunk size (4 vs 16) for small MSM inputs (< 2^16) #79

@moven0831

Description

@moven0831

Problem

we currently hard-codes CHUNK_SIZE = 16 in shader.rs and metal_msm.rs, which forces the Metal kernels to create more columns than necessary. Benchmarks show the fixed size favors large inputs but penalizes small MSM workloads (e.g., < 65 536 points). The ZPrize 2023 WebGPU reference implementation switches to CHUNK_SIZE = 4 for inputs < 2^16, giving a lighter MSM configuration and markedly better performance for smaller datasets.

Reference:
https://github.com/z-prize/2023-entries/blob/6cc68aeb63071d90817aeff4b55b34444fae42a8/prize-2-msm-wasm/webgpu-only/tal-derei-koh-wei-jie/src/submission/submission.ts#L80

Details / Proposed Solution

  1. Expose CHUNK_SIZE as a compile-time or runtime parameter

    • Use 4 when num_points < 2**16; keep 16 otherwise.
  2. Update dispatch logic in host code (Metal & WebGPU)

    • Recalculate num_columns, threads_per_threadgroup, and threadgroup_count based on the selected CHUNK_SIZE.
  3. Audit shader code

    • Replace hard-coded literals with constants passed via the push-constants / uniform buffer.
    • Validate loop bounds to prevent infinite loops or out-of-bounds memory writes when CHUNK_SIZE = 4.

Acceptance Criteria

  • cargo test --release test_metal_msm_pipeline -- --nocapture passes for both CHUNK_SIZE = 4 and 16.
  • End-to-end MSM benchmarks finish without runtime errors for point counts: 2¹⁰, 2¹⁵, 2¹⁶, 2²⁰.
  • For inputs ≤ 2¹⁵, CHUNK_SIZE = 4 shows ≥ 10 % speed-up vs the current master on an M-series GPU.

Metadata

Metadata

Assignees

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions