Skip to content

Commit

Permalink
add celeu
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxinyu committed Apr 28, 2024
1 parent e6988e3 commit c666eb6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
12 changes: 7 additions & 5 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2434,20 +2434,22 @@ class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value constantOne =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));

// positiveOutput = max(0,x)
Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero);
Value maxZeroX =
Value positiveOutput =
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
Value positiveOutput = rewriter.create<AtenAddTensorOp>(
loc, resType, maxZeroX, input, constantOne);

// negativeOutput = min(0,alpha∗(exp(x/alpha)−1))
Value scaledInput =
rewriter.create<AtenDivScalarOp>(loc, resType, input, alpha);
Value expX = rewriter.create<AtenExpOp>(loc, resType, scaledInput);
Value expXM1 = rewriter.create<AtenSubScalarOp>(loc, resType, expX,
constantOne, constantOne);
Value scaledExpXM1 =
rewriter.create<AtenMulScalarOp>(loc, resType, expXM1, alpha);
Value negativeOutput = rewriter.create<AtenMulTensorOp>(
loc, resType, scaledExpXM1, zeroTensor);
Value negativeOutput =
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledExpXM1);
Value celuOutput = rewriter.create<AtenAddTensorOp>(
loc, resType, positiveOutput, negativeOutput, constantOne);

Expand Down
3 changes: 3 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,7 @@
"ElementwiseBitwiseRightShiftInt64Module_basic",
"ElementwiseBitwiseRightShiftInt8Module_basic",
"ElementwiseCeilModule_basic",
"ElementwiseCeluStaticModule_basic",
"ElementwiseClampMaxModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampMinTensorFloatModule_basic",
Expand Down Expand Up @@ -1571,6 +1572,8 @@
"ElementwiseBitwiseXorModule_basic",
"ElementwiseBitwiseXorStaticShapeModule_basic",
"ElementwiseCeilModule_basic",
"ElementwiseCeluModule_basic",
"ElementwiseCeluStaticModule_basic",
"ElementwiseClampMaxModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampModule_basic",
Expand Down

0 comments on commit c666eb6

Please sign in to comment.