Skip to content

Commit 5ef8fdc

Browse files
committed
fix deepTileMatmul
1 parent b8f89e9 commit 5ef8fdc

File tree

1 file changed

+21
-37
lines changed

1 file changed

+21
-37
lines changed

lib/gc/Transforms/DeepTileContractionNamedOp.cpp

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -322,11 +322,10 @@ static unsigned getOprandDim(linalg::LinalgOp &linalgOp, unsigned iteratorPos,
322322
return linalgOp.getShape(linalgOp.getDpsInputOperand(operandIdx))[dimPos];
323323
}
324324

325-
static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
326-
Operation *op,
327-
bool isExtract,
328-
SmallVector<int64_t> size,
329-
int shrinDimNum = 0) {
325+
static void setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
326+
Operation *op, bool isExtract,
327+
SmallVector<int64_t> size,
328+
int shrinDimNum = 0) {
330329
OpBuilder::InsertionGuard guard(rewriter);
331330
rewriter.setInsertionPoint(op);
332331
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(op)) {
@@ -348,15 +347,12 @@ static LogicalResult setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
348347
extractSlice, extractSlice.getSource(), mixedOffsets, mixedSizes,
349348
mixedStrides);
350349
}
351-
} else {
352-
return failure();
353350
}
354-
return mlir::success();
355351
}
356352

357-
static LogicalResult setStaticSizeForInsertSliceOp(RewriterBase &rewriter,
358-
Operation *op, Value source,
359-
SmallVector<int64_t> size) {
353+
static void setStaticSizeForInsertSliceOp(RewriterBase &rewriter, Operation *op,
354+
Value source,
355+
SmallVector<int64_t> size) {
360356
OpBuilder::InsertionGuard guard(rewriter);
361357
rewriter.setInsertionPoint(op);
362358
if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(op)) {
@@ -369,10 +365,7 @@ static LogicalResult setStaticSizeForInsertSliceOp(RewriterBase &rewriter,
369365
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
370366
insertSlice, source, insertSlice.getDest(), mixedOffsets, mixedSizes,
371367
mixedStrides);
372-
} else {
373-
return failure();
374368
}
375-
return success();
376369
}
377370

378371
using InnermostFullResultCallBackFn = std::function<FailureOr<linalg::LinalgOp>(
@@ -691,7 +684,6 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
691684
linalg::LinalgOp originOp,
692685
linalg::LinalgOp currentOp,
693686
innerBodyGenerationOption &option) const {
694-
695687
mlir::easybuild::EasyBuilder eb{rewriter, originOp.getLoc()};
696688
auto operandDimTypes = getOprandDimType(originOp);
697689
auto cfg = MatmulConfigAnalysis(originOp.getOperation()).getConfig();
@@ -744,6 +736,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
744736
CInnermostDims =
745737
SmallVector<int64_t>{cfg.innerMostMBlock, cfg.innerMostNBlock};
746738
}
739+
747740
if (NDimNum > 1) {
748741
firstN = true;
749742
firstK = true;
@@ -780,21 +773,17 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
780773

781774
// update the extractSlice to static size, replace it with
782775
// useBlockedLayout when
783-
if (failed(setStaticSizeForExtractSliceOp(
784-
rewriter, currentOp.getDpsInits()[0].getDefiningOp(), true,
785-
CInnermostDims, MDimNum > 1 ? 2 : 0)) ||
786-
failed(setStaticSizeForExtractSliceOp(
787-
rewriter, currentOp.getDpsInputs()[1].getDefiningOp(), true,
788-
BInnermostDims, NDimNum > 1)) ||
789-
failed(setStaticSizeForExtractSliceOp(
790-
rewriter, currentOp.getDpsInputs()[0].getDefiningOp(), true,
791-
AInnermostDims, MDimNum > 1)) ||
792-
(currentOp.getDpsInits().size() > 1 &&
793-
failed(setStaticSizeForExtractSliceOp(
794-
rewriter, currentOp.getDpsInits()[1].getDefiningOp(), true,
795-
CInnermostDims, MDimNum > 1 ? 2 : 0)))) {
796-
return failure();
776+
setStaticSizeForExtractSliceOp(rewriter,
777+
currentOp.getDpsInputs()[1].getDefiningOp(),
778+
true, BInnermostDims, NDimNum > 1);
779+
setStaticSizeForExtractSliceOp(rewriter,
780+
currentOp.getDpsInputs()[0].getDefiningOp(),
781+
true, AInnermostDims, MDimNum > 1);
782+
for (auto init : currentOp.getDpsInits()) {
783+
setStaticSizeForExtractSliceOp(rewriter, init.getDefiningOp(), true,
784+
CInnermostDims, MDimNum > 1 ? 2 : 0);
797785
}
786+
798787
// View the tensor to brgemm required format
799788
Value dataOprand = tensorViewRankedTensor(
800789
rewriter,
@@ -841,10 +830,7 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
841830

842831
// Insert the result back to the original tensor
843832
for (Operation *user : currentOp->getResult(0).getUsers()) {
844-
if (failed(setStaticSizeForInsertSliceOp(rewriter, user, result,
845-
CInnermostDims))) {
846-
return failure();
847-
}
833+
setStaticSizeForInsertSliceOp(rewriter, user, result, CInnermostDims);
848834
}
849835

850836
if (option.needLowPrecisionCast) {
@@ -869,10 +855,8 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
869855
auto ifOp = eb.getLastOperaion();
870856
// set static size for the insertSliceOp of copyOp
871857
for (Operation *user : currentOp->getResult(1).getUsers()) {
872-
if (failed(setStaticSizeForInsertSliceOp(
873-
rewriter, user, ifOp->getResult(0), CInnermostDims))) {
874-
return failure();
875-
}
858+
setStaticSizeForInsertSliceOp(rewriter, user, ifOp->getResult(0),
859+
CInnermostDims);
876860
}
877861
rewriter.replaceOp(currentOp, {matmul->getResult(0), ifOp->getResult(0)});
878862
} else {

0 commit comments

Comments
 (0)