Skip to content

Conversation

ghost
Copy link

@ghost ghost commented Sep 5, 2025

During my attempts to enhance efficiency in multigpu setup with Ampere+Turing cards, I found out that disabling paged attention increase PP speed ~3x times (non-tp mode, Mistral Large, 8.0bpw). For the past two days I profiled prefill() code trying to understand what actually slows down paged attention.
In the end problem appears to be in the get_block_index func which is probably meant to be turned into unblocking as well in this commit 843cec5 but for some reason it wasn't. After making it unblocking, profiler's picture for paged and non-paged become identical.

To provide some numbers, this is part of the log which shows how much time is being spent when processing modules when looping here during prefill stage:

    Module 150
             Peer copy, s1 <torch.cuda.Stream device=cuda:3 cuda_stream=0x736ca0acc920>, s2 <torch.cuda.Stream device=cuda:4 cuda_stream=0x736c68ad2480>
        safe_move_tensor[torch.Size([1, 1996, 12288])] (cuda:3->cuda:4): 0.000247 seconds
        forward: 0.000096 seconds
    Module 151
        safe_move_tensor[torch.Size([1, 1996, 12288])] (cuda:4->cuda:4): 0.000004 seconds
          start 1: 0.000001 seconds
          start 2: 0.000000 seconds
          start 3: 0.000001 seconds
          start 4: 0.000093 seconds
          start 5: 1.357264 seconds
          before is_q 1: 0.000000 seconds
          before is_q 2: 0.000002 seconds
          before is_q 3: 0.000019 seconds
          before is_q 4: 0.000001 seconds
          before is_q 5: 0.000000 seconds
          before q_attn_forward_1: 0.000043 seconds
          q_attn_forward_1: 0.000118 seconds
          before flash_attn_with_kvcache: 0.000193 seconds
          flash_attn_with_kvcache: 0.000044 seconds
          before q_attn_forward_2: 0.000005 seconds
          q_attn_forward_2: 0.000027 seconds
          after q_attn_forward_2: 0.000000 seconds
        forward: 1.357875 seconds

Because of the modules created sequentially in layers order, hidden states tensor is also being moved sequentially from one gpu to another. Here we can see that we reached modules which process layer that lies on gpu 4, and we should move tensor from gpu 3 to gpu 4 (module 150 is likely ExLlamaV2MLP, module 151 is ExLlamaV2Attention). Then, during attention processing we spend almost 1.5 seconds inside this line (sorry for my weird startN namings, I had to think up many different names to measure time between each call)
block_table = attn_params.get_block_index(self.device_idx)
And it's the only culprit why forward call for attention module takes so long.

After changing get_block_index to non-blocking version, the log looks like this:

 Module 150
             Peer copy, s1 <torch.cuda.Stream device=cuda:3 cuda_stream=0x75ed5cacc9a0>, s2 <torch.cuda.Stream device=cuda:4 cuda_stream=0x75ed24ad2900>
        safe_move_tensor[torch.Size([1, 1996, 12288])] (cuda:3->cuda:4): 0.011406 seconds
        forward: 0.000114 seconds
    Module 151
        safe_move_tensor[torch.Size([1, 1996, 12288])] (cuda:4->cuda:4): 0.000005 seconds
          start 1: 0.000001 seconds
          start 2: 0.000000 seconds
          start 3: 0.000001 seconds
          start 4: 0.000101 seconds
          start 5: 0.000088 seconds
          before is_q 1: 0.000000 seconds
          before is_q 2: 0.000001 seconds
          before is_q 3: 0.000011 seconds
          before is_q 4: 0.000001 seconds
          before is_q 5: 0.000000 seconds
          before q_attn_forward_1: 0.000031 seconds
          q_attn_forward_1: 0.000070 seconds
          before flash_attn_with_kvcache: 0.000100 seconds
          flash_attn_with_kvcache: 0.000040 seconds
          before q_attn_forward_2: 0.000005 seconds
          q_attn_forward_2: 0.000028 seconds
          after q_attn_forward_2: 0.000000 seconds
        forward: 0.000535 seconds

So, now almost no wait here. The new place where we wait things are the first attention module:

Module 1
             From CPU copy, s <torch.cuda.Stream device=cuda:0 cuda_stream=0x75ee9026c120>
        safe_move_tensor[torch.Size([1, 2048, 12288])] (cpu:None->cuda:0): 1.263441 seconds
          start 1: 0.000001 seconds
          start 2: 0.000001 seconds
          start 3: 0.000001 seconds
          start 4: 0.000104 seconds
          start 5: 0.000048 seconds
          before is_q 1: 0.000000 seconds
          before is_q 2: 0.000002 seconds
          before is_q 3: 0.000011 seconds
          before is_q 4: 0.000001 seconds
          before is_q 5: 0.000000 seconds
          before q_attn_forward_1: 0.000031 seconds
          q_attn_forward_1: 0.000087 seconds
          before flash_attn_with_kvcache: 0.005945 seconds
          flash_attn_with_kvcache: 0.000044 seconds
          before q_attn_forward_2: 0.000005 seconds
          q_attn_forward_2: 0.000027 seconds
          after q_attn_forward_2: 0.000000 seconds
        forward: 0.006366 seconds

So, in the end, instead of waiting for each gpu per chunch we are waiting only once in the beginning of the chunk.

And the final numbers by tabbyAPI (I had to powerlimit some cards, so I got less speed improvements, but still)
Before the fix:

179 tokens generated in 70.44 seconds (Queue:
0.0 s, Process: 0 cached tokens and 10189 new tokens at 255.43 T/s, Generate: 
5.86 T/s, Context: 10189 tokens) 

After the fix:

340 tokens generated in 78.94 seconds (Queue:
0.0 s, Process: 0 cached tokens and 10189 new tokens at 472.37 T/s, Generate: 
5.93 T/s, Context: 10189 tokens) 

In my previous experiments when I compared paged and non paged, I got something like 280 T/s -> 740 T/s

Actually I know very little of attention mechanism itself and how it should be implemented, is it even okay for any non-blocking tensors move, etc, etc... So I expect for someone to review this change and maybe test it in your multigpu setup (you don't have to turn paged attention off, it's a fix for paged attention, so just add this change and test in non-tensor parallel mode)

making it consistent with get_cache_seqlens
@ghost ghost closed this by deleting the head repository Sep 21, 2025
This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

0 participants