From db2399e9ce8380008bc0a07e092e0fbcdd5d06c8 Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Wed, 31 Jul 2024 23:49:00 -0700 Subject: [PATCH 01/23] Support named op layout propagation and pack processing --- include/gc/Analysis/GlobalAnalysis.h | 152 ++++ include/gc/Transforms/Passes.td | 34 + include/gc/Transforms/Transforms.h | 28 + lib/gc/Analysis/CMakeLists.txt | 1 + lib/gc/Analysis/GlobalAnalysis.cpp | 348 +++++++++ lib/gc/Transforms/CMakeLists.txt | 3 + lib/gc/Transforms/LowerPackUnpack.cpp | 83 +++ lib/gc/Transforms/Pipeline.cpp | 4 + lib/gc/Transforms/PostProcessPackUnpack.cpp | 179 +++++ lib/gc/Transforms/PropagateLayout.cpp | 677 ++++++++++++++++++ packMatmul.patch | 32 + .../named-op-layout-propagation.mlir | 12 + test/mlir/test/gc/Transforms/pack-matmul.mlir | 58 ++ 13 files changed, 1611 insertions(+) create mode 100644 include/gc/Analysis/GlobalAnalysis.h create mode 100644 include/gc/Transforms/Transforms.h create mode 100644 lib/gc/Analysis/GlobalAnalysis.cpp create mode 100644 lib/gc/Transforms/LowerPackUnpack.cpp create mode 100644 lib/gc/Transforms/PostProcessPackUnpack.cpp create mode 100644 lib/gc/Transforms/PropagateLayout.cpp create mode 100644 packMatmul.patch create mode 100644 test/mlir/test/gc/Transforms/named-op-layout-propagation.mlir create mode 100644 test/mlir/test/gc/Transforms/pack-matmul.mlir diff --git a/include/gc/Analysis/GlobalAnalysis.h b/include/gc/Analysis/GlobalAnalysis.h new file mode 100644 index 000000000..48fe9677e --- /dev/null +++ b/include/gc/Analysis/GlobalAnalysis.h @@ -0,0 +1,152 @@ +//===- GlobalAnalysis.h - Graph Compiler analysis pass ----------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_GLOBALANALYSIS_H +#define MLIR_ANALYSIS_GLOBALANALYSIS_H + +#include + +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace gc { + +using namespace mlir; + +class TensorLayout { +public: + TensorLayout(ArrayRef outerAxis, ArrayRef innerAxis, + ArrayRef tileSizes) + : outerAxis(outerAxis), innerAxis(innerAxis), tileSizes(tileSizes) { + assert(innerAxis.size() == tileSizes.size()); + } + + bool isPlainLayout() const { + for (int64_t i = 0; i < static_cast(outerAxis.size()); ++i) { + if (i != outerAxis[i]) + return false; + } + return tileSizes.empty() && innerAxis.empty(); + } + + static TensorLayout createPlainLayout(int64_t rank) { + SmallVector outerAxis(rank, 0); + std::iota(outerAxis.begin(), outerAxis.end(), 0); + return TensorLayout(outerAxis, SmallVector{}, + SmallVector{}); + } + + DenseMap> getPlainToPackedAxisMapping() { + DenseMap> axisMapping; + int64_t outerAxisSize = outerAxis.size(); + for (int64_t i = 0; i < outerAxisSize; ++i) { + axisMapping[outerAxis[i]].push_back(i); + } + for (int64_t i = 0; i < static_cast(innerAxis.size()); ++i) { + axisMapping[innerAxis[i]].push_back(outerAxisSize + i); + } + return axisMapping; + } + + FailureOr getPlainAxis(int64_t idx) { + int64_t totalRank = outerAxis.size() + innerAxis.size(); + if (idx >= totalRank || idx < 0) { + return failure(); + } else if (idx >= static_cast(outerAxis.size())) { + return innerAxis[idx - outerAxis.size()]; + } else { + return outerAxis[idx]; + } + } + + size_t getRank() const { return outerAxis.size(); } + + SmallVector getOuterAxis() const { return outerAxis; } + + SmallVector getInnerAxis() const { return innerAxis; } + + SmallVector getTileSizes() const { return tileSizes; } + + friend llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, + const TensorLayout &layout); + + bool operator==(const TensorLayout &layout); + +private: + SmallVector outerAxis; + SmallVector innerAxis; + SmallVector tileSizes; +}; + +class OperatorLayout { +public: + OperatorLayout() {} + + OperatorLayout(SmallVector inputLayouts, + SmallVector outputLayouts) { + supportedInputLayouts = inputLayouts; + supportedOutputLayouts = outputLayouts; + } + + SmallVector getSupportedInputLayouts() const { + return supportedInputLayouts; + } + + SmallVector getSupportedOutputLayouts() const { + return supportedOutputLayouts; + } + + TensorLayout getOutputLayout(int64_t idx) const { + assert(idx < static_cast(supportedOutputLayouts.size())); + return supportedOutputLayouts[idx]; + } + + bool isPlain() const { + for (const auto &layout : llvm::concat( + supportedInputLayouts, supportedOutputLayouts)) { + if (!layout.isPlainLayout()) + return false; + } + return true; + } + + friend llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, + const OperatorLayout &opLayout); + +private: + SmallVector supportedInputLayouts; + SmallVector supportedOutputLayouts; +}; + +class GlobalAnalysis { +public: + explicit GlobalAnalysis(Operation *root); + + FailureOr getOpLayout(Operation *op) { + if (layoutCache.find(op) != layoutCache.end()) + return layoutCache[op]; + else + return failure("Current op does not have layout information."); + } + +private: + DenseMap layoutCache; +}; + +namespace utils { +bool isPackableNamedOp(Operation *op); +} +} // namespace gc +} // namespace mlir + +#endif diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index d5330851b..bc9057c2d 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -106,4 +106,38 @@ def MergeNestedForall : Pass<"merge-nested-forall"> { let dependentDialects = ["scf::SCFDialect"]; } +def PropagateLayoutOnNamedOps : Pass<"propagate-layout-on-named-ops"> { + let summary = "Insert and propagte tensor.pack to pack the computation of linalg named ops and tensor ops."; + let description = [{ + Insert and propagte tensor.pack on linalg named ops and tensor ops. + }]; + let dependentDialects = [ + "mlir::tensor::TensorDialect", + "mlir::linalg::LinalgDialect", + "mlir::linalgx::LinalgxDialect" + ]; +} + +def PostProcessPackUnpack : Pass<"post-process-pack-unpack"> { + let summary = "Fold and simplify pack and unpack ops."; + let description = [{ + Fold and simplify pack and unpack ops. + }]; + let dependentDialects = [ + "mlir::tensor::TensorDialect", + "mlir::linalg::LinalgDialect" + ]; +} + +def LowerPackUnpack : Pass<"lower-pack-unpack"> { + let summary = "Lower pack and unpack ops."; + let description = [{ + Lower pack and unpack into transpose and shape related ops. + }]; + let dependentDialects = [ + "mlir::tensor::TensorDialect", + "mlir::linalg::LinalgDialect" + ]; +} + #endif // GC_DIALECT_GC_PASSES diff --git a/include/gc/Transforms/Transforms.h b/include/gc/Transforms/Transforms.h new file mode 100644 index 000000000..0e0ed76c7 --- /dev/null +++ b/include/gc/Transforms/Transforms.h @@ -0,0 +1,28 @@ +//===- Transforms.h - transformation utilities ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef GC_TRANSFORMS_TRANSFORMS_H +#define GC_TRANSFORMS_TRANSFORMS_H + +#include "gc/Analysis/GlobalAnalysis.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" + +namespace mlir { +namespace gc { +FailureOr packNamedOp(RewriterBase &rewriter, + linalg::LinalgOp linalgOp, + OperatorLayout opLayout); + +LogicalResult namedOpLayoutPropagation(RewriterBase &rewriter, + linalg::LinalgOp linalgOp, + OperatorLayout opLayout); +} // namespace gc +} // namespace mlir + +#endif // GC_TRANSFORMS_TRANSFORMS_H diff --git a/lib/gc/Analysis/CMakeLists.txt b/lib/gc/Analysis/CMakeLists.txt index d7160f350..55a689d86 100644 --- a/lib/gc/Analysis/CMakeLists.txt +++ b/lib/gc/Analysis/CMakeLists.txt @@ -5,6 +5,7 @@ gc_set_mlir_link_components(MLIR_LINK_COMPONENTS gc_add_mlir_library(GcAnalysis TargetDescriptionAnalysis.cpp MatmulConfigAnalysis.cpp + GlobalAnalysis.cpp DEPENDS GraphCompilerPassIncGen diff --git a/lib/gc/Analysis/GlobalAnalysis.cpp b/lib/gc/Analysis/GlobalAnalysis.cpp new file mode 100644 index 000000000..b634daea1 --- /dev/null +++ b/lib/gc/Analysis/GlobalAnalysis.cpp @@ -0,0 +1,348 @@ +//===- GlobalAnalysis.cpp - Propagate packing on linalg named ops *- C++-*-===// +// +// This file is only temporarily used to extend upstream or upcoming utility in +// TilingInterface, which finally aims for upstream. +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "gc/Analysis/GlobalAnalysis.h" +#include "gc/Analysis/MatmulConfigAnalysis.h" + +namespace mlir { +namespace gc { + +#define DEBUG_TYPE "global-analysis" + +llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, + const TensorLayout &layoutCache) { + SmallVector outerAxis = layoutCache.getOuterAxis(); + SmallVector innerAxis = layoutCache.getInnerAxis(); + SmallVector tileSizes = layoutCache.getTileSizes(); + ss << "["; + llvm::interleaveComma(outerAxis, ss); + if (!innerAxis.empty()) { + ss << "; "; + llvm::interleaveComma(innerAxis, ss); + } + ss << "]"; + if (!tileSizes.empty()) { + ss << "; {"; + llvm::interleaveComma(tileSizes, ss); + ss << "}"; + } + return ss; +} + +bool TensorLayout::operator==(const TensorLayout &layout) { + return (this->outerAxis == layout.getOuterAxis()) && + (this->innerAxis == layout.getInnerAxis()) && + (this->tileSizes == layout.getTileSizes()); +} + +llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, + const OperatorLayout &opLayout) { + for (auto &&[idx, layoutCache] : + llvm::enumerate(opLayout.getSupportedInputLayouts())) { + ss << "input " << idx << "'s layout: " << layoutCache << "\n"; + } + for (auto &&[idx, layoutCache] : + llvm::enumerate(opLayout.getSupportedOutputLayouts())) { + ss << "output " << idx << "'s layout: " << layoutCache << "\n"; + } + return ss; +} + +// inferring the relationship of two indexing map +// j -> i, means j is represented as the same symbol as i +// we don't allow duplicate in symbols +// e.g. if 2 j corresponding to 1 i, then return failure +static FailureOr> +inferIndexingMapRelation(AffineMap indexingMapBase, + AffineMap indexingMapTarget) { + DenseMap res; + ArrayRef resultsBase = indexingMapBase.getResults(); + ArrayRef resultsTarget = indexingMapTarget.getResults(); + for (size_t j = 0; j < resultsTarget.size(); ++j) { + for (size_t i = 0; i < resultsBase.size(); ++i) { + auto base = dyn_cast(resultsBase[i]); + auto target = dyn_cast(resultsTarget[j]); + if (base && target && base.getPosition() == target.getPosition()) { + if (res.find(j) != res.end()) + return failure(); + res[j] = i; + } + } + if (res.find(j) == res.end()) + res[j] = -1; + } + // check res + DenseSet indexSet; + for (auto pair : res) { + if (indexSet.find(pair.second) != indexSet.end()) { + return failure(); + } + if (pair.second >= 0) { + indexSet.insert(pair.second); + } + } + return res; +} + +// given j --> i and max rank of i, return i --> j +static DenseMap +getReversedIndexMap(const DenseMap &indexMap, + size_t maxRank) { + DenseMap res; + for (auto pair : indexMap) { + if (pair.second >= 0) { + res[pair.second] = pair.first; + } + } + for (size_t i = 0; i < maxRank; ++i) { + if (res.find(i) == res.end()) { + res[i] = -1; + } + } + return res; +} + +static FailureOr +inferTargetLayout(TensorLayout layoutBase, + const DenseMap &indexMap) { + SmallVector baseOuterAxis = layoutBase.getOuterAxis(); + SmallVector baseInnerAxis = layoutBase.getInnerAxis(); + SmallVector baseTileSizes = layoutBase.getTileSizes(); + SmallVector targetOuterAxis; + SmallVector targetInnerAxis; + SmallVector targetTileSizes; + DenseMap reverseIndexMap = + getReversedIndexMap(indexMap, layoutBase.getRank()); + for (auto oa : baseOuterAxis) { + if (reverseIndexMap[oa] >= 0) { + targetOuterAxis.push_back(reverseIndexMap[oa]); + } + } + // filling up new j axes + SmallVector newDimAxis; + for (auto pair : indexMap) { + if (pair.second < 0) { + newDimAxis.push_back(pair.first); + } + } + targetOuterAxis.insert(targetOuterAxis.begin(), newDimAxis.begin(), + newDimAxis.end()); + for (auto &&[ia, ts] : llvm::zip(baseInnerAxis, baseTileSizes)) { + if (reverseIndexMap[ia] >= 0) { + targetInnerAxis.push_back(reverseIndexMap[ia]); + targetTileSizes.push_back(ts); + } + } + return TensorLayout(targetOuterAxis, targetInnerAxis, targetTileSizes); +} + +static size_t getTargetInputIdx(ArrayRef curInputLayouts) { + for (size_t i = 0; i < curInputLayouts.size(); ++i) { + if (!curInputLayouts[i].isPlainLayout()) { + return i; + } + } + return 0; +} + +static bool supportedContractionNamedOpList(linalg::LinalgOp &linalgOp) { + if (isa( + linalgOp)) + return true; + return false; +} + +std::pair, SmallVector> +getPackingAxis(int64_t numRank, bool transposed) { + assert(numRank >= 2 && + "The rank of matmul semantic contraction op shall be at least 2."); + SmallVector outerAxisPerm(numRank); + SmallVector innerAxisPos(2); + std::iota(outerAxisPerm.begin(), outerAxisPerm.end(), 0); + innerAxisPos[0] = numRank - 2; + innerAxisPos[1] = numRank - 1; + if (transposed) { + std::swap(outerAxisPerm[numRank - 2], outerAxisPerm[numRank - 1]); + std::swap(innerAxisPos[0], innerAxisPos[1]); + } + return std::make_pair(outerAxisPerm, innerAxisPos); +} + +GlobalAnalysis::GlobalAnalysis(Operation *root) { + root->walk([&](Operation *op) { + if (auto linalgOp = dyn_cast(op)) { + LLVM_DEBUG(llvm::dbgs() + << "Inferring layout of op: " << op->getName() << "\n"); + auto curInputs = linalgOp.getDpsInputOperands(); + auto curResults = linalgOp.getOperation()->getResults(); + // ---------------- Get Current Input Layouts ------------------- + SmallVector curInputLayouts; + for (auto input : curInputs) { + auto parent = input->get().getDefiningOp(); + if (layoutCache.find(parent) != layoutCache.end()) { + // TODO(yifei): it is not always 0 here + curInputLayouts.push_back(layoutCache[parent].getOutputLayout(0)); + } else { + curInputLayouts.push_back(TensorLayout::createPlainLayout( + linalgOp.getMatchingIndexingMap(input).getNumResults())); + } + } + // ------ Get Current Op's Suggested Layout & Do Propagation ------ + IRRewriter rewriter(linalgOp); + // TODO: extend to packed/vnni matmul ops + if (supportedContractionNamedOpList(linalgOp)) { + // get input and output rank + auto ARank = cast(linalgOp.getDpsInputs()[0].getType()) + .getShape() + .size(); + auto BRank = cast(linalgOp.getDpsInputs()[1].getType()) + .getShape() + .size(); + auto CRank = + cast(linalgOp.getOperation()->getResults()[0].getType()) + .getShape() + .size(); + bool ASideTransposed = + isa( + linalgOp); + bool BSideTransposed = + isa( + linalgOp); + // set outer&inner axis values + auto APackInfo = getPackingAxis(ARank, ASideTransposed); + auto BPackInfo = getPackingAxis(BRank, BSideTransposed); + auto CPackInfo = getPackingAxis(CRank, /*transposed*/ false); + // query the cost model for tile sizes + MatmulConfig cfg = + MatmulConfigAnalysis(linalgOp.getOperation()).getConfig(); + uint32_t iim = cfg.innerMostKBlock, iin = cfg.innerMostNBlock, + iik = cfg.innerMostKBlock; + // current layout is MKmk, NKkn, MNmn + TensorLayout ALayout( + APackInfo.first, APackInfo.second, + SmallVector{rewriter.getIndexAttr(iim), + rewriter.getIndexAttr(iik)}); + TensorLayout BLayout( + BPackInfo.first, BPackInfo.second, + SmallVector{rewriter.getIndexAttr(iik), + rewriter.getIndexAttr(iin)}); + TensorLayout CLayout( + CPackInfo.first, CPackInfo.second, + SmallVector{rewriter.getIndexAttr(iim), + rewriter.getIndexAttr(iin)}); + OperatorLayout suggestedLayout({ALayout, BLayout}, {CLayout}); + layoutCache[linalgOp] = suggestedLayout; + } else if (!mlir::linalg::isaContractionOpInterface(linalgOp) && + !supportedContractionNamedOpList(linalgOp)) { + SmallVector inputLayouts, outputLayouts; + size_t targetIdx = getTargetInputIdx(curInputLayouts); + // TODO(yifei): wisely choose the input format basis + // Let's only refer to input[0] for now + for (size_t i = 0; i < curInputs.size(); ++i) { + // getMatchingIndexingMap + if (i != targetIdx) { + auto res = inferIndexingMapRelation( + linalgOp.getMatchingIndexingMap(curInputs[targetIdx]), + linalgOp.getMatchingIndexingMap(curInputs[i])); + TensorLayout inputLayout = + *inferTargetLayout(curInputLayouts[targetIdx], *res); + inputLayouts.push_back(inputLayout); + } else { + inputLayouts.push_back(curInputLayouts[targetIdx]); + } + } + auto res_out = inferIndexingMapRelation( + linalgOp.getMatchingIndexingMap(curInputs[targetIdx]), + linalgOp.getIndexingMapMatchingResult(curResults[0])); + TensorLayout outputLayout = + *inferTargetLayout(curInputLayouts[targetIdx], *res_out); + outputLayouts.push_back(outputLayout); + OperatorLayout suggestedLayout(inputLayouts, outputLayouts); + layoutCache[linalgOp] = suggestedLayout; + } + } else if (auto padOp = dyn_cast(op)) { + auto inputOperand = padOp.getSource(); + auto inputRank = + cast(inputOperand.getType()).getShape().size(); + auto parent = inputOperand.getDefiningOp(); + TensorLayout curInputLayout = + layoutCache.find(parent) != layoutCache.end() + ? layoutCache[parent].getOutputLayout(0) + : TensorLayout::createPlainLayout(inputRank); + SmallVector inputLayouts{curInputLayout}, + outputLayouts{curInputLayout}; + OperatorLayout suggestedLayout(inputLayouts, outputLayouts); + layoutCache[padOp] = suggestedLayout; + } else if (auto expandShapeOp = dyn_cast(op)) { + auto reassociation = expandShapeOp.getReassociation(); + auto staticOutputShape = expandShapeOp.getStaticOutputShape(); + auto parent = expandShapeOp.getSrc().getDefiningOp(); + auto inputShape = expandShapeOp.getSrcType().getShape(); + TensorLayout curInputLayout = + layoutCache.find(parent) != layoutCache.end() + ? layoutCache[parent].getOutputLayout(0) + : TensorLayout::createPlainLayout(inputShape.size()); + DenseMap outputInputIdxMapping, inputOutputIndexMapping; + int64_t accumulationOffset = 0; + for (int64_t i = 0; i < static_cast(reassociation.size()); ++i) { + auto subReassociation = llvm::cast(reassociation[i]); + for (int64_t j = 0; j < static_cast(subReassociation.size()); + ++j) { + if (staticOutputShape[accumulationOffset + j] == inputShape[i]) { + outputInputIdxMapping[accumulationOffset + j] = i; + inputOutputIndexMapping[i] = accumulationOffset + j; + } + } + accumulationOffset += subReassociation.size(); + } + auto inputOuterAxis = curInputLayout.getOuterAxis(); + auto inputInnerAxis = curInputLayout.getInnerAxis(); + int64_t diffDifference = staticOutputShape.size() - inputShape.size(); + int64_t startIdx = 0; + SmallVector outputOuterAxis, outputInnerAxis; + for (int64_t i = 0; i < static_cast(staticOutputShape.size()); + ++i) { + if (outputInputIdxMapping.find(i) != outputInputIdxMapping.end()) { + outputOuterAxis.push_back(inputOuterAxis[outputInputIdxMapping[i]] + + diffDifference); + } else { + outputOuterAxis.push_back(startIdx++); + } + } + for (int64_t i = 0; i < static_cast(inputInnerAxis.size()); + ++i) { + outputInnerAxis.push_back(inputOutputIndexMapping[inputInnerAxis[i]]); + } + TensorLayout outputLayout(outputOuterAxis, outputInnerAxis, + curInputLayout.getTileSizes()); + SmallVector inputLayouts{curInputLayout}, + outputLayouts{outputLayout}; + OperatorLayout suggestedLayout(inputLayouts, outputLayouts); + layoutCache[expandShapeOp] = suggestedLayout; + } + }); +} + +namespace utils { +bool isPackableNamedOp(Operation *op) { + if (auto linalgOp = dyn_cast(op)) { + if (!supportedContractionNamedOpList(linalgOp)) { + return true; + } + } else if (isa( + op)) + return true; + return false; +} +} // namespace utils +} // namespace gc +} // namespace mlir diff --git a/lib/gc/Transforms/CMakeLists.txt b/lib/gc/Transforms/CMakeLists.txt index 705e257d7..5a57da040 100644 --- a/lib/gc/Transforms/CMakeLists.txt +++ b/lib/gc/Transforms/CMakeLists.txt @@ -15,6 +15,9 @@ gc_add_mlir_library(GcPasses IterativeTilingAndFusion.cpp TilingUsingInterfaceX.cpp VerifyTargetDescription.cpp + PropagateLayout.cpp + PostProcessPackUnpack.cpp + LowerPackUnpack.cpp DeepTileContractionNamedOp.cpp TilingUtil.cpp SinkOpIntoInnerLoop.cpp diff --git a/lib/gc/Transforms/LowerPackUnpack.cpp b/lib/gc/Transforms/LowerPackUnpack.cpp new file mode 100644 index 000000000..811f8a4ce --- /dev/null +++ b/lib/gc/Transforms/LowerPackUnpack.cpp @@ -0,0 +1,83 @@ +//===- LowerPackUnpack.cpp - Lower pack unpack into linalg ops *---- C++-*-===// +// +// This file is only temporarily used to extend upstream or upcoming utility in +// TilingInterface, which finally aims for upstream. +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "gc/Transforms/Transforms.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "gc/Dialect/Linalgx/LinalgxDialect.h" +#include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "gc/Transforms/Passes.h" +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_LOWERPACKUNPACK +#include "gc/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "lower-pack-unpack" + +using namespace mlir; + +// copied from tpp +// A wrapper pattern that calls linalg::lowerPack on tensor::PackOp. It lowers +// a tensor.pack op to tensor.pad + tensor.expand_shape + linalg.transpose ops. +struct LowerPackPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PackOp op, + PatternRewriter &rewriter) const override { + FailureOr res = linalg::lowerPack(rewriter, op); + if (failed(res)) { + return rewriter.notifyMatchFailure( + op, "cannot lower to pad + expand + transpose"); + } + return success(); + } +}; + +// A wrapper pattern that calls linalg::lowerUnPack on tensor::UnPackOp. It +// lowers a tensor.unpack op to tensor.empty + linalg.transpose + +// tensor.collapse_shape + tensor.extract_slice ops. +struct LowerUnPackPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::UnPackOp op, + PatternRewriter &rewriter) const override { + if (failed(linalg::lowerUnPack(rewriter, op))) { + return rewriter.notifyMatchFailure( + op, "cannot lower to empty + transpose + reshape + extract_slice"); + } + return success(); + } +}; + +class LowerPackUnpack : public impl::LowerPackUnpackBase { +public: + using impl::LowerPackUnpackBase::LowerPackUnpackBase; + void runOnOperation() final; +}; + +void LowerPackUnpack::runOnOperation() { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +} // namespace gc +} // namespace mlir diff --git a/lib/gc/Transforms/Pipeline.cpp b/lib/gc/Transforms/Pipeline.cpp index f198c6c75..162369e4c 100644 --- a/lib/gc/Transforms/Pipeline.cpp +++ b/lib/gc/Transforms/Pipeline.cpp @@ -50,6 +50,8 @@ void populateFrontendPasses(mlir::OpPassManager &pm) { void populateTensorPasses(mlir::OpPassManager &pm) { // todo: padding propagation pass // todo: layout propagation pass + pm.addPass(createPropagateLayoutOnNamedOps()); + pm.addPass(createPostProcessPackUnpack()); // todo: tensor constant propagation pass // linalg.matmul lowering to (scf.loop + linalg.brgemm) pass pm.addNestedPass(createDeepTileContractionNamedOp()); @@ -64,6 +66,8 @@ void populateTensorPasses(mlir::OpPassManager &pm) { pm.addNestedPass(createLinalgGeneralizeNamedOpsPass()); pm.addPass(createLoopInvariantCodeMotionPass()); pm.addPass(createControlFlowSinkPass()); + // TODO(yifei): remove lower pack here + pm.addPass(createLowerPackUnpack()); populateCleanUpPasses(pm); } diff --git a/lib/gc/Transforms/PostProcessPackUnpack.cpp b/lib/gc/Transforms/PostProcessPackUnpack.cpp new file mode 100644 index 000000000..0427cfaf9 --- /dev/null +++ b/lib/gc/Transforms/PostProcessPackUnpack.cpp @@ -0,0 +1,179 @@ +//===- PostProcessPackUnpack.cpp - Fold and simplify pack unpack *-- C++-*-===// +// +// This file is only temporarily used to extend upstream or upcoming utility in +// TilingInterface, which finally aims for upstream. +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "gc/Transforms/Transforms.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "gc/Dialect/Linalgx/LinalgxDialect.h" +#include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "gc/Transforms/Passes.h" +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_POSTPROCESSPACKUNPACK +#include "gc/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "post-process-pack-unpack" + +using namespace mlir; + +// Helper pattern - lower tensor.pack operations that pack constants. +struct LowerConstantPacking : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PackOp packOp, + PatternRewriter &rewriter) const override { + auto constOp = packOp.getSource().getDefiningOp(); + if (!constOp) + return failure(); + // Must be a dense constant. + auto denseAttr = dyn_cast(constOp.getValue()); + if (!denseAttr) + return failure(); + + // Bail out if the pack is used as a writing operation i.e., the destination + // is not a tensor.empty. + if (!packOp.getDest().getDefiningOp()) + return rewriter.notifyMatchFailure(packOp, + "expects empty tensor destination"); + // Pack destination must have static shape. + if (!packOp.getDestType().hasStaticShape()) + return rewriter.notifyMatchFailure( + packOp, "expects destination with static shape"); + + // If it is a splat constant, skip and let tensor.pack folder to handle this + // case. + if (denseAttr.isSplat()) + return rewriter.notifyMatchFailure( + packOp, "skip pack - existing folder covers constant splats"); + + return linalg::lowerPack(rewriter, packOp); + } +}; + +static void tppPopulateConstantFoldPack(RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + patterns.add(ctx); + // Apply canonicalization to fold trivial cases and linalg constant folders + // to cleanup lowered packs. + linalg::FillOp::getCanonicalizationPatterns(patterns, ctx); + tensor::PackOp::getCanonicalizationPatterns(patterns, ctx); + tensor::populateRewriteAsConstantPatterns( + patterns, [](OpOperand *) -> bool { return true; }); + linalg::populateConstantFoldLinalgOperations( + patterns, [](OpOperand *) -> bool { return true; }); +} + +struct EliminateDummyPack : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PackOp packOp, + PatternRewriter &rewriter) const override { + if (packOp.getStaticInnerTiles().empty() && + packOp.getInnerTiles().empty()) { + auto outerPerm = packOp.getOuterDimsPerm(); + for (size_t i = 0; i < outerPerm.size(); ++i) { + if (outerPerm[i] != i) { + return rewriter.notifyMatchFailure(packOp, "Not dummy"); + } + } + auto source = packOp.getSource(); + rewriter.replaceAllOpUsesWith(packOp, source); + packOp->erase(); + return success(); + } else { + return rewriter.notifyMatchFailure(packOp, "Not dummy"); + } + } +}; + +struct EliminateDummyUnpack : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp, + PatternRewriter &rewriter) const override { + if (unpackOp.getStaticInnerTiles().empty() && + unpackOp.getInnerTiles().empty()) { + auto outerPerm = unpackOp.getOuterDimsPerm(); + for (size_t i = 0; i < outerPerm.size(); ++i) { + if (outerPerm[i] != i) { + return rewriter.notifyMatchFailure(unpackOp, "Not dummy"); + } + } + auto source = unpackOp.getSource(); + rewriter.replaceAllOpUsesWith(unpackOp, source); + unpackOp->erase(); + return success(); + } else { + return rewriter.notifyMatchFailure(unpackOp, "Not dummy"); + } + } +}; + +static void populateEliminateDummyPackUnpack(RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + patterns.add(ctx); +} + +class PostProcessPackUnpack + : public impl::PostProcessPackUnpackBase { +public: + using impl::PostProcessPackUnpackBase< + PostProcessPackUnpack>::PostProcessPackUnpackBase; + void runOnOperation() final; +}; + +static void tppPopulateSimplifyPacking(RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + tensor::populateSimplifyPackAndUnpackPatterns(patterns); + tensor::populateFoldTensorEmptyPatterns(patterns); + tensor::PackOp::getCanonicalizationPatterns(patterns, ctx); + tensor::UnPackOp::getCanonicalizationPatterns(patterns, ctx); + tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, ctx); + tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, ctx); + tensor::CastOp::getCanonicalizationPatterns(patterns, ctx); + tensor::InsertSliceOp::getCanonicalizationPatterns(patterns, ctx); + tensor::EmptyOp::getCanonicalizationPatterns(patterns, ctx); + tensor::PadOp::getCanonicalizationPatterns(patterns, ctx); + tensor::ParallelInsertSliceOp::getCanonicalizationPatterns(patterns, ctx); + scf::ForallOp::getCanonicalizationPatterns(patterns, ctx); + // Propagate packs/unpacks only through expand shapes at this point. + // This captures the transformation scope of the replaced downstream pass. + linalg::populateDataLayoutPropagationPatterns( + patterns, [](Operation *op) { return isa(op); }); + ctx->getLoadedDialect()->getCanonicalizationPatterns( + patterns); + // patterns.add(ctx); + tensor::populateReassociativeReshapeFoldingPatterns(patterns); +} + +void PostProcessPackUnpack::runOnOperation() { + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + // constant fold packing + tppPopulateConstantFoldPack(patterns); + // simplify packing + tppPopulateSimplifyPacking(patterns); + // gc new packing related simplification + populateEliminateDummyPackUnpack(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +} // namespace gc +} // namespace mlir diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp new file mode 100644 index 000000000..4a4eace68 --- /dev/null +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -0,0 +1,677 @@ +//===- PropagateLayoutOnNamedOps.cpp - Propagate packing on linalg named ops*- +// C++-*-===// +// +// This file is only temporarily used to extend upstream or upcoming utility in +// TilingInterface, which finally aims for upstream. +// +//===----------------------------------------------------------------------===// + +#include + +#include "gc/Transforms/Transforms.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/DenseMap.h" + +#include "gc/Dialect/Linalgx/LinalgxDialect.h" +#include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "gc/Transforms/Passes.h" +namespace mlir { +namespace gc { +#define GEN_PASS_DEF_PROPAGATELAYOUTONNAMEDOPS +#include "gc/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "named-op-layout-propagation" + +using namespace mlir; +using namespace mlir::arith; +using namespace mlir::tensor; + +static SmallVector getPackedAxes(ArrayRef dimensions, + TensorLayout targetLayout) { + SmallVector result(dimensions); + // permuting on outer axis + auto outerPerm = targetLayout.getOuterAxis(); + for (size_t i = 0; i < dimensions.size(); ++i) { + result[i] = outerPerm[dimensions[i]]; + } + // inserting inner axis + auto innerPos = targetLayout.getInnerAxis(); + for (size_t i = 0; i < dimensions.size(); ++i) { + if (std::find(innerPos.begin(), innerPos.end(), dimensions[i]) != + innerPos.end()) { + result.push_back(i + targetLayout.getOuterAxis().size()); + } + } + return result; +} + +static SmallVector getPackedPermAxes(ArrayRef plainPermAxes, + TensorLayout inputLayout, + TensorLayout outputLayout) { + // dim(result, i) = dim(input, permutation[i]) + // input: permutation[i] --> output: i + // input: permutation[i] --> packed input: std::find(permutation[i]) - begin() + // output: i --> packed output: std::find(permutation[i]) - begin() + int64_t packedRank = + outputLayout.getInnerAxis().size() + outputLayout.getOuterAxis().size(); + SmallVector result(packedRank, 0); + SmallVector inputCount(inputLayout.getOuterAxis().size(), 0); + auto axisPlainToPacked = inputLayout.getPlainToPackedAxisMapping(); + for (int64_t i = 0; i < packedRank; ++i) { + // packedOutput[i] --> originalOutputAxis --> originalInputAxis + // TODO: add failed check here + int64_t originalOutputAxis = *outputLayout.getPlainAxis(i); + int64_t originalInputAxis = plainPermAxes[originalOutputAxis]; + SmallVector packedInputAxes = axisPlainToPacked[originalInputAxis]; + result[i] = packedInputAxes[inputCount[originalInputAxis]++]; + } + return result; +} + +// extends linalg::pack(...) for named ops +FailureOr packNamedOp(RewriterBase &rewriter, + linalg::LinalgOp linalgOp, + OperatorLayout opLayout) { + if (linalgOp.hasPureBufferSemantics()) + return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics"); + LLVM_DEBUG(llvm::dbgs() << "Try packing named op " + << linalgOp.getOperation()->getName() << ".\n"); + Location loc = linalgOp->getLoc(); + SmallVector packOps; + SmallVector unPackOps; + SmallVector inputsAndInits, results; + SmallVector initOperands = llvm::to_vector(llvm::map_range( + linalgOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); + SmallVector inputOperands = linalgOp.getDpsInputOperands(); + SmallVector inputLayouts = opLayout.getSupportedInputLayouts(); + SmallVector initLayouts = opLayout.getSupportedOutputLayouts(); + // check all inputs and inits are tensor, otherwise no need for layout + // propagation + bool allTensor = + llvm::all_of(inputOperands, + [](OpOperand *opOperand) { + return mlir::isa(opOperand->get().getType()); + }) && + llvm::all_of(initOperands, [](OpOperand *opOperand) { + return mlir::isa(opOperand->get().getType()); + }); + if (!allTensor) { + LLVM_DEBUG(llvm::dbgs() << "At least one input of named op: " + << linalgOp.getOperation()->getName() + << " is not tensor. Skip.\n"); + return failure("The op does not need packing."); + } + for (const auto &operandsList : {inputOperands, initOperands}) { + for (OpOperand *opOperand : operandsList) { + size_t pos = opOperand->getOperandNumber(); + Value operand = opOperand->get(); + TensorLayout targetLayout = pos >= inputLayouts.size() + ? initLayouts[pos - inputLayouts.size()] + : inputLayouts[pos]; + SmallVector outerPerm = targetLayout.getOuterAxis(); + SmallVector innerPos = targetLayout.getInnerAxis(); + SmallVector innerPackSizes = targetLayout.getTileSizes(); + Value dest = tensor::PackOp::createDestinationTensor( + rewriter, loc, operand, innerPackSizes, innerPos, outerPerm); + ShapedType operandType = cast(operand.getType()); + bool areConstantTiles = + llvm::all_of(innerPackSizes, [](OpFoldResult tile) { + return getConstantIntValue(tile).has_value(); + }); + if (areConstantTiles && operandType.hasStaticShape() && + !tensor::PackOp::requirePaddingValue( + operandType.getShape(), innerPos, + cast(dest.getType()).getShape(), {}, + innerPackSizes)) { + packOps.push_back(rewriter.create( + loc, operand, dest, innerPos, innerPackSizes, std::nullopt, + outerPerm)); + } else { + // TODO: value of the padding attribute should be determined by + // consumers. + auto zeroAttr = + rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); + Value zero = rewriter.create(loc, zeroAttr); + packOps.push_back(rewriter.create( + loc, operand, dest, innerPos, innerPackSizes, zero, outerPerm)); + } + inputsAndInits.push_back(packOps.back()); + } + } + + // Step 3. Build the packed op, use the type of `inits` as result types. + ValueRange inputs = + ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs()); + ValueRange inits = + ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits()); + // TODO: deal with generic + linalg::LinalgOp packedLinalgOp; + if (auto reduceOp = dyn_cast(&linalgOp)) { + SmallVector packedAxes = + getPackedAxes(reduceOp->getDimensions(), inputLayouts[0]); + packedLinalgOp = rewriter.create( + loc, inits.getTypes(), inputs, inits, packedAxes); + packedLinalgOp->getRegion(0).takeBody(linalgOp->getRegion(0)); + } else if (auto broadcastOp = dyn_cast(&linalgOp)) { + SmallVector packedAxes = + getPackedAxes(broadcastOp->getDimensions(), initLayouts[0]); + packedLinalgOp = rewriter.create(loc, inputs[0], + inits[0], packedAxes); + } else if (auto transposeOp = dyn_cast(&linalgOp)) { + SmallVector packedPermAxes = getPackedPermAxes( + transposeOp->getPermutation(), inputLayouts[0], initLayouts[0]); + packedLinalgOp = rewriter.create( + loc, inputs[0], inits[0], packedPermAxes); + } else if (isa(linalgOp) || + isa(linalgOp) || isa(linalgOp) || + isa(linalgOp) || isa(linalgOp)) { + return failure( + "Packing logic not implemented for SoftMax/Generic/Map/Yield/Index."); + } else { + packedLinalgOp = mlir::clone( + rewriter, linalgOp, SmallVector{inputsAndInits.back().getType()}, + inputsAndInits); + } + + // Step 4. Unpack all the op results. + for (OpResult result : packedLinalgOp->getResults()) { + int64_t resultNum = result.getResultNumber(); + tensor::PackOp maybePackedInit = + inits[resultNum].getDefiningOp(); + if (!maybePackedInit) { + results.push_back(result); + continue; + } + // Build the symmetrical UnPackOp to the existing PackOp. + unPackOps.push_back(rewriter.create( + packedLinalgOp->getLoc(), result, maybePackedInit.getSource(), + maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles(), + maybePackedInit.getOuterDimsPerm())); + results.push_back(unPackOps.back()); + } + + // Step 5. Replace `linalgOp`. + rewriter.replaceOp(linalgOp, results); + + // Return packedLinalgOp. + return linalg::PackResult{ + packOps, cast(packedLinalgOp.getOperation()), + unPackOps}; +} + +// check whether the op is already packed or not +static bool checkPacked(Operation *op, const OperatorLayout &opLayout) { + // check whether rank match + if (auto linalgOp = dyn_cast(op)) { + assert(linalgOp.getDpsInits().size() == + opLayout.getSupportedOutputLayouts().size() && + linalgOp.getDpsInputs().size() == + opLayout.getSupportedInputLayouts().size()); + for (auto [index, layout] : + llvm::enumerate(opLayout.getSupportedInputLayouts())) { + // if dimension mismatch, then the op itself is already packed + if (layout.getOuterAxis().size() != + cast(linalgOp.getDpsInputs()[index].getType()) + .getShape() + .size()) + return true; + } + for (auto [index, layout] : + llvm::enumerate(opLayout.getSupportedOutputLayouts())) { + // if dimension mismatch, then the op itself is already packed + if (layout.getOuterAxis().size() != + cast(linalgOp.getDpsInits()[index].getType()) + .getShape() + .size()) + return true; + } + } else { + assert(op->getNumOperands() == 1 && op->getNumResults() == 1); + } + return false; +} + +using ControlPackNamedOpsFn = + std::function(Operation *)>; + +class PropagateLayoutOnNamedOps + : public impl::PropagateLayoutOnNamedOpsBase { +public: + using impl::PropagateLayoutOnNamedOpsBase< + PropagateLayoutOnNamedOps>::PropagateLayoutOnNamedOpsBase; + void runOnOperation() final; +}; + +LogicalResult graphAlreadyPacked(MLIRContext *ctx, mlir::Operation *graph) { + IRRewriter rewriter(ctx); + auto walk = graph->walk([&](Operation *op) { + if (mlir::gc::utils::isPackableNamedOp(op) && op->hasAttr("packed")) { + LLVM_DEBUG(llvm::dbgs() + << "Graph already packed. Stop layout propagation.\n"); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (walk.wasInterrupted()) { + return failure(); + } + return success(); +} + +LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, + ControlPackNamedOpsFn controlFn) { + IRRewriter rewriter(ctx); + auto walk = graph->walk([&](Operation *op) { + if (mlir::gc::utils::isPackableNamedOp(op)) { + LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName() << " visited.\n"); + FailureOr opLayout = controlFn(op); + if (failed(opLayout)) { + LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName() + << " does not have layout information.\n"); + return WalkResult::skip(); + } + if ((*opLayout).isPlain()) { + LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName() + << " has plain layout, skip packing.\n"); + return WalkResult::advance(); + } + // pack op into ideal layout + LLVM_DEBUG(llvm::dbgs() + << "Op " << op->getName() << "'s inferred layout:\n" + << *opLayout << "\n"); + // insert pack + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + if (checkPacked(op, *opLayout)) { + LLVM_DEBUG(llvm::dbgs() + << "Op " << op->getName() << " already packed.\n"); + return WalkResult::advance(); + } + if (auto linalgOp = dyn_cast(op)) { + FailureOr packedOp = + packNamedOp(rewriter, linalgOp, *opLayout); + if (failed(packedOp)) { + return WalkResult::skip(); + } else { + packedOp->packedLinalgOp->setAttr("packed", + rewriter.getBoolAttr(true)); + } + } else if (auto expandShapeOp = dyn_cast(op)) { + // Location loc = expandShapeOp->getLoc(); + // auto inputLayout = opLayout->getSupportedInputLayouts()[0]; + // auto outputLayout = opLayout->getSupportedOutputLayouts()[0]; + // Value dest = tensor::PackOp::createDestinationTensor( + // rewriter, loc, expandShapeOp.getSrc(), + // inputLayout.getTileSizes(), inputLayout.getInnerAxis(), + // inputLayout.getOuterAxis()); + // Value packedSource = rewriter.create( + // loc, expandShapeOp.getSrc(), dest, inputLayout.getInnerAxis(), + // inputLayout.getTileSizes(), std::nullopt, + // inputLayout.getOuterAxis()); + // auto resultType = RankedTensorType::get( + // expandShapeOp.getStaticOutputShape(), + // expandShapeOp.getSrcType().getElementType()); + // RankedTensorType resultPackType = tensor::PackOp::inferPackedType( + // resultType, vector::getAsIntegers(outputLayout.getTileSizes()), + // outputLayout.getInnerAxis(), outputLayout.getOuterAxis()); + // auto reassocExpand = getReassociationIndicesForReshape( + // cast(dest.getType()), resultPackType); + // auto packedExpandShape = rewriter.create( + // loc, expandShapeOp.getSrcType().getElementType(), packedSource, + // *reassocExpand); + // Value result = rewriter.create( + // packedExpandShape->getLoc(), packedExpandShape, + // packedExpandShape, outputLayout.getInnerAxis(), + // outputLayout.getTileSizes(), outputLayout.getOuterAxis()); + // rewriter.replaceOp(expandShapeOp, result); + } + } + return WalkResult::advance(); + }); + return success(); +} + +static void createAndReplaceWithGenericVNNIMatmul( + RewriterBase &rewriter, MLIRContext *context, SmallVector inputs, + SmallVector inits, int64_t batchDimSize, int64_t blockingFactor, + Operation *matmulOp) { + AffineMap mapInput, mapWeight, mapOutput; + int64_t dims = batchDimSize + 7; + SmallVector exprs(dims); + // dims is in order B1, ..., Bn, M, N, K, m, n, k, vnni + bindDimsList(context, exprs); + SmallVector batchExprs(exprs.begin(), + exprs.begin() + batchDimSize); + AffineExpr M = exprs[batchDimSize], N = exprs[batchDimSize + 1], + K = exprs[batchDimSize + 2], m = exprs[batchDimSize + 3], + n = exprs[batchDimSize + 4], k = exprs[batchDimSize + 5], + vnni = exprs[batchDimSize + 6]; + SmallVector resultA{M, K, m, k}; + SmallVector resultB{N, K, k.floorDiv(blockingFactor), n, vnni}; + SmallVector resultC{M, N, m, n}; + resultA.insert(resultA.begin(), batchExprs.begin(), batchExprs.end()); + resultB.insert(resultB.begin(), batchExprs.begin(), batchExprs.end()); + resultC.insert(resultC.begin(), batchExprs.begin(), batchExprs.end()); + mapInput = AffineMap::get(/*dims=*/dims, /*symbols=*/0, resultA, context); + mapWeight = AffineMap::get(/*dims=*/dims, /*symbols=*/0, resultB, context); + mapOutput = AffineMap::get(/*dims=*/dims, /*symbols=*/0, resultC, context); + SmallVector batchIterators( + batchDimSize, mlir::utils::IteratorType::parallel); + SmallVector iterators{ + mlir::utils::IteratorType::parallel, + mlir::utils::IteratorType::parallel, + mlir::utils::IteratorType::reduction, + mlir::utils::IteratorType::parallel, + mlir::utils::IteratorType::parallel, + mlir::utils::IteratorType::reduction, + mlir::utils::IteratorType::reduction}; + iterators.insert(iterators.begin(), batchIterators.begin(), + batchIterators.end()); + auto replacementOp = rewriter.create( + matmulOp->getLoc(), inits[0].getType(), inputs, inits, + ArrayRef{mapInput, mapWeight, mapOutput}, iterators, + /*doc=*/"", /*libraryCall=*/""); + rewriter.inlineRegionBefore(matmulOp->getRegion(0), replacementOp.getRegion(), + replacementOp.getRegion().begin()); + rewriter.replaceOp(matmulOp, replacementOp.getResult(0)); +} + +template +static LogicalResult packVNNIMMT4D(RewriterBase &rewriter, OpTy mmt4dOp) { + auto elementType = getElementTypeOrSelf(mmt4dOp.getInputs()[0].getType()); + if (!elementType.isBF16() && !elementType.isInteger(8)) + return rewriter.notifyMatchFailure(mmt4dOp, "require bf16/int8 data type"); + Location loc = mmt4dOp.getLoc(); + // BNKnk --> BNKkn2k + int64_t weightRank = + cast(mmt4dOp.getInputs()[1].getType()).getShape().size(); + // pack innermost k axis + SmallVector innerPos{weightRank - 1}; + int64_t blockingFactor = elementType.isBF16() ? 2 : 4; + SmallVector tileSize{rewriter.getIndexAttr(blockingFactor)}; + // BNKnk --> BNKkn2k + int64_t batchDimSize = weightRank - 4; + SmallVector batchPerm(batchDimSize, 0); + std::iota(batchPerm.begin(), batchPerm.end(), 0); + SmallVector outerPerm{batchDimSize, batchDimSize + 1, + batchDimSize + 3, batchDimSize + 2}; + outerPerm.insert(outerPerm.begin(), batchPerm.begin(), batchPerm.end()); + OpOperand *RHSOperand = mmt4dOp.getDpsInputOperand(1); + Value dest = tensor::PackOp::createDestinationTensor( + rewriter, loc, RHSOperand->get(), tileSize, innerPos, outerPerm); + Value VNNIPack = + rewriter.create(loc, RHSOperand->get(), dest, innerPos, + tileSize, std::nullopt, outerPerm); + SmallVector inputsValues{mmt4dOp.getInputs()[0], VNNIPack}; + if (!batchDimSize) { + auto vnniOp = rewriter.create( + loc, mmt4dOp.getDpsInits().getTypes(), inputsValues, + mmt4dOp.getDpsInits()); + rewriter.replaceOp(mmt4dOp, vnniOp); + } else { + mlir::gc::createAndReplaceWithGenericVNNIMatmul( + rewriter, mmt4dOp.getContext(), inputsValues, mmt4dOp.getDpsInits(), + batchDimSize, blockingFactor, mmt4dOp); + } + return success(); +} + +// strictly check whether the packed matmul is BMKmk & BNKkn +static bool isMM4DMatmul(linalg::GenericOp matmulOp) { + SmallVector indexingMaps = matmulOp.getIndexingMapsArray(); + auto iterators = matmulOp.getIteratorTypesArray(); + AffineMap inputMap = indexingMaps[0], weightMap = indexingMaps[1], + outputMap = indexingMaps[2]; + int64_t inputRank = inputMap.getNumResults(), + weightRank = weightMap.getNumResults(), + outputRank = outputMap.getNumResults(); + // check rank + if ((weightRank < 4) || (inputRank != weightRank) || + (weightRank != outputRank)) + return false; + // check mapping --> find batch, M, N, K + FailureOr res = + mlir::linalg::inferContractionDims(matmulOp); + assert(succeeded(res) && "unexpected failure in infer contraction dims"); + unsigned batchDimSize = res->batch.size(); + SmallVector expectedM{batchDimSize, batchDimSize + 3}; + SmallVector expectedN{batchDimSize + 1, batchDimSize + 4}; + SmallVector expectedK{batchDimSize + 2, batchDimSize + 5}; + if (expectedM == res->m && expectedN == res->n && expectedK == res->k) + return true; + return false; +} + +/* +If possible, pack to Mm2DVnniOp or Mm4DVnniOp. +If not possible, pack to GenericOp. +*/ +static LogicalResult packVNNIGeneric(RewriterBase &rewriter, + linalg::GenericOp matmulOp) { + if (matmulOp.getDpsInputs().size() != 2) + return rewriter.notifyMatchFailure(matmulOp, "require 2 inputs"); + + auto elementType = getElementTypeOrSelf(matmulOp.getInputs()[0].getType()); + if (!elementType.isBF16() && !elementType.isInteger(8)) + return rewriter.notifyMatchFailure(matmulOp, "require bf16/int8 data type"); + + if (matmulOp.hasDynamicShape()) + return rewriter.notifyMatchFailure(matmulOp, "require static shape"); + + if (matmulOp.hasPureBufferSemantics()) + return rewriter.notifyMatchFailure(matmulOp, "require tensor semantics"); + + // isContractionInterfaceImpl checks the following restrictions: + // 1. has 2 inputs && 1 outputs + // 2. has >=1 reduction loop + // 3. all affine maps are projected permutations: + // a. no symbols or zeros in result + // b. result is a non-duplicated subset of input + // 4. op body contains both mul&&add + if (!mlir::linalg::isaContractionOpInterface(matmulOp)) + return rewriter.notifyMatchFailure(matmulOp, "require matmul semantics"); + + // check whether generic op is packed as BMKmk & BNKkn + if (!isMM4DMatmul(matmulOp)) + return rewriter.notifyMatchFailure(matmulOp, + "require packed MM4D matmul semantics"); + + OpOperand &weight = matmulOp->getOpOperand(1); + // TODO(yifei): check ISA + Location loc = matmulOp.getLoc(); + int64_t blockingFactor = elementType.isBF16() ? 2 : 4; + SmallVector tileSize{rewriter.getIndexAttr(blockingFactor)}; + // get weight's rank + int64_t weightRank = + cast(weight.get().getType()).getShape().size(); + auto innerPos = SmallVector{weightRank - 2}; + // pack weight. + Value dest = tensor::PackOp::createDestinationTensor( + rewriter, loc, weight.get(), tileSize, innerPos, SmallVector{}); + Value VNNIPack = rewriter.create( + loc, weight.get(), dest, innerPos, tileSize, std::nullopt); + + int64_t batchDimSize = weightRank - 4; + SmallVector inputsValues{matmulOp.getInputs()[0], VNNIPack}; + if (!batchDimSize) { + Value operandC = matmulOp.getDpsInits()[0]; + auto VNNIMatmulOp = rewriter.create( + loc, operandC.getType(), inputsValues, ValueRange{operandC}); + rewriter.replaceOp(matmulOp, VNNIMatmulOp); + } else { + mlir::gc::createAndReplaceWithGenericVNNIMatmul( + rewriter, matmulOp.getContext(), inputsValues, matmulOp.getDpsInits(), + batchDimSize, blockingFactor, matmulOp); + } + return success(); +} + +template struct PackVNNI : public OpRewritePattern { + PackVNNI(MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(OpTy linalgOp, + PatternRewriter &rewriter) const override { + if (failed(packVNNIMMT4D(rewriter, linalgOp))) + return failure(); + return success(); + } +}; + +template <> +struct PackVNNI + : public OpRewritePattern { + PackVNNI(MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} + LogicalResult matchAndRewrite(linalg::GenericOp matmulOp, + PatternRewriter &rewriter) const override { + if (failed(packVNNIGeneric(rewriter, matmulOp))) + return failure(); + return success(); + } +}; + +/* +Match patterns like broadcast + pack, uplift pack +*/ +struct UpliftPackOverBroadcast : public OpRewritePattern { + UpliftPackOverBroadcast(MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} + LogicalResult matchAndRewrite(tensor::PackOp pack, + PatternRewriter &rewriter) const override { + auto broadcastOp = pack.getSource().getDefiningOp(); + if (!broadcastOp || !broadcastOp.getResult()[0].hasOneUse()) { + return failure(); + } + SmallVector innerTileSizes = pack.getStaticTiles(); + SmallVector innerDimsPos(pack.getInnerDimsPos()); + SmallVector outerDimsPerm(pack.getOuterDimsPerm()); + int64_t rank = + cast(pack.getSource().getType()).getShape().size(); + if (outerDimsPerm.empty()) { + outerDimsPerm.resize(rank); + std::iota(outerDimsPerm.begin(), outerDimsPerm.end(), 0); + } + ArrayRef broadcastAxis = broadcastOp.getDimensions(); + SmallVector newInnerDimsPos, newOuterDimsPerm, packedBroadcastAxis; + SmallVector newInnerTileSizes; + llvm::SmallDenseMap axisMapping; + int64_t axisCounter = 0; + for (int64_t axis = 0; axis < rank; ++axis) { + if (std::find(broadcastAxis.begin(), broadcastAxis.end(), axis) == + broadcastAxis.end()) { + // if the axis is not broadcasted, keep it + axisMapping[axis] = axisCounter++; + } + } + // update broadcast dims + for (auto [index, axis] : llvm::enumerate(outerDimsPerm)) { + if (std::find(broadcastAxis.begin(), broadcastAxis.end(), axis) != + broadcastAxis.end()) { + packedBroadcastAxis.push_back(index); + } + } + for (auto [index, axis] : llvm::enumerate(innerDimsPos)) { + if (std::find(broadcastAxis.begin(), broadcastAxis.end(), axis) != + broadcastAxis.end()) { + packedBroadcastAxis.push_back(index + rank); + } + } + // update packing axis + for (auto [index, axis] : llvm::enumerate(outerDimsPerm)) { + if (std::find(broadcastAxis.begin(), broadcastAxis.end(), axis) == + broadcastAxis.end()) { + newOuterDimsPerm.push_back(axisMapping[axis]); + } + } + for (auto [index, axis] : llvm::enumerate(innerDimsPos)) { + if (std::find(broadcastAxis.begin(), broadcastAxis.end(), axis) == + broadcastAxis.end()) { + newInnerDimsPos.push_back(axisMapping[axis]); + newInnerTileSizes.push_back( + rewriter.getIndexAttr(innerTileSizes[index])); + } + } + // replace ops + auto loc = broadcastOp.getLoc(); + auto dest = tensor::PackOp::createDestinationTensor( + rewriter, loc, broadcastOp.getDpsInputs()[0], newInnerTileSizes, + newInnerDimsPos, newOuterDimsPerm); + Value packedSource = rewriter.create( + loc, broadcastOp.getDpsInputs()[0], dest, newInnerDimsPos, + newInnerTileSizes, + /*padding=*/std::nullopt, newOuterDimsPerm); + auto newBroadcastOp = rewriter.create( + loc, packedSource, pack.getDest(), packedBroadcastAxis); + rewriter.replaceOp(pack, newBroadcastOp.getResults()); + return success(); + } +}; + +void PropagateLayoutOnNamedOps::runOnOperation() { + MLIRContext *ctx = &getContext(); + mlir::Operation *graph = getOperation(); + // stage0: check if the graph has been packed + if (failed(graphAlreadyPacked(ctx, graph))) + return; + // stage1: pack matmul + RewritePatternSet packMatmulPatterns(&getContext()); + mlir::linalg::ControlBlockPackMatmulFn packMatmulControlFn = + [&](linalg::LinalgOp op) -> mlir::linalg::BlockPackMatmulOptions { + mlir::linalg::BlockPackMatmulOptions options; + auto &layoutAnalysisResult = getAnalysis(); + auto matmulLayout = *(layoutAnalysisResult.getOpLayout(op)); + TensorLayout LHSLayout = matmulLayout.getSupportedInputLayouts()[0]; + TensorLayout RHSLayout = matmulLayout.getSupportedInputLayouts()[1]; + // hardcode to let B side to be NKkn + options.rhsTransposeOuterBlocks = true; + options.rhsTransposeInnerBlocks = false; + assert(LHSLayout.getTileSizes()[1] == RHSLayout.getTileSizes()[0] && + "Inconsistent matmul tile size."); + options.blockFactors.push_back( + *getConstantIntValue(LHSLayout.getTileSizes()[0])); + options.blockFactors.push_back( + *getConstantIntValue(LHSLayout.getTileSizes()[1])); + options.blockFactors.push_back( + *getConstantIntValue(RHSLayout.getTileSizes()[1])); + return options; + }; + linalg::populateBlockPackMatmulPatterns(packMatmulPatterns, + packMatmulControlFn); + if (failed( + applyPatternsAndFoldGreedily(graph, std::move(packMatmulPatterns)))) + return signalPassFailure(); + + // stage2: pack VNNI + RewritePatternSet packVNNIPatterns(&getContext()); + packVNNIPatterns.add, PackVNNI, + PackVNNI>(ctx); + if (failed(applyPatternsAndFoldGreedily(graph, std::move(packVNNIPatterns)))) + return signalPassFailure(); + + // stage3: propagate layout on other named ops + ControlPackNamedOpsFn layoutControlFn = + [&](Operation *op) -> FailureOr { + auto &layoutAnalysisResult = getAnalysis(); + return layoutAnalysisResult.getOpLayout(op); + }; + if (failed(namedOpLayoutPropagation(ctx, graph, layoutControlFn))) + return signalPassFailure(); + + // stage4: uplift pack through broadcast + RewritePatternSet upliftPatterns(&getContext()); + upliftPatterns.add(ctx); + if (failed(applyPatternsAndFoldGreedily(graph, std::move(upliftPatterns)))) + return signalPassFailure(); +} + +} // namespace gc +} // namespace mlir diff --git a/packMatmul.patch b/packMatmul.patch new file mode 100644 index 000000000..ef695240c --- /dev/null +++ b/packMatmul.patch @@ -0,0 +1,32 @@ +diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp +index 91d4efa3372b..f3f61ff92140 100644 +--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp ++++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp +@@ -210,6 +210,19 @@ linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, + packedMatmul->packOps[1] = packedRhs->transposedPackOp; + packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp; + ++ // rewrite generic to mmt4d ++ if (!options->lhsTransposeOuterBlocks && !options->lhsTransposeInnerBlocks && ++ options->rhsTransposeOuterBlocks && options->rhsTransposeInnerBlocks && ++ options->mnkOrder == SmallVector{0, 1, 2}) { ++ auto originalLinalgOp = packedMatmul->packedLinalgOp; ++ rewriter.setInsertionPoint(originalLinalgOp); ++ auto mmt4d = rewriter.create( ++ originalLinalgOp.getLoc(), originalLinalgOp.getDpsInits().getTypes(), ++ originalLinalgOp.getDpsInputs(), originalLinalgOp.getDpsInits()); ++ rewriter.replaceOp(originalLinalgOp, mmt4d); ++ packedMatmul->packedLinalgOp = mmt4d; ++ } ++ + return packedMatmul; + } + +@@ -307,6 +320,7 @@ struct LinalgBlockPackMatmul + }; + } // namespace + ++// extend to transform to mmt4d or batch_mmt4d + void linalg::populateBlockPackMatmulPatterns( + RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) { + patterns.add, diff --git a/test/mlir/test/gc/Transforms/named-op-layout-propagation.mlir b/test/mlir/test/gc/Transforms/named-op-layout-propagation.mlir new file mode 100644 index 000000000..d3ca62e73 --- /dev/null +++ b/test/mlir/test/gc/Transforms/named-op-layout-propagation.mlir @@ -0,0 +1,12 @@ +// RUN: gc-opt %s --split-input-file --propagate-layout-on-named-ops | FileCheck %s + +// CHECK-LABEL: @matmul_add +func.func @matmul_add(%arg0: tensor<128x64xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<32xf32>) -> tensor<128x32xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<128x32xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x32xf32>) -> tensor<128x32xf32> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x64xf32>, tensor<64x32xf32>) outs(%1 : tensor<128x32xf32>) -> tensor<128x32xf32> + %3 = linalg.broadcast ins(%arg2 : tensor<32xf32>) outs(%0 : tensor<128x32xf32>) dimensions = [0] + %4 = linalg.add ins(%2, %3 : tensor<128x32xf32>, tensor<128x32xf32>) outs(%0 : tensor<128x32xf32>) -> tensor<128x32xf32> + return %4 : tensor<128x32xf32> +} diff --git a/test/mlir/test/gc/Transforms/pack-matmul.mlir b/test/mlir/test/gc/Transforms/pack-matmul.mlir new file mode 100644 index 000000000..2c2ab8a5f --- /dev/null +++ b/test/mlir/test/gc/Transforms/pack-matmul.mlir @@ -0,0 +1,58 @@ +// RUN: gc-opt %s --split-input-file --propagate-layout-on-named-ops | FileCheck %s + +// CHECK-LABEL: @single_matmul_f32 +func.func @single_matmul_f32(%arg0: tensor<128x64xf32>, %arg1: tensor<64x32xf32>) -> tensor<128x32xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<128x32xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x32xf32>) -> tensor<128x32xf32> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x64xf32>, tensor<64x32xf32>) outs(%0 : tensor<128x32xf32>) -> tensor<128x32xf32> + return %2 : tensor<128x32xf32> +} +// CHECK-COUNT-3: tensor.pack +// CHECK-COUNT-1: linalg.generic +// CHECK-COUNT-1: tensor.unpack + +// CHECK-LABEL: @single_matmul_bf16 +func.func @single_matmul_bf16(%arg0: tensor<128x64xbf16>, %arg1: tensor<64x32xbf16>) -> tensor<128x32xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<128x32xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x32xbf16>) -> tensor<128x32xbf16> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x64xbf16>, tensor<64x32xbf16>) outs(%0 : tensor<128x32xbf16>) -> tensor<128x32xbf16> + return %2 : tensor<128x32xbf16> +} +// CHECK-COUNT-4: tensor.pack +// CHECK-COUNT-1: linalgx.mm4d_vnni +// CHECK-COUNT-1: tensor.unpack + +// CHECK-LABEL: @single_batch_matmul_bf16 +func.func @single_batch_matmul_bf16(%arg0: tensor<64x128x64xbf16>, %arg1: tensor<64x64x32xbf16>) -> tensor<64x128x32xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<64x128x32xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<64x128x32xbf16>) -> tensor<64x128x32xbf16> + %2 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<64x128x64xbf16>, tensor<64x64x32xbf16>) outs(%0 : tensor<64x128x32xbf16>) -> tensor<64x128x32xbf16> + return %2 : tensor<64x128x32xbf16> +} +// CHECK-COUNT-4: tensor.pack +// CHECK-COUNT-1: linalg.generic +// CHECK-COUNT-1: tensor.unpack + +func.func @pack_vnni_mmt4d(%arg0: tensor<4x2x32x32xbf16>, %arg1: tensor<1x2x32x32xbf16>) -> tensor<4x1x32x32xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<4x1x32x32xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<4x1x32x32xbf16>) -> tensor<4x1x32x32xbf16> + %2 = linalg.mmt4d ins(%arg0, %arg1 : tensor<4x2x32x32xbf16>, tensor<1x2x32x32xbf16>) outs(%0 : tensor<4x1x32x32xbf16>) -> tensor<4x1x32x32xbf16> + return %2 : tensor<4x1x32x32xbf16> +} +// CHECK-COUNT-1: tensor.pack +// CHECK-COUNT-1: linalgx.mm4d_vnni + +func.func @pack_vnni_batchmmt4d(%arg0: tensor<4x4x2x32x32xbf16>, %arg1: tensor<4x1x2x32x32xbf16>) -> tensor<4x4x1x32x32xbf16> { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<4x4x1x32x32xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<4x4x1x32x32xbf16>) -> tensor<4x4x1x32x32xbf16> + %2 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<4x4x2x32x32xbf16>, tensor<4x1x2x32x32xbf16>) outs(%0 : tensor<4x4x1x32x32xbf16>) -> tensor<4x4x1x32x32xbf16> + return %2 : tensor<4x4x1x32x32xbf16> +} +// CHECK-COUNT-1: tensor.pack +// CHECK-COUNT-1: linalg.generic + From 966fc8fcb046cb87784f0aea74b40a33fe3e726e Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Sun, 4 Aug 2024 22:54:20 -0700 Subject: [PATCH 02/23] fix layout propagation on expand shape --- lib/gc/Analysis/GlobalAnalysis.cpp | 123 +++++++++++++++++--------- lib/gc/Transforms/PropagateLayout.cpp | 102 ++++++++++----------- 2 files changed, 131 insertions(+), 94 deletions(-) diff --git a/lib/gc/Analysis/GlobalAnalysis.cpp b/lib/gc/Analysis/GlobalAnalysis.cpp index b634daea1..bf09074ba 100644 --- a/lib/gc/Analysis/GlobalAnalysis.cpp +++ b/lib/gc/Analysis/GlobalAnalysis.cpp @@ -55,13 +55,16 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, return ss; } -// inferring the relationship of two indexing map -// j -> i, means j is represented as the same symbol as i -// we don't allow duplicate in symbols -// e.g. if 2 j corresponding to 1 i, then return failure +// infer the relation between two indexing maps +// returns target dim -> base dim, means target is the same as input +// we don't allow duplication, e.g. 2 target corresponding to 1 base static FailureOr> inferIndexingMapRelation(AffineMap indexingMapBase, AffineMap indexingMapTarget) { + // symbols are not allowed to occur + if (indexingMapBase.getNumSymbols() != 0 || + indexingMapTarget.getNumSymbols() != 0) + return failure(); DenseMap res; ArrayRef resultsBase = indexingMapBase.getResults(); ArrayRef resultsTarget = indexingMapTarget.getResults(); @@ -70,6 +73,7 @@ inferIndexingMapRelation(AffineMap indexingMapBase, auto base = dyn_cast(resultsBase[i]); auto target = dyn_cast(resultsTarget[j]); if (base && target && base.getPosition() == target.getPosition()) { + // dim j already mapped to certain i if (res.find(j) != res.end()) return failure(); res[j] = i; @@ -91,7 +95,7 @@ inferIndexingMapRelation(AffineMap indexingMapBase, return res; } -// given j --> i and max rank of i, return i --> j +// given target --> base and max rank of base, return base --> target static DenseMap getReversedIndexMap(const DenseMap &indexMap, size_t maxRank) { @@ -109,7 +113,7 @@ getReversedIndexMap(const DenseMap &indexMap, return res; } -static FailureOr +static TensorLayout inferTargetLayout(TensorLayout layoutBase, const DenseMap &indexMap) { SmallVector baseOuterAxis = layoutBase.getOuterAxis(); @@ -177,6 +181,39 @@ getPackingAxis(int64_t numRank, bool transposed) { return std::make_pair(outerAxisPerm, innerAxisPos); } +// copied from mlir +static SmallVector +projectToInnerMostNonUnitDimsPos(ArrayRef dimsPos, + ArrayRef reassocIndices, + ArrayRef targetShape) { + SmallVector projectedDimsPos; + for (auto pos : dimsPos) { + // In the case all dims are unit, this will return the inner-most one. + int64_t projectedPos = reassocIndices[pos].back(); + for (auto i : llvm::reverse(reassocIndices[pos])) { + int64_t dim = targetShape[i]; + if (dim > 1 || ShapedType::isDynamic(dim)) { + projectedPos = i; + break; + } + } + projectedDimsPos.push_back(projectedPos); + } + return projectedDimsPos; +} + +/// Check if all dims in dimsPos are divisible by the corresponding tile sizes. +static bool isDimsDivisibleByTileSizes(ArrayRef dimsPos, + ArrayRef shape, + ArrayRef tileSizes) { + for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) { + int64_t dim = shape[pos]; + if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0) + return false; + } + return true; +} + GlobalAnalysis::GlobalAnalysis(Operation *root) { root->walk([&](Operation *op) { if (auto linalgOp = dyn_cast(op)) { @@ -198,9 +235,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { } // ------ Get Current Op's Suggested Layout & Do Propagation ------ IRRewriter rewriter(linalgOp); - // TODO: extend to packed/vnni matmul ops if (supportedContractionNamedOpList(linalgOp)) { - // get input and output rank + // infer layout for linalg contraction named ops auto ARank = cast(linalgOp.getDpsInputs()[0].getType()) .getShape() .size(); @@ -242,29 +278,36 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { OperatorLayout suggestedLayout({ALayout, BLayout}, {CLayout}); layoutCache[linalgOp] = suggestedLayout; } else if (!mlir::linalg::isaContractionOpInterface(linalgOp) && + !mlir::linalg::isaConvolutionOpInterface(linalgOp) && !supportedContractionNamedOpList(linalgOp)) { + // infer layout for non-contraction/non-convolution linalg named ops + // and linalg generic ops SmallVector inputLayouts, outputLayouts; size_t targetIdx = getTargetInputIdx(curInputLayouts); - // TODO(yifei): wisely choose the input format basis - // Let's only refer to input[0] for now for (size_t i = 0; i < curInputs.size(); ++i) { // getMatchingIndexingMap if (i != targetIdx) { - auto res = inferIndexingMapRelation( + auto indexRelation = inferIndexingMapRelation( linalgOp.getMatchingIndexingMap(curInputs[targetIdx]), linalgOp.getMatchingIndexingMap(curInputs[i])); + if (failed(indexRelation)) { + return WalkResult::skip(); + } TensorLayout inputLayout = - *inferTargetLayout(curInputLayouts[targetIdx], *res); + inferTargetLayout(curInputLayouts[targetIdx], *indexRelation); inputLayouts.push_back(inputLayout); } else { inputLayouts.push_back(curInputLayouts[targetIdx]); } } - auto res_out = inferIndexingMapRelation( + auto indexRelation = inferIndexingMapRelation( linalgOp.getMatchingIndexingMap(curInputs[targetIdx]), linalgOp.getIndexingMapMatchingResult(curResults[0])); + if (failed(indexRelation)) { + return WalkResult::skip(); + } TensorLayout outputLayout = - *inferTargetLayout(curInputLayouts[targetIdx], *res_out); + inferTargetLayout(curInputLayouts[targetIdx], *indexRelation); outputLayouts.push_back(outputLayout); OperatorLayout suggestedLayout(inputLayouts, outputLayouts); layoutCache[linalgOp] = suggestedLayout; @@ -283,7 +326,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { OperatorLayout suggestedLayout(inputLayouts, outputLayouts); layoutCache[padOp] = suggestedLayout; } else if (auto expandShapeOp = dyn_cast(op)) { - auto reassociation = expandShapeOp.getReassociation(); + SmallVector reassocIndices = + expandShapeOp.getReassociationIndices(); auto staticOutputShape = expandShapeOp.getStaticOutputShape(); auto parent = expandShapeOp.getSrc().getDefiningOp(); auto inputShape = expandShapeOp.getSrcType().getShape(); @@ -291,44 +335,35 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { layoutCache.find(parent) != layoutCache.end() ? layoutCache[parent].getOutputLayout(0) : TensorLayout::createPlainLayout(inputShape.size()); - DenseMap outputInputIdxMapping, inputOutputIndexMapping; - int64_t accumulationOffset = 0; - for (int64_t i = 0; i < static_cast(reassociation.size()); ++i) { - auto subReassociation = llvm::cast(reassociation[i]); - for (int64_t j = 0; j < static_cast(subReassociation.size()); - ++j) { - if (staticOutputShape[accumulationOffset + j] == inputShape[i]) { - outputInputIdxMapping[accumulationOffset + j] = i; - inputOutputIndexMapping[i] = accumulationOffset + j; - } - } - accumulationOffset += subReassociation.size(); + SmallVector innerTileSizes; + auto intTileSizes = getConstantIntValues(curInputLayout.getTileSizes()); + if (intTileSizes) { + innerTileSizes = *intTileSizes; } - auto inputOuterAxis = curInputLayout.getOuterAxis(); - auto inputInnerAxis = curInputLayout.getInnerAxis(); - int64_t diffDifference = staticOutputShape.size() - inputShape.size(); - int64_t startIdx = 0; - SmallVector outputOuterAxis, outputInnerAxis; - for (int64_t i = 0; i < static_cast(staticOutputShape.size()); - ++i) { - if (outputInputIdxMapping.find(i) != outputInputIdxMapping.end()) { - outputOuterAxis.push_back(inputOuterAxis[outputInputIdxMapping[i]] + - diffDifference); - } else { - outputOuterAxis.push_back(startIdx++); - } + ArrayRef innerDimsPos = curInputLayout.getInnerAxis(); + ArrayRef outerDimsPerm = curInputLayout.getOuterAxis(); + SmallVector projectedInnerDimsPos = + projectToInnerMostNonUnitDimsPos(curInputLayout.getInnerAxis(), + reassocIndices, staticOutputShape); + + if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, staticOutputShape, + innerTileSizes)) { + return WalkResult::skip(); } - for (int64_t i = 0; i < static_cast(inputInnerAxis.size()); - ++i) { - outputInnerAxis.push_back(inputOutputIndexMapping[inputInnerAxis[i]]); + SmallVector newOuterDimsPerm; + for (auto outerPos : outerDimsPerm) { + newOuterDimsPerm.insert(newOuterDimsPerm.end(), + reassocIndices[outerPos].begin(), + reassocIndices[outerPos].end()); } - TensorLayout outputLayout(outputOuterAxis, outputInnerAxis, + TensorLayout outputLayout(newOuterDimsPerm, projectedInnerDimsPos, curInputLayout.getTileSizes()); SmallVector inputLayouts{curInputLayout}, outputLayouts{outputLayout}; OperatorLayout suggestedLayout(inputLayouts, outputLayouts); layoutCache[expandShapeOp] = suggestedLayout; } + return WalkResult::advance(); }); } diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp index 4a4eace68..d6298322d 100644 --- a/lib/gc/Transforms/PropagateLayout.cpp +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -76,6 +76,21 @@ static SmallVector getPackedPermAxes(ArrayRef plainPermAxes, return result; } +static int64_t applyPermutationAndReindexReassoc( + SmallVector &reassocIndices, + ArrayRef permutation) { + if (!permutation.empty()) + applyPermutationToVector(reassocIndices, permutation); + int64_t nextPos = 0; + for (ReassociationIndices &indices : reassocIndices) { + for (auto &index : indices) { + index = nextPos; + nextPos += 1; + } + } + return nextPos; +} + // extends linalg::pack(...) for named ops FailureOr packNamedOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, @@ -250,26 +265,10 @@ class PropagateLayoutOnNamedOps void runOnOperation() final; }; -LogicalResult graphAlreadyPacked(MLIRContext *ctx, mlir::Operation *graph) { - IRRewriter rewriter(ctx); - auto walk = graph->walk([&](Operation *op) { - if (mlir::gc::utils::isPackableNamedOp(op) && op->hasAttr("packed")) { - LLVM_DEBUG(llvm::dbgs() - << "Graph already packed. Stop layout propagation.\n"); - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - if (walk.wasInterrupted()) { - return failure(); - } - return success(); -} - LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, ControlPackNamedOpsFn controlFn) { IRRewriter rewriter(ctx); - auto walk = graph->walk([&](Operation *op) { + graph->walk([&](Operation *op) { if (mlir::gc::utils::isPackableNamedOp(op)) { LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName() << " visited.\n"); FailureOr opLayout = controlFn(op); @@ -300,38 +299,44 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, packNamedOp(rewriter, linalgOp, *opLayout); if (failed(packedOp)) { return WalkResult::skip(); - } else { - packedOp->packedLinalgOp->setAttr("packed", - rewriter.getBoolAttr(true)); } } else if (auto expandShapeOp = dyn_cast(op)) { - // Location loc = expandShapeOp->getLoc(); - // auto inputLayout = opLayout->getSupportedInputLayouts()[0]; - // auto outputLayout = opLayout->getSupportedOutputLayouts()[0]; - // Value dest = tensor::PackOp::createDestinationTensor( - // rewriter, loc, expandShapeOp.getSrc(), - // inputLayout.getTileSizes(), inputLayout.getInnerAxis(), - // inputLayout.getOuterAxis()); - // Value packedSource = rewriter.create( - // loc, expandShapeOp.getSrc(), dest, inputLayout.getInnerAxis(), - // inputLayout.getTileSizes(), std::nullopt, - // inputLayout.getOuterAxis()); - // auto resultType = RankedTensorType::get( - // expandShapeOp.getStaticOutputShape(), - // expandShapeOp.getSrcType().getElementType()); - // RankedTensorType resultPackType = tensor::PackOp::inferPackedType( - // resultType, vector::getAsIntegers(outputLayout.getTileSizes()), - // outputLayout.getInnerAxis(), outputLayout.getOuterAxis()); - // auto reassocExpand = getReassociationIndicesForReshape( - // cast(dest.getType()), resultPackType); - // auto packedExpandShape = rewriter.create( - // loc, expandShapeOp.getSrcType().getElementType(), packedSource, - // *reassocExpand); - // Value result = rewriter.create( - // packedExpandShape->getLoc(), packedExpandShape, - // packedExpandShape, outputLayout.getInnerAxis(), - // outputLayout.getTileSizes(), outputLayout.getOuterAxis()); - // rewriter.replaceOp(expandShapeOp, result); + Location loc = expandShapeOp->getLoc(); + auto inputLayout = opLayout->getSupportedInputLayouts()[0]; + auto outputLayout = opLayout->getSupportedOutputLayouts()[0]; + LLVM_DEBUG(llvm::dbgs() << "Input layout: " << inputLayout << ".\n"); + LLVM_DEBUG(llvm::dbgs() << "Output layout: " << outputLayout << ".\n"); + Value curSrc = expandShapeOp.getSrc(); + Value curDst = expandShapeOp.getResult(); + Value dest = tensor::PackOp::createDestinationTensor( + rewriter, loc, curSrc, inputLayout.getTileSizes(), + inputLayout.getInnerAxis(), inputLayout.getOuterAxis()); + Value packedSource = rewriter.create( + loc, curSrc, dest, inputLayout.getInnerAxis(), + inputLayout.getTileSizes(), std::nullopt, + inputLayout.getOuterAxis()); + SmallVector newReassocIndices = + expandShapeOp.getReassociationIndices(); + int64_t nextPos = applyPermutationAndReindexReassoc( + newReassocIndices, inputLayout.getOuterAxis()); + // Then add direct mapping for the inner tile dims. + for (size_t i = 0; i < inputLayout.getInnerAxis().size(); ++i) { + newReassocIndices.push_back({nextPos}); + nextPos += 1; + } + RankedTensorType newExpandType = tensor::PackOp::inferPackedType( + dyn_cast(curDst.getType()), + *getConstantIntValues(outputLayout.getTileSizes()), + outputLayout.getInnerAxis(), outputLayout.getOuterAxis()); + Value packedExpandShape = rewriter.create( + loc, newExpandType, packedSource, newReassocIndices); + auto unpackDst = tensor::UnPackOp::createDestinationTensor( + rewriter, loc, packedExpandShape, outputLayout.getTileSizes(), + outputLayout.getInnerAxis(), outputLayout.getOuterAxis()); + auto newUnPackOp = rewriter.create( + loc, packedExpandShape, unpackDst, outputLayout.getInnerAxis(), + outputLayout.getTileSizes(), outputLayout.getOuterAxis()); + rewriter.replaceOp(expandShapeOp, newUnPackOp); } } return WalkResult::advance(); @@ -619,9 +624,6 @@ struct UpliftPackOverBroadcast : public OpRewritePattern { void PropagateLayoutOnNamedOps::runOnOperation() { MLIRContext *ctx = &getContext(); mlir::Operation *graph = getOperation(); - // stage0: check if the graph has been packed - if (failed(graphAlreadyPacked(ctx, graph))) - return; // stage1: pack matmul RewritePatternSet packMatmulPatterns(&getContext()); mlir::linalg::ControlBlockPackMatmulFn packMatmulControlFn = From 8bbf601a6135d16069198b67b188390c1624a9d5 Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Mon, 5 Aug 2024 19:21:59 -0700 Subject: [PATCH 03/23] update debug info --- lib/gc/Analysis/GlobalAnalysis.cpp | 31 +++++++++++++++++---------- lib/gc/Transforms/PropagateLayout.cpp | 2 -- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/lib/gc/Analysis/GlobalAnalysis.cpp b/lib/gc/Analysis/GlobalAnalysis.cpp index bf09074ba..9e5f94da4 100644 --- a/lib/gc/Analysis/GlobalAnalysis.cpp +++ b/lib/gc/Analysis/GlobalAnalysis.cpp @@ -44,13 +44,15 @@ bool TensorLayout::operator==(const TensorLayout &layout) { llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, const OperatorLayout &opLayout) { - for (auto &&[idx, layoutCache] : - llvm::enumerate(opLayout.getSupportedInputLayouts())) { - ss << "input " << idx << "'s layout: " << layoutCache << "\n"; + if (!opLayout.getSupportedInputLayouts().empty()) { + ss << "Input layouts: "; + llvm::interleave(opLayout.getSupportedInputLayouts(), ss, "; "); + ss << ". "; } - for (auto &&[idx, layoutCache] : - llvm::enumerate(opLayout.getSupportedOutputLayouts())) { - ss << "output " << idx << "'s layout: " << layoutCache << "\n"; + if (!opLayout.getSupportedOutputLayouts().empty()) { + ss << "Output layouts: "; + llvm::interleave(opLayout.getSupportedOutputLayouts(), ss, "; "); + ss << ". "; } return ss; } @@ -217,8 +219,6 @@ static bool isDimsDivisibleByTileSizes(ArrayRef dimsPos, GlobalAnalysis::GlobalAnalysis(Operation *root) { root->walk([&](Operation *op) { if (auto linalgOp = dyn_cast(op)) { - LLVM_DEBUG(llvm::dbgs() - << "Inferring layout of op: " << op->getName() << "\n"); auto curInputs = linalgOp.getDpsInputOperands(); auto curResults = linalgOp.getOperation()->getResults(); // ---------------- Get Current Input Layouts ------------------- @@ -277,8 +277,11 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { rewriter.getIndexAttr(iin)}); OperatorLayout suggestedLayout({ALayout, BLayout}, {CLayout}); layoutCache[linalgOp] = suggestedLayout; + LLVM_DEBUG(llvm::dbgs() << "Inferred layout of op: " << op->getName() + << " is: " << suggestedLayout << "\n"); } else if (!mlir::linalg::isaContractionOpInterface(linalgOp) && - !mlir::linalg::isaConvolutionOpInterface(linalgOp) && + !isa( + linalgOp.getOperation()) && !supportedContractionNamedOpList(linalgOp)) { // infer layout for non-contraction/non-convolution linalg named ops // and linalg generic ops @@ -311,6 +314,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { outputLayouts.push_back(outputLayout); OperatorLayout suggestedLayout(inputLayouts, outputLayouts); layoutCache[linalgOp] = suggestedLayout; + LLVM_DEBUG(llvm::dbgs() << "Inferred layout of op: " << op->getName() + << " is: " << suggestedLayout << "\n"); } } else if (auto padOp = dyn_cast(op)) { auto inputOperand = padOp.getSource(); @@ -325,6 +330,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { outputLayouts{curInputLayout}; OperatorLayout suggestedLayout(inputLayouts, outputLayouts); layoutCache[padOp] = suggestedLayout; + LLVM_DEBUG(llvm::dbgs() << "Inferred layout of op: " << op->getName() + << " is: " << suggestedLayout << "\n"); } else if (auto expandShapeOp = dyn_cast(op)) { SmallVector reassocIndices = expandShapeOp.getReassociationIndices(); @@ -343,8 +350,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { ArrayRef innerDimsPos = curInputLayout.getInnerAxis(); ArrayRef outerDimsPerm = curInputLayout.getOuterAxis(); SmallVector projectedInnerDimsPos = - projectToInnerMostNonUnitDimsPos(curInputLayout.getInnerAxis(), - reassocIndices, staticOutputShape); + projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, + staticOutputShape); if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, staticOutputShape, innerTileSizes)) { @@ -362,6 +369,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { outputLayouts{outputLayout}; OperatorLayout suggestedLayout(inputLayouts, outputLayouts); layoutCache[expandShapeOp] = suggestedLayout; + LLVM_DEBUG(llvm::dbgs() << "Inferred layout of op: " << op->getName() + << " is: " << suggestedLayout << "\n"); } return WalkResult::advance(); }); diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp index d6298322d..e0674c196 100644 --- a/lib/gc/Transforms/PropagateLayout.cpp +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -304,8 +304,6 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, Location loc = expandShapeOp->getLoc(); auto inputLayout = opLayout->getSupportedInputLayouts()[0]; auto outputLayout = opLayout->getSupportedOutputLayouts()[0]; - LLVM_DEBUG(llvm::dbgs() << "Input layout: " << inputLayout << ".\n"); - LLVM_DEBUG(llvm::dbgs() << "Output layout: " << outputLayout << ".\n"); Value curSrc = expandShapeOp.getSrc(); Value curDst = expandShapeOp.getResult(); Value dest = tensor::PackOp::createDestinationTensor( From 5af47a3be20ef1697d2a30face68350528b0ee08 Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Mon, 5 Aug 2024 23:22:17 -0700 Subject: [PATCH 04/23] update to OpOperand, adapt to latest llvm main --- lib/gc/Transforms/PostProcessPackUnpack.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/gc/Transforms/PostProcessPackUnpack.cpp b/lib/gc/Transforms/PostProcessPackUnpack.cpp index 0427cfaf9..55d09ba3b 100644 --- a/lib/gc/Transforms/PostProcessPackUnpack.cpp +++ b/lib/gc/Transforms/PostProcessPackUnpack.cpp @@ -155,8 +155,9 @@ static void tppPopulateSimplifyPacking(RewritePatternSet &patterns) { scf::ForallOp::getCanonicalizationPatterns(patterns, ctx); // Propagate packs/unpacks only through expand shapes at this point. // This captures the transformation scope of the replaced downstream pass. - linalg::populateDataLayoutPropagationPatterns( - patterns, [](Operation *op) { return isa(op); }); + linalg::populateDataLayoutPropagationPatterns(patterns, [](OpOperand *op) { + return isa(op->getOwner()); + }); ctx->getLoadedDialect()->getCanonicalizationPatterns( patterns); // patterns.add(ctx); From c31ca2176355c0588f8a99f553dab1fa0fa0111c Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Tue, 6 Aug 2024 01:52:27 -0700 Subject: [PATCH 05/23] fix getPackedAxes --- lib/gc/Analysis/GlobalAnalysis.cpp | 2 ++ lib/gc/Transforms/PropagateLayout.cpp | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/lib/gc/Analysis/GlobalAnalysis.cpp b/lib/gc/Analysis/GlobalAnalysis.cpp index 9e5f94da4..7d14ad145 100644 --- a/lib/gc/Analysis/GlobalAnalysis.cpp +++ b/lib/gc/Analysis/GlobalAnalysis.cpp @@ -138,6 +138,8 @@ inferTargetLayout(TensorLayout layoutBase, newDimAxis.push_back(pair.first); } } + // TODO(yifei): double consider the performance, whether to push all new axis + // at the beginning of outer perm targetOuterAxis.insert(targetOuterAxis.begin(), newDimAxis.begin(), newDimAxis.end()); for (auto &&[ia, ts] : llvm::zip(baseInnerAxis, baseTileSizes)) { diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp index e0674c196..99baf9333 100644 --- a/lib/gc/Transforms/PropagateLayout.cpp +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -36,11 +36,13 @@ using namespace mlir::tensor; static SmallVector getPackedAxes(ArrayRef dimensions, TensorLayout targetLayout) { - SmallVector result(dimensions); + SmallVector result; // permuting on outer axis auto outerPerm = targetLayout.getOuterAxis(); for (size_t i = 0; i < dimensions.size(); ++i) { - result[i] = outerPerm[dimensions[i]]; + auto pos = std::find(outerPerm.begin(), outerPerm.end(), dimensions[i]); + assert(pos != outerPerm.end() && "dimension must be within output perm."); + result.push_back(std::distance(outerPerm.begin(), pos)); } // inserting inner axis auto innerPos = targetLayout.getInnerAxis(); From 4967daf0ecc5e2f06af8b5b56897290198b29493 Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Wed, 7 Aug 2024 03:50:55 -0700 Subject: [PATCH 06/23] extend to collapse shape --- lib/gc/Analysis/GlobalAnalysis.cpp | 81 +++++++++++++++++++++++++-- lib/gc/Transforms/PropagateLayout.cpp | 35 ++++++++++++ 2 files changed, 111 insertions(+), 5 deletions(-) diff --git a/lib/gc/Analysis/GlobalAnalysis.cpp b/lib/gc/Analysis/GlobalAnalysis.cpp index 7d14ad145..3fb54c4b7 100644 --- a/lib/gc/Analysis/GlobalAnalysis.cpp +++ b/lib/gc/Analysis/GlobalAnalysis.cpp @@ -10,6 +10,8 @@ #include "gc/Analysis/GlobalAnalysis.h" #include "gc/Analysis/MatmulConfigAnalysis.h" +#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/SetVector.h" namespace mlir { namespace gc { @@ -345,14 +347,16 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { ? layoutCache[parent].getOutputLayout(0) : TensorLayout::createPlainLayout(inputShape.size()); SmallVector innerTileSizes; - auto intTileSizes = getConstantIntValues(curInputLayout.getTileSizes()); - if (intTileSizes) { - innerTileSizes = *intTileSizes; + auto tileSizes = getConstantIntValues(curInputLayout.getTileSizes()); + if (tileSizes) { + innerTileSizes = *tileSizes; + } else { + return WalkResult::skip(); } - ArrayRef innerDimsPos = curInputLayout.getInnerAxis(); + ArrayRef innerPosPos = curInputLayout.getInnerAxis(); ArrayRef outerDimsPerm = curInputLayout.getOuterAxis(); SmallVector projectedInnerDimsPos = - projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, + projectToInnerMostNonUnitDimsPos(innerPosPos, reassocIndices, staticOutputShape); if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, staticOutputShape, @@ -373,6 +377,73 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { layoutCache[expandShapeOp] = suggestedLayout; LLVM_DEBUG(llvm::dbgs() << "Inferred layout of op: " << op->getName() << " is: " << suggestedLayout << "\n"); + } else if (auto collapseShapeOp = dyn_cast(op)) { + SmallVector reassocIndices = + collapseShapeOp.getReassociationIndices(); + auto parent = collapseShapeOp.getSrc().getDefiningOp(); + auto inputShape = collapseShapeOp.getSrcType().getShape(); + TensorLayout curInputLayout = + layoutCache.find(parent) != layoutCache.end() + ? layoutCache[parent].getOutputLayout(0) + : TensorLayout::createPlainLayout(inputShape.size()); + auto innerPos = curInputLayout.getInnerAxis(); + llvm::SetVector innerPosSet(innerPos.begin(), innerPos.end()); + for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { + // For each reassociation, figure out which dimensions get packed if + // any. + llvm::SetVector collapseDimPos(indices.begin(), indices.end()); + llvm::SetVector packedDims = + llvm::set_intersection(innerPosSet, collapseDimPos); + // only one of the collapsed indices can be packed + if (packedDims.size() > 1) + return WalkResult::skip(); + // Only the inner-most expanded dimension should be packed. Otherwise, + // elements order will be affected after operation reordering. + if (!packedDims.empty() && packedDims[0] != indices.back()) + return WalkResult::skip(); + } + + // Project pack.inner_dims_pos to positions before shape expansion. + SmallVector projectedInnerDimsPos; + for (auto pos : innerPos) { + for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { + if (llvm::any_of(indices, [&](int64_t collapseDim) { + return collapseDim == pos; + })) { + projectedInnerDimsPos.push_back(idx); + break; + } + } + } + assert(projectedInnerDimsPos.size() == innerPos.size() && + "Invalid dim pos projection"); + + // outerPerm shall be a permutation of reassocIndices + auto outerPerm = curInputLayout.getOuterAxis(); + SmallVector newOuterDimsPerm; + int64_t axisIdx = 0; + while (axisIdx < outerPerm.size()) { + for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { + if (llvm::any_of(indices, [&](int64_t collapseDim) { + return collapseDim == outerPerm[axisIdx]; + })) { + for (auto collapseDim : indices) { + if (collapseDim != outerPerm[axisIdx++]) + return WalkResult::skip(); + } + newOuterDimsPerm.push_back(idx); + break; + } + } + } + TensorLayout outputLayout(newOuterDimsPerm, projectedInnerDimsPos, + curInputLayout.getTileSizes()); + SmallVector inputLayouts{curInputLayout}, + outputLayouts{outputLayout}; + OperatorLayout suggestedLayout(inputLayouts, outputLayouts); + layoutCache[collapseShapeOp] = suggestedLayout; + LLVM_DEBUG(llvm::dbgs() << "Inferred layout of op: " << op->getName() + << " is: " << suggestedLayout << "\n"); } return WalkResult::advance(); }); diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp index 99baf9333..bded9b53b 100644 --- a/lib/gc/Transforms/PropagateLayout.cpp +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -337,6 +337,41 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, loc, packedExpandShape, unpackDst, outputLayout.getInnerAxis(), outputLayout.getTileSizes(), outputLayout.getOuterAxis()); rewriter.replaceOp(expandShapeOp, newUnPackOp); + } else if (auto collapseShapeOp = dyn_cast(op)) { + Location loc = collapseShapeOp->getLoc(); + auto inputLayout = opLayout->getSupportedInputLayouts()[0]; + auto outputLayout = opLayout->getSupportedOutputLayouts()[0]; + Value curSrc = collapseShapeOp.getSrc(); + Value curDst = collapseShapeOp.getResult(); + Value dest = tensor::PackOp::createDestinationTensor( + rewriter, loc, curSrc, inputLayout.getTileSizes(), + inputLayout.getInnerAxis(), inputLayout.getOuterAxis()); + Value packedSource = rewriter.create( + loc, curSrc, dest, inputLayout.getInnerAxis(), + inputLayout.getTileSizes(), std::nullopt, + inputLayout.getOuterAxis()); + SmallVector newReassocIndices = + collapseShapeOp.getReassociationIndices(); + int64_t nextPos = applyPermutationAndReindexReassoc( + newReassocIndices, outputLayout.getOuterAxis()); + // Then add direct mapping for the inner tile dims. + for (size_t i = 0; i < inputLayout.getInnerAxis().size(); ++i) { + newReassocIndices.push_back({nextPos}); + nextPos += 1; + } + RankedTensorType newCollapseType = tensor::PackOp::inferPackedType( + dyn_cast(curDst.getType()), + *getConstantIntValues(outputLayout.getTileSizes()), + outputLayout.getInnerAxis(), outputLayout.getOuterAxis()); + Value packedCollapseShape = rewriter.create( + loc, newCollapseType, packedSource, newReassocIndices); + auto unpackDst = tensor::UnPackOp::createDestinationTensor( + rewriter, loc, packedCollapseShape, outputLayout.getTileSizes(), + outputLayout.getInnerAxis(), outputLayout.getOuterAxis()); + auto newUnPackOp = rewriter.create( + loc, packedCollapseShape, unpackDst, outputLayout.getInnerAxis(), + outputLayout.getTileSizes(), outputLayout.getOuterAxis()); + rewriter.replaceOp(collapseShapeOp, newUnPackOp); } } return WalkResult::advance(); From 2f4a9afec8c14390e748e31d7c6944a9b4722b94 Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Thu, 8 Aug 2024 23:27:06 -0700 Subject: [PATCH 07/23] replace empty inner pos pack with transpose --- include/gc/Transforms/Transforms.h | 5 +- lib/gc/Transforms/PropagateLayout.cpp | 127 ++++++++++++++------------ 2 files changed, 71 insertions(+), 61 deletions(-) diff --git a/include/gc/Transforms/Transforms.h b/include/gc/Transforms/Transforms.h index 0e0ed76c7..590fe3cfc 100644 --- a/include/gc/Transforms/Transforms.h +++ b/include/gc/Transforms/Transforms.h @@ -15,9 +15,8 @@ namespace mlir { namespace gc { -FailureOr packNamedOp(RewriterBase &rewriter, - linalg::LinalgOp linalgOp, - OperatorLayout opLayout); +LogicalResult packLinalgOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, + OperatorLayout opLayout); LogicalResult namedOpLayoutPropagation(RewriterBase &rewriter, linalg::LinalgOp linalgOp, diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp index bded9b53b..d8e61da86 100644 --- a/lib/gc/Transforms/PropagateLayout.cpp +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -34,6 +34,44 @@ using namespace mlir; using namespace mlir::arith; using namespace mlir::tensor; +// insert pack when innerPosDims is non-empty +// insert linalg.transpose otherwise +static Value insertLayoutPack(RewriterBase &rewriter, Location loc, Value input, + Value dest, ArrayRef innerDimsPos, + ArrayRef innerTiles, + ArrayRef outerDimsPerm) { + if (!innerDimsPos.empty()) + return rewriter.create( + loc, input, dest, innerDimsPos, innerTiles, + /*padding=*/std::nullopt, outerDimsPerm); + else { + return rewriter.create(loc, input, dest, outerDimsPerm) + .getResults()[0]; + } +} + +// insert unpack when innerPosDims is non-empty +// insert linalg.transpose otherwise +static Value insertLayoutUnpack(RewriterBase &rewriter, Location loc, + Value input, ArrayRef innerDimsPos, + ArrayRef innerTiles, + ArrayRef outerDimsPerm) { + Value dest = tensor::UnPackOp::createDestinationTensor( + rewriter, loc, input, innerTiles, innerDimsPos, outerDimsPerm); + if (!innerDimsPos.empty()) { + return rewriter.create(loc, input, dest, innerDimsPos, + innerTiles, outerDimsPerm); + } else { + // inverse the permutationVector + SmallVector permAxes(outerDimsPerm.size()); + for (auto [idx, axis] : llvm::enumerate(outerDimsPerm)) { + permAxes[axis] = idx; + } + return rewriter.create(loc, input, dest, permAxes) + .getResults()[0]; + } +} + static SmallVector getPackedAxes(ArrayRef dimensions, TensorLayout targetLayout) { SmallVector result; @@ -94,16 +132,15 @@ static int64_t applyPermutationAndReindexReassoc( } // extends linalg::pack(...) for named ops -FailureOr packNamedOp(RewriterBase &rewriter, - linalg::LinalgOp linalgOp, - OperatorLayout opLayout) { +LogicalResult packLinalgOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, + OperatorLayout opLayout) { if (linalgOp.hasPureBufferSemantics()) return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics"); LLVM_DEBUG(llvm::dbgs() << "Try packing named op " << linalgOp.getOperation()->getName() << ".\n"); Location loc = linalgOp->getLoc(); - SmallVector packOps; - SmallVector unPackOps; + SmallVector packOps; + SmallVector unPackOps; SmallVector inputsAndInits, results; SmallVector initOperands = llvm::to_vector(llvm::map_range( linalgOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); @@ -148,17 +185,10 @@ FailureOr packNamedOp(RewriterBase &rewriter, operandType.getShape(), innerPos, cast(dest.getType()).getShape(), {}, innerPackSizes)) { - packOps.push_back(rewriter.create( - loc, operand, dest, innerPos, innerPackSizes, std::nullopt, - outerPerm)); + packOps.push_back(insertLayoutPack( + rewriter, loc, operand, dest, innerPos, innerPackSizes, outerPerm)); } else { - // TODO: value of the padding attribute should be determined by - // consumers. - auto zeroAttr = - rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); - Value zero = rewriter.create(loc, zeroAttr); - packOps.push_back(rewriter.create( - loc, operand, dest, innerPos, innerPackSizes, zero, outerPerm)); + return failure(); } inputsAndInits.push_back(packOps.back()); } @@ -187,11 +217,10 @@ FailureOr packNamedOp(RewriterBase &rewriter, transposeOp->getPermutation(), inputLayouts[0], initLayouts[0]); packedLinalgOp = rewriter.create( loc, inputs[0], inits[0], packedPermAxes); - } else if (isa(linalgOp) || - isa(linalgOp) || isa(linalgOp) || + } else if (isa(linalgOp) || isa(linalgOp) || isa(linalgOp) || isa(linalgOp)) { return failure( - "Packing logic not implemented for SoftMax/Generic/Map/Yield/Index."); + "Packing logic not implemented for SoftMax/Map/Yield/Index."); } else { packedLinalgOp = mlir::clone( rewriter, linalgOp, SmallVector{inputsAndInits.back().getType()}, @@ -201,27 +230,20 @@ FailureOr packNamedOp(RewriterBase &rewriter, // Step 4. Unpack all the op results. for (OpResult result : packedLinalgOp->getResults()) { int64_t resultNum = result.getResultNumber(); - tensor::PackOp maybePackedInit = - inits[resultNum].getDefiningOp(); - if (!maybePackedInit) { - results.push_back(result); - continue; - } + assert(resultNum < initLayouts.size() && + "Linalg op results num exceeds inits num."); // Build the symmetrical UnPackOp to the existing PackOp. - unPackOps.push_back(rewriter.create( - packedLinalgOp->getLoc(), result, maybePackedInit.getSource(), - maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles(), - maybePackedInit.getOuterDimsPerm())); + unPackOps.push_back( + insertLayoutUnpack(rewriter, packedLinalgOp->getLoc(), result, + initLayouts[resultNum].getInnerAxis(), + initLayouts[resultNum].getTileSizes(), + initLayouts[resultNum].getOuterAxis())); results.push_back(unPackOps.back()); } // Step 5. Replace `linalgOp`. rewriter.replaceOp(linalgOp, results); - - // Return packedLinalgOp. - return linalg::PackResult{ - packOps, cast(packedLinalgOp.getOperation()), - unPackOps}; + return success(); } // check whether the op is already packed or not @@ -297,9 +319,7 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, return WalkResult::advance(); } if (auto linalgOp = dyn_cast(op)) { - FailureOr packedOp = - packNamedOp(rewriter, linalgOp, *opLayout); - if (failed(packedOp)) { + if (failed(packLinalgOp(rewriter, linalgOp, *opLayout))) { return WalkResult::skip(); } } else if (auto expandShapeOp = dyn_cast(op)) { @@ -311,10 +331,9 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, Value dest = tensor::PackOp::createDestinationTensor( rewriter, loc, curSrc, inputLayout.getTileSizes(), inputLayout.getInnerAxis(), inputLayout.getOuterAxis()); - Value packedSource = rewriter.create( - loc, curSrc, dest, inputLayout.getInnerAxis(), - inputLayout.getTileSizes(), std::nullopt, - inputLayout.getOuterAxis()); + Value packedSource = insertLayoutPack( + rewriter, loc, curSrc, dest, inputLayout.getInnerAxis(), + inputLayout.getTileSizes(), inputLayout.getOuterAxis()); SmallVector newReassocIndices = expandShapeOp.getReassociationIndices(); int64_t nextPos = applyPermutationAndReindexReassoc( @@ -330,11 +349,8 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, outputLayout.getInnerAxis(), outputLayout.getOuterAxis()); Value packedExpandShape = rewriter.create( loc, newExpandType, packedSource, newReassocIndices); - auto unpackDst = tensor::UnPackOp::createDestinationTensor( - rewriter, loc, packedExpandShape, outputLayout.getTileSizes(), - outputLayout.getInnerAxis(), outputLayout.getOuterAxis()); - auto newUnPackOp = rewriter.create( - loc, packedExpandShape, unpackDst, outputLayout.getInnerAxis(), + Value newUnPackOp = insertLayoutUnpack( + rewriter, loc, packedExpandShape, outputLayout.getInnerAxis(), outputLayout.getTileSizes(), outputLayout.getOuterAxis()); rewriter.replaceOp(expandShapeOp, newUnPackOp); } else if (auto collapseShapeOp = dyn_cast(op)) { @@ -346,10 +362,9 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, Value dest = tensor::PackOp::createDestinationTensor( rewriter, loc, curSrc, inputLayout.getTileSizes(), inputLayout.getInnerAxis(), inputLayout.getOuterAxis()); - Value packedSource = rewriter.create( - loc, curSrc, dest, inputLayout.getInnerAxis(), - inputLayout.getTileSizes(), std::nullopt, - inputLayout.getOuterAxis()); + Value packedSource = insertLayoutPack( + rewriter, loc, curSrc, dest, inputLayout.getInnerAxis(), + inputLayout.getTileSizes(), inputLayout.getOuterAxis()); SmallVector newReassocIndices = collapseShapeOp.getReassociationIndices(); int64_t nextPos = applyPermutationAndReindexReassoc( @@ -365,11 +380,8 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, outputLayout.getInnerAxis(), outputLayout.getOuterAxis()); Value packedCollapseShape = rewriter.create( loc, newCollapseType, packedSource, newReassocIndices); - auto unpackDst = tensor::UnPackOp::createDestinationTensor( - rewriter, loc, packedCollapseShape, outputLayout.getTileSizes(), - outputLayout.getInnerAxis(), outputLayout.getOuterAxis()); - auto newUnPackOp = rewriter.create( - loc, packedCollapseShape, unpackDst, outputLayout.getInnerAxis(), + Value newUnPackOp = insertLayoutUnpack( + rewriter, loc, packedCollapseShape, outputLayout.getInnerAxis(), outputLayout.getTileSizes(), outputLayout.getOuterAxis()); rewriter.replaceOp(collapseShapeOp, newUnPackOp); } @@ -645,10 +657,9 @@ struct UpliftPackOverBroadcast : public OpRewritePattern { auto dest = tensor::PackOp::createDestinationTensor( rewriter, loc, broadcastOp.getDpsInputs()[0], newInnerTileSizes, newInnerDimsPos, newOuterDimsPerm); - Value packedSource = rewriter.create( - loc, broadcastOp.getDpsInputs()[0], dest, newInnerDimsPos, - newInnerTileSizes, - /*padding=*/std::nullopt, newOuterDimsPerm); + Value packedSource = + insertLayoutPack(rewriter, loc, broadcastOp.getDpsInputs()[0], dest, + newInnerDimsPos, newInnerTileSizes, newOuterDimsPerm); auto newBroadcastOp = rewriter.create( loc, packedSource, pack.getDest(), packedBroadcastAxis); rewriter.replaceOp(pack, newBroadcastOp.getResults()); From df32b62043ef8bf3e4a6d7427d9267d21caad4c4 Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Sun, 11 Aug 2024 02:12:04 -0700 Subject: [PATCH 08/23] add transpose canonicalization --- lib/gc/Transforms/PostProcessPackUnpack.cpp | 34 ++++++--------------- 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/lib/gc/Transforms/PostProcessPackUnpack.cpp b/lib/gc/Transforms/PostProcessPackUnpack.cpp index 55d09ba3b..3e34865d7 100644 --- a/lib/gc/Transforms/PostProcessPackUnpack.cpp +++ b/lib/gc/Transforms/PostProcessPackUnpack.cpp @@ -5,35 +5,28 @@ // //===----------------------------------------------------------------------===// -#include #include -#include "gc/Transforms/Transforms.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "gc/Dialect/Linalgx/LinalgxDialect.h" -#include "gc/Dialect/Linalgx/LinalgxOps.h" #include "gc/Transforms/Passes.h" +#include "gc/Transforms/Transforms.h" + namespace mlir { namespace gc { #define GEN_PASS_DEF_POSTPROCESSPACKUNPACK #include "gc/Transforms/Passes.h.inc" -#define DEBUG_TYPE "post-process-pack-unpack" - using namespace mlir; -// Helper pattern - lower tensor.pack operations that pack constants. +// copied from tpp - lower tensor.pack operations that pack constants. struct LowerConstantPacking : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -67,11 +60,9 @@ struct LowerConstantPacking : public OpRewritePattern { } }; -static void tppPopulateConstantFoldPack(RewritePatternSet &patterns) { +static void populateConstantFoldPacking(RewritePatternSet &patterns) { MLIRContext *ctx = patterns.getContext(); patterns.add(ctx); - // Apply canonicalization to fold trivial cases and linalg constant folders - // to cleanup lowered packs. linalg::FillOp::getCanonicalizationPatterns(patterns, ctx); tensor::PackOp::getCanonicalizationPatterns(patterns, ctx); tensor::populateRewriteAsConstantPatterns( @@ -139,7 +130,7 @@ class PostProcessPackUnpack void runOnOperation() final; }; -static void tppPopulateSimplifyPacking(RewritePatternSet &patterns) { +static void populateSimplifyPacking(RewritePatternSet &patterns) { MLIRContext *ctx = patterns.getContext(); tensor::populateSimplifyPackAndUnpackPatterns(patterns); tensor::populateFoldTensorEmptyPatterns(patterns); @@ -153,26 +144,21 @@ static void tppPopulateSimplifyPacking(RewritePatternSet &patterns) { tensor::PadOp::getCanonicalizationPatterns(patterns, ctx); tensor::ParallelInsertSliceOp::getCanonicalizationPatterns(patterns, ctx); scf::ForallOp::getCanonicalizationPatterns(patterns, ctx); - // Propagate packs/unpacks only through expand shapes at this point. - // This captures the transformation scope of the replaced downstream pass. - linalg::populateDataLayoutPropagationPatterns(patterns, [](OpOperand *op) { - return isa(op->getOwner()); - }); ctx->getLoadedDialect()->getCanonicalizationPatterns( patterns); - // patterns.add(ctx); tensor::populateReassociativeReshapeFoldingPatterns(patterns); } void PostProcessPackUnpack::runOnOperation() { auto *ctx = &getContext(); RewritePatternSet patterns(ctx); - // constant fold packing - tppPopulateConstantFoldPack(patterns); + // constant fold packing and transpose + populateConstantFoldPacking(patterns); // simplify packing - tppPopulateSimplifyPacking(patterns); - // gc new packing related simplification + populateSimplifyPacking(patterns); populateEliminateDummyPackUnpack(patterns); + // simplify transpose inserted to perform packing + linalg::TransposeOp::getCanonicalizationPatterns(patterns, ctx); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } From 83306302b3c90062ac0b6f7285d7384e57d7b087 Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Sun, 11 Aug 2024 03:30:42 -0700 Subject: [PATCH 09/23] update --- include/gc/Analysis/GlobalAnalysis.h | 10 ++++++-- lib/gc/Analysis/GlobalAnalysis.cpp | 35 +++++++++++---------------- lib/gc/Transforms/PropagateLayout.cpp | 7 ++++-- 3 files changed, 27 insertions(+), 25 deletions(-) diff --git a/include/gc/Analysis/GlobalAnalysis.h b/include/gc/Analysis/GlobalAnalysis.h index 48fe9677e..3309bc432 100644 --- a/include/gc/Analysis/GlobalAnalysis.h +++ b/include/gc/Analysis/GlobalAnalysis.h @@ -31,12 +31,18 @@ class TensorLayout { assert(innerAxis.size() == tileSizes.size()); } - bool isPlainLayout() const { + static bool isPlainOuterAxis(ArrayRef outerAxis) { for (int64_t i = 0; i < static_cast(outerAxis.size()); ++i) { if (i != outerAxis[i]) return false; } - return tileSizes.empty() && innerAxis.empty(); + return true; + } + + bool isPlainLayout() const { + if (isPlainOuterAxis(outerAxis)) + return tileSizes.empty() && innerAxis.empty(); + return false; } static TensorLayout createPlainLayout(int64_t rank) { diff --git a/lib/gc/Analysis/GlobalAnalysis.cpp b/lib/gc/Analysis/GlobalAnalysis.cpp index 3fb54c4b7..c1102ecee 100644 --- a/lib/gc/Analysis/GlobalAnalysis.cpp +++ b/lib/gc/Analysis/GlobalAnalysis.cpp @@ -6,7 +6,6 @@ //===----------------------------------------------------------------------===// #include -#include #include "gc/Analysis/GlobalAnalysis.h" #include "gc/Analysis/MatmulConfigAnalysis.h" @@ -60,7 +59,7 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, } // infer the relation between two indexing maps -// returns target dim -> base dim, means target is the same as input +// returns target dim -> base dim, means target is the same as base // we don't allow duplication, e.g. 2 target corresponding to 1 base static FailureOr> inferIndexingMapRelation(AffineMap indexingMapBase, @@ -208,7 +207,8 @@ projectToInnerMostNonUnitDimsPos(ArrayRef dimsPos, return projectedDimsPos; } -/// Check if all dims in dimsPos are divisible by the corresponding tile sizes. +// copied from mlir +// Check if all dims in dimsPos are divisible by the corresponding tile sizes. static bool isDimsDivisibleByTileSizes(ArrayRef dimsPos, ArrayRef shape, ArrayRef tileSizes) { @@ -221,11 +221,12 @@ static bool isDimsDivisibleByTileSizes(ArrayRef dimsPos, } GlobalAnalysis::GlobalAnalysis(Operation *root) { + IRRewriter rewriter(root); root->walk([&](Operation *op) { if (auto linalgOp = dyn_cast(op)) { auto curInputs = linalgOp.getDpsInputOperands(); auto curResults = linalgOp.getOperation()->getResults(); - // ---------------- Get Current Input Layouts ------------------- + // get current op's input layouts SmallVector curInputLayouts; for (auto input : curInputs) { auto parent = input->get().getDefiningOp(); @@ -237,8 +238,7 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { linalgOp.getMatchingIndexingMap(input).getNumResults())); } } - // ------ Get Current Op's Suggested Layout & Do Propagation ------ - IRRewriter rewriter(linalgOp); + // infer current op's output layout accordingly if (supportedContractionNamedOpList(linalgOp)) { // infer layout for linalg contraction named ops auto ARank = cast(linalgOp.getDpsInputs()[0].getType()) @@ -266,7 +266,7 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { MatmulConfigAnalysis(linalgOp.getOperation()).getConfig(); uint32_t iim = cfg.innerMostKBlock, iin = cfg.innerMostNBlock, iik = cfg.innerMostKBlock; - // current layout is MKmk, NKkn, MNmn + // current default layout is MKmk, NKkn, MNmn TensorLayout ALayout( APackInfo.first, APackInfo.second, SmallVector{rewriter.getIndexAttr(iim), @@ -281,12 +281,7 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { rewriter.getIndexAttr(iin)}); OperatorLayout suggestedLayout({ALayout, BLayout}, {CLayout}); layoutCache[linalgOp] = suggestedLayout; - LLVM_DEBUG(llvm::dbgs() << "Inferred layout of op: " << op->getName() - << " is: " << suggestedLayout << "\n"); - } else if (!mlir::linalg::isaContractionOpInterface(linalgOp) && - !isa( - linalgOp.getOperation()) && - !supportedContractionNamedOpList(linalgOp)) { + } else if (mlir::gc::utils::isPackableNamedOp(op)) { // infer layout for non-contraction/non-convolution linalg named ops // and linalg generic ops SmallVector inputLayouts, outputLayouts; @@ -318,8 +313,6 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { outputLayouts.push_back(outputLayout); OperatorLayout suggestedLayout(inputLayouts, outputLayouts); layoutCache[linalgOp] = suggestedLayout; - LLVM_DEBUG(llvm::dbgs() << "Inferred layout of op: " << op->getName() - << " is: " << suggestedLayout << "\n"); } } else if (auto padOp = dyn_cast(op)) { auto inputOperand = padOp.getSource(); @@ -334,8 +327,6 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { outputLayouts{curInputLayout}; OperatorLayout suggestedLayout(inputLayouts, outputLayouts); layoutCache[padOp] = suggestedLayout; - LLVM_DEBUG(llvm::dbgs() << "Inferred layout of op: " << op->getName() - << " is: " << suggestedLayout << "\n"); } else if (auto expandShapeOp = dyn_cast(op)) { SmallVector reassocIndices = expandShapeOp.getReassociationIndices(); @@ -375,8 +366,6 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { outputLayouts{outputLayout}; OperatorLayout suggestedLayout(inputLayouts, outputLayouts); layoutCache[expandShapeOp] = suggestedLayout; - LLVM_DEBUG(llvm::dbgs() << "Inferred layout of op: " << op->getName() - << " is: " << suggestedLayout << "\n"); } else if (auto collapseShapeOp = dyn_cast(op)) { SmallVector reassocIndices = collapseShapeOp.getReassociationIndices(); @@ -442,8 +431,10 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { outputLayouts{outputLayout}; OperatorLayout suggestedLayout(inputLayouts, outputLayouts); layoutCache[collapseShapeOp] = suggestedLayout; + } + if (layoutCache.find(op) != layoutCache.end()) { LLVM_DEBUG(llvm::dbgs() << "Inferred layout of op: " << op->getName() - << " is: " << suggestedLayout << "\n"); + << " is: " << layoutCache[op] << "\n"); } return WalkResult::advance(); }); @@ -452,7 +443,9 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { namespace utils { bool isPackableNamedOp(Operation *op) { if (auto linalgOp = dyn_cast(op)) { - if (!supportedContractionNamedOpList(linalgOp)) { + if (!mlir::linalg::isaContractionOpInterface(linalgOp) && + !isa(linalgOp.getOperation()) && + !supportedContractionNamedOpList(linalgOp)) { return true; } } else if (isa( diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp index d8e61da86..44d951ea1 100644 --- a/lib/gc/Transforms/PropagateLayout.cpp +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -44,10 +44,11 @@ static Value insertLayoutPack(RewriterBase &rewriter, Location loc, Value input, return rewriter.create( loc, input, dest, innerDimsPos, innerTiles, /*padding=*/std::nullopt, outerDimsPerm); - else { + if (!TensorLayout::isPlainOuterAxis(outerDimsPerm)) { return rewriter.create(loc, input, dest, outerDimsPerm) .getResults()[0]; } + return input; } // insert unpack when innerPosDims is non-empty @@ -61,7 +62,8 @@ static Value insertLayoutUnpack(RewriterBase &rewriter, Location loc, if (!innerDimsPos.empty()) { return rewriter.create(loc, input, dest, innerDimsPos, innerTiles, outerDimsPerm); - } else { + } + if (!TensorLayout::isPlainOuterAxis(outerDimsPerm)) { // inverse the permutationVector SmallVector permAxes(outerDimsPerm.size()); for (auto [idx, axis] : llvm::enumerate(outerDimsPerm)) { @@ -70,6 +72,7 @@ static Value insertLayoutUnpack(RewriterBase &rewriter, Location loc, return rewriter.create(loc, input, dest, permAxes) .getResults()[0]; } + return input; } static SmallVector getPackedAxes(ArrayRef dimensions, From a24db0ae2e07c206949b095476947f8ad208c309 Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Tue, 20 Aug 2024 23:52:39 -0700 Subject: [PATCH 10/23] updatee pack matmul and pack vnni --- include/gc/Analysis/GlobalAnalysis.h | 8 +- lib/gc/Analysis/GlobalAnalysis.cpp | 156 +++++++++++++++++++------- lib/gc/Transforms/PropagateLayout.cpp | 54 +++++---- 3 files changed, 152 insertions(+), 66 deletions(-) diff --git a/include/gc/Analysis/GlobalAnalysis.h b/include/gc/Analysis/GlobalAnalysis.h index 3309bc432..46b6726ea 100644 --- a/include/gc/Analysis/GlobalAnalysis.h +++ b/include/gc/Analysis/GlobalAnalysis.h @@ -39,12 +39,14 @@ class TensorLayout { return true; } - bool isPlainLayout() const { + bool isPlain() const { if (isPlainOuterAxis(outerAxis)) return tileSizes.empty() && innerAxis.empty(); return false; } + bool isBlocking() const { return !tileSizes.empty() && !innerAxis.empty(); } + static TensorLayout createPlainLayout(int64_t rank) { SmallVector outerAxis(rank, 0); std::iota(outerAxis.begin(), outerAxis.end(), 0); @@ -86,7 +88,7 @@ class TensorLayout { friend llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, const TensorLayout &layout); - bool operator==(const TensorLayout &layout); + bool operator==(const TensorLayout &layout) const; private: SmallVector outerAxis; @@ -120,7 +122,7 @@ class OperatorLayout { bool isPlain() const { for (const auto &layout : llvm::concat( supportedInputLayouts, supportedOutputLayouts)) { - if (!layout.isPlainLayout()) + if (!layout.isPlain()) return false; } return true; diff --git a/lib/gc/Analysis/GlobalAnalysis.cpp b/lib/gc/Analysis/GlobalAnalysis.cpp index c1102ecee..c254709ca 100644 --- a/lib/gc/Analysis/GlobalAnalysis.cpp +++ b/lib/gc/Analysis/GlobalAnalysis.cpp @@ -37,7 +37,7 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, return ss; } -bool TensorLayout::operator==(const TensorLayout &layout) { +bool TensorLayout::operator==(const TensorLayout &layout) const { return (this->outerAxis == layout.getOuterAxis()) && (this->innerAxis == layout.getInnerAxis()) && (this->tileSizes == layout.getTileSizes()); @@ -154,7 +154,7 @@ inferTargetLayout(TensorLayout layoutBase, static size_t getTargetInputIdx(ArrayRef curInputLayouts) { for (size_t i = 0; i < curInputLayouts.size(); ++i) { - if (!curInputLayouts[i].isPlainLayout()) { + if (!curInputLayouts[i].isPlain()) { return i; } } @@ -220,6 +220,114 @@ static bool isDimsDivisibleByTileSizes(ArrayRef dimsPos, return true; } +// if forceBlocking is set, we strictly follow matmul config to block to +// blocking layout; otherwise we follow query format logic +static SmallVector +queryMatmulLayout(IRRewriter &rewriter, linalg::LinalgOp matmulOp, + ArrayRef curInputLayouts, + bool forceBlocking = false) { + SmallVector ret; + // infer layout for linalg contraction named ops + auto ARank = matmulOp.getRank(matmulOp.getDpsInputOperand(0)); + auto BRank = matmulOp.getRank(matmulOp.getDpsInputOperand(1)); + auto CRank = matmulOp.getRank(matmulOp.getDpsInitOperand(0)); + auto elementType = getElementTypeOrSelf(matmulOp.getDpsInputs()[0].getType()); + auto AShape = matmulOp.getShape(matmulOp.getDpsInputOperand(0)); + auto BShape = matmulOp.getShape(matmulOp.getDpsInputOperand(1)); + int64_t M = AShape[0], K = AShape[1], N = BShape[1]; + bool ASideTransposed = + isa( + matmulOp); + bool BSideTransposed = + isa( + matmulOp); + // set outer&inner axis values + auto APackInfo = getPackingAxis(ARank, ASideTransposed); + auto BPackInfo = getPackingAxis(BRank, BSideTransposed); + auto CPackInfo = getPackingAxis(CRank, /*transposed*/ false); + // query the cost model for tile sizes + MatmulConfig cfg = MatmulConfigAnalysis(matmulOp.getOperation()).getConfig(); + uint32_t iim = cfg.innerMostKBlock, iin = cfg.innerMostNBlock, + iik = cfg.innerMostKBlock; + if (forceBlocking) { + TensorLayout ALayout(APackInfo.first, APackInfo.second, + SmallVector{rewriter.getIndexAttr(iim), + rewriter.getIndexAttr(iik)}); + TensorLayout BLayout(BPackInfo.first, BPackInfo.second, + SmallVector{rewriter.getIndexAttr(iik), + rewriter.getIndexAttr(iin)}); + TensorLayout CLayout(CPackInfo.first, CPackInfo.second, + SmallVector{rewriter.getIndexAttr(iim), + rewriter.getIndexAttr(iin)}); + ret.emplace_back(SmallVector{ALayout, BLayout}, + SmallVector{CLayout}); + return ret; + } + // TODO(yifei): add condition constant_A + TensorLayout transposedLayout({1, 0}, {}, {}); + SmallVector ALayouts, BLayouts, CLayouts; + if (curInputLayouts[0].isBlocking() || (M % iim) || (K % iik) || + (elementType.isBF16() && curInputLayouts[0] == transposedLayout)) { + ALayouts.emplace_back( + APackInfo.first, APackInfo.second, + SmallVector{rewriter.getIndexAttr(iim), + rewriter.getIndexAttr(iik)}); + } else { + ALayouts.emplace_back(APackInfo.first, SmallVector{}, + SmallVector{}); + } + if (curInputLayouts[0].isBlocking() || (M % iim) || (K % iik) || + (elementType.isBF16() && curInputLayouts[0] == transposedLayout)) { + ALayouts.emplace_back( + APackInfo.first, APackInfo.second, + SmallVector{rewriter.getIndexAttr(iim), + rewriter.getIndexAttr(iik)}); + } else { + ALayouts.emplace_back(APackInfo.first, SmallVector{}, + SmallVector{}); + } + if (curInputLayouts[1].isBlocking() || K % iik || N % iin || + elementType.isBF16()) { + BLayouts.emplace_back( + BPackInfo.first, BPackInfo.second, + SmallVector{rewriter.getIndexAttr(iik), + rewriter.getIndexAttr(iin)}); + } else { + BLayouts.emplace_back(BPackInfo.first, SmallVector{}, + SmallVector{}); + } + if (M == iim && M >= 32 && N % iin == 0) { + CLayouts.emplace_back(CPackInfo.first, SmallVector{}, + SmallVector{}); + } else if (M % iim || N % iin) { + CLayouts.emplace_back( + CPackInfo.first, CPackInfo.second, + SmallVector{rewriter.getIndexAttr(iim), + rewriter.getIndexAttr(iin)}); + } else { + if (BSideTransposed) { + CLayouts.emplace_back(CPackInfo.first, SmallVector{}, + SmallVector{}); + } else { + // push 2 possibilities + CLayouts.emplace_back(CPackInfo.first, SmallVector{}, + SmallVector{}); + CLayouts.emplace_back( + CPackInfo.first, CPackInfo.second, + SmallVector{rewriter.getIndexAttr(iim), + rewriter.getIndexAttr(iin)}); + ALayouts.emplace_back(ALayouts[0]); + BLayouts.emplace_back(BLayouts[0]); + } + } + for (auto [ALayout, BLayout, CLayout] : + llvm::zip(ALayouts, BLayouts, CLayouts)) { + ret.emplace_back(SmallVector{ALayout, BLayout}, + SmallVector{CLayout}); + } + return ret; +} + GlobalAnalysis::GlobalAnalysis(Operation *root) { IRRewriter rewriter(root); root->walk([&](Operation *op) { @@ -240,47 +348,9 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { } // infer current op's output layout accordingly if (supportedContractionNamedOpList(linalgOp)) { - // infer layout for linalg contraction named ops - auto ARank = cast(linalgOp.getDpsInputs()[0].getType()) - .getShape() - .size(); - auto BRank = cast(linalgOp.getDpsInputs()[1].getType()) - .getShape() - .size(); - auto CRank = - cast(linalgOp.getOperation()->getResults()[0].getType()) - .getShape() - .size(); - bool ASideTransposed = - isa( - linalgOp); - bool BSideTransposed = - isa( - linalgOp); - // set outer&inner axis values - auto APackInfo = getPackingAxis(ARank, ASideTransposed); - auto BPackInfo = getPackingAxis(BRank, BSideTransposed); - auto CPackInfo = getPackingAxis(CRank, /*transposed*/ false); - // query the cost model for tile sizes - MatmulConfig cfg = - MatmulConfigAnalysis(linalgOp.getOperation()).getConfig(); - uint32_t iim = cfg.innerMostKBlock, iin = cfg.innerMostNBlock, - iik = cfg.innerMostKBlock; - // current default layout is MKmk, NKkn, MNmn - TensorLayout ALayout( - APackInfo.first, APackInfo.second, - SmallVector{rewriter.getIndexAttr(iim), - rewriter.getIndexAttr(iik)}); - TensorLayout BLayout( - BPackInfo.first, BPackInfo.second, - SmallVector{rewriter.getIndexAttr(iik), - rewriter.getIndexAttr(iin)}); - TensorLayout CLayout( - CPackInfo.first, CPackInfo.second, - SmallVector{rewriter.getIndexAttr(iim), - rewriter.getIndexAttr(iin)}); - OperatorLayout suggestedLayout({ALayout, BLayout}, {CLayout}); - layoutCache[linalgOp] = suggestedLayout; + auto suggestedLayouts = + queryMatmulLayout(rewriter, linalgOp, curInputLayouts, true); + layoutCache[linalgOp] = suggestedLayouts[0]; } else if (mlir::gc::utils::isPackableNamedOp(op)) { // infer layout for non-contraction/non-convolution linalg named ops // and linalg generic ops diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp index 44d951ea1..5dcc69df4 100644 --- a/lib/gc/Transforms/PropagateLayout.cpp +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -394,7 +394,7 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, return success(); } -static void createAndReplaceWithGenericVNNIMatmul( +static LogicalResult createAndReplaceWithGenericVNNIMatmul( RewriterBase &rewriter, MLIRContext *context, SmallVector inputs, SmallVector inits, int64_t batchDimSize, int64_t blockingFactor, Operation *matmulOp) { @@ -437,10 +437,12 @@ static void createAndReplaceWithGenericVNNIMatmul( rewriter.inlineRegionBefore(matmulOp->getRegion(0), replacementOp.getRegion(), replacementOp.getRegion().begin()); rewriter.replaceOp(matmulOp, replacementOp.getResult(0)); + return success(); } template -static LogicalResult packVNNIMMT4D(RewriterBase &rewriter, OpTy mmt4dOp) { +static LogicalResult packVNNIMMT4D(RewriterBase &rewriter, OpTy mmt4dOp, + bool useNamedOp = false) { auto elementType = getElementTypeOrSelf(mmt4dOp.getInputs()[0].getType()); if (!elementType.isBF16() && !elementType.isInteger(8)) return rewriter.notifyMatchFailure(mmt4dOp, "require bf16/int8 data type"); @@ -462,17 +464,18 @@ static LogicalResult packVNNIMMT4D(RewriterBase &rewriter, OpTy mmt4dOp) { OpOperand *RHSOperand = mmt4dOp.getDpsInputOperand(1); Value dest = tensor::PackOp::createDestinationTensor( rewriter, loc, RHSOperand->get(), tileSize, innerPos, outerPerm); - Value VNNIPack = - rewriter.create(loc, RHSOperand->get(), dest, innerPos, - tileSize, std::nullopt, outerPerm); + auto zeroAttr = rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); + Value zero = rewriter.create(loc, zeroAttr); + Value VNNIPack = rewriter.create( + loc, RHSOperand->get(), dest, innerPos, tileSize, zero, outerPerm); SmallVector inputsValues{mmt4dOp.getInputs()[0], VNNIPack}; - if (!batchDimSize) { + if (useNamedOp) { auto vnniOp = rewriter.create( loc, mmt4dOp.getDpsInits().getTypes(), inputsValues, mmt4dOp.getDpsInits()); rewriter.replaceOp(mmt4dOp, vnniOp); } else { - mlir::gc::createAndReplaceWithGenericVNNIMatmul( + auto result = mlir::gc::createAndReplaceWithGenericVNNIMatmul( rewriter, mmt4dOp.getContext(), inputsValues, mmt4dOp.getDpsInits(), batchDimSize, blockingFactor, mmt4dOp); } @@ -510,7 +513,8 @@ If possible, pack to Mm2DVnniOp or Mm4DVnniOp. If not possible, pack to GenericOp. */ static LogicalResult packVNNIGeneric(RewriterBase &rewriter, - linalg::GenericOp matmulOp) { + linalg::GenericOp matmulOp, + bool useNamedOp = false) { if (matmulOp.getDpsInputs().size() != 2) return rewriter.notifyMatchFailure(matmulOp, "require 2 inputs"); @@ -548,15 +552,17 @@ static LogicalResult packVNNIGeneric(RewriterBase &rewriter, int64_t weightRank = cast(weight.get().getType()).getShape().size(); auto innerPos = SmallVector{weightRank - 2}; - // pack weight. + // pack weight Value dest = tensor::PackOp::createDestinationTensor( rewriter, loc, weight.get(), tileSize, innerPos, SmallVector{}); - Value VNNIPack = rewriter.create( - loc, weight.get(), dest, innerPos, tileSize, std::nullopt); + auto zeroAttr = rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); + Value zero = rewriter.create(loc, zeroAttr); + Value VNNIPack = rewriter.create(loc, weight.get(), dest, + innerPos, tileSize, zero); int64_t batchDimSize = weightRank - 4; SmallVector inputsValues{matmulOp.getInputs()[0], VNNIPack}; - if (!batchDimSize) { + if (useNamedOp) { Value operandC = matmulOp.getDpsInits()[0]; auto VNNIMatmulOp = rewriter.create( loc, operandC.getType(), inputsValues, ValueRange{operandC}); @@ -672,6 +678,7 @@ struct UpliftPackOverBroadcast : public OpRewritePattern { void PropagateLayoutOnNamedOps::runOnOperation() { MLIRContext *ctx = &getContext(); + IRRewriter rewriter(ctx); mlir::Operation *graph = getOperation(); // stage1: pack matmul RewritePatternSet packMatmulPatterns(&getContext()); @@ -685,14 +692,19 @@ void PropagateLayoutOnNamedOps::runOnOperation() { // hardcode to let B side to be NKkn options.rhsTransposeOuterBlocks = true; options.rhsTransposeInnerBlocks = false; - assert(LHSLayout.getTileSizes()[1] == RHSLayout.getTileSizes()[0] && - "Inconsistent matmul tile size."); - options.blockFactors.push_back( - *getConstantIntValue(LHSLayout.getTileSizes()[0])); - options.blockFactors.push_back( - *getConstantIntValue(LHSLayout.getTileSizes()[1])); - options.blockFactors.push_back( - *getConstantIntValue(RHSLayout.getTileSizes()[1])); + // extract tile sizes + OpFoldResult M_block = LHSLayout.getTileSizes().empty() + ? rewriter.getIndexAttr(1) + : LHSLayout.getTileSizes()[0]; + OpFoldResult K_block = LHSLayout.getTileSizes().empty() + ? rewriter.getIndexAttr(1) + : LHSLayout.getTileSizes()[1]; + OpFoldResult N_block = RHSLayout.getTileSizes().empty() + ? rewriter.getIndexAttr(1) + : RHSLayout.getTileSizes()[0]; + options.blockFactors.push_back(*getConstantIntValue(M_block)); + options.blockFactors.push_back(*getConstantIntValue(K_block)); + options.blockFactors.push_back(*getConstantIntValue(N_block)); return options; }; linalg::populateBlockPackMatmulPatterns(packMatmulPatterns, @@ -708,6 +720,8 @@ void PropagateLayoutOnNamedOps::runOnOperation() { if (failed(applyPatternsAndFoldGreedily(graph, std::move(packVNNIPatterns)))) return signalPassFailure(); + // stage 2.5: revert necessary blocking on matmul op + // stage3: propagate layout on other named ops ControlPackNamedOpsFn layoutControlFn = [&](Operation *op) -> FailureOr { From 2e538e0c92ee3d3c914714112650139d585b4f28 Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Fri, 13 Sep 2024 01:16:49 -0700 Subject: [PATCH 11/23] sync with mlp benching --- include/gc/Analysis/GlobalAnalysis.h | 6 +- include/gc/Analysis/MatmulConfigAnalysis.h | 6 + include/gc/Dialect/Linalgx/Utils.h | 1 + lib/gc/Analysis/GlobalAnalysis.cpp | 446 ++++++++++-------- lib/gc/Dialect/Linalgx/Utils.cpp | 6 + lib/gc/Transforms/DeepTileContractionOp.cpp | 3 +- lib/gc/Transforms/PropagateLayout.cpp | 265 ++++++++--- test/mlir/test/gc/Transforms/pack-matmul.mlir | 41 +- 8 files changed, 487 insertions(+), 287 deletions(-) diff --git a/include/gc/Analysis/GlobalAnalysis.h b/include/gc/Analysis/GlobalAnalysis.h index 46b6726ea..824e8c904 100644 --- a/include/gc/Analysis/GlobalAnalysis.h +++ b/include/gc/Analysis/GlobalAnalysis.h @@ -90,6 +90,8 @@ class TensorLayout { bool operator==(const TensorLayout &layout) const; + bool operator!=(const TensorLayout &layout) const; + private: SmallVector outerAxis; SmallVector innerAxis; @@ -152,8 +154,10 @@ class GlobalAnalysis { }; namespace utils { +bool isSupportedContractionNamedOp(linalg::LinalgOp &linalgOp); + bool isPackableNamedOp(Operation *op); -} +} // namespace utils } // namespace gc } // namespace mlir diff --git a/include/gc/Analysis/MatmulConfigAnalysis.h b/include/gc/Analysis/MatmulConfigAnalysis.h index 2b275f246..3507f6edc 100644 --- a/include/gc/Analysis/MatmulConfigAnalysis.h +++ b/include/gc/Analysis/MatmulConfigAnalysis.h @@ -128,6 +128,12 @@ getOprandDimType(linalg::LinalgOp &linalgOp) { SmallVector{DimType::M, DimType::K, DimType::M, DimType::K}, SmallVector{DimType::N, DimType::K, DimType::K, DimType::N}, SmallVector{DimType::M, DimType::N, DimType::M, DimType::N}}; + } else if (linalgx::isGenericPackedMatmulOp(linalgOp.getOperation(), + linalgx::PackingType::MM2D4D)) { + return SmallVector>{ + SmallVector{DimType::M, DimType::K}, + SmallVector{DimType::N, DimType::K, DimType::K, DimType::N}, + SmallVector{DimType::M, DimType::N}}; } return failure(); } diff --git a/include/gc/Dialect/Linalgx/Utils.h b/include/gc/Dialect/Linalgx/Utils.h index 5bc83b449..d5281b60a 100644 --- a/include/gc/Dialect/Linalgx/Utils.h +++ b/include/gc/Dialect/Linalgx/Utils.h @@ -20,6 +20,7 @@ namespace linalgx { /// @brief enum of type of matmul packing enum class PackingType : int { MM4D = 0, // MKmk x NKkn + MM2D4D, // MK x NKkn VNNI_MM2D, // MK x NKknV VNNI_MM4D, // MKmk x NKknV VNNI_BRMM3D, // BMK x BKNV diff --git a/lib/gc/Analysis/GlobalAnalysis.cpp b/lib/gc/Analysis/GlobalAnalysis.cpp index c254709ca..3365b0fab 100644 --- a/lib/gc/Analysis/GlobalAnalysis.cpp +++ b/lib/gc/Analysis/GlobalAnalysis.cpp @@ -18,10 +18,10 @@ namespace gc { #define DEBUG_TYPE "global-analysis" llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, - const TensorLayout &layoutCache) { - SmallVector outerAxis = layoutCache.getOuterAxis(); - SmallVector innerAxis = layoutCache.getInnerAxis(); - SmallVector tileSizes = layoutCache.getTileSizes(); + const TensorLayout &tmpLayoutCache) { + SmallVector outerAxis = tmpLayoutCache.getOuterAxis(); + SmallVector innerAxis = tmpLayoutCache.getInnerAxis(); + SmallVector tileSizes = tmpLayoutCache.getTileSizes(); ss << "["; llvm::interleaveComma(outerAxis, ss); if (!innerAxis.empty()) { @@ -43,6 +43,10 @@ bool TensorLayout::operator==(const TensorLayout &layout) const { (this->tileSizes == layout.getTileSizes()); } +bool TensorLayout::operator!=(const TensorLayout &layout) const { + return !(*this == layout); +} + llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, const OperatorLayout &opLayout) { if (!opLayout.getSupportedInputLayouts().empty()) { @@ -161,15 +165,6 @@ static size_t getTargetInputIdx(ArrayRef curInputLayouts) { return 0; } -static bool supportedContractionNamedOpList(linalg::LinalgOp &linalgOp) { - if (isa( - linalgOp)) - return true; - return false; -} - std::pair, SmallVector> getPackingAxis(int64_t numRank, bool transposed) { assert(numRank >= 2 && @@ -220,8 +215,9 @@ static bool isDimsDivisibleByTileSizes(ArrayRef dimsPos, return true; } -// if forceBlocking is set, we strictly follow matmul config to block to -// blocking layout; otherwise we follow query format logic +// if forceBlocking is set to true, we will unconditionally convert +// input/weight/output to blocking layout; otherwise we follow the default +// heuristic logic static SmallVector queryMatmulLayout(IRRewriter &rewriter, linalg::LinalgOp matmulOp, ArrayRef curInputLayouts, @@ -247,7 +243,7 @@ queryMatmulLayout(IRRewriter &rewriter, linalg::LinalgOp matmulOp, auto CPackInfo = getPackingAxis(CRank, /*transposed*/ false); // query the cost model for tile sizes MatmulConfig cfg = MatmulConfigAnalysis(matmulOp.getOperation()).getConfig(); - uint32_t iim = cfg.innerMostKBlock, iin = cfg.innerMostNBlock, + uint32_t iim = cfg.innerMostMBlock, iin = cfg.innerMostNBlock, iik = cfg.innerMostKBlock; if (forceBlocking) { TensorLayout ALayout(APackInfo.first, APackInfo.second, @@ -263,21 +259,12 @@ queryMatmulLayout(IRRewriter &rewriter, linalg::LinalgOp matmulOp, SmallVector{CLayout}); return ret; } - // TODO(yifei): add condition constant_A - TensorLayout transposedLayout({1, 0}, {}, {}); + // TODO(yifei): add detailed check for constant A or B + bool constantA = false, constantB = true; SmallVector ALayouts, BLayouts, CLayouts; - if (curInputLayouts[0].isBlocking() || (M % iim) || (K % iik) || - (elementType.isBF16() && curInputLayouts[0] == transposedLayout)) { - ALayouts.emplace_back( - APackInfo.first, APackInfo.second, - SmallVector{rewriter.getIndexAttr(iim), - rewriter.getIndexAttr(iik)}); - } else { - ALayouts.emplace_back(APackInfo.first, SmallVector{}, - SmallVector{}); - } - if (curInputLayouts[0].isBlocking() || (M % iim) || (K % iik) || - (elementType.isBF16() && curInputLayouts[0] == transposedLayout)) { + if (constantA || curInputLayouts[0].isBlocking() || (M % iim) || (K % iik) || + (elementType.isBF16() && + curInputLayouts[0] == TensorLayout({1, 0}, {}, {}))) { ALayouts.emplace_back( APackInfo.first, APackInfo.second, SmallVector{rewriter.getIndexAttr(iim), @@ -286,7 +273,7 @@ queryMatmulLayout(IRRewriter &rewriter, linalg::LinalgOp matmulOp, ALayouts.emplace_back(APackInfo.first, SmallVector{}, SmallVector{}); } - if (curInputLayouts[1].isBlocking() || K % iik || N % iin || + if (constantB || curInputLayouts[1].isBlocking() || K % iik || N % iin || elementType.isBF16()) { BLayouts.emplace_back( BPackInfo.first, BPackInfo.second, @@ -316,6 +303,7 @@ queryMatmulLayout(IRRewriter &rewriter, linalg::LinalgOp matmulOp, CPackInfo.first, CPackInfo.second, SmallVector{rewriter.getIndexAttr(iim), rewriter.getIndexAttr(iin)}); + // duplicate ALayouts and BLayouts ALayouts.emplace_back(ALayouts[0]); BLayouts.emplace_back(BLayouts[0]); } @@ -330,192 +318,272 @@ queryMatmulLayout(IRRewriter &rewriter, linalg::LinalgOp matmulOp, GlobalAnalysis::GlobalAnalysis(Operation *root) { IRRewriter rewriter(root); + int64_t totalLayoutPossibilities = 1; + std::vector possibilities; + int64_t numMatmuls = 0; root->walk([&](Operation *op) { if (auto linalgOp = dyn_cast(op)) { - auto curInputs = linalgOp.getDpsInputOperands(); - auto curResults = linalgOp.getOperation()->getResults(); - // get current op's input layouts - SmallVector curInputLayouts; - for (auto input : curInputs) { - auto parent = input->get().getDefiningOp(); - if (layoutCache.find(parent) != layoutCache.end()) { - // TODO(yifei): it is not always 0 here - curInputLayouts.push_back(layoutCache[parent].getOutputLayout(0)); - } else { + if (mlir::gc::utils::isSupportedContractionNamedOp(linalgOp)) { + auto curInputs = linalgOp.getDpsInputOperands(); + SmallVector curInputLayouts; + for (auto input : curInputs) curInputLayouts.push_back(TensorLayout::createPlainLayout( linalgOp.getMatchingIndexingMap(input).getNumResults())); - } - } - // infer current op's output layout accordingly - if (supportedContractionNamedOpList(linalgOp)) { auto suggestedLayouts = - queryMatmulLayout(rewriter, linalgOp, curInputLayouts, true); - layoutCache[linalgOp] = suggestedLayouts[0]; - } else if (mlir::gc::utils::isPackableNamedOp(op)) { - // infer layout for non-contraction/non-convolution linalg named ops - // and linalg generic ops - SmallVector inputLayouts, outputLayouts; - size_t targetIdx = getTargetInputIdx(curInputLayouts); - for (size_t i = 0; i < curInputs.size(); ++i) { - // getMatchingIndexingMap - if (i != targetIdx) { - auto indexRelation = inferIndexingMapRelation( - linalgOp.getMatchingIndexingMap(curInputs[targetIdx]), - linalgOp.getMatchingIndexingMap(curInputs[i])); - if (failed(indexRelation)) { - return WalkResult::skip(); - } - TensorLayout inputLayout = - inferTargetLayout(curInputLayouts[targetIdx], *indexRelation); - inputLayouts.push_back(inputLayout); + queryMatmulLayout(rewriter, linalgOp, curInputLayouts); + possibilities.push_back(suggestedLayouts.size()); + totalLayoutPossibilities *= possibilities.back(); + numMatmuls++; + } + } + return WalkResult::advance(); + }); + auto computePackingCost = + [&](linalg::LinalgOp linalgOp, ArrayRef curInputLayouts, + ArrayRef suggestedLayout) -> int64_t { + int64_t cost = 0; + for (auto [operand, curLayout, suggestedLayout] : + llvm::zip(linalgOp.getDpsInputOperands(), curInputLayouts, + suggestedLayout)) { + if (curLayout != suggestedLayout) { + ArrayRef shape = linalgOp.getShape(operand); + int64_t inputSize = std::accumulate( + shape.begin(), shape.end(), (int64_t)1, std::multiplies()); + if (suggestedLayout.isBlocking()) + cost += inputSize * 0.9; + else + cost += inputSize; + } + } + return cost; + }; + std::vector curChoice(possibilities.size(), 0); + int64_t bestCost = std::numeric_limits::max(); + for (int64_t trialIdx = 0; trialIdx < totalLayoutPossibilities; ++trialIdx) { + // trialIdx to map + int64_t tmpIdx = trialIdx; + for (size_t i = 0; i < possibilities.size(); i++) { + curChoice[i] = tmpIdx % possibilities[i]; + tmpIdx /= possibilities[i]; + } + LLVM_DEBUG(llvm::dbgs() << "Inferring with layout choice: ["); + LLVM_DEBUG(llvm::interleaveComma(curChoice, llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "].\n"); + int64_t curMatmulIdx = 0; + int64_t curCost = 0; + DenseMap tmpLayoutCache; + root->walk([&](Operation *op) { + if (auto linalgOp = dyn_cast(op)) { + auto curInputs = linalgOp.getDpsInputOperands(); + auto curResults = linalgOp.getOperation()->getResults(); + // get current op's input layouts + SmallVector curInputLayouts; + for (auto input : curInputs) { + auto parent = input->get().getDefiningOp(); + if (tmpLayoutCache.find(parent) != tmpLayoutCache.end()) { + // TODO(yifei): it is not always 0 here + curInputLayouts.push_back( + tmpLayoutCache[parent].getOutputLayout(0)); } else { - inputLayouts.push_back(curInputLayouts[targetIdx]); + curInputLayouts.push_back(TensorLayout::createPlainLayout( + linalgOp.getMatchingIndexingMap(input).getNumResults())); } } - auto indexRelation = inferIndexingMapRelation( - linalgOp.getMatchingIndexingMap(curInputs[targetIdx]), - linalgOp.getIndexingMapMatchingResult(curResults[0])); - if (failed(indexRelation)) { - return WalkResult::skip(); + // infer current op's output layout accordingly + if (mlir::gc::utils::isSupportedContractionNamedOp(linalgOp)) { + auto suggestedLayouts = + queryMatmulLayout(rewriter, linalgOp, curInputLayouts, false); + tmpLayoutCache[linalgOp] = + suggestedLayouts[curChoice[curMatmulIdx++]]; + curCost += computePackingCost( + linalgOp, curInputLayouts, + tmpLayoutCache[linalgOp].getSupportedInputLayouts()); + } else if (mlir::gc::utils::isPackableNamedOp(op)) { + // infer layout for non-contraction/non-convolution linalg named ops + // and linalg generic ops + SmallVector inputLayouts, outputLayouts; + size_t targetIdx = getTargetInputIdx(curInputLayouts); + for (size_t i = 0; i < curInputs.size(); ++i) { + // getMatchingIndexingMap + if (i != targetIdx) { + auto indexRelation = inferIndexingMapRelation( + linalgOp.getMatchingIndexingMap(curInputs[targetIdx]), + linalgOp.getMatchingIndexingMap(curInputs[i])); + if (failed(indexRelation)) { + return WalkResult::skip(); + } + TensorLayout inputLayout = + inferTargetLayout(curInputLayouts[targetIdx], *indexRelation); + inputLayouts.push_back(inputLayout); + } else { + inputLayouts.push_back(curInputLayouts[targetIdx]); + } + } + auto indexRelation = inferIndexingMapRelation( + linalgOp.getMatchingIndexingMap(curInputs[targetIdx]), + linalgOp.getIndexingMapMatchingResult(curResults[0])); + if (failed(indexRelation)) { + return WalkResult::skip(); + } + TensorLayout outputLayout = + inferTargetLayout(curInputLayouts[targetIdx], *indexRelation); + outputLayouts.push_back(outputLayout); + OperatorLayout suggestedLayout(inputLayouts, outputLayouts); + tmpLayoutCache[linalgOp] = suggestedLayout; + curCost += + computePackingCost(linalgOp, curInputLayouts, inputLayouts); } - TensorLayout outputLayout = - inferTargetLayout(curInputLayouts[targetIdx], *indexRelation); - outputLayouts.push_back(outputLayout); + } else if (auto padOp = dyn_cast(op)) { + auto inputOperand = padOp.getSource(); + auto inputRank = + cast(inputOperand.getType()).getShape().size(); + auto parent = inputOperand.getDefiningOp(); + TensorLayout curInputLayout = + tmpLayoutCache.find(parent) != tmpLayoutCache.end() + ? tmpLayoutCache[parent].getOutputLayout(0) + : TensorLayout::createPlainLayout(inputRank); + SmallVector inputLayouts{curInputLayout}, + outputLayouts{curInputLayout}; OperatorLayout suggestedLayout(inputLayouts, outputLayouts); - layoutCache[linalgOp] = suggestedLayout; - } - } else if (auto padOp = dyn_cast(op)) { - auto inputOperand = padOp.getSource(); - auto inputRank = - cast(inputOperand.getType()).getShape().size(); - auto parent = inputOperand.getDefiningOp(); - TensorLayout curInputLayout = - layoutCache.find(parent) != layoutCache.end() - ? layoutCache[parent].getOutputLayout(0) - : TensorLayout::createPlainLayout(inputRank); - SmallVector inputLayouts{curInputLayout}, - outputLayouts{curInputLayout}; - OperatorLayout suggestedLayout(inputLayouts, outputLayouts); - layoutCache[padOp] = suggestedLayout; - } else if (auto expandShapeOp = dyn_cast(op)) { - SmallVector reassocIndices = - expandShapeOp.getReassociationIndices(); - auto staticOutputShape = expandShapeOp.getStaticOutputShape(); - auto parent = expandShapeOp.getSrc().getDefiningOp(); - auto inputShape = expandShapeOp.getSrcType().getShape(); - TensorLayout curInputLayout = - layoutCache.find(parent) != layoutCache.end() - ? layoutCache[parent].getOutputLayout(0) - : TensorLayout::createPlainLayout(inputShape.size()); - SmallVector innerTileSizes; - auto tileSizes = getConstantIntValues(curInputLayout.getTileSizes()); - if (tileSizes) { - innerTileSizes = *tileSizes; - } else { - return WalkResult::skip(); - } - ArrayRef innerPosPos = curInputLayout.getInnerAxis(); - ArrayRef outerDimsPerm = curInputLayout.getOuterAxis(); - SmallVector projectedInnerDimsPos = - projectToInnerMostNonUnitDimsPos(innerPosPos, reassocIndices, - staticOutputShape); - - if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, staticOutputShape, - innerTileSizes)) { - return WalkResult::skip(); - } - SmallVector newOuterDimsPerm; - for (auto outerPos : outerDimsPerm) { - newOuterDimsPerm.insert(newOuterDimsPerm.end(), - reassocIndices[outerPos].begin(), - reassocIndices[outerPos].end()); - } - TensorLayout outputLayout(newOuterDimsPerm, projectedInnerDimsPos, - curInputLayout.getTileSizes()); - SmallVector inputLayouts{curInputLayout}, - outputLayouts{outputLayout}; - OperatorLayout suggestedLayout(inputLayouts, outputLayouts); - layoutCache[expandShapeOp] = suggestedLayout; - } else if (auto collapseShapeOp = dyn_cast(op)) { - SmallVector reassocIndices = - collapseShapeOp.getReassociationIndices(); - auto parent = collapseShapeOp.getSrc().getDefiningOp(); - auto inputShape = collapseShapeOp.getSrcType().getShape(); - TensorLayout curInputLayout = - layoutCache.find(parent) != layoutCache.end() - ? layoutCache[parent].getOutputLayout(0) - : TensorLayout::createPlainLayout(inputShape.size()); - auto innerPos = curInputLayout.getInnerAxis(); - llvm::SetVector innerPosSet(innerPos.begin(), innerPos.end()); - for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { - // For each reassociation, figure out which dimensions get packed if - // any. - llvm::SetVector collapseDimPos(indices.begin(), indices.end()); - llvm::SetVector packedDims = - llvm::set_intersection(innerPosSet, collapseDimPos); - // only one of the collapsed indices can be packed - if (packedDims.size() > 1) - return WalkResult::skip(); - // Only the inner-most expanded dimension should be packed. Otherwise, - // elements order will be affected after operation reordering. - if (!packedDims.empty() && packedDims[0] != indices.back()) + tmpLayoutCache[padOp] = suggestedLayout; + } else if (auto expandShapeOp = dyn_cast(op)) { + SmallVector reassocIndices = + expandShapeOp.getReassociationIndices(); + auto staticOutputShape = expandShapeOp.getStaticOutputShape(); + auto parent = expandShapeOp.getSrc().getDefiningOp(); + auto inputShape = expandShapeOp.getSrcType().getShape(); + TensorLayout curInputLayout = + tmpLayoutCache.find(parent) != tmpLayoutCache.end() + ? tmpLayoutCache[parent].getOutputLayout(0) + : TensorLayout::createPlainLayout(inputShape.size()); + SmallVector innerTileSizes; + auto tileSizes = getConstantIntValues(curInputLayout.getTileSizes()); + if (tileSizes) { + innerTileSizes = *tileSizes; + } else { return WalkResult::skip(); - } + } + ArrayRef innerPosPos = curInputLayout.getInnerAxis(); + ArrayRef outerDimsPerm = curInputLayout.getOuterAxis(); + SmallVector projectedInnerDimsPos = + projectToInnerMostNonUnitDimsPos(innerPosPos, reassocIndices, + staticOutputShape); - // Project pack.inner_dims_pos to positions before shape expansion. - SmallVector projectedInnerDimsPos; - for (auto pos : innerPos) { + if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, + staticOutputShape, innerTileSizes)) { + return WalkResult::skip(); + } + SmallVector newOuterDimsPerm; + for (auto outerPos : outerDimsPerm) { + newOuterDimsPerm.insert(newOuterDimsPerm.end(), + reassocIndices[outerPos].begin(), + reassocIndices[outerPos].end()); + } + TensorLayout outputLayout(newOuterDimsPerm, projectedInnerDimsPos, + curInputLayout.getTileSizes()); + SmallVector inputLayouts{curInputLayout}, + outputLayouts{outputLayout}; + OperatorLayout suggestedLayout(inputLayouts, outputLayouts); + tmpLayoutCache[expandShapeOp] = suggestedLayout; + } else if (auto collapseShapeOp = dyn_cast(op)) { + SmallVector reassocIndices = + collapseShapeOp.getReassociationIndices(); + auto parent = collapseShapeOp.getSrc().getDefiningOp(); + auto inputShape = collapseShapeOp.getSrcType().getShape(); + TensorLayout curInputLayout = + tmpLayoutCache.find(parent) != tmpLayoutCache.end() + ? tmpLayoutCache[parent].getOutputLayout(0) + : TensorLayout::createPlainLayout(inputShape.size()); + auto innerPos = curInputLayout.getInnerAxis(); + llvm::SetVector innerPosSet(innerPos.begin(), innerPos.end()); for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { - if (llvm::any_of(indices, [&](int64_t collapseDim) { - return collapseDim == pos; - })) { - projectedInnerDimsPos.push_back(idx); - break; + // For each reassociation, figure out which dimensions get packed if + // any. + llvm::SetVector collapseDimPos(indices.begin(), + indices.end()); + llvm::SetVector packedDims = + llvm::set_intersection(innerPosSet, collapseDimPos); + // only one of the collapsed indices can be packed + if (packedDims.size() > 1) + return WalkResult::skip(); + // Only the inner-most expanded dimension should be packed. Otherwise, + // elements order will be affected after operation reordering. + if (!packedDims.empty() && packedDims[0] != indices.back()) + return WalkResult::skip(); + } + + // Project pack.inner_dims_pos to positions before shape expansion. + SmallVector projectedInnerDimsPos; + for (auto pos : innerPos) { + for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { + if (llvm::any_of(indices, [&](int64_t collapseDim) { + return collapseDim == pos; + })) { + projectedInnerDimsPos.push_back(idx); + break; + } } } - } - assert(projectedInnerDimsPos.size() == innerPos.size() && - "Invalid dim pos projection"); + assert(projectedInnerDimsPos.size() == innerPos.size() && + "Invalid dim pos projection"); - // outerPerm shall be a permutation of reassocIndices - auto outerPerm = curInputLayout.getOuterAxis(); - SmallVector newOuterDimsPerm; - int64_t axisIdx = 0; - while (axisIdx < outerPerm.size()) { - for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { - if (llvm::any_of(indices, [&](int64_t collapseDim) { - return collapseDim == outerPerm[axisIdx]; - })) { - for (auto collapseDim : indices) { - if (collapseDim != outerPerm[axisIdx++]) - return WalkResult::skip(); + // outerPerm shall be a permutation of reassocIndices + auto outerPerm = curInputLayout.getOuterAxis(); + SmallVector newOuterDimsPerm; + int64_t axisIdx = 0; + while (axisIdx < outerPerm.size()) { + for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { + if (llvm::any_of(indices, [&](int64_t collapseDim) { + return collapseDim == outerPerm[axisIdx]; + })) { + for (auto collapseDim : indices) { + if (collapseDim != outerPerm[axisIdx++]) + return WalkResult::skip(); + } + newOuterDimsPerm.push_back(idx); + break; } - newOuterDimsPerm.push_back(idx); - break; } } + TensorLayout outputLayout(newOuterDimsPerm, projectedInnerDimsPos, + curInputLayout.getTileSizes()); + SmallVector inputLayouts{curInputLayout}, + outputLayouts{outputLayout}; + OperatorLayout suggestedLayout(inputLayouts, outputLayouts); + tmpLayoutCache[collapseShapeOp] = suggestedLayout; } - TensorLayout outputLayout(newOuterDimsPerm, projectedInnerDimsPos, - curInputLayout.getTileSizes()); - SmallVector inputLayouts{curInputLayout}, - outputLayouts{outputLayout}; - OperatorLayout suggestedLayout(inputLayouts, outputLayouts); - layoutCache[collapseShapeOp] = suggestedLayout; - } - if (layoutCache.find(op) != layoutCache.end()) { - LLVM_DEBUG(llvm::dbgs() << "Inferred layout of op: " << op->getName() - << " is: " << layoutCache[op] << "\n"); + if (tmpLayoutCache.find(op) != tmpLayoutCache.end()) { + LLVM_DEBUG(llvm::dbgs() << "Inferred layout of op: " << op->getName() + << " is: " << tmpLayoutCache[op] << "\n"); + } + return WalkResult::advance(); + }); + if (curCost < bestCost) { + bestCost = curCost; + layoutCache = tmpLayoutCache; + LLVM_DEBUG(llvm::dbgs() + << "Current cost " << curCost + << " is lower than the best cost; update best cost." + << "\n"); } - return WalkResult::advance(); - }); + } } namespace utils { +bool isSupportedContractionNamedOp(linalg::LinalgOp &linalgOp) { + if (isa( + linalgOp)) + return true; + return false; +} + bool isPackableNamedOp(Operation *op) { if (auto linalgOp = dyn_cast(op)) { if (!mlir::linalg::isaContractionOpInterface(linalgOp) && !isa(linalgOp.getOperation()) && - !supportedContractionNamedOpList(linalgOp)) { + !isSupportedContractionNamedOp(linalgOp)) { return true; } } else if (isa( diff --git a/lib/gc/Dialect/Linalgx/Utils.cpp b/lib/gc/Dialect/Linalgx/Utils.cpp index fe9096fe7..683038940 100644 --- a/lib/gc/Dialect/Linalgx/Utils.cpp +++ b/lib/gc/Dialect/Linalgx/Utils.cpp @@ -341,6 +341,12 @@ PackingAttr getPackingAttr(PackingType opType) { attr.nPacking = {PackingMap{{0}, {1}}, PackingMap{{3}, {3}}}; attr.kPacking = {PackingMap{{1}, {1}}, PackingMap{{3}, {2}}}; } break; + case PackingType::MM2D4D: { + attr.weightDims = 4; + attr.mPacking = {PackingMap{{0}, {0}}}; + attr.nPacking = {PackingMap{{0, 3}, {1}}}; + attr.kPacking = {PackingMap{{1}, {1, 2}}}; + } break; case PackingType::VNNI_MM2D: { attr.isVnni = true; attr.weightDims = 5; diff --git a/lib/gc/Transforms/DeepTileContractionOp.cpp b/lib/gc/Transforms/DeepTileContractionOp.cpp index 21de7b778..8805c9dc5 100644 --- a/lib/gc/Transforms/DeepTileContractionOp.cpp +++ b/lib/gc/Transforms/DeepTileContractionOp.cpp @@ -952,7 +952,8 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern { return llvm::isa(linalgOp) || linalgx::isGenericPackedMatmulOp( linalgOp.getOperation(), linalgx::PackingType::VNNI_MM2D, - linalgx::PackingType::VNNI_MM4D, linalgx::PackingType::MM4D); + linalgx::PackingType::VNNI_MM4D, linalgx::PackingType::MM4D, + linalgx::PackingType::MM2D4D); } LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp index 9e8ea2372..fd2e10caa 100644 --- a/lib/gc/Transforms/PropagateLayout.cpp +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -20,8 +20,10 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/DenseMap.h" +#include "gc/Analysis/MatmulConfigAnalysis.h" #include "gc/Dialect/Linalgx/LinalgxDialect.h" #include "gc/Dialect/Linalgx/LinalgxOps.h" +#include "gc/Dialect/Linalgx/Utils.h" #include "gc/Transforms/Passes.h" namespace mlir { namespace gc { @@ -441,14 +443,16 @@ static LogicalResult createAndReplaceWithGenericVNNIMatmul( } template -static LogicalResult packVNNIMMT4D(RewriterBase &rewriter, OpTy mmt4dOp) { +static LogicalResult packVNNIMMT4D(RewriterBase &rewriter, OpTy mmt4dOp, + bool useNamedOp = false) { auto elementType = getElementTypeOrSelf(mmt4dOp.getInputs()[0].getType()); if (!elementType.isBF16() && !elementType.isInteger(8)) return rewriter.notifyMatchFailure(mmt4dOp, "require bf16/int8 data type"); Location loc = mmt4dOp.getLoc(); // BNKnk --> BNKkn2k - int64_t weightRank = - cast(mmt4dOp.getInputs()[1].getType()).getShape().size(); + auto weightShape = + cast(mmt4dOp.getInputs()[1].getType()).getShape(); + int64_t weightRank = weightShape.size(); // pack innermost k axis SmallVector innerPos{weightRank - 1}; int64_t blockingFactor = elementType.isBF16() ? 2 : 4; @@ -467,37 +471,26 @@ static LogicalResult packVNNIMMT4D(RewriterBase &rewriter, OpTy mmt4dOp) { Value zero = rewriter.create(loc, zeroAttr); Value VNNIPack = rewriter.create( loc, RHSOperand->get(), dest, innerPos, tileSize, zero, outerPerm); + // check whether VNNIPack causes padding + int64_t innermostKDim = weightShape[weightRank - 1]; + int64_t paddingSize = (innermostKDim % blockingFactor) + ? (blockingFactor - innermostKDim % blockingFactor) + : 0; + assert(!paddingSize && "Padding shall not be introduced by VNNI pack."); SmallVector inputsValues{mmt4dOp.getInputs()[0], VNNIPack}; - LogicalResult result = mlir::gc::createAndReplaceWithGenericVNNIMatmul( - rewriter, mmt4dOp.getContext(), inputsValues, mmt4dOp.getDpsInits(), - batchDimSize, blockingFactor, mmt4dOp); - return result; + FailureOr op = linalgx::makeGenericPackedMatmulOp( + rewriter, loc, linalgx::PackingType::VNNI_MM4D, inputsValues, + mmt4dOp.getDpsInits()); + if (failed(op)) + return failure(); + rewriter.replaceOp(mmt4dOp, *op); + return success(); } // strictly check whether the packed matmul is BMKmk & BNKkn static bool isMM4DMatmul(linalg::GenericOp matmulOp) { - SmallVector indexingMaps = matmulOp.getIndexingMapsArray(); - auto iterators = matmulOp.getIteratorTypesArray(); - AffineMap inputMap = indexingMaps[0], weightMap = indexingMaps[1], - outputMap = indexingMaps[2]; - int64_t inputRank = inputMap.getNumResults(), - weightRank = weightMap.getNumResults(), - outputRank = outputMap.getNumResults(); - // check rank - if ((weightRank < 4) || (inputRank != weightRank) || - (weightRank != outputRank)) - return false; - // check mapping --> find batch, M, N, K - FailureOr res = - mlir::linalg::inferContractionDims(matmulOp); - assert(succeeded(res) && "unexpected failure in infer contraction dims"); - unsigned batchDimSize = res->batch.size(); - SmallVector expectedM{batchDimSize, batchDimSize + 3}; - SmallVector expectedN{batchDimSize + 1, batchDimSize + 4}; - SmallVector expectedK{batchDimSize + 2, batchDimSize + 5}; - if (expectedM == res->m && expectedN == res->n && expectedK == res->k) - return true; - return false; + return linalgx::isGenericPackedMatmulOp(matmulOp.getOperation(), + linalgx::PackingType::MM4D); } /* @@ -505,7 +498,8 @@ If possible, pack to Mm2DVnniOp or Mm4DVnniOp. If not possible, pack to GenericOp. */ static LogicalResult packVNNIGeneric(RewriterBase &rewriter, - linalg::GenericOp matmulOp) { + linalg::GenericOp matmulOp, + bool useNamedOp = false) { if (matmulOp.getDpsInputs().size() != 2) return rewriter.notifyMatchFailure(matmulOp, "require 2 inputs"); @@ -539,9 +533,10 @@ static LogicalResult packVNNIGeneric(RewriterBase &rewriter, Location loc = matmulOp.getLoc(); int64_t blockingFactor = elementType.isBF16() ? 2 : 4; SmallVector tileSize{rewriter.getIndexAttr(blockingFactor)}; - // get weight's rank - int64_t weightRank = - cast(weight.get().getType()).getShape().size(); + // BNKkn, get weight's rank + auto weightShape = + cast(matmulOp.getInputs()[1].getType()).getShape(); + int64_t weightRank = weightShape.size(); auto innerPos = SmallVector{weightRank - 2}; // pack weight Value dest = tensor::PackOp::createDestinationTensor( @@ -553,10 +548,19 @@ static LogicalResult packVNNIGeneric(RewriterBase &rewriter, int64_t batchDimSize = weightRank - 4; SmallVector inputsValues{matmulOp.getInputs()[0], VNNIPack}; - LogicalResult result = mlir::gc::createAndReplaceWithGenericVNNIMatmul( - rewriter, matmulOp.getContext(), inputsValues, matmulOp.getDpsInits(), - batchDimSize, blockingFactor, matmulOp); - return result; + // check whether VNNIPack causes padding, weightShape is BNKkn + int64_t innermostKDim = weightShape[weightRank - 2]; + int64_t paddingSize = (innermostKDim % blockingFactor) + ? (blockingFactor - innermostKDim % blockingFactor) + : 0; + assert(!paddingSize && "Padding shall not be introduced by VNNI pack."); + FailureOr op = linalgx::makeGenericPackedMatmulOp( + rewriter, loc, linalgx::PackingType::VNNI_MM4D, inputsValues, + matmulOp.getDpsInits()); + if (failed(op)) + return failure(); + rewriter.replaceOp(matmulOp, *op); + return success(); } template struct PackVNNI : public OpRewritePattern { @@ -584,6 +588,105 @@ struct PackVNNI } }; +static FailureOr +shallRevertToType(linalg::GenericOp matmulOp) { + if (linalgx::isGenericPackedMatmulOp(matmulOp.getOperation(), + linalgx::PackingType::MM4D)) + return linalgx::PackingType::MM2D4D; + else if (linalgx::isGenericPackedMatmulOp(matmulOp.getOperation(), + linalgx::PackingType::VNNI_MM4D)) + return linalgx::PackingType::VNNI_MM2D; + return failure(); +} + +static bool isPlainActivationMatmul(OperatorLayout matmulLayout) { + auto inputLayout = matmulLayout.getSupportedInputLayouts()[0]; + auto outputLayout = matmulLayout.getSupportedInputLayouts()[0]; + return !inputLayout.isBlocking() && !outputLayout.isBlocking(); +} + +static LogicalResult +revertMatmulPacking(MLIRContext *ctx, mlir::Operation *graph, + const std::vector &matmulLayouts) { + IRRewriter rewriter(ctx); + uint64_t layoutOffset = 0; + graph->walk([&](Operation *op) { + if (auto matmulOp = dyn_cast(op)) { + FailureOr revertType = shallRevertToType(matmulOp); + if (succeeded(revertType) && + isPlainActivationMatmul(matmulLayouts[layoutOffset])) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + // replace VNNI_MM4D with unpack + VNNI_MM2D + pack + // get preceding pack and successive unpack + auto packInputOp = matmulOp.getDpsInputOperand(0) + ->get() + .getDefiningOp(); + auto packInitOp = matmulOp.getDpsInitOperand(0) + ->get() + .getDefiningOp(); + if (!packInputOp || !packInitOp) + return WalkResult::skip(); + if (!matmulOp.getResults()[0].hasOneUse()) + return WalkResult::skip(); + auto consumer = matmulOp.getResults()[0].getUses().begin(); + auto unPackOp = dyn_cast(consumer->getOwner()); + if (!unPackOp) + return WalkResult::skip(); + Location loc = matmulOp.getLoc(); + // unpack input + auto packInputInnerTiles = packInputOp.getMixedTiles(); + auto packInputInnerDimsPos = packInputOp.getInnerDimsPos(); + auto packInputOuterDimsPerm = packInputOp.getInnerDimsPos(); + Value unpackInputDest = tensor::UnPackOp::createDestinationTensor( + rewriter, loc, packInputOp, packInputInnerTiles, + packInputInnerDimsPos, packInputOuterDimsPerm); + Value reUnpackInput = rewriter.create( + loc, packInputOp, unpackInputDest, packInputInnerDimsPos, + packInputInnerTiles, packInputOuterDimsPerm); + // unpack init + auto packInitInnerTiles = packInitOp.getMixedTiles(); + auto packInitInnerDimsPos = packInitOp.getInnerDimsPos(); + auto packInitOuterDimsPerm = packInitOp.getInnerDimsPos(); + Value unpackInitDest = tensor::UnPackOp::createDestinationTensor( + rewriter, loc, packInitOp, packInitInnerTiles, packInitInnerDimsPos, + packInitOuterDimsPerm); + Value reUnpackInit = rewriter.create( + loc, packInitOp, unpackInitDest, packInitInnerDimsPos, + packInitInnerTiles, packInitOuterDimsPerm); + // replace vnni_4D with vnni_2D + auto VNNI2D = linalgx::makeGenericPackedMatmulOp( + rewriter, loc, *revertType, + ValueRange{reUnpackInput, matmulOp.getDpsInputOperand(1)->get()}, + ValueRange{reUnpackInit}); + if (failed(VNNI2D)) + return WalkResult::interrupt(); + // insert pack before unpack + auto unPackInnerTiles = unPackOp.getMixedTiles(); + auto unPackInnerDimsPos = unPackOp.getInnerDimsPos(); + auto unPackOuterDimsPerm = unPackOp.getInnerDimsPos(); + Value packDest = tensor::PackOp::createDestinationTensor( + rewriter, loc, (*VNNI2D)->getResult(0), unPackInnerTiles, + unPackInnerDimsPos, unPackOuterDimsPerm); + auto zeroAttr = + rewriter.getZeroAttr(getElementTypeOrSelf(packDest.getType())); + Value zero = rewriter.create(loc, zeroAttr); + Value rePack = rewriter.create( + loc, (*VNNI2D)->getResult(0), packDest, unPackInnerDimsPos, + unPackInnerTiles, zero, unPackOuterDimsPerm); + rewriter.replaceOp(op, rePack); + layoutOffset++; + } + } else if (auto matmulOp = dyn_cast(op)) { + if (mlir::gc::utils::isSupportedContractionNamedOp(matmulOp)) { + layoutOffset++; + } + } + return WalkResult::advance(); + }); + return success(); +} + /* Match patterns like broadcast + pack, uplift pack */ @@ -664,31 +767,56 @@ void PropagateLayoutOnNamedOps::runOnOperation() { MLIRContext *ctx = &getContext(); IRRewriter rewriter(ctx); mlir::Operation *graph = getOperation(); - // stage1: pack matmul + // collect matmul layouts in topological order + auto &layoutAnalysisResult = getAnalysis(); + std::vector matmulLayouts; + graph->walk([&](Operation *op) { + if (auto linalgOp = dyn_cast(op)) { + if (mlir::gc::utils::isSupportedContractionNamedOp(linalgOp)) { + matmulLayouts.push_back(*(layoutAnalysisResult.getOpLayout(op))); + } + } + return WalkResult::advance(); + }); + // stage 1.1: pack matmul with `BlockPackMatmulPatterns` if any side of it + // requires packing; do nothing if the matmul is computed on plain format + // TODO(yifei): deal with transposed plain matmul... RewritePatternSet packMatmulPatterns(&getContext()); mlir::linalg::ControlBlockPackMatmulFn packMatmulControlFn = [&](linalg::LinalgOp op) -> mlir::linalg::BlockPackMatmulOptions { mlir::linalg::BlockPackMatmulOptions options; - auto &layoutAnalysisResult = getAnalysis(); auto matmulLayout = *(layoutAnalysisResult.getOpLayout(op)); - TensorLayout LHSLayout = matmulLayout.getSupportedInputLayouts()[0]; - TensorLayout RHSLayout = matmulLayout.getSupportedInputLayouts()[1]; - // hardcode to let B side to be NKkn + // currently supported combination: plain & blocking & plain OR blocking & + // blocking & blocking + TensorLayout inputLayout = matmulLayout.getSupportedInputLayouts()[0]; + TensorLayout weightLayout = matmulLayout.getSupportedInputLayouts()[1]; + TensorLayout outputLayout = matmulLayout.getSupportedOutputLayouts()[0]; + if (!inputLayout.isBlocking() && !weightLayout.isBlocking() && + !outputLayout.isBlocking()) + return options; // return default options to skip packing + // specify B side as be NKkn options.rhsTransposeOuterBlocks = true; options.rhsTransposeInnerBlocks = false; // extract tile sizes - OpFoldResult M_block = LHSLayout.getTileSizes().empty() - ? rewriter.getIndexAttr(1) - : LHSLayout.getTileSizes()[0]; - OpFoldResult K_block = LHSLayout.getTileSizes().empty() - ? rewriter.getIndexAttr(1) - : LHSLayout.getTileSizes()[1]; - OpFoldResult N_block = RHSLayout.getTileSizes().empty() - ? rewriter.getIndexAttr(1) - : RHSLayout.getTileSizes()[0]; - options.blockFactors.push_back(*getConstantIntValue(M_block)); - options.blockFactors.push_back(*getConstantIntValue(K_block)); - options.blockFactors.push_back(*getConstantIntValue(N_block)); + auto matmulCfg = MatmulConfigAnalysis(op.getOperation()).getConfig(); + OpFoldResult MBlock = rewriter.getIndexAttr(matmulCfg.innerMostMBlock), + KBlock = rewriter.getIndexAttr(matmulCfg.innerMostKBlock), + NBlock = rewriter.getIndexAttr(matmulCfg.innerMostNBlock); + if (!inputLayout.getTileSizes().empty()) + assert(inputLayout.getTileSizes()[0] == MBlock && + inputLayout.getTileSizes()[1] == KBlock && + "Layout tile size and matmul block size mismatch."); + if (!weightLayout.getTileSizes().empty()) + assert(weightLayout.getTileSizes()[0] == KBlock && + weightLayout.getTileSizes()[1] == NBlock && + "Layout tile size and matmul block size mismatch."); + if (!outputLayout.getTileSizes().empty()) + assert(outputLayout.getTileSizes()[0] == MBlock && + outputLayout.getTileSizes()[1] == NBlock && + "Layout tile size and matmul block size mismatch."); + options.blockFactors.push_back(*getConstantIntValue(MBlock)); + options.blockFactors.push_back(*getConstantIntValue(NBlock)); + options.blockFactors.push_back(*getConstantIntValue(KBlock)); return options; }; linalg::populateBlockPackMatmulPatterns(packMatmulPatterns, @@ -697,25 +825,44 @@ void PropagateLayoutOnNamedOps::runOnOperation() { applyPatternsAndFoldGreedily(graph, std::move(packMatmulPatterns)))) return signalPassFailure(); - // stage2: pack VNNI + // stage 1.2: pack VNNI RewritePatternSet packVNNIPatterns(&getContext()); packVNNIPatterns.add, PackVNNI, PackVNNI>(ctx); if (failed(applyPatternsAndFoldGreedily(graph, std::move(packVNNIPatterns)))) return signalPassFailure(); - // stage 2.5: revert necessary blocking on matmul op + // stage 1.3: revert necessary blocking on matmul op + // RevertMatmulPacking + // double confirm the number of identifiable matmuls + // collect matmul layouts in topological order + uint64_t numMatmuls = 0; + graph->walk([&](Operation *op) { + if (auto linalgOp = dyn_cast(op)) { + if (mlir::gc::utils::isSupportedContractionNamedOp(linalgOp) || + linalgx::isGenericPackedMatmulOp(linalgOp.getOperation(), + linalgx::PackingType::MM4D) || + linalgx::isGenericPackedMatmulOp(linalgOp.getOperation(), + linalgx::PackingType::VNNI_MM4D)) { + numMatmuls += 1; + } + } + return WalkResult::advance(); + }); + assert(matmulLayouts.size() == numMatmuls && + "One to one matmul mapping failed."); + if (failed(revertMatmulPacking(ctx, graph, matmulLayouts))) + return signalPassFailure(); - // stage3: propagate layout on other named ops + // stage 2: propagate layout on other named ops ControlPackNamedOpsFn layoutControlFn = [&](Operation *op) -> FailureOr { - auto &layoutAnalysisResult = getAnalysis(); return layoutAnalysisResult.getOpLayout(op); }; if (failed(namedOpLayoutPropagation(ctx, graph, layoutControlFn))) return signalPassFailure(); - // stage4: uplift pack through broadcast + // stage 3: uplift pack through broadcast RewritePatternSet upliftPatterns(&getContext()); upliftPatterns.add(ctx); if (failed(applyPatternsAndFoldGreedily(graph, std::move(upliftPatterns)))) diff --git a/test/mlir/test/gc/Transforms/pack-matmul.mlir b/test/mlir/test/gc/Transforms/pack-matmul.mlir index 2c2ab8a5f..9f674bc33 100644 --- a/test/mlir/test/gc/Transforms/pack-matmul.mlir +++ b/test/mlir/test/gc/Transforms/pack-matmul.mlir @@ -1,4 +1,4 @@ -// RUN: gc-opt %s --split-input-file --propagate-layout-on-named-ops | FileCheck %s +// RUN: gc-opt %s --split-input-file --propagate-layout-on-named-ops --post-process-pack-unpack | FileCheck %s // CHECK-LABEL: @single_matmul_f32 func.func @single_matmul_f32(%arg0: tensor<128x64xf32>, %arg1: tensor<64x32xf32>) -> tensor<128x32xf32> { @@ -8,9 +8,9 @@ func.func @single_matmul_f32(%arg0: tensor<128x64xf32>, %arg1: tensor<64x32xf32> %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x64xf32>, tensor<64x32xf32>) outs(%0 : tensor<128x32xf32>) -> tensor<128x32xf32> return %2 : tensor<128x32xf32> } -// CHECK-COUNT-3: tensor.pack +// CHECK-COUNT-1: tensor.pack // CHECK-COUNT-1: linalg.generic -// CHECK-COUNT-1: tensor.unpack +// CHECK-NOT: tensor.unpack // CHECK-LABEL: @single_matmul_bf16 func.func @single_matmul_bf16(%arg0: tensor<128x64xbf16>, %arg1: tensor<64x32xbf16>) -> tensor<128x32xbf16> { @@ -20,39 +20,6 @@ func.func @single_matmul_bf16(%arg0: tensor<128x64xbf16>, %arg1: tensor<64x32xbf %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x64xbf16>, tensor<64x32xbf16>) outs(%0 : tensor<128x32xbf16>) -> tensor<128x32xbf16> return %2 : tensor<128x32xbf16> } -// CHECK-COUNT-4: tensor.pack -// CHECK-COUNT-1: linalgx.mm4d_vnni -// CHECK-COUNT-1: tensor.unpack - -// CHECK-LABEL: @single_batch_matmul_bf16 -func.func @single_batch_matmul_bf16(%arg0: tensor<64x128x64xbf16>, %arg1: tensor<64x64x32xbf16>) -> tensor<64x128x32xbf16> { - %cst = arith.constant 0.000000e+00 : bf16 - %0 = tensor.empty() : tensor<64x128x32xbf16> - %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<64x128x32xbf16>) -> tensor<64x128x32xbf16> - %2 = linalg.batch_matmul ins(%arg0, %arg1 : tensor<64x128x64xbf16>, tensor<64x64x32xbf16>) outs(%0 : tensor<64x128x32xbf16>) -> tensor<64x128x32xbf16> - return %2 : tensor<64x128x32xbf16> -} -// CHECK-COUNT-4: tensor.pack -// CHECK-COUNT-1: linalg.generic -// CHECK-COUNT-1: tensor.unpack - -func.func @pack_vnni_mmt4d(%arg0: tensor<4x2x32x32xbf16>, %arg1: tensor<1x2x32x32xbf16>) -> tensor<4x1x32x32xbf16> { - %cst = arith.constant 0.000000e+00 : bf16 - %0 = tensor.empty() : tensor<4x1x32x32xbf16> - %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<4x1x32x32xbf16>) -> tensor<4x1x32x32xbf16> - %2 = linalg.mmt4d ins(%arg0, %arg1 : tensor<4x2x32x32xbf16>, tensor<1x2x32x32xbf16>) outs(%0 : tensor<4x1x32x32xbf16>) -> tensor<4x1x32x32xbf16> - return %2 : tensor<4x1x32x32xbf16> -} -// CHECK-COUNT-1: tensor.pack -// CHECK-COUNT-1: linalgx.mm4d_vnni - -func.func @pack_vnni_batchmmt4d(%arg0: tensor<4x4x2x32x32xbf16>, %arg1: tensor<4x1x2x32x32xbf16>) -> tensor<4x4x1x32x32xbf16> { - %cst = arith.constant 0.000000e+00 : bf16 - %0 = tensor.empty() : tensor<4x4x1x32x32xbf16> - %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<4x4x1x32x32xbf16>) -> tensor<4x4x1x32x32xbf16> - %2 = linalg.batch_mmt4d ins(%arg0, %arg1 : tensor<4x4x2x32x32xbf16>, tensor<4x1x2x32x32xbf16>) outs(%0 : tensor<4x4x1x32x32xbf16>) -> tensor<4x4x1x32x32xbf16> - return %2 : tensor<4x4x1x32x32xbf16> -} // CHECK-COUNT-1: tensor.pack // CHECK-COUNT-1: linalg.generic - +// CHECK-NOT: tensor.unpack From aa9da9f84cdecbb77f83a9f433a81b37850b2538 Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Fri, 13 Sep 2024 01:45:51 -0700 Subject: [PATCH 12/23] fix liense --- include/gc/Transforms/Transforms.h | 4 +-- lib/gc/Analysis/GlobalAnalysis.cpp | 7 +++-- lib/gc/Transforms/LowerPackUnpack.cpp | 7 +++-- lib/gc/Transforms/PostProcessPackUnpack.cpp | 7 +++-- lib/gc/Transforms/PropagateLayout.cpp | 8 +++--- packMatmul.patch | 32 --------------------- 6 files changed, 18 insertions(+), 47 deletions(-) delete mode 100644 packMatmul.patch diff --git a/include/gc/Transforms/Transforms.h b/include/gc/Transforms/Transforms.h index 590fe3cfc..e0fdbcffa 100644 --- a/include/gc/Transforms/Transforms.h +++ b/include/gc/Transforms/Transforms.h @@ -1,6 +1,6 @@ -//===- Transforms.h - transformation utilities ------------------*- C++ -*-===// +//===-- Transforms.h - transformation utilities -----------------*- C++ -*-===// // -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // diff --git a/lib/gc/Analysis/GlobalAnalysis.cpp b/lib/gc/Analysis/GlobalAnalysis.cpp index 3365b0fab..62e27a1c4 100644 --- a/lib/gc/Analysis/GlobalAnalysis.cpp +++ b/lib/gc/Analysis/GlobalAnalysis.cpp @@ -1,7 +1,8 @@ -//===- GlobalAnalysis.cpp - Propagate packing on linalg named ops *- C++-*-===// +//===-- GlobalAnalysis.cpp - Analyze layout on named ops *- C++ ---------*-===// // -// This file is only temporarily used to extend upstream or upcoming utility in -// TilingInterface, which finally aims for upstream. +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// diff --git a/lib/gc/Transforms/LowerPackUnpack.cpp b/lib/gc/Transforms/LowerPackUnpack.cpp index 811f8a4ce..e13159aaf 100644 --- a/lib/gc/Transforms/LowerPackUnpack.cpp +++ b/lib/gc/Transforms/LowerPackUnpack.cpp @@ -1,7 +1,8 @@ -//===- LowerPackUnpack.cpp - Lower pack unpack into linalg ops *---- C++-*-===// +//===-- LowerPackUnpack.cpp - Lower pack unpack into linalg ops *--- C++-*-===// // -// This file is only temporarily used to extend upstream or upcoming utility in -// TilingInterface, which finally aims for upstream. +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// diff --git a/lib/gc/Transforms/PostProcessPackUnpack.cpp b/lib/gc/Transforms/PostProcessPackUnpack.cpp index 3e34865d7..8b3b48ad2 100644 --- a/lib/gc/Transforms/PostProcessPackUnpack.cpp +++ b/lib/gc/Transforms/PostProcessPackUnpack.cpp @@ -1,7 +1,8 @@ -//===- PostProcessPackUnpack.cpp - Fold and simplify pack unpack *-- C++-*-===// +//===-- PostProcessPackUnpack.cpp - Fold and simplify pack unpack *- C++-*-===// // -// This file is only temporarily used to extend upstream or upcoming utility in -// TilingInterface, which finally aims for upstream. +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp index fd2e10caa..cead19d5e 100644 --- a/lib/gc/Transforms/PropagateLayout.cpp +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -1,8 +1,8 @@ -//===- PropagateLayoutOnNamedOps.cpp - Propagate packing on linalg named ops*- -// C++-*-===// +//===-- PropagateLayout.cpp - Propagate packing on named ops*- C++ ------*-===// // -// This file is only temporarily used to extend upstream or upcoming utility in -// TilingInterface, which finally aims for upstream. +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// diff --git a/packMatmul.patch b/packMatmul.patch deleted file mode 100644 index ef695240c..000000000 --- a/packMatmul.patch +++ /dev/null @@ -1,32 +0,0 @@ -diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp -index 91d4efa3372b..f3f61ff92140 100644 ---- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp -+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp -@@ -210,6 +210,19 @@ linalg::blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, - packedMatmul->packOps[1] = packedRhs->transposedPackOp; - packedMatmul->packedLinalgOp = packedRhs->transposedLinalgOp; - -+ // rewrite generic to mmt4d -+ if (!options->lhsTransposeOuterBlocks && !options->lhsTransposeInnerBlocks && -+ options->rhsTransposeOuterBlocks && options->rhsTransposeInnerBlocks && -+ options->mnkOrder == SmallVector{0, 1, 2}) { -+ auto originalLinalgOp = packedMatmul->packedLinalgOp; -+ rewriter.setInsertionPoint(originalLinalgOp); -+ auto mmt4d = rewriter.create( -+ originalLinalgOp.getLoc(), originalLinalgOp.getDpsInits().getTypes(), -+ originalLinalgOp.getDpsInputs(), originalLinalgOp.getDpsInits()); -+ rewriter.replaceOp(originalLinalgOp, mmt4d); -+ packedMatmul->packedLinalgOp = mmt4d; -+ } -+ - return packedMatmul; - } - -@@ -307,6 +320,7 @@ struct LinalgBlockPackMatmul - }; - } // namespace - -+// extend to transform to mmt4d or batch_mmt4d - void linalg::populateBlockPackMatmulPatterns( - RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) { - patterns.add, From 9238ad0372415a7cb60c4d378a39d974dd413c82 Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Mon, 16 Sep 2024 05:46:21 -0700 Subject: [PATCH 13/23] fix clang tidy --- include/gc/Transforms/Transforms.h | 2 +- lib/gc/Analysis/GlobalAnalysis.cpp | 27 ++++---- lib/gc/Transforms/PostProcessPackUnpack.cpp | 4 +- lib/gc/Transforms/PropagateLayout.cpp | 70 +++------------------ 4 files changed, 25 insertions(+), 78 deletions(-) diff --git a/include/gc/Transforms/Transforms.h b/include/gc/Transforms/Transforms.h index e0fdbcffa..1b10cc64f 100644 --- a/include/gc/Transforms/Transforms.h +++ b/include/gc/Transforms/Transforms.h @@ -16,7 +16,7 @@ namespace mlir { namespace gc { LogicalResult packLinalgOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, - OperatorLayout opLayout); + const OperatorLayout &opLayout); LogicalResult namedOpLayoutPropagation(RewriterBase &rewriter, linalg::LinalgOp linalgOp, diff --git a/lib/gc/Analysis/GlobalAnalysis.cpp b/lib/gc/Analysis/GlobalAnalysis.cpp index 62e27a1c4..b9b532549 100644 --- a/lib/gc/Analysis/GlobalAnalysis.cpp +++ b/lib/gc/Analysis/GlobalAnalysis.cpp @@ -19,10 +19,10 @@ namespace gc { #define DEBUG_TYPE "global-analysis" llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, - const TensorLayout &tmpLayoutCache) { - SmallVector outerAxis = tmpLayoutCache.getOuterAxis(); - SmallVector innerAxis = tmpLayoutCache.getInnerAxis(); - SmallVector tileSizes = tmpLayoutCache.getTileSizes(); + const TensorLayout &layout) { + SmallVector outerAxis = layout.getOuterAxis(); + SmallVector innerAxis = layout.getInnerAxis(); + SmallVector tileSizes = layout.getTileSizes(); ss << "["; llvm::interleaveComma(outerAxis, ss); if (!innerAxis.empty()) { @@ -122,7 +122,7 @@ getReversedIndexMap(const DenseMap &indexMap, } static TensorLayout -inferTargetLayout(TensorLayout layoutBase, +inferTargetLayout(const TensorLayout &layoutBase, const DenseMap &indexMap) { SmallVector baseOuterAxis = layoutBase.getOuterAxis(); SmallVector baseInnerAxis = layoutBase.getInnerAxis(); @@ -203,7 +203,6 @@ projectToInnerMostNonUnitDimsPos(ArrayRef dimsPos, return projectedDimsPos; } -// copied from mlir // Check if all dims in dimsPos are divisible by the corresponding tile sizes. static bool isDimsDivisibleByTileSizes(ArrayRef dimsPos, ArrayRef shape, @@ -464,8 +463,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { } else { return WalkResult::skip(); } - ArrayRef innerPosPos = curInputLayout.getInnerAxis(); - ArrayRef outerDimsPerm = curInputLayout.getOuterAxis(); + SmallVector innerPosPos = curInputLayout.getInnerAxis(); + SmallVector outerDimsPerm = curInputLayout.getOuterAxis(); SmallVector projectedInnerDimsPos = projectToInnerMostNonUnitDimsPos(innerPosPos, reassocIndices, staticOutputShape); @@ -532,7 +531,7 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { auto outerPerm = curInputLayout.getOuterAxis(); SmallVector newOuterDimsPerm; int64_t axisIdx = 0; - while (axisIdx < outerPerm.size()) { + while (axisIdx < static_cast(outerPerm.size())) { for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { if (llvm::any_of(indices, [&](int64_t collapseDim) { return collapseDim == outerPerm[axisIdx]; @@ -572,12 +571,10 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { namespace utils { bool isSupportedContractionNamedOp(linalg::LinalgOp &linalgOp) { - if (isa( - linalgOp)) - return true; - return false; + return isa( + linalgOp); } bool isPackableNamedOp(Operation *op) { diff --git a/lib/gc/Transforms/PostProcessPackUnpack.cpp b/lib/gc/Transforms/PostProcessPackUnpack.cpp index 8b3b48ad2..ebf6d4721 100644 --- a/lib/gc/Transforms/PostProcessPackUnpack.cpp +++ b/lib/gc/Transforms/PostProcessPackUnpack.cpp @@ -80,7 +80,7 @@ struct EliminateDummyPack : public OpRewritePattern { if (packOp.getStaticInnerTiles().empty() && packOp.getInnerTiles().empty()) { auto outerPerm = packOp.getOuterDimsPerm(); - for (size_t i = 0; i < outerPerm.size(); ++i) { + for (int64_t i = 0; i < static_cast(outerPerm.size()); ++i) { if (outerPerm[i] != i) { return rewriter.notifyMatchFailure(packOp, "Not dummy"); } @@ -103,7 +103,7 @@ struct EliminateDummyUnpack : public OpRewritePattern { if (unpackOp.getStaticInnerTiles().empty() && unpackOp.getInnerTiles().empty()) { auto outerPerm = unpackOp.getOuterDimsPerm(); - for (size_t i = 0; i < outerPerm.size(); ++i) { + for (int64_t i = 0; i < static_cast(outerPerm.size()); ++i) { if (outerPerm[i] != i) { return rewriter.notifyMatchFailure(unpackOp, "Not dummy"); } diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp index cead19d5e..73f000bb1 100644 --- a/lib/gc/Transforms/PropagateLayout.cpp +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -78,12 +78,12 @@ static Value insertLayoutUnpack(RewriterBase &rewriter, Location loc, } static SmallVector getPackedAxes(ArrayRef dimensions, - TensorLayout targetLayout) { + const TensorLayout &targetLayout) { SmallVector result; // permuting on outer axis auto outerPerm = targetLayout.getOuterAxis(); - for (size_t i = 0; i < dimensions.size(); ++i) { - auto pos = std::find(outerPerm.begin(), outerPerm.end(), dimensions[i]); + for (int64_t dim : dimensions) { + auto pos = std::find(outerPerm.begin(), outerPerm.end(), dim); assert(pos != outerPerm.end() && "dimension must be within output perm."); result.push_back(std::distance(outerPerm.begin(), pos)); } @@ -138,7 +138,7 @@ static int64_t applyPermutationAndReindexReassoc( // extends linalg::pack(...) for named ops LogicalResult packLinalgOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, - OperatorLayout opLayout) { + const OperatorLayout &opLayout) { if (linalgOp.hasPureBufferSemantics()) return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics"); LLVM_DEBUG(llvm::dbgs() << "Try packing named op " @@ -166,7 +166,7 @@ LogicalResult packLinalgOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, LLVM_DEBUG(llvm::dbgs() << "At least one input of named op: " << linalgOp.getOperation()->getName() << " is not tensor. Skip.\n"); - return failure("The op does not need packing."); + return failure(); } for (const auto &operandsList : {inputOperands, initOperands}) { for (OpOperand *opOperand : operandsList) { @@ -224,8 +224,7 @@ LogicalResult packLinalgOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, loc, inputs[0], inits[0], packedPermAxes); } else if (isa(linalgOp) || isa(linalgOp) || isa(linalgOp) || isa(linalgOp)) { - return failure( - "Packing logic not implemented for SoftMax/Map/Yield/Index."); + return failure(); } else { packedLinalgOp = mlir::clone( rewriter, linalgOp, SmallVector{inputsAndInits.back().getType()}, @@ -235,7 +234,7 @@ LogicalResult packLinalgOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, // Step 4. Unpack all the op results. for (OpResult result : packedLinalgOp->getResults()) { int64_t resultNum = result.getResultNumber(); - assert(resultNum < initLayouts.size() && + assert(resultNum < static_cast(initLayouts.size()) && "Linalg op results num exceeds inits num."); // Build the symmetrical UnPackOp to the existing PackOp. unPackOps.push_back( @@ -396,55 +395,8 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, return success(); } -static LogicalResult createAndReplaceWithGenericVNNIMatmul( - RewriterBase &rewriter, MLIRContext *context, SmallVector inputs, - SmallVector inits, int64_t batchDimSize, int64_t blockingFactor, - Operation *matmulOp) { - AffineMap mapInput, mapWeight, mapOutput; - int64_t dims = batchDimSize + 7; - SmallVector exprs(dims); - // dims is in order B1, ..., Bn, M, N, K, m, n, k, vnni - bindDimsList(context, exprs); - SmallVector batchExprs(exprs.begin(), - exprs.begin() + batchDimSize); - AffineExpr M = exprs[batchDimSize], N = exprs[batchDimSize + 1], - K = exprs[batchDimSize + 2], m = exprs[batchDimSize + 3], - n = exprs[batchDimSize + 4], k = exprs[batchDimSize + 5], - vnni = exprs[batchDimSize + 6]; - SmallVector resultA{M, K, m, k}; - SmallVector resultB{N, K, k.floorDiv(blockingFactor), n, vnni}; - SmallVector resultC{M, N, m, n}; - resultA.insert(resultA.begin(), batchExprs.begin(), batchExprs.end()); - resultB.insert(resultB.begin(), batchExprs.begin(), batchExprs.end()); - resultC.insert(resultC.begin(), batchExprs.begin(), batchExprs.end()); - mapInput = AffineMap::get(/*dims=*/dims, /*symbols=*/0, resultA, context); - mapWeight = AffineMap::get(/*dims=*/dims, /*symbols=*/0, resultB, context); - mapOutput = AffineMap::get(/*dims=*/dims, /*symbols=*/0, resultC, context); - SmallVector batchIterators( - batchDimSize, mlir::utils::IteratorType::parallel); - SmallVector iterators{ - mlir::utils::IteratorType::parallel, - mlir::utils::IteratorType::parallel, - mlir::utils::IteratorType::reduction, - mlir::utils::IteratorType::parallel, - mlir::utils::IteratorType::parallel, - mlir::utils::IteratorType::reduction, - mlir::utils::IteratorType::reduction}; - iterators.insert(iterators.begin(), batchIterators.begin(), - batchIterators.end()); - auto replacementOp = rewriter.create( - matmulOp->getLoc(), inits[0].getType(), inputs, inits, - ArrayRef{mapInput, mapWeight, mapOutput}, iterators, - /*doc=*/"", /*libraryCall=*/""); - rewriter.inlineRegionBefore(matmulOp->getRegion(0), replacementOp.getRegion(), - replacementOp.getRegion().begin()); - rewriter.replaceOp(matmulOp, replacementOp.getResult(0)); - return success(); -} - template -static LogicalResult packVNNIMMT4D(RewriterBase &rewriter, OpTy mmt4dOp, - bool useNamedOp = false) { +static LogicalResult packVNNIMMT4D(RewriterBase &rewriter, OpTy mmt4dOp) { auto elementType = getElementTypeOrSelf(mmt4dOp.getInputs()[0].getType()); if (!elementType.isBF16() && !elementType.isInteger(8)) return rewriter.notifyMatchFailure(mmt4dOp, "require bf16/int8 data type"); @@ -498,8 +450,7 @@ If possible, pack to Mm2DVnniOp or Mm4DVnniOp. If not possible, pack to GenericOp. */ static LogicalResult packVNNIGeneric(RewriterBase &rewriter, - linalg::GenericOp matmulOp, - bool useNamedOp = false) { + linalg::GenericOp matmulOp) { if (matmulOp.getDpsInputs().size() != 2) return rewriter.notifyMatchFailure(matmulOp, "require 2 inputs"); @@ -546,7 +497,6 @@ static LogicalResult packVNNIGeneric(RewriterBase &rewriter, Value VNNIPack = rewriter.create(loc, weight.get(), dest, innerPos, tileSize, zero); - int64_t batchDimSize = weightRank - 4; SmallVector inputsValues{matmulOp.getInputs()[0], VNNIPack}; // check whether VNNIPack causes padding, weightShape is BNKkn int64_t innermostKDim = weightShape[weightRank - 2]; @@ -599,7 +549,7 @@ shallRevertToType(linalg::GenericOp matmulOp) { return failure(); } -static bool isPlainActivationMatmul(OperatorLayout matmulLayout) { +static bool isPlainActivationMatmul(const OperatorLayout &matmulLayout) { auto inputLayout = matmulLayout.getSupportedInputLayouts()[0]; auto outputLayout = matmulLayout.getSupportedInputLayouts()[0]; return !inputLayout.isBlocking() && !outputLayout.isBlocking(); From 5d37aab027c455d5bcb8ce67af5b97a71428de04 Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Tue, 17 Sep 2024 20:10:48 -0700 Subject: [PATCH 14/23] fix test --- lib/gc/Analysis/GlobalAnalysis.cpp | 12 +-- lib/gc/Transforms/PropagateLayout.cpp | 5 +- test/mlir/test/gc/Transforms/pack-matmul.mlir | 86 ++++++++++++++++++- 3 files changed, 93 insertions(+), 10 deletions(-) diff --git a/lib/gc/Analysis/GlobalAnalysis.cpp b/lib/gc/Analysis/GlobalAnalysis.cpp index b9b532549..98cb2610e 100644 --- a/lib/gc/Analysis/GlobalAnalysis.cpp +++ b/lib/gc/Analysis/GlobalAnalysis.cpp @@ -207,12 +207,12 @@ projectToInnerMostNonUnitDimsPos(ArrayRef dimsPos, static bool isDimsDivisibleByTileSizes(ArrayRef dimsPos, ArrayRef shape, ArrayRef tileSizes) { - for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) { - int64_t dim = shape[pos]; - if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0) - return false; - } - return true; + return llvm::all_of(llvm::zip_equal(dimsPos, tileSizes), + [shape](std::tuple sizePair) { + int64_t dim = shape[std::get<0>(sizePair)]; + return !ShapedType::isDynamic(dim) && + (dim % std::get<1>(sizePair)) == 0; + }); } // if forceBlocking is set to true, we will unconditionally convert diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp index 73f000bb1..2803b0498 100644 --- a/lib/gc/Transforms/PropagateLayout.cpp +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -717,7 +717,7 @@ void PropagateLayoutOnNamedOps::runOnOperation() { MLIRContext *ctx = &getContext(); IRRewriter rewriter(ctx); mlir::Operation *graph = getOperation(); - // collect matmul layouts in topological order + // pre-collect matmul layouts in topological order auto &layoutAnalysisResult = getAnalysis(); std::vector matmulLayouts; graph->walk([&](Operation *op) { @@ -791,8 +791,7 @@ void PropagateLayoutOnNamedOps::runOnOperation() { if (auto linalgOp = dyn_cast(op)) { if (mlir::gc::utils::isSupportedContractionNamedOp(linalgOp) || linalgx::isGenericPackedMatmulOp(linalgOp.getOperation(), - linalgx::PackingType::MM4D) || - linalgx::isGenericPackedMatmulOp(linalgOp.getOperation(), + linalgx::PackingType::MM4D, linalgx::PackingType::VNNI_MM4D)) { numMatmuls += 1; } diff --git a/test/mlir/test/gc/Transforms/pack-matmul.mlir b/test/mlir/test/gc/Transforms/pack-matmul.mlir index 9f674bc33..edb8bcb72 100644 --- a/test/mlir/test/gc/Transforms/pack-matmul.mlir +++ b/test/mlir/test/gc/Transforms/pack-matmul.mlir @@ -1,5 +1,7 @@ // RUN: gc-opt %s --split-input-file --propagate-layout-on-named-ops --post-process-pack-unpack | FileCheck %s +// ----- + // CHECK-LABEL: @single_matmul_f32 func.func @single_matmul_f32(%arg0: tensor<128x64xf32>, %arg1: tensor<64x32xf32>) -> tensor<128x32xf32> { %cst = arith.constant 0.000000e+00 : f32 @@ -12,6 +14,8 @@ func.func @single_matmul_f32(%arg0: tensor<128x64xf32>, %arg1: tensor<64x32xf32> // CHECK-COUNT-1: linalg.generic // CHECK-NOT: tensor.unpack +// ----- + // CHECK-LABEL: @single_matmul_bf16 func.func @single_matmul_bf16(%arg0: tensor<128x64xbf16>, %arg1: tensor<64x32xbf16>) -> tensor<128x32xbf16> { %cst = arith.constant 0.000000e+00 : bf16 @@ -20,6 +24,86 @@ func.func @single_matmul_bf16(%arg0: tensor<128x64xbf16>, %arg1: tensor<64x32xbf %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x64xbf16>, tensor<64x32xbf16>) outs(%0 : tensor<128x32xbf16>) -> tensor<128x32xbf16> return %2 : tensor<128x32xbf16> } -// CHECK-COUNT-1: tensor.pack +// CHECK-COUNT-2: tensor.pack // CHECK-COUNT-1: linalg.generic // CHECK-NOT: tensor.unpack + +// ----- + +// CHECK-LABEL: @mlp_f32 +func.func @mlp_f32(%arg0: tensor<128x16xf32>, %arg1: tensor<16x512xf32>, %arg2: tensor<512x256xf32>, %arg3: tensor<256x128xf32>, %arg4: tensor<512xf32>, %arg5: tensor<256xf32>, %arg6: tensor<128xf32>) -> tensor<128x128xf32> attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<128x512xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x512xf32>) -> tensor<128x512xf32> + %2 = linalg.matmul {cast = #linalg.type_fn} ins(%arg0, %arg1 : tensor<128x16xf32>, tensor<16x512xf32>) outs(%1 : tensor<128x512xf32>) -> tensor<128x512xf32> + %3 = tensor.empty() : tensor<128x512xf32> + %broadcasted = linalg.broadcast ins(%arg4 : tensor<512xf32>) outs(%3 : tensor<128x512xf32>) dimensions = [0] + %4 = tensor.empty() : tensor<128x512xf32> + %5 = linalg.add ins(%2, %broadcasted : tensor<128x512xf32>, tensor<128x512xf32>) outs(%4 : tensor<128x512xf32>) -> tensor<128x512xf32> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x512xf32> + %6 = tensor.empty() : tensor<128x512xf32> + %7 = linalg.max ins(%5, %cst_0 : tensor<128x512xf32>, tensor<128x512xf32>) outs(%6 : tensor<128x512xf32>) -> tensor<128x512xf32> + %8 = tensor.empty() : tensor<128x256xf32> + %9 = linalg.fill ins(%cst : f32) outs(%8 : tensor<128x256xf32>) -> tensor<128x256xf32> + %10 = linalg.matmul {cast = #linalg.type_fn} ins(%7, %arg2 : tensor<128x512xf32>, tensor<512x256xf32>) outs(%9 : tensor<128x256xf32>) -> tensor<128x256xf32> + %11 = tensor.empty() : tensor<128x256xf32> + %broadcasted_1 = linalg.broadcast ins(%arg5 : tensor<256xf32>) outs(%11 : tensor<128x256xf32>) dimensions = [0] + %12 = tensor.empty() : tensor<128x256xf32> + %13 = linalg.add ins(%10, %broadcasted_1 : tensor<128x256xf32>, tensor<128x256xf32>) outs(%12 : tensor<128x256xf32>) -> tensor<128x256xf32> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x256xf32> + %14 = tensor.empty() : tensor<128x256xf32> + %15 = linalg.max ins(%13, %cst_2 : tensor<128x256xf32>, tensor<128x256xf32>) outs(%14 : tensor<128x256xf32>) -> tensor<128x256xf32> + %16 = tensor.empty() : tensor<128x128xf32> + %17 = linalg.fill ins(%cst : f32) outs(%16 : tensor<128x128xf32>) -> tensor<128x128xf32> + %18 = linalg.matmul {cast = #linalg.type_fn} ins(%15, %arg3 : tensor<128x256xf32>, tensor<256x128xf32>) outs(%17 : tensor<128x128xf32>) -> tensor<128x128xf32> + %19 = tensor.empty() : tensor<128x128xf32> + %broadcasted_3 = linalg.broadcast ins(%arg6 : tensor<128xf32>) outs(%19 : tensor<128x128xf32>) dimensions = [0] + %20 = tensor.empty() : tensor<128x128xf32> + %21 = linalg.add ins(%18, %broadcasted_3 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%20 : tensor<128x128xf32>) -> tensor<128x128xf32> + %cst_4 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> + %22 = tensor.empty() : tensor<128x128xf32> + %23 = linalg.max ins(%21, %cst_4 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%22 : tensor<128x128xf32>) -> tensor<128x128xf32> + return %23 : tensor<128x128xf32> +} +// CHECK-COUNT-3: tensor.pack +// CHECK-NOT: tensor.unpack + +// ----- + +// CHECK-LABEL: @mlp_bf16 +func.func @mlp_bf16(%arg0: tensor<32x4096xbf16>, %arg1: tensor<4096x4096xbf16>, %arg2: tensor<4096x11008xbf16>, %arg3: tensor<11008x4096xbf16>, %arg4: tensor<4096xbf16>, %arg5: tensor<11008xbf16>, %arg6: tensor<4096xbf16>) -> tensor<32x4096xbf16> attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : bf16 + %0 = tensor.empty() : tensor<32x4096xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %2 = linalg.matmul {cast = #linalg.type_fn} ins(%arg0, %arg1 : tensor<32x4096xbf16>, tensor<4096x4096xbf16>) outs(%1 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %3 = tensor.empty() : tensor<32x4096xbf16> + %broadcasted = linalg.broadcast ins(%arg4 : tensor<4096xbf16>) outs(%3 : tensor<32x4096xbf16>) dimensions = [0] + %4 = tensor.empty() : tensor<32x4096xbf16> + %5 = linalg.add ins(%2, %broadcasted : tensor<32x4096xbf16>, tensor<32x4096xbf16>) outs(%4 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x4096xbf16> + %6 = tensor.empty() : tensor<32x4096xbf16> + %7 = linalg.max ins(%5, %cst_0 : tensor<32x4096xbf16>, tensor<32x4096xbf16>) outs(%6 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %8 = tensor.empty() : tensor<32x11008xbf16> + %9 = linalg.fill ins(%cst : bf16) outs(%8 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %10 = linalg.matmul {cast = #linalg.type_fn} ins(%7, %arg2 : tensor<32x4096xbf16>, tensor<4096x11008xbf16>) outs(%9 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %11 = tensor.empty() : tensor<32x11008xbf16> + %broadcasted_1 = linalg.broadcast ins(%arg5 : tensor<11008xbf16>) outs(%11 : tensor<32x11008xbf16>) dimensions = [0] + %12 = tensor.empty() : tensor<32x11008xbf16> + %13 = linalg.add ins(%10, %broadcasted_1 : tensor<32x11008xbf16>, tensor<32x11008xbf16>) outs(%12 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<32x11008xbf16> + %14 = tensor.empty() : tensor<32x11008xbf16> + %15 = linalg.max ins(%13, %cst_2 : tensor<32x11008xbf16>, tensor<32x11008xbf16>) outs(%14 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %16 = tensor.empty() : tensor<32x4096xbf16> + %17 = linalg.fill ins(%cst : bf16) outs(%16 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %18 = linalg.matmul {cast = #linalg.type_fn} ins(%15, %arg3 : tensor<32x11008xbf16>, tensor<11008x4096xbf16>) outs(%17 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %19 = tensor.empty() : tensor<32x4096xbf16> + %broadcasted_3 = linalg.broadcast ins(%arg6 : tensor<4096xbf16>) outs(%19 : tensor<32x4096xbf16>) dimensions = [0] + %20 = tensor.empty() : tensor<32x4096xbf16> + %21 = linalg.add ins(%18, %broadcasted_3 : tensor<32x4096xbf16>, tensor<32x4096xbf16>) outs(%20 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %cst_4 = arith.constant dense<0.000000e+00> : tensor<32x4096xbf16> + %22 = tensor.empty() : tensor<32x4096xbf16> + %23 = linalg.max ins(%21, %cst_4 : tensor<32x4096xbf16>, tensor<32x4096xbf16>) outs(%22 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + return %23 : tensor<32x4096xbf16> +} +// CHECK-COUNT-6: tensor.pack +// CHECK-NOT: tensor.unpack From 2e3e37761025c1abe8c93e5bbc2e5b4c6fb0715c Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Tue, 17 Sep 2024 21:21:58 -0700 Subject: [PATCH 15/23] fix license --- lib/gc/Analysis/GlobalAnalysis.cpp | 2 +- lib/gc/Transforms/LowerPackUnpack.cpp | 2 +- lib/gc/Transforms/PostProcessPackUnpack.cpp | 2 +- lib/gc/Transforms/PropagateLayout.cpp | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/gc/Analysis/GlobalAnalysis.cpp b/lib/gc/Analysis/GlobalAnalysis.cpp index 98cb2610e..4880b0d49 100644 --- a/lib/gc/Analysis/GlobalAnalysis.cpp +++ b/lib/gc/Analysis/GlobalAnalysis.cpp @@ -1,4 +1,4 @@ -//===-- GlobalAnalysis.cpp - Analyze layout on named ops *- C++ ---------*-===// +//===-- GlobalAnalysis.cpp - Analyze layout on named ops --------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/lib/gc/Transforms/LowerPackUnpack.cpp b/lib/gc/Transforms/LowerPackUnpack.cpp index e13159aaf..efd96a0dd 100644 --- a/lib/gc/Transforms/LowerPackUnpack.cpp +++ b/lib/gc/Transforms/LowerPackUnpack.cpp @@ -1,4 +1,4 @@ -//===-- LowerPackUnpack.cpp - Lower pack unpack into linalg ops *--- C++-*-===// +//===-- LowerPackUnpack.cpp - Lower pack unpack into linalg ops -*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/lib/gc/Transforms/PostProcessPackUnpack.cpp b/lib/gc/Transforms/PostProcessPackUnpack.cpp index ebf6d4721..6b5c2336c 100644 --- a/lib/gc/Transforms/PostProcessPackUnpack.cpp +++ b/lib/gc/Transforms/PostProcessPackUnpack.cpp @@ -1,4 +1,4 @@ -//===-- PostProcessPackUnpack.cpp - Fold and simplify pack unpack *- C++-*-===// +//===-- PostProcessPackUnpack.cpp - Simplify pack unpack --------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp index 2803b0498..d9f19f84e 100644 --- a/lib/gc/Transforms/PropagateLayout.cpp +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -1,4 +1,4 @@ -//===-- PropagateLayout.cpp - Propagate packing on named ops*- C++ ------*-===// +//===-- PropagateLayout.cpp - Propagate packing on named ops ----*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. From 4b07acfdfc89f3df4e93ec4485a501c58d9043ff Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Tue, 17 Sep 2024 23:13:07 -0700 Subject: [PATCH 16/23] temp fix correctness check --- lib/gc/Analysis/GlobalAnalysis.cpp | 3 +- lib/gc/Transforms/PropagateLayout.cpp | 85 ++++++++++++++------------- 2 files changed, 45 insertions(+), 43 deletions(-) diff --git a/lib/gc/Analysis/GlobalAnalysis.cpp b/lib/gc/Analysis/GlobalAnalysis.cpp index 4880b0d49..45c0c586c 100644 --- a/lib/gc/Analysis/GlobalAnalysis.cpp +++ b/lib/gc/Analysis/GlobalAnalysis.cpp @@ -572,8 +572,7 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { namespace utils { bool isSupportedContractionNamedOp(linalg::LinalgOp &linalgOp) { return isa( + linalg::MatmulTransposeBOp>( linalgOp); } diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp index d9f19f84e..e74286b29 100644 --- a/lib/gc/Transforms/PropagateLayout.cpp +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -251,36 +251,36 @@ LogicalResult packLinalgOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, } // check whether the op is already packed or not -static bool checkPacked(Operation *op, const OperatorLayout &opLayout) { - // check whether rank match - if (auto linalgOp = dyn_cast(op)) { - assert(linalgOp.getDpsInits().size() == - opLayout.getSupportedOutputLayouts().size() && - linalgOp.getDpsInputs().size() == - opLayout.getSupportedInputLayouts().size()); - for (auto [index, layout] : - llvm::enumerate(opLayout.getSupportedInputLayouts())) { - // if dimension mismatch, then the op itself is already packed - if (layout.getOuterAxis().size() != - cast(linalgOp.getDpsInputs()[index].getType()) - .getShape() - .size()) - return true; - } - for (auto [index, layout] : - llvm::enumerate(opLayout.getSupportedOutputLayouts())) { - // if dimension mismatch, then the op itself is already packed - if (layout.getOuterAxis().size() != - cast(linalgOp.getDpsInits()[index].getType()) - .getShape() - .size()) - return true; - } - } else { - assert(op->getNumOperands() == 1 && op->getNumResults() == 1); - } - return false; -} +// static bool checkPacked(Operation *op, const OperatorLayout &opLayout) { +// // check whether rank match +// if (auto linalgOp = dyn_cast(op)) { +// assert(linalgOp.getDpsInits().size() == +// opLayout.getSupportedOutputLayouts().size() && +// linalgOp.getDpsInputs().size() == +// opLayout.getSupportedInputLayouts().size()); +// for (auto [index, layout] : +// llvm::enumerate(opLayout.getSupportedInputLayouts())) { +// // if dimension mismatch, then the op itself is already packed +// if (layout.getOuterAxis().size() != +// cast(linalgOp.getDpsInputs()[index].getType()) +// .getShape() +// .size()) +// return true; +// } +// for (auto [index, layout] : +// llvm::enumerate(opLayout.getSupportedOutputLayouts())) { +// // if dimension mismatch, then the op itself is already packed +// if (layout.getOuterAxis().size() != +// cast(linalgOp.getDpsInits()[index].getType()) +// .getShape() +// .size()) +// return true; +// } +// } else { +// assert(op->getNumOperands() == 1 && op->getNumResults() == 1); +// } +// return false; +// } using ControlPackNamedOpsFn = std::function(Operation *)>; @@ -317,11 +317,11 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, // insert pack OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(op); - if (checkPacked(op, *opLayout)) { - LLVM_DEBUG(llvm::dbgs() - << "Op " << op->getName() << " already packed.\n"); - return WalkResult::advance(); - } + // if (checkPacked(op, *opLayout)) { + // LLVM_DEBUG(llvm::dbgs() + // << "Op " << op->getName() << " already packed.\n"); + // return WalkResult::advance(); + // } if (auto linalgOp = dyn_cast(op)) { if (failed(packLinalgOp(rewriter, linalgOp, *opLayout))) { return WalkResult::skip(); @@ -735,12 +735,15 @@ void PropagateLayoutOnNamedOps::runOnOperation() { mlir::linalg::ControlBlockPackMatmulFn packMatmulControlFn = [&](linalg::LinalgOp op) -> mlir::linalg::BlockPackMatmulOptions { mlir::linalg::BlockPackMatmulOptions options; - auto matmulLayout = *(layoutAnalysisResult.getOpLayout(op)); - // currently supported combination: plain & blocking & plain OR blocking & - // blocking & blocking - TensorLayout inputLayout = matmulLayout.getSupportedInputLayouts()[0]; - TensorLayout weightLayout = matmulLayout.getSupportedInputLayouts()[1]; - TensorLayout outputLayout = matmulLayout.getSupportedOutputLayouts()[0]; + FailureOr matmulLayout = + layoutAnalysisResult.getOpLayout(op); + if (failed(matmulLayout)) + return options; // return default options to skip packing + // currently supported combination: plain & blocking & plain || + // blocking & blocking & blocking + TensorLayout inputLayout = matmulLayout->getSupportedInputLayouts()[0]; + TensorLayout weightLayout = matmulLayout->getSupportedInputLayouts()[1]; + TensorLayout outputLayout = matmulLayout->getSupportedOutputLayouts()[0]; if (!inputLayout.isBlocking() && !weightLayout.isBlocking() && !outputLayout.isBlocking()) return options; // return default options to skip packing From 95846387fbf5250b161883fea9da8fede24a0f8e Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Wed, 18 Sep 2024 02:47:43 -0700 Subject: [PATCH 17/23] fix ci --- include/gc/Transforms/Passes.td | 5 +++++ lib/gc/Analysis/GlobalAnalysis.cpp | 10 +++++----- lib/gc/Transforms/PropagateLayout.cpp | 21 +++++++++++++++++---- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index 905c4f4eb..de0a4e6ff 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -179,6 +179,11 @@ def PropagateLayoutOnNamedOps : Pass<"propagate-layout-on-named-ops"> { "mlir::linalg::LinalgDialect", "mlir::linalgx::LinalgxDialect" ]; + let options = [ + Option<"forceBlocking", "force-blocking", "bool", + /*default=*/"false", + "Choose blocking layout for all matmul op, override the default matmul layout heuristic.">, + ]; } def PostProcessPackUnpack : Pass<"post-process-pack-unpack"> { diff --git a/lib/gc/Analysis/GlobalAnalysis.cpp b/lib/gc/Analysis/GlobalAnalysis.cpp index 45c0c586c..39f66ec53 100644 --- a/lib/gc/Analysis/GlobalAnalysis.cpp +++ b/lib/gc/Analysis/GlobalAnalysis.cpp @@ -144,8 +144,8 @@ inferTargetLayout(const TensorLayout &layoutBase, newDimAxis.push_back(pair.first); } } - // TODO(yifei): double consider the performance, whether to push all new axis - // at the beginning of outer perm + // TODO(yifei): double consider the performance + // whether to push all new axis at the beginning of outer perm targetOuterAxis.insert(targetOuterAxis.begin(), newDimAxis.begin(), newDimAxis.end()); for (auto &&[ia, ts] : llvm::zip(baseInnerAxis, baseTileSizes)) { @@ -381,7 +381,7 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { for (auto input : curInputs) { auto parent = input->get().getDefiningOp(); if (tmpLayoutCache.find(parent) != tmpLayoutCache.end()) { - // TODO(yifei): it is not always 0 here + // TODO(yifei): extend to cases with multiple outputs curInputLayouts.push_back( tmpLayoutCache[parent].getOutputLayout(0)); } else { @@ -570,10 +570,10 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { } namespace utils { +// TODO(yifei): extend to batch matmuls, sync with deep tile matmul bool isSupportedContractionNamedOp(linalg::LinalgOp &linalgOp) { return isa( - linalgOp); + linalg::MatmulTransposeBOp>(linalgOp); } bool isPackableNamedOp(Operation *op) { diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp index e74286b29..3c378deb8 100644 --- a/lib/gc/Transforms/PropagateLayout.cpp +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -480,7 +480,7 @@ static LogicalResult packVNNIGeneric(RewriterBase &rewriter, "require packed MM4D matmul semantics"); OpOperand &weight = matmulOp->getOpOperand(1); - // TODO(yifei): check ISA + // TODO(yifei): check ISA feasibility Location loc = matmulOp.getLoc(); int64_t blockingFactor = elementType.isBF16() ? 2 : 4; SmallVector tileSize{rewriter.getIndexAttr(blockingFactor)}; @@ -588,12 +588,25 @@ revertMatmulPacking(MLIRContext *ctx, mlir::Operation *graph, auto packInputInnerTiles = packInputOp.getMixedTiles(); auto packInputInnerDimsPos = packInputOp.getInnerDimsPos(); auto packInputOuterDimsPerm = packInputOp.getInnerDimsPos(); + llvm::SmallVector unpackInputInnerDimsPos( + packInputInnerDimsPos); + // eliminate the transpose semantic in unpack + llvm::SmallDenseMap axisMapping; + if (!packInputOuterDimsPerm.empty()) { + for (auto [index, axis] : llvm::enumerate(packInputOuterDimsPerm)) { + axisMapping[axis] = index; + } + for (size_t i = 0; i < packInputOuterDimsPerm.size(); ++i) { + unpackInputInnerDimsPos[i] = + axisMapping[unpackInputInnerDimsPos[i]]; + } + } Value unpackInputDest = tensor::UnPackOp::createDestinationTensor( rewriter, loc, packInputOp, packInputInnerTiles, - packInputInnerDimsPos, packInputOuterDimsPerm); + unpackInputInnerDimsPos, ArrayRef{}); Value reUnpackInput = rewriter.create( - loc, packInputOp, unpackInputDest, packInputInnerDimsPos, - packInputInnerTiles, packInputOuterDimsPerm); + loc, packInputOp, unpackInputDest, unpackInputInnerDimsPos, + packInputInnerTiles); // unpack init auto packInitInnerTiles = packInitOp.getMixedTiles(); auto packInitInnerDimsPos = packInitOp.getInnerDimsPos(); From e979fe25381d6a7fb03000ccd7416549ac2a4cfc Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Thu, 19 Sep 2024 00:44:02 -0700 Subject: [PATCH 18/23] refactor 1 --- include/gc/Analysis/GlobalAnalysis.h | 10 +- lib/gc/Analysis/GlobalAnalysis.cpp | 196 ++++++++++++++++---------- lib/gc/Transforms/PropagateLayout.cpp | 110 +++++++-------- 3 files changed, 181 insertions(+), 135 deletions(-) diff --git a/include/gc/Analysis/GlobalAnalysis.h b/include/gc/Analysis/GlobalAnalysis.h index 824e8c904..bee0306aa 100644 --- a/include/gc/Analysis/GlobalAnalysis.h +++ b/include/gc/Analysis/GlobalAnalysis.h @@ -88,9 +88,9 @@ class TensorLayout { friend llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, const TensorLayout &layout); - bool operator==(const TensorLayout &layout) const; + bool operator==(const TensorLayout &other) const; - bool operator!=(const TensorLayout &layout) const; + bool operator!=(const TensorLayout &other) const; private: SmallVector outerAxis; @@ -154,9 +154,11 @@ class GlobalAnalysis { }; namespace utils { -bool isSupportedContractionNamedOp(linalg::LinalgOp &linalgOp); +bool isSupportedContractionNamedOp(const linalg::LinalgOp &linalgOp); -bool isPackableNamedOp(Operation *op); +bool isPackableOp(Operation *op); + +bool hasAllTensorSemantics(linalg::LinalgOp linalgOp); } // namespace utils } // namespace gc } // namespace mlir diff --git a/lib/gc/Analysis/GlobalAnalysis.cpp b/lib/gc/Analysis/GlobalAnalysis.cpp index 39f66ec53..7612671b6 100644 --- a/lib/gc/Analysis/GlobalAnalysis.cpp +++ b/lib/gc/Analysis/GlobalAnalysis.cpp @@ -1,4 +1,4 @@ -//===-- GlobalAnalysis.cpp - Analyze layout on named ops --------*- C++ -*-===// +//===-- GlobalAnalysis.cpp - Infer layout on packable ops -------*- C++ -*-===// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -11,13 +11,47 @@ #include "gc/Analysis/GlobalAnalysis.h" #include "gc/Analysis/MatmulConfigAnalysis.h" #include "llvm/ADT/SetOperations.h" -#include "llvm/ADT/SetVector.h" namespace mlir { namespace gc { #define DEBUG_TYPE "global-analysis" +namespace utils { +// TODO(yifei): extend to batch matmuls, sync with deep tile matmul +bool isSupportedContractionNamedOp(const linalg::LinalgOp &linalgOp) { + return isa(linalgOp); +} + +bool isPackableOp(Operation *op) { + if (auto linalgOp = dyn_cast(op)) { + if (!mlir::linalg::isaContractionOpInterface(linalgOp) && + !mlir::linalg::isaConvolutionOpInterface(linalgOp) && + !isSupportedContractionNamedOp(linalgOp)) { + return true; + } + } else if (isa( + op)) + return true; + return false; +} + +bool hasAllTensorSemantics(linalg::LinalgOp linalgOp) { + SmallVector initOperands = llvm::to_vector(llvm::map_range( + linalgOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); + SmallVector inputOperands = linalgOp.getDpsInputOperands(); + return llvm::all_of(inputOperands, + [](OpOperand *opOperand) { + return mlir::isa( + opOperand->get().getType()); + }) && + llvm::all_of(initOperands, [](OpOperand *opOperand) { + return mlir::isa(opOperand->get().getType()); + }); +} +} // namespace utils + llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, const TensorLayout &layout) { SmallVector outerAxis = layout.getOuterAxis(); @@ -38,14 +72,14 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, return ss; } -bool TensorLayout::operator==(const TensorLayout &layout) const { - return (this->outerAxis == layout.getOuterAxis()) && - (this->innerAxis == layout.getInnerAxis()) && - (this->tileSizes == layout.getTileSizes()); +bool TensorLayout::operator==(const TensorLayout &other) const { + return (this->outerAxis == other.getOuterAxis()) && + (this->innerAxis == other.getInnerAxis()) && + (this->tileSizes == other.getTileSizes()); } -bool TensorLayout::operator!=(const TensorLayout &layout) const { - return !(*this == layout); +bool TensorLayout::operator!=(const TensorLayout &other) const { + return !(*this == other); } llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, @@ -90,7 +124,6 @@ inferIndexingMapRelation(AffineMap indexingMapBase, if (res.find(j) == res.end()) res[j] = -1; } - // check res DenseSet indexSet; for (auto pair : res) { if (indexSet.find(pair.second) != indexSet.end()) { @@ -121,28 +154,30 @@ getReversedIndexMap(const DenseMap &indexMap, return res; } -static TensorLayout -inferTargetLayout(const TensorLayout &layoutBase, - const DenseMap &indexMap) { +static FailureOr inferTargetLayout(const TensorLayout &layoutBase, + AffineMap indexingMapBase, + AffineMap indexingMapTarget) { SmallVector baseOuterAxis = layoutBase.getOuterAxis(); SmallVector baseInnerAxis = layoutBase.getInnerAxis(); SmallVector baseTileSizes = layoutBase.getTileSizes(); SmallVector targetOuterAxis; SmallVector targetInnerAxis; SmallVector targetTileSizes; + FailureOr> indexMap = + inferIndexingMapRelation(indexingMapBase, indexingMapTarget); + if (failed(indexMap)) + return failure(); DenseMap reverseIndexMap = - getReversedIndexMap(indexMap, layoutBase.getRank()); - for (auto oa : baseOuterAxis) { - if (reverseIndexMap[oa] >= 0) { + getReversedIndexMap(*indexMap, layoutBase.getRank()); + for (int64_t oa : baseOuterAxis) { + if (reverseIndexMap[oa] >= 0) targetOuterAxis.push_back(reverseIndexMap[oa]); - } } // filling up new j axes SmallVector newDimAxis; - for (auto pair : indexMap) { - if (pair.second < 0) { + for (const auto &pair : *indexMap) { + if (pair.second < 0) newDimAxis.push_back(pair.first); - } } // TODO(yifei): double consider the performance // whether to push all new axis at the beginning of outer perm @@ -157,7 +192,8 @@ inferTargetLayout(const TensorLayout &layoutBase, return TensorLayout(targetOuterAxis, targetInnerAxis, targetTileSizes); } -static size_t getTargetInputIdx(ArrayRef curInputLayouts) { +// TODO(yifei): enhance the logic for choose base input index +static size_t getBaseInputIdx(ArrayRef curInputLayouts) { for (size_t i = 0; i < curInputLayouts.size(); ++i) { if (!curInputLayouts[i].isPlain()) { return i; @@ -318,9 +354,9 @@ queryMatmulLayout(IRRewriter &rewriter, linalg::LinalgOp matmulOp, GlobalAnalysis::GlobalAnalysis(Operation *root) { IRRewriter rewriter(root); + // stage 1: calculate the total number of layout combination int64_t totalLayoutPossibilities = 1; std::vector possibilities; - int64_t numMatmuls = 0; root->walk([&](Operation *op) { if (auto linalgOp = dyn_cast(op)) { if (mlir::gc::utils::isSupportedContractionNamedOp(linalgOp)) { @@ -333,20 +369,23 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { queryMatmulLayout(rewriter, linalgOp, curInputLayouts); possibilities.push_back(suggestedLayouts.size()); totalLayoutPossibilities *= possibilities.back(); - numMatmuls++; } } return WalkResult::advance(); }); + // define cost function auto computePackingCost = [&](linalg::LinalgOp linalgOp, ArrayRef curInputLayouts, - ArrayRef suggestedLayout) -> int64_t { + ArrayRef suggestedLayouts = {}) -> int64_t { int64_t cost = 0; - for (auto [operand, curLayout, suggestedLayout] : - llvm::zip(linalgOp.getDpsInputOperands(), curInputLayouts, - suggestedLayout)) { + auto inputOperands = linalgOp.getDpsInputOperands(); + for (auto [index, curLayout] : llvm::enumerate(curInputLayouts)) { + TensorLayout suggestedLayout = + suggestedLayouts.empty() + ? TensorLayout::createPlainLayout(curLayout.getRank()) + : suggestedLayouts[index]; if (curLayout != suggestedLayout) { - ArrayRef shape = linalgOp.getShape(operand); + ArrayRef shape = linalgOp.getShape(inputOperands[index]); int64_t inputSize = std::accumulate( shape.begin(), shape.end(), (int64_t)1, std::multiplies()); if (suggestedLayout.isBlocking()) @@ -359,8 +398,9 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { }; std::vector curChoice(possibilities.size(), 0); int64_t bestCost = std::numeric_limits::max(); + // stage 2: infer layout for each possibility for (int64_t trialIdx = 0; trialIdx < totalLayoutPossibilities; ++trialIdx) { - // trialIdx to map + // stage 2.1: get the current layout choice int64_t tmpIdx = trialIdx; for (size_t i = 0; i < possibilities.size(); i++) { curChoice[i] = tmpIdx % possibilities[i]; @@ -371,11 +411,22 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { LLVM_DEBUG(llvm::dbgs() << "].\n"); int64_t curMatmulIdx = 0; int64_t curCost = 0; + // stage 2.2: infer the current temp layout for the whole graph DenseMap tmpLayoutCache; root->walk([&](Operation *op) { + LLVM_DEBUG(llvm::dbgs() + << "Try inferring layout for op: " << op->getName() << "\n"); if (auto linalgOp = dyn_cast(op)) { auto curInputs = linalgOp.getDpsInputOperands(); auto curResults = linalgOp.getOperation()->getResults(); + // if any input/output is not tensor, skip it + if (!gc::utils::hasAllTensorSemantics(linalgOp)) { + LLVM_DEBUG( + llvm::dbgs() + << "Op " << linalgOp.getOperation()->getName() + << " contains non-tensor operand. Skip layout inference.\n"); + return WalkResult::skip(); + } // get current op's input layouts SmallVector curInputLayouts; for (auto input : curInputs) { @@ -389,7 +440,7 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { linalgOp.getMatchingIndexingMap(input).getNumResults())); } } - // infer current op's output layout accordingly + // start infer current op's layout if (mlir::gc::utils::isSupportedContractionNamedOp(linalgOp)) { auto suggestedLayouts = queryMatmulLayout(rewriter, linalgOp, curInputLayouts, false); @@ -398,36 +449,46 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { curCost += computePackingCost( linalgOp, curInputLayouts, tmpLayoutCache[linalgOp].getSupportedInputLayouts()); - } else if (mlir::gc::utils::isPackableNamedOp(op)) { + } else if (mlir::gc::utils::isPackableOp(op)) { // infer layout for non-contraction/non-convolution linalg named ops // and linalg generic ops SmallVector inputLayouts, outputLayouts; - size_t targetIdx = getTargetInputIdx(curInputLayouts); + size_t baseIdx = getBaseInputIdx(curInputLayouts); + // infer layout for inputs for (size_t i = 0; i < curInputs.size(); ++i) { - // getMatchingIndexingMap - if (i != targetIdx) { - auto indexRelation = inferIndexingMapRelation( - linalgOp.getMatchingIndexingMap(curInputs[targetIdx]), + if (i != baseIdx) { + FailureOr inferredLayout = inferTargetLayout( + curInputLayouts[baseIdx], + linalgOp.getMatchingIndexingMap(curInputs[baseIdx]), linalgOp.getMatchingIndexingMap(curInputs[i])); - if (failed(indexRelation)) { + if (failed(inferredLayout)) { + LLVM_DEBUG( + llvm::dbgs() + << "Op " << linalgOp.getOperation()->getName() + << "'s input " << i + << "'s layout cannot be inferred. Choose plain layout.\n"); + curCost += computePackingCost(linalgOp, curInputLayouts); return WalkResult::skip(); } - TensorLayout inputLayout = - inferTargetLayout(curInputLayouts[targetIdx], *indexRelation); - inputLayouts.push_back(inputLayout); + inputLayouts.push_back(*inferredLayout); } else { - inputLayouts.push_back(curInputLayouts[targetIdx]); + inputLayouts.push_back(curInputLayouts[baseIdx]); } } - auto indexRelation = inferIndexingMapRelation( - linalgOp.getMatchingIndexingMap(curInputs[targetIdx]), + // infer layout for output + FailureOr inferredOutputLayout = inferTargetLayout( + curInputLayouts[baseIdx], + linalgOp.getMatchingIndexingMap(curInputs[baseIdx]), linalgOp.getIndexingMapMatchingResult(curResults[0])); - if (failed(indexRelation)) { + if (failed(inferredOutputLayout)) { + LLVM_DEBUG(llvm::dbgs() + << "Op " << linalgOp.getOperation()->getName() + << "'s output layout cannot be inferred. Choose plain " + "layout.\n"); + curCost += computePackingCost(linalgOp, curInputLayouts); return WalkResult::skip(); } - TensorLayout outputLayout = - inferTargetLayout(curInputLayouts[targetIdx], *indexRelation); - outputLayouts.push_back(outputLayout); + outputLayouts.push_back(*inferredOutputLayout); OperatorLayout suggestedLayout(inputLayouts, outputLayouts); tmpLayoutCache[linalgOp] = suggestedLayout; curCost += @@ -461,6 +522,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { if (tileSizes) { innerTileSizes = *tileSizes; } else { + LLVM_DEBUG(llvm::dbgs() + << "ExpandShapeOp's layout cannot be penetrated. Skip.\n"); return WalkResult::skip(); } SmallVector innerPosPos = curInputLayout.getInnerAxis(); @@ -471,6 +534,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, staticOutputShape, innerTileSizes)) { + LLVM_DEBUG(llvm::dbgs() + << "ExpandShapeOp's layout cannot be penetrated. Skip.\n"); return WalkResult::skip(); } SmallVector newOuterDimsPerm; @@ -504,14 +569,21 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { llvm::SetVector packedDims = llvm::set_intersection(innerPosSet, collapseDimPos); // only one of the collapsed indices can be packed - if (packedDims.size() > 1) + if (packedDims.size() > 1) { + LLVM_DEBUG( + llvm::dbgs() + << "CollapseShapeOp's layout cannot be penetrated. Skip.\n"); return WalkResult::skip(); + } // Only the inner-most expanded dimension should be packed. Otherwise, // elements order will be affected after operation reordering. - if (!packedDims.empty() && packedDims[0] != indices.back()) + if (!packedDims.empty() && packedDims[0] != indices.back()) { + LLVM_DEBUG( + llvm::dbgs() + << "CollapseShapeOp's layout cannot be penetrated. Skip.\n"); return WalkResult::skip(); + } } - // Project pack.inner_dims_pos to positions before shape expansion. SmallVector projectedInnerDimsPos; for (auto pos : innerPos) { @@ -537,8 +609,11 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { return collapseDim == outerPerm[axisIdx]; })) { for (auto collapseDim : indices) { - if (collapseDim != outerPerm[axisIdx++]) + if (collapseDim != outerPerm[axisIdx++]) { + LLVM_DEBUG(llvm::dbgs() << "CollapseShapeOp's layout cannot " + "be penetrated. Skip.\n"); return WalkResult::skip(); + } } newOuterDimsPerm.push_back(idx); break; @@ -568,26 +643,5 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { } } } - -namespace utils { -// TODO(yifei): extend to batch matmuls, sync with deep tile matmul -bool isSupportedContractionNamedOp(linalg::LinalgOp &linalgOp) { - return isa(linalgOp); -} - -bool isPackableNamedOp(Operation *op) { - if (auto linalgOp = dyn_cast(op)) { - if (!mlir::linalg::isaContractionOpInterface(linalgOp) && - !isa(linalgOp.getOperation()) && - !isSupportedContractionNamedOp(linalgOp)) { - return true; - } - } else if (isa( - op)) - return true; - return false; -} -} // namespace utils } // namespace gc } // namespace mlir diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp index 3c378deb8..b6647e1d4 100644 --- a/lib/gc/Transforms/PropagateLayout.cpp +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -139,8 +139,6 @@ static int64_t applyPermutationAndReindexReassoc( // extends linalg::pack(...) for named ops LogicalResult packLinalgOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, const OperatorLayout &opLayout) { - if (linalgOp.hasPureBufferSemantics()) - return rewriter.notifyMatchFailure(linalgOp, "require tensor semantics"); LLVM_DEBUG(llvm::dbgs() << "Try packing named op " << linalgOp.getOperation()->getName() << ".\n"); Location loc = linalgOp->getLoc(); @@ -154,18 +152,10 @@ LogicalResult packLinalgOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, SmallVector initLayouts = opLayout.getSupportedOutputLayouts(); // check all inputs and inits are tensor, otherwise no need for layout // propagation - bool allTensor = - llvm::all_of(inputOperands, - [](OpOperand *opOperand) { - return mlir::isa(opOperand->get().getType()); - }) && - llvm::all_of(initOperands, [](OpOperand *opOperand) { - return mlir::isa(opOperand->get().getType()); - }); - if (!allTensor) { - LLVM_DEBUG(llvm::dbgs() << "At least one input of named op: " + if (!gc::utils::hasAllTensorSemantics(linalgOp)) { + LLVM_DEBUG(llvm::dbgs() << "All inputs and outputs of linalg op: " << linalgOp.getOperation()->getName() - << " is not tensor. Skip.\n"); + << " shall be tensor. Skip layout packing.\n"); return failure(); } for (const auto &operandsList : {inputOperands, initOperands}) { @@ -199,7 +189,7 @@ LogicalResult packLinalgOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, } } - // Step 3. Build the packed op, use the type of `inits` as result types. + // Step 3. Build the packed op ValueRange inputs = ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs()); ValueRange inits = @@ -250,37 +240,37 @@ LogicalResult packLinalgOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, return success(); } -// check whether the op is already packed or not -// static bool checkPacked(Operation *op, const OperatorLayout &opLayout) { -// // check whether rank match -// if (auto linalgOp = dyn_cast(op)) { -// assert(linalgOp.getDpsInits().size() == -// opLayout.getSupportedOutputLayouts().size() && -// linalgOp.getDpsInputs().size() == -// opLayout.getSupportedInputLayouts().size()); -// for (auto [index, layout] : -// llvm::enumerate(opLayout.getSupportedInputLayouts())) { -// // if dimension mismatch, then the op itself is already packed -// if (layout.getOuterAxis().size() != -// cast(linalgOp.getDpsInputs()[index].getType()) -// .getShape() -// .size()) -// return true; -// } -// for (auto [index, layout] : -// llvm::enumerate(opLayout.getSupportedOutputLayouts())) { -// // if dimension mismatch, then the op itself is already packed -// if (layout.getOuterAxis().size() != -// cast(linalgOp.getDpsInits()[index].getType()) -// .getShape() -// .size()) -// return true; -// } -// } else { -// assert(op->getNumOperands() == 1 && op->getNumResults() == 1); -// } -// return false; -// } +// check whether non-contraction packable ops are already packed or not +static bool checkPacked(Operation *op, const OperatorLayout &opLayout) { + // check whether rank match + if (auto linalgOp = dyn_cast(op)) { + assert(linalgOp.getDpsInits().size() == + opLayout.getSupportedOutputLayouts().size() && + linalgOp.getDpsInputs().size() == + opLayout.getSupportedInputLayouts().size()); + for (auto [index, layout] : + llvm::enumerate(opLayout.getSupportedInputLayouts())) { + // if dimension mismatch, then the op itself is already packed + if (layout.getOuterAxis().size() != + cast(linalgOp.getDpsInputs()[index].getType()) + .getShape() + .size()) + return true; + } + for (auto [index, layout] : + llvm::enumerate(opLayout.getSupportedOutputLayouts())) { + // if dimension mismatch, then the op itself is already packed + if (layout.getOuterAxis().size() != + cast(linalgOp.getDpsInits()[index].getType()) + .getShape() + .size()) + return true; + } + } else { + assert(op->getNumOperands() == 1 && op->getNumResults() == 1); + } + return false; +} using ControlPackNamedOpsFn = std::function(Operation *)>; @@ -297,15 +287,15 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, ControlPackNamedOpsFn controlFn) { IRRewriter rewriter(ctx); graph->walk([&](Operation *op) { - if (mlir::gc::utils::isPackableNamedOp(op)) { + if (mlir::gc::utils::isPackableOp(op)) { LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName() << " visited.\n"); - FailureOr opLayout = controlFn(op); - if (failed(opLayout)) { + if (failed(controlFn(op))) { LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName() << " does not have layout information.\n"); return WalkResult::skip(); } - if ((*opLayout).isPlain()) { + OperatorLayout opLayout = *controlFn(op); + if (opLayout.isPlain()) { LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName() << " has plain layout, skip packing.\n"); return WalkResult::advance(); @@ -313,23 +303,23 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, // pack op into ideal layout LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName() << "'s inferred layout:\n" - << *opLayout << "\n"); + << opLayout << "\n"); // insert pack OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(op); - // if (checkPacked(op, *opLayout)) { - // LLVM_DEBUG(llvm::dbgs() - // << "Op " << op->getName() << " already packed.\n"); - // return WalkResult::advance(); - // } + if (checkPacked(op, opLayout)) { + LLVM_DEBUG(llvm::dbgs() + << "Op " << op->getName() << " already packed.\n"); + return WalkResult::advance(); + } if (auto linalgOp = dyn_cast(op)) { - if (failed(packLinalgOp(rewriter, linalgOp, *opLayout))) { + if (failed(packLinalgOp(rewriter, linalgOp, opLayout))) { return WalkResult::skip(); } } else if (auto expandShapeOp = dyn_cast(op)) { Location loc = expandShapeOp->getLoc(); - auto inputLayout = opLayout->getSupportedInputLayouts()[0]; - auto outputLayout = opLayout->getSupportedOutputLayouts()[0]; + auto inputLayout = opLayout.getSupportedInputLayouts()[0]; + auto outputLayout = opLayout.getSupportedOutputLayouts()[0]; Value curSrc = expandShapeOp.getSrc(); Value curDst = expandShapeOp.getResult(); Value dest = tensor::PackOp::createDestinationTensor( @@ -359,8 +349,8 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, rewriter.replaceOp(expandShapeOp, newUnPackOp); } else if (auto collapseShapeOp = dyn_cast(op)) { Location loc = collapseShapeOp->getLoc(); - auto inputLayout = opLayout->getSupportedInputLayouts()[0]; - auto outputLayout = opLayout->getSupportedOutputLayouts()[0]; + auto inputLayout = opLayout.getSupportedInputLayouts()[0]; + auto outputLayout = opLayout.getSupportedOutputLayouts()[0]; Value curSrc = collapseShapeOp.getSrc(); Value curDst = collapseShapeOp.getResult(); Value dest = tensor::PackOp::createDestinationTensor( From 65ac7ef111dff714b780989d776985b722434ebc Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Thu, 19 Sep 2024 22:42:11 -0700 Subject: [PATCH 19/23] update test --- .../test/gc/Transforms/pack-llama-mlp.mlir | 122 ++++++++++++++++++ test/mlir/test/gc/Transforms/pack-matmul.mlir | 14 +- 2 files changed, 134 insertions(+), 2 deletions(-) create mode 100644 test/mlir/test/gc/Transforms/pack-llama-mlp.mlir diff --git a/test/mlir/test/gc/Transforms/pack-llama-mlp.mlir b/test/mlir/test/gc/Transforms/pack-llama-mlp.mlir new file mode 100644 index 000000000..c4e5ce89f --- /dev/null +++ b/test/mlir/test/gc/Transforms/pack-llama-mlp.mlir @@ -0,0 +1,122 @@ +// RUN: gc-opt %s --split-input-file --propagate-layout-on-named-ops --post-process-pack-unpack | FileCheck %s + +// ----- + +// CHECK-LABEL: @llama2_mlp +func.func @llama2_mlp(%arg0: tensor<1x32x4096xbf16>, %arg1: tensor<4096x4096xbf16>, %arg2: tensor<1x32x4096xbf16>, %arg3: tensor<1xf32>, %arg4: tensor<4096xbf16>, %arg5: tensor<11008x4096xbf16>, %arg6: tensor<11008x4096xbf16>, %arg7: tensor<4096x11008xbf16>, %arg8: tensor<1xf32>, %arg9: tensor<4096xbf16>) -> (tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) { + %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<1x32x4096xbf16> into tensor<32x4096xbf16> + %cst = arith.constant 0.000000e+00 : bf16 + %cst_0 = arith.constant 1.000000e+00 : bf16 + %0 = tensor.empty() : tensor<32x4096xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %2 = linalg.matmul_transpose_b ins(%collapsed, %arg1 : tensor<32x4096xbf16>, tensor<4096x4096xbf16>) outs(%1 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %expanded = tensor.expand_shape %2 [[0, 1], [2]] output_shape [1, 32, 4096] : tensor<32x4096xbf16> into tensor<1x32x4096xbf16> + %3 = tensor.empty() : tensor<1x32x4096xbf16> + %4 = linalg.add ins(%arg2, %expanded : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%3 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %5 = tensor.empty() : tensor<1x32x4096xf32> + %6 = linalg.copy ins(%4 : tensor<1x32x4096xbf16>) outs(%5 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_1 = arith.constant dense<2.000000e+00> : tensor<1x32x4096xf32> + %7 = tensor.empty() : tensor<1x32x4096xf32> + %8 = linalg.powf ins(%6, %cst_1 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%7 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_2 = arith.constant 0.000000e+00 : f32 + %9 = tensor.empty() : tensor<1x32xf32> + %10 = linalg.fill ins(%cst_2 : f32) outs(%9 : tensor<1x32xf32>) -> tensor<1x32xf32> + %reduced = linalg.reduce ins(%8 : tensor<1x32x4096xf32>) outs(%10 : tensor<1x32xf32>) dimensions = [2] + (%in: f32, %init: f32) { + %72 = arith.addf %in, %init : f32 + linalg.yield %72 : f32 + } + %cst_3 = arith.constant dense<4.096000e+03> : tensor<1x32xf32> + %11 = tensor.empty() : tensor<1x32xf32> + %12 = linalg.div ins(%reduced, %cst_3 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%11 : tensor<1x32xf32>) -> tensor<1x32xf32> + %expanded_4 = tensor.expand_shape %12 [[0], [1, 2]] output_shape [1, 32, 1] : tensor<1x32xf32> into tensor<1x32x1xf32> + %13 = tensor.empty() : tensor<1x32x1xf32> + %broadcasted = linalg.broadcast ins(%arg8 : tensor<1xf32>) outs(%13 : tensor<1x32x1xf32>) dimensions = [0, 1] + %14 = tensor.empty() : tensor<1x32x1xf32> + %15 = linalg.add ins(%expanded_4, %broadcasted : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%14 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %cst_5 = arith.constant dense<-5.000000e-01> : tensor<1x32x1xf32> + %16 = tensor.empty() : tensor<1x32x1xf32> + %17 = linalg.powf ins(%15, %cst_5 : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%16 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %collapsed_6 = tensor.collapse_shape %17 [[0], [1, 2]] : tensor<1x32x1xf32> into tensor<1x32xf32> + %18 = tensor.empty() : tensor<1x32x4096xf32> + %broadcasted_7 = linalg.broadcast ins(%collapsed_6 : tensor<1x32xf32>) outs(%18 : tensor<1x32x4096xf32>) dimensions = [2] + %19 = tensor.empty() : tensor<1x32x4096xf32> + %20 = linalg.mul ins(%6, %broadcasted_7 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%19 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %21 = tensor.empty() : tensor<1x32x4096xbf16> + %22 = linalg.copy ins(%20 : tensor<1x32x4096xf32>) outs(%21 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %23 = tensor.empty() : tensor<1x32x4096xbf16> + %broadcasted_8 = linalg.broadcast ins(%arg4 : tensor<4096xbf16>) outs(%23 : tensor<1x32x4096xbf16>) dimensions = [0, 1] + %24 = tensor.empty() : tensor<1x32x4096xbf16> + %25 = linalg.mul ins(%broadcasted_8, %22 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%24 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %collapsed_9 = tensor.collapse_shape %25 [[0, 1], [2]] : tensor<1x32x4096xbf16> into tensor<32x4096xbf16> + %cst_10 = arith.constant 0.000000e+00 : bf16 + %26 = tensor.empty() : tensor<32x11008xbf16> + %27 = linalg.fill ins(%cst_10 : bf16) outs(%26 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %28 = linalg.matmul_transpose_b ins(%collapsed_9, %arg5 : tensor<32x4096xbf16>, tensor<11008x4096xbf16>) outs(%27 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %expanded_11 = tensor.expand_shape %28 [[0, 1], [2]] output_shape [1, 32, 11008] : tensor<32x11008xbf16> into tensor<1x32x11008xbf16> + %29 = tensor.empty() : tensor<1x32x11008xbf16> + %30 = linalg.negf ins(%expanded_11 : tensor<1x32x11008xbf16>) outs(%29 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %31 = tensor.empty() : tensor<1x32x11008xbf16> + %32 = linalg.exp ins(%30 : tensor<1x32x11008xbf16>) outs(%31 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %33 = tensor.empty() : tensor<1x32x11008xbf16> + %34 = tensor.empty() : tensor<1x32x11008xbf16> + %35 = linalg.fill ins(%cst_0 : bf16) outs(%34 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %36 = linalg.add ins(%32, %35 : tensor<1x32x11008xbf16>, tensor<1x32x11008xbf16>) outs(%33 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %37 = tensor.empty() : tensor<1x32x11008xbf16> + %38 = linalg.reciprocal ins(%36 : tensor<1x32x11008xbf16>) outs(%37 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %39 = tensor.empty() : tensor<1x32x11008xbf16> + %40 = linalg.mul ins(%30, %expanded_11 : tensor<1x32x11008xbf16>, tensor<1x32x11008xbf16>) outs(%39 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %collapsed_12 = tensor.collapse_shape %25 [[0, 1], [2]] : tensor<1x32x4096xbf16> into tensor<32x4096xbf16> + %cst_13 = arith.constant 0.000000e+00 : bf16 + %41 = tensor.empty() : tensor<32x11008xbf16> + %42 = linalg.fill ins(%cst_13 : bf16) outs(%41 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %43 = linalg.matmul_transpose_b ins(%collapsed_12, %arg6 : tensor<32x4096xbf16>, tensor<11008x4096xbf16>) outs(%42 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> + %expanded_14 = tensor.expand_shape %43 [[0, 1], [2]] output_shape [1, 32, 11008] : tensor<32x11008xbf16> into tensor<1x32x11008xbf16> + %44 = tensor.empty() : tensor<1x32x11008xbf16> + %45 = linalg.mul ins(%40, %expanded_14 : tensor<1x32x11008xbf16>, tensor<1x32x11008xbf16>) outs(%44 : tensor<1x32x11008xbf16>) -> tensor<1x32x11008xbf16> + %collapsed_15 = tensor.collapse_shape %45 [[0, 1], [2]] : tensor<1x32x11008xbf16> into tensor<32x11008xbf16> + %cst_16 = arith.constant 0.000000e+00 : bf16 + %46 = tensor.empty() : tensor<32x4096xbf16> + %47 = linalg.fill ins(%cst_16 : bf16) outs(%46 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %48 = linalg.matmul_transpose_b ins(%collapsed_15, %arg7 : tensor<32x11008xbf16>, tensor<4096x11008xbf16>) outs(%47 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> + %expanded_17 = tensor.expand_shape %48 [[0, 1], [2]] output_shape [1, 32, 4096] : tensor<32x4096xbf16> into tensor<1x32x4096xbf16> + %49 = tensor.empty() : tensor<1x32x4096xbf16> + %50 = linalg.add ins(%4, %expanded_17 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%49 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %51 = tensor.empty() : tensor<1x32x4096xf32> + %52 = linalg.copy ins(%50 : tensor<1x32x4096xbf16>) outs(%51 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_18 = arith.constant dense<2.000000e+00> : tensor<1x32x4096xf32> + %53 = tensor.empty() : tensor<1x32x4096xf32> + %54 = linalg.powf ins(%52, %cst_18 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%53 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %cst_19 = arith.constant 0.000000e+00 : f32 + %55 = tensor.empty() : tensor<1x32xf32> + %56 = linalg.fill ins(%cst_19 : f32) outs(%55 : tensor<1x32xf32>) -> tensor<1x32xf32> + %reduced_20 = linalg.reduce ins(%54 : tensor<1x32x4096xf32>) outs(%56 : tensor<1x32xf32>) dimensions = [2] + (%in: f32, %init: f32) { + %72 = arith.addf %in, %init : f32 + linalg.yield %72 : f32 + } + %cst_21 = arith.constant dense<4.096000e+03> : tensor<1x32xf32> + %57 = tensor.empty() : tensor<1x32xf32> + %58 = linalg.div ins(%reduced_20, %cst_21 : tensor<1x32xf32>, tensor<1x32xf32>) outs(%57 : tensor<1x32xf32>) -> tensor<1x32xf32> + %expanded_22 = tensor.expand_shape %58 [[0], [1, 2]] output_shape [1, 32, 1] : tensor<1x32xf32> into tensor<1x32x1xf32> + %59 = tensor.empty() : tensor<1x32x1xf32> + %broadcasted_23 = linalg.broadcast ins(%arg8 : tensor<1xf32>) outs(%59 : tensor<1x32x1xf32>) dimensions = [0, 1] + %60 = tensor.empty() : tensor<1x32x1xf32> + %61 = linalg.add ins(%expanded_22, %broadcasted_23 : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%60 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %cst_24 = arith.constant dense<-5.000000e-01> : tensor<1x32x1xf32> + %62 = tensor.empty() : tensor<1x32x1xf32> + %63 = linalg.powf ins(%61, %cst_24 : tensor<1x32x1xf32>, tensor<1x32x1xf32>) outs(%62 : tensor<1x32x1xf32>) -> tensor<1x32x1xf32> + %collapsed_25 = tensor.collapse_shape %63 [[0], [1, 2]] : tensor<1x32x1xf32> into tensor<1x32xf32> + %64 = tensor.empty() : tensor<1x32x4096xf32> + %broadcasted_26 = linalg.broadcast ins(%collapsed_25 : tensor<1x32xf32>) outs(%64 : tensor<1x32x4096xf32>) dimensions = [2] + %65 = tensor.empty() : tensor<1x32x4096xf32> + %66 = linalg.mul ins(%52, %broadcasted_26 : tensor<1x32x4096xf32>, tensor<1x32x4096xf32>) outs(%65 : tensor<1x32x4096xf32>) -> tensor<1x32x4096xf32> + %67 = tensor.empty() : tensor<1x32x4096xbf16> + %68 = linalg.copy ins(%66 : tensor<1x32x4096xf32>) outs(%67 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + %69 = tensor.empty() : tensor<1x32x4096xbf16> + %broadcasted_27 = linalg.broadcast ins(%arg9 : tensor<4096xbf16>) outs(%69 : tensor<1x32x4096xbf16>) dimensions = [0, 1] + %70 = tensor.empty() : tensor<1x32x4096xbf16> + %71 = linalg.mul ins(%broadcasted_27, %68 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16>) outs(%70 : tensor<1x32x4096xbf16>) -> tensor<1x32x4096xbf16> + return %71, %50 : tensor<1x32x4096xbf16>, tensor<1x32x4096xbf16> +} +// CHECK-COUNT-8: tensor.pack diff --git a/test/mlir/test/gc/Transforms/pack-matmul.mlir b/test/mlir/test/gc/Transforms/pack-matmul.mlir index edb8bcb72..27e752359 100644 --- a/test/mlir/test/gc/Transforms/pack-matmul.mlir +++ b/test/mlir/test/gc/Transforms/pack-matmul.mlir @@ -65,7 +65,12 @@ func.func @mlp_f32(%arg0: tensor<128x16xf32>, %arg1: tensor<16x512xf32>, %arg2: %23 = linalg.max ins(%21, %cst_4 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%22 : tensor<128x128xf32>) -> tensor<128x128xf32> return %23 : tensor<128x128xf32> } -// CHECK-COUNT-3: tensor.pack +// CHECK-COUNT-1: tensor.pack +// CHECK-COUNT-1: linalg.generic +// CHECK-COUNT-1: tensor.pack +// CHECK-COUNT-1: linalg.generic +// CHECK-COUNT-1: tensor.pack +// CHECK-COUNT-1: linalg.generic // CHECK-NOT: tensor.unpack // ----- @@ -105,5 +110,10 @@ func.func @mlp_bf16(%arg0: tensor<32x4096xbf16>, %arg1: tensor<4096x4096xbf16>, %23 = linalg.max ins(%21, %cst_4 : tensor<32x4096xbf16>, tensor<32x4096xbf16>) outs(%22 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> return %23 : tensor<32x4096xbf16> } -// CHECK-COUNT-6: tensor.pack +// CHECK-COUNT-2: tensor.pack +// CHECK-COUNT-1: linalg.generic +// CHECK-COUNT-2: tensor.pack +// CHECK-COUNT-1: linalg.generic +// CHECK-COUNT-2: tensor.pack +// CHECK-COUNT-1: linalg.generic // CHECK-NOT: tensor.unpack From 6d126508fec4980e741d78092beff7eddf3a592f Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Fri, 20 Sep 2024 01:58:58 -0700 Subject: [PATCH 20/23] refactor 2 --- include/gc/Analysis/GlobalAnalysis.h | 9 +- include/gc/Transforms/Passes.td | 5 - lib/gc/Analysis/GlobalAnalysis.cpp | 20 ++- lib/gc/Transforms/PropagateLayout.cpp | 235 ++++++++++++-------------- 4 files changed, 122 insertions(+), 147 deletions(-) diff --git a/include/gc/Analysis/GlobalAnalysis.h b/include/gc/Analysis/GlobalAnalysis.h index bee0306aa..1e7d6beac 100644 --- a/include/gc/Analysis/GlobalAnalysis.h +++ b/include/gc/Analysis/GlobalAnalysis.h @@ -66,11 +66,10 @@ class TensorLayout { return axisMapping; } - FailureOr getPlainAxis(int64_t idx) { + int64_t getPlainAxis(int64_t idx) { int64_t totalRank = outerAxis.size() + innerAxis.size(); - if (idx >= totalRank || idx < 0) { - return failure(); - } else if (idx >= static_cast(outerAxis.size())) { + assert(idx >= 0 && idx < totalRank && "Provided plain axis out of bound"); + if (idx >= static_cast(outerAxis.size())) { return innerAxis[idx - outerAxis.size()]; } else { return outerAxis[idx]; @@ -146,7 +145,7 @@ class GlobalAnalysis { if (layoutCache.find(op) != layoutCache.end()) return layoutCache[op]; else - return failure("Current op does not have layout information."); + return failure(); } private: diff --git a/include/gc/Transforms/Passes.td b/include/gc/Transforms/Passes.td index de0a4e6ff..905c4f4eb 100644 --- a/include/gc/Transforms/Passes.td +++ b/include/gc/Transforms/Passes.td @@ -179,11 +179,6 @@ def PropagateLayoutOnNamedOps : Pass<"propagate-layout-on-named-ops"> { "mlir::linalg::LinalgDialect", "mlir::linalgx::LinalgxDialect" ]; - let options = [ - Option<"forceBlocking", "force-blocking", "bool", - /*default=*/"false", - "Choose blocking layout for all matmul op, override the default matmul layout heuristic.">, - ]; } def PostProcessPackUnpack : Pass<"post-process-pack-unpack"> { diff --git a/lib/gc/Analysis/GlobalAnalysis.cpp b/lib/gc/Analysis/GlobalAnalysis.cpp index 7612671b6..44aba2e75 100644 --- a/lib/gc/Analysis/GlobalAnalysis.cpp +++ b/lib/gc/Analysis/GlobalAnalysis.cpp @@ -375,17 +375,20 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { }); // define cost function auto computePackingCost = - [&](linalg::LinalgOp linalgOp, ArrayRef curInputLayouts, + [&](Operation *op, ArrayRef curInputLayouts, ArrayRef suggestedLayouts = {}) -> int64_t { int64_t cost = 0; - auto inputOperands = linalgOp.getDpsInputOperands(); + assert(op->getOperands().size() >= curInputLayouts.size() && + "curInputLayouts size out of range."); for (auto [index, curLayout] : llvm::enumerate(curInputLayouts)) { TensorLayout suggestedLayout = suggestedLayouts.empty() ? TensorLayout::createPlainLayout(curLayout.getRank()) : suggestedLayouts[index]; if (curLayout != suggestedLayout) { - ArrayRef shape = linalgOp.getShape(inputOperands[index]); + ArrayRef shape = + cast(op->getOperands()[index].getType()) + .getShape(); int64_t inputSize = std::accumulate( shape.begin(), shape.end(), (int64_t)1, std::multiplies()); if (suggestedLayout.isBlocking()) @@ -447,7 +450,7 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { tmpLayoutCache[linalgOp] = suggestedLayouts[curChoice[curMatmulIdx++]]; curCost += computePackingCost( - linalgOp, curInputLayouts, + op, curInputLayouts, tmpLayoutCache[linalgOp].getSupportedInputLayouts()); } else if (mlir::gc::utils::isPackableOp(op)) { // infer layout for non-contraction/non-convolution linalg named ops @@ -467,7 +470,7 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { << "Op " << linalgOp.getOperation()->getName() << "'s input " << i << "'s layout cannot be inferred. Choose plain layout.\n"); - curCost += computePackingCost(linalgOp, curInputLayouts); + curCost += computePackingCost(op, curInputLayouts); return WalkResult::skip(); } inputLayouts.push_back(*inferredLayout); @@ -485,14 +488,13 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { << "Op " << linalgOp.getOperation()->getName() << "'s output layout cannot be inferred. Choose plain " "layout.\n"); - curCost += computePackingCost(linalgOp, curInputLayouts); + curCost += computePackingCost(op, curInputLayouts); return WalkResult::skip(); } outputLayouts.push_back(*inferredOutputLayout); OperatorLayout suggestedLayout(inputLayouts, outputLayouts); tmpLayoutCache[linalgOp] = suggestedLayout; - curCost += - computePackingCost(linalgOp, curInputLayouts, inputLayouts); + curCost += computePackingCost(op, curInputLayouts, inputLayouts); } } else if (auto padOp = dyn_cast(op)) { auto inputOperand = padOp.getSource(); @@ -524,6 +526,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) { } else { LLVM_DEBUG(llvm::dbgs() << "ExpandShapeOp's layout cannot be penetrated. Skip.\n"); + curCost += + computePackingCost(op, SmallVector{curInputLayout}); return WalkResult::skip(); } SmallVector innerPosPos = curInputLayout.getInnerAxis(); diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp index b6647e1d4..ec40892f8 100644 --- a/lib/gc/Transforms/PropagateLayout.cpp +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -112,8 +112,7 @@ static SmallVector getPackedPermAxes(ArrayRef plainPermAxes, auto axisPlainToPacked = inputLayout.getPlainToPackedAxisMapping(); for (int64_t i = 0; i < packedRank; ++i) { // packedOutput[i] --> originalOutputAxis --> originalInputAxis - // TODO: add failed check here - int64_t originalOutputAxis = *outputLayout.getPlainAxis(i); + int64_t originalOutputAxis = outputLayout.getPlainAxis(i); int64_t originalInputAxis = plainPermAxes[originalOutputAxis]; SmallVector packedInputAxes = axisPlainToPacked[originalInputAxis]; result[i] = packedInputAxes[inputCount[originalInputAxis]++]; @@ -136,7 +135,6 @@ static int64_t applyPermutationAndReindexReassoc( return nextPos; } -// extends linalg::pack(...) for named ops LogicalResult packLinalgOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, const OperatorLayout &opLayout) { LLVM_DEBUG(llvm::dbgs() << "Try packing named op " @@ -454,13 +452,6 @@ static LogicalResult packVNNIGeneric(RewriterBase &rewriter, if (matmulOp.hasPureBufferSemantics()) return rewriter.notifyMatchFailure(matmulOp, "require tensor semantics"); - // isContractionInterfaceImpl checks the following restrictions: - // 1. has 2 inputs && 1 outputs - // 2. has >=1 reduction loop - // 3. all affine maps are projected permutations: - // a. no symbols or zeros in result - // b. result is a non-duplicated subset of input - // 4. op body contains both mul&&add if (!mlir::linalg::isaContractionOpInterface(matmulOp)) return rewriter.notifyMatchFailure(matmulOp, "require matmul semantics"); @@ -528,15 +519,16 @@ struct PackVNNI } }; -static FailureOr -shallRevertToType(linalg::GenericOp matmulOp) { +static linalgx::PackingType revertToPackingType(linalg::GenericOp matmulOp) { if (linalgx::isGenericPackedMatmulOp(matmulOp.getOperation(), linalgx::PackingType::MM4D)) return linalgx::PackingType::MM2D4D; else if (linalgx::isGenericPackedMatmulOp(matmulOp.getOperation(), linalgx::PackingType::VNNI_MM4D)) return linalgx::PackingType::VNNI_MM2D; - return failure(); + else + assert(false && + "Unexpected generic op encountered in matmul reversion stage."); } static bool isPlainActivationMatmul(const OperatorLayout &matmulLayout) { @@ -549,94 +541,108 @@ static LogicalResult revertMatmulPacking(MLIRContext *ctx, mlir::Operation *graph, const std::vector &matmulLayouts) { IRRewriter rewriter(ctx); - uint64_t layoutOffset = 0; - graph->walk([&](Operation *op) { + uint64_t layoutIndex = 0; + auto result = graph->walk([&](Operation *op) { if (auto matmulOp = dyn_cast(op)) { - FailureOr revertType = shallRevertToType(matmulOp); - if (succeeded(revertType) && - isPlainActivationMatmul(matmulLayouts[layoutOffset])) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(op); - // replace VNNI_MM4D with unpack + VNNI_MM2D + pack - // get preceding pack and successive unpack - auto packInputOp = matmulOp.getDpsInputOperand(0) - ->get() - .getDefiningOp(); - auto packInitOp = matmulOp.getDpsInitOperand(0) - ->get() - .getDefiningOp(); - if (!packInputOp || !packInitOp) - return WalkResult::skip(); - if (!matmulOp.getResults()[0].hasOneUse()) - return WalkResult::skip(); - auto consumer = matmulOp.getResults()[0].getUses().begin(); - auto unPackOp = dyn_cast(consumer->getOwner()); - if (!unPackOp) - return WalkResult::skip(); - Location loc = matmulOp.getLoc(); - // unpack input - auto packInputInnerTiles = packInputOp.getMixedTiles(); - auto packInputInnerDimsPos = packInputOp.getInnerDimsPos(); - auto packInputOuterDimsPerm = packInputOp.getInnerDimsPos(); - llvm::SmallVector unpackInputInnerDimsPos( - packInputInnerDimsPos); - // eliminate the transpose semantic in unpack - llvm::SmallDenseMap axisMapping; - if (!packInputOuterDimsPerm.empty()) { - for (auto [index, axis] : llvm::enumerate(packInputOuterDimsPerm)) { - axisMapping[axis] = index; + if (linalgx::isGenericPackedMatmulOp(matmulOp.getOperation(), + linalgx::PackingType::MM4D, + linalgx::PackingType::VNNI_MM4D)) { + if (isPlainActivationMatmul(matmulLayouts[layoutIndex])) { + linalgx::PackingType revertType = revertToPackingType(matmulOp); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(op); + // replace matmul 4D with unpack + matmul 2D + pack + auto packInputOp = matmulOp.getDpsInputOperand(0) + ->get() + .getDefiningOp(); + auto packInitOp = matmulOp.getDpsInitOperand(0) + ->get() + .getDefiningOp(); + if (!packInputOp || !packInitOp) + return WalkResult::skip(); + if (!matmulOp.getResults()[0].hasOneUse()) + return WalkResult::skip(); + auto consumer = matmulOp.getResults()[0].getUses().begin(); + auto unPackOp = dyn_cast(consumer->getOwner()); + if (!unPackOp) + return WalkResult::skip(); + Location loc = matmulOp.getLoc(); + // unpack input + auto packInputInnerTiles = packInputOp.getMixedTiles(); + auto packInputInnerDimsPos = packInputOp.getInnerDimsPos(); + auto packInputOuterDimsPerm = packInputOp.getInnerDimsPos(); + llvm::SmallVector unpackInputInnerDimsPos( + packInputInnerDimsPos); + // eliminate the transpose semantic in unpack + llvm::SmallDenseMap axisMapping; + if (!packInputOuterDimsPerm.empty()) { + for (auto [index, axis] : llvm::enumerate(packInputOuterDimsPerm)) { + axisMapping[axis] = index; + } + for (size_t i = 0; i < packInputOuterDimsPerm.size(); ++i) { + unpackInputInnerDimsPos[i] = + axisMapping[unpackInputInnerDimsPos[i]]; + } } - for (size_t i = 0; i < packInputOuterDimsPerm.size(); ++i) { - unpackInputInnerDimsPos[i] = - axisMapping[unpackInputInnerDimsPos[i]]; + Value unpackInputDest = tensor::UnPackOp::createDestinationTensor( + rewriter, loc, packInputOp, packInputInnerTiles, + unpackInputInnerDimsPos, ArrayRef{}); + Value reUnpackInput = rewriter.create( + loc, packInputOp, unpackInputDest, unpackInputInnerDimsPos, + packInputInnerTiles); + // unpack init + auto packInitInnerTiles = packInitOp.getMixedTiles(); + auto packInitInnerDimsPos = packInitOp.getInnerDimsPos(); + auto packInitOuterDimsPerm = packInitOp.getInnerDimsPos(); + // assert packInitOuterDimsPerm is not permuted + if (!packInitOuterDimsPerm.empty()) { + for (auto [index, dim] : llvm::enumerate(packInitOuterDimsPerm)) { + if (static_cast(index) != dim) + assert(false && "Packed matmul's init pack shall not contain " + "permutation semantics."); + } } + Value unpackInitDest = tensor::UnPackOp::createDestinationTensor( + rewriter, loc, packInitOp, packInitInnerTiles, + packInitInnerDimsPos, packInitOuterDimsPerm); + Value reUnpackInit = rewriter.create( + loc, packInitOp, unpackInitDest, packInitInnerDimsPos, + packInitInnerTiles, packInitOuterDimsPerm); + // replace matmul 4D with matmul 2D + auto matmul2D = linalgx::makeGenericPackedMatmulOp( + rewriter, loc, revertType, + ValueRange{reUnpackInput, matmulOp.getDpsInputOperand(1)->get()}, + ValueRange{reUnpackInit}); + if (failed(matmul2D)) + return WalkResult::interrupt(); + // insert pack before unpack + auto unPackInnerTiles = unPackOp.getMixedTiles(); + auto unPackInnerDimsPos = unPackOp.getInnerDimsPos(); + auto unPackOuterDimsPerm = unPackOp.getInnerDimsPos(); + Value packDest = tensor::PackOp::createDestinationTensor( + rewriter, loc, (*matmul2D)->getResult(0), unPackInnerTiles, + unPackInnerDimsPos, unPackOuterDimsPerm); + auto zeroAttr = + rewriter.getZeroAttr(getElementTypeOrSelf(packDest.getType())); + Value zero = rewriter.create(loc, zeroAttr); + Value rePack = rewriter.create( + loc, (*matmul2D)->getResult(0), packDest, unPackInnerDimsPos, + unPackInnerTiles, zero, unPackOuterDimsPerm); + rewriter.replaceOp(op, rePack); } - Value unpackInputDest = tensor::UnPackOp::createDestinationTensor( - rewriter, loc, packInputOp, packInputInnerTiles, - unpackInputInnerDimsPos, ArrayRef{}); - Value reUnpackInput = rewriter.create( - loc, packInputOp, unpackInputDest, unpackInputInnerDimsPos, - packInputInnerTiles); - // unpack init - auto packInitInnerTiles = packInitOp.getMixedTiles(); - auto packInitInnerDimsPos = packInitOp.getInnerDimsPos(); - auto packInitOuterDimsPerm = packInitOp.getInnerDimsPos(); - Value unpackInitDest = tensor::UnPackOp::createDestinationTensor( - rewriter, loc, packInitOp, packInitInnerTiles, packInitInnerDimsPos, - packInitOuterDimsPerm); - Value reUnpackInit = rewriter.create( - loc, packInitOp, unpackInitDest, packInitInnerDimsPos, - packInitInnerTiles, packInitOuterDimsPerm); - // replace vnni_4D with vnni_2D - auto VNNI2D = linalgx::makeGenericPackedMatmulOp( - rewriter, loc, *revertType, - ValueRange{reUnpackInput, matmulOp.getDpsInputOperand(1)->get()}, - ValueRange{reUnpackInit}); - if (failed(VNNI2D)) - return WalkResult::interrupt(); - // insert pack before unpack - auto unPackInnerTiles = unPackOp.getMixedTiles(); - auto unPackInnerDimsPos = unPackOp.getInnerDimsPos(); - auto unPackOuterDimsPerm = unPackOp.getInnerDimsPos(); - Value packDest = tensor::PackOp::createDestinationTensor( - rewriter, loc, (*VNNI2D)->getResult(0), unPackInnerTiles, - unPackInnerDimsPos, unPackOuterDimsPerm); - auto zeroAttr = - rewriter.getZeroAttr(getElementTypeOrSelf(packDest.getType())); - Value zero = rewriter.create(loc, zeroAttr); - Value rePack = rewriter.create( - loc, (*VNNI2D)->getResult(0), packDest, unPackInnerDimsPos, - unPackInnerTiles, zero, unPackOuterDimsPerm); - rewriter.replaceOp(op, rePack); - layoutOffset++; + layoutIndex++; } } else if (auto matmulOp = dyn_cast(op)) { if (mlir::gc::utils::isSupportedContractionNamedOp(matmulOp)) { - layoutOffset++; + layoutIndex++; } } return WalkResult::advance(); }); + if (result.wasInterrupted() || result.wasSkipped()) + return failure(); // reversion not performed as expected + if (layoutIndex != matmulLayouts.size()) + return failure(); // layout index mismatch, reversion failed return success(); } @@ -720,8 +726,9 @@ void PropagateLayoutOnNamedOps::runOnOperation() { MLIRContext *ctx = &getContext(); IRRewriter rewriter(ctx); mlir::Operation *graph = getOperation(); - // pre-collect matmul layouts in topological order auto &layoutAnalysisResult = getAnalysis(); + + // pre-collect matmul layouts std::vector matmulLayouts; graph->walk([&](Operation *op) { if (auto linalgOp = dyn_cast(op)) { @@ -731,9 +738,9 @@ void PropagateLayoutOnNamedOps::runOnOperation() { } return WalkResult::advance(); }); - // stage 1.1: pack matmul with `BlockPackMatmulPatterns` if any side of it - // requires packing; do nothing if the matmul is computed on plain format - // TODO(yifei): deal with transposed plain matmul... + + // stage 1.1: pack matmul with `BlockPackMatmulPatterns` if any side of the + // matmul op requires packing RewritePatternSet packMatmulPatterns(&getContext()); mlir::linalg::ControlBlockPackMatmulFn packMatmulControlFn = [&](linalg::LinalgOp op) -> mlir::linalg::BlockPackMatmulOptions { @@ -742,8 +749,6 @@ void PropagateLayoutOnNamedOps::runOnOperation() { layoutAnalysisResult.getOpLayout(op); if (failed(matmulLayout)) return options; // return default options to skip packing - // currently supported combination: plain & blocking & plain || - // blocking & blocking & blocking TensorLayout inputLayout = matmulLayout->getSupportedInputLayouts()[0]; TensorLayout weightLayout = matmulLayout->getSupportedInputLayouts()[1]; TensorLayout outputLayout = matmulLayout->getSupportedOutputLayouts()[0]; @@ -758,21 +763,9 @@ void PropagateLayoutOnNamedOps::runOnOperation() { OpFoldResult MBlock = rewriter.getIndexAttr(matmulCfg.innerMostMBlock), KBlock = rewriter.getIndexAttr(matmulCfg.innerMostKBlock), NBlock = rewriter.getIndexAttr(matmulCfg.innerMostNBlock); - if (!inputLayout.getTileSizes().empty()) - assert(inputLayout.getTileSizes()[0] == MBlock && - inputLayout.getTileSizes()[1] == KBlock && - "Layout tile size and matmul block size mismatch."); - if (!weightLayout.getTileSizes().empty()) - assert(weightLayout.getTileSizes()[0] == KBlock && - weightLayout.getTileSizes()[1] == NBlock && - "Layout tile size and matmul block size mismatch."); - if (!outputLayout.getTileSizes().empty()) - assert(outputLayout.getTileSizes()[0] == MBlock && - outputLayout.getTileSizes()[1] == NBlock && - "Layout tile size and matmul block size mismatch."); - options.blockFactors.push_back(*getConstantIntValue(MBlock)); - options.blockFactors.push_back(*getConstantIntValue(NBlock)); - options.blockFactors.push_back(*getConstantIntValue(KBlock)); + options.blockFactors = SmallVector{ + *getConstantIntValue(MBlock), *getConstantIntValue(NBlock), + *getConstantIntValue(KBlock)}; return options; }; linalg::populateBlockPackMatmulPatterns(packMatmulPatterns, @@ -788,24 +781,8 @@ void PropagateLayoutOnNamedOps::runOnOperation() { if (failed(applyPatternsAndFoldGreedily(graph, std::move(packVNNIPatterns)))) return signalPassFailure(); - // stage 1.3: revert necessary blocking on matmul op - // RevertMatmulPacking - // double confirm the number of identifiable matmuls - // collect matmul layouts in topological order - uint64_t numMatmuls = 0; - graph->walk([&](Operation *op) { - if (auto linalgOp = dyn_cast(op)) { - if (mlir::gc::utils::isSupportedContractionNamedOp(linalgOp) || - linalgx::isGenericPackedMatmulOp(linalgOp.getOperation(), - linalgx::PackingType::MM4D, - linalgx::PackingType::VNNI_MM4D)) { - numMatmuls += 1; - } - } - return WalkResult::advance(); - }); - assert(matmulLayouts.size() == numMatmuls && - "One to one matmul mapping failed."); + // stage 1.3: revert packed matmul from blocking activation to plain + // activation if (failed(revertMatmulPacking(ctx, graph, matmulLayouts))) return signalPassFailure(); From 69a37a485a3e27cb4761d0ac2097e21beda9fec0 Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Fri, 20 Sep 2024 02:01:19 -0700 Subject: [PATCH 21/23] remove test --- .../gc/Transforms/named-op-layout-propagation.mlir | 12 ------------ 1 file changed, 12 deletions(-) delete mode 100644 test/mlir/test/gc/Transforms/named-op-layout-propagation.mlir diff --git a/test/mlir/test/gc/Transforms/named-op-layout-propagation.mlir b/test/mlir/test/gc/Transforms/named-op-layout-propagation.mlir deleted file mode 100644 index d3ca62e73..000000000 --- a/test/mlir/test/gc/Transforms/named-op-layout-propagation.mlir +++ /dev/null @@ -1,12 +0,0 @@ -// RUN: gc-opt %s --split-input-file --propagate-layout-on-named-ops | FileCheck %s - -// CHECK-LABEL: @matmul_add -func.func @matmul_add(%arg0: tensor<128x64xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<32xf32>) -> tensor<128x32xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %0 = tensor.empty() : tensor<128x32xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x32xf32>) -> tensor<128x32xf32> - %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x64xf32>, tensor<64x32xf32>) outs(%1 : tensor<128x32xf32>) -> tensor<128x32xf32> - %3 = linalg.broadcast ins(%arg2 : tensor<32xf32>) outs(%0 : tensor<128x32xf32>) dimensions = [0] - %4 = linalg.add ins(%2, %3 : tensor<128x32xf32>, tensor<128x32xf32>) outs(%0 : tensor<128x32xf32>) -> tensor<128x32xf32> - return %4 : tensor<128x32xf32> -} From effb13215d26579caf4d9d7ba401c1060539ee7e Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Sun, 22 Sep 2024 22:58:08 -0700 Subject: [PATCH 22/23] refactor 3 --- lib/gc/Transforms/PropagateLayout.cpp | 139 ++++++++++++++------------ 1 file changed, 74 insertions(+), 65 deletions(-) diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp index ec40892f8..7e16ec782 100644 --- a/lib/gc/Transforms/PropagateLayout.cpp +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -281,6 +281,44 @@ class PropagateLayoutOnNamedOps void runOnOperation() final; }; +template +static void packReshapeOp(T reshapeOp, IRRewriter &rewriter, + const OperatorLayout &opLayout) { + Location loc = reshapeOp->getLoc(); + TensorLayout inputLayout = opLayout.getSupportedInputLayouts()[0]; + TensorLayout outputLayout = opLayout.getSupportedOutputLayouts()[0]; + Value curSrc = reshapeOp.getSrc(); + Value curDst = reshapeOp.getResult(); + Value dest = tensor::PackOp::createDestinationTensor( + rewriter, loc, curSrc, inputLayout.getTileSizes(), + inputLayout.getInnerAxis(), inputLayout.getOuterAxis()); + Value packedSource = + insertLayoutPack(rewriter, loc, curSrc, dest, inputLayout.getInnerAxis(), + inputLayout.getTileSizes(), inputLayout.getOuterAxis()); + SmallVector newReassocIndices = + reshapeOp.getReassociationIndices(); + TensorLayout shorterSide = inputLayout.getRank() > outputLayout.getRank() + ? outputLayout + : inputLayout; + int64_t nextPos = applyPermutationAndReindexReassoc( + newReassocIndices, shorterSide.getOuterAxis()); + // Then add direct mapping for the inner tile dims. + for (size_t i = 0; i < inputLayout.getInnerAxis().size(); ++i) { + newReassocIndices.push_back({nextPos}); + nextPos += 1; + } + RankedTensorType newExpandType = tensor::PackOp::inferPackedType( + dyn_cast(curDst.getType()), + *getConstantIntValues(outputLayout.getTileSizes()), + outputLayout.getInnerAxis(), outputLayout.getOuterAxis()); + Value packedExpandShape = + rewriter.create(loc, newExpandType, packedSource, newReassocIndices); + Value newUnPackOp = insertLayoutUnpack( + rewriter, loc, packedExpandShape, outputLayout.getInnerAxis(), + outputLayout.getTileSizes(), outputLayout.getOuterAxis()); + rewriter.replaceOp(reshapeOp, newUnPackOp); +} + LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, ControlPackNamedOpsFn controlFn) { IRRewriter rewriter(ctx); @@ -298,84 +336,55 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, << " has plain layout, skip packing.\n"); return WalkResult::advance(); } + if (checkPacked(op, opLayout)) { + LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName() + << " is already packed, skip packing.\n"); + return WalkResult::advance(); + } // pack op into ideal layout LLVM_DEBUG(llvm::dbgs() - << "Op " << op->getName() << "'s inferred layout:\n" + << "Packing op " << op->getName() << " into inferred layout:\n" << opLayout << "\n"); // insert pack OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(op); - if (checkPacked(op, opLayout)) { - LLVM_DEBUG(llvm::dbgs() - << "Op " << op->getName() << " already packed.\n"); - return WalkResult::advance(); - } if (auto linalgOp = dyn_cast(op)) { if (failed(packLinalgOp(rewriter, linalgOp, opLayout))) { return WalkResult::skip(); } } else if (auto expandShapeOp = dyn_cast(op)) { - Location loc = expandShapeOp->getLoc(); - auto inputLayout = opLayout.getSupportedInputLayouts()[0]; - auto outputLayout = opLayout.getSupportedOutputLayouts()[0]; - Value curSrc = expandShapeOp.getSrc(); - Value curDst = expandShapeOp.getResult(); - Value dest = tensor::PackOp::createDestinationTensor( - rewriter, loc, curSrc, inputLayout.getTileSizes(), - inputLayout.getInnerAxis(), inputLayout.getOuterAxis()); - Value packedSource = insertLayoutPack( - rewriter, loc, curSrc, dest, inputLayout.getInnerAxis(), - inputLayout.getTileSizes(), inputLayout.getOuterAxis()); - SmallVector newReassocIndices = - expandShapeOp.getReassociationIndices(); - int64_t nextPos = applyPermutationAndReindexReassoc( - newReassocIndices, inputLayout.getOuterAxis()); - // Then add direct mapping for the inner tile dims. - for (size_t i = 0; i < inputLayout.getInnerAxis().size(); ++i) { - newReassocIndices.push_back({nextPos}); - nextPos += 1; - } - RankedTensorType newExpandType = tensor::PackOp::inferPackedType( - dyn_cast(curDst.getType()), - *getConstantIntValues(outputLayout.getTileSizes()), - outputLayout.getInnerAxis(), outputLayout.getOuterAxis()); - Value packedExpandShape = rewriter.create( - loc, newExpandType, packedSource, newReassocIndices); - Value newUnPackOp = insertLayoutUnpack( - rewriter, loc, packedExpandShape, outputLayout.getInnerAxis(), - outputLayout.getTileSizes(), outputLayout.getOuterAxis()); - rewriter.replaceOp(expandShapeOp, newUnPackOp); + packReshapeOp(expandShapeOp, rewriter, opLayout); } else if (auto collapseShapeOp = dyn_cast(op)) { - Location loc = collapseShapeOp->getLoc(); - auto inputLayout = opLayout.getSupportedInputLayouts()[0]; - auto outputLayout = opLayout.getSupportedOutputLayouts()[0]; - Value curSrc = collapseShapeOp.getSrc(); - Value curDst = collapseShapeOp.getResult(); + packReshapeOp(collapseShapeOp, rewriter, + opLayout); + } else if (auto padOp = dyn_cast(op)) { + Location loc = padOp->getLoc(); + TensorLayout inputLayout = opLayout.getSupportedInputLayouts()[0]; + Value curSrc = padOp.getSource(); + SmallVector outerDimsPerm = inputLayout.getOuterAxis(); + SmallVector innerDimsPos = inputLayout.getInnerAxis(); + SmallVector tileSizes = inputLayout.getTileSizes(); Value dest = tensor::PackOp::createDestinationTensor( - rewriter, loc, curSrc, inputLayout.getTileSizes(), - inputLayout.getInnerAxis(), inputLayout.getOuterAxis()); - Value packedSource = insertLayoutPack( - rewriter, loc, curSrc, dest, inputLayout.getInnerAxis(), - inputLayout.getTileSizes(), inputLayout.getOuterAxis()); - SmallVector newReassocIndices = - collapseShapeOp.getReassociationIndices(); - int64_t nextPos = applyPermutationAndReindexReassoc( - newReassocIndices, outputLayout.getOuterAxis()); - // Then add direct mapping for the inner tile dims. - for (size_t i = 0; i < inputLayout.getInnerAxis().size(); ++i) { - newReassocIndices.push_back({nextPos}); - nextPos += 1; - } - RankedTensorType newCollapseType = tensor::PackOp::inferPackedType( - dyn_cast(curDst.getType()), - *getConstantIntValues(outputLayout.getTileSizes()), - outputLayout.getInnerAxis(), outputLayout.getOuterAxis()); - Value packedCollapseShape = rewriter.create( - loc, newCollapseType, packedSource, newReassocIndices); - Value newUnPackOp = insertLayoutUnpack( - rewriter, loc, packedCollapseShape, outputLayout.getInnerAxis(), - outputLayout.getTileSizes(), outputLayout.getOuterAxis()); - rewriter.replaceOp(collapseShapeOp, newUnPackOp); + rewriter, loc, curSrc, tileSizes, innerDimsPos, outerDimsPerm); + Value packedSource = + insertLayoutPack(rewriter, loc, curSrc, dest, innerDimsPos, + tileSizes, outerDimsPerm); + // update lowPad and highPad + SmallVector lowPad = padOp.getMixedLowPad(); + SmallVector highPad = padOp.getMixedHighPad(); + applyPermutationToVector(lowPad, outerDimsPerm); + applyPermutationToVector(highPad, outerDimsPerm); + lowPad.append(innerDimsPos.size(), rewriter.getIndexAttr(0)); + highPad.append(innerDimsPos.size(), rewriter.getIndexAttr(0)); + auto packedPadOp = rewriter.create( + loc, /*result=*/Type(), packedSource, lowPad, highPad, + padOp.getConstantPaddingValue(), padOp.getNofold()); + auto unpackEmpty = tensor::UnPackOp::createDestinationTensor( + rewriter, loc, packedPadOp, tileSizes, innerDimsPos, outerDimsPerm); + Value unpackedPad = rewriter.create( + loc, packedPadOp, unpackEmpty, innerDimsPos, tileSizes, + outerDimsPerm); + rewriter.replaceOp(padOp, unpackedPad); } } return WalkResult::advance(); From 6dc7c743bf89afd018552666656e9cdcbd77697d Mon Sep 17 00:00:00 2001 From: "Zhang, Yifei" Date: Mon, 23 Sep 2024 19:44:00 -0700 Subject: [PATCH 23/23] update blocking acktivation test --- lib/gc/Transforms/PropagateLayout.cpp | 82 ++++++----- test/mlir/test/gc/Transforms/pack-matmul.mlir | 131 +++++------------- 2 files changed, 81 insertions(+), 132 deletions(-) diff --git a/lib/gc/Transforms/PropagateLayout.cpp b/lib/gc/Transforms/PropagateLayout.cpp index 7e16ec782..df534c387 100644 --- a/lib/gc/Transforms/PropagateLayout.cpp +++ b/lib/gc/Transforms/PropagateLayout.cpp @@ -42,25 +42,26 @@ static Value insertLayoutPack(RewriterBase &rewriter, Location loc, Value input, Value dest, ArrayRef innerDimsPos, ArrayRef innerTiles, ArrayRef outerDimsPerm) { - if (!innerDimsPos.empty()) - return rewriter.create( - loc, input, dest, innerDimsPos, innerTiles, - /*padding=*/std::nullopt, outerDimsPerm); - if (!TensorLayout::isPlainOuterAxis(outerDimsPerm)) { + if (!innerDimsPos.empty()) { + auto zeroAttr = rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); + Value zero = rewriter.create(loc, zeroAttr); + // TODO(yifei): correct the padding value here + return rewriter.create(loc, input, dest, innerDimsPos, + innerTiles, zero, outerDimsPerm); + } + if (!TensorLayout::isPlainOuterAxis(outerDimsPerm)) return rewriter.create(loc, input, dest, outerDimsPerm) .getResults()[0]; - } return input; } // insert unpack when innerPosDims is non-empty // insert linalg.transpose otherwise static Value insertLayoutUnpack(RewriterBase &rewriter, Location loc, - Value input, ArrayRef innerDimsPos, + Value input, Value dest, + ArrayRef innerDimsPos, ArrayRef innerTiles, ArrayRef outerDimsPerm) { - Value dest = tensor::UnPackOp::createDestinationTensor( - rewriter, loc, input, innerTiles, innerDimsPos, outerDimsPerm); if (!innerDimsPos.empty()) { return rewriter.create(loc, input, dest, innerDimsPos, innerTiles, outerDimsPerm); @@ -173,14 +174,16 @@ LogicalResult packLinalgOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, llvm::all_of(innerPackSizes, [](OpFoldResult tile) { return getConstantIntValue(tile).has_value(); }); - if (areConstantTiles && operandType.hasStaticShape() && - !tensor::PackOp::requirePaddingValue( - operandType.getShape(), innerPos, - cast(dest.getType()).getShape(), {}, - innerPackSizes)) { + if (areConstantTiles && operandType.hasStaticShape()) { + // TODO(yifei): use masked operation or choose the correct padding value + // to ensure computation correctness packOps.push_back(insertLayoutPack( rewriter, loc, operand, dest, innerPos, innerPackSizes, outerPerm)); } else { + LLVM_DEBUG( + llvm::dbgs() + << "Packing of linalg op " << linalgOp.getOperation()->getName() + << " failed due to non-constant tile sizes or dynamic shape.\n"); return failure(); } inputsAndInits.push_back(packOps.back()); @@ -225,11 +228,11 @@ LogicalResult packLinalgOp(RewriterBase &rewriter, linalg::LinalgOp linalgOp, assert(resultNum < static_cast(initLayouts.size()) && "Linalg op results num exceeds inits num."); // Build the symmetrical UnPackOp to the existing PackOp. - unPackOps.push_back( - insertLayoutUnpack(rewriter, packedLinalgOp->getLoc(), result, - initLayouts[resultNum].getInnerAxis(), - initLayouts[resultNum].getTileSizes(), - initLayouts[resultNum].getOuterAxis())); + unPackOps.push_back(insertLayoutUnpack( + rewriter, packedLinalgOp->getLoc(), result, + initOperands[resultNum]->get(), initLayouts[resultNum].getInnerAxis(), + initLayouts[resultNum].getTileSizes(), + initLayouts[resultNum].getOuterAxis())); results.push_back(unPackOps.back()); } @@ -307,15 +310,19 @@ static void packReshapeOp(T reshapeOp, IRRewriter &rewriter, newReassocIndices.push_back({nextPos}); nextPos += 1; } - RankedTensorType newExpandType = tensor::PackOp::inferPackedType( + RankedTensorType newReshapeType = tensor::PackOp::inferPackedType( dyn_cast(curDst.getType()), *getConstantIntValues(outputLayout.getTileSizes()), outputLayout.getInnerAxis(), outputLayout.getOuterAxis()); - Value packedExpandShape = - rewriter.create(loc, newExpandType, packedSource, newReassocIndices); + Value packedReshapeShape = + rewriter.create(loc, newReshapeType, packedSource, newReassocIndices); + Value unpackDest = tensor::UnPackOp::createDestinationTensor( + rewriter, loc, packedReshapeShape, outputLayout.getTileSizes(), + outputLayout.getInnerAxis(), outputLayout.getOuterAxis()); Value newUnPackOp = insertLayoutUnpack( - rewriter, loc, packedExpandShape, outputLayout.getInnerAxis(), - outputLayout.getTileSizes(), outputLayout.getOuterAxis()); + rewriter, loc, packedReshapeShape, unpackDest, + outputLayout.getInnerAxis(), outputLayout.getTileSizes(), + outputLayout.getOuterAxis()); rewriter.replaceOp(reshapeOp, newUnPackOp); } @@ -331,6 +338,9 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, return WalkResult::skip(); } OperatorLayout opLayout = *controlFn(op); + LLVM_DEBUG(llvm::dbgs() + << "Packing op " << op->getName() << " into inferred layout:\n" + << opLayout << "\n"); if (opLayout.isPlain()) { LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName() << " has plain layout, skip packing.\n"); @@ -361,6 +371,7 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, Location loc = padOp->getLoc(); TensorLayout inputLayout = opLayout.getSupportedInputLayouts()[0]; Value curSrc = padOp.getSource(); + Value curDest = padOp.getResult(); SmallVector outerDimsPerm = inputLayout.getOuterAxis(); SmallVector innerDimsPos = inputLayout.getInnerAxis(); SmallVector tileSizes = inputLayout.getTileSizes(); @@ -379,12 +390,10 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph, auto packedPadOp = rewriter.create( loc, /*result=*/Type(), packedSource, lowPad, highPad, padOp.getConstantPaddingValue(), padOp.getNofold()); - auto unpackEmpty = tensor::UnPackOp::createDestinationTensor( - rewriter, loc, packedPadOp, tileSizes, innerDimsPos, outerDimsPerm); - Value unpackedPad = rewriter.create( - loc, packedPadOp, unpackEmpty, innerDimsPos, tileSizes, - outerDimsPerm); - rewriter.replaceOp(padOp, unpackedPad); + Value newUnPackOp = + insertLayoutUnpack(rewriter, loc, packedPadOp, curDest, + innerDimsPos, tileSizes, outerDimsPerm); + rewriter.replaceOp(padOp, newUnPackOp); } } return WalkResult::advance(); @@ -596,9 +605,10 @@ revertMatmulPacking(MLIRContext *ctx, mlir::Operation *graph, Value unpackInputDest = tensor::UnPackOp::createDestinationTensor( rewriter, loc, packInputOp, packInputInnerTiles, unpackInputInnerDimsPos, ArrayRef{}); - Value reUnpackInput = rewriter.create( - loc, packInputOp, unpackInputDest, unpackInputInnerDimsPos, - packInputInnerTiles); + Value reUnpackInput = + insertLayoutUnpack(rewriter, loc, packInputOp, unpackInputDest, + unpackInputInnerDimsPos, packInputInnerTiles, + ArrayRef{}); // unpack init auto packInitInnerTiles = packInitOp.getMixedTiles(); auto packInitInnerDimsPos = packInitOp.getInnerDimsPos(); @@ -614,9 +624,9 @@ revertMatmulPacking(MLIRContext *ctx, mlir::Operation *graph, Value unpackInitDest = tensor::UnPackOp::createDestinationTensor( rewriter, loc, packInitOp, packInitInnerTiles, packInitInnerDimsPos, packInitOuterDimsPerm); - Value reUnpackInit = rewriter.create( - loc, packInitOp, unpackInitDest, packInitInnerDimsPos, - packInitInnerTiles, packInitOuterDimsPerm); + Value reUnpackInit = insertLayoutUnpack( + rewriter, loc, packInitOp, unpackInitDest, packInitInnerDimsPos, + packInitInnerTiles, ArrayRef{}); // replace matmul 4D with matmul 2D auto matmul2D = linalgx::makeGenericPackedMatmulOp( rewriter, loc, revertType, diff --git a/test/mlir/test/gc/Transforms/pack-matmul.mlir b/test/mlir/test/gc/Transforms/pack-matmul.mlir index 27e752359..de13e0a6a 100644 --- a/test/mlir/test/gc/Transforms/pack-matmul.mlir +++ b/test/mlir/test/gc/Transforms/pack-matmul.mlir @@ -2,118 +2,57 @@ // ----- -// CHECK-LABEL: @single_matmul_f32 -func.func @single_matmul_f32(%arg0: tensor<128x64xf32>, %arg1: tensor<64x32xf32>) -> tensor<128x32xf32> { +// CHECK-LABEL: @matmul_add_plain_activation_f32 +func.func @matmul_add_plain_activation_f32(%arg0: tensor<128x64xf32>, %arg1: tensor<64x64xf32>, %arg2: tensor<64xf32>) -> tensor<128x64xf32> { %cst = arith.constant 0.000000e+00 : f32 - %0 = tensor.empty() : tensor<128x32xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x32xf32>) -> tensor<128x32xf32> - %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x64xf32>, tensor<64x32xf32>) outs(%0 : tensor<128x32xf32>) -> tensor<128x32xf32> - return %2 : tensor<128x32xf32> + %0 = tensor.empty() : tensor<128x64xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x64xf32>) -> tensor<128x64xf32> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x64xf32>, tensor<64x64xf32>) outs(%0 : tensor<128x64xf32>) -> tensor<128x64xf32> + %3 = tensor.empty() : tensor<128x64xf32> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<64xf32>) outs(%3 : tensor<128x64xf32>) dimensions = [0] + %4 = tensor.empty() : tensor<128x64xf32> + %5 = linalg.add ins(%2, %broadcasted : tensor<128x64xf32>, tensor<128x64xf32>) outs(%4 : tensor<128x64xf32>) -> tensor<128x64xf32> + return %5 : tensor<128x64xf32> } // CHECK-COUNT-1: tensor.pack // CHECK-COUNT-1: linalg.generic +// CHECK: linalg.add ins(%{{.*}}, %{{.*}} : tensor<{{.*}}x{{.*}}xf32>, tensor<{{.*}}x{{.*}}xf32>) outs(%{{.*}} : tensor<{{.*}}x{{.*}}xf32>) -> tensor<{{.*}}x{{.*}}xf32> // CHECK-NOT: tensor.unpack // ----- -// CHECK-LABEL: @single_matmul_bf16 -func.func @single_matmul_bf16(%arg0: tensor<128x64xbf16>, %arg1: tensor<64x32xbf16>) -> tensor<128x32xbf16> { - %cst = arith.constant 0.000000e+00 : bf16 - %0 = tensor.empty() : tensor<128x32xbf16> - %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x32xbf16>) -> tensor<128x32xbf16> - %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x64xbf16>, tensor<64x32xbf16>) outs(%0 : tensor<128x32xbf16>) -> tensor<128x32xbf16> - return %2 : tensor<128x32xbf16> -} -// CHECK-COUNT-2: tensor.pack -// CHECK-COUNT-1: linalg.generic -// CHECK-NOT: tensor.unpack - -// ----- - -// CHECK-LABEL: @mlp_f32 -func.func @mlp_f32(%arg0: tensor<128x16xf32>, %arg1: tensor<16x512xf32>, %arg2: tensor<512x256xf32>, %arg3: tensor<256x128xf32>, %arg4: tensor<512xf32>, %arg5: tensor<256xf32>, %arg6: tensor<128xf32>) -> tensor<128x128xf32> attributes {llvm.emit_c_interface} { +// CHECK-LABEL: @matmul_add_blocking_activation_f32 +func.func @matmul_add_blocking_activation_f32(%arg0: tensor<128x511xf32>, %arg1: tensor<511x255xf32>, %arg2: tensor<255xf32>) -> tensor<128x255xf32> { %cst = arith.constant 0.000000e+00 : f32 - %0 = tensor.empty() : tensor<128x512xf32> - %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x512xf32>) -> tensor<128x512xf32> - %2 = linalg.matmul {cast = #linalg.type_fn} ins(%arg0, %arg1 : tensor<128x16xf32>, tensor<16x512xf32>) outs(%1 : tensor<128x512xf32>) -> tensor<128x512xf32> - %3 = tensor.empty() : tensor<128x512xf32> - %broadcasted = linalg.broadcast ins(%arg4 : tensor<512xf32>) outs(%3 : tensor<128x512xf32>) dimensions = [0] - %4 = tensor.empty() : tensor<128x512xf32> - %5 = linalg.add ins(%2, %broadcasted : tensor<128x512xf32>, tensor<128x512xf32>) outs(%4 : tensor<128x512xf32>) -> tensor<128x512xf32> - %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x512xf32> - %6 = tensor.empty() : tensor<128x512xf32> - %7 = linalg.max ins(%5, %cst_0 : tensor<128x512xf32>, tensor<128x512xf32>) outs(%6 : tensor<128x512xf32>) -> tensor<128x512xf32> - %8 = tensor.empty() : tensor<128x256xf32> - %9 = linalg.fill ins(%cst : f32) outs(%8 : tensor<128x256xf32>) -> tensor<128x256xf32> - %10 = linalg.matmul {cast = #linalg.type_fn} ins(%7, %arg2 : tensor<128x512xf32>, tensor<512x256xf32>) outs(%9 : tensor<128x256xf32>) -> tensor<128x256xf32> - %11 = tensor.empty() : tensor<128x256xf32> - %broadcasted_1 = linalg.broadcast ins(%arg5 : tensor<256xf32>) outs(%11 : tensor<128x256xf32>) dimensions = [0] - %12 = tensor.empty() : tensor<128x256xf32> - %13 = linalg.add ins(%10, %broadcasted_1 : tensor<128x256xf32>, tensor<128x256xf32>) outs(%12 : tensor<128x256xf32>) -> tensor<128x256xf32> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x256xf32> - %14 = tensor.empty() : tensor<128x256xf32> - %15 = linalg.max ins(%13, %cst_2 : tensor<128x256xf32>, tensor<128x256xf32>) outs(%14 : tensor<128x256xf32>) -> tensor<128x256xf32> - %16 = tensor.empty() : tensor<128x128xf32> - %17 = linalg.fill ins(%cst : f32) outs(%16 : tensor<128x128xf32>) -> tensor<128x128xf32> - %18 = linalg.matmul {cast = #linalg.type_fn} ins(%15, %arg3 : tensor<128x256xf32>, tensor<256x128xf32>) outs(%17 : tensor<128x128xf32>) -> tensor<128x128xf32> - %19 = tensor.empty() : tensor<128x128xf32> - %broadcasted_3 = linalg.broadcast ins(%arg6 : tensor<128xf32>) outs(%19 : tensor<128x128xf32>) dimensions = [0] - %20 = tensor.empty() : tensor<128x128xf32> - %21 = linalg.add ins(%18, %broadcasted_3 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%20 : tensor<128x128xf32>) -> tensor<128x128xf32> - %cst_4 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> - %22 = tensor.empty() : tensor<128x128xf32> - %23 = linalg.max ins(%21, %cst_4 : tensor<128x128xf32>, tensor<128x128xf32>) outs(%22 : tensor<128x128xf32>) -> tensor<128x128xf32> - return %23 : tensor<128x128xf32> + %0 = tensor.empty() : tensor<128x255xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<128x255xf32>) -> tensor<128x255xf32> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x511xf32>, tensor<511x255xf32>) outs(%0 : tensor<128x255xf32>) -> tensor<128x255xf32> + %3 = tensor.empty() : tensor<128x255xf32> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<255xf32>) outs(%3 : tensor<128x255xf32>) dimensions = [0] + %4 = tensor.empty() : tensor<128x255xf32> + %5 = linalg.add ins(%2, %broadcasted : tensor<128x255xf32>, tensor<128x255xf32>) outs(%4 : tensor<128x255xf32>) -> tensor<128x255xf32> + return %5 : tensor<128x255xf32> } -// CHECK-COUNT-1: tensor.pack -// CHECK-COUNT-1: linalg.generic -// CHECK-COUNT-1: tensor.pack -// CHECK-COUNT-1: linalg.generic -// CHECK-COUNT-1: tensor.pack +// CHECK-COUNT-2: tensor.pack // CHECK-COUNT-1: linalg.generic -// CHECK-NOT: tensor.unpack +// CHECK: linalg.add ins(%{{.*}}, %{{.*}} : tensor<{{.*}}x{{.*}}x{{.*}}x{{.*}}xf32>, tensor<{{.*}}x{{.*}}x{{.*}}x{{.*}}xf32>) outs(%{{.*}} : tensor<{{.*}}x{{.*}}x{{.*}}x{{.*}}xf32>) -> tensor<{{.*}}x{{.*}}x{{.*}}x{{.*}}xf32> +// CHECK-COUNT-1: tensor.unpack // ----- -// CHECK-LABEL: @mlp_bf16 -func.func @mlp_bf16(%arg0: tensor<32x4096xbf16>, %arg1: tensor<4096x4096xbf16>, %arg2: tensor<4096x11008xbf16>, %arg3: tensor<11008x4096xbf16>, %arg4: tensor<4096xbf16>, %arg5: tensor<11008xbf16>, %arg6: tensor<4096xbf16>) -> tensor<32x4096xbf16> attributes {llvm.emit_c_interface} { +// CHECK-LABEL: @matmul_add_plain_activation_bf16 +func.func @matmul_add_plain_activation_bf16(%arg0: tensor<128x64xbf16>, %arg1: tensor<64x64xbf16>, %arg2: tensor<64xbf16>) -> tensor<128x64xbf16> { %cst = arith.constant 0.000000e+00 : bf16 - %0 = tensor.empty() : tensor<32x4096xbf16> - %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> - %2 = linalg.matmul {cast = #linalg.type_fn} ins(%arg0, %arg1 : tensor<32x4096xbf16>, tensor<4096x4096xbf16>) outs(%1 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> - %3 = tensor.empty() : tensor<32x4096xbf16> - %broadcasted = linalg.broadcast ins(%arg4 : tensor<4096xbf16>) outs(%3 : tensor<32x4096xbf16>) dimensions = [0] - %4 = tensor.empty() : tensor<32x4096xbf16> - %5 = linalg.add ins(%2, %broadcasted : tensor<32x4096xbf16>, tensor<32x4096xbf16>) outs(%4 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> - %cst_0 = arith.constant dense<0.000000e+00> : tensor<32x4096xbf16> - %6 = tensor.empty() : tensor<32x4096xbf16> - %7 = linalg.max ins(%5, %cst_0 : tensor<32x4096xbf16>, tensor<32x4096xbf16>) outs(%6 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> - %8 = tensor.empty() : tensor<32x11008xbf16> - %9 = linalg.fill ins(%cst : bf16) outs(%8 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> - %10 = linalg.matmul {cast = #linalg.type_fn} ins(%7, %arg2 : tensor<32x4096xbf16>, tensor<4096x11008xbf16>) outs(%9 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> - %11 = tensor.empty() : tensor<32x11008xbf16> - %broadcasted_1 = linalg.broadcast ins(%arg5 : tensor<11008xbf16>) outs(%11 : tensor<32x11008xbf16>) dimensions = [0] - %12 = tensor.empty() : tensor<32x11008xbf16> - %13 = linalg.add ins(%10, %broadcasted_1 : tensor<32x11008xbf16>, tensor<32x11008xbf16>) outs(%12 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> - %cst_2 = arith.constant dense<0.000000e+00> : tensor<32x11008xbf16> - %14 = tensor.empty() : tensor<32x11008xbf16> - %15 = linalg.max ins(%13, %cst_2 : tensor<32x11008xbf16>, tensor<32x11008xbf16>) outs(%14 : tensor<32x11008xbf16>) -> tensor<32x11008xbf16> - %16 = tensor.empty() : tensor<32x4096xbf16> - %17 = linalg.fill ins(%cst : bf16) outs(%16 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> - %18 = linalg.matmul {cast = #linalg.type_fn} ins(%15, %arg3 : tensor<32x11008xbf16>, tensor<11008x4096xbf16>) outs(%17 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> - %19 = tensor.empty() : tensor<32x4096xbf16> - %broadcasted_3 = linalg.broadcast ins(%arg6 : tensor<4096xbf16>) outs(%19 : tensor<32x4096xbf16>) dimensions = [0] - %20 = tensor.empty() : tensor<32x4096xbf16> - %21 = linalg.add ins(%18, %broadcasted_3 : tensor<32x4096xbf16>, tensor<32x4096xbf16>) outs(%20 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> - %cst_4 = arith.constant dense<0.000000e+00> : tensor<32x4096xbf16> - %22 = tensor.empty() : tensor<32x4096xbf16> - %23 = linalg.max ins(%21, %cst_4 : tensor<32x4096xbf16>, tensor<32x4096xbf16>) outs(%22 : tensor<32x4096xbf16>) -> tensor<32x4096xbf16> - return %23 : tensor<32x4096xbf16> + %0 = tensor.empty() : tensor<128x64xbf16> + %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x64xbf16>, tensor<64x64xbf16>) outs(%0 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + %3 = tensor.empty() : tensor<128x64xbf16> + %broadcasted = linalg.broadcast ins(%arg2 : tensor<64xbf16>) outs(%3 : tensor<128x64xbf16>) dimensions = [0] + %4 = tensor.empty() : tensor<128x64xbf16> + %5 = linalg.add ins(%2, %broadcasted : tensor<128x64xbf16>, tensor<128x64xbf16>) outs(%4 : tensor<128x64xbf16>) -> tensor<128x64xbf16> + return %5 : tensor<128x64xbf16> } // CHECK-COUNT-2: tensor.pack // CHECK-COUNT-1: linalg.generic -// CHECK-COUNT-2: tensor.pack -// CHECK-COUNT-1: linalg.generic -// CHECK-COUNT-2: tensor.pack -// CHECK-COUNT-1: linalg.generic +// CHECK: linalg.add ins(%{{.*}}, %{{.*}} : tensor<{{.*}}x{{.*}}xbf16>, tensor<{{.*}}x{{.*}}xbf16>) outs(%{{.*}} : tensor<{{.*}}x{{.*}}xbf16>) -> tensor<{{.*}}x{{.*}}xbf16> // CHECK-NOT: tensor.unpack