-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Math] Add erfc to math dialect #126439
Conversation
@llvm/pr-subscribers-mlir-math @llvm/pr-subscribers-mlir Author: Jan Leyonberg (jsjodin) ChangesThis patch adds the erfc op to the math dialect. It also does lowering of the math.erfc op to libm calls. There is also a f32 polynomial approximation for the function based on Patch is 23.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/126439.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 5990a9f0d2e442b..67d2d5168fe374e 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -543,6 +543,28 @@ def Math_ErfOp : Math_FloatUnaryOp<"erf"> {
let hasFolder = 1;
}
+//===----------------------------------------------------------------------===//
+// ErfcOp
+//===----------------------------------------------------------------------===//
+
+def Math_ErfcOp : Math_FloatUnaryOp<"erfc"> {
+ let summary = "complementary error function of the specified value";
+ let description = [{
+ The `erfc` operation computes the complementary error function.
+ It takes one operand of floating point type (i.e., scalar, tensor or
+ vector) and returns one result of the same type.
+ It has no standard attributes.
+
+ Example:
+
+ ```mlir
+ // Scalar error function value.
+ %a = math.erfc %b : f64
+ ```
+ }];
+ let hasFolder = 1;
+}
+
//===----------------------------------------------------------------------===//
// ExpOp
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Approximation.h b/mlir/include/mlir/Dialect/Math/Transforms/Approximation.h
index b4ebc2f0f8fcd28..ecfdb71817dffd9 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Approximation.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Approximation.h
@@ -23,6 +23,14 @@ struct ErfPolynomialApproximation : public OpRewritePattern<math::ErfOp> {
PatternRewriter &rewriter) const final;
};
+struct ErfcPolynomialApproximation : public OpRewritePattern<math::ErfcOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(math::ErfcOp op,
+ PatternRewriter &rewriter) const final;
+};
+
} // namespace math
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index f0f17c6adcb088e..b8055354645244b 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -47,6 +47,7 @@ struct MathPolynomialApproximationOptions {
void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns);
void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns);
+void populatePolynomialApproximateErfcPattern(RewritePatternSet &patterns);
void populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns,
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index a2488dc600f51af..1a568d549548083 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -175,6 +175,7 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
populatePatternsForOp<math::CosOp>(patterns, ctx, "cosf", "cos");
populatePatternsForOp<math::CoshOp>(patterns, ctx, "coshf", "cosh");
populatePatternsForOp<math::ErfOp>(patterns, ctx, "erff", "erf");
+ populatePatternsForOp<math::ErfcOp>(patterns, ctx, "erfcf", "erfc");
populatePatternsForOp<math::ExpOp>(patterns, ctx, "expf", "exp");
populatePatternsForOp<math::Exp2Op>(patterns, ctx, "exp2f", "exp2");
populatePatternsForOp<math::ExpM1Op>(patterns, ctx, "expm1f", "expm1");
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 1690585e78c5dad..93d19a69a701474 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -318,6 +318,24 @@ OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
});
}
+//===----------------------------------------------------------------------===//
+// ErfcOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::ErfcOp::fold(FoldAdaptor adaptor) {
+ return constFoldUnaryOpConditional<FloatAttr>(
+ adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
+ switch (a.getSizeInBits(a.getSemantics())) {
+ case 64:
+ return APFloat(erfc(a.convertToDouble()));
+ case 32:
+ return APFloat(erfcf(a.convertToFloat()));
+ default:
+ return {};
+ }
+ });
+}
+
//===----------------------------------------------------------------------===//
// IPowIOp folder
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 24c892f68b50316..1585ebc74e7908f 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -173,6 +173,10 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
// Helper functions to create constants.
//----------------------------------------------------------------------------//
+static Value boolCst(ImplicitLocOpBuilder &builder, bool value) {
+ return builder.create<arith::ConstantOp>(builder.getBoolAttr(value));
+}
+
static Value floatCst(ImplicitLocOpBuilder &builder, float value,
Type elementType) {
assert((elementType.isF16() || elementType.isF32()) &&
@@ -1118,12 +1122,102 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
return success();
}
+// Approximates erfc(x) with
+LogicalResult
+ErfcPolynomialApproximation::matchAndRewrite(math::ErfcOp op,
+ PatternRewriter &rewriter) const {
+ Value x = op.getOperand();
+ Type et = getElementTypeOrSelf(x);
+
+ if (!et.isF32())
+ return rewriter.notifyMatchFailure(op, "only f32 type is supported.");
+ std::optional<VectorShape> shape = vectorShape(x);
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ auto bcast = [&](Value value) -> Value {
+ return broadcast(builder, value, shape);
+ };
+
+ Value trueValue = bcast(boolCst(builder, true));
+ Value zero = bcast(floatCst(builder, 0.0f, et));
+ Value one = bcast(floatCst(builder, 1.0f, et));
+ Value onehalf = bcast(floatCst(builder, 0.5f, et));
+ Value neg4 = bcast(floatCst(builder, -4.0f, et));
+ Value neg2 = bcast(floatCst(builder, -2.0f, et));
+ Value pos2 = bcast(floatCst(builder, 2.0f, et));
+ Value posInf = bcast(f32FromBits(builder, 0x7f800000u));
+ Value clampVal = bcast(floatCst(builder, 10.0546875f, et));
+
+ // Get abs(x)
+ Value isNegativeArg =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, zero);
+ Value negArg = builder.create<arith::NegFOp>(x);
+ Value a = builder.create<arith::SelectOp>(isNegativeArg, negArg, x);
+ Value p = builder.create<arith::AddFOp>(a, pos2);
+ Value r = builder.create<arith::DivFOp>(one, p);
+ Value q = builder.create<math::FmaOp>(neg4, r, one);
+ Value t = builder.create<math::FmaOp>(builder.create<arith::AddFOp>(q, one),
+ neg2, a);
+ Value e = builder.create<math::FmaOp>(builder.create<arith::NegFOp>(a), q, t);
+ q = builder.create<math::FmaOp>(r, e, q);
+
+ p = bcast(floatCst(builder, -0x1.a4a000p-12f, et)); // -4.01139259e-4
+ Value c1 = bcast(floatCst(builder, -0x1.42a260p-10f, et)); // -1.23075210e-3
+ p = builder.create<math::FmaOp>(p, q, c1);
+ Value c2 = bcast(floatCst(builder, 0x1.585714p-10f, et)); // 1.31355342e-3
+ p = builder.create<math::FmaOp>(p, q, c2);
+ Value c3 = bcast(floatCst(builder, 0x1.1adcc4p-07f, et)); // 8.63227434e-3
+ p = builder.create<math::FmaOp>(p, q, c3);
+ Value c4 = bcast(floatCst(builder, -0x1.081b82p-07f, et)); // -8.05991981e-3
+ p = builder.create<math::FmaOp>(p, q, c4);
+ Value c5 = bcast(floatCst(builder, -0x1.bc0b6ap-05f, et)); // -5.42046614e-2
+ p = builder.create<math::FmaOp>(p, q, c5);
+ Value c6 = bcast(floatCst(builder, 0x1.4ffc46p-03f, et)); // 1.64055392e-1
+ p = builder.create<math::FmaOp>(p, q, c6);
+ Value c7 = bcast(floatCst(builder, -0x1.540840p-03f, et)); // -1.66031361e-1
+ p = builder.create<math::FmaOp>(p, q, c7);
+ Value c8 = bcast(floatCst(builder, -0x1.7bf616p-04f, et)); // -9.27639827e-2
+ p = builder.create<math::FmaOp>(p, q, c8);
+ Value c9 = bcast(floatCst(builder, 0x1.1ba03ap-02f, et)); // 2.76978403e-1
+ p = builder.create<math::FmaOp>(p, q, c9);
+
+ Value d = builder.create<math::FmaOp>(pos2, a, one);
+ r = builder.create<arith::DivFOp>(one, d);
+ q = builder.create<math::FmaOp>(p, r, r);
+ e = builder.create<math::FmaOp>(
+ builder.create<math::FmaOp>(q, builder.create<arith::NegFOp>(a), onehalf),
+ pos2, builder.create<arith::SubFOp>(p, q));
+ r = builder.create<math::FmaOp>(e, r, q);
+
+ Value s = builder.create<arith::MulFOp>(a, a);
+ e = builder.create<math::ExpOp>(builder.create<arith::NegFOp>(s));
+
+ t = builder.create<math::FmaOp>(builder.create<arith::NegFOp>(a), a, s);
+ r = builder.create<math::FmaOp>(
+ r, e,
+ builder.create<arith::MulFOp>(builder.create<arith::MulFOp>(r, e), t));
+
+ Value isNotLessThanInf = builder.create<arith::XOrIOp>(
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, a, posInf),
+ trueValue);
+ r = builder.create<arith::SelectOp>(isNotLessThanInf,
+ builder.create<arith::AddFOp>(x, x), r);
+ Value isGreaterThanClamp =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, a, clampVal);
+ r = builder.create<arith::SelectOp>(isGreaterThanClamp, zero, r);
+
+ Value isNegative =
+ builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, zero);
+ r = builder.create<arith::SelectOp>(
+ isNegative, builder.create<arith::SubFOp>(pos2, r), r);
+
+ rewriter.replaceOp(op, r);
+ return success();
+}
//----------------------------------------------------------------------------//
// Exp approximation.
//----------------------------------------------------------------------------//
-
namespace {
-
Value clampWithNormals(ImplicitLocOpBuilder &builder,
const std::optional<VectorShape> shape, Value value,
float lowerBound, float upperBound) {
@@ -1667,6 +1761,11 @@ void mlir::populatePolynomialApproximateErfPattern(
patterns.add<ErfPolynomialApproximation>(patterns.getContext());
}
+void mlir::populatePolynomialApproximateErfcPattern(
+ RewritePatternSet &patterns) {
+ patterns.add<ErfcPolynomialApproximation>(patterns.getContext());
+}
+
void mlir::populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns,
const MathPolynomialApproximationOptions &options) {
@@ -1680,13 +1779,14 @@ void mlir::populateMathPolynomialApproximationPatterns(
ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
patterns.getContext());
- patterns
- .add<AtanApproximation, Atan2Approximation, TanhApproximation,
- LogApproximation, Log2Approximation, Log1pApproximation,
- ErfPolynomialApproximation, AsinPolynomialApproximation,
- AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
- CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
- SinAndCosApproximation<false, math::CosOp>>(patterns.getContext());
+ patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
+ LogApproximation, Log2Approximation, Log1pApproximation,
+ ErfPolynomialApproximation, ErfcPolynomialApproximation,
+ AsinPolynomialApproximation, AcosPolynomialApproximation,
+ ExpApproximation, ExpM1Approximation, CbrtApproximation,
+ SinAndCosApproximation<true, math::SinOp>,
+ SinAndCosApproximation<false, math::CosOp>>(
+ patterns.getContext());
if (options.enableAvx2) {
patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
patterns.getContext());
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 81d071e6bbba368..badc95fa2d4aac6 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -81,6 +81,118 @@ func.func @erf_scalar(%arg0: f32) -> f32 {
return %0 : f32
}
+// CHECK-LABEL: func @erfc_scalar(
+// CHECK-SAME: %[[val_arg0:.*]]: f32) -> f32 {
+// CHECK-DAG: %[[c127_i32:.*]] = arith.constant 127 : i32
+// CHECK-DAG: %[[c23_i32:.*]] = arith.constant 23 : i32
+// CHECK-DAG: %[[cst:.*]] = arith.constant 1.270000e+02 : f32
+// CHECK-DAG: %[[cst_0:.*]] = arith.constant -1.270000e+02 : f32
+// CHECK-DAG: %[[cst_1:.*]] = arith.constant 8.880000e+01 : f32
+// CHECK-DAG: %[[cst_2:.*]] = arith.constant -8.780000e+01 : f32
+// CHECK-DAG: %[[cst_3:.*]] = arith.constant 0.166666657 : f32
+// CHECK-DAG: %[[cst_4:.*]] = arith.constant 0.0416657962 : f32
+// CHECK-DAG: %[[cst_5:.*]] = arith.constant 0.00833345205 : f32
+// CHECK-DAG: %[[cst_6:.*]] = arith.constant 0.00139819994 : f32
+// CHECK-DAG: %[[cst_7:.*]] = arith.constant 1.98756912E-4 : f32
+// CHECK-DAG: %[[cst_8:.*]] = arith.constant 2.12194442E-4 : f32
+// CHECK-DAG: %[[cst_9:.*]] = arith.constant -0.693359375 : f32
+// CHECK-DAG: %[[cst_10:.*]] = arith.constant 1.44269502 : f32
+// CHECK-DAG: %[[cst_11:.*]] = arith.constant 0.276978403 : f32
+// CHECK-DAG: %[[cst_12:.*]] = arith.constant -0.0927639827 : f32
+// CHECK-DAG: %[[cst_13:.*]] = arith.constant -0.166031361 : f32
+// CHECK-DAG: %[[cst_14:.*]] = arith.constant 0.164055392 : f32
+// CHECK-DAG: %[[cst_15:.*]] = arith.constant -0.0542046614 : f32
+// CHECK-DAG: %[[cst_16:.*]] = arith.constant -8.059920e-03 : f32
+// CHECK-DAG: %[[cst_17:.*]] = arith.constant 0.00863227434 : f32
+// CHECK-DAG: %[[cst_18:.*]] = arith.constant 0.00131355342 : f32
+// CHECK-DAG: %[[cst_19:.*]] = arith.constant -0.0012307521 : f32
+// CHECK-DAG: %[[cst_20:.*]] = arith.constant -4.01139259E-4 : f32
+// CHECK-DAG: %[[cst_true:.*]] = arith.constant true
+// CHECK-DAG: %[[cst_21:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[cst_22:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG: %[[cst_23:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK-DAG: %[[cst_24:.*]] = arith.constant -4.000000e+00 : f32
+// CHECK-DAG: %[[cst_25:.*]] = arith.constant -2.000000e+00 : f32
+// CHECK-DAG: %[[cst_26:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG: %[[cst_27:.*]] = arith.constant 0x7F800000 : f32
+// CHECK-DAG: %[[cst_28:.*]] = arith.constant 10.0546875 : f32
+// CHECK: %[[val_0:.*]] = arith.cmpf olt, %[[val_arg0]], %[[cst_21]] : f32
+// CHECK: %[[val_1:.*]] = arith.negf %[[val_arg0]] : f32
+// CHECK: %[[val_2:.*]] = arith.select %[[val_0]], %[[val_1]], %[[val_arg0]] : f32
+// CHECK: %[[val_3:.*]] = arith.addf %[[val_2]], %[[cst_26]] : f32
+// CHECK: %[[val_4:.*]] = arith.divf %[[cst_22]], %[[val_3]] : f32
+// CHECK: %[[val_5:.*]] = math.fma %[[cst_24]], %[[val_4]], %[[cst_22]] : f32
+// CHECK: %[[val_6:.*]] = arith.addf %[[val_5]], %[[cst_22]] : f32
+// CHECK: %[[val_7:.*]] = math.fma %[[val_6]], %[[cst_25]], %[[val_2]] : f32
+// CHECK: %[[val_8:.*]] = arith.negf %[[val_2]] : f32
+// CHECK: %[[val_9:.*]] = math.fma %[[val_8]], %[[val_5]], %[[val_7]] : f32
+// CHECK: %[[val_10:.*]] = math.fma %[[val_4]], %[[val_9]], %[[val_5]] : f32
+// CHECK: %[[val_11:.*]] = math.fma %[[cst_20]], %[[val_10]], %[[cst_19]] : f32
+// CHECK: %[[val_12:.*]] = math.fma %[[val_11]], %[[val_10]], %[[cst_18]] : f32
+// CHECK: %[[val_13:.*]] = math.fma %[[val_12]], %[[val_10]], %[[cst_17]] : f32
+// CHECK: %[[val_14:.*]] = math.fma %[[val_13]], %[[val_10]], %[[cst_16]] : f32
+// CHECK: %[[val_15:.*]] = math.fma %[[val_14]], %[[val_10]], %[[cst_15]] : f32
+// CHECK: %[[val_16:.*]] = math.fma %[[val_15]], %[[val_10]], %[[cst_14]] : f32
+// CHECK: %[[val_17:.*]] = math.fma %[[val_16]], %[[val_10]], %[[cst_13]] : f32
+// CHECK: %[[val_18:.*]] = math.fma %[[val_17]], %[[val_10]], %[[cst_12]] : f32
+// CHECK: %[[val_19:.*]] = math.fma %[[val_18]], %[[val_10]], %[[cst_11]] : f32
+// CHECK: %[[val_20:.*]] = math.fma %[[cst_26]], %[[val_2]], %[[cst_22]] : f32
+// CHECK: %[[val_21:.*]] = arith.divf %[[cst_22]], %[[val_20]] : f32
+// CHECK: %[[val_22:.*]] = math.fma %[[val_19]], %[[val_21]], %[[val_21]] : f32
+// CHECK: %[[val_23:.*]] = arith.subf %[[val_19]], %[[val_22]] : f32
+// CHECK: %[[val_24:.*]] = arith.negf %[[val_2]] : f32
+// CHECK: %[[val_25:.*]] = math.fma %[[val_22]], %[[val_24]], %[[cst_23]] : f32
+// CHECK: %[[val_26:.*]] = math.fma %[[val_25]], %[[cst_26]], %[[val_23]] : f32
+// CHECK: %[[val_27:.*]] = math.fma %[[val_26]], %[[val_21]], %[[val_22]] : f32
+// CHECK: %[[val_28:.*]] = arith.mulf %[[val_2]], %[[val_2]] : f32
+// CHECK: %[[val_29:.*]] = arith.negf %[[val_28]] : f32
+// CHECK: %[[val_30:.*]] = arith.cmpf uge, %[[val_29]], %[[cst_2]] : f32
+// CHECK: %[[val_31:.*]] = arith.select %[[val_30]], %[[val_29]], %[[cst_2]] : f32
+// CHECK: %[[val_32:.*]] = arith.cmpf ule, %[[val_31]], %[[cst_1]] : f32
+// CHECK: %[[val_33:.*]] = arith.select %[[val_32]], %[[val_31]], %[[cst_1]] : f32
+// CHECK: %[[val_34:.*]] = math.fma %[[val_33]], %[[cst_10]], %[[cst_23]] : f32
+// CHECK: %[[val_35:.*]] = math.floor %[[val_34]] : f32
+// CHECK: %[[val_36:.*]] = arith.cmpf uge, %[[val_35]], %[[cst_0]] : f32
+// CHECK: %[[val_37:.*]] = arith.select %[[val_36]], %[[val_35]], %[[cst_0]] : f32
+// CHECK: %[[val_38:.*]] = arith.cmpf ule, %[[val_37]], %[[cst]] : f32
+// CHECK: %[[val_39:.*]] = arith.select %[[val_38]], %[[val_37]], %[[cst]] : f32
+// CHECK: %[[val_40:.*]] = math.fma %[[cst_9]], %[[val_39]], %[[val_33]] : f32
+// CHECK: %[[val_41:.*]] = math.fma %[[cst_8]], %[[val_39]], %[[val_40]] : f32
+// CHECK: %[[val_42:.*]] = math.fma %[[val_41]], %[[cst_7]], %[[cst_6]] : f32
+// CHECK: %[[val_43:.*]] = math.fma %[[val_42]], %[[val_41]], %[[cst_5]] : f32
+// CHECK: %[[val_44:.*]] = math.fma %[[val_43]], %[[val_41]], %[[cst_4]] : f32
+// CHECK: %[[val_45:.*]] = math.fma %[[val_44]], %[[val_41]], %[[cst_3]] : f32
+// CHECK: %[[val_46:.*]] = math.fma %[[val_45]], %[[val_41]], %[[cst_23]] : f32
+// CHECK: %[[val_47:.*]] = arith.mulf %[[val_41]], %[[val_41]] : f32
+// CHECK: %[[val_48:.*]] = math.fma %[[val_46]], %[[val_47]], %[[val_41]] : f32
+// CHECK: %[[val_49:.*]] = arith.addf %[[val_48]], %[[cst_22]] : f32
+// CHECK: %[[val_50:.*]] = arith.fptosi %[[val_39]] : f32 to i32
+// CHECK: %[[val_51:.*]] = arith.addi %[[val_50]], %[[c127_i32]] : i32
+// CHECK: %[[val_52:.*]] = arith.shli %[[val_51]], %[[c23_i32]] : i32
+// CHECK: %[[val_53:.*]] = arith.bitcast %[[val_52]] : i32 to f32
+// CHECK: %[[val_54:.*]] = arith.mulf %[[val_49]], %[[val_53]] : f32
+// CHECK: %[[val_55:.*]] = arith.negf %[[val_2]] : f32
+// CHECK: %[[val_56:.*]] = math.fma %[[val_55]], %[[val_2]], %[[val_28]] : f32
+// CHECK: %[[val_57:.*]] = arith.mulf %[[val_27]], %[[val_54]] : f32
+// CHECK: %[[val_58:.*]] = arith.mulf %[[val_57]], %[[val_56]] : f32
+// CHECK: %[[val_59:.*]] = math.fma %[[val_27]], %[[val_54]], %[[val_58]] : f32
+// CHECK: %[[val_60:.*]] = arith.cmpf olt, %[[val_2]], %[[cst_27]] : f32
+// CHECK: %[[val_61:.*]] = arith.xori %[[val_60]], %[[cst_true]] : i1
+// CHECK: %[[val_62:.*]] = arith.addf %[[val_arg0]], %[[val_arg0]] : f32
+// CHECK: %[[val_63:.*]] = arith.select %[[val_61]], %[[val_62]], %[[val_59]] : f32
+// CHECK: %[[val_64:.*]] = arith.cmpf ogt, %[[val_2]], %[[cst_28]] : f32
+// CHECK: %[[val_65:.*]] = arith.select %[[val_64]], %[[cst_21]], %[[val_63]] : f32
+// CHECK: %[[val_66:.*]] = arith.cmpf olt, %[[val_arg0]], %[[cst_21]] : f32
+// CHECK: %[[val_67:.*]] = arith.subf %[[cst_26]], %[[val_65]] : f32
+// CHECK: %[[val_68:.*]] = arith.select %[[val_66]], %[[val_67]], %[[val_65]]...
[truncated]
|
75497da
to
e780841
Compare
@@ -1118,12 +1122,103 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, | |||
return success(); | |||
} | |||
|
|||
// Approximates erfc(x) with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please describe how it's being approximated and the StackOverflow URL.
9543f7d
to
24c86ec
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Please wait for other reviewers.
The `erfc` operation computes the complementary error function. | ||
It takes one operand of floating point type (i.e., scalar, tensor or | ||
vector) and returns one result of the same type. | ||
It has no standard attributes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a detailed description of erfc op here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe also a word on why it exists as a separate op, as opposed to letting people write 1 - erf(x)
(it's about floating point accuracy vs cancellation when erf(x) is close to 1, and it's standard in the C math.h library etc).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I improved the wording a bit. Hope it helps!
mlir/lib/Dialect/Math/IR/MathOps.cpp
Outdated
switch (a.getSizeInBits(a.getSemantics())) { | ||
case 64: | ||
return APFloat(erfc(a.convertToDouble())); | ||
case 32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not familiar with fltSemantics, but at least the idea is that there could be multiple semantics sharing the same bit width, like there is at 16 bit between fp16 and bf16. So maybe there is something a bit more righteous that could be done here, matching exact semantics of IEEE single / double precision rather than merely the size in bits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point, it would have been better to have a function that would return an enum perhaps, but this would be better in a separate PR, since this is the pattern that is used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I meant something more specific to the code being added here.
Instead of switching on the size in bits, switch on the semantics itself.
Reading APFloat.h, I see that there is:
static Semantics SemanticsToEnum(const llvm::fltSemantics &Sem);
so here you could do:
switch (APFloat::SemanticsToEnum(a.getSemantics)) {
and have your case:
statement based on enumerators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, that's much better. I wasn't aware they already had an enum.
Value neg4 = bcast(floatCst(builder, -4.0f, et)); | ||
Value neg2 = bcast(floatCst(builder, -2.0f, et)); | ||
Value pos2 = bcast(floatCst(builder, 2.0f, et)); | ||
Value posInf = bcast(f32FromBits(builder, 0x7f800000u)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we really need to get down to hex bits! how about floatCst(INFINITY)
or some such.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried using INFINITY and std::numeric_limits::infinity() , but it requires this specific inf, which seems to be used in other functions as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be only one positive infinity. But the difference between floatCst(builder, INFINITY, et)
and f32FromBits(builder, 0x7f800000u)
is that the former uses the element type et
while the latter is explicitly f32
.
Have you tried floatCst(builder, INFINITY, builder.getF32Type())
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, INFINITY worked fine. I must have make some error before. Thanks!
// Get abs(x) | ||
Value isNegativeArg = | ||
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, zero); | ||
Value negArg = builder.create<arith::NegFOp>(x); | ||
Value a = builder.create<arith::SelectOp>(isNegativeArg, negArg, x); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use math.absf
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
adadfbc
to
68574f7
Compare
This patch adds the erfc op to the math dialect. It also does lowering of the math.erfc op to libm calls. There is also a f32 polynomial approximation for the function based on https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf This is in turn based on M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of (1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol. 36, No. 153, January 1981, pp. 249-253. The code has a ULP error less than 3, which was tested, and MLIR test values were verified against the C implementation.
…ce is created for different Linux systems.
6abca7b
to
4a0b672
Compare
This patch adds the erfc op to the math dialect. It also does lowering of the math.erfc op to libm calls. There is also a f32 polynomial approximation for the function based on https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf This is in turn based on M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of (1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol. 36, No. 153, January 1981, pp. 249-253. The code has a ULP error less than 3, which was tested, and MLIR test values were verified against the C implementation.
This patch adds the erfc op to the math dialect. It also does lowering of the math.erfc op to libm calls. There is also a f32 polynomial approximation for the function based on
https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf This is in turn based on
M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of (1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol. 36, No. 153, January 1981, pp. 249-253.
The code has a ULP error less than 3, which was tested, and MLIR test values were verified against the C implementation.