Skip to content

metal: SSM_SCAN performance #14743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

gabe-l-hart
Copy link
Collaborator

Description

This is an attempt to improve the overall performance of the mamba2 implementation of SSM_SCAN for metal. I'm specifically interested in improving performance for Granite Four, but will hopefully achieve a nice speedup for other models as well.

Changes

  • In the mamba2 clauses of the SSM_SCAN case for ggml_metal_encode_node, launch the kernel with threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1) (d_state threads) and a shared memory buffer of size 32 * sizeof(float) (SIMD size)
  • In kernel_ssm_scan_f32_group, remove the loop over nc (d_state) and instead use simd_sum to perform the final y calculation.

Testing

All testing was done on my M3 Max 64GB MacBook Pro running macOS 15.5.

Validity testing

To ensure correctness, I used test-backend-ops:

./bin/test-backend-ops -o SSM_SCAN
output
Backend 1/3: Metal
  Device description: Apple M3 Max
  Device memory: 49152 MB (49146 MB free)

  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4): OK
  SSM_SCAN(type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4): OK
  SSM_SCAN(type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4): OK
  6551/6551 tests passed
  Backend Metal: OK
ggml_metal_free: deallocating
ggml_metal_mem_pool_free: freeing memory pool, num heaps = 0 (total = 0)
ggml_metal_mem_pool_free: freeing memory pool, num heaps = 0 (total = 0)
ggml_metal_mem_pool_free: freeing memory pool, num heaps = 0 (total = 0)
ggml_metal_mem_pool_free: freeing memory pool, num heaps = 0 (total = 0)
ggml_metal_mem_pool_free: freeing memory pool, num heaps = 0 (total = 0)
ggml_metal_mem_pool_free: freeing memory pool, num heaps = 0 (total = 0)
ggml_metal_mem_pool_free: freeing memory pool, num heaps = 0 (total = 0)
ggml_metal_mem_pool_free: freeing memory pool, num heaps = 0 (total = 0)
Backend 2/3: BLAS
  Device description: Accelerate
  Device memory: 0 MB (0 MB free)

  SSM_SCAN(type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4): not supported [BLAS] 
  SSM_SCAN(type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4): not supported [BLAS] 
  SSM_SCAN(type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4): not supported [BLAS] 
  6551/6551 tests passed
  Backend BLAS: OK
Backend 3/3: CPU
  Skipping CPU backend
3/3 backends passed
OK

Performance Testing

To test the performance improvements, I used llama-batched-bench. I ran with a baseline on master (01612b) and comparison against raw CPU (-ngl 0)

# Test 256/256/1 for simple usage
./bin/llama-batched-bench -m ~/models/granite-4.0-tiny-preview/ggml-model-Q4_K_M.gguf -c 131072 -b 2048 -ub 512 -npp 256 -ntg 256 -npl 1 -ngl 99

# Test 2560/256/[1,2] for more realistic usage with context
./bin/llama-batched-bench -m ~/models/granite-4.0-tiny-preview/ggml-model-Q4_K_M.gguf -c 131072 -b 2048 -ub 512 -npp 2560 -ntg 256 -npl 1,2 -ngl 99

Metal (baseline 01612b)

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
256 256 1 512 0.852 300.60 6.391 40.06 7.243 70.69
2560 256 1 2816 8.399 304.80 6.314 40.54 14.713 191.39
2560 256 2 5632 16.503 310.25 11.220 45.63 27.723 203.15

Metal (with changes)

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
256 256 1 512 0.296 863.45 5.839 43.84 6.136 83.44
2560 256 1 2816 2.690 951.77 5.780 44.29 8.470 332.49
2560 256 2 5632 5.398 948.50 9.754 52.49 15.152 371.71

CPU

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
256 256 1 512 0.876 292.33 3.509 72.95 4.385 116.76
2560 256 1 2816 7.263 352.46 3.652 70.11 10.915 258.00
2560 256 2 5632 16.695 306.68 5.105 100.30 21.800 258.35

Discussion

This is my very first forray into kernel implementation, so there are very likely problems with this implementation that an experienced eye will catch. The first part that I would love eyes on is the implementation of the sequential simd_sum calls, the use of threadgroup_barrier and the sizing of the shared memory. This was guesswork based on the implementation of kernel_sum_rows, and it passes test-backend-ops, but I don't feel 100% solid about it.

This may not be necessary, but it more closely mirrors the CUDA kernel

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <[email protected]>
This is a first attempt at optimizing the metal kernel. The changes here
are:

- Launch the kernel with a thread group of size d_state
- Use simd groups and shared memory to do the summation for the y
  computation

When tested with G4 tiny preview, this shows roughly a 3x speedup on
prefill and 15% speedup on decode.

Signed-off-by: Gabe Goodhart <[email protected]>
@gabe-l-hart
Copy link
Collaborator Author

@compilade I'd particularly love your feedback on this!

@github-actions github-actions bot added ggml changes relating to the ggml tensor library for machine learning Apple Metal https://en.wikipedia.org/wiki/Metal_(API) labels Jul 17, 2025
@gabe-l-hart gabe-l-hart changed the title metail: SSM_SCAN performance metal: SSM_SCAN performance Jul 17, 2025
@gabe-l-hart
Copy link
Collaborator Author

gabe-l-hart commented Jul 17, 2025

It looks like the use of simd_sum means that this now requires MTLGPUFamilyApple7 and the CI runner is on MTLGPUFamilyApple5. My laptop where this is passing shows:

ggml_metal_init: GPU family: MTLGPUFamilyApple9  (1009)
ggml_metal_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_init: GPU family: MTLGPUFamilyMetal3  (5001)

I see that in ggml-metal.m, sum_rows, which also uses simd_sum in the implementation does not do any version checking in its enablement flag, but I'm wondering if using simd_sum in SSM_SCAN should do a version check somehow?

@gabe-l-hart
Copy link
Collaborator Author

@gabe-l-hart
Copy link
Collaborator Author

I'm able to repro the failure on my old M1 Max 64GB MacBook Pro (macOS 14.4.1). Support printouts look like:

ggml_metal_init: GPU family: MTLGPUFamilyApple7  (1007)
ggml_metal_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_init: GPU family: MTLGPUFamilyMetal3  (5001)

This indicates that there's some difference between MTLGPUFamilyApple7 and MTLGPUFamilyApple9 showing up.

@gabe-l-hart
Copy link
Collaborator Author

gabe-l-hart commented Jul 18, 2025

Ok, passing now on my M1 and M3. 🤞 CI is good too!

The fix seemed to be the second use of simd_sum which was not working as expected. I replaced this with a for loop over the number of simd groups in the threadgroup. This seems to result in a slight slowdown at prefill:

./bin/llama-batched-bench -m ~/models/granite-4.0-tiny-preview/ggml-model-Q4_K_M.gguf -c 131072 -b 2048 -ub 512 -npp 256,2560 -ntg 256 -npl 1,2 -ngl 99
PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
256 256 1 512 0.337 759.78 5.703 44.89 6.040 84.77
256 256 2 1024 0.620 825.87 10.115 50.62 10.735 95.39
2560 256 1 2816 2.971 861.68 5.990 42.74 8.960 314.27
2560 256 2 5632 5.973 857.13 10.225 50.07 16.199 347.68

@@ -519,6 +519,7 @@ typedef struct {
int64_t n_group;
int64_t n_seq_tokens;
int64_t n_seqs;
int64_t s_off;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is actually necessary. I did this when initially trying to emulate the CUDA kernel better which passes this as an arg. It does seem like it might be slightly faster to avoid computing this in the kernel, though that could be offset by the latency of an additional int64_t being passed to the device?

@gabe-l-hart
Copy link
Collaborator Author

…relationships

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <[email protected]>
@gabe-l-hart
Copy link
Collaborator Author

gabe-l-hart commented Jul 18, 2025

I experimented with attempting to only compute dt_soft_plus / x_dt / dA once per (simd group | threadgroup). I was able to get it working once-per-simd-group, but the performance seems to actually be slightly worse. Leaving this here for posterity and/or inspiration if others have thoughts:

        /*
        const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
        const float x_dt = x[0] * dt_soft_plus;
        const float dA = exp(dt_soft_plus * A[0]);

        /*/

        // Compute once per recursion and share across threads
        float dt_soft_plus = 0.0f;
        float x_dt = 0.0f;
        float dA = 0.0f;
        if (tiisg == 0) {
            dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
            x_dt = x[0] * dt_soft_plus;
            dA = exp(dt_soft_plus * A[0]);
        }

        dt_soft_plus = simd_broadcast_first(dt_soft_plus);
        x_dt = simd_broadcast_first(x_dt);
        dA = simd_broadcast_first(dA);

        //*/

I also had a similar version that used shared to do once-per-threadgroup but also saw a slight performance degradation because it required threadgroup_barrier to do the broadcast between simd groups.

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <[email protected]>
@gabe-l-hart
Copy link
Collaborator Author

Writing another (failed) idea down for posterity. To get further parallelism, it seems that there should be a way to parallelize over n_t to unroll the recurrence. As a sanity check on performance gain, I tried dispatching with dispatchThreadgroups:MTLSizeMake(d_inner * n_seq_tokens, n_head, n_seqs) and computing i1 and i2 from tgpig.x as:

    const int64_t i1 = (int64_t)(tgpig.x / nc);
    const int64_t i2 = tgpig.x % nc;

I commented out the actual i2 for loop and didn't bother getting the state computation to actually work. I then ran llama-batched-bench to see if the extra parallel dimension helped, and it was much worse (~2x slower for prefill). This likely has to do with memory alignment, but given this, I'm going to pause pushing in this direction.


Given this, I think the PR is ready for review @compilade @ggerganov

Copy link
Collaborator

@compilade compilade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To get further parallelism, it seems that there should be a way to parallelize over n_t to unroll the recurrence.

I feel like to truly parallelize this, the duality of Mamba-2's SSM scan with semi-structured matrices could be used. But implementing that is a whole project in itself.

s[i] = state;
const int64_t i = tpitg.x + i1*nc;
const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt);
s[i] = state;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it would be faster to store the intermediate states in a local variable instead of repeatedly in the destination buffer.

I'm not very experienced with Metal (the naïve version you're starting from was pretty much my first Metal kernel), but I assume it should be possible?

Unless I'm misunderstanding the memory model, each thread only handles a single state (as in s[i] always refers to the same place, but differs between threads).

I think this would only affect prompt processing speed, not really small batches, though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting thought! I'm also very much still learning the memory model, so I'll play with this idea and see how far I can get it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great suggestion. It was easy to implement and gives a nice bump in performance. Will commit and push shortly

PP TG B N_KV T_PP s S_PP t/s T_TG s S_TG t/s T s S t/s
256 256 1 512 0.312 821.40 5.700 44.91 6.012 85.17
256 256 2 1024 0.582 880.34 9.658 53.01 10.240 100.00
2560 256 1 2816 2.845 899.88 5.854 43.73 8.699 323.71
2560 256 2 5632 5.700 898.23 9.911 51.66 15.611 360.76

@gabe-l-hart
Copy link
Collaborator Author

I feel like to truly parallelize this, the duality of Mamba-2's SSM scan with semi-structured matrices could be used. But implementing that is a whole project in itself.

Yep, I agree that this is the right next step, but also a much bigger undertaking. Hopefully the added parallelism here is a step in the right direction at least.

Branch: GraniteFourPerf

Signed-off-by: Gabe Goodhart <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple Metal https://en.wikipedia.org/wiki/Metal_(API) ggml changes relating to the ggml tensor library for machine learning
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants