From 5f54d6eeda20b634ab60f20ba23520ac689067e1 Mon Sep 17 00:00:00 2001 From: Ivana Mitreski Date: Wed, 16 Apr 2025 13:23:12 +0200 Subject: [PATCH] Implement Decomposition for aten.outer --- .../Torch/Transforms/DecomposeComplexOps.cpp | 54 +++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 1 + .../torch_mlir_e2e_test/test_suite/matmul.py | 24 +++++++++ 4 files changed, 80 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index f18b424a6ed2..20beb614a748 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1894,6 +1894,59 @@ class DecomposeAtenAtleast1dOp : public OpRewritePattern { }; } // namespace +// Decompose 'aten.outer' into 'aten.unsqueeze', 'aten.matmul' + +namespace { +class DecomposeAtenOuterOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenOuterOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + Value input = op.getSelf(); + Value vec2 = op.getVec2(); + Type opType = op.getType(); + + auto inputType = cast(input.getType()); + auto vec2Type = cast(vec2.getType()); + + // Check if both tensors are 1-dimensional + SmallVector inputShape(inputType.getSizes()); + SmallVector vec2Shape(vec2Type.getSizes()); + + if (inputShape.size() == 1 && vec2Shape.size() == 1) { + + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); // Dimension index + SmallVector inputMatrixShape = {inputShape[0], 1}; + Type inputMatrixType = inputType.getWithSizesAndDtype( + inputMatrixShape, inputType.getOptionalDtype()); + + Value inputMatrix = + rewriter.create(loc, inputMatrixType, input, one); + + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + SmallVector vec2MatrixShape = {1, vec2Shape[0]}; + Type vec2MatrixType = vec2Type.getWithSizesAndDtype( + vec2MatrixShape, vec2Type.getOptionalDtype()); + + Value vec2Matrix = + rewriter.create(loc, vec2MatrixType, vec2, zero); + + rewriter.replaceOpWithNewOp(op, opType, inputMatrix, + vec2Matrix); + return success(); + } else { + return failure(); + } + + return success(); + } +}; +} // namespace + namespace { // Decompose aten.atleast_2d into: aten.reshape. See // https://github.com/pytorch/pytorch/blob/9a8ab778d34bd24c5caceb340837483decc4c311/torch/_refs/__init__.py#L2604 @@ -11955,6 +12008,7 @@ class DecomposeComplexOpsPass patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index ff9c5d969977..c067c91ced0e 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -399,6 +399,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 202378d1f9ac..6c9b6cb09fb7 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3855,6 +3855,7 @@ ONNX_TOSA_XFAIL_SET = { "AtenFftRfft2DLastDim_basic", + "AtenOuter_basic", "AtenFftRfft2DMiddleDim_basic", "AtenStftCenter1D_basic", "AtenStftCenter1DUnkSigLen_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index 17240cf953df..7948da1cd04b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -36,6 +36,30 @@ def AtenDotModule_basic(module, tu: TestUtils): # ============================================================================== +class AtenOuter(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True), + ] + ) + def forward(self, x, y): + return torch.outer(x, y) + + +@register_test_case(module_factory=lambda: AtenOuter()) +def AtenOuter_basic(module, tu: TestUtils): + module.forward(tu.rand(4), tu.rand(3)) + + +# ============================================================================== + + class MatmulDot(torch.nn.Module): def __init__(self): super().__init__()