Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Brgemm register tiling for bf16 type #1005

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

arun-thmn
Copy link
Contributor

@arun-thmn arun-thmn commented Feb 3, 2025

This PR extends the brgemm register tiling pass to support bf16 type. The changes:

  1. Template the existing pass to execute on linalg.batch_reduce_matmul for fp32 and linal.generic for vnni opt bf16,
  2. Test-cases for bf16 type.

@arun-thmn arun-thmn added the benchmark-full Benchmark all targets label Feb 3, 2025
@arun-thmn arun-thmn marked this pull request as ready for review February 3, 2025 03:38
@arun-thmn
Copy link
Contributor Author

@rengolin Request to review this PR for bf16 register tile support. I have re-written the tiling pass with new logic (template and more checks) to tile both fp32 and f16 (vnni). If you have time, I request you to review it as a new pass (as the existing tiling for fp32, I did it immediately joining Intel with lesser understanding of concepts).

@arun-thmn arun-thmn added benchmark-full Benchmark all targets and removed benchmark-full Benchmark all targets labels Feb 3, 2025
lib/TPP/Transforms/BrgemmLinalgTiling.cpp Outdated Show resolved Hide resolved
lib/TPP/Transforms/BrgemmLinalgTiling.cpp Outdated Show resolved Hide resolved
lib/TPP/Transforms/BrgemmLinalgTiling.cpp Outdated Show resolved Hide resolved
if (options.registerTileShape.size() == 2)
mxnxkTile[2] = 1;

// k-tile size adjusted based on the vnni layout for bf16 type
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has baked-in assumptions that are not verified.
As the pass now operates on generic, we need to strictly filter ops that are accepted. I think you need to at least ensure it is a VNNI contraction first - there should be some suitable helpers in VnniUtils.
If f32 generic should be supported as well, it might need some extra checks there too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing out this.
Added more checks for f32 and bf16 type. Used a check from vnniutils.

lib/TPP/Transforms/BrgemmLinalgTiling.cpp Outdated Show resolved Hide resolved
lib/TPP/Transforms/BrgemmLinalgTiling.cpp Outdated Show resolved Hide resolved
lib/TPP/Transforms/BrgemmLinalgTiling.cpp Outdated Show resolved Hide resolved
test/Integration/tile-brgemm-linalg-matmul-bf16.mlir Outdated Show resolved Hide resolved
test/Integration/tile-brgemm-linalg-matmul-bf16.mlir Outdated Show resolved Hide resolved
// CONF1-LABEL: memref.global "private" constant @__constant_48x32x32xf32 : memref<48x32x32xf32> = dense<1.000000e+00> {alignment = 64 : i64}
// CONF1-LABEL: func.func @chainned_gemm_do_register_tiling(
// CONF1-SAME: %[[VAL_0:.*]]: memref<8x48x32x32xf32>) -> memref<8x48x32x32xf32> {
// CONF1: %[[VAL_1:.*]] = arith.constant 1 : index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use more descriptive named for the captured values?

Also, these check feel too explicit, maybe you could omit some details

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have simplified the checks for smaller tests. Please have a look, if it is fine will replicate the same.

// Creates M, N, and K tile loops
scf::ForOp loopOp = rewriter.create<scf::ForOp>(brgemmOp.getLoc(),
zeroCst, ubCstTiledLoop, stepCstTiledLoop);
scf::ForOp loopOp = rewriter.create<scf::ForOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am understanding right, this transform is meant to operate on linalg ops. As I expect all the ops you want to support will implement TilingInterface, would it be possible to just use the TileUsingFor transform instead of manually implementing tiling?

Copy link
Contributor Author

@arun-thmn arun-thmn Feb 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it a target for linalg ops. But, now we are focusing only on batch reduce matmul. As far as I know, the plan is to make tiling, hoisting, vector-to-fma, and vector-to-amx as a one register blocking transform schedule in the upstream.
Also, we extended this pass to support bf16 type, so not taught of using TileUsingFor transform schedule.

%0 = memref.get_global @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16>
%expand_shape = memref.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [8, 32, 32, 16, 2] : memref<8x32x32x32xbf16> into memref<8x32x32x16x2xbf16>
scf.forall (%arg1, %arg2) in (8, 32) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these scf.forall are not needed by the (matcher of the) transform can we please get rid of them? Same goes for all the unittests in this file and other surrounding IR that does not influence the code-under-test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The A, B, and C in linalg.generic or in b r matmul are of memref.subview types. They are extracted from args based on the induction variable of scf.forall, so we rely on them.
ps: We do tile only if the A, B, and C are subviews.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
benchmark-full Benchmark all targets
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants