forked from daphne-project/daphne
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDaphneOptPass.cpp
More file actions
95 lines (79 loc) · 3.87 KB
/
DaphneOptPass.cpp
File metadata and controls
95 lines (79 loc) · 3.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#include "compiler/utils/CompilerUtils.h"
#include "compiler/utils/LoweringUtils.h"
#include "ir/daphneir/Daphne.h"
#include "ir/daphneir/Passes.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "dm-opt"
using namespace mlir;
class IntegerModOpt : public mlir::OpConversionPattern<mlir::daphne::EwModOp> {
public:
using OpConversionPattern::OpConversionPattern;
[[nodiscard]] static bool optimization_viable(mlir::daphne::EwModOp op) {
if (!op.getRhs().getType().isUnsignedInteger())
return false;
std::pair<bool, uint64_t> isConstant = CompilerUtils::isConstant<uint64_t>(op.getRhs());
// Apply (lhs % rhs) to (lhs & (rhs - 1)) optimization when rhs is a
// power of two
return isConstant.first && (isConstant.second & (isConstant.second - 1)) == 0;
}
mlir::LogicalResult matchAndRewrite(mlir::daphne::EwModOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
mlir::Type u = daphne::UnknownType::get(getContext());
mlir::Value cst_one = rewriter.create<mlir::daphne::ConstantOp>(op.getLoc(), static_cast<uint64_t>(1));
mlir::Value sub = CompilerUtils::retValWithInferredType(
rewriter.create<mlir::daphne::EwSubOp>(op.getLoc(), u, adaptor.getRhs(), cst_one));
mlir::Value andOp = CompilerUtils::retValWithInferredType(
rewriter.create<mlir::daphne::EwBitwiseAndOp>(op.getLoc(), u, adaptor.getLhs(), sub));
rewriter.replaceOp(op, andOp);
return success();
}
};
namespace {
/**
* @brief This pass transforms operations (currently limited to the EwModOp) in
* the DaphneDialect to a different set of operations also from the
* DaphneDialect.
*/
struct DenseMatrixOptPass : public mlir::PassWrapper<DenseMatrixOptPass, mlir::OperationPass<mlir::ModuleOp>> {
explicit DenseMatrixOptPass() {}
void getDependentDialects(mlir::DialectRegistry ®istry) const override {
registry.insert<mlir::LLVM::LLVMDialect, mlir::arith::ArithDialect, mlir::daphne::DaphneDialect>();
}
void runOnOperation() final;
StringRef getArgument() const final { return "opt-daphne"; }
StringRef getDescription() const final {
return "Performs optimizations on the DaphneIR by transforming "
"operations in the DaphneDialect to a set of other operation "
"also from the DaphneDialect.";
}
};
} // end anonymous namespace
void DenseMatrixOptPass::runOnOperation() {
mlir::ConversionTarget target(getContext());
mlir::RewritePatternSet patterns(&getContext());
mlir::LowerToLLVMOptions llvmOptions(&getContext());
mlir::LLVMTypeConverter typeConverter(&getContext(), llvmOptions);
typeConverter.addConversion([](Type type) { return type; });
target.addLegalDialect<mlir::BuiltinDialect>();
target.addLegalDialect<mlir::arith::ArithDialect>();
target.addLegalDialect<mlir::daphne::DaphneDialect>();
target.addDynamicallyLegalOp<mlir::daphne::EwModOp>(
[&](mlir::daphne::EwModOp op) { return !IntegerModOpt::optimization_viable(op); });
patterns.insert<IntegerModOpt>(typeConverter, &getContext());
auto module = getOperation();
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
}
std::unique_ptr<mlir::Pass> mlir::daphne::createDaphneOptPass() { return std::make_unique<DenseMatrixOptPass>(); }