Skip to content

Commit

Permalink
code re-factoring and test cases for bf16
Browse files Browse the repository at this point in the history
  • Loading branch information
Arun Thangamani committed Feb 3, 2025
1 parent 7943e1e commit 310a21a
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 134 deletions.
89 changes: 49 additions & 40 deletions lib/TPP/Transforms/BrgemmLinalgTiling.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//===- BrgemmLinalgTiling.cpp -----------------------------------------*- C++-*-===//
//===- BrgemmLinalgTiling.cpp -----------------------------------------*-
//C++-*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down Expand Up @@ -57,35 +58,37 @@ struct LinalgOpTiling : OpRewritePattern<BrgemmOp> {
return failure();

// Check whether the tile sizes are valid
if (options.registerTileShape.size() != 3 && options.registerTileShape.size() != 2)
return failure();
if (options.registerTileShape.size() != 3 &&
options.registerTileShape.size() != 2)
return failure();

// Check the whether the operation is brmatmul fp32 or bf16 type using reduction count
// Check the whether the operation is brmatmul fp32 or bf16 type using
// reduction count
SmallVector<utils::IteratorType> brgemmIteratorTypes =
brgemmOp.getIteratorTypesArray();
int reductionCount =
std::count(brgemmIteratorTypes.begin(), brgemmIteratorTypes.end(),
utils::IteratorType::reduction);
if (reductionCount != 2 && reductionCount != 3)
return failure();
return failure();

// Get the register blocking tile shape from the user input
SmallVector<int64_t> mxnxkTile(3);
for (size_t i = 0; i < options.registerTileShape.size(); i++) {
mxnxkTile[i] = options.registerTileShape[i];
mxnxkTile[i] = options.registerTileShape[i];
}

// Set the K tile to 1, if the user not provided (it is fp32 target)
if (options.registerTileShape.size() == 2)
mxnxkTile[2] = 1;

// k-tile size adjusted based on the vnni layout for bf16 type
auto tensorShape = dyn_cast<MemRefType>(brgemmOp.getOperand(0).getType()).getShape();
auto tensorShape =
dyn_cast<MemRefType>(brgemmOp.getOperand(0).getType()).getShape();
if (tensorShape.size() == 4 && options.registerTileShape.size() == 3) {
mxnxkTile[2] = mxnxkTile[2] / tensorShape[3];
}


SmallVector<int> swap_i = {0, 2, 1};
size_t i = 0;
std::map<int, std::map<int, Value>> inductionVars;
Expand All @@ -100,42 +103,46 @@ struct LinalgOpTiling : OpRewritePattern<BrgemmOp> {
dyn_cast<MemRefType>(brgemmOp.getOperand(swap_i[i]).getType())
.getShape()[1];

//Tile size should not be greater than the upperBound
// Tile size should not be greater than the upperBound
if ((*itrShapeMNK) > upperBound)
return failure();
return failure();

Location loc = brgemmOp.getLoc();
Value zeroCst = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value ubCstTiledLoop = rewriter.create<arith::ConstantIndexOp>(loc, upperBound);
Value ubCstTiledLoop =
rewriter.create<arith::ConstantIndexOp>(loc, upperBound);

Value stepCstTiledLoop = rewriter.create<arith::ConstantIndexOp>(loc, *itrShapeMNK);
Value stepCstTiledLoop =
rewriter.create<arith::ConstantIndexOp>(loc, *itrShapeMNK);
// Creates M, N, and K tile loops
scf::ForOp loopOp = rewriter.create<scf::ForOp>(brgemmOp.getLoc(),
zeroCst, ubCstTiledLoop, stepCstTiledLoop);
scf::ForOp loopOp = rewriter.create<scf::ForOp>(
brgemmOp.getLoc(), zeroCst, ubCstTiledLoop, stepCstTiledLoop);
rewriter.setInsertionPointToStart(loopOp.getBody());
innermostForLoop = loopOp;

// Stores the induction variable with respect to the operands mapping it's subview.
// Stores the induction variable with respect to the operands mapping it's
// subview.
if (i == 0) {
inductionVars[0][1] = loopOp.getInductionVar();
inductionVars[2][0] = loopOp.getInductionVar();
} else if(i == 1) {
inductionVars[1][2] = loopOp.getInductionVar();
inductionVars[2][1] = loopOp.getInductionVar();
//Creates reduction loop after the N loop
inductionVars[0][1] = loopOp.getInductionVar();
inductionVars[2][0] = loopOp.getInductionVar();
} else if (i == 1) {
inductionVars[1][2] = loopOp.getInductionVar();
inductionVars[2][1] = loopOp.getInductionVar();
// Creates reduction loop after the N loop
Value ubCstReduction = rewriter.create<arith::ConstantIndexOp>(
loc, dyn_cast<MemRefType>(brgemmOp.getOperand(0).getType())
.getShape()[0]);
Value stepCstReduction = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value stepCstReduction =
rewriter.create<arith::ConstantIndexOp>(loc, 1);
scf::ForOp redloopOp = rewriter.create<scf::ForOp>(
brgemmOp.getLoc(), zeroCst, ubCstReduction, stepCstReduction);
rewriter.setInsertionPointToStart(redloopOp.getBody());
inductionVars[0][0] = redloopOp.getInductionVar();
inductionVars[1][0] = redloopOp.getInductionVar();

} else if(i == 2) {
inductionVars[0][2] = loopOp.getInductionVar();
inductionVars[1][1] = loopOp.getInductionVar();
} else if (i == 2) {
inductionVars[0][2] = loopOp.getInductionVar();
inductionVars[1][1] = loopOp.getInductionVar();
}
}

