Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FlexAttention] Triton XPU didn't get correct value with the block io if the base address is not restricted aligned #3704

Open
chengjunlu opened this issue Mar 18, 2025 · 1 comment · May be fixed by #3712

Comments

@chengjunlu
Copy link
Contributor

chengjunlu commented Mar 18, 2025

Describe the bug

In the FlexDecoding test case, we found an issue that the block IO returns the in-correct matrix value if the base address is not aligned.

The Inductor code will generate the code like this:

    K_block_ptr = tl.make_block_ptr(
        base=K + k_offset,
        shape=(QK_HEAD_DIM, KV_LEN),                # (d, N)
        strides=(stride_kk, stride_kn),
        offsets=(0, off_n),
        block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
        order=(0, 1)
    )

It adds the offset directly into the base.

K_block_ptr base: 0xff000000002007ca
K_block_ptr shape:  [64]
K_block_ptr shape:  [2048]
K_block_ptr strides:  [1]
K_block_ptr strides:  [64]
K_block_ptr offsets:  [0]
K_block_ptr offsets:  [0]
K_block_ptr block_shape:  [64]
K_block_ptr block_shape:  [64]

Environment details

Triton XPU: Latest

@chengjunlu chengjunlu self-assigned this Mar 18, 2025
@chengjunlu
Copy link
Contributor Author

In the 2D block IO lowering, we have compensate the offset of non-64 bytes aligned base to the OffsetX and BaseWidth.
But there is extra restriction on the OffsetX that it has to be 4-bytes aligned.
We need to fallback to gather load for the case that OffsetX is not 4-bytes aligned.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment