diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index da71cf99cb2d..296c812637b1 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2415,7 +2415,7 @@ class DecomposeAtenPreluOp : public OpRewritePattern { } // namespace -// rrelu = max(0, x) + min(0, alpha * (exp(x) - 1)) +// rrelu = max(0, x) + min(0, alpha * x) // if in training mode, the alpha is sampled from uniform distribution (lower, // upper) if in testing mode, the alpha is (lower + upper) / 2 namespace { @@ -2438,48 +2438,46 @@ class DecomposeAtenRreluOp : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "training should be a constant"); } - Value constantZero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value constantZeroFloat = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); Value constantOneFloat = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); Value constantTwoFloat = rewriter.create(loc, rewriter.getF64FloatAttr(2.0)); + Value alpha; if (training) { // Create a uniform random op with low and high set to `lower` and // `upper`, respectively. Value none = rewriter.create(loc); - Value zero = - rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); Value emptyTensor = rewriter.create( - loc, resType, self, zero, /*dtype=*/none, /*layout=*/none, + loc, resType, self, constantZeroFloat, /*dtype=*/none, + /*layout=*/none, /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none); alpha = rewriter.create(loc, resType, emptyTensor, /*from=*/lower, /*to=*/upper, /*generator=*/none); } else { - Value half = rewriter.create(loc, resType, lower, upper, - constantOneFloat); - alpha = rewriter.create(loc, resType, half, - constantTwoFloat); + Value half = rewriter.create(loc, constantTwoFloat.getType(), + lower, upper); + alpha = rewriter.create(loc, constantTwoFloat.getType(), half, + constantTwoFloat); } - Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); + Value zeroTensor = + createRank0Tensor(rewriter, loc, resType, constantZeroFloat); Value positiveOutput = rewriter.create(loc, resType, zeroTensor, self); - Value expX = rewriter.create(loc, resType, self); - Value expXM1 = rewriter.create( - loc, resType, expX, constantOneFloat, constantOneFloat); - Value scaledExpXM1; + + Value scaledSelf; if (training) { - scaledExpXM1 = - rewriter.create(loc, resType, expXM1, alpha); + scaledSelf = rewriter.create(loc, resType, self, alpha); } else { - scaledExpXM1 = - rewriter.create(loc, resType, expXM1, alpha); + scaledSelf = rewriter.create(loc, resType, self, alpha); } + Value negativeOutput = - rewriter.create(loc, resType, zeroTensor, scaledExpXM1); + rewriter.create(loc, resType, zeroTensor, scaledSelf); Value rreluOutput = rewriter.create( loc, resType, positiveOutput, negativeOutput, constantOneFloat); rewriter.replaceOp(op, rreluOutput); @@ -7822,6 +7820,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index e7bed6463552..f9e2c2fd57d9 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -474,6 +474,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 276cc47c1cc6..e31b79a88e39 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1004,6 +1004,8 @@ "ElementwiseRemainderTensorModule_Float_basic", "ElementwiseRemainderTensorModule_Int_Float_basic", "ElementwiseRemainderTensorModule_Int_basic", + "ElementwiseRreluEvalStaticModule_basic", + "ElementwiseRreluTrainStaticModule_basic", "ElementwiseRsqrtModule_basic", "ElementwiseSigmoidModule_basic", "ElementwiseSinModule_basic", @@ -1658,6 +1660,8 @@ "ElementwiseRemainderScalarModule_Int_Float_basic", "ElementwiseRemainderScalarModule_Int_basic", "ElementwiseRemainderScalarModule_Int_basic", + "ElementwiseRreluEvalModule_basic", + "ElementwiseRreluEvalStaticModule_basic", "ElementwiseRsqrtModule_basic", "ElementwiseSeluModule_basic", "ElementwiseSigmoidModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 95f742f314e5..33e23c4fc915 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1051,12 +1051,13 @@ def __init__(self): ] ) def forward(self, x): - return torch.ops.aten.rrelu(x, 0.4, 0.6, True) + res = torch.ops.aten.rrelu(x, 0.4, 0.6, True) + return torch.mean(res), torch.std(res) @register_test_case(module_factory=lambda: ElementwiseRreluTrainModule()) def ElementwiseRreluTrainModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 3, low=-1, high=1)) + module.forward(tu.rand(1024, 1536)) # ============================================================================== @@ -1070,16 +1071,17 @@ def __init__(self): @annotate_args( [ None, - ([5, 3], torch.float32, True), + ([1024, 1536], torch.float32, True), ] ) def forward(self, x): - return torch.ops.aten.rrelu(x, 0.1, 0.9, True) + res = torch.ops.aten.rrelu(x, 0.1, 0.9, True) + return torch.mean(res), torch.std(res) @register_test_case(module_factory=lambda: ElementwiseRreluTrainStaticModule()) def ElementwiseRreluTrainStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(5, 3, low=-1, high=1)) + module.forward(tu.rand(1024, 1536)) # ==============================================================================