Skip to content

Commit

Permalink
fix and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxinyu committed Apr 28, 2024
1 parent f4593cd commit 49d4bbb
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 0 deletions.
73 changes: 73 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2415,6 +2415,79 @@ class DecomposeAtenPreluOp : public OpRewritePattern<AtenPreluOp> {

} // namespace

// rrelu = max(0, x) + min(0, alpha * (exp(x) - 1))
// if in training mode, the alpha is sampled from uniform distribution (lower,
// upper) if in testing mode, the alpha is (lower + upper) / 2
namespace {
class DecomposeAtenRreluOp : public OpRewritePattern<AtenRreluOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenRreluOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
Value lower = op.getLower();
Value upper = op.getUpper();
auto resType = cast<BaseTensorType>(op.getType());
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}

bool training;
if (!matchPattern(op.getTraining(), m_TorchConstantBool(&training))) {
return rewriter.notifyMatchFailure(op, "training should be a constant");
}

Value constantZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value constantOneFloat =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value constantTwoFloat =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(2.0));
Value alpha;
if (training) {
// Create a uniform random op with low and high set to `lower` and
// `upper`, respectively.
Value none = rewriter.create<ConstantNoneOp>(loc);
Value zero =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
Value emptyTensor = rewriter.create<AtenFullLikeOp>(
loc, resType, self, zero, /*dtype=*/none, /*layout=*/none,
/*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none);
alpha = rewriter.create<AtenUniformOp>(loc, resType, emptyTensor,
/*from=*/lower, /*to=*/upper,
/*generator=*/none);
} else {
Value half = rewriter.create<AtenAddScalarOp>(loc, resType, lower, upper,
constantOneFloat);
alpha = rewriter.create<AtenDivScalarOp>(loc, resType, half,
constantTwoFloat);
}

Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero);
Value positiveOutput =
rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, self);
Value expX = rewriter.create<AtenExpOp>(loc, resType, self);
Value expXM1 = rewriter.create<AtenSubScalarOp>(
loc, resType, expX, constantOneFloat, constantOneFloat);
Value scaledExpXM1;
if (training) {
scaledExpXM1 =
rewriter.create<AtenMulTensorOp>(loc, resType, expXM1, alpha);
} else {
scaledExpXM1 =
rewriter.create<AtenMulScalarOp>(loc, resType, expXM1, alpha);
}
Value negativeOutput =
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, scaledExpXM1);
Value rreluOutput = rewriter.create<AtenAddTensorOp>(
loc, resType, positiveOutput, negativeOutput, constantOneFloat);
rewriter.replaceOp(op, rreluOutput);
return success();
}
};
} // namespace

// CELU(x)=max(0,x)+min(0,alpha∗(exp(x/alpha)−1))
namespace {
class DecomposeAtenCeluOp : public OpRewritePattern<AtenCeluOp> {
Expand Down
92 changes: 92 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,98 @@ def ElementwiseCeluModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseRreluTrainModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.rrelu(x, 0.4, 0.6, True)


@register_test_case(module_factory=lambda: ElementwiseRreluTrainModule())
def ElementwiseRreluTrainModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, low=-1, high=1))


# ==============================================================================


class ElementwiseRreluTrainStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([5, 3], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.rrelu(x, 0.1, 0.9, True)


@register_test_case(module_factory=lambda: ElementwiseRreluTrainStaticModule())
def ElementwiseRreluTrainStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, low=-1, high=1))


# ==============================================================================


class ElementwiseRreluEvalModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.rrelu(x, 0.4, 0.6, False)


@register_test_case(module_factory=lambda: ElementwiseRreluEvalModule())
def ElementwiseRreluEvalModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, low=-1, high=1))


# ==============================================================================


class ElementwiseRreluEvalStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([5, 3], torch.float32, True),
]
)
def forward(self, x):
return torch.ops.aten.rrelu(x, 0.1, 0.9, False)


@register_test_case(module_factory=lambda: ElementwiseRreluEvalStaticModule())
def ElementwiseRreluEvalStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, low=-1, high=1))


# ==============================================================================


class ElementwiseCeluStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down

0 comments on commit 49d4bbb

Please sign in to comment.