-
Notifications
You must be signed in to change notification settings - Fork 12.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][ARITH] Adds missing foldings for truncf #128096
Conversation
@llvm/pr-subscribers-mlir-emitc @llvm/pr-subscribers-mlir Author: Zahi Moudallal (zahimoud) ChangesThis patch is mainly to deal with folding truncf, as follows: Full diff: https://github.com/llvm/llvm-project/pull/128096.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 8a9f223089794..494985fbce94e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1517,6 +1517,28 @@ LogicalResult arith::TruncIOp::verify() {
/// Perform safe const propagation for truncf, i.e., only propagate if FP value
/// can be represented without precision loss.
OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
+ if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
+ Value src = extOp.getIn();
+ Type srcType = getElementTypeOrSelf(src.getType());
+ Type dstType = getElementTypeOrSelf(getType());
+ // truncf(extf(a)) -> truncf(a)
+ if (llvm::cast<FloatType>(srcType).getWidth() >
+ llvm::cast<FloatType>(dstType).getWidth()) {
+ setOperand(src);
+ return getResult();
+ }
+
+ // truncf(extf(a)) -> a
+ if (srcType == dstType)
+ return src;
+ }
+
+ // truncf(truncf(a)) -> truncf(a)
+ if (auto truncOp = getOperand().getDefiningOp<arith::TruncFOp>()) {
+ setOperand(truncOp.getIn());
+ return getResult();
+ }
+
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
return constFoldCastOp<FloatAttr, FloatAttr>(
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index cb1d092918f03..cebdebef85dc9 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -764,9 +764,8 @@ func.func @arith_extf(%arg0: f16) -> f64 {
func.func @arith_truncf(%arg0: f64) -> f16 {
// CHECK-LABEL: arith_truncf
// CHECK-SAME: (%[[Arg0:[^ ]*]]: f64)
- // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f32
+ // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f16
%truncd0 = arith.truncf %arg0 : f64 to f32
- // CHECK: %[[Truncd1:.*]] = emitc.cast %[[Truncd0]] : f32 to f16
%truncd1 = arith.truncf %truncd0 : f32 to f16
return %truncd1 : f16
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index e3750bb020cad..aa4136cd6361e 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -714,6 +714,35 @@ func.func @extFPVectorConstant() -> vector<2xf128> {
return %0 : vector<2xf128>
}
+// CHECK-LABEL: @truncExtf
+// CHECK-NOT: truncf
+// CHECK: return %arg0
+func.func @truncExtf(%arg0: f32) -> f32 {
+ %extf = arith.extf %arg0 : f32 to f64
+ %trunc = arith.truncf %extf : f64 to f32
+ return %trunc : f32
+}
+
+// CHECK-LABEL: @truncExtf2
+// CHECK: %[[ARG0:.+]]: f32
+// CHECK: %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f32 to f16
+// CHECK: return %[[CST:.*]]
+func.func @truncExtf2(%arg0: f32) -> f16 {
+ %extf = arith.extf %arg0 : f32 to f64
+ %truncf = arith.truncf %extf : f64 to f16
+ return %truncf : f16
+}
+
+// CHECK-LABEL: @truncTruncf
+// CHECK: %[[ARG0:.+]]: f64
+// CHECK: %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f64 to f16
+// CHECK: return %[[CST:.*]]
+func.func @truncTruncf(%arg0: f64) -> f16 {
+ %truncf = arith.truncf %arg0 : f64 to f32
+ %truncf1 = arith.truncf %truncf : f32 to f16
+ return %truncf1 : f16
+}
+
// TODO: We should also add a test for not folding arith.extf on information loss.
// This may happen when extending f8E5M2FNUZ to f16.
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
b8c749c
to
d822206
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On further consideration and the realization that I was thinking of ext(trunc(x)) => x
folding in some of my comments, this is fine, approved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fixes. Look alright now, I'm still not 100% convinced that this is correct over all small fp types, but I don't have a counterexample at hand and don't want to block this.
Co-authored-by: Jakub Kuderski <[email protected]>
This patch is mainly to deal with folding
truncf
, as follows:truncf(extf(a))
->a
, ifa
has the same bitwidth as the resulttruncf(extf(a))
->truncf(a)
, ifa
has larger bitwidth than the result