Skip to content

Conversation

@chiennv2000
Copy link

@chiennv2000 chiennv2000 commented Sep 23, 2025

This PR introduces several enhancements to the attention kernel, including the implementation of a backward pass, memory optimization for grouped query attention, and a bug fix.

1. Bug Fix: Incorrect Attention with Query Offset: Fixed a bug where the attention kernel produced incorrect results when the query offset (start_q) was non-zero. The kernel's starting loop bound (lo) was incorrectly initialized to start_q, causing the computation to skip the initial keys in the KV cache.

2. Improve GQA Memory Optimization: The K and V tensors were explicitly expanded using torch.repeat_interleave, which materialized large tensors in memory. It is better to handle it by manipulating pointers to map query heads to their corresponding KV head

3. Backward Pass Implementation: Implement a custom backward pass, making the module fully differentiable and usable for end-to-end model training.


Testing: All pytestcases have been updated to validate both the forward and backward passes against a reference PyTorch implementation:

================= 18 passed, 6 skipped in 5.06s =================

@chiennv2000 chiennv2000 changed the title Fix: Correct attention kernel loop bounds Fix: Correct/Improve the triton attention kernel Sep 28, 2025
@chiennv2000
Copy link
Author

@Maratyszcza @dkundel-openai please review

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant