[CUDA] [Improvement] Rope without copy#3704
Conversation
…explore#3704 port) Partial RoPE (dims_ < D, e.g. Qwen3-Next partial_rotary_factor=0.25) copied the whole tensor to `out` so the untouched [dims_:D] tail was present, then rotated in place — a full (often strided/General) copy per call. Port of ml-explore/mlx ml-explore#3704: when the input is donatable, rotate the first dims_ channels IN PLACE and adopt the input's layout for `out`, eliminating the copy. Downstream SDPA accepts non-contiguous q/k, so it's safe. Falls back to the copy path when not donatable; 4/8-bit and contiguous paths unchanged. Verified coherent (Qwen3.6 4-bit).
angeloskath
left a comment
There was a problem hiding this comment.
The results look really nice!
I am trying to figure out if it would be simpler to just pass 4 strides all the time.
A bit of a random thought but since I saw it I might as well comment about it. Line 296 if (dims_ < D) we always copy. We don't need to if it is going to be donated.
So I would do likely the following, check if always 4 strides affects performance (you could also do the number of strides as a template parameter but that's also complex). You can also check it on the spark just to be sure.
Then swap the check to can I donate the input if yes done, if not check for dims to copy etc.
Wdyt? If you think it is already good enough I am fine merging it as is.
|
I think you right about both. I think passing 4 strides sounds like a good idea, we will get rid of |
|
I simplified the dispatch function, I think it makes sense, I tested it on a spark, don't see any slow downs. if
if input has > 4 dims and is not row_contigous:
if input is partial:
|
angeloskath
left a comment
There was a problem hiding this comment.
Looks great to me!
Feel free to run the same benchmark as the first comment to make sure we didn't regress due to the extra stride and merge.
In rope for non-contiguous inputs we do a full copy and allocate a new buffer, so the input is not donatable.
So for an input with swapped axes like:
(B, L, n_heads, head_dim).swapped(1, 2)we copy to row-contiguous buffer and write output to row-contiguous buffer.Since usually rope is followed by scaled dot product attention, that does not need row-contiguous input, we don't really need this copy.
This small change speeds up rope a bit and decreases peak memory a bit.
On a spark:
It gives some speed up:
End-to-end training on Spark with Qwen 0.6B:
7050 tok/sec vs 6890 tok/sec
End-to-end training on B200 with Qwen 4B:
26499 vs 25687 tok/sec
More importantly it decreases peak memory quite significantly. For example, for 4B training from 171GB to 166GB.
The diff is not as big as it seems, kernel dispatch is just wrapped in one extra lambda.
Probably there is a way to make it nicer.