Skip to content

Commit 75751f3

Browse files
authored
Reapply "Reapply "[mlir] Add FP software implementation lowering pass: arith-to-apfloat (#166618)" (#167431)" (#167436)
Reland #166618 by fixing missing symbol issues by explicitly loading `--shared-libs=%mlir_apfloat_wrappers` as well as `--shared-libs=%mlir_c_runner_utils`.
1 parent 7f81869 commit 75751f3

File tree

17 files changed

+542
-0
lines changed

17 files changed

+542
-0
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- ArithToAPFloat.h - Arith to APFloat impl conversion ---*- C++ ----*-===//
2+
//
3+
// Part of the APFloat Project, under the Apache License v2.0 with APFloat
4+
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
10+
#define MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
class Pass;
16+
17+
#define GEN_PASS_DECL_ARITHTOAPFLOATCONVERSIONPASS
18+
#include "mlir/Conversion/Passes.h.inc"
19+
} // namespace mlir
20+
21+
#endif // MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
1313
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
1414
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
15+
#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
1516
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
1617
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
1718
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,21 @@ def ArithToLLVMConversionPass : Pass<"convert-arith-to-llvm"> {
186186
];
187187
}
188188

189+
//===----------------------------------------------------------------------===//
190+
// ArithToAPFloat
191+
//===----------------------------------------------------------------------===//
192+
193+
def ArithToAPFloatConversionPass
194+
: Pass<"convert-arith-to-apfloat", "ModuleOp"> {
195+
let summary = "Convert Arith ops to APFloat runtime library calls";
196+
let description = [{
197+
This pass converts supported Arith ops to APFloat-based runtime library
198+
calls (APFloatWrappers.cpp). APFloat is a software implementation of
199+
floating-point arithmetic operations.
200+
}];
201+
let dependentDialects = ["func::FuncDialect"];
202+
}
203+
189204
//===----------------------------------------------------------------------===//
190205
// ArithToSPIRV
191206
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Func/Utils/Utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ mlir::FailureOr<std::pair<mlir::func::FuncOp, mlir::func::CallOp>>
6060
deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp,
6161
mlir::ModuleOp moduleOp);
6262

63+
/// Look up a FuncOp with signature `resultTypes`(`paramTypes`)` and name
64+
/// `name`. Return a failure if the FuncOp is found but with a different
65+
/// signature.
66+
FailureOr<FuncOp> lookupFnDecl(SymbolOpInterface symTable, StringRef name,
67+
FunctionType funcT,
68+
SymbolTableCollection *symbolTables = nullptr);
69+
6370
} // namespace func
6471
} // namespace mlir
6572

mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
5252
FailureOr<LLVM::LLVMFuncOp>
5353
lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
5454
SymbolTableCollection *symbolTables = nullptr);
55+
FailureOr<LLVM::LLVMFuncOp>
56+
lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
57+
SymbolTableCollection *symbolTables = nullptr);
58+
5559
/// Declares a function to print a C-string.
5660
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
5761
/// have the signature void(char const*). The default function is `printString`.
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
//===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
10+
11+
#include "mlir/Dialect/Arith/IR/Arith.h"
12+
#include "mlir/Dialect/Arith/Transforms/Passes.h"
13+
#include "mlir/Dialect/Func/IR/FuncOps.h"
14+
#include "mlir/Dialect/Func/Utils/Utils.h"
15+
#include "mlir/IR/PatternMatch.h"
16+
#include "mlir/IR/Verifier.h"
17+
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
18+
19+
namespace mlir {
20+
#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
21+
#include "mlir/Conversion/Passes.h.inc"
22+
} // namespace mlir
23+
24+
using namespace mlir;
25+
using namespace mlir::func;
26+
27+
static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
28+
StringRef name, FunctionType funcT, bool setPrivate,
29+
SymbolTableCollection *symbolTables = nullptr) {
30+
OpBuilder::InsertionGuard g(b);
31+
assert(!symTable->getRegion(0).empty() && "expected non-empty region");
32+
b.setInsertionPointToStart(&symTable->getRegion(0).front());
33+
FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
34+
if (setPrivate)
35+
funcOp.setPrivate();
36+
if (symbolTables) {
37+
SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
38+
symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
39+
}
40+
return funcOp;
41+
}
42+
43+
/// Helper function to look up or create the symbol for a runtime library
44+
/// function for a binary arithmetic operation.
45+
///
46+
/// Parameter 1: APFloat semantics
47+
/// Parameter 2: Left-hand side operand
48+
/// Parameter 3: Right-hand side operand
49+
///
50+
/// This function will return a failure if the function is found but has an
51+
/// unexpected signature.
52+
///
53+
static FailureOr<FuncOp>
54+
lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
55+
SymbolTableCollection *symbolTables = nullptr) {
56+
auto i32Type = IntegerType::get(symTable->getContext(), 32);
57+
auto i64Type = IntegerType::get(symTable->getContext(), 64);
58+
59+
std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
60+
FunctionType funcT =
61+
FunctionType::get(b.getContext(), {i32Type, i64Type, i64Type}, {i64Type});
62+
FailureOr<FuncOp> func =
63+
lookupFnDecl(symTable, funcName, funcT, symbolTables);
64+
// Failed due to type mismatch.
65+
if (failed(func))
66+
return func;
67+
// Successfully matched existing decl.
68+
if (*func)
69+
return *func;
70+
71+
return createFnDecl(b, symTable, funcName, funcT,
72+
/*setPrivate=*/true, symbolTables);
73+
}
74+
75+
/// Rewrite a binary arithmetic operation to an APFloat function call.
76+
template <typename OpTy, const char *APFloatName>
77+
struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
78+
BinaryArithOpToAPFloatConversion(MLIRContext *context, PatternBenefit benefit,
79+
SymbolOpInterface symTable)
80+
: OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {};
81+
82+
LogicalResult matchAndRewrite(OpTy op,
83+
PatternRewriter &rewriter) const override {
84+
// Get APFloat function from runtime library.
85+
FailureOr<FuncOp> fn =
86+
lookupOrCreateBinaryFn(rewriter, symTable, APFloatName);
87+
if (failed(fn))
88+
return fn;
89+
90+
rewriter.setInsertionPoint(op);
91+
// Cast operands to 64-bit integers.
92+
Location loc = op.getLoc();
93+
auto floatTy = cast<FloatType>(op.getType());
94+
auto intWType = rewriter.getIntegerType(floatTy.getWidth());
95+
auto int64Type = rewriter.getI64Type();
96+
Value lhsBits = arith::ExtUIOp::create(
97+
rewriter, loc, int64Type,
98+
arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
99+
Value rhsBits = arith::ExtUIOp::create(
100+
rewriter, loc, int64Type,
101+
arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
102+
103+
// Call APFloat function.
104+
int32_t sem =
105+
llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
106+
Value semValue = arith::ConstantOp::create(
107+
rewriter, loc, rewriter.getI32Type(),
108+
rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
109+
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
110+
auto resultOp =
111+
func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
112+
SymbolRefAttr::get(*fn), params);
113+
114+
// Truncate result to the original width.
115+
Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
116+
resultOp->getResult(0));
117+
rewriter.replaceOp(
118+
op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
119+
return success();
120+
}
121+
122+
SymbolOpInterface symTable;
123+
};
124+
125+
namespace {
126+
struct ArithToAPFloatConversionPass final
127+
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
128+
using Base::Base;
129+
130+
void runOnOperation() override {
131+
MLIRContext *context = &getContext();
132+
RewritePatternSet patterns(context);
133+
static const char add[] = "add";
134+
static const char subtract[] = "subtract";
135+
static const char multiply[] = "multiply";
136+
static const char divide[] = "divide";
137+
static const char remainder[] = "remainder";
138+
patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp, add>,
139+
BinaryArithOpToAPFloatConversion<arith::SubFOp, subtract>,
140+
BinaryArithOpToAPFloatConversion<arith::MulFOp, multiply>,
141+
BinaryArithOpToAPFloatConversion<arith::DivFOp, divide>,
142+
BinaryArithOpToAPFloatConversion<arith::RemFOp, remainder>>(
143+
context, 1, getOperation());
144+
LogicalResult result = success();
145+
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
146+
if (diag.getSeverity() == DiagnosticSeverity::Error) {
147+
result = failure();
148+
}
149+
// NB: if you don't return failure, no other diag handlers will fire (see
150+
// mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit).
151+
return failure();
152+
});
153+
walkAndApplyPatterns(getOperation(), std::move(patterns));
154+
if (failed(result))
155+
return signalPassFailure();
156+
}
157+
};
158+
} // namespace
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
add_mlir_conversion_library(MLIRArithToAPFloat
2+
ArithToAPFloat.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRArithDialect
15+
MLIRArithTransforms
16+
MLIRFuncDialect
17+
MLIRFuncUtils
18+
)

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/Arith/Transforms/Passes.h"
17+
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
1718
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
1819
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1920
#include "mlir/IR/TypeUtilities.h"

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
22
add_subdirectory(AMDGPUToROCDL)
33
add_subdirectory(ArithCommon)
44
add_subdirectory(ArithToAMDGPU)
5+
add_subdirectory(ArithToAPFloat)
56
add_subdirectory(ArithToArmSME)
67
add_subdirectory(ArithToEmitC)
78
add_subdirectory(ArithToLLVM)

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,6 +1654,20 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
16541654
return failure();
16551655
}
16561656
}
1657+
} else if (auto floatTy = dyn_cast<FloatType>(printType)) {
1658+
// Print other floating-point types using the APFloat runtime library.
1659+
int32_t sem =
1660+
llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
1661+
Value semValue = LLVM::ConstantOp::create(
1662+
rewriter, loc, rewriter.getI32Type(),
1663+
rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
1664+
Value floatBits =
1665+
LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value);
1666+
printer =
1667+
LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables);
1668+
emitCall(rewriter, loc, printer.value(),
1669+
ValueRange({semValue, floatBits}));
1670+
return success();
16571671
} else {
16581672
return failure();
16591673
}

0 commit comments

Comments
 (0)