Skip to content

Commit

Permalink
decompose AtenLerpTensorOp
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxinyu committed Apr 28, 2024
1 parent 46c0f3c commit 6f104b1
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
30 changes: 30 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2444,6 +2444,35 @@ class DecomposeAtenLerpScalarOp : public OpRewritePattern<AtenLerpScalarOp> {
};
} // namespace

namespace {
class DecomposeAtenLerpTensorOp : public OpRewritePattern<AtenLerpTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenLerpTensorOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
auto start = op.getSelf();
auto inputType = cast<BaseTensorType>(start.getType());

auto delta = rewriter.create<AtenSubTensorOp>(loc, inputType, op.getEnd(),
start, cstOne);

auto weightedDelta =
rewriter.create<AtenMulTensorOp>(loc, inputType, delta, op.getWeight());
auto lerp = rewriter.create<AtenAddTensorOp>(loc, inputType, start,
weightedDelta, cstOne);
rewriter.replaceOp(op, lerp);
return success();
}
};
} // namespace

// Elu = scale * max(0,x) + alpha * scale * (exp(min(0,x) * input_scale) - 1)
namespace {
class DecomposeAtenEluOp : public OpRewritePattern<AtenEluOp> {
Expand Down Expand Up @@ -7780,6 +7809,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLerpScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLerpTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyStridedOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<Aten_EmbeddingBagOp>();
target.addIllegalOp<AtenLiftFreshCopyOp>();
target.addIllegalOp<AtenLerpScalarOp>();
target.addIllegalOp<AtenLerpTensorOp>();
target.addIllegalOp<AtenIndexTensorOp>();
target.addIllegalOp<AtenMseLossOp>();
target.addIllegalOp<AtenRandintLowOp>();
Expand Down

0 comments on commit 6f104b1

Please sign in to comment.