diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 5dfce4b465e2..ff4e1637ee90 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -8038,105 +8038,17 @@ class DecomposeAtenToDeviceOp : public OpRewritePattern { } // namespace namespace { +// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d`. // Decompose `aten.adaptive_max_pool1d` op into `aten.max_pool1d_with_indices` -// op. -class DecomposeAtenAdaptiveMaxPool1dOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenAdaptiveMaxPool1dOp op, - PatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - MLIRContext *context = op.getContext(); - - Value input = op.getSelf(); - std::optional maybeRank = getTensorRank(input); - if (!maybeRank) { - return rewriter.notifyMatchFailure(op, "expected input to have a rank"); - } - unsigned rank = *maybeRank; - Value sizeDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(rank - 1)); - Value inputSize = rewriter.create(loc, input, sizeDim); - - Value outputShape = op.getOutputSize(); - SmallVector outputShapeSizesTorchInt; - getListConstructElements(outputShape, outputShapeSizesTorchInt); - Value outputSize = outputShapeSizesTorchInt[0]; - - Value constantOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value constantZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value constantFalse = rewriter.create(loc, false); - - int64_t outputSizeInt; - if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) { - return rewriter.notifyMatchFailure( - op, "the output size of adaptive_max_pool1d must be a constant int"); - } - - SmallVector kernelSize; - if (outputSizeInt == 1) { - BaseTensorType inputTensorType = cast(input.getType()); - ArrayRef inputShape = inputTensorType.getSizes(); - kernelSize.push_back( - inputShape[rank - 1] == kUnknownSize - ? inputSize - : rewriter.create( - loc, rewriter.getI64IntegerAttr(inputShape[rank - 1]))); - } else { - if (!isAssumingStrictSymbolicShapes(rewriter)) { - Value cond = rewriter.create(loc, inputSize, outputSize); - rewriter.create( - loc, cond, - "unimplemented: only support cases where input and output size are " - "equal for non-unit output size"); - } - kernelSize.push_back(constantOne); - } - - Value kernelSizeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize); - Value strideList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), - ValueRange{constantOne}); - Value paddingSizeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), - ValueRange{constantZero}); - Value dialationList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), - ValueRange{constantOne}); - - if (op.getResult(1).use_empty()) { - auto maxPool = rewriter.create( - loc, op.getType(0), input, kernelSizeList, strideList, - paddingSizeList, dialationList, - /*ceil_mode=*/constantFalse); - rewriter.replaceOp(op, {maxPool.getResult(), Value()}); - } else { - auto maxPool = rewriter.create( - loc, op.getType(0), op.getType(1), input, kernelSizeList, strideList, - paddingSizeList, dialationList, - /*ceil_mode=*/constantFalse); - rewriter.replaceOp(op, maxPool.getResults()); - } - return success(); - } -}; -} // namespace - -namespace { -// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op. - -// The logic of this decomposition is totally same with -// the DecomposeAtenAdaptiveAvgPool2dOp, that means currently only following two -// cases are supported: +// or `aten.max_pool1d`. +// +// Only following two cases are supported: // 1. inputSize = outputSize // 2. outputSize = 1 -class DecomposeAtenAdaptiveAvgPool1dOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenAdaptiveAvgPool1dOp op, +template +class DecomposeAtenAdaptivePool1dOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenOpT op, PatternRewriter &rewriter) const override { Location loc = op->getLoc(); MLIRContext *context = op.getContext(); @@ -8149,11 +8061,10 @@ class DecomposeAtenAdaptiveAvgPool1dOp unsigned rank = *maybeRank; Value sizeDim = rewriter.create( loc, rewriter.getI64IntegerAttr(rank - 1)); - Value inputSize = rewriter.create(loc, input, sizeDim); + Value inputSize = rewriter.createOrFold(loc, input, sizeDim); - Value outputShape = op.getOutputSize(); SmallVector outputShapeSizesTorchInt; - getListConstructElements(outputShape, outputShapeSizesTorchInt); + getListConstructElements(op.getOutputSize(), outputShapeSizesTorchInt); Value outputSize = outputShapeSizesTorchInt[0]; Value constantOne = rewriter.create( @@ -8166,18 +8077,12 @@ class DecomposeAtenAdaptiveAvgPool1dOp int64_t outputSizeInt; if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) { return rewriter.notifyMatchFailure( - op, "the output size of adaptive_pool_1d must be a constant int"); + op, "the output size of adaptive pool1d must be a constant int"); } SmallVector kernelSize; if (outputSizeInt == 1) { - BaseTensorType inputTensorType = cast(input.getType()); - ArrayRef inputShape = inputTensorType.getSizes(); - kernelSize.push_back( - inputShape[rank - 1] == kUnknownSize - ? inputSize - : rewriter.create( - loc, rewriter.getI64IntegerAttr(inputShape[rank - 1]))); + kernelSize.push_back(inputSize); } else { if (!isAssumingStrictSymbolicShapes(rewriter)) { Value cond = rewriter.create(loc, inputSize, outputSize); @@ -8198,16 +8103,40 @@ class DecomposeAtenAdaptiveAvgPool1dOp loc, Torch::ListType::get(Torch::IntType::get(context)), ValueRange{constantZero}); - rewriter.replaceOpWithNewOp( - op, op.getType(), input, kernelSizeList, strideList, paddingSizeList, - /*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue); - return success(); + if constexpr (std::is_same_v) { + rewriter.replaceOpWithNewOp( + op, op.getType(), input, kernelSizeList, strideList, paddingSizeList, + /*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue); + return success(); + } else if constexpr (std::is_same_v) { + Value dilationList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + ValueRange{constantOne}); + if (op.getResult(1).use_empty()) { + auto maxPool = rewriter.create( + loc, op.getType(0), input, kernelSizeList, strideList, + paddingSizeList, dilationList, + /*ceil_mode=*/constantFalse); + rewriter.replaceOp(op, {maxPool.getResult(), Value()}); + } else { + auto maxPool = rewriter.create( + loc, op.getType(0), op.getType(1), input, kernelSizeList, + strideList, paddingSizeList, dilationList, + /*ceil_mode=*/constantFalse); + rewriter.replaceOp(op, maxPool.getResults()); + } + return success(); + } + return rewriter.notifyMatchFailure( + op, "unimplemented: unsupported template op"); } }; } // namespace namespace { -// Decompose `aten.adaptiveAvgPool2d` op into `aten.avgPool2d` op. +// Decompose `aten.adaptive_avg_pool2d` op into `aten.avg_pool2d` op. +// Decompose `aten.adaptive_max_pool2d` op into `aten.max_pool2d` or +// `aten.max_pool2d_with_indices` op. // // For AdaptiveAvgPool2d op, when the input size is an integer multiple of // output size the kernelSize, stride and padding is calculated as follows: @@ -8217,10 +8146,10 @@ namespace { // kernelW = inW - [(outW - 1) * strideW] = strideW // paddingH = 0, paddingW = 0 // -class DecomposeAtenAdaptiveAvgPool2dOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenAdaptiveAvgPool2dOp op, +template +class DecomposeAtenAdaptivePool2dOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenOpT op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); @@ -8236,15 +8165,14 @@ class DecomposeAtenAdaptiveAvgPool2dOp Value dimH = rewriter.create( loc, rewriter.getI64IntegerAttr(rank - 2)); inputHW.push_back( - /*inH=*/rewriter.create(loc, input, dimH)); + /*inH=*/rewriter.createOrFold(loc, input, dimH)); Value dimW = rewriter.create( loc, rewriter.getI64IntegerAttr(rank - 1)); inputHW.push_back( - /*inW=*/rewriter.create(loc, input, dimW)); + /*inW=*/rewriter.createOrFold(loc, input, dimW)); - Value outputShape = op.getOutputSize(); SmallVector outputShapeSizesTorchInt; - getListConstructElements(outputShape, outputShapeSizesTorchInt); + getListConstructElements(op.getOutputSize(), outputShapeSizesTorchInt); // TODO: Add support for cases other than: // inH % outH != 0 or inW % outW != 0 where @@ -8325,11 +8253,32 @@ class DecomposeAtenAdaptiveAvgPool2dOp loc, Torch::ListType::get(Torch::IntType::get(context)), ValueRange{constantZero, constantZero}); - rewriter.replaceOpWithNewOp( - op, op.getType(), input, kernelSizeList, strideList, paddingSizeList, - /*ceilMode=*/constantFalse, /*countIncludePad=*/constantTrue, - /*divisorOverride=*/constantNone); - return success(); + if constexpr (std::is_same_v) { + rewriter.replaceOpWithNewOp( + op, op.getType(), input, kernelSizeList, strideList, paddingSizeList, + /*ceilMode=*/constantFalse, /*countIncludePad=*/constantTrue, + /*divisorOverride=*/constantNone); + return success(); + } else if constexpr (std::is_same_v) { + Value dilationList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(context)), + ValueRange{constantOne, constantOne}); + if (op.getResult(1).use_empty()) { + auto maxPool = rewriter.create( + loc, op.getType(0), input, kernelSizeList, strideList, + paddingSizeList, dilationList, /*ceil_mode=*/constantFalse); + rewriter.replaceOp(op, {maxPool.getResult(), Value()}); + } else { + auto maxPool = rewriter.create( + loc, op.getType(0), op.getType(1), input, kernelSizeList, + strideList, paddingSizeList, dilationList, + /*ceil_mode=*/constantFalse); + rewriter.replaceOp(op, maxPool.getResults()); + } + return success(); + } + return rewriter.notifyMatchFailure( + op, "unimplemented: unsupported template op"); } }; } // namespace @@ -11760,9 +11709,14 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenAdaptivePool1dOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenAdaptivePool1dOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenAdaptivePool2dOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenAdaptivePool2dOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 9e1cb530aa0b..6695f2964b65 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -509,6 +509,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index fc4673d4d1ab..c19d329a521e 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2702,6 +2702,7 @@ "AdaptiveMaxPool2dDynamicNoBatch_basic", "AdaptiveMaxPool2dDynamicWithIndices_basic", "AdaptiveMaxPool2dDynamic_basic", + "AdaptiveMaxPool2dFixedKernelStrideSizeStaticModule_basic", "AdaptiveMaxPool2dStaticWithIndices_basic", "AdaptiveMaxPool2dStatic_basic", "AdaptiveMaxPool3dDynamicNoBatch_basic", diff --git a/projects/pt1/python/torch_mlir/torchscript.py b/projects/pt1/python/torch_mlir/torchscript.py index c6cf625e4fe1..cf979838f0f0 100644 --- a/projects/pt1/python/torch_mlir/torchscript.py +++ b/projects/pt1/python/torch_mlir/torchscript.py @@ -145,34 +145,6 @@ def _get_for_tracing( return result -# The set of ops that are considered legal for each backend. -# These are currently quite load-bearing, since different backends might be -# missing patterns for decomposed forms of certain ops. -# TODO: Tighten up the definition of these "conditionally legal for backends" -# ops in the backend contract, and move these lists somewhere deeper in the -# compiler where each backend can "own" its set of legal ops. -BACKEND_LEGAL_OPS = { - OutputType.TOSA: [ - "aten.flatten.using_ints", - "aten.native_layer_norm", - "aten.linear", - ], - OutputType.LINALG_ON_TENSORS: [ - "aten.flatten.using_ints", - "aten.adaptive_avg_pool1d", - "aten.adaptive_avg_pool2d", - "aten.unflatten.int", - ], - OutputType.STABLEHLO: [ - "aten.amax", - "aten.amin", - "aten.randn.generator", - "aten.normal_functional", - "aten.fmod.Tensor", - ], -} - - def _canon_extra_library( extra_library, extra_library_file_name="custom_op_extra_library.mlir" ): @@ -249,19 +221,10 @@ def compile( if ignore_traced_shapes and not use_tracing: raise Exception("`ignore_traced_shapes` requires `use_tracing`") - # We only allow `backend_legal_ops` to be specified for the `"torch"` - # output type because the other output types actually invoke their - # respective backends (Linalg, TOSA, or STABLEHLO), and those backends have - # very specific requirements about the ops which are legal. - # See `BACKEND_LEGAL_OPS` for more details. if backend_legal_ops is not None: - if output_type != OutputType.TORCH: - raise Exception( - "`backend_legal_ops` is only valid with the " "`torch` output type" - ) backend_legal_ops = list(sorted(set(backend_legal_ops))) else: - backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, []) + backend_legal_ops = [] # For FX-based models, automatically strip overloads. if isinstance(model, torch.fx.GraphModule): diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py index 9793f60cd683..10e06411db31 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py @@ -11,6 +11,7 @@ from torch._dynamo.backends.common import aot_autograd from torch_mlir import fx +from torch_mlir.compiler_utils import OutputType from torch_mlir_e2e_test.configs.utils import ( recursively_convert_to_numpy, recursively_convert_from_numpy, @@ -19,6 +20,13 @@ from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME +BACKEND_LEGAL_OPS = { + OutputType.LINALG_ON_TENSORS: [ + "aten.adaptive_max_pool2d", + ], +} + + def refine_result_type(_result): if isinstance(_result, tuple): return tuple(refine_result_type(x) for x in _result) @@ -38,6 +46,9 @@ def __init__(self, backend, output_type="linalg-on-tensors", torch_compile=False self._backend = backend self._torch_compile = torch_compile self._output_type = output_type + self._backend_legal_ops = BACKEND_LEGAL_OPS.get( + OutputType.get(self._output_type), [] + ) def compile( self, program: torch.nn.Module, verbose: bool = False @@ -86,6 +97,7 @@ def _base_backend(gm: torch.fx.GraphModule, example_inputs): output_type=self._output_type, model_name=artifact.__class__.__name__, verbose=self._verbose, + backend_legal_ops=self._backend_legal_ops, ) module = self._backend.compile(module) backend_module = self._backend.load(module) @@ -131,6 +143,7 @@ def _export_run(self, artifact: torch.nn.Module, trace: Trace) -> Trace: # enabling this here ensures they don't regress either. import_symbolic_shape_expressions=True, verbose=self._verbose, + backend_legal_ops=self._backend_legal_ops, ) module = self._backend.compile(module) backend_module = self._backend.load(module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/jit_importer_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/jit_importer_backend.py index dda29badec89..4f547d531294 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/jit_importer_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/jit_importer_backend.py @@ -7,6 +7,7 @@ import torch from torch_mlir import torchscript +from torch_mlir.compiler_utils import OutputType from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders @@ -17,6 +18,34 @@ ) +# The set of ops that are considered legal for each backend. +# These are currently quite load-bearing, since different backends might be +# missing patterns for decomposed forms of certain ops. +# TODO: Tighten up the definition of these "conditionally legal for backends" +# ops in the backend contract, and move these lists somewhere deeper in the +# compiler where each backend can "own" its set of legal ops. +BACKEND_LEGAL_OPS = { + OutputType.TOSA: [ + "aten.flatten.using_ints", + "aten.native_layer_norm", + "aten.linear", + ], + OutputType.LINALG_ON_TENSORS: [ + "aten.flatten.using_ints", + "aten.adaptive_avg_pool1d", + "aten.adaptive_avg_pool2d", + "aten.unflatten.int", + ], + OutputType.STABLEHLO: [ + "aten.amax", + "aten.amin", + "aten.randn.generator", + "aten.normal_functional", + "aten.fmod.Tensor", + ], +} + + class JITImporterTestConfig(TestConfig): """TestConfig that runs the torch.nn.Module with JIT Importer""" @@ -24,11 +53,18 @@ def __init__(self, backend, output_type="linalg-on-tensors"): super().__init__() self.backend = backend self.output_type = output_type + self.backend_legal_ops = BACKEND_LEGAL_OPS.get( + OutputType.get(self.output_type), [] + ) def compile(self, program: torch.nn.Module, verbose: bool = False) -> Any: example_args = convert_annotations_to_placeholders(program.forward) module = torchscript.compile( - program, example_args, output_type=self.output_type, verbose=verbose + program, + example_args, + output_type=self.output_type, + backend_legal_ops=self.backend_legal_ops, + verbose=verbose, ) return self.backend.compile(module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py index fcea6d87de6f..2854519edddf 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py @@ -24,9 +24,11 @@ from torch_mlir.dynamo import _get_decomposition_table from torch_mlir.torchscript import ( _example_args, - BACKEND_LEGAL_OPS, _canon_extra_library, ) +from torch_mlir_e2e_test.configs.jit_importer_backend import ( + BACKEND_LEGAL_OPS, +) from torch_mlir_e2e_test.configs.utils import ( recursively_convert_to_numpy, recursively_convert_from_numpy, diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index eed71bccdbc6..c6cc264d6aff 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -2107,6 +2107,52 @@ def AdaptiveMaxPool2dStaticWithIndices_basic(module, tu: TestUtils): module.forward(tu.rand(1, 512, 10, 16)) +class AdaptiveMaxPool2dFixedKernelStrideSizeStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.amp2d = torch.nn.AdaptiveMaxPool2d((2, 2)) + + @export + @annotate_args( + [ + None, + ([1, 3, 7, 7], torch.float32, True), + ] + ) + def forward(self, x): + return self.amp2d(x) + + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool2dFixedKernelStrideSizeStaticModule() +) +def AdaptiveMaxPool2dFixedKernelStrideSizeStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 3, 7, 7)) + + +class AdaptiveMaxPool2dUnitOutputSizeStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.amp2d = torch.nn.AdaptiveMaxPool2d((1, 1)) + + @export + @annotate_args( + [ + None, + ([1, 512, 7, 7], torch.float32, True), + ] + ) + def forward(self, x): + return self.amp2d(x) + + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool2dUnitOutputSizeStaticModule() +) +def AdaptiveMaxPool2dUnitOutputSizeStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 512, 7, 7)) + + # AdaptiveMaxPool3d