-
Notifications
You must be signed in to change notification settings - Fork 12.4k
metal: SSM_SCAN performance #14743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
metal: SSM_SCAN performance #14743
Conversation
This may not be necessary, but it more closely mirrors the CUDA kernel Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <[email protected]>
This is a first attempt at optimizing the metal kernel. The changes here are: - Launch the kernel with a thread group of size d_state - Use simd groups and shared memory to do the summation for the y computation When tested with G4 tiny preview, this shows roughly a 3x speedup on prefill and 15% speedup on decode. Signed-off-by: Gabe Goodhart <[email protected]>
@compilade I'd particularly love your feedback on this! |
It looks like the use of
I see that in |
26524d0
to
8d5a25d
Compare
Adding a version check did not do the trick: https://github.com/ggml-org/llama.cpp/actions/runs/16354752616/job/46210213747?pr=14743#step:6:27203 |
I'm able to repro the failure on my old M1 Max 64GB MacBook Pro (macOS 14.4.1). Support printouts look like:
This indicates that there's some difference between |
Signed-off-by: Gabe Goodhart <[email protected]>
Ok, passing now on my M1 and M3. 🤞 CI is good too! The fix seemed to be the second use of ./bin/llama-batched-bench -m ~/models/granite-4.0-tiny-preview/ggml-model-Q4_K_M.gguf -c 131072 -b 2048 -ub 512 -npp 256,2560 -ntg 256 -npl 1,2 -ngl 99
|
@@ -519,6 +519,7 @@ typedef struct { | |||
int64_t n_group; | |||
int64_t n_seq_tokens; | |||
int64_t n_seqs; | |||
int64_t s_off; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this is actually necessary. I did this when initially trying to emulate the CUDA kernel better which passes this as an arg. It does seem like it might be slightly faster to avoid computing this in the kernel, though that could be offset by the latency of an additional int64_t
being passed to the device?
…relationships Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <[email protected]>
0817add
to
21db0b5
Compare
I experimented with attempting to only compute /*
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
const float dA = exp(dt_soft_plus * A[0]);
/*/
// Compute once per recursion and share across threads
float dt_soft_plus = 0.0f;
float x_dt = 0.0f;
float dA = 0.0f;
if (tiisg == 0) {
dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
x_dt = x[0] * dt_soft_plus;
dA = exp(dt_soft_plus * A[0]);
}
dt_soft_plus = simd_broadcast_first(dt_soft_plus);
x_dt = simd_broadcast_first(x_dt);
dA = simd_broadcast_first(dA);
//*/ I also had a similar version that used |
Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <[email protected]>
Writing another (failed) idea down for posterity. To get further parallelism, it seems that there should be a way to parallelize over const int64_t i1 = (int64_t)(tgpig.x / nc);
const int64_t i2 = tgpig.x % nc; I commented out the actual Given this, I think the PR is ready for review @compilade @ggerganov |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To get further parallelism, it seems that there should be a way to parallelize over
n_t
to unroll the recurrence.
I feel like to truly parallelize this, the duality of Mamba-2's SSM scan with semi-structured matrices could be used. But implementing that is a whole project in itself.
ggml/src/ggml-metal/ggml-metal.metal
Outdated
s[i] = state; | ||
const int64_t i = tpitg.x + i1*nc; | ||
const float state = (s0[i] * dA) + (B[tpitg.x] * x_dt); | ||
s[i] = state; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it would be faster to store the intermediate states in a local variable instead of repeatedly in the destination buffer.
I'm not very experienced with Metal (the naïve version you're starting from was pretty much my first Metal kernel), but I assume it should be possible?
Unless I'm misunderstanding the memory model, each thread only handles a single state (as in s[i]
always refers to the same place, but differs between threads).
I think this would only affect prompt processing speed, not really small batches, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting thought! I'm also very much still learning the memory model, so I'll play with this idea and see how far I can get it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great suggestion. It was easy to implement and gives a nice bump in performance. Will commit and push shortly
PP | TG | B | N_KV | T_PP s | S_PP t/s | T_TG s | S_TG t/s | T s | S t/s |
---|---|---|---|---|---|---|---|---|---|
256 | 256 | 1 | 512 | 0.312 | 821.40 | 5.700 | 44.91 | 6.012 | 85.17 |
256 | 256 | 2 | 1024 | 0.582 | 880.34 | 9.658 | 53.01 | 10.240 | 100.00 |
2560 | 256 | 1 | 2816 | 2.845 | 899.88 | 5.854 | 43.73 | 8.699 | 323.71 |
2560 | 256 | 2 | 5632 | 5.700 | 898.23 | 9.911 | 51.66 | 15.611 | 360.76 |
Yep, I agree that this is the right next step, but also a much bigger undertaking. Hopefully the added parallelism here is a step in the right direction at least. |
Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <[email protected]>
Description
This is an attempt to improve the overall performance of the
mamba2
implementation ofSSM_SCAN
formetal
. I'm specifically interested in improving performance for Granite Four, but will hopefully achieve a nice speedup for other models as well.Changes
mamba2
clauses of the SSM_SCAN case forggml_metal_encode_node
, launch the kernel withthreadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)
(d_state
threads) and a shared memory buffer of size32 * sizeof(float)
(SIMD size)kernel_ssm_scan_f32_group
, remove the loop overnc
(d_state
) and instead usesimd_sum
to perform the finaly
calculation.Testing
All testing was done on my M3 Max 64GB MacBook Pro running macOS 15.5.
Validity testing
To ensure correctness, I used
test-backend-ops
:output
Performance Testing
To test the performance improvements, I used
llama-batched-bench
. I ran with a baseline onmaster
(01612b
) and comparison against raw CPU (-ngl 0
)Metal (baseline
01612b
)Metal (with changes)
CPU
Discussion
This is my very first forray into kernel implementation, so there are very likely problems with this implementation that an experienced eye will catch. The first part that I would love eyes on is the implementation of the sequential
simd_sum
calls, the use ofthreadgroup_barrier
and the sizing of the shared memory. This was guesswork based on the implementation of kernel_sum_rows, and it passestest-backend-ops
, but I don't feel 100% solid about it.