From c666eb68de07d035e86694bd25ed4631db6e2134 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Sun, 28 Apr 2024 06:58:56 +0000 Subject: [PATCH] add celeu --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 12 +++++++----- projects/pt1/e2e_testing/xfail_sets.py | 3 +++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index ed7274ddf3fc..677ccc4f241b 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2434,11 +2434,13 @@ class DecomposeAtenCeluOp : public OpRewritePattern { rewriter.create(loc, rewriter.getI64IntegerAttr(0)); Value constantOne = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + + // positiveOutput = max(0,x) Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); - Value maxZeroX = + Value positiveOutput = rewriter.create(loc, resType, zeroTensor, input); - Value positiveOutput = rewriter.create( - loc, resType, maxZeroX, input, constantOne); + + // negativeOutput = min(0,alpha∗(exp(x/alpha)−1)) Value scaledInput = rewriter.create(loc, resType, input, alpha); Value expX = rewriter.create(loc, resType, scaledInput); @@ -2446,8 +2448,8 @@ class DecomposeAtenCeluOp : public OpRewritePattern { constantOne, constantOne); Value scaledExpXM1 = rewriter.create(loc, resType, expXM1, alpha); - Value negativeOutput = rewriter.create( - loc, resType, scaledExpXM1, zeroTensor); + Value negativeOutput = + rewriter.create(loc, resType, zeroTensor, scaledExpXM1); Value celuOutput = rewriter.create( loc, resType, positiveOutput, negativeOutput, constantOne); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 87344fb99b59..276cc47c1cc6 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -951,6 +951,7 @@ "ElementwiseBitwiseRightShiftInt64Module_basic", "ElementwiseBitwiseRightShiftInt8Module_basic", "ElementwiseCeilModule_basic", + "ElementwiseCeluStaticModule_basic", "ElementwiseClampMaxModule_basic", "ElementwiseClampMinModule_basic", "ElementwiseClampMinTensorFloatModule_basic", @@ -1571,6 +1572,8 @@ "ElementwiseBitwiseXorModule_basic", "ElementwiseBitwiseXorStaticShapeModule_basic", "ElementwiseCeilModule_basic", + "ElementwiseCeluModule_basic", + "ElementwiseCeluStaticModule_basic", "ElementwiseClampMaxModule_basic", "ElementwiseClampMinModule_basic", "ElementwiseClampModule_basic",