Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Codegen] Run LoopCoalescingPass at the end of warp reduce #19950

Merged
merged 2 commits into from
Feb 12, 2025

Conversation

IanWood1
Copy link
Contributor

@IanWood1 IanWood1 commented Feb 10, 2025

I observed that this improved some of the punet dispatches using warp reduction from 80us to 60us (example). However, I'm not seeing the improvement through the noise of CI. Locally, I saw ~0.7ms improvement to punet.

This seems to improve batch norm dispatches. I observed
(in a small bench) that it went from 80us to 60us.

Signed-off-by: Ian Wood <[email protected]>
@IanWood1 IanWood1 requested a review from pashu123 February 10, 2025 20:42
@IanWood1 IanWood1 marked this pull request as ready for review February 10, 2025 20:42
@pashu123
Copy link
Contributor

For reviewers:

  %2 = scf.for %arg0 = %c0 to %c60 step %c4 iter_args(%arg1 = %cst_1) -> (vector<1x4xf32>) {
    %60 = scf.for %arg2 = %c0 to %c3840 step %c256 iter_args(%arg3 = %arg1) -> (vector<1x4xf32>) {
      %61 = affine.apply affine_map<(d0)[s0] -> (d0 + s0 floordiv 64)>(%arg0)[%thread_id_x]
      %62 = affine.apply affine_map<(d0)[s0] -> (d0 + s0 * 4 - (s0 floordiv 64) * 256)>(%arg2)[%thread_id_x]
      %63 = vector.transfer_read %0[%workgroup_id_y, %workgroup_id_x, %61, %62], %cst_2 {in_bounds = [true, true]} : memref<2x32x60x3840xf16, strided<[7372800, 230400, 3840, 1], offset: 42741504>, #hal.descriptor
_type<storage_buffer>>, vector<1x4xf16>
      %64 = arith.extf %63 : vector<1x4xf16> to vector<1x4xf32>
      %65 = arith.addf %64, %arg3 : vector<1x4xf32>
      scf.yield %65 : vector<1x4xf32>
    }
    scf.yield %60 : vector<1x4xf32>
  }

It removes the outer loop with affine delinearize (apart from normalizing the loops).

 %2 = scf.for %arg0 = %c0_4 to %c225 step %c1 iter_args(%arg1 = %cst_1) -> (vector<1x4xf32>) {
      %60:2 = affine.delinearize_index %arg0 into (15, 15) : index, index
      %61 = affine.apply #map(%60#1)
      %62 = affine.apply #map1(%60#0)
      %63 = affine.apply #map2(%62)[%thread_id_x]
      %64 = affine.apply #map3(%61)[%thread_id_x]
      %65 = vector.transfer_read %0[%workgroup_id_y, %workgroup_id_x, %63, %64], %cst_2 {in_bounds = [true, true]} : memref<2x32x60x3840xf16, strided<[7372800, 230400, 3840, 1], offset: 42741504>, #hal.descriptor
_type<storage_buffer>>, vector<1x4xf16>
      %66 = arith.extf %65 : vector<1x4xf16> to vector<1x4xf32>
      %67 = arith.addf %66, %arg1 : vector<1x4xf32>
      scf.yield %67 : vector<1x4xf32>
  }

Copy link
Contributor

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Please make sure onnx-test-suite passes.

@MaheshRavishankar
Copy link
Contributor

Let's land this. Maybe try onnx suite to be sure, by this seems harmless enough. If we have issues we can always revert

@IanWood1
Copy link
Contributor Author

Okay, I'll run the onnx tests first. workflow here: https://github.com/nod-ai/SHARK-TestSuite/actions/runs/13257180035

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we have a LIT test for this?

@IanWood1
Copy link
Contributor Author

IanWood1 commented Feb 11, 2025

Could we have a LIT test for this?

Any suggestions on where to add a test? should it be added to https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_rocm.mlir?

@MaheshRavishankar
Copy link
Contributor

That would be the place.

@MaheshRavishankar MaheshRavishankar merged commit 06eaead into iree-org:main Feb 12, 2025
43 checks passed
Signed-off-by: Ian Wood <[email protected]>
hanhanW pushed a commit to hanhanW/iree that referenced this pull request Feb 13, 2025
…19950)

I observed that this improved some of the punet dispatches using warp
reduction from 80us to 60us
([example](https://gist.github.com/IanWood1/5b0e4fcb4e90a02525b94ea4347145f5)).
However, I'm not seeing the improvement through the noise of CI.
Locally, I saw ~0.7ms improvement to punet.

---------

Signed-off-by: Ian Wood <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants