Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][ARITH] Adds missing foldings for truncf #128096

Merged
merged 6 commits into from
Feb 21, 2025

Conversation

zahimoud
Copy link
Contributor

@zahimoud zahimoud commented Feb 21, 2025

This patch is mainly to deal with folding truncf, as follows:
truncf(extf(a)) -> a, if a has the same bitwidth as the result
truncf(extf(a)) -> truncf(a), if a has larger bitwidth than the result

@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2025

@llvm/pr-subscribers-mlir-emitc
@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Zahi Moudallal (zahimoud)

Changes

This patch is mainly to deal with folding truncf, as follows:
truncf(extf(a)) -> a, if a has the same bitwidth as the result
truncf(extf(a)) -> truncf(a), if a has larger bitwidth than the result
truncf(truncf(a)) -> truncf(a), in any case


Full diff: https://github.com/llvm/llvm-project/pull/128096.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+22)
  • (modified) mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir (+1-2)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (+29)
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 8a9f223089794..494985fbce94e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -1517,6 +1517,28 @@ LogicalResult arith::TruncIOp::verify() {
 /// Perform safe const propagation for truncf, i.e., only propagate if FP value
 /// can be represented without precision loss.
 OpFoldResult arith::TruncFOp::fold(FoldAdaptor adaptor) {
+  if (auto extOp = getOperand().getDefiningOp<arith::ExtFOp>()) {
+    Value src = extOp.getIn();
+    Type srcType = getElementTypeOrSelf(src.getType());
+    Type dstType = getElementTypeOrSelf(getType());
+    // truncf(extf(a)) -> truncf(a)
+    if (llvm::cast<FloatType>(srcType).getWidth() >
+        llvm::cast<FloatType>(dstType).getWidth()) {
+      setOperand(src);
+      return getResult();
+    }
+
+    // truncf(extf(a)) -> a
+    if (srcType == dstType)
+      return src;
+  }
+
+  // truncf(truncf(a)) -> truncf(a)
+  if (auto truncOp = getOperand().getDefiningOp<arith::TruncFOp>()) {
+    setOperand(truncOp.getIn());
+    return getResult();
+  }
+
   auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
   const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
   return constFoldCastOp<FloatAttr, FloatAttr>(
diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
index cb1d092918f03..cebdebef85dc9 100644
--- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
+++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
@@ -764,9 +764,8 @@ func.func @arith_extf(%arg0: f16) -> f64 {
 func.func @arith_truncf(%arg0: f64) -> f16 {
   // CHECK-LABEL: arith_truncf
   // CHECK-SAME: (%[[Arg0:[^ ]*]]: f64)
-  // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f32
+  // CHECK: %[[Truncd0:.*]] = emitc.cast %[[Arg0]] : f64 to f16
   %truncd0 = arith.truncf %arg0 : f64 to f32
-  // CHECK: %[[Truncd1:.*]] = emitc.cast %[[Truncd0]] : f32 to f16
   %truncd1 = arith.truncf %truncd0 : f32 to f16
 
   return %truncd1 : f16
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index e3750bb020cad..aa4136cd6361e 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -714,6 +714,35 @@ func.func @extFPVectorConstant() -> vector<2xf128> {
   return %0 : vector<2xf128>
 }
 
+// CHECK-LABEL: @truncExtf
+//       CHECK-NOT:  truncf
+//       CHECK:   return  %arg0
+func.func @truncExtf(%arg0: f32) -> f32 {
+  %extf = arith.extf %arg0 : f32 to f64
+  %trunc = arith.truncf %extf : f64 to f32
+  return %trunc : f32
+}
+
+// CHECK-LABEL: @truncExtf2
+//       CHECK:  %[[ARG0:.+]]: f32
+//       CHECK:  %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f32 to f16
+//       CHECK:   return  %[[CST:.*]]
+func.func @truncExtf2(%arg0: f32) -> f16 {
+  %extf = arith.extf %arg0 : f32 to f64
+  %truncf = arith.truncf %extf : f64 to f16
+  return %truncf : f16
+}
+
+// CHECK-LABEL: @truncTruncf
+//       CHECK:  %[[ARG0:.+]]: f64
+//       CHECK:  %[[CST:.*]] = arith.truncf %[[ARG0:.+]] : f64 to f16
+//       CHECK:   return  %[[CST:.*]]
+func.func @truncTruncf(%arg0: f64) -> f16 {
+  %truncf = arith.truncf %arg0 : f64 to f32
+  %truncf1 = arith.truncf %truncf : f32 to f16
+  return %truncf1 : f16
+}
+
 // TODO: We should also add a test for not folding arith.extf on information loss.
 // This may happen when extending f8E5M2FNUZ to f16.
 

@zahimoud zahimoud requested a review from kuhar February 21, 2025 01:22
@ThomasRaoux ThomasRaoux requested a review from krzysz00 February 21, 2025 05:25
Copy link

github-actions bot commented Feb 21, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@zahimoud zahimoud force-pushed the zahi/truncf-folding branch from b8c749c to d822206 Compare February 21, 2025 19:49
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On further consideration and the realization that I was thinking of ext(trunc(x)) => x folding in some of my comments, this is fine, approved

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fixes. Look alright now, I'm still not 100% convinced that this is correct over all small fp types, but I don't have a counterexample at hand and don't want to block this.

@ThomasRaoux ThomasRaoux merged commit 5d0c5c6 into llvm:main Feb 21, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants