Skip to content

[CUDA] [Improvement] Rope without copy#3704

Open
nastya236 wants to merge 5 commits into
mainfrom
rope-without-copy
Open

[CUDA] [Improvement] Rope without copy#3704
nastya236 wants to merge 5 commits into
mainfrom
rope-without-copy

Conversation

@nastya236

@nastya236 nastya236 commented Jun 16, 2026

Copy link
Copy Markdown
Collaborator

In rope for non-contiguous inputs we do a full copy and allocate a new buffer, so the input is not donatable.
So for an input with swapped axes like: (B, L, n_heads, head_dim).swapped(1, 2) we copy to row-contiguous buffer and write output to row-contiguous buffer.
Since usually rope is followed by scaled dot product attention, that does not need row-contiguous input, we don't really need this copy.
This small change speeds up rope a bit and decreases peak memory a bit.

import mlx.core as mx
import mlx.nn as nn
from time_utils import time_fn

def time_rope_hs_transposed():
    rope = nn.RoPE(128)

    x = mx.random.uniform(shape=(16, 8192, 8, 128)).astype(mx.float16)
    mx.eval(x)
    x = x.transpose(0, 2, 1, 3)

    def rope_transposed(x):
        for _ in range(32):
            x = rope(x)
        return x

    time_fn(rope_transposed, x)

time_rope_hs_transposed()

On a spark:

New: 98.82948 msec
Old: 113.83483 msec

It gives some speed up:
End-to-end training on Spark with Qwen 0.6B:
7050 tok/sec vs 6890 tok/sec

End-to-end training on B200 with Qwen 4B:
26499 vs 25687 tok/sec

More importantly it decreases peak memory quite significantly. For example, for 4B training from 171GB to 166GB.
The diff is not as big as it seems, kernel dispatch is just wrapped in one extra lambda.
Probably there is a way to make it nicer.

@nastya236 nastya236 changed the title Rope without copy [CUDA] [Improvement] Rope without copy Jun 16, 2026
@nastya236 nastya236 requested a review from zcbenz June 16, 2026 16:51
Geramy added a commit to NripeshN/mlx that referenced this pull request Jun 16, 2026
…explore#3704 port)

Partial RoPE (dims_ < D, e.g. Qwen3-Next partial_rotary_factor=0.25) copied the
whole tensor to `out` so the untouched [dims_:D] tail was present, then rotated
in place — a full (often strided/General) copy per call. Port of ml-explore/mlx
ml-explore#3704: when the input is donatable, rotate the first dims_ channels IN PLACE and
adopt the input's layout for `out`, eliminating the copy. Downstream SDPA accepts
non-contiguous q/k, so it's safe. Falls back to the copy path when not donatable;
4/8-bit and contiguous paths unchanged. Verified coherent (Qwen3.6 4-bit).

@angeloskath angeloskath left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The results look really nice!

I am trying to figure out if it would be simpler to just pass 4 strides all the time.

A bit of a random thought but since I saw it I might as well comment about it. Line 296 if (dims_ < D) we always copy. We don't need to if it is going to be donated.

So I would do likely the following, check if always 4 strides affects performance (you could also do the number of strides as a template parameter but that's also complex). You can also check it on the spark just to be sure.

Then swap the check to can I donate the input if yes done, if not check for dims to copy etc.

Wdyt? If you think it is already good enough I am fine merging it as is.

@nastya236

nastya236 commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator Author

I think you right about both. I think passing 4 strides sounds like a good idea, we will get rid of hs_transpose. Regarding dims_ < D, of course, thank you, I have not even looked into partial rotation, we don't need to copy it!

@nastya236

Copy link
Copy Markdown
Collaborator Author

I simplified the dispatch function, I think it makes sense, I tested it on a spark, don't see any slow downs.
So:

if input.ndims < 4 or row_contiguous -- we consider this as appropriate layout:

  • if we can donate the input -- reuse the buffer
  • if input is not donatable, we allocate a new buffer for the output with the same strides as input

if input has > 4 dims and is not row_contigous:

  • we copy the input to a new buffer (we should not be here often)

if input is partial:

  • if input is donatable, we reuse input buffer
  • if not, we copy

@angeloskath angeloskath left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me!

Feel free to run the same benchmark as the first comment to make sure we didn't regress due to the extra stride and merge.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants