Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Experimental][Transform] Split Compute Intensive Op #154

Open
wants to merge 12 commits into
base: xurui/add_benchmark
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix color and recursive logic
Zhang Yan committed Jun 25, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 62fc8bffe548737655d5b4212b19beb2bdea8051
35 changes: 28 additions & 7 deletions lib/gc/Transforms/SplitComputeIntensivePatterns.cpp
Original file line number Diff line number Diff line change
@@ -35,7 +35,7 @@ namespace gc {
#include "gc/Transforms/Passes.h.inc"
} // namespace gc

size_t NUM_OF_NUMA = 2;
size_t NUM_OF_NUMA = 3;
size_t SUPPORTED_RANK = 2;

void printValueType(Value value) {
@@ -149,11 +149,14 @@ void SplitMMonN(SmallVector<Value>& outputs, SmallVector<Value>& inputs, TensorT
loc, ArrayRef<int64_t> {M, weight.getType().cast<RankedTensorType>().getDimSize(1)}, resultTy.getElementType());
Value tensor =
rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0);
outputs.push_back(rewriter.create<linalg::MatmulOp>(
auto newMM = rewriter.create<linalg::MatmulOp>(
/*location=*/loc,
/*resultTensorTypes=*/tensor.getType().cast<RankedTensorType>(),
/*inputs=*/ValueRange{inputs[0], weight},
/*outputs=*/tensor)->getResult(0));
/*outputs=*/tensor);
mlir::BoolAttr boolAttr = rewriter.getBoolAttr(true);
newMM->setAttr("splited", boolAttr);
outputs.push_back(newMM->getResult(0));
}
}

@@ -179,11 +182,15 @@ void SplitMMonK(SmallVector<Value>& outputs, SmallVector<Value>& inputs, TensorT
loc, resultTy.getShape(), resultTy.getElementType());
Value tensor =
rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0);
outputs.push_back(rewriter.create<linalg::MatmulOp>(
auto newMM = rewriter.create<linalg::MatmulOp>(
/*location=*/loc,
/*resultTensorTypes=*/tensor.getType().cast<RankedTensorType>(),
/*inputs=*/ValueRange{data, weight},
/*outputs=*/tensor)->getResult(0));
/*outputs=*/tensor);
mlir::BoolAttr boolAttr = rewriter.getBoolAttr(true);
newMM->setAttr("splited", boolAttr);
outputs.push_back(newMM->getResult(0));
outputs.push_back(newMM->getResult(0));
}
}

@@ -203,8 +210,10 @@ bool isSupportedPostOp(Operation *op) {
void getUnOps(Operation *op, SmallVectorImpl<Operation *> &postOps) {
for (auto user : op->getUsers()) {
if (isSupportedPostOp(user)) postOps.push_back(user);
// Recursively search for unary ops
if (isa<linalg::MatmulOp>(user)) return;
// Recursively search for unary ops, unless it's a matmul op
getUnOps(user, postOps);
// }
}
}

@@ -296,7 +305,12 @@ Value addN(Value& initTensor, SmallVector<Value>& ins, TensorType& resultTy, Loc

LogicalResult splitSingleMM(linalg::MatmulOp& op,
PatternRewriter &rewriter) {
SmallVector<Operation *> postOps;
// rewriter.updateRootInPlace(op, [&]() {
// mlir::BoolAttr boolAttr = rewriter.getBoolAttr(true);
// op->setAttr("splited", boolAttr);
// });

SmallVector<Operation *> postOps = {};
getUnOps(op, postOps);
auto loc = op->getLoc();
auto resultTy = dyn_cast<TensorType>(op->getResultTypes().front());
@@ -321,6 +335,7 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op,
if (splites_res.size() != NUM_OF_NUMA) return failure();
SmallVector<Value> Outputs = splites_res;
auto lastInput = op->getResult(0);
llvm::outs() << "postOps num: " << postOps.size() << "\n";
for (auto postOp : postOps) {
llvm::outs() << "Operation name: " << postOp->getName().getStringRef() << "\n";
auto opInputs = postOp->getOperands().drop_back();
@@ -401,6 +416,8 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op,
duplicateBinary<linalg::MaxOp>(Outputs, Inputs, resultTy, rewriter);
llvm::outs() << "post op creation and deletion done \n";
lastInput = postOp->getResult(0);
if(auto lastop = lastInput.getDefiningOp())
std::cout << "lastInput operation name: " << lastop->getName().getStringRef().str() << std::endl;
}
// Concatenate the two halves back together on N axis
auto newop = rewriter.create<tensor::ConcatOp>(
@@ -415,6 +432,8 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op,
}
deleteOperands(replaced_op);
rewriter.replaceOp(replaced_op, newop);
postOps = {};
llvm::outs() << "after duplicate, postOps num: " << postOps.size() << "\n";
} else {
SplitMMonK(splites_res, input_tensors, resultTy, loc, rewriter);
if (splites_res.size() != NUM_OF_NUMA) return failure();
@@ -430,6 +449,8 @@ LogicalResult splitSingleMM(linalg::MatmulOp& op,
// Replace the original operation with the new linalg.map operation
rewriter.replaceOp(op, newop);
}
llvm::outs() << "exit duplicate mm.\n";
llvm::outs() << "==================================================\n";
return success();
}

Original file line number Diff line number Diff line change
@@ -1,4 +1,24 @@
// RUN: gc-opt %s --split-compute-intensive-patterns | FileCheck %s
func.func @basic_mlp(%in: tensor<128x512xbf16>,
%weight: tensor<512x256xbf16>,
%offset: tensor<128x256xbf16>,
%scale: tensor<128x256xbf16>,
%weight2: tensor<256x1024xbf16>) -> tensor<128x1024xbf16> {
%0 = tensor.empty() : tensor<128x256xbf16>
%cst = arith.constant 0.000000e+00 : bf16
%1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x256xbf16>) -> tensor<128x256xbf16>
%2 = linalg.matmul ins(%in, %weight : tensor<128x512xbf16>, tensor<512x256xbf16>) outs(%1 : tensor<128x256xbf16>) -> tensor<128x256xbf16>
%3 = tensor.empty() : tensor<128x256xbf16>
%4 = linalg.add ins(%2, %offset : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%3 : tensor<128x256xbf16>) -> tensor<128x256xbf16>
%5 = tensor.empty() : tensor<128x256xbf16>
%6 = linalg.mul ins(%4, %scale : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%5 : tensor<128x256xbf16>) -> tensor<128x256xbf16>
%9 = tensor.empty() : tensor<128x256xbf16>
%10 = linalg.max ins(%6, %1 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%9 : tensor<128x256xbf16>) -> tensor<128x256xbf16>
%11 = tensor.empty() : tensor<128x1024xbf16>
%12 = linalg.fill ins(%cst : bf16) outs(%11 : tensor<128x1024xbf16>) -> tensor<128x1024xbf16>
%13 = linalg.matmul ins(%10, %weight2 : tensor<128x256xbf16>, tensor<256x1024xbf16>) outs(%12 : tensor<128x1024xbf16>) -> tensor<128x1024xbf16>
return %13 : tensor<128x1024xbf16>
}

func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x64xbf16>, %arg2: tensor<64xbf16>, %arg3: tensor<64x256xbf16>, %arg4: tensor<256xbf16>) -> tensor<128x256xbf16> {
%cst = arith.constant 0.000000e+00 : bf16
@@ -20,5 +40,8 @@ func.func @mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x64xbf16>, %arg2: t
%broadcasted_2 = linalg.broadcast ins(%arg4 : tensor<256xbf16>) outs(%11 : tensor<128x256xbf16>) dimensions = [0]
%12 = tensor.empty() : tensor<128x256xbf16>
%13 = linalg.add ins(%10, %broadcasted_2 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%12 : tensor<128x256xbf16>) -> tensor<128x256xbf16>
return %13 : tensor<128x256xbf16>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<128x256xbf16>
%14 = tensor.empty() : tensor<128x256xbf16>
%15 = linalg.max ins(%13, %cst_3 : tensor<128x256xbf16>, tensor<128x256xbf16>) outs(%14 : tensor<128x256xbf16>) -> tensor<128x256xbf16>
return %15 : tensor<128x256xbf16>
}