Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][Math] Add erfc to math dialect #126439

Merged
merged 9 commits into from
Feb 18, 2025
Merged

Conversation

jsjodin
Copy link
Contributor

@jsjodin jsjodin commented Feb 9, 2025

This patch adds the erfc op to the math dialect. It also does lowering of the math.erfc op to libm calls. There is also a f32 polynomial approximation for the function based on
https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf This is in turn based on
M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of (1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol. 36, No. 153, January 1981, pp. 249-253.
The code has a ULP error less than 3, which was tested, and MLIR test values were verified against the C implementation.

@llvmbot
Copy link
Member

llvmbot commented Feb 9, 2025

@llvm/pr-subscribers-mlir-math

@llvm/pr-subscribers-mlir

Author: Jan Leyonberg (jsjodin)

Changes

This patch adds the erfc op to the math dialect. It also does lowering of the math.erfc op to libm calls. There is also a f32 polynomial approximation for the function based on
https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf This is in turn based on
M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of (1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol. 36, No. 153, January 1981, pp. 249-253.
The code has a ULP error less than 3, which was tested, and MLIR test values were verified against the C implementation.


Patch is 23.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/126439.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Math/IR/MathOps.td (+22)
  • (modified) mlir/include/mlir/Dialect/Math/Transforms/Approximation.h (+8)
  • (modified) mlir/include/mlir/Dialect/Math/Transforms/Passes.h (+1)
  • (modified) mlir/lib/Conversion/MathToLibm/MathToLibm.cpp (+1)
  • (modified) mlir/lib/Dialect/Math/IR/MathOps.cpp (+18)
  • (modified) mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp (+109-9)
  • (modified) mlir/test/Dialect/Math/polynomial-approximation.mlir (+112)
  • (modified) mlir/test/mlir-runner/math-polynomial-approx.mlir (+72)
  • (modified) mlir/utils/vim/syntax/mlir.vim (+1)
diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
index 5990a9f0d2e442b..67d2d5168fe374e 100644
--- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td
+++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td
@@ -543,6 +543,28 @@ def Math_ErfOp : Math_FloatUnaryOp<"erf"> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// ErfcOp
+//===----------------------------------------------------------------------===//
+
+def Math_ErfcOp : Math_FloatUnaryOp<"erfc"> {
+  let summary = "complementary error function of the specified value";
+  let description = [{
+    The `erfc` operation computes the complementary error function.
+    It takes one operand of floating point type (i.e., scalar, tensor or
+    vector) and returns one result of the same type.
+    It has no standard attributes.
+
+    Example:
+
+    ```mlir
+    // Scalar error function value.
+    %a = math.erfc %b : f64
+    ```
+  }];
+  let hasFolder = 1;
+}
+
 
 //===----------------------------------------------------------------------===//
 // ExpOp
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Approximation.h b/mlir/include/mlir/Dialect/Math/Transforms/Approximation.h
index b4ebc2f0f8fcd28..ecfdb71817dffd9 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Approximation.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Approximation.h
@@ -23,6 +23,14 @@ struct ErfPolynomialApproximation : public OpRewritePattern<math::ErfOp> {
                                 PatternRewriter &rewriter) const final;
 };
 
+struct ErfcPolynomialApproximation : public OpRewritePattern<math::ErfcOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(math::ErfcOp op,
+                                PatternRewriter &rewriter) const final;
+};
+
 } // namespace math
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index f0f17c6adcb088e..b8055354645244b 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -47,6 +47,7 @@ struct MathPolynomialApproximationOptions {
 
 void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns);
 void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns);
+void populatePolynomialApproximateErfcPattern(RewritePatternSet &patterns);
 
 void populateMathPolynomialApproximationPatterns(
     RewritePatternSet &patterns,
diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
index a2488dc600f51af..1a568d549548083 100644
--- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
+++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
@@ -175,6 +175,7 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns) {
   populatePatternsForOp<math::CosOp>(patterns, ctx, "cosf", "cos");
   populatePatternsForOp<math::CoshOp>(patterns, ctx, "coshf", "cosh");
   populatePatternsForOp<math::ErfOp>(patterns, ctx, "erff", "erf");
+  populatePatternsForOp<math::ErfcOp>(patterns, ctx, "erfcf", "erfc");
   populatePatternsForOp<math::ExpOp>(patterns, ctx, "expf", "exp");
   populatePatternsForOp<math::Exp2Op>(patterns, ctx, "exp2f", "exp2");
   populatePatternsForOp<math::ExpM1Op>(patterns, ctx, "expm1f", "expm1");
diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp
index 1690585e78c5dad..93d19a69a701474 100644
--- a/mlir/lib/Dialect/Math/IR/MathOps.cpp
+++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp
@@ -318,6 +318,24 @@ OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
       });
 }
 
+//===----------------------------------------------------------------------===//
+// ErfcOp folder
+//===----------------------------------------------------------------------===//
+
+OpFoldResult math::ErfcOp::fold(FoldAdaptor adaptor) {
+  return constFoldUnaryOpConditional<FloatAttr>(
+      adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
+        switch (a.getSizeInBits(a.getSemantics())) {
+        case 64:
+          return APFloat(erfc(a.convertToDouble()));
+        case 32:
+          return APFloat(erfcf(a.convertToFloat()));
+        default:
+          return {};
+        }
+      });
+}
+
 //===----------------------------------------------------------------------===//
 // IPowIOp folder
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 24c892f68b50316..1585ebc74e7908f 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -173,6 +173,10 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
 // Helper functions to create constants.
 //----------------------------------------------------------------------------//
 
+static Value boolCst(ImplicitLocOpBuilder &builder, bool value) {
+  return builder.create<arith::ConstantOp>(builder.getBoolAttr(value));
+}
+
 static Value floatCst(ImplicitLocOpBuilder &builder, float value,
                       Type elementType) {
   assert((elementType.isF16() || elementType.isF32()) &&
@@ -1118,12 +1122,102 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
   return success();
 }
 
+// Approximates erfc(x) with
+LogicalResult
+ErfcPolynomialApproximation::matchAndRewrite(math::ErfcOp op,
+                                             PatternRewriter &rewriter) const {
+  Value x = op.getOperand();
+  Type et = getElementTypeOrSelf(x);
+
+  if (!et.isF32())
+    return rewriter.notifyMatchFailure(op, "only f32 type is supported.");
+  std::optional<VectorShape> shape = vectorShape(x);
+
+  ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+  auto bcast = [&](Value value) -> Value {
+    return broadcast(builder, value, shape);
+  };
+
+  Value trueValue = bcast(boolCst(builder, true));
+  Value zero = bcast(floatCst(builder, 0.0f, et));
+  Value one = bcast(floatCst(builder, 1.0f, et));
+  Value onehalf = bcast(floatCst(builder, 0.5f, et));
+  Value neg4 = bcast(floatCst(builder, -4.0f, et));
+  Value neg2 = bcast(floatCst(builder, -2.0f, et));
+  Value pos2 = bcast(floatCst(builder, 2.0f, et));
+  Value posInf = bcast(f32FromBits(builder, 0x7f800000u));
+  Value clampVal = bcast(floatCst(builder, 10.0546875f, et));
+
+  // Get abs(x)
+  Value isNegativeArg =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, zero);
+  Value negArg = builder.create<arith::NegFOp>(x);
+  Value a = builder.create<arith::SelectOp>(isNegativeArg, negArg, x);
+  Value p = builder.create<arith::AddFOp>(a, pos2);
+  Value r = builder.create<arith::DivFOp>(one, p);
+  Value q = builder.create<math::FmaOp>(neg4, r, one);
+  Value t = builder.create<math::FmaOp>(builder.create<arith::AddFOp>(q, one),
+                                        neg2, a);
+  Value e = builder.create<math::FmaOp>(builder.create<arith::NegFOp>(a), q, t);
+  q = builder.create<math::FmaOp>(r, e, q);
+
+  p = bcast(floatCst(builder, -0x1.a4a000p-12f, et));        // -4.01139259e-4
+  Value c1 = bcast(floatCst(builder, -0x1.42a260p-10f, et)); // -1.23075210e-3
+  p = builder.create<math::FmaOp>(p, q, c1);
+  Value c2 = bcast(floatCst(builder, 0x1.585714p-10f, et)); //  1.31355342e-3
+  p = builder.create<math::FmaOp>(p, q, c2);
+  Value c3 = bcast(floatCst(builder, 0x1.1adcc4p-07f, et)); // 8.63227434e-3
+  p = builder.create<math::FmaOp>(p, q, c3);
+  Value c4 = bcast(floatCst(builder, -0x1.081b82p-07f, et)); // -8.05991981e-3
+  p = builder.create<math::FmaOp>(p, q, c4);
+  Value c5 = bcast(floatCst(builder, -0x1.bc0b6ap-05f, et)); // -5.42046614e-2
+  p = builder.create<math::FmaOp>(p, q, c5);
+  Value c6 = bcast(floatCst(builder, 0x1.4ffc46p-03f, et)); //  1.64055392e-1
+  p = builder.create<math::FmaOp>(p, q, c6);
+  Value c7 = bcast(floatCst(builder, -0x1.540840p-03f, et)); // -1.66031361e-1
+  p = builder.create<math::FmaOp>(p, q, c7);
+  Value c8 = bcast(floatCst(builder, -0x1.7bf616p-04f, et)); // -9.27639827e-2
+  p = builder.create<math::FmaOp>(p, q, c8);
+  Value c9 = bcast(floatCst(builder, 0x1.1ba03ap-02f, et)); // 2.76978403e-1
+  p = builder.create<math::FmaOp>(p, q, c9);
+
+  Value d = builder.create<math::FmaOp>(pos2, a, one);
+  r = builder.create<arith::DivFOp>(one, d);
+  q = builder.create<math::FmaOp>(p, r, r);
+  e = builder.create<math::FmaOp>(
+      builder.create<math::FmaOp>(q, builder.create<arith::NegFOp>(a), onehalf),
+      pos2, builder.create<arith::SubFOp>(p, q));
+  r = builder.create<math::FmaOp>(e, r, q);
+
+  Value s = builder.create<arith::MulFOp>(a, a);
+  e = builder.create<math::ExpOp>(builder.create<arith::NegFOp>(s));
+
+  t = builder.create<math::FmaOp>(builder.create<arith::NegFOp>(a), a, s);
+  r = builder.create<math::FmaOp>(
+      r, e,
+      builder.create<arith::MulFOp>(builder.create<arith::MulFOp>(r, e), t));
+
+  Value isNotLessThanInf = builder.create<arith::XOrIOp>(
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, a, posInf),
+      trueValue);
+  r = builder.create<arith::SelectOp>(isNotLessThanInf,
+                                      builder.create<arith::AddFOp>(x, x), r);
+  Value isGreaterThanClamp =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, a, clampVal);
+  r = builder.create<arith::SelectOp>(isGreaterThanClamp, zero, r);
+
+  Value isNegative =
+      builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, zero);
+  r = builder.create<arith::SelectOp>(
+      isNegative, builder.create<arith::SubFOp>(pos2, r), r);
+
+  rewriter.replaceOp(op, r);
+  return success();
+}
 //----------------------------------------------------------------------------//
 // Exp approximation.
 //----------------------------------------------------------------------------//
