forked from daphne-project/daphne
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMapOpLowering.cpp
More file actions
130 lines (111 loc) · 5.42 KB
/
MapOpLowering.cpp
File metadata and controls
130 lines (111 loc) · 5.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
/*
* Copyright 2023 The DAPHNE Consortium
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "compiler/utils/CompilerUtils.h"
#include "compiler/utils/LoweringUtils.h"
#include "ir/daphneir/Daphne.h"
#include "ir/daphneir/Passes.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
using namespace mlir;
class InlineMapOpLowering : public mlir::OpConversionPattern<mlir::daphne::MapOp> {
public:
using OpConversionPattern::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(mlir::daphne::MapOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
mlir::daphne::MatrixType lhsMatrixType = op->getOperandTypes().front().dyn_cast<mlir::daphne::MatrixType>();
auto matrixElementType = lhsMatrixType.getElementType();
auto lhsMemRefType =
mlir::MemRefType::get({lhsMatrixType.getNumRows(), lhsMatrixType.getNumCols()}, matrixElementType);
mlir::Value lhs =
rewriter.create<mlir::daphne::ConvertDenseMatrixToMemRef>(loc, lhsMemRefType, adaptor.getArg());
mlir::ModuleOp module = op->getParentOfType<mlir::ModuleOp>();
func::FuncOp udfFuncOp = module.lookupSymbol<func::FuncOp>(op.getFunc());
SmallVector<Value, 4> loopIvs;
auto outerLoop = rewriter.create<AffineForOp>(loc, 0, lhsMatrixType.getNumRows(), 1);
for (Operation &nested : *outerLoop.getBody()) {
rewriter.eraseOp(&nested);
}
loopIvs.push_back(outerLoop.getInductionVar());
// outer loop body
rewriter.setInsertionPointToStart(outerLoop.getBody());
auto innerLoop = rewriter.create<AffineForOp>(loc, 0, lhsMatrixType.getNumCols(), 1);
for (Operation &nested : *innerLoop.getBody()) {
rewriter.eraseOp(&nested);
}
loopIvs.push_back(innerLoop.getInductionVar());
rewriter.create<AffineYieldOp>(loc);
rewriter.setInsertionPointToStart(innerLoop.getBody());
// inner loop body
mlir::Value lhsValue = rewriter.create<AffineLoadOp>(loc, lhs, loopIvs);
mlir::Value res = rewriter.create<func::CallOp>(loc, udfFuncOp, ValueRange{lhsValue})->getResult(0);
rewriter.create<AffineStoreOp>(loc, res, lhs, loopIvs);
rewriter.create<AffineYieldOp>(loc);
rewriter.setInsertionPointAfter(outerLoop);
mlir::Value output = convertMemRefToDenseMatrix(op->getLoc(), rewriter, lhs, op.getType());
rewriter.replaceOp(op, output);
return mlir::success();
}
};
namespace {
/**
* @brief The MapOpLoweringPass rewrites the daphne::MapOp operator
* to a set of perfectly nested affine loops and inserts for each element a call
* to the UDF assigned to the daphne::MapOp.
*
* This rewrite enables subsequent inlining pass to completely replace
* the daphne::MapOp by inlining the produced CallOps from this pass.
*/
struct MapOpLoweringPass : public mlir::PassWrapper<MapOpLoweringPass, mlir::OperationPass<mlir::ModuleOp>> {
explicit MapOpLoweringPass() {}
void getDependentDialects(mlir::DialectRegistry ®istry) const override {
registry.insert<mlir::LLVM::LLVMDialect, mlir::AffineDialect, mlir::memref::MemRefDialect,
mlir::daphne::DaphneDialect, mlir::func::FuncDialect>();
}
void runOnOperation() final;
StringRef getArgument() const final { return "lower-map"; }
StringRef getDescription() const final {
return "Lowers the daphne.mapOp operation to"
"a set of affine loops, directly calling the UDF. "
"Subsequent use of the inlining pass may inline the call to the "
"UDF.";
}
};
} // end anonymous namespace
void MapOpLoweringPass::runOnOperation() {
mlir::ConversionTarget target(getContext());
mlir::RewritePatternSet patterns(&getContext());
mlir::LowerToLLVMOptions llvmOptions(&getContext());
mlir::LLVMTypeConverter typeConverter(&getContext(), llvmOptions);
target.addLegalDialect<mlir::AffineDialect, arith::ArithDialect, memref::MemRefDialect, mlir::daphne::DaphneDialect,
mlir::func::FuncDialect>();
target.addIllegalOp<mlir::daphne::MapOp>();
patterns.insert<InlineMapOpLowering>(&getContext());
auto module = getOperation();
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
}
std::unique_ptr<mlir::Pass> mlir::daphne::createMapOpLoweringPass() { return std::make_unique<MapOpLoweringPass>(); }