Expand All @@ -162,31 +169,30 @@ struct LinalgOpTiling : OpRewritePattern<BrgemmOp> {
auto tensorShape = dyn_cast<MemRefType>(input.getType()).getShape();
auto tileItr = tileshapes[i].begin();

// Iterates over the shape of each tensor and update its offsets, indices, shapes, strides with respect to tile sizes
// Iterates over the shape of each tensor and update its offsets, indices,
// shapes, strides with respect to tile sizes
for (size_t j = 0; j < tensorShape.size(); j++) {
if (j == 0 && (i < 2)) { // Updates the batch dimension
offsets.push_back(inductionVars[i][j]);
indices.push_back(1);
shape.push_back(rewriter.getIndexAttr(1));
strides.push_back(rewriter.getIndexAttr(1));
offsets.push_back(inductionVars[i][j]);
indices.push_back(1);
shape.push_back(rewriter.getIndexAttr(1));
strides.push_back(rewriter.getIndexAttr(1));
} else if (j < 3) { // Updates the M, N, and K dimensions
offsets.push_back(inductionVars[i][j]);
indices.push_back((*tileItr));
shape.push_back(rewriter.getIndexAttr(*tileItr));
strides.push_back(rewriter.getIndexAttr(1));
tileItr++;
} else { // Just copies the vnni layout dimensions
offsets.push_back(rewriter.getIndexAttr(0));
indices.push_back(tensorShape[j]);
shape.push_back(rewriter.getIndexAttr(tensorShape[j]));
strides.push_back(rewriter.getIndexAttr(1));
offsets.push_back(rewriter.getIndexAttr(0));
indices.push_back(tensorShape[j]);
shape.push_back(rewriter.getIndexAttr(tensorShape[j]));
strides.push_back(rewriter.getIndexAttr(1));
}

}

auto subview = rewriter.create<memref::SubViewOp>(
brgemmOp.getLoc(), MemRefType(),
input, offsets, shape, strides);
brgemmOp.getLoc(), MemRefType(), input, offsets, shape, strides);
brgemmOp.setOperand(i, subview);
}

Expand All @@ -204,11 +210,14 @@ struct LinalgOpTiling : OpRewritePattern<BrgemmOp> {
};

void populateBrgemmLinalgTilingPatterns(RewritePatternSet &patterns,
BrgemmLinalgTilingOptions options) {
patterns.add<LinalgOpTiling<linalg::GenericOp>, LinalgOpTiling<linalg::BatchReduceMatmulOp>>(patterns.getContext(), options);
BrgemmLinalgTilingOptions options) {
patterns.add<LinalgOpTiling<linalg::GenericOp>,
LinalgOpTiling<linalg::BatchReduceMatmulOp>>(
patterns.getContext(), options);
}

struct BrgemmLinalgTiling : public tpp::impl::BrgemmLinalgTilingBase<BrgemmLinalgTiling> {
struct BrgemmLinalgTiling
: public tpp::impl::BrgemmLinalgTilingBase<BrgemmLinalgTiling> {

using BrgemmLinalgTilingBase::BrgemmLinalgTilingBase;

Expand Down
Loading

0 comments on commit 310a21a

Please sign in to comment.