-
 namespace {
-
 Value clampWithNormals(ImplicitLocOpBuilder &builder,
                        const std::optional<VectorShape> shape, Value value,
                        float lowerBound, float upperBound) {
@@ -1667,6 +1761,11 @@ void mlir::populatePolynomialApproximateErfPattern(
   patterns.add<ErfPolynomialApproximation>(patterns.getContext());
 }
 
+void mlir::populatePolynomialApproximateErfcPattern(
+    RewritePatternSet &patterns) {
+  patterns.add<ErfcPolynomialApproximation>(patterns.getContext());
+}
+
 void mlir::populateMathPolynomialApproximationPatterns(
     RewritePatternSet &patterns,
     const MathPolynomialApproximationOptions &options) {
@@ -1680,13 +1779,14 @@ void mlir::populateMathPolynomialApproximationPatterns(
            ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
           patterns.getContext());
 
-  patterns
-      .add<AtanApproximation, Atan2Approximation, TanhApproximation,
-           LogApproximation, Log2Approximation, Log1pApproximation,
-           ErfPolynomialApproximation, AsinPolynomialApproximation,
-           AcosPolynomialApproximation, ExpApproximation, ExpM1Approximation,
-           CbrtApproximation, SinAndCosApproximation<true, math::SinOp>,
-           SinAndCosApproximation<false, math::CosOp>>(patterns.getContext());
+  patterns.add<AtanApproximation, Atan2Approximation, TanhApproximation,
+               LogApproximation, Log2Approximation, Log1pApproximation,
+               ErfPolynomialApproximation, ErfcPolynomialApproximation,
+               AsinPolynomialApproximation, AcosPolynomialApproximation,
+               ExpApproximation, ExpM1Approximation, CbrtApproximation,
+               SinAndCosApproximation<true, math::SinOp>,
+               SinAndCosApproximation<false, math::CosOp>>(
+      patterns.getContext());
   if (options.enableAvx2) {
     patterns.add<RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
         patterns.getContext());
diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir
index 81d071e6bbba368..badc95fa2d4aac6 100644
--- a/mlir/test/Dialect/Math/polynomial-approximation.mlir
+++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir
@@ -81,6 +81,118 @@ func.func @erf_scalar(%arg0: f32) -> f32 {
   return %0 : f32
 }
 
+// CHECK-LABEL: func @erfc_scalar(
+// CHECK-SAME:    %[[val_arg0:.*]]: f32) -> f32 {
+// CHECK-DAG:     %[[c127_i32:.*]] = arith.constant 127 : i32
+// CHECK-DAG:     %[[c23_i32:.*]] = arith.constant 23 : i32
+// CHECK-DAG:     %[[cst:.*]] = arith.constant 1.270000e+02 : f32
+// CHECK-DAG:     %[[cst_0:.*]] = arith.constant -1.270000e+02 : f32
+// CHECK-DAG:     %[[cst_1:.*]] = arith.constant 8.880000e+01 : f32
+// CHECK-DAG:     %[[cst_2:.*]] = arith.constant -8.780000e+01 : f32
+// CHECK-DAG:     %[[cst_3:.*]] = arith.constant 0.166666657 : f32
+// CHECK-DAG:     %[[cst_4:.*]] = arith.constant 0.0416657962 : f32
+// CHECK-DAG:     %[[cst_5:.*]] = arith.constant 0.00833345205 : f32
+// CHECK-DAG:     %[[cst_6:.*]] = arith.constant 0.00139819994 : f32
+// CHECK-DAG:     %[[cst_7:.*]] = arith.constant 1.98756912E-4 : f32
+// CHECK-DAG:     %[[cst_8:.*]] = arith.constant 2.12194442E-4 : f32
+// CHECK-DAG:     %[[cst_9:.*]] = arith.constant -0.693359375 : f32
+// CHECK-DAG:     %[[cst_10:.*]] = arith.constant 1.44269502 : f32
+// CHECK-DAG:     %[[cst_11:.*]] = arith.constant 0.276978403 : f32
+// CHECK-DAG:     %[[cst_12:.*]] = arith.constant -0.0927639827 : f32
+// CHECK-DAG:     %[[cst_13:.*]] = arith.constant -0.166031361 : f32
+// CHECK-DAG:     %[[cst_14:.*]] = arith.constant 0.164055392 : f32
+// CHECK-DAG:     %[[cst_15:.*]] = arith.constant -0.0542046614 : f32
+// CHECK-DAG:     %[[cst_16:.*]] = arith.constant -8.059920e-03 : f32
+// CHECK-DAG:     %[[cst_17:.*]] = arith.constant 0.00863227434 : f32
+// CHECK-DAG:     %[[cst_18:.*]] = arith.constant 0.00131355342 : f32
+// CHECK-DAG:     %[[cst_19:.*]] = arith.constant -0.0012307521 : f32
+// CHECK-DAG:     %[[cst_20:.*]] = arith.constant -4.01139259E-4 : f32
+// CHECK-DAG:     %[[cst_true:.*]] = arith.constant true
+// CHECK-DAG:     %[[cst_21:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG:     %[[cst_22:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-DAG:     %[[cst_23:.*]] = arith.constant 5.000000e-01 : f32
+// CHECK-DAG:     %[[cst_24:.*]] = arith.constant -4.000000e+00 : f32
+// CHECK-DAG:     %[[cst_25:.*]] = arith.constant -2.000000e+00 : f32
+// CHECK-DAG:     %[[cst_26:.*]] = arith.constant 2.000000e+00 : f32
+// CHECK-DAG:     %[[cst_27:.*]] = arith.constant 0x7F800000 : f32
+// CHECK-DAG:     %[[cst_28:.*]] = arith.constant 10.0546875 : f32
+// CHECK:         %[[val_0:.*]] = arith.cmpf olt, %[[val_arg0]], %[[cst_21]] : f32
+// CHECK:         %[[val_1:.*]] = arith.negf %[[val_arg0]] : f32
+// CHECK:         %[[val_2:.*]] = arith.select %[[val_0]], %[[val_1]], %[[val_arg0]] : f32
+// CHECK:         %[[val_3:.*]] = arith.addf %[[val_2]], %[[cst_26]] : f32
+// CHECK:         %[[val_4:.*]] = arith.divf %[[cst_22]], %[[val_3]] : f32
+// CHECK:         %[[val_5:.*]] = math.fma %[[cst_24]], %[[val_4]], %[[cst_22]] : f32
+// CHECK:         %[[val_6:.*]] = arith.addf %[[val_5]], %[[cst_22]] : f32
+// CHECK:         %[[val_7:.*]] = math.fma %[[val_6]], %[[cst_25]], %[[val_2]] : f32
+// CHECK:         %[[val_8:.*]] = arith.negf %[[val_2]] : f32
+// CHECK:         %[[val_9:.*]] = math.fma %[[val_8]], %[[val_5]], %[[val_7]] : f32
+// CHECK:         %[[val_10:.*]] = math.fma %[[val_4]], %[[val_9]], %[[val_5]] : f32
+// CHECK:         %[[val_11:.*]] = math.fma %[[cst_20]], %[[val_10]], %[[cst_19]] : f32
+// CHECK:         %[[val_12:.*]] = math.fma %[[val_11]], %[[val_10]], %[[cst_18]] : f32
+// CHECK:         %[[val_13:.*]] = math.fma %[[val_12]], %[[val_10]], %[[cst_17]] : f32
+// CHECK:         %[[val_14:.*]] = math.fma %[[val_13]], %[[val_10]], %[[cst_16]] : f32
+// CHECK:         %[[val_15:.*]] = math.fma %[[val_14]], %[[val_10]], %[[cst_15]] : f32
+// CHECK:         %[[val_16:.*]] = math.fma %[[val_15]], %[[val_10]], %[[cst_14]] : f32
+// CHECK:         %[[val_17:.*]] = math.fma %[[val_16]], %[[val_10]], %[[cst_13]] : f32
+// CHECK:         %[[val_18:.*]] = math.fma %[[val_17]], %[[val_10]], %[[cst_12]] : f32
+// CHECK:         %[[val_19:.*]] = math.fma %[[val_18]], %[[val_10]], %[[cst_11]] : f32
+// CHECK:         %[[val_20:.*]] = math.fma %[[cst_26]], %[[val_2]], %[[cst_22]] : f32
+// CHECK:         %[[val_21:.*]] = arith.divf %[[cst_22]], %[[val_20]] : f32
+// CHECK:         %[[val_22:.*]] = math.fma %[[val_19]], %[[val_21]], %[[val_21]] : f32
+// CHECK:         %[[val_23:.*]] = arith.subf %[[val_19]], %[[val_22]] : f32
+// CHECK:         %[[val_24:.*]] = arith.negf %[[val_2]] : f32
+// CHECK:         %[[val_25:.*]] = math.fma %[[val_22]], %[[val_24]], %[[cst_23]] : f32
+// CHECK:         %[[val_26:.*]] = math.fma %[[val_25]], %[[cst_26]], %[[val_23]] : f32
+// CHECK:         %[[val_27:.*]] = math.fma %[[val_26]], %[[val_21]], %[[val_22]] : f32
+// CHECK:         %[[val_28:.*]] = arith.mulf %[[val_2]], %[[val_2]] : f32
+// CHECK:         %[[val_29:.*]] = arith.negf %[[val_28]] : f32
+// CHECK:         %[[val_30:.*]] = arith.cmpf uge, %[[val_29]], %[[cst_2]] : f32
+// CHECK:         %[[val_31:.*]] = arith.select %[[val_30]], %[[val_29]], %[[cst_2]] : f32
+// CHECK:         %[[val_32:.*]] = arith.cmpf ule, %[[val_31]], %[[cst_1]] : f32
+// CHECK:         %[[val_33:.*]] = arith.select %[[val_32]], %[[val_31]], %[[cst_1]] : f32
+// CHECK:         %[[val_34:.*]] = math.fma %[[val_33]], %[[cst_10]], %[[cst_23]] : f32
+// CHECK:         %[[val_35:.*]] = math.floor %[[val_34]] : f32
+// CHECK:         %[[val_36:.*]] = arith.cmpf uge, %[[val_35]], %[[cst_0]] : f32
+// CHECK:         %[[val_37:.*]] = arith.select %[[val_36]], %[[val_35]], %[[cst_0]] : f32
+// CHECK:         %[[val_38:.*]] = arith.cmpf ule, %[[val_37]], %[[cst]] : f32
+// CHECK:         %[[val_39:.*]] = arith.select %[[val_38]], %[[val_37]], %[[cst]] : f32
+// CHECK:         %[[val_40:.*]] = math.fma %[[cst_9]], %[[val_39]], %[[val_33]] : f32
+// CHECK:         %[[val_41:.*]] = math.fma %[[cst_8]], %[[val_39]], %[[val_40]] : f32
+// CHECK:         %[[val_42:.*]] = math.fma %[[val_41]], %[[cst_7]], %[[cst_6]] : f32
+// CHECK:         %[[val_43:.*]] = math.fma %[[val_42]], %[[val_41]], %[[cst_5]] : f32
+// CHECK:         %[[val_44:.*]] = math.fma %[[val_43]], %[[val_41]], %[[cst_4]] : f32
+// CHECK:         %[[val_45:.*]] = math.fma %[[val_44]], %[[val_41]], %[[cst_3]] : f32
+// CHECK:         %[[val_46:.*]] = math.fma %[[val_45]], %[[val_41]], %[[cst_23]] : f32
+// CHECK:         %[[val_47:.*]] = arith.mulf %[[val_41]], %[[val_41]] : f32
+// CHECK:         %[[val_48:.*]] = math.fma %[[val_46]], %[[val_47]], %[[val_41]] : f32
+// CHECK:         %[[val_49:.*]] = arith.addf %[[val_48]], %[[cst_22]] : f32
+// CHECK:         %[[val_50:.*]] = arith.fptosi %[[val_39]] : f32 to i32
+// CHECK:         %[[val_51:.*]] = arith.addi %[[val_50]], %[[c127_i32]] : i32
+// CHECK:         %[[val_52:.*]] = arith.shli %[[val_51]], %[[c23_i32]] : i32
+// CHECK:         %[[val_53:.*]] = arith.bitcast %[[val_52]] : i32 to f32
+// CHECK:         %[[val_54:.*]] = arith.mulf %[[val_49]], %[[val_53]] : f32
+// CHECK:         %[[val_55:.*]] = arith.negf %[[val_2]] : f32
+// CHECK:         %[[val_56:.*]] = math.fma %[[val_55]], %[[val_2]], %[[val_28]] : f32
+// CHECK:         %[[val_57:.*]] = arith.mulf %[[val_27]], %[[val_54]] : f32
+// CHECK:         %[[val_58:.*]] = arith.mulf %[[val_57]], %[[val_56]] : f32
+// CHECK:         %[[val_59:.*]] = math.fma %[[val_27]], %[[val_54]], %[[val_58]] : f32
+// CHECK:         %[[val_60:.*]] = arith.cmpf olt, %[[val_2]], %[[cst_27]] : f32
+// CHECK:         %[[val_61:.*]] = arith.xori %[[val_60]], %[[cst_true]] : i1
+// CHECK:         %[[val_62:.*]] = arith.addf %[[val_arg0]], %[[val_arg0]] : f32
+// CHECK:         %[[val_63:.*]] = arith.select %[[val_61]], %[[val_62]], %[[val_59]] : f32
+// CHECK:         %[[val_64:.*]] = arith.cmpf ogt, %[[val_2]], %[[cst_28]] : f32
+// CHECK:         %[[val_65:.*]] = arith.select %[[val_64]], %[[cst_21]], %[[val_63]] : f32
+// CHECK:         %[[val_66:.*]] = arith.cmpf olt, %[[val_arg0]], %[[cst_21]] : f32
+// CHECK:         %[[val_67:.*]] = arith.subf %[[cst_26]], %[[val_65]] : f32
+// CHECK:         %[[val_68:.*]] = arith.select %[[val_66]], %[[val_67]], %[[val_65]]...
[truncated]

@@ -1118,12 +1122,103 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
return success();
}

// Approximates erfc(x) with
Copy link
Member

@pashu123 pashu123 Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please describe how it's being approximated and the StackOverflow URL.

@jsjodin jsjodin requested a review from pashu123 February 10, 2025 22:43
@pashu123 pashu123 requested a review from bjacob February 12, 2025 22:05
Copy link
Member

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Please wait for other reviewers.

The `erfc` operation computes the complementary error function.
It takes one operand of floating point type (i.e., scalar, tensor or
vector) and returns one result of the same type.
It has no standard attributes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a detailed description of erfc op here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also a word on why it exists as a separate op, as opposed to letting people write 1 - erf(x) (it's about floating point accuracy vs cancellation when erf(x) is close to 1, and it's standard in the C math.h library etc).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I improved the wording a bit. Hope it helps!

Comment on lines 328 to 331
switch (a.getSizeInBits(a.getSemantics())) {
case 64:
return APFloat(erfc(a.convertToDouble()));
case 32:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not familiar with fltSemantics, but at least the idea is that there could be multiple semantics sharing the same bit width, like there is at 16 bit between fp16 and bf16. So maybe there is something a bit more righteous that could be done here, matching exact semantics of IEEE single / double precision rather than merely the size in bits.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point, it would have been better to have a function that would return an enum perhaps, but this would be better in a separate PR, since this is the pattern that is used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant something more specific to the code being added here.

Instead of switching on the size in bits, switch on the semantics itself.

Reading APFloat.h, I see that there is:

  static Semantics SemanticsToEnum(const llvm::fltSemantics &Sem);

so here you could do:

switch (APFloat::SemanticsToEnum(a.getSemantics)) {

and have your case: statement based on enumerators.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, that's much better. I wasn't aware they already had an enum.

Value neg4 = bcast(floatCst(builder, -4.0f, et));
Value neg2 = bcast(floatCst(builder, -2.0f, et));
Value pos2 = bcast(floatCst(builder, 2.0f, et));
Value posInf = bcast(f32FromBits(builder, 0x7f800000u));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need to get down to hex bits! how about floatCst(INFINITY) or some such.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using INFINITY and std::numeric_limits::infinity() , but it requires this specific inf, which seems to be used in other functions as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be only one positive infinity. But the difference between floatCst(builder, INFINITY, et) and f32FromBits(builder, 0x7f800000u) is that the former uses the element type et while the latter is explicitly f32.

Have you tried floatCst(builder, INFINITY, builder.getF32Type()) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, INFINITY worked fine. I must have make some error before. Thanks!

Comment on lines 1153 to 1157
// Get abs(x)
Value isNegativeArg =
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, zero);
Value negArg = builder.create<arith::NegFOp>(x);
Value a = builder.create<arith::SelectOp>(isNegativeArg, negArg, x);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use math.absf.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

This patch adds the erfc op to the math dialect. It also does lowering of the
math.erfc op to libm calls. There is also a f32 polynomial approximation for
the function based on
https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf
This is in turn based on
M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of
(1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol. 36,
No. 153, January 1981, pp. 249-253.
The code has a ULP error less than 3, which was tested, and MLIR test values
were verified against the C implementation.
@jsjodin jsjodin merged commit 8806311 into llvm:main Feb 18, 2025
8 checks passed
wldfngrs pushed a commit to wldfngrs/llvm-project that referenced this pull request Feb 19, 2025
This patch adds the erfc op to the math dialect. It also does lowering
of the math.erfc op to libm calls. There is also a f32 polynomial
approximation for the function based on

https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf
This is in turn based on
M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of
(1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol.
36, No. 153, January 1981, pp. 249-253.
The code has a ULP error less than 3, which was tested, and MLIR test
values were verified against the C implementation.
@jsjodin jsjodin deleted the jleyonberg/erfc branch February 22, 2025 02:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants