From 6f104b1457d7f3669e8633db876442394060c328 Mon Sep 17 00:00:00 2001 From: yangxinyu Date: Sun, 28 Apr 2024 17:34:29 +0000 Subject: [PATCH] decompose AtenLerpTensorOp --- .../Torch/Transforms/DecomposeComplexOps.cpp | 30 +++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + 2 files changed, 31 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 6cb02297d497..d8ef033a880e 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2444,6 +2444,35 @@ class DecomposeAtenLerpScalarOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenLerpTensorOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLerpTensorOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto resType = cast(op.getType()); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + auto start = op.getSelf(); + auto inputType = cast(start.getType()); + + auto delta = rewriter.create(loc, inputType, op.getEnd(), + start, cstOne); + + auto weightedDelta = + rewriter.create(loc, inputType, delta, op.getWeight()); + auto lerp = rewriter.create(loc, inputType, start, + weightedDelta, cstOne); + rewriter.replaceOp(op, lerp); + return success(); + } +}; +} // namespace + // Elu = scale * max(0,x) + alpha * scale * (exp(min(0,x) * input_scale) - 1) namespace { class DecomposeAtenEluOp : public OpRewritePattern { @@ -7780,6 +7809,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index c5855a1fa092..37168b95ee34 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -497,6 +497,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp();