-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Brgemm register tiling for bf16 type #1005
base: main
Are you sure you want to change the base?
Conversation
@rengolin Request to review this PR for |
if (options.registerTileShape.size() == 2) | ||
mxnxkTile[2] = 1; | ||
|
||
// k-tile size adjusted based on the vnni layout for bf16 type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has baked-in assumptions that are not verified.
As the pass now operates on generic, we need to strictly filter ops that are accepted. I think you need to at least ensure it is a VNNI contraction first - there should be some suitable helpers in VnniUtils
.
If f32 generic should be supported as well, it might need some extra checks there too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing out this.
Added more checks for f32
and bf16
type. Used a check from vnniutils
.
// CONF1-LABEL: memref.global "private" constant @__constant_48x32x32xf32 : memref<48x32x32xf32> = dense<1.000000e+00> {alignment = 64 : i64} | ||
// CONF1-LABEL: func.func @chainned_gemm_do_register_tiling( | ||
// CONF1-SAME: %[[VAL_0:.*]]: memref<8x48x32x32xf32>) -> memref<8x48x32x32xf32> { | ||
// CONF1: %[[VAL_1:.*]] = arith.constant 1 : index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you use more descriptive named for the captured values?
Also, these check feel too explicit, maybe you could omit some details
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have simplified the checks for smaller tests. Please have a look, if it is fine will replicate the same.
// Creates M, N, and K tile loops | ||
scf::ForOp loopOp = rewriter.create<scf::ForOp>(brgemmOp.getLoc(), | ||
zeroCst, ubCstTiledLoop, stepCstTiledLoop); | ||
scf::ForOp loopOp = rewriter.create<scf::ForOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I am understanding right, this transform is meant to operate on linalg ops. As I expect all the ops you want to support will implement TilingInterface, would it be possible to just use the TileUsingFor
transform instead of manually implementing tiling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it a target for linalg ops. But, now we are focusing only on batch reduce matmul
. As far as I know, the plan is to make tiling
, hoisting
, vector-to-fma
, and vector-to-amx
as a one register blocking
transform schedule in the upstream.
Also, we extended this pass to support bf16
type, so not taught of using TileUsingFor
transform schedule.
%0 = memref.get_global @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> | ||
%alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16> | ||
%expand_shape = memref.expand_shape %arg0 [[0], [1], [2], [3, 4]] output_shape [8, 32, 32, 16, 2] : memref<8x32x32x32xbf16> into memref<8x32x32x16x2xbf16> | ||
scf.forall (%arg1, %arg2) in (8, 32) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If these scf.forall are not needed by the (matcher of the) transform can we please get rid of them? Same goes for all the unittests in this file and other surrounding IR that does not influence the code-under-test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The A
, B
, and C
in linalg.generic
or in b r matmul
are of memref.subview
types. They are extracted from args
based on the induction variable of scf.forall
, so we rely on them.
ps: We do tile
only if the A
, B
, and C
are subviews
.
This PR extends the
brgemm register tiling
pass to supportbf16
type. The changes:linalg.batch_reduce_matmul
forfp32
andlinal.generic
forvnni
opt bf16,bf16
type.