Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge OpenAI Triton commit c2fd8e1 #3699

Merged
merged 6 commits into from
Mar 17, 2025
Merged

Merge OpenAI Triton commit c2fd8e1 #3699

merged 6 commits into from
Mar 17, 2025

Conversation

anmyachev
Copy link
Contributor

@anmyachev anmyachev commented Mar 17, 2025

This PR change the Triton base from a9475c9 to c2fd8e1 (Mar 10).
Pass rate: 92.47% ->89.93%

Please do not squash and merge this PR.

bingyizh233 and others added 6 commits March 7, 2025 23:06
I have discussed with @ThomasRaoux regarding the subtiling strategy on
Hopper WGMMA.

The subtiling for Ampere also works for Hopper WGMMA.  
I have verify the correctness by manually changing the following ttgir
into subitling fashion as shown in the code below.

Just need to disable a assert shown in this PR. Wondering if there is
any specific reason for keeping this assert, because
NVMMSharedEncodingAttr keeps the information of swizzle size regardless
of the matrix shape.

Disabling this assert also pass all the test. 


Originally code:
```
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0]}>
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = true, elementBitWidth = 16}>
#smem = #ttg.shared_memory

      %152 = ttg.memdesc_subview %59[%150, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x64xi8, #shared, #smem, mutable> -> !ttg.memdesc<128x64xi8, #shared, #smem, mutable, 2x128x64> loc(#loc23)
      %154 = ttg.memdesc_subview %60[%147, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xbf16, #shared1, #smem, mutable, 3x64x256> loc(#loc24)
      %153 = ttg.local_load %152 : !ttg.memdesc<128x64xi8, #shared, #smem, mutable, 2x128x64> -> tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc27)
      %155 = arith.sitofp %153 : tensor<128x64xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc27)
      %156 = ttng.warp_group_dot %155, %154, %arg13 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<64x256xbf16, #shared1, #smem, mutable, 3x64x256> -> tensor<128x256xf32, #mma> loc(#loc28)
      %157:2 = ttng.warp_group_dot_wait %156, %154 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<64x256xbf16, #shared1, #smem, mutable, 3x64x256> loc(#loc28)
```


With subtiling:
```
      %152 = ttg.memdesc_subview %59[%150, %c0_i32, %c0_i32] : !ttg.memdesc<2x128x64xi8, #shared, #smem, mutable> -> !ttg.memdesc<128x64xi8, #shared, #smem, mutable, 2x128x64> loc(#loc23)
      %154 = ttg.memdesc_subview %60[%147, %c0_i32, %c0_i32] : !ttg.memdesc<3x64x256xbf16, #shared1, #smem, mutable> -> !ttg.memdesc<64x256xbf16, #shared1, #smem, mutable, 3x64x256> loc(#loc24)

      %c0_i32 = arith.constant 0 : i32 loc(#loc1)
      %1000 = ttg.memdesc_subview %152[%c0_i32, %c0_i32] : !ttg.memdesc<128x64xi8, #shared, #smem, mutable, 2x128x64> -> !ttg.memdesc<128x16xi8, #shared, #smem, mutable, 2x128x64>  loc(#loc37)
      %3000 = ttg.memdesc_subview %154[%c0_i32, %c0_i32] :  !ttg.memdesc<64x256xbf16, #shared1, #smem, mutable, 3x64x256> ->  !ttg.memdesc<16x256xbf16, #shared1, #smem, mutable, 3x64x256>  loc(#loc37)
      %1001 = ttg.local_load %1000 : !ttg.memdesc<128x16xi8, #shared, #smem, mutable, 2x128x64>  -> tensor<128x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc27)
      %1002 = arith.sitofp %1001 : tensor<128x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc27)
      %2000 = ttng.warp_group_dot %1002, %3000, %arg13 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<16x256xbf16, #shared1, #smem, mutable, 3x64x256> -> tensor<128x256xf32, #mma> loc(#loc28)

      %c16_i32 = arith.constant 16 : i32 loc(#loc37)
      %1003 = ttg.memdesc_subview %152[%c0_i32, %c16_i32] : !ttg.memdesc<128x64xi8, #shared, #smem, mutable, 2x128x64> -> !ttg.memdesc<128x16xi8, #shared, #smem, mutable, 2x128x64>  loc(#loc37)
      %3001 = ttg.memdesc_subview %154[%c16_i32, %c0_i32] :  !ttg.memdesc<64x256xbf16, #shared1, #smem, mutable, 3x64x256> ->  !ttg.memdesc<16x256xbf16, #shared1, #smem, mutable, 3x64x256>  loc(#loc37)
      %1004 = ttg.local_load %1003 :  !ttg.memdesc<128x16xi8, #shared, #smem, mutable, 2x128x64>  -> tensor<128x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc27)
      %1005 = arith.sitofp %1004 : tensor<128x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc27)
      %2001 = ttng.warp_group_dot %1005, %3001, %2000 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<16x256xbf16, #shared1, #smem, mutable, 3x64x256> -> tensor<128x256xf32, #mma> loc(#loc28)

      %c32_i32 = arith.constant 32 : i32 loc(#loc37) 
      %1006 = ttg.memdesc_subview %152[%c0_i32, %c32_i32] : !ttg.memdesc<128x64xi8, #shared, #smem, mutable, 2x128x64> -> !ttg.memdesc<128x16xi8, #shared, #smem, mutable, 2x128x64>  loc(#loc37)
      %3002 = ttg.memdesc_subview %154[%c32_i32, %c0_i32] :  !ttg.memdesc<64x256xbf16, #shared1, #smem, mutable, 3x64x256> ->  !ttg.memdesc<16x256xbf16, #shared1, #smem, mutable, 3x64x256>  loc(#loc37)
      %1007 = ttg.local_load %1006 :  !ttg.memdesc<128x16xi8, #shared, #smem, mutable, 2x128x64> -> tensor<128x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc27)
      %1008 = arith.sitofp %1007 : tensor<128x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc27)
      %2002 = ttng.warp_group_dot %1008, %3002, %2001 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<16x256xbf16, #shared1, #smem, mutable, 3x64x256> -> tensor<128x256xf32, #mma> loc(#loc28)

      %c48_i32 = arith.constant 48 : i32 loc(#loc37) 
      %1009 = ttg.memdesc_subview %152[%c0_i32, %c48_i32] : !ttg.memdesc<128x64xi8, #shared, #smem, mutable, 2x128x64> -> !ttg.memdesc<128x16xi8, #shared, #smem, mutable, 2x128x64>  loc(#loc37)
      %3003 = ttg.memdesc_subview %154[%c48_i32, %c0_i32] :  !ttg.memdesc<64x256xbf16, #shared1, #smem, mutable, 3x64x256> ->  !ttg.memdesc<16x256xbf16, #shared1, #smem, mutable, 3x64x256>  loc(#loc37)
      %1010 = ttg.local_load %1009 : !ttg.memdesc<128x16xi8, #shared, #smem, mutable, 2x128x64>    -> tensor<128x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc27)
      %1011 = arith.sitofp %1010 : tensor<128x16xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> to tensor<128x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> loc(#loc27)
      %2003 = ttng.warp_group_dot %1011, %3003, %2002 {inputPrecision = 0 : i32, isAsync = true} : tensor<128x16xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * !ttg.memdesc<16x256xbf16, #shared1, #smem, mutable, 3x64x256> -> tensor<128x256xf32, #mma> loc(#loc28)

      %157:2 = ttng.warp_group_dot_wait %2003, %154 {pendings = 0 : i32} : tensor<128x256xf32, #mma>, !ttg.memdesc<64x256xbf16, #shared1, #smem, mutable, 3x64x256> loc(#loc28)
```



<!---
The core Triton is a small number of people, and we receive many PRs
(thank
you!).  To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the
following
tasks and include the filled-out checklist in your PR description.**

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.
-->

# New contributor declaration
- [x ] I am not making a trivial change, such as fixing a typo in a
comment.

- [ x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x ] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [x ] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
  - [ ] This PR does not need a test because `FILL THIS IN`.

- Select one of the following.
  - [ ] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)
```
cvt with .e4m3x2/.e5m2x2 requires sm89 or higher.

cvt.satfinite.{e4m3x2, e5m2x2}.{f32, f16x2} requires sm_89 or higher.
```
Restricts the memory ops in consideration for each of the clusters to
only those that are used to compute the dot product. This fixes a
potential issue if any persistent kernel were to put a `tl.load` in its
prologue it would result in invalid IR after this transformation.

---------

Co-authored-by: Nick Riasanovsky <[email protected]>
Enabled more tests in test_mxfp8_mxfp4_matmul
for the AMD backend.
@anmyachev anmyachev marked this pull request as ready for review March 17, 2025 15:29
@whitneywhtsang whitneywhtsang merged commit 30bcd8e into main Mar 17, 2025
6 checks passed
@whitneywhtsang whitneywhtsang deleted the amyachev/merge10 branch March 17, 2025 16:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants