Skip to content

Commit 222f4e4

Browse files
[mlir] Add FP software implementation lowering pass: arith-to-apfloat (#166618)
This commit adds a new pass that lowers floating-point `arith` operations to calls into the execution engine runtime library. Currently supported operations: `addf`, `subf`, `mulf`, `divf`, `remf`. All floating-point types that have an APFloat semantics are supported. This includes low-precision floating-point types such as `f4E2M1FN` that cannot execute natively on CPUs. This commit also improves the `vector.print` lowering pattern to call into the runtime library for floating-point types that are not supported by LLVM. This is necessary to write a meaningful integration test. The way it works is ```mlir func.func @full_example() { %a = arith.constant 1.4 : f8E4M3FN %b = func.call @foo() : () -> (f8E4M3FN) %c = arith.addf %a, %b : f8E4M3FN vector.print %c : f8E4M3FN return } ``` gets transformed to ```mlir func.func private @__mlir_apfloat_add(i32, i64, i64) -> i6 func.func @full_example() { %cst = arith.constant 1.375000e+00 : f8E4M3FN %0 = call @foo() : () -> f8E4M3FN // bitcast operand A to integer of equal width %1 = arith.bitcast %cst : f8E4M3FN to i8 // zext A to i64 %2 = arith.extui %1 : i8 to i64 // same for operand B %3 = arith.bitcast %0 : f8E4M3FN to i8 %4 = arith.extui %3 : i8 to i64 // get the llvm::fltSemantics(f8E4M3FN) as an enum %c10_i32 = arith.constant 10 : i32 // call the impl against APFloat in mlir_apfloat_wrappers %5 = call @__mlir_apfloat_add(%c10_i32, %2, %4) : (i32, i64, i64) -> i64 // "cast" back to the original fp type %6 = arith.trunci %5 : i64 to i8 %7 = arith.bitcast %6 : i8 to f8E4M3FN vector.print %7 : f8E4M3FN } ``` Note, `llvm::fltSemantics(f8E4M3FN)` is emitted by the pattern each time an `arith` op is transformed, thereby making the call to `__mlir_apfloat_add` correct (i.e., no name mangling on type necessary). RFC: https://discourse.llvm.org/t/rfc-software-implementation-for-unsupported-fp-types-in-convert-arith-to-llvm/88785 --------- Co-authored-by: Matthias Springer <[email protected]>
1 parent 8751f26 commit 222f4e4

File tree

16 files changed

+533
-0
lines changed

16 files changed

+533
-0
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- ArithToAPFloat.h - Arith to APFloat impl conversion ---*- C++ ----*-===//
2+
//
3+
// Part of the APFloat Project, under the Apache License v2.0 with APFloat
4+
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH APFloat-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
10+
#define MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
class Pass;
16+
17+
#define GEN_PASS_DECL_ARITHTOAPFLOATCONVERSIONPASS
18+
#include "mlir/Conversion/Passes.h.inc"
19+
} // namespace mlir
20+
21+
#endif // MLIR_CONVERSION_ARITHTOAPFLOAT_ARITHTOAPFLOAT_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
1313
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
1414
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
15+
#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
1516
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
1617
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
1718
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,21 @@ def ArithToLLVMConversionPass : Pass<"convert-arith-to-llvm"> {
186186
];
187187
}
188188

189+
//===----------------------------------------------------------------------===//
190+
// ArithToAPFloat
191+
//===----------------------------------------------------------------------===//
192+
193+
def ArithToAPFloatConversionPass
194+
: Pass<"convert-arith-to-apfloat", "ModuleOp"> {
195+
let summary = "Convert Arith ops to APFloat runtime library calls";
196+
let description = [{
197+
This pass converts supported Arith ops to APFloat-based runtime library
198+
calls (APFloatWrappers.cpp). APFloat is a software implementation of
199+
floating-point arithmetic operations.
200+
}];
201+
let dependentDialects = ["func::FuncDialect"];
202+
}
203+
189204
//===----------------------------------------------------------------------===//
190205
// ArithToSPIRV
191206
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Func/Utils/Utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ mlir::FailureOr<std::pair<mlir::func::FuncOp, mlir::func::CallOp>>
6060
deduplicateArgsOfFuncOp(mlir::RewriterBase &rewriter, mlir::func::FuncOp funcOp,
6161
mlir::ModuleOp moduleOp);
6262

63+
/// Look up a FuncOp with signature `resultTypes`(`paramTypes`)` and name
64+
/// `name`. Return a failure if the FuncOp is found but with a different
65+
/// signature.
66+
FailureOr<FuncOp> lookupFnDecl(SymbolOpInterface symTable, StringRef name,
67+
FunctionType funcT,
68+
SymbolTableCollection *symbolTables = nullptr);
69+
6370
} // namespace func
6471
} // namespace mlir
6572

mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ lookupOrCreatePrintF32Fn(OpBuilder &b, Operation *moduleOp,
5252
FailureOr<LLVM::LLVMFuncOp>
5353
lookupOrCreatePrintF64Fn(OpBuilder &b, Operation *moduleOp,
5454
SymbolTableCollection *symbolTables = nullptr);
55+
FailureOr<LLVM::LLVMFuncOp>
56+
lookupOrCreateApFloatPrintFn(OpBuilder &b, Operation *moduleOp,
57+
SymbolTableCollection *symbolTables = nullptr);
58+
5559
/// Declares a function to print a C-string.
5660
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
5761
/// have the signature void(char const*). The default function is `printString`.
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
//===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/ArithToAPFloat/ArithToAPFloat.h"
10+
11+
#include "mlir/Dialect/Arith/IR/Arith.h"
12+
#include "mlir/Dialect/Arith/Transforms/Passes.h"
13+
#include "mlir/Dialect/Func/IR/FuncOps.h"
14+
#include "mlir/Dialect/Func/Utils/Utils.h"
15+
#include "mlir/IR/PatternMatch.h"
16+
#include "mlir/IR/Verifier.h"
17+
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
18+
#include "llvm/Support/Debug.h"
19+
20+
#define DEBUG_TYPE "arith-to-apfloat"
21+
22+
namespace mlir {
23+
#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
24+
#include "mlir/Conversion/Passes.h.inc"
25+
} // namespace mlir
26+
27+
using namespace mlir;
28+
using namespace mlir::func;
29+
30+
static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
31+
StringRef name, FunctionType funcT, bool setPrivate,
32+
SymbolTableCollection *symbolTables = nullptr) {
33+
OpBuilder::InsertionGuard g(b);
34+
assert(!symTable->getRegion(0).empty() && "expected non-empty region");
35+
b.setInsertionPointToStart(&symTable->getRegion(0).front());
36+
FuncOp funcOp = FuncOp::create(b, symTable->getLoc(), name, funcT);
37+
if (setPrivate)
38+
funcOp.setPrivate();
39+
if (symbolTables) {
40+
SymbolTable &symbolTable = symbolTables->getSymbolTable(symTable);
41+
symbolTable.insert(funcOp, symTable->getRegion(0).front().begin());
42+
}
43+
return funcOp;
44+
}
45+
46+
/// Helper function to look up or create the symbol for a runtime library
47+
/// function for a binary arithmetic operation.
48+
///
49+
/// Parameter 1: APFloat semantics
50+
/// Parameter 2: Left-hand side operand
51+
/// Parameter 3: Right-hand side operand
52+
///
53+
/// This function will return a failure if the function is found but has an
54+
/// unexpected signature.
55+
///
56+
static FailureOr<FuncOp>
57+
lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
58+
SymbolTableCollection *symbolTables = nullptr) {
59+
auto i32Type = IntegerType::get(symTable->getContext(), 32);
60+
auto i64Type = IntegerType::get(symTable->getContext(), 64);
61+
62+
std::string funcName = (llvm::Twine("__mlir_apfloat_") + name).str();
63+
FunctionType funcT =
64+
FunctionType::get(b.getContext(), {i32Type, i64Type, i64Type}, {i64Type});
65+
FailureOr<FuncOp> func =
66+
lookupFnDecl(symTable, funcName, funcT, symbolTables);
67+
// Failed due to type mismatch.
68+
if (failed(func))
69+
return func;
70+
// Successfully matched existing decl.
71+
if (*func)
72+
return *func;
73+
74+
return createFnDecl(b, symTable, funcName, funcT,
75+
/*setPrivate=*/true, symbolTables);
76+
}
77+
78+
/// Rewrite a binary arithmetic operation to an APFloat function call.
79+
template <typename OpTy, const char *APFloatName>
80+
struct BinaryArithOpToAPFloatConversion final : OpRewritePattern<OpTy> {
81+
BinaryArithOpToAPFloatConversion(MLIRContext *context, PatternBenefit benefit,
82+
SymbolOpInterface symTable)
83+
: OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {};
84+
85+
LogicalResult matchAndRewrite(OpTy op,
86+
PatternRewriter &rewriter) const override {
87+
// Get APFloat function from runtime library.
88+
FailureOr<FuncOp> fn =
89+
lookupOrCreateBinaryFn(rewriter, symTable, APFloatName);
90+
if (failed(fn))
91+
return fn;
92+
93+
rewriter.setInsertionPoint(op);
94+
// Cast operands to 64-bit integers.
95+
Location loc = op.getLoc();
96+
auto floatTy = cast<FloatType>(op.getType());
97+
auto intWType = rewriter.getIntegerType(floatTy.getWidth());
98+
auto int64Type = rewriter.getI64Type();
99+
Value lhsBits = arith::ExtUIOp::create(
100+
rewriter, loc, int64Type,
101+
arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
102+
Value rhsBits = arith::ExtUIOp::create(
103+
rewriter, loc, int64Type,
104+
arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
105+
106+
// Call APFloat function.
107+
int32_t sem =
108+
llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
109+
Value semValue = arith::ConstantOp::create(
110+
rewriter, loc, rewriter.getI32Type(),
111+
rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
112+
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
113+
auto resultOp =
114+
func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
115+
SymbolRefAttr::get(*fn), params);
116+
117+
// Truncate result to the original width.
118+
Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
119+
resultOp->getResult(0));
120+
rewriter.replaceOp(
121+
op, arith::BitcastOp::create(rewriter, loc, floatTy, truncatedBits));
122+
return success();
123+
}
124+
125+
SymbolOpInterface symTable;
126+
};
127+
128+
namespace {
129+
struct ArithToAPFloatConversionPass final
130+
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
131+
using Base::Base;
132+
133+
void runOnOperation() override {
134+
MLIRContext *context = &getContext();
135+
RewritePatternSet patterns(context);
136+
static const char add[] = "add";
137+
static const char subtract[] = "subtract";
138+
static const char multiply[] = "multiply";
139+
static const char divide[] = "divide";
140+
static const char remainder[] = "remainder";
141+
patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp, add>,
142+
BinaryArithOpToAPFloatConversion<arith::SubFOp, subtract>,
143+
BinaryArithOpToAPFloatConversion<arith::MulFOp, multiply>,
144+
BinaryArithOpToAPFloatConversion<arith::DivFOp, divide>,
145+
BinaryArithOpToAPFloatConversion<arith::RemFOp, remainder>>(
146+
context, 1, getOperation());
147+
LogicalResult result = success();
148+
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
149+
if (diag.getSeverity() == DiagnosticSeverity::Error) {
150+
result = failure();
151+
}
152+
// NB: if you don't return failure, no other diag handlers will fire (see
153+
// mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit).
154+
return failure();
155+
});
156+
walkAndApplyPatterns(getOperation(), std::move(patterns));
157+
if (failed(result))
158+
return signalPassFailure();
159+
}
160+
};
161+
} // namespace
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
add_mlir_conversion_library(MLIRArithToAPFloat
2+
ArithToAPFloat.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToLLVM
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRArithDialect
15+
MLIRArithTransforms
16+
MLIRFuncDialect
17+
)

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/Arith/Transforms/Passes.h"
17+
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
1718
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
1819
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1920
#include "mlir/IR/TypeUtilities.h"

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_subdirectory(AffineToStandard)
22
add_subdirectory(AMDGPUToROCDL)
33
add_subdirectory(ArithCommon)
44
add_subdirectory(ArithToAMDGPU)
5+
add_subdirectory(ArithToAPFloat)
56
add_subdirectory(ArithToArmSME)
67
add_subdirectory(ArithToEmitC)
78
add_subdirectory(ArithToLLVM)

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,6 +1654,20 @@ class VectorPrintOpConversion : public ConvertOpToLLVMPattern<vector::PrintOp> {
16541654
return failure();
16551655
}
16561656
}
1657+
} else if (auto floatTy = dyn_cast<FloatType>(printType)) {
1658+
// Print other floating-point types using the APFloat runtime library.
1659+
int32_t sem =
1660+
llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics());
1661+
Value semValue = LLVM::ConstantOp::create(
1662+
rewriter, loc, rewriter.getI32Type(),
1663+
rewriter.getIntegerAttr(rewriter.getI32Type(), sem));
1664+
Value floatBits =
1665+
LLVM::ZExtOp::create(rewriter, loc, rewriter.getI64Type(), value);
1666+
printer =
1667+
LLVM::lookupOrCreateApFloatPrintFn(rewriter, parent, symbolTables);
1668+
emitCall(rewriter, loc, printer.value(),
1669+
ValueRange({semValue, floatBits}));
1670+
return success();
16571671
} else {
16581672
return failure();
16591673
}

0 commit comments

Comments
 (0)