diff --git a/.github/actions/setup-build/action.yml b/.github/actions/setup-build/action.yml index 7ed50e866492..a21c9a1d7296 100644 --- a/.github/actions/setup-build/action.yml +++ b/.github/actions/setup-build/action.yml @@ -27,7 +27,7 @@ runs: steps: - name: Set up Python if: ${{ runner.arch == 'X64' }} - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + uses: actions/setup-python@v4 with: python-version: '3.11' @@ -74,7 +74,7 @@ runs: - name: Enable ccache if: ${{ inputs.cache-enabled == 'true' }} - uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache@v3 with: path: ${{ github.workspace }}/.ccache key: ${{ runner.os }}-${{ inputs.cache-suffix }}-${{ github.sha }} diff --git a/.github/workflows/RollPyTorch.yml b/.github/workflows/RollPyTorch.yml index 8c571893e145..3c8b95a3181a 100644 --- a/.github/workflows/RollPyTorch.yml +++ b/.github/workflows/RollPyTorch.yml @@ -22,7 +22,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Get torch-mlir - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@v3 with: submodules: 'false' token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} @@ -95,7 +95,7 @@ jobs: - name: Post issue comment on build failure if: failure() - uses: peter-evans/create-or-update-comment@71345be0265236311c031f5c7866368bd1eff043 # v4.0.0 + uses: peter-evans/create-or-update-comment@v2 with: issue-number: 1690 body: | @@ -111,7 +111,7 @@ jobs: - name: Update PyTorch Build Cache (if running on main branch) if: github.ref_name == 'main' id: cache-pytorch - uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache@v3 with: path: ${{ github.workspace }}/build_tools/python_deploy/wheelhouse key: ${{ runner.os }}-pytorch-${{ env.PT_HASH }} @@ -127,7 +127,7 @@ jobs: git pull origin main - name: Create pull request - uses: peter-evans/create-pull-request@5e914681df9dc83aa4e4905692ca88beb2f9e91f # v7.0.5 + uses: peter-evans/create-pull-request@v5.0.1 with: author: Roll PyTorch Action branch: rollpytorch diff --git a/.github/workflows/bazelBuildAndTest.yml b/.github/workflows/bazelBuildAndTest.yml index 4eeef0b9bb5e..23f2addbe5af 100644 --- a/.github/workflows/bazelBuildAndTest.yml +++ b/.github/workflows/bazelBuildAndTest.yml @@ -22,7 +22,7 @@ concurrency: jobs: ubuntu-build: name: ubuntu-x86_64 - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest steps: - name: Prepare workspace @@ -32,7 +32,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Checkout torch-mlir - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@v3 with: submodules: 'true' @@ -40,7 +40,7 @@ jobs: # restore to avoid the cache going stale over time # https://github.com/actions/cache/blob/main/workarounds.md#update-a-cache - name: Setup cache for bazel - uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0 + uses: actions/cache@v3 with: path: ~/.cache/bazel key: torch_mlir-bazel-build-cache-${{ runner.os }}-${{ github.sha }} @@ -102,7 +102,7 @@ jobs: - name: Send mail if: failure() - uses: dawidd6/action-send-mail@2cea9617b09d79a095af21254fbcb7ae95903dde # v3.12.0 + uses: dawidd6/action-send-mail@v3 with: server_address: ${{ secrets.SMTP_SERVER }} server_port: ${{ secrets.SMTP_PORT }} diff --git a/.github/workflows/buildRelease.yml b/.github/workflows/buildRelease.yml index a304672b474f..e84aabb4b388 100644 --- a/.github/workflows/buildRelease.yml +++ b/.github/workflows/buildRelease.yml @@ -28,7 +28,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Get torch-mlir - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@v3 with: submodules: 'true' fetch-depth: 0 @@ -59,7 +59,7 @@ jobs: - name: Publish Release (if requested) if: github.event.inputs.release_id != '' id: publish_release - uses: eregon/publish-release@01df127f5e9a3c26935118e22e738d95b59d10ce # v1.0.6 + uses: eregon/publish-release@v1 env: GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} with: @@ -75,7 +75,7 @@ jobs: # # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - name: Store the binary wheel - uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + uses: actions/upload-artifact@v2 with: name: wheels path: dist @@ -96,7 +96,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Get torch-mlir - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@v3 with: submodules: 'true' fetch-depth: 0 @@ -127,7 +127,7 @@ jobs: - name: Publish Release (if requested) if: github.event.inputs.release_id != '' id: publish_release - uses: eregon/publish-release@01df127f5e9a3c26935118e22e738d95b59d10ce # v1.0.6 + uses: eregon/publish-release@v1 env: GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} with: @@ -143,7 +143,7 @@ jobs: # # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - name: Store the binary wheel - uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + uses: actions/upload-artifact@v2 with: name: wheels path: dist @@ -156,7 +156,7 @@ jobs: package: [torch-mlir] steps: - name: Get torch-mlir - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@v3 with: submodules: 'true' - uses: ./.github/actions/setup-build @@ -187,7 +187,7 @@ jobs: - name: Publish Release (if requested) if: github.event.inputs.release_id != '' id: publish_release - uses: eregon/publish-release@01df127f5e9a3c26935118e22e738d95b59d10ce # v1.0.6 + uses: eregon/publish-release@v1 env: GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} with: @@ -203,7 +203,7 @@ jobs: # # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - name: Store the binary wheel - uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + uses: actions/upload-artifact@v2 with: name: wheels path: dist @@ -216,7 +216,7 @@ jobs: package: [torch-mlir] steps: - name: Get torch-mlir - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@v3 with: submodules: 'true' - uses: ./.github/actions/setup-build @@ -250,7 +250,7 @@ jobs: - name: Publish Release (if requested) if: github.event.inputs.release_id != '' id: publish_release - uses: eregon/publish-release@01df127f5e9a3c26935118e22e738d95b59d10ce # v1.0.6 + uses: eregon/publish-release@v1 env: GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} with: @@ -267,13 +267,13 @@ jobs: # # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - name: Store the binary wheel - uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 + uses: actions/upload-artifact@v2 with: name: wheels path: dist publish_releases: - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest needs: - build_linux - build_linux_arm64 @@ -285,7 +285,7 @@ jobs: steps: - name: Invoke Publish Releases Page - uses: benc-uk/workflow-dispatch@e2e5e9a103e331dad343f381a29e654aea3cf8fc # v1.2.4 + uses: benc-uk/workflow-dispatch@v1 with: workflow: Publish releases page token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} diff --git a/.github/workflows/gh-pages-releases.yml b/.github/workflows/gh-pages-releases.yml index e87630edb28c..a0eb45257b11 100644 --- a/.github/workflows/gh-pages-releases.yml +++ b/.github/workflows/gh-pages-releases.yml @@ -8,7 +8,7 @@ on: jobs: scrape_and_publish_releases: name: "Scrape and publish releases" - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest # Don't run this in everyone's forks. if: github.repository == 'llvm/torch-mlir' @@ -20,7 +20,7 @@ jobs: # existing lock files. sudo rm -rf $GITHUB_WORKSPACE/* - name: Checking out repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@v3 with: token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - name: Run scrape releases script diff --git a/.github/workflows/merge-rollpytorch.yml b/.github/workflows/merge-rollpytorch.yml index e335f1fdfd7d..58a91fd1d409 100644 --- a/.github/workflows/merge-rollpytorch.yml +++ b/.github/workflows/merge-rollpytorch.yml @@ -9,7 +9,7 @@ on: jobs: merge-pr: - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest if: | github.repository == 'llvm/torch-mlir' && github.event.workflow_run.actor.login == 'stellaraccident' && @@ -18,7 +18,7 @@ jobs: steps: # Fetch the repo first so that the gh command knows where to look for the PR - name: Fetch Repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@v3 with: token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} diff --git a/.github/workflows/oneshotSnapshotPackage.yml b/.github/workflows/oneshotSnapshotPackage.yml index 92d732cea3a6..ec1878606624 100644 --- a/.github/workflows/oneshotSnapshotPackage.yml +++ b/.github/workflows/oneshotSnapshotPackage.yml @@ -18,7 +18,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Checking out repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@v3 with: token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} @@ -43,15 +43,16 @@ jobs: - name: Create Release id: create_release - uses: ncipollo/release-action@2c591bcc8ecdcd2db72b97d6147f871fcd833ba5 # v1.14.0 + uses: actions/create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} with: - tag: ${{ env.tag_name }} - name: torch-mlir snapshot ${{ env.tag_name }} + tag_name: ${{ env.tag_name }} + release_name: torch-mlir snapshot ${{ env.tag_name }} body: | Automatic snapshot release of torch-mlir. draft: true prerelease: false - token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - name: "Invoke workflow :: Build and Test" uses: benc-uk/workflow-dispatch@v1 diff --git a/.github/workflows/pre-commit-all.yml b/.github/workflows/pre-commit-all.yml index 2c0d61e92747..e17d4ebdbb43 100644 --- a/.github/workflows/pre-commit-all.yml +++ b/.github/workflows/pre-commit-all.yml @@ -6,10 +6,10 @@ on: jobs: pre-commit: - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + - uses: pre-commit/action@v3.0.1 with: extra_args: --color=always --all-files diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 6a848fe8674f..29733c2e5d45 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -5,13 +5,13 @@ on: jobs: pre-commit: - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest steps: - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/checkout@v3 with: # requites to grab the history of the PR fetch-depth: 0 - - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 + - uses: actions/setup-python@v3 + - uses: pre-commit/action@v3.0.1 with: extra_args: --color=always --from-ref ${{ github.event.pull_request.base.sha }} --to-ref ${{ github.event.pull_request.head.sha }} diff --git a/.github/workflows/releaseSnapshotPackage.yml b/.github/workflows/releaseSnapshotPackage.yml index 7b575764ac8e..8a0ec914440f 100644 --- a/.github/workflows/releaseSnapshotPackage.yml +++ b/.github/workflows/releaseSnapshotPackage.yml @@ -9,7 +9,7 @@ on: jobs: release_snapshot_package: name: "Tag snapshot release" - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest # Don't run this in everyone's forks. if: github.repository == 'llvm/torch-mlir' steps: @@ -21,7 +21,7 @@ jobs: sudo rm -rf $GITHUB_WORKSPACE/* - name: Checking out repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + uses: actions/checkout@v3 with: token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} @@ -46,25 +46,26 @@ jobs: - name: Create Release id: create_release - uses: ncipollo/release-action@2c591bcc8ecdcd2db72b97d6147f871fcd833ba5 # v1.14.0 + uses: actions/create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} with: - tag: ${{ env.tag_name }} - name: torch-mlir snapshot ${{ env.tag_name }} + tag_name: ${{ env.tag_name }} + release_name: torch-mlir snapshot ${{ env.tag_name }} body: | Automatic snapshot release of torch-mlir. draft: true prerelease: false - token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - name: "Invoke workflow :: Build and Test" - uses: benc-uk/workflow-dispatch@e2e5e9a103e331dad343f381a29e654aea3cf8fc # v1.2.4 + uses: benc-uk/workflow-dispatch@v1 with: workflow: Build and Test token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} ref: "${{ env.tag_name }}" - name: "Invoke workflow :: Release Build" - uses: benc-uk/workflow-dispatch@e2e5e9a103e331dad343f381a29e654aea3cf8fc # v1.2.4 + uses: benc-uk/workflow-dispatch@v1 with: workflow: Release Build token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} diff --git a/build_tools/ci/build_posix.sh b/build_tools/ci/build_posix.sh index 36e9057c973f..ea3e570c8b7e 100755 --- a/build_tools/ci/build_posix.sh +++ b/build_tools/ci/build_posix.sh @@ -50,7 +50,7 @@ cmake -S "$repo_root/externals/llvm-project/llvm" -B "$build_dir" \ -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$repo_root" \ -DLLVM_TARGETS_TO_BUILD=host \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DTORCH_MLIR_ENABLE_LTC=OFF \ + -DTORCH_MLIR_ENABLE_LTC=ON \ -DTORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS=ON echo "::endgroup::" diff --git a/externals/llvm-project b/externals/llvm-project index 813f7c3820d0..6c64c8a6f3f7 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 813f7c3820d00349fe23bfc6ba26159764541540 +Subproject commit 6c64c8a6f3f77c30745c751d4163ff6bf2fc323b diff --git a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h index 221745b1c26e..a6d774a64db1 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TorchToTosa.h @@ -12,25 +12,12 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" - #include namespace mlir { namespace torch { - -/// Collect a set of legal/illegal ops for converting Torch operations to Tosa -/// dialect. -void populateTorchToTosaConversionLegalOps(ConversionTarget &target); - -/// Collect a set of patterns to convert Torch operations to Tosa dialect + -/// return the set of illegalOps -std::set -populateTorchToTosaConversionPatternsAndIllegalOps(TypeConverter &typeConverter, - RewritePatternSet &patterns); - std::unique_ptr> createConvertTorchToTosaPass(); -} // namespace torch +} } // namespace mlir #endif // TORCHMLIR_CONVERSION_TORCHTOTOSA_TORCHTOTOSA_H diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index 264fb4966d39..d21dd5504dcd 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -97,15 +97,6 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, Value torchOptionalInt, Value builtinInt, Value defaultValue, Value dimSize); -// Helper function to unsqueeze the input tensor at given dim. -// Returns the unsqueezed tensor or failure. -FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, - Value input, int64_t dim); - -// Helper function to squeeze the input tensor at given dim. -// Returns the squeezed tensor or failure. -FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, - Value input, int64_t dim); } // namespace Torch } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index ff1ffd7e2b62..a86474551eb1 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -310,7 +310,9 @@ def Torch_AtenRrelu_Op : Torch_Op<"aten.rrelu_", [ } def Torch_AtenRreluWithNoiseOp : Torch_Op<"aten.rrelu_with_noise", [ - AllowsTypeRefinement + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly ]> { let summary = "Generated op for `aten::rrelu_with_noise : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor)`"; let arguments = (ins @@ -4608,29 +4610,6 @@ def Torch_AtenTrunc_Op : Torch_Op<"aten.trunc_", [ }]; } -def Torch_AtenSpecialExpm1Op : Torch_Op<"aten.special_expm1", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::special_expm1 : (Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self - ); - let results = (outs - AnyTorchOptionalTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenSpecialExpm1Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenSpecialExpm1Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - def Torch_AtenSignOp : Torch_Op<"aten.sign", [ AllowsTypeRefinement, HasValueSemantics, @@ -6705,35 +6684,6 @@ def Torch_AtenConv3dOp : Torch_Op<"aten.conv3d", [ }]; } -def Torch_AtenConv3dPaddingOp : Torch_Op<"aten.conv3d.padding", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::conv3d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$input, - AnyTorchTensorType:$weight, - AnyTorchOptionalTensorType:$bias, - AnyTorchListOfTorchIntType:$stride, - Torch_StringType:$padding, - AnyTorchListOfTorchIntType:$dilation, - Torch_IntType:$groups - ); - let results = (outs - AnyTorchOptionalTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenConv3dPaddingOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 7, 1); - } - void AtenConv3dPaddingOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 7, 1); - } - }]; -} - def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ AllowsTypeRefinement, HasValueSemantics, @@ -6763,35 +6713,6 @@ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ }]; } -def Torch_AtenConv2dPaddingOp : Torch_Op<"aten.conv2d.padding", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::conv2d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$input, - AnyTorchTensorType:$weight, - AnyTorchOptionalTensorType:$bias, - AnyTorchListOfTorchIntType:$stride, - Torch_StringType:$padding, - AnyTorchListOfTorchIntType:$dilation, - Torch_IntType:$groups - ); - let results = (outs - AnyTorchOptionalTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenConv2dPaddingOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 7, 1); - } - void AtenConv2dPaddingOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 7, 1); - } - }]; -} - def Torch_AtenConv1dOp : Torch_Op<"aten.conv1d", [ AllowsTypeRefinement, HasValueSemantics, @@ -6821,35 +6742,6 @@ def Torch_AtenConv1dOp : Torch_Op<"aten.conv1d", [ }]; } -def Torch_AtenConv1dPaddingOp : Torch_Op<"aten.conv1d.padding", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::conv1d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$input, - AnyTorchTensorType:$weight, - AnyTorchOptionalTensorType:$bias, - AnyTorchListOfTorchIntType:$stride, - Torch_StringType:$padding, - AnyTorchListOfTorchIntType:$dilation, - Torch_IntType:$groups - ); - let results = (outs - AnyTorchOptionalTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenConv1dPaddingOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 7, 1); - } - void AtenConv1dPaddingOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 7, 1); - } - }]; -} - def Torch_AtenConvTranspose1dOp : Torch_Op<"aten.conv_transpose1d", [ AllowsTypeRefinement, HasValueSemantics, @@ -9491,31 +9383,6 @@ def Torch_AtenMseLossBackwardOp : Torch_Op<"aten.mse_loss_backward", [ }]; } -def Torch_AtenL1LossOp : Torch_Op<"aten.l1_loss", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::l1_loss : (Tensor, Tensor, int) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$target, - Torch_IntType:$reduction - ); - let results = (outs - AnyTorchOptionalTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenL1LossOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); - } - void AtenL1LossOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); - } - }]; -} - def Torch_AtenUpsampleNearest2dBackwardOp : Torch_Op<"aten.upsample_nearest2d_backward", [ AllowsTypeRefinement, HasValueSemantics, @@ -13431,32 +13298,6 @@ def Torch_AtenFftFftOp : Torch_Op<"aten.fft_fft", [ }]; } -def Torch_AtenFftRfftOp : Torch_Op<"aten.fft_rfft", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::fft_rfft : (Tensor, int?, int, str?) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchOptionalIntType:$n, - Torch_IntType:$dim, - AnyTorchOptionalStringType:$norm - ); - let results = (outs - AnyTorchOptionalTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenFftRfftOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 4, 1); - } - void AtenFftRfftOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 4, 1); - } - }]; -} - def Torch_AtenFftIfftOp : Torch_Op<"aten.fft_ifft", [ AllowsTypeRefinement, HasValueSemantics, @@ -14859,29 +14700,6 @@ def Torch_AtenHstackOp : Torch_Op<"aten.hstack", [ }]; } -def Torch_AtenColumnStackOp : Torch_Op<"aten.column_stack", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::column_stack : (Tensor[]) -> (Tensor)`"; - let arguments = (ins - AnyTorchListOfTensorType:$tensors - ); - let results = (outs - AnyTorchOptionalTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenColumnStackOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenColumnStackOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - def Torch_AtenAppendTOp : Torch_Op<"aten.append.t", [ AllowsTypeRefinement ]> { @@ -16019,31 +15837,6 @@ def Torch_AtenMulIntOp : Torch_Op<"aten.mul.int", [ let hasCanonicalizer = 1; } -def Torch_AtenMulIntFloatOp : Torch_Op<"aten.mul.int_float", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::mul.int_float : (int, float) -> (float)`"; - let arguments = (ins - Torch_IntType:$a, - Torch_FloatType:$b - ); - let results = (outs - Torch_FloatType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenMulIntFloatOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenMulIntFloatOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; - let hasFolder = 1; -} - def Torch_AtenDivIntOp : Torch_Op<"aten.div.int", [ AllowsTypeRefinement, HasValueSemantics, @@ -16141,31 +15934,6 @@ def Torch_AtenAddFloatIntOp : Torch_Op<"aten.add.float_int", [ let hasFolder = 1; } -def Torch_AtenMulFloatIntOp : Torch_Op<"aten.mul.float_int", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::mul.float_int : (float, int) -> (float)`"; - let arguments = (ins - Torch_FloatType:$a, - Torch_IntType:$b - ); - let results = (outs - Torch_FloatType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenMulFloatIntOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenMulFloatIntOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; - let hasFolder = 1; -} - def Torch_AtenSubFloatOp : Torch_Op<"aten.sub.float", [ AllowsTypeRefinement, HasValueSemantics, @@ -16510,31 +16278,6 @@ def Torch_Aten__And__BoolOp : Torch_Op<"aten.__and__.bool", [ }]; } -def Torch_AtenEqBoolOp : Torch_Op<"aten.eq.bool", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::eq.bool : (bool, bool) -> (bool)`"; - let arguments = (ins - Torch_BoolType:$a, - Torch_BoolType:$b - ); - let results = (outs - Torch_BoolType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenEqBoolOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenEqBoolOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; - let hasFolder = 1; -} - def Torch_AtenNeBoolOp : Torch_Op<"aten.ne.bool", [ AllowsTypeRefinement, HasValueSemantics, @@ -16682,31 +16425,6 @@ def Torch_AtenLenTOp : Torch_Op<"aten.len.t", [ let hasCanonicalizer = 1; } -def Torch_AtenMulLeftTOp : Torch_Op<"aten.mul.left_t", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::mul.left_t : (t[], int) -> (t[])`"; - let arguments = (ins - AnyTorchListType:$l, - Torch_IntType:$n - ); - let results = (outs - AnyTorchListType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenMulLeftTOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenMulLeftTOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; - let hasCanonicalizer = 1; -} - def Torch_Aten__Getitem__TOp : Torch_Op<"aten.__getitem__.t", [ AllowsTypeRefinement, ReadOnly @@ -17132,29 +16850,6 @@ def Torch_AtenTrilIndicesOp : Torch_Op<"aten.tril_indices", [ let hasVerifier = 1; } -def Torch_AtenDeg2radOp : Torch_Op<"aten.deg2rad", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::deg2rad : (Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self - ); - let results = (outs - AnyTorchOptionalTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenDeg2radOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenDeg2radOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - def Torch_Aten_SoftmaxBackwardDataOp : Torch_Op<"aten._softmax_backward_data", [ AllowsTypeRefinement, HasValueSemantics, @@ -17517,35 +17212,6 @@ def Torch_AtenRreluWithNoiseBackwardOp : Torch_Op<"aten.rrelu_with_noise_backwar }]; } -def Torch_AtenRreluWithNoiseFunctionalOp : Torch_Op<"aten.rrelu_with_noise_functional", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::rrelu_with_noise_functional : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor, Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$noise, - AnyTorchScalarType:$lower, - AnyTorchScalarType:$upper, - Torch_BoolType:$training, - AnyTorchOptionalGeneratorType:$generator - ); - let results = (outs - AnyTorchOptionalTensorType:$result0, - AnyTorchOptionalTensorType:$noise_out - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenRreluWithNoiseFunctionalOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 6, 2); - } - void AtenRreluWithNoiseFunctionalOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 6, 2); - } - }]; -} - def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index b0a40e35c652..cf31c8f9735a 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -20,8 +20,6 @@ namespace Torch { int64_t toPositiveDim(int64_t dim, int64_t inputRank); bool isValidDim(int64_t dim, int64_t inputRank); -Value toIntListConstruct(PatternRewriter &rewriter, Location loc, - ArrayRef cstInput); bool getListConstructElements(Value v, SmallVectorImpl &elems); /// Returns the index indicated by `v` for a list of given `length`. /// If the index is negative, it is adjusted to `length` + `v`. diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index d8517fbd156d..85dbfdac1961 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -7,10 +7,12 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/DialectResourceBlobManager.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/Support/FormatVariadic.h" #include using namespace mlir; @@ -1290,7 +1292,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( }); patterns.onOp( "Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Location loc = binder.getLoc(); Torch::ValueTensorType resultType; Value input, weight; int64_t group; @@ -1315,6 +1316,14 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, "unsupported conversion: kernel_shape list size should have " "number of values equal to weight_rank - 2"); + } else { + for (unsigned i = 0; i < kernelShape.size(); i++) { + if (weightShape[i + 2] != kernelShape[i]) { + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: kernel_shape value " + "should be equal to the weight tensor shape"); + } + } } } @@ -1371,11 +1380,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( ArrayRef inputShape = inputTensorType.getSizes(); padding.resize_for_overwrite(2 * spatialRank); for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) { - if (weightShape[dimIdx + 2] == Torch::kUnknownSize || - inputShape[dimIdx + 2] == Torch::kUnknownSize) - return rewriter.notifyMatchFailure( - binder.op, - "expected weight and input tensor to have static shape"); const int64_t dilatedKernelSize = dilations[dimIdx] * (weightShape[dimIdx + 2] - 1) + 1; int64_t totalPad = ((inputShape[dimIdx + 2] + strides[dimIdx] - 1) / @@ -1401,10 +1405,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (padding.size() != 2 * (rank - 2)) { for (int64_t i : padding) { cstPadding.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(i))); + binder.getLoc(), rewriter.getI64IntegerAttr(i))); } paddingList = rewriter.create( - loc, + binder.getLoc(), Torch::ListType::get( Torch::IntType::get(binder.op->getContext())), cstPadding); @@ -1427,10 +1431,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (matchedPads) { for (unsigned i = 0; i < padding.size() / 2; i++) { cstPadding.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(padding[i]))); + binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); } paddingList = rewriter.create( - loc, + binder.getLoc(), Torch::ListType::get( Torch::IntType::get(binder.op->getContext())), cstPadding); @@ -1439,40 +1443,40 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( SmallVector inputPaddingList; for (uint32_t i = 0; i < padding.size() / 2; i++) { padsRearrange.emplace_back(rewriter.create( - loc, rewriter.getI64IntegerAttr( - padding[padding.size() / 2 - i - 1]))); + binder.getLoc(), rewriter.getI64IntegerAttr( + padding[padding.size() / 2 - i - 1]))); padsRearrange.emplace_back(rewriter.create( - loc, + binder.getLoc(), rewriter.getI64IntegerAttr(padding[padding.size() - i - 1]))); inputPaddingList.emplace_back( rewriter.create( - loc, rewriter.getI64IntegerAttr(0))); + binder.getLoc(), rewriter.getI64IntegerAttr(0))); } // The conv op itself will have no padding since the actual padding // is performed using the torch.pad preceding it. paddingList = rewriter.create( - loc, + binder.getLoc(), Torch::ListType::get( Torch::IntType::get(binder.op->getContext())), inputPaddingList); Value padsSizeList = rewriter .create( - loc, + binder.getLoc(), Torch::ListType::get( rewriter.getType()), padsRearrange) .getResult(); Value modeVal = rewriter.create( - loc, rewriter.getStringAttr("constant")); + binder.getLoc(), rewriter.getStringAttr("constant")); Value constantValue; if (isa(inputTensorType.getDtype())) constantValue = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + binder.getLoc(), rewriter.getI64IntegerAttr(0)); if (isa(inputTensorType.getDtype())) constantValue = rewriter.create( - loc, rewriter.getF64FloatAttr(0.0f)); + binder.getLoc(), rewriter.getF64FloatAttr(0.0f)); // Pad output shape must be computed explicitly from the pad values SmallVector newInputShape(inputTensorType.getSizes()); for (uint32_t i = 0; i < padding.size() / 2; i++) { @@ -1482,44 +1486,46 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( auto padTy = rewriter.getType( newInputShape, inputTensorType.getDtype()); paddedInput = rewriter.create( - loc, padTy, input, padsSizeList, modeVal, constantValue); + binder.getLoc(), padTy, input, padsSizeList, modeVal, + constantValue); } } for (int64_t i : dilations) { cstDilations.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(i))); + binder.getLoc(), rewriter.getI64IntegerAttr(i))); } for (int64_t i : strides) { cstStrides.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(i))); + binder.getLoc(), rewriter.getI64IntegerAttr(i))); } Value cstZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + binder.getLoc(), rewriter.getI64IntegerAttr(0)); cstOutputPadding = {cstZero, cstZero}; Value dilationsList = rewriter.create( - loc, + binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstDilations); Value stridesList = rewriter.create( - loc, + binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstStrides); Value outputPaddingList = rewriter.create( - loc, + binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstOutputPadding); - Value transposed = rewriter.create(loc, false); + Value transposed = + rewriter.create(binder.getLoc(), false); Value bias; if (binder.op->getNumOperands() == 3) { if (binder.tensorOperandAtIndex(bias, 2)) { return failure(); } } else { - bias = rewriter.create(loc); + bias = rewriter.create(binder.getLoc()); } Value cstGroup = rewriter.create( - loc, rewriter.getI64IntegerAttr(group)); + binder.getLoc(), rewriter.getI64IntegerAttr(group)); rewriter.replaceOpWithNewOp( binder.op, resultType, paddedInput, weight, bias, stridesList, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 12d8683bc9d1..1f3ff7ac2346 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1093,35 +1093,18 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rewriter.replaceOp(binder.op, nllLoss); return success(); }); - patterns.onOp( - "NonZero", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) { - return failure(); - } - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - Value one = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); - auto rawSize = resultType.getSizes(); - SmallVector torchResultSize(rawSize.rbegin(), rawSize.rend()); - auto torchResultType = rewriter.getType( - torchResultSize, resultType.getDtype()); - auto nonZero = rewriter.create( - binder.getLoc(), torchResultType, operand); - // The output tensor has a shape of ((n, z)), where (n) is the - // number of dimensions in the input tensor and (z) is the - // number of non-zero elements2. This is different from - // PyTorch's default behavior, where the dimensions are - // reversed. - rewriter.replaceOpWithNewOp( - binder.op, resultType, nonZero, zero, one); - return success(); - }); + patterns.onOp("NonZero", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp( "MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; @@ -2560,33 +2543,26 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.s64IntegerAttr(stashType, "stash_type", 1)) return failure(); + // Since the support for `stash_type` arg does not exist in + // the torch op so we just check for the stash_type to be same + // as the input dtype since that won't require us to do any + // input type conversion and hence can be supported. + auto xType = cast(x.getType()); std::optional stashTypeIntTorch = onnxDtypeIntToTorchDtypeInt(stashType); if (!stashTypeIntTorch.has_value()) return rewriter.notifyMatchFailure( binder.op, "unimplemented support for the given stash_type"); + FailureOr stashDtype = Torch::getTypeForScalarType( binder.op->getContext(), (torch_upstream::ScalarType)stashTypeIntTorch.value()); if (failed(stashDtype)) return failure(); - - // Convert dtype if stash_type is different from input dtype - auto xType = cast(x.getType()); - Value cstFalse = - rewriter.create(binder.getLoc(), false); - Value none = rewriter.create(binder.getLoc()); - if (*stashDtype != xType.getOptionalDtype()) { - auto newXType = - xType.getWithSizesAndDtype(xType.getOptionalSizes(), *stashDtype); - Value dtypeValue = rewriter.create( - binder.getLoc(), - rewriter.getI64IntegerAttr(stashTypeIntTorch.value())); - x = rewriter.create( - binder.getLoc(), newXType, x, /*dtype=*/dtypeValue, - /*non_blocking=*/cstFalse, /*copy=*/cstFalse, - /*memory_format=*/none); - } + if (*stashDtype != xType.getOptionalDtype()) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: stash_type should be same " + "as the input dtype"); Value constEpsilon = rewriter.create( binder.getLoc(), rewriter.getType(), @@ -2610,43 +2586,33 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), normalized); - SmallVector reducedShape(rank, 1); - for (int64_t i = 0; i < axis; i++) - reducedShape[i] = xShape[i]; - auto reducedType = - xType.getWithSizesAndDtype(reducedShape, *stashDtype); - auto y = rewriter.create( - binder.getLoc(), yType, /*meanType=*/reducedType, - /*invStdDevType=*/reducedType, x, normalized_shape, scale, b, - constEpsilon); - int64_t numResults = binder.op->getNumResults(); if (numResults == 1) { - rewriter.replaceOp(binder.op, y.getResult0()); + SmallVector reducedShape(rank, 1); + for (int64_t i = 0; i < axis; i++) + reducedShape[i] = xShape[i]; + auto reducedType = xType.getWithSizesAndDtype( + reducedShape, xType.getOptionalDtype()); + Value y = rewriter + .create( + binder.getLoc(), yType, /*meanType=*/reducedType, + /*invStdDevType=*/reducedType, x, normalized_shape, + scale, b, constEpsilon) + .getResult0(); + rewriter.replaceOp(binder.op, y); return success(); } - - Value meanOutput = y.getResult1(); - Value varOutput = y.getResult2(); - // Convert meanType and varType back if stash_dtype is different - if (binder.tensorResultTypeAtIndex(meanType, 1) || - binder.tensorResultTypeAtIndex(invStdDevType, 2)) - return failure(); - if (*stashDtype != meanType.getOptionalDtype()) { - Value constDtype = Torch::getDtypeIntValueForType( - rewriter, binder.getLoc(), meanType.getDtype()); - meanOutput = rewriter.create( - binder.getLoc(), meanType, meanOutput, /*dtype=*/constDtype, - /*non_blocking=*/cstFalse, /*copy=*/cstFalse, - /*memory_format=*/none); - varOutput = rewriter.create( - binder.getLoc(), invStdDevType, varOutput, /*dtype=*/constDtype, - /*non_blocking=*/cstFalse, /*copy=*/cstFalse, - /*memory_format=*/none); + if (numResults == 3) { + if (binder.tensorResultTypeAtIndex(meanType, 1) || + binder.tensorResultTypeAtIndex(invStdDevType, 2)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, yType, meanType, invStdDevType, x, normalized_shape, + scale, b, constEpsilon); + return success(); } - rewriter.replaceOp(binder.op, {y.getResult0(), meanOutput, varOutput}); - - return success(); + return rewriter.notifyMatchFailure( + binder.op, "Unimplemented: expected either 1 or 3 results"); }); patterns.onOp("LeakyRelu", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -3031,9 +2997,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( }); patterns.onOp( "Pow", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - // ONNX specifies that the result types matches the type of lhs. - // In torch, the result type is integer when both operands are integer, - // and otherwise operand types are promoted to f64. Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || @@ -3042,14 +3005,35 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } auto loc = binder.getLoc(); + auto lhsTy = cast(lhs.getType()); + auto rhsTy = cast(rhs.getType()); Value cstFalse = rewriter.create( loc, rewriter.getBoolAttr(false)); Value none = rewriter.create(loc); + auto torchDtype = Torch::getScalarTypeForType(rewriter.getF32Type()); + Value tyConst = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + static_cast(torchDtype))); + + if (isa(lhsTy.getDtype())) { + lhsTy = rewriter.getType( + lhsTy.getSizes(), rewriter.getF32Type()); + lhs = rewriter.create(loc, lhsTy, lhs, tyConst, + cstFalse, cstFalse, none); + } + + if (isa(rhsTy.getDtype())) { + rhsTy = rewriter.getType( + rhsTy.getSizes(), rewriter.getF32Type()); + rhs = rewriter.create(loc, rhsTy, rhs, tyConst, + cstFalse, cstFalse, none); + } auto powType = resultType; if (isa(resultType.getDtype())) { powType = rewriter.getType( - resultType.getSizes(), rewriter.getF64Type()); + resultType.getSizes(), rewriter.getF32Type()); } Value pow = rewriter.create(loc, powType, @@ -3688,7 +3672,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( patterns.onOp( "NonMaxSuppression", 10, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Location loc = binder.getLoc(); Torch::ValueTensorType resultType; SmallVector operands; int64_t centerPointBox; @@ -3703,132 +3686,89 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, "unimplemented: expected center_point_box " "attribute value to be 0"); - // TODO: Support multiple batches and classes + // TODO: Add support for optional arguments to be absent. + if (operands.size() != 5) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: expected all 5 args to be present"); + // Squeeze the boxes and scores tensor. // In Onnx, the shape of boxes is [BxNx4] while the // torchvision expects it to be of shape [Nx4]. Similarly, for // the scores tensor shape in Onnx is [BxCxN] while the // torchvision expects it to be of shape [N]. Value boxes = operands[0], scores = operands[1]; - FailureOr squeezedBoxes = - Torch::squeezeTensor(rewriter, binder.op, loc, 0, boxes); + FailureOr squeezedBoxes = Torch::squeezeTensor( + rewriter, binder.op, binder.getLoc(), 0, boxes); if (failed(squeezedBoxes)) return rewriter.notifyMatchFailure(binder.op, "failed to squeeze boxes tensor"); - FailureOr squeezedScores = - Torch::squeezeTensor(rewriter, binder.op, loc, 0, scores); + + FailureOr squeezedScores = Torch::squeezeTensor( + rewriter, binder.op, binder.getLoc(), 0, scores); if (failed(squeezedScores)) return rewriter.notifyMatchFailure(binder.op, "failed to squeeze scores tensor"); - squeezedScores = Torch::squeezeTensor(rewriter, binder.op, loc, 0, - squeezedScores.value()); + squeezedScores = Torch::squeezeTensor( + rewriter, binder.op, binder.getLoc(), 0, squeezedScores.value()); if (failed(squeezedScores)) return rewriter.notifyMatchFailure(binder.op, "failed to squeeze scores tensor"); + boxes = squeezedBoxes.value(); scores = squeezedScores.value(); - // TODO: Support score_threshold input - // Filter out the boxes if the score < score_threshold - if (operands.size() == 5) { - Value scoreThreshold = rewriter.create( - loc, rewriter.getType(), operands[4]); - Value minScores = rewriter.create( - loc, - Torch::ValueTensorType::get(binder.op->getContext(), - SmallVector{}, - rewriter.getF32Type()), - scores); - minScores = rewriter.create( - loc, rewriter.getType(), minScores); - - Value scoresCond = rewriter.create( - loc, minScores, scoreThreshold); - rewriter.create( - loc, scoresCond, - rewriter.getStringAttr( - "unimplemented: score_threshold should be <= min(scores)")); - } - - // Get max_output_boxes_per_class and iou_threshold - Value cst0 = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value cst1 = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value maxOutputBoxesPerClass = cst0; - Value iouThreshold = rewriter.create( - loc, rewriter.getF64FloatAttr(0.0)); - if (operands.size() > 3 && - !isa(operands[3].getType())) { - iouThreshold = rewriter.create( - loc, rewriter.getType(), operands[3]); - } - if (operands.size() > 2 && - !isa(operands[2].getType())) { - maxOutputBoxesPerClass = rewriter.create( - loc, rewriter.getType(), operands[2]); - } - - auto nmsTy = Torch::ValueTensorType::get( - binder.op->getContext(), SmallVector{-1}, - rewriter.getIntegerType(64, /*signed=*/true)); + // TODO: Add support for handling score_threshold arg. + // If score_threshold > min(scores) then the op can't be lowered since + // the torchvision::nms op doesn't have support for handling the + // score_threshold arg. + Value scoreThreshold = rewriter.create( + binder.getLoc(), rewriter.getType(), operands[4]); + Value minScores = rewriter.create( + binder.getLoc(), + Torch::ValueTensorType::get(binder.op->getContext(), {}, + rewriter.getF32Type()), + scores); + minScores = rewriter.create( + binder.getLoc(), rewriter.getType(), minScores); + + Value scoresCond = rewriter.create( + binder.getLoc(), minScores, scoreThreshold); + rewriter.create( + binder.getLoc(), scoresCond, + rewriter.getStringAttr( + "unimplemented: score_threshold should be <= min(scores)")); + + Value iouThreshold = rewriter.create( + binder.getLoc(), rewriter.getType(), operands[3]); Value result = rewriter.create( - loc, nmsTy, boxes, scores, iouThreshold); - - // Slice the result if numOutputBoxes (N) > max_output_boxes_per_class - Value numOutputBoxes = - rewriter.create(loc, result, cst0); - Value boxesCond = rewriter.create( - loc, numOutputBoxes, maxOutputBoxesPerClass); - - auto nmsResultTy = Torch::ValueTensorType::get( - binder.op->getContext(), - SmallVector{resultType.getSizes()[0]}, - rewriter.getIntegerType(64, /*signed=*/true)); - auto ifSlice = rewriter.create( - loc, TypeRange({nmsResultTy}), boxesCond); - { - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.createBlock(&ifSlice.getThenRegion(), - ifSlice.getThenRegion().begin()); - - Value curResult = rewriter.create( - loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0, - /*end=*/maxOutputBoxesPerClass, /*step=*/cst1); - rewriter.create(loc, curResult); - } - { - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.createBlock(&ifSlice.getElseRegion(), - ifSlice.getElseRegion().begin()); - - Value curResult = rewriter.create( - loc, nmsResultTy, result); - rewriter.create(loc, curResult); - } - result = ifSlice.getResult(0); + binder.getLoc(), resultType, boxes, scores, iouThreshold); // The result generated by torchvision.nms op is of shape [n], while the // onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor // and make it of shape [n, 1] and then concatenate it with a zero // tensor of shape [n, 2] to make it of shape [n, 3]. + Value dim = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); FailureOr unsqueezedResult = - Torch::unsqueezeTensor(rewriter, binder.op, result, cst1); + Torch::unsqueezeTensor(rewriter, binder.op, result, dim); if (failed(unsqueezedResult)) return rewriter.notifyMatchFailure( binder.op, "failed to unsqueeze result tensor"); result = unsqueezedResult.value(); - numOutputBoxes = - rewriter.create(loc, result, cst0); + Value numOutputBoxes = rewriter.create( + binder.getLoc(), result, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0))); SmallVector zerosShapeValues{numOutputBoxes}; zerosShapeValues.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(2))); + binder.getLoc(), rewriter.getI64IntegerAttr(2))); Value zerosShapeList = rewriter.create( - loc, + binder.getLoc(), rewriter.getType( rewriter.getType()), zerosShapeValues); + std::optional> resultShape = cast(result.getType()).getOptionalSizes(); if (!resultShape.has_value()) @@ -3837,9 +3777,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( llvm::SmallVector zerosShape = {resultShape->front(), 2}; auto zerosTy = Torch::ValueTensorType::get( resultType.getContext(), zerosShape, resultType.getOptionalDtype()); - Value cstNone = rewriter.create(loc); + Value cstNone = rewriter.create(binder.getLoc()); Value zeros = rewriter.create( - loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone); + binder.getLoc(), zerosTy, zerosShapeList, cstNone, cstNone, cstNone, + cstNone); Type listElemType = cast(resultType) @@ -3847,9 +3788,26 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); Value tensorList = rewriter.create( - loc, listType, SmallVector{zeros, result}); + binder.op->getLoc(), listType, SmallVector{result, zeros}); + + // TODO: Add support for handling max_output_boxes_per_class arg. + // If numOutputBoxes (N) > max_output_boxes_per_class then the op can't + // be lowered since the torchvision::nms op doesn't have support for + // handling the max_output_boxes_per_class arg. Also, we have already + // constrained the number of classes to be 1 above, so the number of + // output boxes inferred from the result is num_output_boxes_per_class. + Value maxOutputBoxesPerClass = rewriter.create( + binder.getLoc(), rewriter.getType(), operands[2]); + Value boxesCond = rewriter.create( + binder.getLoc(), numOutputBoxes, maxOutputBoxesPerClass); + rewriter.create( + binder.getLoc(), boxesCond, + rewriter.getStringAttr( + "unimplemented: number of output boxes per class should be " + "<= max_output_boxes_per_class")); + rewriter.replaceOpWithNewOp(binder.op, resultType, - tensorList, cst1); + tensorList, dim); return success(); }); } diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 85b51ca7efaa..ea2e0452eb7f 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -36,24 +36,21 @@ namespace { // we provide the original operand through storeResult, which will be modified // if the result will be passed onto another operation, and will be used for // noop_with_empty_axes handling before that. -template -LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter, - Value data, Torch::ValueTensorType resultType, - Value &storeResult, int64_t keepDims, - int64_t noop_with_empty_axes, - bool isIntermediateOp) { - - auto inputType = dyn_cast(data.getType()); - if (!inputType) - return failure(); +LogicalResult reducedSumImpl(OpBinder binder, + ConversionPatternRewriter &rewriter, Value data, + Torch::ValueTensorType resultType, + Value &storeResult, int64_t keepDims, + int64_t noop_with_empty_axes, + bool isIntermediateOp) { + SmallVector axesList; Value axesVal; if (!binder.tensorOperandAtIndex(axesVal, 1)) { - auto axesTy = dyn_cast(axesVal.getType()); - if (!axesTy || !axesTy.areAllSizesKnown() || axesTy.getSizes().size() > 1) - return failure(); - auto axesShape = axesTy.getSizes(); - uint64_t numAxes = (axesShape.empty()) ? 1 : axesShape.front(); + auto inputType = dyn_cast(data.getType()); + if (!inputType.hasSizes() || !resultType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: expected input and result to have shapes"); + } if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) { SmallVector inputShape{inputType.getSizes()}; @@ -80,25 +77,22 @@ LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter, } else { reduceDims.push_back(i); if (resultShapeCounter < resultShape.size() && - resultShape[resultShapeCounter] == 1 && keepDims == 1) + resultShape[resultShapeCounter] == 1) resultShapeCounter++; } } - if (reduceDims.size() == numAxes) { - for (auto i : reduceDims) { - axesList.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); - } - } else - binder.op->emitWarning( - "Number of inferred reduce dims, " + - std::to_string(reduceDims.size()) + - ", does not match the provided number of axes, " + - std::to_string(numAxes) + "."); + for (auto i : reduceDims) { + axesList.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } } } if (axesList.empty()) { - if (axesTy.getSizes()[0] == Torch::kUnknownSize) + Torch::BaseTensorType axesType = + cast(axesVal.getType()); + auto axesTy = dyn_cast(axesVal.getType()); + auto axesShape = axesTy.getSizes(); + if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize) return failure(); Value zero = rewriter.create( @@ -106,8 +100,9 @@ LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter, rewriter.getI64IntegerAttr(0)); SmallVector selectSizes{1}; auto selType = rewriter.getType( - selectSizes, axesTy.getOptionalDtype()); - for (uint64_t i = 0; i < numAxes; ++i) { + selectSizes, axesType.getOptionalDtype()); + int64_t numAxes = axesShape[0]; + for (int64_t i = 0; i < numAxes; ++i) { Value iv = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getI64IntegerAttr(i)); @@ -122,60 +117,38 @@ LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter, SmallVector axesInts; if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) { - for (int64_t i : axesInts) { - axesList.push_back( - rewriter.create(binder.getLoc(), i)); + for (int64_t i = 0, s = axesInts.size(); i < s; ++i) { + Value iv = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(axesInts[i])); + axesList.push_back(iv); } } // Do not include absolute value in the noop - if (axesList.empty() && noop_with_empty_axes == 1) { - if (!isIntermediateOp) - rewriter.replaceOp(binder.op, data); - else - storeResult = data; + if (axesList.empty() && noop_with_empty_axes) { + rewriter.replaceOp(binder.op, storeResult); return success(); } - // if the axes list is still empty, reduce everything. - if (axesList.empty()) { - if (keepDims == 0 && !resultType.getSizes().empty()) - return rewriter.notifyMatchFailure( - binder.op, - "no axes provided & no keepdim: expected result to be rank zero."); - if (keepDims == 1 && - (resultType.getSizes().size() != inputType.getSizes().size() || - llvm::any_of(resultType.getSizes(), - [](int64_t size) { return size != 1; }))) - return rewriter.notifyMatchFailure( - binder.op, "no axes provided & keepdim: expected result to have all " - "dimensions equal to 1."); - for (uint64_t i = 0; i < inputType.getSizes().size(); i++) { - axesList.push_back( - rewriter.create(binder.getLoc(), i)); - } - } - Value dimValueList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), axesList); Value keepDimBool = rewriter.create(binder.getLoc(), keepDims); - // If we are using the reduction op as an intermediate op to be passed into + Value dType = rewriter.create(binder.getLoc()); + // If we are using the ReducedSum as an intermediate op to be passed into // another operation, we might not want to replace the Op. So we create a new // Op and store the result in a variable. - SmallVector operands = {data, dimValueList, keepDimBool}; - if (llvm::is_one_of()) - operands.push_back( - /*dtype=*/rewriter.create(binder.getLoc())); if (!isIntermediateOp) { - rewriter.replaceOpWithNewOp(binder.op, resultType, - operands); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimBool, + /*dtype=*/dType); } else { - storeResult = rewriter.create(binder.getLoc(), - resultType, operands); + storeResult = rewriter.create( + binder.getLoc(), resultType, data, dimValueList, keepDimBool, + /*dtype=*/dType); } return success(); } @@ -1066,25 +1039,25 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, operand, vAlpha, vScale, vInputScale); return success(); }); - patterns.onOp( - "ReduceL1", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - int64_t keepDims, noop_with_empty_axes; - Value operand; - if (binder.tensorOperandAtIndex(operand, 0) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", - 0)) - return failure(); + patterns.onOp("ReduceL1", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + int64_t keepDims, noop_with_empty_axes; + Value operand; + if (binder.tensorOperandAtIndex(operand, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, + "noop_with_empty_axes", 0)) + return failure(); - Value data = rewriter.create( - binder.getLoc(), operand.getType(), operand); + Value data = rewriter.create( + binder.getLoc(), operand.getType(), operand); - return reduceOpImpl( - binder, rewriter, data, resultType, - /*storeValue=*/operand, keepDims, noop_with_empty_axes, false); - }); + return reducedSumImpl(binder, rewriter, data, resultType, + /*storeValue=*/operand, keepDims, + noop_with_empty_axes, false); + }); patterns.onOp( "ReduceL2", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -1102,9 +1075,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value squareOfOperand = rewriter.create( binder.getLoc(), operand.getType(), operand, operand); - auto reducedSum = reduceOpImpl( - binder, rewriter, squareOfOperand, resultType, operand, keepDims, - noop_with_empty_axes, true); + auto reducedSum = + reducedSumImpl(binder, rewriter, squareOfOperand, resultType, + operand, keepDims, noop_with_empty_axes, true); if (failed(reducedSum)) return rewriter.notifyMatchFailure( binder.op, @@ -1139,32 +1112,32 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*memory_format=*/noneVal); return success(); }); - patterns.onOp( - "ReduceLogSum", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value data; - int64_t keepDims, noop_with_empty_axes; - if (binder.tensorOperandAtIndex(data, 0) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", - 0)) - return failure(); + patterns.onOp("ReduceLogSum", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, + "noop_with_empty_axes", 0)) + return failure(); - auto reducedSumBool = reduceOpImpl( - binder, rewriter, data, resultType, - /*storeValue=*/data, keepDims, noop_with_empty_axes, true); + auto reducedSumBool = + reducedSumImpl(binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, + noop_with_empty_axes, true); - if (failed(reducedSumBool)) - return rewriter.notifyMatchFailure( - binder.op, - "Failed to perform sum operation on square of operand"); + if (failed(reducedSumBool)) + return rewriter.notifyMatchFailure( + binder.op, + "Failed to perform sum operation on square of operand"); - rewriter.replaceOpWithNewOp(binder.op, resultType, - data); - return success(); - }); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data); + return success(); + }); patterns.onOp( "ReduceLogSumExp", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -1196,7 +1169,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.getLoc(), f64ResultType, dataCast); auto f64ReduceType = rewriter.getType( resultType.getOptionalSizes(), rewriter.getF64Type()); - auto reducedSumBool = reduceOpImpl( + auto reducedSumBool = reducedSumImpl( binder, rewriter, dataExp, f64ReduceType, /*storeValue=*/data, keepDims, noop_with_empty_axes, true); if (failed(reducedSumBool)) @@ -1213,23 +1186,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*memory_format=*/noneVal); return success(); }); - patterns.onOp( - "ReduceSum", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value data; - int64_t keepDims, noop_with_empty_axes; - if (binder.tensorOperandAtIndex(data, 0) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", - 0)) - return failure(); - - return reduceOpImpl( - binder, rewriter, data, resultType, - /*storeValue=*/data, keepDims, noop_with_empty_axes, false); - }); - patterns.onOp("ReduceSumSquare", 1, + patterns.onOp("ReduceSum", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data; @@ -1241,15 +1198,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "noop_with_empty_axes", 0)) return failure(); - Value dataSquare = rewriter.create( - binder.getLoc(), data.getType(), data, data); - - return reduceOpImpl( - binder, rewriter, dataSquare, resultType, - /*storeValue=*/data, keepDims, noop_with_empty_axes, - false); + return reducedSumImpl(binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, + noop_with_empty_axes, false); }); - patterns.onOp("ReduceMean", 1, + patterns.onOp("ReduceSumSquare", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data; @@ -1261,18 +1214,140 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "noop_with_empty_axes", 0)) return failure(); - Value reduceSum = data; - return reduceOpImpl( - binder, rewriter, data, resultType, - /*storeValue=*/reduceSum, keepDims, noop_with_empty_axes, - false); + Value dataSquare = rewriter.create( + binder.getLoc(), data.getType(), data, data); + + return reducedSumImpl(binder, rewriter, dataSquare, + resultType, + /*storeValue=*/data, keepDims, + noop_with_empty_axes, false); }); + patterns.onOp( + "ReduceMean", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + + SmallVector axesList; + + Value axesVal; + if (!binder.tensorOperandAtIndex(axesVal, 1)) { + auto inputType = dyn_cast(data.getType()); + if (!inputType.hasSizes() || !resultType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: expected input and result to have shapes"); + } + + // If the input shape and result shape is statically known then the + // list of dims to be squeezed can be derived from those shapes. As a + // result, we don't have to wait for the dim values to be known at + // runtime which is also expected by the downstream pipeline. + if (inputType.areAllSizesKnown() && resultType.areAllSizesKnown()) { + SmallVector inputShape{inputType.getSizes()}; + SmallVector resultShape{resultType.getSizes()}; + if (llvm::equal(inputShape, resultShape)) { + // Case: none of the dimension is reduced. + rewriter.replaceOp(binder.op, data); + return success(); + } + if (areAllElementsDistinct(inputShape)) { + // The check for the input shape elements to be distinct is added + // for the cases like: + // Input: [3, 2, 2] -> Output: [3, 2] + // For the above case, from the input and output shape it can't be + // inferred whether the dim:1 is reduced or dim:2. To avoid these + // type of cases, the check has been placed. + SmallVector reduceDims; + unsigned resultShapeCounter = 0; + for (unsigned i = 0; i < inputShape.size(); i++) { + if (resultShapeCounter < resultShape.size() && + inputShape[i] == resultShape[resultShapeCounter]) { + resultShapeCounter++; + } else { + reduceDims.push_back(i); + if (resultShapeCounter < resultShape.size() && + resultShape[resultShapeCounter] == 1) + resultShapeCounter++; + } + } + for (auto i : reduceDims) { + axesList.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + } + } + + if (axesList.empty()) { + Torch::BaseTensorType axesType = + cast(axesVal.getType()); + auto axesTy = dyn_cast(axesVal.getType()); + auto axesShape = axesTy.getSizes(); + if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize) + return failure(); + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(0)); + SmallVector selectSizes{1}; + auto selType = rewriter.getType( + selectSizes, axesType.getOptionalDtype()); + int64_t numAxes = axesShape[0]; + for (int64_t i = 0; i < numAxes; ++i) { + Value iv = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(i)); + Value extract = rewriter.create( + binder.getLoc(), selType, axesVal, zero, iv); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + axesList.push_back(dim); + } + } + } + + SmallVector axesInts; + if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) { + for (int64_t i = 0, s = axesInts.size(); i < s; ++i) { + Value iv = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(axesInts[i])); + axesList.push_back(iv); + } + } + + // deal with case when axes is empty + if (axesList.empty() && noop_with_empty_axes) { + rewriter.replaceOp(binder.op, data); + return success(); + } + + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + axesList); + Value keepDimBool = + rewriter.create(binder.getLoc(), keepDims); + Value noneVal = rewriter.create(binder.getLoc()); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimBool, + /*dtype=*/noneVal); + return success(); + }); patterns.onOp( "ReduceMax", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { // AtenAmaxOp allows us to pass a list of dims Torch::ValueTensorType resultType; Value data; + Value axes; int64_t keepDims; int64_t noop_with_empty_axes; if (binder.tensorOperandAtIndex(data, 0) || @@ -1337,9 +1412,87 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); } - return reduceOpImpl( - binder, rewriter, data, resultType, - /*storeValue=*/data, keepDims, noop_with_empty_axes, false); + // Previous version of the operation had the axes as an attribute: + SmallVector axesList; + llvm::SmallVector axesAttr; + if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) { + for (int i = 0, s = axesAttr.size(); i < s; ++i) { + axesList.push_back(rewriter.create( + binder.getLoc(), torchIntTy, + rewriter.getI64IntegerAttr(axesAttr[i]))); + } + } + + // Extract the axes values from the axes operand: + if (!binder.tensorOperandAtIndex(axes, 1)) { + Torch::BaseTensorType axesType = + cast(axes.getType()); + SmallVector selectSizes{1}; + Type selectResultType = axesType.getWithSizesAndDtype( + selectSizes, axesType.getOptionalDtype()); + auto sizes = axesType.getSizes(); + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + // Extract the value of each axes: + for (int i = 0; i < sizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + axesList.push_back(dim); + } + } + + // Handle the noop case: + if (axesList.empty() && noop_with_empty_axes) { + rewriter.replaceOp(binder.op, data); + return success(); + } + + // Deal with case when no axes arg is passed but not a noop: + if (axesList.empty()) { + int64_t numDims = dyn_cast(data.getType()) + .getSizes() + .size(); + for (int i = 0; i < numDims; i++) { + Value curr = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + axesList.push_back(curr); + } + } + + // Handle negative axis: + Value rankVal = rewriter.create(binder.getLoc(), + torchIntTy, data); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(0)); + for (Value &axes : axesList) { + Value isNegative = + rewriter.create(binder.getLoc(), axes, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, rankVal); + axes = rewriter.create(binder.getLoc(), axes, + finalOffset); + } + + Value dimValueList = rewriter.create( + binder.getLoc(), Torch::ListType::get(torchIntTy), axesList); + Value keepDimBool = + rewriter.create(binder.getLoc(), keepDims); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimBool); + return success(); }); patterns.onOp( @@ -1348,6 +1501,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // AtenAminOp allows us to pass a list of dims Torch::ValueTensorType resultType; Value data; + Value axes; int64_t keepDims; int64_t noop_with_empty_axes; if (binder.tensorOperandAtIndex(data, 0) || @@ -1411,9 +1565,87 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); } - return reduceOpImpl( - binder, rewriter, data, resultType, - /*storeValue=*/data, keepDims, noop_with_empty_axes, false); + // Previous version of the operation had the axes as an attribute: + SmallVector axesList; + llvm::SmallVector axesAttr; + if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) { + for (int i = 0, s = axesAttr.size(); i < s; ++i) { + axesList.push_back(rewriter.create( + binder.getLoc(), torchIntTy, + rewriter.getI64IntegerAttr(axesAttr[i]))); + } + } + + // Extract the axes values from the axes operand: + if (!binder.tensorOperandAtIndex(axes, 1)) { + Torch::BaseTensorType axesType = + cast(axes.getType()); + SmallVector selectSizes{1}; + Type selectResultType = axesType.getWithSizesAndDtype( + selectSizes, axesType.getOptionalDtype()); + auto sizes = axesType.getSizes(); + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + // Extract the value of each axes: + for (int i = 0; i < sizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + axesList.push_back(dim); + } + } + + // Handle the noop case: + if (axesList.empty() && noop_with_empty_axes) { + rewriter.replaceOp(binder.op, data); + return success(); + } + + // Deal with case when no axes arg is passed but not a noop: + if (axesList.empty()) { + int64_t numDims = dyn_cast(data.getType()) + .getSizes() + .size(); + for (int i = 0; i < numDims; i++) { + Value curr = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + axesList.push_back(curr); + } + } + + // Handle negative axis: + Value rankVal = rewriter.create(binder.getLoc(), + torchIntTy, data); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(0)); + for (Value &axes : axesList) { + Value isNegative = + rewriter.create(binder.getLoc(), axes, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, rankVal); + axes = rewriter.create(binder.getLoc(), axes, + finalOffset); + } + + Value dimValueList = rewriter.create( + binder.getLoc(), Torch::ListType::get(torchIntTy), axesList); + Value keepDimBool = + rewriter.create(binder.getLoc(), keepDims); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimBool); + return success(); }); patterns.onOp( @@ -2690,7 +2922,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( llvm::SmallVector operands; std::string mode, nearest_mode, coordTfMode; int64_t antialias, exclude_outside; - float extrapolation_value, cubic_coeff_a; + float extrapolation_value; Value noneVal = rewriter.create(binder.getLoc()); if (auto attr = binder.op->getAttr("torch.onnx.axes")) { @@ -2715,8 +2947,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.f32FloatAttr(extrapolation_value, "extrapolation_value", 0.0) || binder.customOpNameStringAttr(nearest_mode, "nearest_mode", - "round_prefer_floor") || - binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75)) + "round_prefer_floor")) return failure(); if (antialias != 0) { return rewriter.notifyMatchFailure( @@ -2745,11 +2976,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "except asymmetric and half_pixel"); } - if (mode == "cubic" && cubic_coeff_a != -0.75) { - return rewriter.notifyMatchFailure( - binder.op, "unimplemented: cubic coeff must be -0.75"); - } - unsigned rank = dyn_cast(operands[0].getType()) .getSizes() .size(); @@ -2765,11 +2991,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value alignCorners = coordTfMode == "align_corners" ? cstTrue : cstFalse; if (mode == "cubic") { - std::string modeStr = "cubic"; - if (coordTfMode != "half_pixel") - modeStr = modeStr + "_" + coordTfMode; - modeStrValue = - rewriter.create(binder.getLoc(), modeStr); + return rewriter.notifyMatchFailure(binder.op, + "unimplemented: bicubic mode"); } // supported modes: // bilinear (half_pixel), bilinear with align_corners, diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index baed74fed6dc..143b46694030 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -74,36 +74,14 @@ class ConvertAtenBinaryOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Value a = adaptor.getA(); Value b = adaptor.getB(); - if (llvm::is_one_of::value || - llvm::is_one_of::value) + if (llvm::is_one_of::value) b = convertScalarToDtype(rewriter, op.getLoc(), b, a.getType()); - if (llvm::is_one_of::value) - a = convertScalarToDtype(rewriter, op.getLoc(), a, b.getType()); rewriter.template replaceOpWithNewOp(op, a, b); return success(); } }; } // namespace -namespace { -class ConvertAtenNegIntOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AtenNegIntOp op, - typename OpConversionPattern::OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value a = adaptor.getA(); - rewriter.replaceOpWithNewOp( - op, - rewriter.create(op.getLoc(), /*value=*/0, - /*bitwidth=*/64), - a); - return success(); - } -}; -} // namespace - namespace { template class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern { @@ -454,14 +432,11 @@ class ConvertTorchToArith patterns.add< ConvertAtenIntComparisonOp>( typeConverter, context); - target.addIllegalOp(); + target.addIllegalOp(); patterns.add< ConvertAtenFloatComparisonOp>( typeConverter, context); - patterns.add< - ConvertAtenFloatComparisonOp>( - typeConverter, context); patterns.add>( typeConverter, context); @@ -490,25 +465,17 @@ class ConvertTorchToArith target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + target.addIllegalOp(); + AtenMulIntOp>(); patterns.add>( typeConverter, context); - patterns.add>( - typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); patterns.add>( typeConverter, context); - patterns.add>( - typeConverter, context); - patterns.add>( - typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index b8c20bc73f65..a18c0bae01fc 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1642,18 +1642,69 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); + Value input = adaptor.getSelf(); + auto inputType = cast(input.getType()); + int64_t inputRank = inputType.getRank(); + + if (inputRank == 0) { + return rewriter.notifyMatchFailure( + op, "zero input rank should have been handled by the folder"); + } + int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "dim must be constant"); + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - auto squeezeTensorInfo = - squeezeTensor(rewriter, op, adaptor.getSelf(), dim); - if (failed(squeezeTensorInfo)) { - return rewriter.notifyMatchFailure(op, - "cannot generate unsqueeze tensor"); + // assert dynamic squeeze dim size == 1 + if (inputType.isDynamicDim(dim)) { + Value cstDim = rewriter.create(op.getLoc(), dim); + Value dimVal = rewriter.create(op.getLoc(), input, cstDim); + Value cstOne = rewriter.create(op.getLoc(), 1); + Value cmp = rewriter.create( + op.getLoc(), arith::CmpIPredicate::eq, dimVal, cstOne); + rewriter.create( + op.getLoc(), cmp, + rewriter.getStringAttr( + "Expected dynamic squeeze dim size to be statically 1")); } - rewriter.replaceOp(op, squeezeTensorInfo.value()); + const TypeConverter *typeConverter = getTypeConverter(); + auto resultType = + cast(typeConverter->convertType(op.getType())); + int64_t resultRank = resultType.getRank(); + + // If the dim(th) dimension of operand tensor type is not statically unit, + // `aten.squeeze` will behave as an identity operation. + if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) { + rewriter.replaceOpWithNewOp(op, resultType, input); + return success(); + } + + SmallVector reassociationMap(resultRank); + bool alreadyCrossedSqueezedDim = false; + for (int i = 0; i != resultRank; i++) { + if (alreadyCrossedSqueezedDim) { + reassociationMap[i].push_back(i + 1); + } else { + reassociationMap[i].push_back(i); + if (dim != 0 && i != dim - 1) + continue; + + alreadyCrossedSqueezedDim = true; + if (dim == 0) + reassociationMap[0].push_back(1); + if (i == dim - 1) + reassociationMap[i].push_back(dim); + } + } + // Note: In case the operand tensor type is of unit rank and is statically + // shaped with unit dimension, the `reassociationMap` will be empty and the + // input will be collapsed to a 0-D tensor. + rewriter.replaceOpWithNewOp(op, resultType, input, + reassociationMap); return success(); } }; @@ -1671,15 +1722,36 @@ class ConvertAtenUnsqueezeOp : public OpConversionPattern { int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return rewriter.notifyMatchFailure(op, "dim must be constant"); + auto inputRank = + cast(adaptor.getSelf().getType()).getRank(); + dim = toPositiveDim(dim, inputRank + 1); + if (!isValidDim(dim, inputRank + 1)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - auto unsqueezeTensorInfo = - unsqueezeTensor(rewriter, op, adaptor.getSelf(), dim); - if (failed(unsqueezeTensorInfo)) { - return rewriter.notifyMatchFailure(op, - "cannot generate unsqueeze tensor"); + SmallVector reassociationMap(inputRank); + // From the perspective of the reassociation map, the situation of + // unsqueezing before or after the last dimension is symmetrical. + // Normalize it to the "before" case. + // The 0 case is special here, since there is no last dimension to insert + // before -- we simply rely on the loop below iterating 0 times. + if (dim == inputRank && inputRank != 0) + dim = inputRank - 1; + bool alreadyCrossedExpandedDim = false; + for (int i = 0; i != inputRank; i++) { + if (alreadyCrossedExpandedDim) { + reassociationMap[i].push_back(i + 1); + } else { + reassociationMap[i].push_back(i); + if (i == dim) { + reassociationMap[i].push_back(i + 1); + alreadyCrossedExpandedDim = true; + } + } } - - rewriter.replaceOp(op, unsqueezeTensorInfo.value()); + auto resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); + rewriter.replaceOpWithNewOp( + op, resultType, adaptor.getSelf(), reassociationMap); return success(); } }; diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 9073c5846f33..9c914690bbf4 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -11,7 +11,6 @@ #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/IR/Matchers.h" @@ -727,21 +726,15 @@ class ConvertAtenBmmOp : public OpConversionPattern { // Check the matrixs shapes are valid for mulplication. checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1); - Type accumulatorDType = getDefaultAccType(rewriter, resultElementType); Value initTensor0 = createZeroInitTensor( - rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, accumulatorDType); + rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, + resultElementType); Value bmm = rewriter .create(loc, initTensor0.getType(), ValueRange{lhs, rhs}, initTensor0) .getResult(0); - - if (accumulatorDType != resultElementType) { - bmm = torch_to_linalg::convertTensorToElementType(rewriter, loc, bmm, - resultElementType); - } - rewriter.replaceOpWithNewOp(op, newResultType, bmm); return success(); } @@ -856,48 +849,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "only support constant int dilations"); - // Checks for valid group size - int64_t numGroups; - if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups))) - return rewriter.notifyMatchFailure(op, - "only constant group size supported."); - Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups()); - - // Adding support for 1d group convolution by converting the 1d-conv to - // 2d-conv. - // TODO: Replace this logic with the appropriate linalg op for 1-d group - // convolution once that support is added. - bool is1DGroupConv = (numSpatialDims == 1 && numGroups != 1); - if (is1DGroupConv) { - // Unsqueezing the last dim of input and weight. Also extending the - // dilation, stride, padding, and output padding lists. - auto unsqueezeInputInfo = - unsqueezeTensor(rewriter, op, input, /*dim=*/-1); - if (failed(unsqueezeInputInfo)) { - return rewriter.notifyMatchFailure(op, - "cannot generate unsqueeze tensor"); - } - input = unsqueezeInputInfo.value(); - - auto unsqueezeWeightInfo = - unsqueezeTensor(rewriter, op, weight, /*dim=*/-1); - if (failed(unsqueezeWeightInfo)) { - return rewriter.notifyMatchFailure(op, - "cannot generate unsqueeze tensor"); - } - weight = unsqueezeWeightInfo.value(); - - Value cstZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - paddingIntValues.push_back(cstZero); - outputPaddingIntValues.push_back(cstZero); - strideInts.push_back(1); - dilationInts.push_back(1); - - inRank++; - numSpatialDims++; - } - Value inBatch = getDimOp(rewriter, loc, input, 0); Value inChannels = getDimOp(rewriter, loc, input, 1); SmallVector inDims; @@ -909,6 +860,13 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { for (size_t i = 2; i < inRank; i++) weightDims.push_back(getDimOp(rewriter, loc, weight, i)); + // Checks for valid group size + int64_t numGroups; + if (!matchPattern(op.getGroups(), m_TorchConstantInt(&numGroups))) + return rewriter.notifyMatchFailure(op, + "only constant group size supported."); + Value groups = castIntToIndex(rewriter, loc, adaptor.getGroups()); + auto validate = [&](Value toValidate, std::string err) { Value c0 = rewriter.create(loc, rewriter.getIndexAttr(0)); @@ -1321,24 +1279,13 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, resultElementType); } - - if (is1DGroupConv) { - // Squeezing the last dim of the result of conv. - auto squeezeOutputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1); - if (failed(squeezeOutputInfo)) { - return rewriter.notifyMatchFailure(op, - "cannot generate squeeze tensor"); - } - conv = squeezeOutputInfo.value(); - } - rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } if (numSpatialDims != 2) return rewriter.notifyMatchFailure( - op, "unimplemented: only 1D and 2D grouped convolution supported"); + op, "unimplemented: only 2D grouped convolution supported"); // Grouped case, use the grouped conv linalg op auto expandGroups = [&](Value tensor, size_t dim) { @@ -1423,210 +1370,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, resultElementType); } - - if (is1DGroupConv) { - // Squeezing the last dim of the result of conv. - auto squeezeOutputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1); - if (failed(squeezeOutputInfo)) { - return rewriter.notifyMatchFailure(op, - "cannot generate squeeze tensor"); - } - conv = squeezeOutputInfo.value(); - } rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } }; } // namespace -namespace { - -/// Creates coefficients based on DFT definition, see -/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform. -Value getDFTMatmulCoeff(OpBuilder b, Location loc, - RankedTensorType matrixType) { - - ComplexType complexTy = llvm::cast(matrixType.getElementType()); - mlir::FloatType floatType = - llvm::cast(complexTy.getElementType()); - - // scale = 2 * pi / N - double scale = 2 * M_PI / matrixType.getDimSize(0); - - SmallVector> values; - for (auto i : llvm::seq(0, matrixType.getDimSize(0))) { - for (auto j : llvm::seq(0, matrixType.getDimSize(1))) { - double v = scale * i * j; - double realV = cos(v); - double imagV = -sin(v); - - bool unused; - APFloat real(realV); - real.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven, - &unused); - APFloat imag(imagV); - imag.convert(floatType.getFloatSemantics(), APFloat::rmNearestTiesToEven, - &unused); - - values.push_back(std::complex(real, imag)); - } - } - return b.create( - loc, matrixType, DenseElementsAttr::get(matrixType, values)); -} - -struct ConvertAtenFftRfftOp final : OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(AtenFftRfftOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value self = adaptor.getSelf(); - - int64_t dim; - auto dimVal = op.getDim(); - if (isa(dimVal.getType())) { - dim = -1; - } else if (!matchPattern(dimVal, torch::Torch::m_TorchConstantInt(&dim))) { - return rewriter.notifyMatchFailure( - op, "unimplemented: requires dim to be constant"); - } - - if (!isa(op.getN().getType())) { - return rewriter.notifyMatchFailure(op, "unimplemented: parameter n"); - } - - if (!isa(op.getNorm().getType())) { - return rewriter.notifyMatchFailure(op, "unimplemented: parameter norm"); - } - - RankedTensorType inputType = - cast(adaptor.getSelf().getType()); - if (!inputType.hasRank()) { - return rewriter.notifyMatchFailure( - op, "unsupported: only ranked tensors are supported"); - } - - const ArrayRef inputShape = inputType.getShape(); - dim += dim < 0 ? inputShape.size() : 0; - - const int64_t fftLength = inputShape[dim]; - if (fftLength == ShapedType::kDynamic) { - return rewriter.notifyMatchFailure( - op, "unsupported: FFT signal length must be static"); - } - const int64_t rank = inputType.getRank(); - const int64_t lastDim = rank - 1; - const int64_t outputFftDim = fftLength / 2 + 1; - const bool needTranspose = dim != lastDim; - - // Transpose if FFT dimension is not the last one - llvm::SmallVector perms = llvm::to_vector(llvm::seq(rank)); - std::swap(perms[dim], perms[lastDim]); - if (needTranspose) { - self = transposeValue(loc, self, perms, rewriter); - } - - RankedTensorType newResultType = llvm::cast( - getTypeConverter()->convertType(op.getType())); - ComplexType complexElemType = - llvm::cast(newResultType.getElementType()); - Type elemType = complexElemType.getElementType(); - - // coeffMatrix : tensor> - RankedTensorType coeffType = - RankedTensorType::get({fftLength, outputFftDim}, complexElemType); - // coeffMatrix(n,m) = cos(2 pi n m / N) - j sin(2 pi n m / N) - Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, coeffType); - - // #matmul_trait = { - // indexing_maps = [ - // affine_map<(d_0, ... d_m, f, o) -> (d_0, ... d_m, f)>, - // affine_map<(d_0, ... d_m, f, o) -> (f, o)>, - // affine_map<(d_0, ... d_m, f, o) -> (d_0, ... d_m, o)> - // ], - // iterator_types = ["parallel", ..., "parallel", "reduction", "parallel"] - // } - // linalg.generic #matmul_trait - // ins(%A, %B : tensor, - // tensor>) - // outs(%C : tensor>) { - // ^bb0(%a: f32, %b: complex, %c: complex) : - // %re = complex.re %b : f32 - // %im = complex.im %b : f32 - // %mulre = arith.mulf %a, %re: f32 - // %mulim = arith.mulf %a, %im: f32 - // %mulcplx = complex.create %mulre, %mulim : complex - // %add = complex.add %c, %mulcplx: complex - // linalg.yield %add : complex - // } -> (tensor>) - - Value lhs = self; - Value rhs = coeffMatrix; - RankedTensorType lhsType = llvm::cast(lhs.getType()); - ArrayRef lhsShape(lhsType.getShape()); - ArrayRef rhsShape(coeffType.getShape()); - - unsigned batchRank = lhsShape.size() - 1; - - SmallVector lhsExpr; - SmallVector rhsExpr; - SmallVector outExpr; - SmallVector iteratorTypes( - batchRank, utils::IteratorType::parallel); - SmallVector resultShape; - for (unsigned i = 0; i < batchRank; i++) { - lhsExpr.push_back(rewriter.getAffineDimExpr(i)); - outExpr.push_back(rewriter.getAffineDimExpr(i)); - resultShape.push_back(getDimOp(rewriter, loc, lhs, i)); - } - unsigned fIdx = batchRank, oIdx = batchRank + 1; - lhsExpr.insert(lhsExpr.end(), {rewriter.getAffineDimExpr(fIdx)}); - rhsExpr.insert(rhsExpr.end(), {rewriter.getAffineDimExpr(fIdx), - rewriter.getAffineDimExpr(oIdx)}); - outExpr.insert(outExpr.end(), {rewriter.getAffineDimExpr(oIdx)}); - resultShape.insert(resultShape.end(), - {getDimOp(rewriter, loc, rhs, rhsShape.size() - 1)}); - - Value zeroTensor = - createZeroInitTensor(rewriter, loc, resultShape, complexElemType); - auto indexingMaps = AffineMap::inferFromExprList( - {lhsExpr, rhsExpr, outExpr}, rewriter.getContext()); - iteratorTypes.insert(iteratorTypes.end(), {utils::IteratorType::reduction, - utils::IteratorType::parallel}); - - Value complexRes = - rewriter - .create( - loc, zeroTensor.getType(), - /*inputs=*/ValueRange{lhs, rhs}, - /*outputs=*/zeroTensor, indexingMaps, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value l = args[0], r = args[1], res = args[2]; - Value re = b.create(loc, elemType, r); - Value im = b.create(loc, elemType, r); - Value mulRe = b.create(loc, l, re); - Value mulIm = b.create(loc, l, im); - Value mulCplx = b.create( - loc, complexElemType, mulRe, mulIm); - Value add = b.create(loc, mulCplx, res); - b.create(loc, add); - }) - .getResult(0); - - // Transpose back - if (needTranspose) { - complexRes = transposeValue(loc, complexRes, perms, rewriter); - } - - rewriter.replaceOp(op, complexRes); - return success(); - } -}; - -} // namespace - void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -1641,6 +1390,4 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); } diff --git a/lib/Conversion/TorchToLinalg/Random.cpp b/lib/Conversion/TorchToLinalg/Random.cpp index 854e3f86d367..aa4ec91d7da5 100644 --- a/lib/Conversion/TorchToLinalg/Random.cpp +++ b/lib/Conversion/TorchToLinalg/Random.cpp @@ -186,7 +186,7 @@ class ConvertAtenUniformOp : public OpConversionPattern { Value res = randomUniformF64(b, loc, linearIndex, key, min, max); Value truncRes = res; - if (isa(elemTy)) + if (isa(elemTy)) truncRes = b.create(loc, elemTy, res); b.create(loc, truncRes); }) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index d6b5aaf869c8..c129c9614eb0 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1019,20 +1019,12 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = cast(converter->convertType(pow.getType())) .getElementType(); if (!isa(dtype)) { - // The result type is integer when both operands are integer. - // Torch then uses the following implementation: - // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Pow.h pow.emitError("unimplemented: non-floating point dtype"); return nullptr; } - Type powType = dtype; - if (payloadArgs[0].getType().isInteger() || - payloadArgs[1].getType().isInteger()) - powType = mlir::FloatType::getF64(op->getContext()); - Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], powType); - Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], powType); - auto powOp = b.create(loc, lhs, rhs); - return convertScalarToDtype(b, loc, powOp, dtype); + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + return b.create(loc, lhs, rhs); } if (auto imag = dyn_cast(op)) { @@ -2691,7 +2683,7 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { }; } // namespace -static Value nearestInterpolate(OpBuilder &b, Location loc, +static Value NearestInterpolate(OpBuilder &b, Location loc, SmallVector outputSizes, Value input, SmallVector inputSizes, SmallVector scaleValues, @@ -2779,12 +2771,12 @@ static Value nearestInterpolate(OpBuilder &b, Location loc, return retVal; } -static SmallVector coordinateTransform( - OpBuilder &b, Aten__InterpolateSizeListScaleListOp op, Location loc, - SmallVector outputSizes, Value input, SmallVector inputSizes, - SmallVector scaleValues, std::string coordStr, bool alignCornersBool, - SmallVector indices, bool clip) { - +static Value BilinearInterpolate(OpBuilder &b, + Aten__InterpolateSizeListScaleListOp op, + Location loc, SmallVector outputSizes, + Value input, SmallVector inputSizes, + SmallVector scaleValues, + std::string coordStr) { unsigned dimOffset = 2; auto inputType = cast(input.getType()); auto inputRank = inputType.getRank(); @@ -2793,7 +2785,15 @@ static SmallVector coordinateTransform( Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); Value zero = b.create(loc, b.getF32FloatAttr(0.0)); - SmallVector proj; + bool alignCornersBool; + matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool)); + + SmallVector indices; + for (unsigned i = 0; i < inputRank; i++) { + indices.push_back(b.create(loc, i)); + } + + SmallVector proj, projEps, high, low, highFP, lowFP; for (unsigned i = 0; i < inputRank - dimOffset; i++) { // length_original Value inputFP = @@ -2856,50 +2856,13 @@ static SmallVector coordinateTransform( outputSizeFP, cstOneFloat); preClip = b.create(loc, cmp, zero, preClip); } - if (clip) { - // preClip is the fp position inside the input image to extract from. - // clip to [0,inf) - Value max = b.create(loc, preClip, zero); - Value inputSubOne = b.create(loc, inputFP, cstOneFloat); - // clip to [0,length_original - 1]. - // proj is properly within the input image. - proj.push_back(b.create(loc, max, inputSubOne)); - } else { - proj.push_back(preClip); - } - } - return proj; -} - -static Value bilinearInterpolate(OpBuilder &b, - Aten__InterpolateSizeListScaleListOp op, - Location loc, SmallVector outputSizes, - Value input, SmallVector inputSizes, - SmallVector scaleValues, - std::string coordStr) { - unsigned dimOffset = 2; - auto inputType = cast(input.getType()); - auto inputRank = inputType.getRank(); - - Value cstOneFloat = b.create(loc, b.getF32FloatAttr(1.0)); - - bool alignCornersBool; - matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool)); - - SmallVector indices; - for (unsigned i = 0; i < inputRank; i++) { - indices.push_back(b.create(loc, i)); - } - - SmallVector proj, high, low, highFP, lowFP; - proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes, - scaleValues, coordStr, alignCornersBool, indices, - true); - for (unsigned i = 0; i < inputRank - dimOffset; i++) { - // length_original - Value inputFP = - b.create(loc, b.getF32Type(), inputSizes[i]); + // preClip is the fp position inside the input image to extract from. + // clip to [0,inf) + Value max = b.create(loc, preClip, zero); Value inputSubOne = b.create(loc, inputFP, cstOneFloat); + // clip to [0,length_original - 1]. + // proj is properly within the input image. + proj.push_back(b.create(loc, max, inputSubOne)); // for bilinear interpolation, we look for the nearest indices below and // above proj @@ -2963,176 +2926,6 @@ static Value bilinearInterpolate(OpBuilder &b, return b.create(loc, left, right); } -static Value bicubicInterpolate(OpBuilder &b, - Aten__InterpolateSizeListScaleListOp op, - Location loc, SmallVector outputSizes, - Value input, SmallVector inputSizes, - SmallVector scaleValues, - std::string coordStr) { - unsigned dimOffset = 2; - auto inputType = cast(input.getType()); - auto inputRank = inputType.getRank(); - - Value inputFPH = - b.create(loc, b.getF32Type(), inputSizes[0]); - Value inputFPW = - b.create(loc, b.getF32Type(), inputSizes[1]); - - Value a = b.create(loc, b.getF32FloatAttr(-0.75)); - Value zero = b.create(loc, b.getF32FloatAttr(0.0)); - Value cstOneFloat = b.create(loc, b.getF32FloatAttr(1.0)); - Value cstTwoFloat = b.create(loc, b.getF32FloatAttr(2.0)); - Value cstThreeFloat = - b.create(loc, b.getF32FloatAttr(3.0)); - Value cstFourFloat = b.create(loc, b.getF32FloatAttr(4.0)); - Value cstFiveFloat = b.create(loc, b.getF32FloatAttr(5.0)); - Value cstEightFloat = - b.create(loc, b.getF32FloatAttr(8.0)); - - // (a+2)|x|^3 - (a+3)|x|^2 + 1 for xDistance (|x| <= 1) - auto WeightLessThanEqualOne = [&](Value xDistance) -> Value { - Value xDistanceSquared = b.create(loc, xDistance, xDistance); - Value xDistanceCubed = - b.create(loc, xDistanceSquared, xDistance); - - Value lessEqualOne = b.create(loc, a, cstTwoFloat); - lessEqualOne = b.create(loc, xDistanceCubed, lessEqualOne); - Value aPlusThree = b.create(loc, a, cstThreeFloat); - aPlusThree = b.create(loc, xDistanceSquared, aPlusThree); - lessEqualOne = b.create(loc, lessEqualOne, aPlusThree); - lessEqualOne = b.create(loc, lessEqualOne, cstOneFloat); - - return lessEqualOne; - }; - - // a|x|^3 - 5a|x|^2 + 8a|x| - 4a for xDistance (1 < |x| < 2) - auto WeightLessThanTwo = [&](Value xDistance) -> Value { - Value xDistanceSquared = b.create(loc, xDistance, xDistance); - Value xDistanceCubed = - b.create(loc, xDistanceSquared, xDistance); - // a|x|^3 - Value lessThanTwo = b.create(loc, xDistanceCubed, a); - - Value fiveA = b.create(loc, xDistanceSquared, a); - fiveA = b.create(loc, fiveA, cstFiveFloat); - // a|x|^3 - 5a|x|^2 - lessThanTwo = b.create(loc, lessThanTwo, fiveA); - - Value eightA = b.create(loc, a, xDistance); - eightA = b.create(loc, eightA, cstEightFloat); - // a|x|^3 - 5a|x|^2 + 8a|x| - lessThanTwo = b.create(loc, eightA, lessThanTwo); - - Value fourA = b.create(loc, a, cstFourFloat); - // a|x|^3 - 5a|x|^2 + 8a|x| - 4a - lessThanTwo = b.create(loc, lessThanTwo, fourA); - return lessThanTwo; - }; - - bool alignCornersBool; - matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool)); - - SmallVector indices; - for (unsigned i = 0; i < inputRank; i++) { - indices.push_back(b.create(loc, i)); - } - - SmallVector proj; - - proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes, - scaleValues, coordStr, alignCornersBool, indices, - false); - - // get the nearest neighbors of proj - Value x1 = b.create(loc, proj[1]); - Value x_1 = b.create(loc, x1, cstOneFloat); - Value x_2 = b.create(loc, x_1, cstOneFloat); - Value x2 = b.create(loc, x1, cstOneFloat); - - Value y1 = b.create(loc, proj[0]); - Value y_1 = b.create(loc, y1, cstOneFloat); - Value y_2 = b.create(loc, y_1, cstOneFloat); - Value y2 = b.create(loc, y1, cstOneFloat); - - // calculate the distance of nearest neighbors x and y to proj - Value y2Distance = b.create(loc, proj[0], y2); - y2Distance = b.create(loc, y2Distance); - Value y1Distance = b.create(loc, proj[0], y1); - y1Distance = b.create(loc, y1Distance); - Value y_1Distance = b.create(loc, proj[0], y_1); - y_1Distance = b.create(loc, y_1Distance); - Value y_2Distance = b.create(loc, proj[0], y_2); - y_2Distance = b.create(loc, y_2Distance); - - Value x2Distance = b.create(loc, proj[1], x2); - x2Distance = b.create(loc, x2Distance); - Value x1Distance = b.create(loc, proj[1], x1); - x1Distance = b.create(loc, x1Distance); - Value x_1Distance = b.create(loc, proj[1], x_1); - x_1Distance = b.create(loc, x_1Distance); - Value x_2Distance = b.create(loc, proj[1], x_2); - x_2Distance = b.create(loc, x_2Distance); - - SmallVector y{y_2, y_1, y1, y2}; - SmallVector x{x_2, x_1, x1, x2}; - - SmallVector wys{ - WeightLessThanTwo(y_2Distance), WeightLessThanEqualOne(y_1Distance), - WeightLessThanEqualOne(y1Distance), WeightLessThanTwo(y2Distance)}; - SmallVector wxs{ - WeightLessThanTwo(x_2Distance), WeightLessThanEqualOne(x_1Distance), - WeightLessThanEqualOne(x1Distance), WeightLessThanTwo(x2Distance)}; - - // clip the nearest neighbors points to inside the original image - for (int k = 0; k < 4; k++) { - Value yClipped = b.create(loc, y[k], zero); - Value inputHSubOne = b.create(loc, inputFPH, cstOneFloat); - yClipped = b.create(loc, yClipped, inputHSubOne); - Value yInt = b.create(loc, b.getI64Type(), yClipped); - y[k] = b.create(loc, b.getIndexType(), yInt); - - Value xClipped = b.create(loc, x[k], zero); - Value inputWSubOne = b.create(loc, inputFPW, cstOneFloat); - xClipped = b.create(loc, xClipped, inputWSubOne); - Value xInt = b.create(loc, b.getI64Type(), xClipped); - x[k] = b.create(loc, b.getIndexType(), xInt); - } - // 1. Compute x_original and y_original (proj) - // 2. Compute nearest x and y neighbors - // 3. Compute Wx Wy - // 4. Extract inputs at nearest neighbors (inputExtracts) - // 5. Compute weighted sum (yield this) - - // 4 nearest x neighbors : [x_2, x_1, x1, x2] of x_original - // 4 nearest y neighbors : [y_2, y_1, y1, y2] of y_original - // Sum_x is over 4 nearest x neighbors (similar for Sum_y) - // f(x_original, y_original) = Sum_y Sum_x W(x_original - x)*input[x,y] - // * W(y_original - y) - Value fxy = zero; - - for (int j = 0; j < 4; j++) { - Value wy = wys[j]; - Value xInterpy = zero; - - indices[dimOffset] = y[j]; - - for (int i = 0; i < 4; i++) { - Value wx = wxs[i]; - - indices[dimOffset + 1] = x[i]; - - Value p = b.create(loc, input, indices); - - Value wxp = b.create(loc, wx, p); - xInterpy = b.create(loc, xInterpy, wxp); - } - Value wyXInterpy = b.create(loc, wy, xInterpy); - fxy = b.create(loc, fxy, wyXInterpy); - } - - return fxy; -} - namespace { class ConvertInterpolateOp : public OpConversionPattern { @@ -3148,8 +2941,7 @@ class ConvertInterpolateOp // coordinate_transformation_mode="asymmetric" will lower to an interpolate // op with the non-standard mode="bilinear_asymmetric". matchPattern(op.getMode(), m_TorchConstantStr(mode)); - if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest" && - mode.substr(0, 5) != "cubic") { + if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest") { return failure(); } @@ -3231,18 +3023,13 @@ class ConvertInterpolateOp (mode.find(",") == std::string::npos) ? "" : mode.substr(mode.find(",") + 1); - retVal = nearestInterpolate( + retVal = NearestInterpolate( b, loc, outputSizeIntValues, input, inputSizes, ScaleFactorFloatValues, coordTfMode, nearestMode); } else if (mode.substr(0, 8) == "bilinear") { - retVal = bilinearInterpolate( + retVal = BilinearInterpolate( b, op, loc, outputSizeIntValues, input, inputSizes, ScaleFactorFloatValues, mode.substr(8)); - } else if (mode.substr(0, 5) == "cubic") { - - retVal = bicubicInterpolate( - b, op, loc, outputSizeIntValues, input, inputSizes, - ScaleFactorFloatValues, mode.substr(5)); } b.create(loc, retVal); }) diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 98dbc1957892..18e8fb449ef5 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -116,22 +116,6 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, else division = b.createOrFold(loc, dividend, strideInt); Value out = b.createOrFold(loc, division, c1); - - if (ceilMode) { - Value outMinusOneTimesStride = - b.createOrFold(loc, division, strideInt); - Value inAddLeftPadding = b.createOrFold( - loc, castIndexToInt64(b, loc, in), paddingInt); - - auto reduceOutputDimCond = - b.createOrFold(loc, arith::CmpIPredicate::uge, - outMinusOneTimesStride, inAddLeftPadding); - - auto reducedDim = b.createOrFold(loc, reduceOutputDimCond, - division, out); - return castIntToIndex(b, loc, reducedDim); - } - return castIntToIndex(b, loc, out); } @@ -594,12 +578,6 @@ LogicalResult torch_to_linalg::permuteTensor(Operation *op, int64_t inputRank = inType.getRank(); Type elementType = inType.getElementType(); - // Check for 0-D tensor. - if (inputRank == 0) { - result = input; - return success(); - } - // Check if the dimensions are a valid constants. int64_t numDimensions = dimensions.size(); if (inputRank != numDimensions) @@ -618,10 +596,28 @@ LogicalResult torch_to_linalg::permuteTensor(Operation *op, Value outVector = rewriter.create( loc, getAsOpFoldResult(outputDims), elementType); - - result = - rewriter.create(loc, input, outVector, dimensions) - ->getResult(0); + SmallVector idExprs; + SmallVector swapExprs; + for (uint32_t i = 0; i < inputRank; i++) + idExprs.push_back(getAffineDimExpr(i, rewriter.getContext())); + for (uint32_t i = 0; i < inputRank; i++) + swapExprs.push_back(idExprs[dimensions[i]]); + + AffineMap inputMap = + AffineMap::get(inputRank, /*symbolCount=*/0, idExprs, op->getContext()); + AffineMap outputMap = + AffineMap::get(inputRank, /*symbolCount=*/0, swapExprs, op->getContext()); + SmallVector indexingMaps{inputMap, outputMap}; + SmallVector iteratorTypes(inputRank, + utils::IteratorType::parallel); + result = rewriter + .create( + loc, outVector.getType(), input, outVector, indexingMaps, + iteratorTypes, + [](OpBuilder &b, Location loc, ValueRange args) { + b.create(loc, args[0]); + }) + .getResult(0); return success(); } diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index d6ba57a08a8f..c4c3a874fbc4 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1143,49 +1143,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -// AtenLogitOp -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenLogitOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = op.getLoc(); - - Value self = adaptor.getSelf(); - auto selfTy = dyn_cast(self.getType()); - if (!selfTy) { - return op.emitError("only ranked tensor type is supported."); - } - - auto outTy = cast(getTypeConverter()->convertType(op.getType())); - self = hlo::promoteType(rewriter, op.getLoc(), self, outTy.getElementType()); - - selfTy = dyn_cast(self.getType()); - - Value eps = adaptor.getEps(); - auto epsTy = eps.getType(); - Value newSelf; - if (!isa(epsTy)) { - auto epsTensor = hlo::scalarToStablehloTensor(rewriter, op, eps, - selfTy.getElementType()); - Value oneEpsTensor = hlo::getConstantLike(rewriter, loc, 1.0, epsTensor); - auto max = - rewriter.create(loc, oneEpsTensor, epsTensor); - newSelf = rewriter.create(loc, epsTensor, self, max); - } else { - newSelf = self; - } - - Value one = hlo::getConstantLike(rewriter, loc, 1.0, self); - Value zi1 = rewriter.create(loc, one, newSelf); - Value newZi = rewriter.create(loc, newSelf, zi1); - - Value log = rewriter.create(loc, outTy, newZi); - - rewriter.replaceOp(op, log); - - return success(); -} - // AtenErfOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -2291,7 +2248,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenGeluOp); INSERT_ATENOP_PATTERN(AtenLog2Op); INSERT_ATENOP_PATTERN(AtenLog10Op); - INSERT_ATENOP_PATTERN(AtenLogitOp); INSERT_ATENOP_PATTERN(AtenErfOp); INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 1c2f7d6f2a11..c033dad1bbb4 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -23,7 +23,6 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" #include "llvm/ADT/TypeSwitch.h" -#include #include #include #include @@ -34,10 +33,10 @@ using namespace mlir::torch::Torch; namespace { -// These legalizations are for unary ops with promoting input to floating-point -// datatypes only. There is no supported quantized integer mode for these. +// These legalizations are for unary ops with only for floating point datatypes. +// There is no supported quantized integer mode for these. template -class ConvertAtenUnaryPromoteToFPOp : public OpConversionPattern { +class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename AtenOpT::Adaptor; @@ -51,22 +50,17 @@ class ConvertAtenUnaryPromoteToFPOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "Only Tensor types supported in TOSA"); - auto resultTy = dyn_cast( - OpConversionPattern::getTypeConverter()->convertType( - op.getType())); - - if (!isa(resultTy.getElementType())) + if (isa(selfTy.getElementType())) { + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter()->convertType( + op.getType()), + self); + return success(); + } else { return rewriter.notifyMatchFailure( - op, "Only floating-point datatype result types are supported"); - - // Non floating point inputs are not supported in TOSA so we cast the input - // to result type - if (!isa(selfTy.getElementType())) - self = tosa::promoteType(rewriter, self, resultTy); - - rewriter.replaceOpWithNewOp(op, resultTy, self); - - return success(); + op, "Only floating-point datatype legalization supported"); + } } }; @@ -405,15 +399,12 @@ class ConvertAtenCompareOp : public OpConversionPattern { Value rhsAsTensor; if (!rhsTy) { if (failed(torchScalarToTosaTensor(rewriter, op, op.getOther(), - rhsAsTensor, rhs.getType(), {}))) + rhsAsTensor, lhsElemTy, {}))) return rewriter.notifyMatchFailure( op, "Currently only scalar constants are supported for " "conversion in TOSA operation"); } auto rhsTensor = rhsTy ? rhs : rhsAsTensor; - auto rhsTensorTy = dyn_cast(rhsTensor.getType()); - auto rhsElemTy = rhsTensorTy.getElementType(); - // There is no Lesser operator in TOSA. constexpr auto swapLhsRhs = (std::is_same() || std::is_same() || @@ -429,34 +420,6 @@ class ConvertAtenCompareOp : public OpConversionPattern { rhsTensor = tosa::promoteType(rewriter, rhsTensor, resultTy); } - // Support different types comparisons - auto isLhsElemFloat = isa(lhsElemTy); - auto isRhsElemFloat = isa(rhsElemTy); - - if (lhsElemTy != rhsElemTy && !isBitwiseOp) { - if (isLhsElemFloat && !isRhsElemFloat) { - rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy); - } else if (!isLhsElemFloat && isRhsElemFloat) { - lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy); - } else if (isLhsElemFloat && isRhsElemFloat) { - auto lhsElemFloatTy = dyn_cast(lhsElemTy); - auto rhsElemFloatTy = dyn_cast(rhsElemTy); - if (lhsElemFloatTy.getWidth() > rhsElemFloatTy.getWidth()) { - rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy); - } else { - lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy); - } - } else { - auto lhsElemIntTy = dyn_cast(lhsElemTy); - auto rhsElemIntTy = dyn_cast(rhsElemTy); - if (lhsElemIntTy.getWidth() > rhsElemIntTy.getWidth()) { - rhsTensor = tosa::promoteType(rewriter, rhsTensor, lhsTy); - } else { - lhs = tosa::promoteType(rewriter, lhs, rhsTensorTy); - } - } - } - auto resultOp = rewriter.create(op.getLoc(), resultTy, (swapLhsRhs ? rhsTensor : lhs), (swapLhsRhs ? lhs : rhsTensor)); @@ -771,24 +734,17 @@ class ConvertAtenActivationFunctionOp : public OpConversionPattern { matchAndRewrite(AtenOpT op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value self = adaptor.getSelf(); - auto selfTy = dyn_cast(self.getType()); + auto selfTy = cast(self.getType()); if (!selfTy) return rewriter.notifyMatchFailure(op, "Only Tensor types supported"); - auto resultTy = dyn_cast( - this->getTypeConverter()->convertType(op.getType())); - - if (!isa(resultTy.getElementType())) - return rewriter.notifyMatchFailure( - op, "Only floating-point datatype result types are supported"); - - // Non floating point inputs are not supported for activation functions - // (erf, sigmoid, tanh) in TOSA so we cast the input to result type if (!isa(selfTy.getElementType())) - self = tosa::promoteType(rewriter, self, resultTy); + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization currently supported"); - rewriter.replaceOpWithNewOp(op, resultTy, self); + rewriter.replaceOpWithNewOp( + op, this->getTypeConverter()->convertType(op.getType()), self); return success(); } @@ -1291,10 +1247,6 @@ class ConvertAtenPowOp : public OpConversionPattern { auto outType = cast(this->getTypeConverter()->convertType(op.getType())); - if (!isa(outType.getElementType())) - return rewriter.notifyMatchFailure( - op, "Only floating-point datatype result types are supported"); - Value selfTensor; if constexpr (std::is_same()) { Value selfScalar = op.getSelf(); @@ -1311,10 +1263,9 @@ class ConvertAtenPowOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "Only ranked tensor types supported in TOSA Pow"); - // Non floating point inputs are not supported for tosa.pow so we cast the - // input to result type if (!isa(selfTy.getElementType())) - selfTensor = tosa::promoteType(rewriter, selfTensor, outType); + return rewriter.notifyMatchFailure( + op, "Only floating-point datatype legalization supported"); } Value expTensor; @@ -1332,11 +1283,6 @@ class ConvertAtenPowOp : public OpConversionPattern { if (!expTy) return rewriter.notifyMatchFailure( op, "Only ranked tensor types supported in TOSA Pow"); - - // Non floating point exponents are not supported for tosa.pow so we cast - // the exponent to result type - if (!isa(expTy.getElementType())) - expTensor = tosa::promoteType(rewriter, expTensor, outType); } auto powOp = tosa::createBinaryOpAndCast( @@ -2945,32 +2891,24 @@ template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenLog2Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto self = adaptor.getSelf(); // Not a tensor type. - auto selfType = dyn_cast(self.getType()); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); - auto outType = - dyn_cast(getTypeConverter()->convertType(op.getType())); - - // If input is not a float type then cast it to output type - auto selfElemTy = selfType.getElementType(); - if (!isa(selfElemTy)) - self = tosa::promoteType(rewriter, self, outType); - // Constant value of ln2. SmallVector ln2Shape(selfType.getRank(), 1); auto ln2Op = tosa::getConstTensor(rewriter, op, {0.69314718056f}, - ln2Shape, outType.getElementType()) + ln2Shape, selfType.getElementType()) .value(); - auto rcpOp = rewriter.create(op.getLoc(), ln2Op.getType(), ln2Op); - auto logOp = rewriter.create(op.getLoc(), outType, self); + auto outType = getTypeConverter()->convertType(op.getType()); + auto logOp = + rewriter.create(op.getLoc(), outType, adaptor.getSelf()); rewriter.replaceOpWithNewOp(op, outType, logOp, rcpOp, /*shift=*/0); @@ -3258,10 +3196,9 @@ template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenGeluOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - auto self = adaptor.getSelf(); // Not a tensor type. - auto selfType = dyn_cast(self.getType()); + auto selfType = dyn_cast(adaptor.getSelf().getType()); if (!selfType) return rewriter.notifyMatchFailure( op, "Only tensor types are currently supported"); @@ -3272,104 +3209,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Only floating-point datatype legalization supported"); } - auto resultType = - dyn_cast(getTypeConverter()->convertType(op.getType())); - + // TODO: Handle approximate. std::string approximate; - if (!matchPattern(op.getApproximate(), m_TorchConstantStr(approximate))) { - return rewriter.notifyMatchFailure( - op, "Non-const approximate value not supported"); + if (!matchPattern(op.getApproximate(), m_TorchConstantStr(approximate)) || + approximate != "none") { + return rewriter.notifyMatchFailure(op, "Unsupported value of approximate"); } - if (approximate.compare("none") == 0) { - // GELU(x) = x * CDF(x) - Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy); - cdf = rewriter.createOrFold( - op->getLoc(), - cast(cdf.getType()).cloneWith({}, selfElemTy), cdf); - - rewriter.replaceOpWithNewOp(op, resultType, self, cdf, - /*shift=*/0); - } else if (approximate.compare("tanh") == 0) { - // "tanh" approximate - // GELU(x) = 0.5 * x * (1 + Tanh(sqrt(2/pi) * (x + 0.044715 * x^3)) - // Formula taken from: - // https://pytorch.org/docs/stable/generated/torch.nn.GELU.html - auto selfShape = selfType.getShape(); - if (!selfType.hasStaticShape()) - return rewriter.notifyMatchFailure( - op, "Only static shape tensor types are currently supported for Tanh " - "approximation"); - - auto numElem = std::accumulate(selfShape.begin(), selfShape.end(), 1, - std::multiplies()); - - Value half = tosa::getConstTensor(rewriter, op, - SmallVector(numElem, 0.5), - selfShape, selfElemTy) - .value(); - Value one = tosa::getConstTensor(rewriter, op, - SmallVector(numElem, 1.0), - selfShape, selfElemTy) - .value(); - Value three = tosa::getConstTensor(rewriter, op, - SmallVector(numElem, 3.0), - selfShape, selfElemTy) - .value(); - - // 0.044715 - Value magicNumber = tosa::getConstTensor( - rewriter, op, SmallVector(numElem, 0.044715), - selfShape, selfElemTy) - .value(); - - // From header: M_2_PI = 2 / pi - Value twoOverPi = tosa::getConstTensor( - rewriter, op, SmallVector(numElem, M_2_PI), - selfShape, selfElemTy) - .value(); - - // 0.5 * x - auto halfInput = rewriter.create(op->getLoc(), resultType, - half, self, /*shift=*/0); - - // sqrt(2/pi) - auto sqrtTwoOverPi = - rewriter.create(op->getLoc(), resultType, twoOverPi, half); - - // x^3 - auto inputPowThree = - rewriter.create(op->getLoc(), resultType, self, three); - - // 0.044715 * x^3 - auto inputPowThreeMul = - rewriter.create(op->getLoc(), resultType, magicNumber, - inputPowThree.getResult(), /*shift=*/0); - - // x + 0.044715 * x^3 - auto inputPowThreeMulAdd = rewriter.create( - op->getLoc(), resultType, self, inputPowThreeMul.getResult()); - - // sqrt(2/pi) * (x + 0.044715 * x^3) - auto sqrtTwoOverPiMul = rewriter.create( - op->getLoc(), resultType, sqrtTwoOverPi.getResult(), - inputPowThreeMulAdd.getResult(), /*shift=*/0); - - // tanh(sqrt(2/pi) * (x + 0.044715 * x^3)) - auto tanh = rewriter.create(op->getLoc(), resultType, - sqrtTwoOverPiMul.getResult()); - - // 1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)) - auto tanhAdd = rewriter.create(op->getLoc(), resultType, one, - tanh.getResult()); + Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy); + cdf = rewriter.createOrFold( + op->getLoc(), + cast(cdf.getType()).cloneWith({}, selfElemTy), cdf); - rewriter.replaceOpWithNewOp( - op, resultType, halfInput.getResult(), tanhAdd.getResult(), - /*shift=*/0); - } else { - return rewriter.notifyMatchFailure(op, - "Unsupported approximation algorithm"); - } + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(), cdf, + /*shift=*/0); return success(); } @@ -5398,11 +5252,9 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { } else { int64_t dimSize = inputDim + padBefore + padAfter - dilation * (kernelDim - 1) - 1; - int64_t outputDim = dimSize / stride + 1; - if (ceilMode && (dimSize % stride != 0) && - (outputDim * stride < inputDim + padBefore)) - outputDim++; - return outputDim; + if (ceilMode && (dimSize % stride != 0)) + return dimSize / stride + 2; + return dimSize / stride + 1; } } @@ -5697,26 +5549,6 @@ static LogicalResult getOutputTypeAndPoolingParameters( std::is_same()) paddingInts.push_back(0); - if constexpr (std::is_same() || - std::is_same()) { - // Currently, we can not represent `count_include_pad` with the existing - // TOSA AvgPool2d specification. Without the below check, we produce silent - // wrong answer (SWA) when the `count_include_pad` value is `true.` - // - // Note: We need to check for `count_include_pad` only when the `padding` - // value is non-zero. - bool countIncludePad; - if ((paddingInts[0] != 0 || paddingInts[1] != 0) && - (!matchPattern(op.getCountIncludePad(), - m_TorchConstantBool(&countIncludePad)) || - - countIncludePad)) { - return rewriter.notifyMatchFailure( - op, "Unsupported `count_include_pad` value, for tosa AvgPool " - "`count_include_pad` value should be `False`."); - } - } - SmallVector padArr = {paddingInts[0], paddingInts[0], paddingInts[1], paddingInts[1]}; kernel = rewriter.getDenseI64ArrayAttr(kernelSizeInts); @@ -5845,6 +5677,18 @@ class ConvertAtenAvgPool2dOp DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad, Type &outputTy) const override { + // Currently, we can not represent `count_include_pad` with the existing + // TOSA AvgPool2d specification. Without the below check, we produce silent + // wrong answers (SWA) when the `count_include_pad` value is `true.` + bool countIncludePad; + if (!matchPattern(op.getCountIncludePad(), + m_TorchConstantBool(&countIncludePad)) || + countIncludePad) { + return rewriter.notifyMatchFailure( + op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp " + "`count_include_pad` value should be `False`."); + } + // Currently, we can not represent `divisor_override` with the existing TOSA // AvgPool2d specification. Without the below check, we produce silent wrong // answers (SWA) when the `divisor_override` value is other than `None.` @@ -5893,7 +5737,7 @@ class ConvertAtenAvgPool1dOp // Expected a rank 3 input tensor if (selfTy.getRank() != 3) return rewriter.notifyMatchFailure( - op, "Input tensor for AvgPool1d should have rank 3"); + op, "Input tensor for MaxPool1d should have rank 3"); // Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp SmallVector rank4Shape(selfShape); @@ -5904,6 +5748,18 @@ class ConvertAtenAvgPool1dOp selfTy.getElementType()), self, rewriter.getDenseI64ArrayAttr(rank4Shape)); + // Currently, we can not represent `count_include_pad` with the existing + // TOSA AvgPool2d specification. Without the below check, we produce silent + // wrong answers (SWA) when the `count_include_pad` value is `true.` + bool countIncludePad; + if (!matchPattern(op.getCountIncludePad(), + m_TorchConstantBool(&countIncludePad)) || + countIncludePad) { + return rewriter.notifyMatchFailure( + op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp " + "`count_include_pad` value should be `False`."); + } + SmallVector dilationArray{1, 1}; if (failed(getOutputTypeAndPoolingParameters( @@ -7342,1487 +7198,348 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -// Legalization for aten.reflection_pad1d -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenReflectionPad1dOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto self = adaptor.getSelf(); - - auto selfType = dyn_cast(self.getType()); - if (!selfType) - return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); - - auto selfShape = selfType.getShape(); - auto selfRank = selfType.getRank(); - auto selfElemTy = selfType.getElementType(); - - auto resultType = - dyn_cast(typeConverter->convertType(op.getType())); - - SmallVector paddingList; - if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingList))) - return rewriter.notifyMatchFailure( - op, "Non-const padding lists are not supported"); - - int64_t paddingLeft = paddingList[0]; - int64_t paddingRight = paddingList[1]; +} // namespace - if (paddingLeft >= selfShape[selfRank - 1] || - paddingRight >= selfShape[selfRank - 1]) - return rewriter.notifyMatchFailure( - op, "Padding should be less than input boundary size"); +// ----------------------------------------------------------------------------- +// TorchToTosa Pass +// ----------------------------------------------------------------------------- - // Identity case - if (paddingLeft == 0 && paddingRight == 0) { - rewriter.replaceOp(op, self); - return success(); +namespace { +class ConvertTorchToTosa : public ConvertTorchToTosaBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + TorchConversion::getBackendTypeConversionDependentDialects(registry); } - SmallVector resultTensors; - - // Use tosa.slice and tosa.reverse to get the reflection pads based on the - // padding size - if (paddingLeft > 0) { - SmallVector leftStartSlice(selfRank, 0); - SmallVector leftSizeSlice(selfShape); + void runOnOperation() override { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + target.addLegalDialect(); - leftStartSlice[selfRank - 1] = 1; - leftSizeSlice[selfRank - 1] = paddingLeft; + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversion(target, typeConverter); - SmallVector leftPadShape(selfShape.begin(), selfShape.end() - 1); - leftPadShape.push_back(paddingLeft); + // The following ops are never the primary reason why lowering fails. + // The backend contract only allows functions to return tensors thus there + // is always another op using them. + // When we have a chain of torch.constant.int followed by a unsupported + // torch op, we want the pass to mention the unsupported torch op + // in the error message. + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addIllegalDialect(); - auto leftPadType = RankedTensorType::get(leftPadShape, selfElemTy); + RewritePatternSet patterns(context); - auto leftPadSlice = rewriter.create( - op->getLoc(), leftPadType, self, - rewriter.getDenseI64ArrayAttr(leftStartSlice), - rewriter.getDenseI64ArrayAttr(leftSizeSlice)); +#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, TosaOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, \ + context); + INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, tosa::LogOp) + INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, tosa::ExpOp) +#undef INSERT_UNARY_FPONLY_PATTERN - auto leftPad = rewriter.create( - op->getLoc(), leftPadType, leftPadSlice.getResult(), - static_cast(selfRank - 1)); +#define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp) + INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp) + INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp) + INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp) + INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp) + INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp) + INSERT_UNARY_PATTERN(AtenCosOp, tosa::CosOp) + INSERT_UNARY_PATTERN(AtenSinOp, tosa::SinOp) + INSERT_UNARY_PATTERN(AtenLogicalNotOp, tosa::LogicalNotOp) +#undef INSERT_UNARY_PATTERN - resultTensors.push_back(leftPad.getResult()); - } +#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp) + INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp) + INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp) + INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp) + INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp) + INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp, + tosa::LogicalLeftShiftOp) + INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp, + tosa::ArithmeticRightShiftOp) +#undef INSERT_BINARY_PATTERN - resultTensors.push_back(self); +#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp) + INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp) +#undef INSERT_BINARY_ADDSUB_PATTERN - if (paddingRight > 0) { - SmallVector rightStartSlice(selfRank, 0); - SmallVector rightSizeSlice(selfShape); +#define INSERT_BINARY_COMPARE_PATTERN(AtenOp, TosaOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp) + INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp) +#undef INSERT_BINARY_COMPARE_PATTERN - rightStartSlice[selfRank - 1] = selfShape[selfRank - 1] - paddingRight - 1; - rightSizeSlice[selfRank - 1] = paddingRight; +#define INSERT_BINARY_MUL_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp); + INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp); +#undef INSERT_BINARY_MUL_PATTERN - SmallVector rightPadShape(selfShape.begin(), selfShape.end() - 1); - rightPadShape.push_back(paddingRight); +#define INSERT_BINARY_DIV_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp); + INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp); + INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp); + INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp); +#undef INSERT_BINARY_DIV_PATTERN - auto rightPadType = RankedTensorType::get(rightPadShape, selfElemTy); +#define INSERT_REMAINDER_FMOD_OP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp); + INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp); +#undef INSERT_REMAINDER_FMOD_OP_PATTERN - auto rightPadSlice = rewriter.create( - op->getLoc(), rightPadType, self, - rewriter.getDenseI64ArrayAttr(rightStartSlice), - rewriter.getDenseI64ArrayAttr(rightSizeSlice)); +#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ + target.addIllegalOp(); \ + patterns.add>( \ + typeConverter, context); + INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp, + mlir::tosa::convertReduceMeanOp) + INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp, + mlir::tosa::convertReduceSumOp) + INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp, + mlir::tosa::convertLinalgVectorNormOp) +#undef INSERT_NDIMS_REDUCTION_OP_PATTERN - auto rightPad = rewriter.create( - op->getLoc(), rightPadType, rightPadSlice.getResult(), - static_cast(selfRank - 1)); +#define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ + target.addIllegalOp(); \ + patterns.add>( \ + typeConverter, context); + INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp, + mlir::tosa::convertReduceAnyOp) + INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp, + mlir::tosa::convertReduceAllOp) + INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp, + mlir::tosa::convertReduceProdOp) +#undef INSERT_ONEDIM_REDUCTION_OP_PATTERN - resultTensors.push_back(rightPad.getResult()); - } +#define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ + target.addIllegalOp(); \ + patterns.add>( \ + typeConverter, context); + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp, + mlir::tosa::convertReduceAllOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp, + mlir::tosa::convertReduceAnyOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp, + mlir::tosa::convertReduceSumOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp, + mlir::tosa::convertReduceMaxOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp, + mlir::tosa::convertReduceMinOp) + INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp, + mlir::tosa::convertReduceProdOp) +#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN - auto result = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), resultType, resultTensors, selfRank - 1); +#define INSERT_INDICES_REDUCTION_OP_PATTERN(AtenOp, TosaOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp); + INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp); +#undef INSERT_INDICES_REDUCTION_OP_PATTERN - rewriter.replaceOp(op, result); - return success(); -} +#define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp) + INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp) +#undef INSERT_SQUEEZE_OP_PATTERN -// Legalization for aten.reflection_pad2d -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenReflectionPad2dOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto self = adaptor.getSelf(); +#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp); +#undef INSERT_MATMUL_ATEMOP_PATTERN - auto selfType = dyn_cast(self.getType()); - if (!selfType) - return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); +#define INSERT_MM_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_MM_ATENOP_PATTERN(AtenMmOp); + INSERT_MM_ATENOP_PATTERN(AtenBmmOp); +#undef INSERT_MM_ATEMOP_PATTERN - auto selfShape = selfType.getShape(); - auto selfRank = selfType.getRank(); - auto selfElemTy = selfType.getElementType(); +#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); +#undef INSERT_LINEAR_ATEMOP_PATTERN - auto resultType = - dyn_cast(typeConverter->convertType(op.getType())); - auto resultShape = resultType.getShape(); +#define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, \ + context); + INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp, + tosa::AvgPool2dOp); +#undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN - SmallVector paddingList; - if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingList))) - return rewriter.notifyMatchFailure( - op, "Non-const padding lists are not supported"); + target.addIllegalOp(); + patterns.add(typeConverter, context); - int64_t paddingLeft = paddingList[0]; - int64_t paddingRight = paddingList[1]; - int64_t paddingTop = paddingList[2]; - int64_t paddingBottom = paddingList[3]; + target.addIllegalOp(); + patterns.add(typeConverter, context); - if (paddingLeft >= selfShape[selfRank - 1] || - paddingRight >= selfShape[selfRank - 1] || - paddingTop >= selfShape[selfRank - 2] || - paddingBottom >= selfShape[selfRank - 2]) - return rewriter.notifyMatchFailure( - op, "Padding must be less than the corresponding input dimension"); + target.addIllegalOp(); + patterns.add(typeConverter, context); - // Identity case - if (paddingLeft == 0 && paddingRight == 0 && paddingTop == 0 && - paddingBottom == 0) { - rewriter.replaceOp(op, self); - return success(); - } + target.addIllegalOp(); + patterns.add(typeConverter, context); - // Use tosa.slice and tosa.reverse to get the reflection pads based on the - // padding size - SmallVector sideTensors; +#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, \ + context); + INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); + INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); + INSERT_CONSTANT_FILL_PATTERN(AtenEmptyMemoryFormatOp, 0); +#undef INSERT_CONSTANT_FILL_PATTERN - if (paddingLeft > 0) { - SmallVector leftStartSlice(selfRank, 0); - SmallVector leftSizeSlice(selfShape); +#define INSERT_FILL_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_FILL_PATTERN(AtenFill_ScalarOp); + INSERT_FILL_PATTERN(AtenFillScalarOp); + INSERT_FILL_PATTERN(AtenFillTensorOp); +#undef INSERT_FILL_PATTERN - leftStartSlice[selfRank - 1] = 1; - leftSizeSlice[selfRank - 1] = paddingLeft; +#define INSERT_MASKED_FILL_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp); + INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp); +#undef INSERT_MASKED_FILL_PATTERN - SmallVector leftPadShape(selfShape.begin(), selfShape.end() - 1); - leftPadShape.push_back(paddingLeft); +#define INSERT_POW_OP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp); + INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp); + INSERT_POW_OP_PATTERN(AtenPowScalarOp); +#undef INSERT_POW_OP_PATTERN - auto leftPadType = RankedTensorType::get(leftPadShape, selfElemTy); +#define INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenOp, TosaOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, \ + context); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp); + INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp); +#undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN - auto leftPadSlice = rewriter.create( - op->getLoc(), leftPadType, self, - rewriter.getDenseI64ArrayAttr(leftStartSlice), - rewriter.getDenseI64ArrayAttr(leftSizeSlice)); +#define INSERT_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp); + INSERT_ATENOP_PATTERN(AtenReluOp); + INSERT_ATENOP_PATTERN(AtenLeakyReluOp); + INSERT_ATENOP_PATTERN(AtenArgmaxOp); + INSERT_ATENOP_PATTERN(AtenRsubScalarOp); + INSERT_ATENOP_PATTERN(AtenConvolutionOp); + INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); + INSERT_ATENOP_PATTERN(AtenReshapeOp); + INSERT_ATENOP_PATTERN(AtenBatchNormOp); + INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); + INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp); + INSERT_ATENOP_PATTERN(AtenUnflattenIntOp); + INSERT_ATENOP_PATTERN(AtenPermuteOp); + INSERT_ATENOP_PATTERN(AtenLog2Op); + INSERT_ATENOP_PATTERN(AtenThresholdOp); + INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); + INSERT_ATENOP_PATTERN(AtenContiguousOp); + INSERT_ATENOP_PATTERN(AtenDropoutOp); + INSERT_ATENOP_PATTERN(AtenViewOp); + INSERT_ATENOP_PATTERN(AtenGeluOp); + INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); + INSERT_ATENOP_PATTERN(AtenEmbeddingOp); + INSERT_ATENOP_PATTERN(AtenTransposeIntOp); + INSERT_ATENOP_PATTERN(AtenSliceTensorOp); + INSERT_ATENOP_PATTERN(AtenBroadcastToOp); + INSERT_ATENOP_PATTERN(AtenGatherOp); + INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp); + INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp); + INSERT_ATENOP_PATTERN(AtenAbsOp); + INSERT_ATENOP_PATTERN(AtenWhereSelfOp); + INSERT_ATENOP_PATTERN(AtenClampOp); + INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); + INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); + INSERT_ATENOP_PATTERN(AtenCopyOp); + INSERT_ATENOP_PATTERN(AtenToDtypeOp); + INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); + INSERT_ATENOP_PATTERN(AtenCatOp); + INSERT_ATENOP_PATTERN(AtenSqrtOp); + INSERT_ATENOP_PATTERN(AtenIscloseOp); + INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp); + INSERT_ATENOP_PATTERN(AtenTrilOp); + INSERT_ATENOP_PATTERN(AtenDiagonalOp); + INSERT_ATENOP_PATTERN(AtenIndexSelectOp); + INSERT_ATENOP_PATTERN(AtenFlipOp); + INSERT_ATENOP_PATTERN(AtenRoundOp); + INSERT_ATENOP_PATTERN(AtenScatterSrcOp); + INSERT_ATENOP_PATTERN(AtenSliceScatterOp); + INSERT_ATENOP_PATTERN(AtenDiagEmbedOp); + INSERT_ATENOP_PATTERN(AtenUniformOp); + INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp); + INSERT_ATENOP_PATTERN(AtenAsStridedOp); + INSERT_ATENOP_PATTERN(AtenClampTensorOp); + INSERT_ATENOP_PATTERN(PrimsCollapseOp); +#undef INSERT_ATENOP_PATTERN - auto leftPad = rewriter.create( - op->getLoc(), leftPadType, leftPadSlice.getResult(), - static_cast(selfRank - 1)); +#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ + target.addIllegalOp(); \ + patterns.add>(typeConverter, context); + INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp); +#undef INSERT_CLONE_ATENOP_PATTERN - sideTensors.push_back(leftPad.getResult()); - } - - sideTensors.push_back(self); - - if (paddingRight > 0) { - SmallVector rightStartSlice(selfRank, 0); - SmallVector rightSizeSlice(selfShape); - - rightStartSlice[selfRank - 1] = selfShape[selfRank - 1] - paddingRight - 1; - rightSizeSlice[selfRank - 1] = paddingRight; - - SmallVector rightPadShape(selfShape.begin(), selfShape.end() - 1); - rightPadShape.push_back(paddingRight); - - auto rightPadType = RankedTensorType::get(rightPadShape, selfElemTy); - - auto rightPadSlice = rewriter.create( - op->getLoc(), rightPadType, self, - rewriter.getDenseI64ArrayAttr(rightStartSlice), - rewriter.getDenseI64ArrayAttr(rightSizeSlice)); - - auto rightPad = rewriter.create( - op->getLoc(), rightPadType, rightPadSlice.getResult(), - static_cast(selfRank - 1)); - - sideTensors.push_back(rightPad.getResult()); - } - - SmallVector selfSidePaddedShape(selfShape.begin(), - selfShape.end() - 1); - selfSidePaddedShape.push_back(resultShape.back()); - - auto selfSidePadded = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), - RankedTensorType::get(selfSidePaddedShape, selfElemTy), sideTensors, - selfRank - 1); - - SmallVector resultTensors; - - if (paddingTop > 0) { - SmallVector topStartSlice(selfRank, 0); - SmallVector topSizeSlice(selfShape.begin(), selfShape.end() - 1); - topSizeSlice.push_back(resultShape.back()); - - topStartSlice[selfRank - 2] = 1; - topSizeSlice[selfRank - 2] = paddingTop; - - SmallVector topPadShape(selfShape.begin(), selfShape.end() - 2); - topPadShape.push_back(paddingTop); - topPadShape.push_back(resultShape.back()); - - auto topPadType = RankedTensorType::get(topPadShape, selfElemTy); - - auto topPadSlice = rewriter.create( - op->getLoc(), topPadType, selfSidePadded, - rewriter.getDenseI64ArrayAttr(topStartSlice), - rewriter.getDenseI64ArrayAttr(topSizeSlice)); - - auto topPad = rewriter.create( - op->getLoc(), topPadType, topPadSlice.getResult(), - static_cast(selfRank - 2)); - - resultTensors.push_back(topPad.getResult()); - } - - resultTensors.push_back(selfSidePadded.getResult()); - - if (paddingBottom > 0) { - SmallVector bottomStartSlice(selfRank, 0); - SmallVector bottomSizeSlice(selfShape.begin(), - selfShape.end() - 1); - bottomSizeSlice.push_back(resultShape.back()); - - bottomStartSlice[selfRank - 2] = - selfShape[selfRank - 2] - paddingBottom - 1; - bottomSizeSlice[selfRank - 2] = paddingBottom; - - SmallVector bottomPadShape(selfShape.begin(), selfShape.end() - 2); - bottomPadShape.push_back(paddingBottom); - bottomPadShape.push_back(resultShape.back()); - - auto bottomPadType = RankedTensorType::get(bottomPadShape, selfElemTy); - - auto bottomPadSlice = rewriter.create( - op->getLoc(), bottomPadType, selfSidePadded, - rewriter.getDenseI64ArrayAttr(bottomStartSlice), - rewriter.getDenseI64ArrayAttr(bottomSizeSlice)); - - auto bottomPad = rewriter.create( - op->getLoc(), bottomPadType, bottomPadSlice.getResult(), - static_cast(selfRank - 2)); - - resultTensors.push_back(bottomPad.getResult()); - } - - auto result = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), resultType, resultTensors, selfRank - 2); - - rewriter.replaceOp(op, result); - return success(); -} - -// Legalization for aten.replication_pad2d -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenReplicationPad2dOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto self = adaptor.getSelf(); - - auto selfType = dyn_cast(self.getType()); - if (!selfType) - return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); - - auto selfShape = selfType.getShape(); - auto selfRank = selfType.getRank(); - auto selfElemTy = selfType.getElementType(); - - auto resultType = - dyn_cast(typeConverter->convertType(op.getType())); - auto resultShape = resultType.getShape(); - - SmallVector paddingList; - if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingList))) - return rewriter.notifyMatchFailure( - op, "Non-const padding lists are not supported"); - - int64_t paddingLeft = paddingList[0]; - int64_t paddingRight = paddingList[1]; - int64_t paddingTop = paddingList[2]; - int64_t paddingBottom = paddingList[3]; - - // Identity case - if (paddingLeft == 0 && paddingRight == 0 && paddingTop == 0 && - paddingBottom == 0) { - rewriter.replaceOp(op, self); - return success(); - } - - // Use tosa.slice to get the reflection pads based on the padding size - SmallVector sideTensors; - - if (paddingLeft > 0) { - SmallVector leftStartSlice(selfRank, 0); - SmallVector leftSizeSlice(selfShape); - - leftStartSlice[selfRank - 1] = 0; - leftSizeSlice[selfRank - 1] = 1; - - SmallVector leftPadSliceShape(selfShape.begin(), - selfShape.end() - 1); - leftPadSliceShape.push_back(1); - - auto leftPadSliceType = - RankedTensorType::get(leftPadSliceShape, selfElemTy); - - auto leftPadSlice = rewriter.create( - op->getLoc(), leftPadSliceType, self, - rewriter.getDenseI64ArrayAttr(leftStartSlice), - rewriter.getDenseI64ArrayAttr(leftSizeSlice)); - - for (int64_t i = 0; i < paddingLeft; i++) - sideTensors.push_back(leftPadSlice.getResult()); - } - - sideTensors.push_back(self); - - if (paddingRight > 0) { - SmallVector rightStartSlice(selfRank, 0); - SmallVector rightSizeSlice(selfShape); - - rightStartSlice[selfRank - 1] = selfShape[selfRank - 1] - 1; - rightSizeSlice[selfRank - 1] = 1; - - SmallVector rightPadSliceShape(selfShape.begin(), - selfShape.end() - 1); - rightPadSliceShape.push_back(1); - - auto rightPadSliceType = - RankedTensorType::get(rightPadSliceShape, selfElemTy); - - auto rightPadSlice = rewriter.create( - op->getLoc(), rightPadSliceType, self, - rewriter.getDenseI64ArrayAttr(rightStartSlice), - rewriter.getDenseI64ArrayAttr(rightSizeSlice)); - - for (int64_t i = 0; i < paddingRight; i++) - sideTensors.push_back(rightPadSlice.getResult()); - } - - SmallVector selfSidePaddedShape(selfShape.begin(), - selfShape.end() - 1); - selfSidePaddedShape.push_back(resultShape.back()); - - auto selfSidePadded = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), - RankedTensorType::get(selfSidePaddedShape, selfElemTy), sideTensors, - selfRank - 1); - - SmallVector resultTensors; - - if (paddingTop > 0) { - SmallVector topStartSlice(selfRank, 0); - SmallVector topSizeSlice(selfShape.begin(), selfShape.end() - 1); - topSizeSlice.push_back(resultShape.back()); - - topStartSlice[selfRank - 2] = 0; - topSizeSlice[selfRank - 2] = 1; - - SmallVector topPadSliceShape(selfShape.begin(), - selfShape.end() - 2); - topPadSliceShape.push_back(1); - topPadSliceShape.push_back(resultShape.back()); - - auto topPadSliceType = RankedTensorType::get(topPadSliceShape, selfElemTy); - - auto topPadSlice = rewriter.create( - op->getLoc(), topPadSliceType, selfSidePadded, - rewriter.getDenseI64ArrayAttr(topStartSlice), - rewriter.getDenseI64ArrayAttr(topSizeSlice)); - - for (int64_t i = 0; i < paddingTop; i++) - resultTensors.push_back(topPadSlice.getResult()); - } - - resultTensors.push_back(selfSidePadded.getResult()); - - if (paddingBottom > 0) { - SmallVector bottomStartSlice(selfRank, 0); - SmallVector bottomSizeSlice(selfShape.begin(), - selfShape.end() - 1); - bottomSizeSlice.push_back(resultShape.back()); - - bottomStartSlice[selfRank - 2] = selfShape[selfRank - 2] - 1; - bottomSizeSlice[selfRank - 2] = 1; - - SmallVector bottomPadSliceShape(selfShape.begin(), - selfShape.end() - 2); - bottomPadSliceShape.push_back(1); - bottomPadSliceShape.push_back(resultShape.back()); - - auto bottomPadSliceType = - RankedTensorType::get(bottomPadSliceShape, selfElemTy); - - auto bottomPadSlice = rewriter.create( - op->getLoc(), bottomPadSliceType, selfSidePadded, - rewriter.getDenseI64ArrayAttr(bottomStartSlice), - rewriter.getDenseI64ArrayAttr(bottomSizeSlice)); - - for (int64_t i = 0; i < paddingBottom; i++) - resultTensors.push_back(bottomPadSlice.getResult()); - } - - auto result = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), resultType, resultTensors, selfRank - 2); - - rewriter.replaceOp(op, result); - return success(); -} - -// Legalization for torch.prims.split_dim -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - PrimsSplitDimOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto self = adaptor.getA(); - - // Not a tensor type - auto selfType = dyn_cast(self.getType()); - if (!selfType) - return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); - - auto resultType = - dyn_cast(typeConverter->convertType(op.getType())); - auto resultShape = resultType.getShape(); - - int64_t dim, outerLength; - if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) - return rewriter.notifyMatchFailure( - op, "Only constant int dim value is supported"); - - auto selfRank = selfType.getRank(); - dim = toPositiveDim(dim, selfRank); - if (!isValidDim(dim, selfRank)) - return rewriter.notifyMatchFailure(op, "Dim is invalid"); - - if (!matchPattern(op.getOuterLength(), m_TorchConstantInt(&outerLength))) - return rewriter.notifyMatchFailure( - op, "Only constant int outer length value is supported"); - - // Technically, I should calculate the output shape based on the dim and outer - // length values. However, that would just give the same result as me taking - // the result shape straight from resultType and applying tosa::ReshapeOp to - // the input. Therefore, I'm opting for the latter approach here, which is - // more simple and quicker. - rewriter.replaceOpWithNewOp( - op, resultType, self, - rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape))); - - return success(); -} - -// Legalization for aten.outer -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenOuterOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto self = adaptor.getSelf(); - - auto selfType = dyn_cast(self.getType()); - if (!selfType) - return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); - - if (selfType.getRank() != 1) - return rewriter.notifyMatchFailure(op, "Only rank 1 vectors are supported"); - - auto vec2 = adaptor.getVec2(); - - auto vec2Type = dyn_cast(vec2.getType()); - if (!vec2Type) - return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); - - if (vec2Type.getRank() != 1) - return rewriter.notifyMatchFailure(op, "Only rank 1 vectors are supported"); - - auto resultType = - dyn_cast(typeConverter->convertType(op.getType())); - auto resultShape = resultType.getShape(); - - self = tosa::promoteType(rewriter, self, resultType); - vec2 = tosa::promoteType(rewriter, vec2, resultType); - - SmallVector resultShapeIndex1Replaced({resultShape[0], 1}); - SmallVector resultShapeIndex0Replaced({1, resultShape[1]}); - - // Reshape and tile self to shape {selfShape[0], resultShape[1]} - auto selfReshaped = rewriter.create( - op->getLoc(), - RankedTensorType::get(resultShapeIndex1Replaced, - resultType.getElementType()), - self, rewriter.getDenseI64ArrayAttr(resultShapeIndex1Replaced)); - - auto selfTiled = rewriter.create( - op->getLoc(), resultType, selfReshaped.getResult(), - rewriter.getDenseI64ArrayAttr(resultShapeIndex0Replaced)); - - // Reshape and tile vec2 to shape {resultShape[0], vec2Shape[0]} - auto vec2Reshaped = rewriter.create( - op->getLoc(), - RankedTensorType::get(resultShapeIndex0Replaced, - resultType.getElementType()), - vec2, rewriter.getDenseI64ArrayAttr(resultShapeIndex0Replaced)); - - auto vec2Tiled = rewriter.create( - op->getLoc(), resultType, vec2Reshaped.getResult(), - rewriter.getDenseI64ArrayAttr(resultShapeIndex1Replaced)); - - auto result = - tosa::createMulOpAndCast(rewriter, op, resultType, selfTiled.getResult(), - vec2Tiled.getResult(), /*shift=*/0); - - rewriter.replaceOp(op, result); - return success(); -} - -// Legalization for aten.upsample_nearest2d -template -class ConvertUpsampleNearest2dForward : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - using OpAdaptor = typename AtenOpT::Adaptor; - LogicalResult - matchAndRewrite(AtenOpT op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // aten.upsample_nearest2d lowering process: - // 1. Reshape input: (N, C, H, W) -> (N, C, H x W) - // 2. Calculate PyTorch-styled gather op indices based on the following - // formula (based on Torch to Linalg UpsampleNearest2d lowering formula): - // for i in range(N x C): - // for heightIndex in range(scaledHeight): - // for widthIndex in range(scaledWidth): - // indices.append(int(heightIndex // scalesH * selfWidth + - // widthIndex // scalesW)) - // 3. Convert PyTorch-styled indices to TensorFlow-styled indices - // 4. Apply TensorFlow-styled ConverGatherOpNd to retrieve the output - // 5. Reshape output to desired output shape - Value self; - if constexpr (std::is_same()) { - self = adaptor.getSelf(); - } else if constexpr (std::is_same()) { - self = adaptor.getInput(); - } else { - return rewriter.notifyMatchFailure( - op, "Expected either AtenUpsampleNearest2dOp or " - "AtenUpsampleNearest2dVecOp"); - } - - auto selfType = dyn_cast(self.getType()); - if (!selfType) - return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); - - auto selfShape = selfType.getShape(); - auto selfRank = selfType.getRank(); - auto selfElemTy = selfType.getElementType(); - - auto selfHeight = selfShape[selfRank - 2]; - auto selfWidth = selfShape[selfRank - 1]; - - auto resultType = dyn_cast( - OpConversionPattern::getTypeConverter()->convertType( - op.getType())); - auto resultShape = resultType.getShape(); - auto resultElemTy = resultType.getElementType(); - - // Get op's parameters - SmallVector outputSize; - SmallVector scaleFactors; - double scalesH; - double scalesW; - int64_t outputHeight; - int64_t outputWidth; - if constexpr (std::is_same()) { - if (!matchPattern(op.getOutputSize(), - m_TorchListOfConstantInts(outputSize))) - return rewriter.notifyMatchFailure( - op, "Non-constant output size not supported"); - - outputHeight = outputSize[0]; - outputWidth = outputSize[1]; - - if (isa(op.getScalesH().getType())) { - scalesH = - static_cast(outputHeight) / static_cast(selfHeight); - } else { - if (!matchPattern(op.getScalesH(), m_TorchConstantFloat(&scalesH))) - return rewriter.notifyMatchFailure( - op, "Non-constant height scales not supported"); - - scalesH = std::ceil(scalesH); - } - - if (isa(op.getScalesW().getType())) { - scalesW = - static_cast(outputWidth) / static_cast(selfWidth); - } else { - if (!matchPattern(op.getScalesW(), m_TorchConstantFloat(&scalesW))) - return rewriter.notifyMatchFailure( - op, "Non-constant width scales not supported"); - - scalesW = std::ceil(scalesW); - } - } else if constexpr (std::is_same()) { - auto isOutputSizeNone = - isa(op.getOutputSize().getType()); - auto isScaleFactorsNone = - isa(op.getScaleFactors().getType()); - - if ((isOutputSizeNone && isScaleFactorsNone) || - (!isOutputSizeNone && !isScaleFactorsNone)) - return rewriter.notifyMatchFailure( - op, "Must specify exactly one of output size and scale factors"); - - if (!isOutputSizeNone) { - if (!matchPattern(op.getOutputSize(), - m_TorchListOfConstantInts(outputSize))) - return rewriter.notifyMatchFailure( - op, "Non-constant output size not supported"); - - outputHeight = outputSize[0]; - outputWidth = outputSize[1]; - - // Output size values being provided implies that scale values are not - // provided - scalesH = - static_cast(outputHeight) / static_cast(selfHeight); - scalesW = - static_cast(outputWidth) / static_cast(selfWidth); - } else { - if (!matchPattern(op.getScaleFactors(), - m_TorchListOfConstantFloats(scaleFactors))) - return rewriter.notifyMatchFailure( - op, "Non-constant output size not supported"); - - scalesH = std::ceil(scaleFactors[0]); - scalesW = std::ceil(scaleFactors[1]); - - // Scale values being provided implies that output size values are not - // provided - outputHeight = static_cast(scalesH * selfHeight); - outputWidth = static_cast(scalesW * selfWidth); - } - } - - // Reshape input - SmallVector reshapedSelfShape(selfShape.begin(), - selfShape.end() - 2); - reshapedSelfShape.push_back(selfHeight * selfWidth); - - auto reshapedSelf = rewriter.create( - op->getLoc(), RankedTensorType::get(reshapedSelfShape, selfElemTy), - self, rewriter.getDenseI64ArrayAttr(reshapedSelfShape)); - - // Calculate PyTorch-styled gather indices - SmallVector targetIndicesVec; - int64_t indexRepeat = std::accumulate( - selfShape.begin(), selfShape.end() - 2, 1, std::multiplies()); - for (int64_t i = 0; i < indexRepeat; i++) { - for (int64_t heightIndex = 0; heightIndex < outputHeight; heightIndex++) { - for (int64_t widthIndex = 0; widthIndex < outputWidth; widthIndex++) { - targetIndicesVec.push_back(static_cast( - std::floor(heightIndex / scalesH) * selfWidth + - std::floor(widthIndex / scalesW))); - } - } - } - - SmallVector targetIndicesShape(selfShape.begin(), - selfShape.end() - 2); - targetIndicesShape.push_back(outputHeight * outputWidth); - auto targetIndicesTorch = - tosa::getConstTensor(rewriter, op, targetIndicesVec, - targetIndicesShape) - .value(); - - // Convert PyTorch-styled indices to TensorFlow-styled indices - auto targetIndicesTF = tosa::convertTorchIndexToTfIndices( - rewriter, op, reshapedSelf.getResult(), targetIndicesTorch, - selfRank - 2); - if (!targetIndicesTF) - return rewriter.notifyMatchFailure( - op, "Convert PyTorch-styled indices and dim " - "to TensorFlow-styled indices failed"); - // Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve - // target elements - auto gatherOp = tosa::convertGatherNdOp( - rewriter, op, RankedTensorType::get(targetIndicesShape, resultElemTy), - reshapedSelf.getResult(), targetIndicesTF.value()); - if (!gatherOp) - return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed"); - - auto result = rewriter.create( - op->getLoc(), resultType, gatherOp.value(), - rewriter.getDenseI64ArrayAttr(resultShape)); - - rewriter.replaceOp(op, {result.getResult()}); - - return success(); - } -}; - -// Legalization for aten.logit -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenLogitOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - // Logit formula: - // result = log(zi / (1 - zi)) - // Where: if eps is not None: - // zi = input clampled to [eps, 1 - eps] - // else: - // zi = input - auto self = adaptor.getSelf(); - - auto selfType = dyn_cast(self.getType()); - if (!selfType) - return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); - - auto resultType = - dyn_cast(typeConverter->convertType(op.getType())); - auto resultElemTy = resultType.getElementType(); - - if (!isa(resultElemTy)) - return rewriter.notifyMatchFailure( - op, "Only floating-point datatype result types are supported"); - - // If input is not a float type then cast it to result element type - auto selfElemTy = selfType.getElementType(); - if (!isa(selfElemTy)) - self = tosa::promoteType(rewriter, self, resultType); - - bool isEpsNone = isa(op.getEps().getType()); - - double eps; - if (!isEpsNone && !matchPattern(op.getEps(), m_TorchConstantFloat(&eps))) - return rewriter.notifyMatchFailure(op, - "Non-const eps value is not supported"); - - auto zi = self; - - // Clamp input to [eps, 1 - eps] when eps is not None - if (!isEpsNone) { - zi = rewriter - .create( - op->getLoc(), resultType, self, - rewriter.getI64IntegerAttr(static_cast(eps)), - rewriter.getI64IntegerAttr(static_cast(1 - eps)), - rewriter.getF32FloatAttr(static_cast(eps)), - rewriter.getF32FloatAttr(static_cast(1 - eps))) - .getResult(); - } - - auto one = - tosa::getConstTensor(rewriter, op, 1.0f, {}, resultElemTy).value(); - - auto oneMinusZi = - rewriter.create(op->getLoc(), resultType, one, zi); - - auto oneMinusZiReciprocal = rewriter.create( - op->getLoc(), resultType, oneMinusZi.getResult()); - - auto mulOp = rewriter.create(op->getLoc(), resultType, zi, - oneMinusZiReciprocal.getResult(), - /*shift=*/0); - - auto result = - rewriter.create(op->getLoc(), resultType, mulOp.getResult()); - - rewriter.replaceOp(op, {result.getResult()}); - - return success(); -} - -// Legalization for aten.log1p -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenLog1pOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - // log1p formula: - // yi = log(xi + 1) - auto self = adaptor.getSelf(); - - auto selfType = dyn_cast(self.getType()); - if (!selfType) - return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); - - auto resultType = - dyn_cast(typeConverter->convertType(op.getType())); - auto resultElemTy = resultType.getElementType(); - - if (!isa(resultElemTy)) - return rewriter.notifyMatchFailure( - op, "Only floating-point datatype result types are supported"); - - // If input is not a float type then cast it to result element type - auto selfElemTy = selfType.getElementType(); - if (!isa(selfElemTy)) - self = tosa::promoteType(rewriter, self, resultType); - - auto one = - tosa::getConstTensor(rewriter, op, 1.0f, {}, resultElemTy).value(); - - auto addOp = - rewriter.create(op->getLoc(), resultType, self, one); - - auto result = - rewriter.create(op->getLoc(), resultType, addOp.getResult()); - - rewriter.replaceOp(op, {result.getResult()}); - - return success(); -} - -// Legalization for aten.log10 -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenLog10Op op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - // log10 formula (using log base changing formula since TOSA doesn't have a - // builtin log10 op): - // yi = log(xi) / log(10) - auto self = adaptor.getSelf(); - - auto selfType = dyn_cast(self.getType()); - if (!selfType) - return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); - - auto resultType = - dyn_cast(typeConverter->convertType(op.getType())); - auto resultElemTy = resultType.getElementType(); - - if (!isa(resultElemTy)) - return rewriter.notifyMatchFailure( - op, "Only floating-point datatype result types are supported"); - - // If input is not a float type then cast it to result element type - auto selfElemTy = selfType.getElementType(); - if (!isa(selfElemTy)) - self = tosa::promoteType(rewriter, self, resultType); - - auto ten = tosa::getConstTensor(rewriter, op, 10.0f, {}, resultElemTy) - .value(); - - auto logOfSelf = rewriter.create(op->getLoc(), resultType, self); - - auto constType = RankedTensorType::get({}, resultElemTy); - - auto logOfTen = rewriter.create(op->getLoc(), constType, ten); - - auto reciprocalOp = rewriter.create( - op->getLoc(), constType, logOfTen.getResult()); - - auto result = rewriter.create( - op->getLoc(), resultType, logOfSelf.getResult(), reciprocalOp.getResult(), - /*shift=*/0); - - rewriter.replaceOp(op, {result.getResult()}); - - return success(); -} - -// Legalization for aten.tan -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenTanOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - // tan = sin / cos - auto self = adaptor.getSelf(); - - auto selfType = dyn_cast(self.getType()); - if (!selfType) - return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); - - auto resultType = - dyn_cast(typeConverter->convertType(op.getType())); - - if (!isa(resultType.getElementType())) - return rewriter.notifyMatchFailure( - op, "Only floating-point datatype result types are supported"); - - // Non floating point inputs are not supported in TOSA so we cast the input - // to result type - if (!isa(selfType.getElementType())) - self = tosa::promoteType(rewriter, self, resultType); - - auto sinOp = rewriter.create(op->getLoc(), resultType, self); - - auto cosOp = rewriter.create(op->getLoc(), resultType, self); - - auto reciprocalOp = - rewriter.create(op->getLoc(), resultType, cosOp); - - auto result = rewriter.create( - op->getLoc(), resultType, sinOp.getResult(), reciprocalOp.getResult(), - /*shift=*/0); - - rewriter.replaceOp(op, {result.getResult()}); - - return success(); -} - -// Legalization for aten.unfold -template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - AtenUnfoldOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - // Approach: Use GatherOp to retrieve target elements from target dim and then - // reshape the output into slices according to the output shape - // - // Lowering steps: - // 1. Create PyTorch-style indices tensor corresponding to target elements and - // reshape them to (d_0, d_1, ..., nWindows * size, ..., d_(rank - 1)) - // with d_x being the dimension size of the input at dim x. - // The indices vector will be calculated using the following formula: - // for i in range(d_0 * d_1 * ... * d_(target_dim - 1)): - // for window in range(nWindows): - // for elementIndex in range(size): - // for j in range(d_(target_dim + 1) * ... * d_(rank-1)): - // indices_vec.push_back(elementIndex + window * step) - // 2. Convert PyTorch-style indices and target dim to TensorFlow-style indices - // 3. Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve - // target elements - // 4. Reshape result from above to correct output shape - auto self = adaptor.getSelf(); - - auto selfType = dyn_cast(self.getType()); - if (!selfType) - return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); - - auto selfShape = selfType.getShape(); - auto selfRank = selfType.getRank(); - auto selfElemTy = selfType.getElementType(); - - auto resultType = - dyn_cast(typeConverter->convertType(op.getType())); - auto resultElemTy = resultType.getElementType(); - - int64_t dim; - if (!matchPattern(op.getDimension(), m_TorchConstantInt(&dim))) - return rewriter.notifyMatchFailure(op, - "Only constant int dims are supported"); - - int64_t size; - if (!matchPattern(op.getSize(), m_TorchConstantInt(&size))) - return rewriter.notifyMatchFailure(op, - "Only constant int sizes are supported"); - - int64_t step; - if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) - return rewriter.notifyMatchFailure(op, - "Only constant int steps are supported"); - - if (step <= 0) - return rewriter.notifyMatchFailure(op, "Step value must be greater than 0"); - - // Handle rank zero - if (selfRank == 0) { - if (dim != 0) - return rewriter.notifyMatchFailure( - op, "Unsupported dim value for rank zero input"); - - if (size != 1) - return rewriter.notifyMatchFailure( - op, "Unsupported size value for rank zero input"); - - auto result = rewriter.create( - op->getLoc(), RankedTensorType::get({1}, selfElemTy), self, - rewriter.getDenseI64ArrayAttr({1})); - - rewriter.replaceOp(op, {result.getResult()}); - return success(); - } - - dim = toPositiveDim(dim, selfRank); - if (!isValidDim(dim, selfRank)) - return rewriter.notifyMatchFailure(op, "Dim value is invalid"); - - // Size of dimension 'dim' in the returned tensor (or number of windows within - // the dimension that got sliced) - int64_t nWindows = (selfShape[dim] - size) / step + 1; - - // Find number of times that each base index value gets repeated for target - // dim based on dim values before and after target dim i.e. preDimAccumulate = - // d_0 * d_1 * ... * d_(target_dim - 1) - // postDimAccumulate = d_(target_dim + 1) * ... * d_(rank - 1) - int64_t preDimAccumulate = - std::accumulate(selfShape.begin(), selfShape.begin() + dim, 1, - std::multiplies()); - int64_t postDimAccumulate = - std::accumulate(selfShape.begin() + dim + 1, selfShape.end(), 1, - std::multiplies()); - - // Calculate PyTorch-style gather indices vector - // Example: shape = (2, 4, 3), dim = 1, size = 3, step = 1 - // -> preDimAccumulate = 2, postDimAccummulate = 3, nWindows = 2 - // pyTorchIndicesBaseVec = [0, 0, 0, 1, 1, 1, 2, 2, 2, - // 1, 1, 1, 2, 2, 2, 3, 3, 3] - // pyTorchIndicesVec = [0, 0, 0, 1, 1, 1, 2, 2, 2, - // 1, 1, 1, 2, 2, 2, 3, 3, 3, - // 0, 0, 0, 1, 1, 1, 2, 2, 2, - // 1, 1, 1, 2, 2, 2, 3, 3, 3] - SmallVector pyTorchIndicesBaseVec; - SmallVector pyTorchIndicesVec; - - for (int64_t window = 0; window < nWindows; window++) { - for (int64_t elementIndex = 0; elementIndex < size; elementIndex++) { - int32_t baseIndex = static_cast(elementIndex + window * step); - for (int64_t i = 0; i < postDimAccumulate; i++) - pyTorchIndicesBaseVec.push_back(baseIndex); - } - } - - for (int64_t i = 0; i < preDimAccumulate; i++) - pyTorchIndicesVec.insert(pyTorchIndicesVec.end(), - pyTorchIndicesBaseVec.begin(), - pyTorchIndicesBaseVec.end()); - - // Create the PyTorch-style indices tensor - // Continuing with the previous example: - // pyTorchIndicesShape = (2, nWindows * size, 3) = (2, 6, 3) - // pyTorchIndices = tensor([[[0, 0, 0], - // [1, 1, 1], - // [2, 2, 2], - // [1, 1, 1], - // [2, 2, 2], - // [3, 3, 3]], - // [[0, 0, 0], - // [1, 1, 1], - // [2, 2, 2], - // [1, 1, 1], - // [2, 2, 2], - // [3, 3, 3]]]) - SmallVector pyTorchIndicesShape(selfShape); - pyTorchIndicesShape[dim] = nWindows * size; - auto pyTorchIndices = - tosa::getConstTensor(rewriter, op, pyTorchIndicesVec, - pyTorchIndicesShape) - .value(); - - // Convert PyTorch-style indices to TensorFlow-style indices - auto tfIndices = tosa::convertTorchIndexToTfIndices(rewriter, op, self, - pyTorchIndices, dim); - if (!tfIndices) - return rewriter.notifyMatchFailure(op, - "Convert PyTorch-style indices and dim " - "to TensorFlow-style indices failed"); - - // Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve - // target elements - auto gatherNdOp = tosa::convertGatherNdOp( - rewriter, op, RankedTensorType::get(pyTorchIndicesShape, resultElemTy), - self, tfIndices.value()); - if (!gatherNdOp) - return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed"); - - // Reshape to an intermediary shape where the gathered elements in dimension - // 'dim' are split back into 2 dimensions of sizes 'nWindows' and 'size' - SmallVector intermediaryShape; - for (int64_t currentDim = 0; currentDim < selfRank; currentDim++) { - if (currentDim == dim) { - intermediaryShape.push_back(nWindows); - intermediaryShape.push_back(size); - } else { - intermediaryShape.push_back(pyTorchIndicesShape[currentDim]); - } - } - - auto reshapeOp = rewriter.create( - op->getLoc(), RankedTensorType::get(intermediaryShape, resultElemTy), - gatherNdOp.value(), rewriter.getDenseI64ArrayAttr(intermediaryShape)); - - // Permute dims to the correct result order - SmallVector permutedDims; - for (int64_t currentDim = 0; currentDim < selfRank + 1; currentDim++) { - if (currentDim != dim + 1) - permutedDims.push_back(static_cast(currentDim)); - } - permutedDims.push_back(static_cast(dim + 1)); - - auto permutedDimsConst = tosa::getConstTensor( - rewriter, op, - /*vec=*/permutedDims, - /*shape=*/{static_cast(selfRank + 1)}) - .value(); - - auto result = rewriter.create( - op->getLoc(), resultType, reshapeOp.getResult(), permutedDimsConst); - - rewriter.replaceOp(op, {result.getResult()}); - - return success(); -} - -} // namespace - -// ----------------------------------------------------------------------------- -// TorchToTosa Pass -// ----------------------------------------------------------------------------- - -namespace { -class ConvertTorchToTosa : public ConvertTorchToTosaBase { -public: - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - registry.insert(); - registry.insert(); - TorchConversion::getBackendTypeConversionDependentDialects(registry); - } - - void runOnOperation() override { - MLIRContext *context = &getContext(); - ConversionTarget target(*context); - target.addLegalDialect(); - target.addIllegalDialect(); - - TypeConverter typeConverter; - typeConverter.addConversion([](Type type) { return type; }); - TorchConversion::setupBackendTypeConversion(target, typeConverter); - - populateTorchToTosaConversionLegalOps(target); - - RewritePatternSet patterns(context); - - auto illegalOps = populateTorchToTosaConversionPatternsAndIllegalOps( - typeConverter, patterns); - - for (auto op : illegalOps) { - target.addIllegalOp(OperationName(op, context)); - } - - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) - return signalPassFailure(); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); } }; } // namespace -void torch::populateTorchToTosaConversionLegalOps(ConversionTarget &target) { - // The following ops are never the primary reason why lowering fails. - // The backend contract only allows functions to return tensors thus there - // is always another op using them. - // When we have a chain of torch.constant.int followed by a unsupported - // torch op, we want the pass to mention the unsupported torch op - // in the error message. - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); -} - -std::set torch::populateTorchToTosaConversionPatternsAndIllegalOps( - TypeConverter &typeConverter, RewritePatternSet &patterns) { - - MLIRContext *context = patterns.getContext(); - std::set illegalOps; - -#define INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenOp, TosaOp) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, \ - context); - INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenLogOp, tosa::LogOp) - INSERT_UNARY_PROMOTE_TO_FP_PATTERN(AtenExpOp, tosa::ExpOp) -#undef INSERT_UNARY_PROMOTE_TO_FP_PATTERN - -#define INSERT_UNARY_PATTERN(AtenOp, TosaOp) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, context); - INSERT_UNARY_PATTERN(AtenNegOp, tosa::NegateOp) - INSERT_UNARY_PATTERN(AtenFloorOp, tosa::FloorOp) - INSERT_UNARY_PATTERN(AtenRsqrtOp, tosa::RsqrtOp) - INSERT_UNARY_PATTERN(AtenBitwiseNotOp, tosa::BitwiseNotOp) - INSERT_UNARY_PATTERN(AtenCeilOp, tosa::CeilOp) - INSERT_UNARY_PATTERN(AtenReciprocalOp, tosa::ReciprocalOp) - INSERT_UNARY_PATTERN(AtenCosOp, tosa::CosOp) - INSERT_UNARY_PATTERN(AtenSinOp, tosa::SinOp) - INSERT_UNARY_PATTERN(AtenLogicalNotOp, tosa::LogicalNotOp) -#undef INSERT_UNARY_PATTERN - -#define INSERT_BINARY_PATTERN(AtenOp, TosaOp) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, context); - INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp) - INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp) - INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp) - INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp) - INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp) - INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp, tosa::LogicalLeftShiftOp) - INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp, - tosa::ArithmeticRightShiftOp) -#undef INSERT_BINARY_PATTERN - -#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, context); - INSERT_BINARY_ADDSUB_PATTERN(AtenAddTensorOp, tosa::AddOp) - INSERT_BINARY_ADDSUB_PATTERN(AtenAddScalarOp, tosa::AddOp) - INSERT_BINARY_ADDSUB_PATTERN(AtenSubTensorOp, tosa::SubOp) - INSERT_BINARY_ADDSUB_PATTERN(AtenSubScalarOp, tosa::SubOp) -#undef INSERT_BINARY_ADDSUB_PATTERN - -#define INSERT_BINARY_COMPARE_PATTERN(AtenOp, TosaOp) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, context); - INSERT_BINARY_COMPARE_PATTERN(AtenGtTensorOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenGeScalarOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenGeTensorOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp) - INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp) -#undef INSERT_BINARY_COMPARE_PATTERN - -#define INSERT_BINARY_MUL_PATTERN(AtenOp) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, context); - INSERT_BINARY_MUL_PATTERN(AtenMulTensorOp); - INSERT_BINARY_MUL_PATTERN(AtenMulScalarOp); -#undef INSERT_BINARY_MUL_PATTERN - -#define INSERT_BINARY_DIV_PATTERN(AtenOp) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, context); - INSERT_BINARY_DIV_PATTERN(AtenDivTensorOp); - INSERT_BINARY_DIV_PATTERN(AtenDivScalarOp); - INSERT_BINARY_DIV_PATTERN(AtenDivTensorModeOp); - INSERT_BINARY_DIV_PATTERN(AtenDivScalarModeOp); -#undef INSERT_BINARY_DIV_PATTERN - -#define INSERT_REMAINDER_FMOD_OP_PATTERN(AtenOp) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, context); - INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderScalarOp); - INSERT_REMAINDER_FMOD_OP_PATTERN(AtenRemainderTensorOp); - INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodScalarOp); - INSERT_REMAINDER_FMOD_OP_PATTERN(AtenFmodTensorOp); -#undef INSERT_REMAINDER_FMOD_OP_PATTERN - -#define INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>( \ - typeConverter, context); - INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenMeanDimOp, - mlir::tosa::convertReduceMeanOp) - INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp, - mlir::tosa::convertReduceSumOp) - INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp, - mlir::tosa::convertLinalgVectorNormOp) -#undef INSERT_NDIMS_REDUCTION_OP_PATTERN - -#define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>( \ - typeConverter, context); - INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAnyDimOp, - mlir::tosa::convertReduceAnyOp) - INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenAllDimOp, - mlir::tosa::convertReduceAllOp) - INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenProdDimIntOp, - mlir::tosa::convertReduceProdOp) -#undef INSERT_ONEDIM_REDUCTION_OP_PATTERN - -#define INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>( \ - typeConverter, context); - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAllOp, mlir::tosa::convertReduceAllOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenAnyOp, mlir::tosa::convertReduceAnyOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenSumOp, mlir::tosa::convertReduceSumOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMaxOp, mlir::tosa::convertReduceMaxOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenMinOp, mlir::tosa::convertReduceMinOp) - INSERT_ALLDIMS_REDUCTION_OP_PATTERN(AtenProdOp, - mlir::tosa::convertReduceProdOp) -#undef INSERT_ALLDIMS_REDUCTION_OP_PATTERN - -#define INSERT_INDICES_REDUCTION_OP_PATTERN(AtenOp, TosaOp) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, context); - INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMaxDimOp, tosa::ReduceMaxOp); - INSERT_INDICES_REDUCTION_OP_PATTERN(AtenMinDimOp, tosa::ReduceMinOp); -#undef INSERT_INDICES_REDUCTION_OP_PATTERN - -#define INSERT_SQUEEZE_OP_PATTERN(AtenOp, TemplateForm) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, context); - INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeOp, ConvertAtenSqueezeAllDimsOp) - INSERT_SQUEEZE_OP_PATTERN(AtenSqueezeDimOp, ConvertAtenSqueezeOneDimOp) -#undef INSERT_SQUEEZE_OP_PATTERN - -#define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, context); - INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp); -#undef INSERT_MATMUL_ATEMOP_PATTERN - -#define INSERT_MM_ATENOP_PATTERN(AtenOp) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, context); - INSERT_MM_ATENOP_PATTERN(AtenMmOp); - INSERT_MM_ATENOP_PATTERN(AtenBmmOp); -#undef INSERT_MM_ATEMOP_PATTERN - -#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, context); - INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp); -#undef INSERT_LINEAR_ATEMOP_PATTERN - -#define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, \ - context); - INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenAdaptiveAvgPool2dOp, - tosa::AvgPool2dOp); -#undef INSERT_ADAPTIVE_POOLING_ATEMOP_PATTERN - - illegalOps.insert(AtenMaxPool2dOp::getOperationName()); - patterns.add(typeConverter, context); - - illegalOps.insert(AtenMaxPool1dOp::getOperationName()); - patterns.add(typeConverter, context); - - illegalOps.insert(AtenAvgPool2dOp::getOperationName()); - patterns.add(typeConverter, context); - - illegalOps.insert(AtenAvgPool1dOp::getOperationName()); - patterns.add(typeConverter, context); - -#define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, \ - context); - INSERT_CONSTANT_FILL_PATTERN(AtenOnesOp, 1); - INSERT_CONSTANT_FILL_PATTERN(AtenZerosOp, 0); - INSERT_CONSTANT_FILL_PATTERN(AtenEmptyMemoryFormatOp, 0); -#undef INSERT_CONSTANT_FILL_PATTERN - -#define INSERT_FILL_PATTERN(AtenOp) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, context); - INSERT_FILL_PATTERN(AtenFill_ScalarOp); - INSERT_FILL_PATTERN(AtenFillScalarOp); - INSERT_FILL_PATTERN(AtenFillTensorOp); -#undef INSERT_FILL_PATTERN - -#define INSERT_MASKED_FILL_PATTERN(AtenOp) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, context); - INSERT_MASKED_FILL_PATTERN(AtenMaskedFillScalarOp); - INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp); -#undef INSERT_MASKED_FILL_PATTERN - -#define INSERT_POW_OP_PATTERN(AtenOp) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, context); - INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp); - INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp); - INSERT_POW_OP_PATTERN(AtenPowScalarOp); -#undef INSERT_POW_OP_PATTERN - -#define INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenOp) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, context); - INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dOp); - INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN(AtenUpsampleNearest2dVecOp); -#undef INSERT_UPSAMPLE_NEAREST_2D_FORWARD_OP_PATTERN - -#define INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenOp, TosaOp) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, \ - context); - INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp); - INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp); - INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp); -#undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN - -#define INSERT_ATENOP_PATTERN(AtenOp) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, context); - INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp); - INSERT_ATENOP_PATTERN(AtenReluOp); - INSERT_ATENOP_PATTERN(AtenLeakyReluOp); - INSERT_ATENOP_PATTERN(AtenArgmaxOp); - INSERT_ATENOP_PATTERN(AtenRsubScalarOp); - INSERT_ATENOP_PATTERN(AtenConvolutionOp); - INSERT_ATENOP_PATTERN(ValueTensorLiteralOp); - INSERT_ATENOP_PATTERN(AtenReshapeOp); - INSERT_ATENOP_PATTERN(AtenBatchNormOp); - INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); - INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp); - INSERT_ATENOP_PATTERN(AtenUnflattenIntOp); - INSERT_ATENOP_PATTERN(AtenPermuteOp); - INSERT_ATENOP_PATTERN(AtenLog2Op); - INSERT_ATENOP_PATTERN(AtenThresholdOp); - INSERT_ATENOP_PATTERN(AtenUnsqueezeOp); - INSERT_ATENOP_PATTERN(AtenContiguousOp); - INSERT_ATENOP_PATTERN(AtenDropoutOp); - INSERT_ATENOP_PATTERN(AtenViewOp); - INSERT_ATENOP_PATTERN(AtenGeluOp); - INSERT_ATENOP_PATTERN(AtenGeluBackwardOp); - INSERT_ATENOP_PATTERN(AtenEmbeddingOp); - INSERT_ATENOP_PATTERN(AtenTransposeIntOp); - INSERT_ATENOP_PATTERN(AtenSliceTensorOp); - INSERT_ATENOP_PATTERN(AtenBroadcastToOp); - INSERT_ATENOP_PATTERN(AtenGatherOp); - INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp); - INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp); - INSERT_ATENOP_PATTERN(AtenAbsOp); - INSERT_ATENOP_PATTERN(AtenWhereSelfOp); - INSERT_ATENOP_PATTERN(AtenClampOp); - INSERT_ATENOP_PATTERN(AtenArangeStartStepOp); - INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp); - INSERT_ATENOP_PATTERN(AtenCopyOp); - INSERT_ATENOP_PATTERN(AtenToDtypeOp); - INSERT_ATENOP_PATTERN(AtenConstantPadNdOp); - INSERT_ATENOP_PATTERN(AtenCatOp); - INSERT_ATENOP_PATTERN(AtenSqrtOp); - INSERT_ATENOP_PATTERN(AtenIscloseOp); - INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp); - INSERT_ATENOP_PATTERN(AtenTrilOp); - INSERT_ATENOP_PATTERN(AtenDiagonalOp); - INSERT_ATENOP_PATTERN(AtenIndexSelectOp); - INSERT_ATENOP_PATTERN(AtenFlipOp); - INSERT_ATENOP_PATTERN(AtenRoundOp); - INSERT_ATENOP_PATTERN(AtenScatterSrcOp); - INSERT_ATENOP_PATTERN(AtenSliceScatterOp); - INSERT_ATENOP_PATTERN(AtenDiagEmbedOp); - INSERT_ATENOP_PATTERN(AtenUniformOp); - INSERT_ATENOP_PATTERN(AtenThresholdBackwardOp); - INSERT_ATENOP_PATTERN(AtenAsStridedOp); - INSERT_ATENOP_PATTERN(AtenClampTensorOp); - INSERT_ATENOP_PATTERN(PrimsCollapseOp); - INSERT_ATENOP_PATTERN(AtenReflectionPad1dOp); - INSERT_ATENOP_PATTERN(AtenReflectionPad2dOp); - INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp); - INSERT_ATENOP_PATTERN(PrimsSplitDimOp); - INSERT_ATENOP_PATTERN(AtenOuterOp); - INSERT_ATENOP_PATTERN(AtenLogitOp); - INSERT_ATENOP_PATTERN(AtenLog1pOp); - INSERT_ATENOP_PATTERN(AtenLog10Op); - INSERT_ATENOP_PATTERN(AtenTanOp); - INSERT_ATENOP_PATTERN(AtenUnfoldOp); -#undef INSERT_ATENOP_PATTERN - -#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ - illegalOps.insert(AtenOp::getOperationName()); \ - patterns.add>(typeConverter, context); - INSERT_CLONE_ATENOP_PATTERN(AtenCloneOp); -#undef INSERT_CLONE_ATENOP_PATTERN - - return illegalOps; -} - std::unique_ptr> mlir::torch::createConvertTorchToTosaPass() { return std::make_unique(); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index ee7f61becf4f..4df8a221d556 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -1031,17 +1031,11 @@ convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op, return std::nullopt; } - auto input_value_casted = - tosa::promoteType(rewriter, input_value, output_type); - auto absVal = CreateOpAndInfer( - rewriter, op->getLoc(), - RankedTensorType::get(input_type.getShape(), elemType), - input_value_casted) + auto absVal = CreateOpAndInfer(rewriter, op->getLoc(), + input_type, input_value) .getResult(); - auto powVal = CreateOpAndInfer( - rewriter, op->getLoc(), - RankedTensorType::get(input_type.getShape(), elemType), - absVal, ordVal) + auto powVal = CreateOpAndInfer(rewriter, op->getLoc(), + input_type, absVal, ordVal) .getResult(); std::optional result = convertReduceSumOp( rewriter, op, output_type, powVal, axes_elems, keep_dims); diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 72217e5f4afd..e3f5b6d0299a 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -447,119 +447,6 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, return castIntToIndex(rewriter, loc, boundedByDimSize); } -// Helper function to unsqueeze the input tensor at given dim. -// Returns the unsqueezed tensor or failure. -FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, - Value input, int64_t dim) { - auto inputType = cast(input.getType()); - int64_t inputRank = inputType.getRank(); - ArrayRef inputShape = inputType.getShape(); - - // `input` has a reduced rank. Hence add 1. - int64_t unsqueezedRank = inputShape.size() + 1; - dim = toPositiveDim(dim, unsqueezedRank); - if (!isValidDim(dim, unsqueezedRank)) { - return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); - } - - SmallVector unsqueezedShape{inputShape}; - unsqueezedShape.insert(unsqueezedShape.begin() + dim, 1); - Type unsqueezedType = - RankedTensorType::get(unsqueezedShape, inputType.getElementType()); - - SmallVector reassociationMap(inputRank); - // From the perspective of the reassociation map, the situation of - // unsqueezing before or after the last dimension is symmetrical. - // Normalize it to the "before" case. - // The 0 case is special here, since there is no last dimension to insert - // before -- we simply rely on the loop below iterating 0 times. - if (dim == inputRank && inputRank != 0) - dim = inputRank - 1; - bool alreadyCrossedExpandedDim = false; - for (int i = 0; i != inputRank; i++) { - if (alreadyCrossedExpandedDim) { - reassociationMap[i].push_back(i + 1); - } else { - reassociationMap[i].push_back(i); - if (i == dim) { - reassociationMap[i].push_back(i + 1); - alreadyCrossedExpandedDim = true; - } - } - } - Value unsqueezed = rewriter.create( - op->getLoc(), unsqueezedType, input, reassociationMap); - return unsqueezed; -} - -// Helper function to squeeze the input tensor at given dim. -// Returns the squeezed tensor or failure. -FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, - Value input, int64_t dim) { - Location loc = op->getLoc(); - auto inputType = cast(input.getType()); - int64_t inputRank = inputType.getRank(); - - // No scope for squeezing the input. - if (inputRank == 0) - return input; - - dim = toPositiveDim(dim, inputRank); - if (!isValidDim(dim, inputRank)) - return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - - // assert dynamic squeeze dim size == 1 - if (inputType.isDynamicDim(dim)) { - Value cstDim = rewriter.create(loc, dim); - Value dimVal = rewriter.create(loc, input, cstDim); - Value cstOne = rewriter.create(loc, 1); - Value cmp = rewriter.create(loc, arith::CmpIPredicate::eq, - dimVal, cstOne); - rewriter.create( - loc, cmp, - rewriter.getStringAttr( - "Expected dynamic squeeze dim size to be statically 1")); - } - - ArrayRef inputShape = inputType.getShape(); - SmallVector squeezedShape; - squeezedShape.append(inputShape.begin(), inputShape.begin() + dim); - squeezedShape.append(inputShape.begin() + dim + 1, inputShape.end()); - int64_t squeezedRank = inputRank - 1; - Type squeezedType = - RankedTensorType::get(squeezedShape, inputType.getElementType()); - - // If the dim(th) dimension of operand tensor type is not statically unit, - // squeeze will behave as an identity operation. - if (inputType.getDimSize(dim) != 1 && !inputType.isDynamicDim(dim)) { - return input; - } - - SmallVector reassociationMap(squeezedRank); - bool alreadyCrossedSqueezedDim = false; - for (int i = 0; i != squeezedRank; i++) { - if (alreadyCrossedSqueezedDim) { - reassociationMap[i].push_back(i + 1); - } else { - reassociationMap[i].push_back(i); - if (dim != 0 && i != dim - 1) - continue; - - alreadyCrossedSqueezedDim = true; - if (dim == 0) - reassociationMap[0].push_back(1); - if (i == dim - 1) - reassociationMap[i].push_back(dim); - } - } - // Note: In case the operand tensor type is of unit rank and is statically - // shaped with unit dimension, the `reassociationMap` will be empty and the - // input will be collapsed to a 0-D tensor. - Value squeezed = rewriter.create( - op->getLoc(), squeezedType, input, reassociationMap); - return squeezed; -} - } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index eafbe14162cc..3774e65f0859 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -769,22 +769,6 @@ OpFoldResult Aten__Or__BoolOp::fold(FoldAdaptor adaptor) { return nullptr; } -//===----------------------------------------------------------------------===// -// AtenEqBoolOp -//===----------------------------------------------------------------------===// - -OpFoldResult AtenEqBoolOp::fold(FoldAdaptor adaptor) { - if (getOperand(0) == getOperand(1)) - return IntegerAttr::get(IntegerType::get(getContext(), 1), true); - - auto intAttrA = dyn_cast_or_null(adaptor.getA()); - auto intAttrB = dyn_cast_or_null(adaptor.getB()); - if (!intAttrA || !intAttrB) - return nullptr; - return IntegerAttr::get(IntegerType::get(getContext(), 1), - intAttrA.getValue() == intAttrB.getValue()); -} - //===----------------------------------------------------------------------===// // AtenNeBoolOp //===----------------------------------------------------------------------===// @@ -793,12 +777,12 @@ OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) { if (getOperand(0) == getOperand(1)) return IntegerAttr::get(IntegerType::get(getContext(), 1), false); - auto intAttrA = dyn_cast_or_null(adaptor.getA()); - auto intAttrB = dyn_cast_or_null(adaptor.getB()); - if (!intAttrA || !intAttrB) + bool a, b; + if (!matchPattern(getOperand(0), m_TorchConstantBool(&a))) + return nullptr; + if (!matchPattern(getOperand(1), m_TorchConstantBool(&b))) return nullptr; - return IntegerAttr::get(IntegerType::get(getContext(), 1), - intAttrA.getValue() != intAttrB.getValue()); + return IntegerAttr::get(IntegerType::get(getContext(), 1), a != b); } //===----------------------------------------------------------------------===// @@ -1147,35 +1131,6 @@ void AtenLenTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } -//===----------------------------------------------------------------------===// -// AtenMulLeftTOp -//===----------------------------------------------------------------------===// - -void AtenMulLeftTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, - MLIRContext *context) { - // `[1,2] * 3` -> `[1,2,1,2,1,2]`, if it is not mutated. - patterns.add(+[](AtenMulLeftTOp op, PatternRewriter &rewriter) { - auto listLiteral = op.getL().getDefiningOp(); - if (!listLiteral || isListPotentiallyMutated(listLiteral)) - return failure(); - - int64_t numReps; - if (!matchPattern(op.getN(), m_TorchConstantInt(&numReps))) - return failure(); - - SmallVector newListElements; - for (int rep = 0; rep < numReps; ++rep) { - for (auto operand : listLiteral.getOperands()) { - newListElements.push_back(operand); - } - } - - rewriter.replaceOpWithNewOp(op, op.getL().getType(), - newListElements); - return success(); - }); -} - //===----------------------------------------------------------------------===// // AtenMinOtherOp //===----------------------------------------------------------------------===// @@ -4113,10 +4068,6 @@ OpFoldResult AtenMulIntOp::fold(FoldAdaptor adaptor) { int64_t lhs, rhs; bool lConstant = matchPattern(getOperand(0), m_TorchConstantInt(&lhs)); bool rConstant = matchPattern(getOperand(1), m_TorchConstantInt(&rhs)); - if (lConstant && lhs == 1) - return getOperand(1); - if (rConstant && rhs == 1) - return getOperand(0); if ((lConstant && lhs == 0) || (rConstant && rhs == 0)) return getI64IntegerAttr(getContext(), 0); if (lConstant && rConstant) @@ -4219,19 +4170,6 @@ OpFoldResult AtenMulOp::fold(FoldAdaptor adaptor) { [](double a, double b) -> double { return a * b; }); } -//===----------------------------------------------------------------------===// -// AtenMulIntFloatOp -//===----------------------------------------------------------------------===// - -OpFoldResult AtenMulIntFloatOp::fold(FoldAdaptor adaptor) { - if (!adaptor.getA() || !adaptor.getB()) { - return nullptr; - } - return atenBinaryFloatOperatorFoldHelper( - adaptor.getOperands(), - [](double a, double b) -> double { return a * b; }); -} - //===----------------------------------------------------------------------===// // AtenSubOp //===----------------------------------------------------------------------===// @@ -4278,18 +4216,6 @@ OpFoldResult AtenAddFloatIntOp::fold(FoldAdaptor adaptor) { adaptor.getOperands(), [](double a, double b) { return a + b; }); } -//===----------------------------------------------------------------------===// -// AtenMulFloatIntOp -//===----------------------------------------------------------------------===// - -OpFoldResult AtenMulFloatIntOp::fold(FoldAdaptor adaptor) { - if (!adaptor.getA() || !adaptor.getB()) { - return nullptr; - } - return atenBinaryFloatOperatorFoldHelper( - adaptor.getOperands(), [](double a, double b) { return a * b; }); -} - //===----------------------------------------------------------------------===// // AtenPowIntFloatOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 5fd05708961c..0e4d7c40a292 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -5507,12 +5507,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" " %16 = torch.aten.__getitem__.t %11, %int0 : !torch.list, !torch.int -> !torch.float\n" -" %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n" +" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n" " %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n" " %19 = torch.aten.append.t %1, %18 : !torch.list, !torch.int -> !torch.list\n" " %20 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" " %21 = torch.aten.__getitem__.t %11, %int1 : !torch.list, !torch.int -> !torch.float\n" -" %22 = torch.aten.mul.int_float %20, %21 : !torch.int, !torch.float -> !torch.float\n" +" %22 = torch.operator \"aten.mul.int_float\"(%20, %21) : (!torch.int, !torch.float) -> !torch.float \n" " %23 = torch.aten.Int.float %22 : !torch.float -> !torch.int\n" " %24 = torch.aten.append.t %1, %23 : !torch.list, !torch.int -> !torch.list\n" " torch.prim.If.yield\n" @@ -6495,10 +6495,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.special_expm1\"(%arg0: !torch.list) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.isfinite\"(%arg0: !torch.list) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -6912,12 +6908,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " %11 = torch.aten.__getitem__.t %9, %int0 : !torch.list, !torch.int -> !torch.float\n" " %12 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" -" %13 = torch.aten.mul.float_int %11, %12 : !torch.float, !torch.int -> !torch.float\n" +" %13 = torch.operator \"aten.mul.float_int\"(%11, %12) : (!torch.float, !torch.int) -> !torch.float \n" " %14 = torch.aten.Int.float %13 : !torch.float -> !torch.int\n" " %15 = torch.aten.append.t %3, %14 : !torch.list, !torch.int -> !torch.list\n" " %16 = torch.aten.__getitem__.t %9, %int1 : !torch.list, !torch.int -> !torch.float\n" " %17 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" -" %18 = torch.aten.mul.float_int %16, %17 : !torch.float, !torch.int -> !torch.float\n" +" %18 = torch.operator \"aten.mul.float_int\"(%16, %17) : (!torch.float, !torch.int) -> !torch.float \n" " %19 = torch.aten.Int.float %18 : !torch.float -> !torch.int\n" " %20 = torch.aten.append.t %3, %19 : !torch.list, !torch.int -> !torch.list\n" " torch.prim.If.yield %true, %3 : !torch.bool, !torch.list\n" @@ -7304,12 +7300,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.rrelu_with_noise_functional\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.tuple, list> {\n" -" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" %1 = call @__torch__.torch.jit._shape_functions.unary(%arg1) : (!torch.list) -> !torch.list\n" -" %2 = torch.prim.TupleConstruct %0, %1 : !torch.list, !torch.list -> !torch.tuple, list>\n" -" return %2 : !torch.tuple, list>\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7747,7 +7737,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.If %2 -> (!torch.list) {\n" " %5 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list\n" " %6 = torch.aten.sub.int %1, %0 : !torch.int, !torch.int -> !torch.int\n" -" %7 = torch.aten.mul.left_t %5, %6 : !torch.list, !torch.int -> !torch.list\n" +" %7 = torch.operator \"aten.mul.left_t\"(%5, %6) : (!torch.list, !torch.int) -> !torch.list \n" " %8 = torch.aten.add.t %7, %arg1 : !torch.list, !torch.list -> !torch.list\n" " torch.prim.If.yield %8 : !torch.list\n" " } else {\n" @@ -8958,7 +8948,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %14 = call @__torch__.torch.jit._shape_functions.broadcast_three(%5, %6, %7) : (!torch.list, !torch.list, !torch.list) -> !torch.list\n" " %15 = torch.prim.ListConstruct %false : (!torch.bool) -> !torch.list\n" " %16 = torch.aten.len.t %14 : !torch.list -> !torch.int\n" -" %17 = torch.aten.mul.left_t %15, %16 : !torch.list, !torch.int -> !torch.list\n" +" %17 = torch.operator \"aten.mul.left_t\"(%15, %16) : (!torch.list, !torch.int) -> !torch.list \n" " %18 = torch.aten.len.t %arg6 : !torch.list -> !torch.int\n" " torch.prim.Loop %18, %true, init() {\n" " ^bb0(%arg8: !torch.int):\n" @@ -9822,7 +9812,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %76 = torch.aten.append.t %72, %75 : !torch.list, !torch.int -> !torch.list\n" " torch.prim.Loop.condition %true, iter()\n" " } : (!torch.int, !torch.bool) -> ()\n" -" %74 = torch.aten.add.t %71, %72 : !torch.list, !torch.list -> !torch.list\n" +" %74 = torch.operator \"aten.add_.t\"(%71, %72) : (!torch.list, !torch.list) -> !torch.list \n" " return %74 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.topk\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.tuple, list> {\n" @@ -10034,65 +10024,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.conv2d.padding\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.str, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" -" %0 = call @__torch__._conv_padding(%arg1, %arg5, %arg4) : (!torch.list, !torch.list, !torch.str) -> !torch.list\n" -" %1 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %0, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" -" return %1 : !torch.list\n" -" }\n" -" func.func @__torch__._conv_padding(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.str) -> !torch.list {\n" -" %true = torch.constant.bool true\n" -" %int-1 = torch.constant.int -1\n" -" %str = torch.constant.str \"same\"\n" -" %none = torch.constant.none\n" -" %str_0 = torch.constant.str \"AssertionError: conv: weight must be at least 3 dimensional.\"\n" -" %int2 = torch.constant.int 2\n" -" %int0 = torch.constant.int 0\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %1 = torch.aten.gt.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %1 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %2 = torch.aten.sub.int %0, %int2 : !torch.int, !torch.int -> !torch.int\n" -" %3 = torch.prim.ListConstruct %int0 : (!torch.int) -> !torch.list\n" -" %4 = torch.aten.mul.left_t %3, %2 : !torch.list, !torch.int -> !torch.list\n" -" %5 = torch.aten.eq.str %arg2, %str : !torch.str, !torch.str -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" -" %6 = torch.aten.sub.int %2, %int1 : !torch.int, !torch.int -> !torch.int\n" -" %7 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" -" %8 = torch.aten.__range_length %6, %int-1, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" -" %9 = torch.prim.ListConstruct %7, %8 : (!torch.int, !torch.int) -> !torch.list\n" -" %10 = torch.prim.min.self_int %9 : !torch.list -> !torch.int\n" -" torch.prim.Loop %10, %true, init() {\n" -" ^bb0(%arg3: !torch.int):\n" -" %11 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list, !torch.int -> !torch.int\n" -" %12 = torch.aten.__derive_index %arg3, %6, %int-1 : !torch.int, !torch.int, !torch.int -> !torch.int\n" -" %13 = torch.aten.add.int %int2, %12 : !torch.int, !torch.int -> !torch.int\n" -" %14 = torch.aten.__getitem__.t %arg0, %13 : !torch.list, !torch.int -> !torch.int\n" -" %15 = torch.aten.sub.int %14, %int1 : !torch.int, !torch.int -> !torch.int\n" -" %16 = torch.aten.mul.int %11, %15 : !torch.int, !torch.int -> !torch.int\n" -" %17 = torch.aten.floordiv.int %16, %int2 : !torch.int, !torch.int -> !torch.int\n" -" %18 = torch.aten._set_item.t %4, %12, %17 : !torch.list, !torch.int, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.If.yield\n" -" }\n" -" return %4 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.conv3d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.conv3d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.conv3d.padding\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.str, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" -" %0 = call @__torch__._conv_padding(%arg1, %arg5, %arg4) : (!torch.list, !torch.list, !torch.str) -> !torch.list\n" -" %1 = call @__torch__.torch.jit._shape_functions.conv3d(%arg0, %arg1, %arg2, %arg3, %0, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" -" return %1 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.conv_transpose2d.input\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.list {\n" " %0 = torch.derefine %arg3 : !torch.list to !torch.optional>\n" " %1 = torch.derefine %arg4 : !torch.list to !torch.optional>\n" @@ -10162,14 +10097,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %false, %0, %int1) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.conv1d.padding\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.str, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" -" %false = torch.constant.bool false\n" -" %int1 = torch.constant.int 1\n" -" %0 = call @__torch__._conv_padding(%arg1, %arg5, %arg4) : (!torch.list, !torch.list, !torch.str) -> !torch.list\n" -" %1 = torch.prim.ListConstruct : () -> !torch.list\n" -" %2 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %0, %arg5, %false, %1, %int1) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" -" return %2 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.conv_transpose1d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.list {\n" " %true = torch.constant.bool true\n" " %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg7, %true, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" @@ -10538,10 +10465,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.deg2rad\"(%arg0: !torch.list) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.nll_loss_forward\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> {\n" " %0 = call @__torch__.torch.jit._shape_functions.nll_loss_forward(%arg0, %arg1, %arg2, %arg3) : (!torch.list, !torch.list, !torch.optional>, !torch.int) -> !torch.tuple, list>\n" " return %0 : !torch.tuple, list>\n" @@ -10562,18 +10485,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.l1_loss\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int) -> !torch.list {\n" -" %int0 = torch.constant.int 0\n" -" %0 = torch.aten.eq.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %1 = torch.prim.If %0 -> (!torch.list) {\n" -" %2 = func.call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" torch.prim.If.yield %2 : !torch.list\n" -" } else {\n" -" %2 = torch.prim.ListConstruct : () -> !torch.list\n" -" torch.prim.If.yield %2 : !torch.list\n" -" }\n" -" return %1 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.cross_entropy_loss\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.cross_entropy_loss(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.optional>, !torch.int, !torch.int, !torch.float) -> !torch.list\n" " return %0 : !torch.list\n" @@ -10975,84 +10886,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %5 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.column_stack\"(%arg0: !torch.list>) -> !torch.list {\n" -" %true = torch.constant.bool true\n" -" %int0 = torch.constant.int 0\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.prim.ListConstruct : () -> !torch.list>\n" -" %1 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" -" torch.prim.Loop %1, %true, init() {\n" -" ^bb0(%arg1: !torch.int):\n" -" %3 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list>, !torch.int -> !torch.list\n" -" %4 = torch.aten.len.t %3 : !torch.list -> !torch.int\n" -" %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %6 = torch.prim.If %5 -> (!torch.list) {\n" -" %8 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list\n" -" torch.prim.If.yield %8 : !torch.list\n" -" } else {\n" -" %8 = torch.aten.len.t %3 : !torch.list -> !torch.int\n" -" %9 = torch.aten.eq.int %8, %int1 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %9 -> () {\n" -" %10 = torch.aten.append.t %3, %int1 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.If.yield\n" -" }\n" -" torch.prim.If.yield %3 : !torch.list\n" -" }\n" -" %7 = torch.aten.append.t %0, %6 : !torch.list>, !torch.list -> !torch.list>\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %2 = call @__torch__.torch.jit._shape_functions.cat(%0, %int1) : (!torch.list>, !torch.int) -> !torch.list\n" -" return %2 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.fft_fft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.fft_rfft\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.list {\n" -" %true = torch.constant.bool true\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: Expected dim in [-rank, rank-1]\"\n" -" %false = torch.constant.bool false\n" -" %int0 = torch.constant.int 0\n" -" %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %0 = torch.aten.lt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %1 = torch.prim.If %0 -> (!torch.int) {\n" -" %10 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %11 = torch.aten.add.int %arg2, %10 : !torch.int, !torch.int -> !torch.int\n" -" torch.prim.If.yield %11 : !torch.int\n" -" } else {\n" -" torch.prim.If.yield %arg2 : !torch.int\n" -" }\n" -" %2 = torch.aten.ge.int %1, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" %3 = torch.prim.If %2 -> (!torch.bool) {\n" -" %10 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %11 = torch.aten.lt.int %1, %10 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %11 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = torch.prim.ListConstruct : () -> !torch.list\n" -" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" torch.prim.Loop %5, %true, init() {\n" -" ^bb0(%arg4: !torch.int):\n" -" %10 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list, !torch.int -> !torch.int\n" -" %11 = torch.aten.append.t %4, %10 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %6 = torch.aten.__getitem__.t %arg0, %1 : !torch.list, !torch.int -> !torch.int\n" -" %7 = torch.aten.floordiv.int %6, %int2 : !torch.int, !torch.int -> !torch.int\n" -" %8 = torch.aten.add.int %7, %int1 : !torch.int, !torch.int -> !torch.int\n" -" %9 = torch.aten._set_item.t %4, %1, %8 : !torch.list, !torch.int, !torch.int -> !torch.list\n" -" return %4 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.stft\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.optional, %arg7: !torch.optional) -> !torch.list {\n" " %str = torch.constant.str \"AssertionError: Expected hop_length to be greater than 0\"\n" " %str_0 = torch.constant.str \"AssertionError: Expected that 0 < n_fft <= len\"\n" @@ -11140,7 +10976,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If.yield %true : !torch.bool\n" " } else {\n" " %24 = torch.prim.unchecked_cast %arg6 : !torch.optional -> !torch.bool\n" -" %25 = torch.aten.eq.bool %24, %true : !torch.bool, !torch.bool -> !torch.bool\n" +" %25 = torch.operator \"aten.eq.bool\"(%24, %true) : (!torch.bool, !torch.bool) -> !torch.bool \n" " torch.prim.If.yield %25 : !torch.bool\n" " }\n" " torch.prim.If %17 -> () {\n" @@ -11159,7 +10995,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %22 = torch.aten.__isnot__ %arg7, %none : !torch.optional, !torch.none -> !torch.bool\n" " %23 = torch.prim.If %22 -> (!torch.bool) {\n" " %24 = torch.prim.unchecked_cast %arg7 : !torch.optional -> !torch.bool\n" -" %25 = torch.aten.eq.bool %24, %false : !torch.bool, !torch.bool -> !torch.bool\n" +" %25 = torch.operator \"aten.eq.bool\"(%24, %false) : (!torch.bool, !torch.bool) -> !torch.bool \n" " torch.prim.If.yield %25 : !torch.bool\n" " } else {\n" " torch.prim.If.yield %false : !torch.bool\n" @@ -11301,7 +11137,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %14 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" " %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" " %16 = torch.aten.__getitem__.t %12, %int0 : !torch.list, !torch.int -> !torch.float\n" -" %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n" +" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n" " %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n" " %19 = torch.prim.ListConstruct %13, %14, %18 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" " torch.prim.If.yield %19 : !torch.list\n" @@ -11381,11 +11217,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %14 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" " %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" " %16 = torch.aten.__getitem__.t %12, %int0 : !torch.list, !torch.int -> !torch.float\n" -" %17 = torch.aten.mul.int_float %15, %16 : !torch.int, !torch.float -> !torch.float\n" +" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n" " %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n" " %19 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" " %20 = torch.aten.__getitem__.t %12, %int1 : !torch.list, !torch.int -> !torch.float\n" -" %21 = torch.aten.mul.int_float %19, %20 : !torch.int, !torch.float -> !torch.float\n" +" %21 = torch.operator \"aten.mul.int_float\"(%19, %20) : (!torch.int, !torch.float) -> !torch.float \n" " %22 = torch.aten.Int.float %21 : !torch.float -> !torch.int\n" " %23 = torch.prim.ListConstruct %13, %14, %18, %22 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" " torch.prim.If.yield %23 : !torch.list\n" @@ -11599,11 +11435,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.special_expm1\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.isfinite\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" @@ -12605,15 +12436,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0#1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.rrelu\"(%arg0: !torch.tuple, %arg1: !torch.number, %arg2: !torch.number, %arg3: !torch.bool, %arg4: !torch.any) -> !torch.int {\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" return %0#1 : !torch.int\n" -" }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" +" %true = torch.constant.bool true\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %3 : !torch.bool\n" +" }\n" " torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" @@ -12622,20 +12455,46 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise_functional\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.tuple {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.rrelu_with_noise\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number, %arg3: !torch.number, %arg4: !torch.bool, %arg5: !torch.any) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" +" %true = torch.constant.bool true\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %2 -> () {\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" }\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %3 = torch.prim.TupleConstruct %0#1, %1#1 : !torch.int, !torch.int -> !torch.tuple\n" -" return %3 : !torch.tuple\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_float_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %7 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" torch.prim.If.yield %7 : !torch.bool\n" +" }\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.eq.int %0#0, %1#0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.relu6\"(%arg0: !torch.tuple) -> !torch.int {\n" " %none = torch.constant.none\n" @@ -13171,44 +13030,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %3 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.fft_rfft\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.int, %arg3: !torch.optional) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" -" %int10 = torch.constant.int 10\n" -" %int7 = torch.constant.int 7\n" -" %int9 = torch.constant.int 9\n" -" %int6 = torch.constant.int 6\n" -" %int8 = torch.constant.int 8\n" -" %int5 = torch.constant.int 5\n" -" %0 = torch.prim.Uninitialized : !torch.int\n" -" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.aten.eq.int %1#1, %int5 : !torch.int, !torch.int -> !torch.bool\n" -" %3 = torch.prim.If %2 -> (!torch.int) {\n" -" torch.prim.If.yield %int8 : !torch.int\n" -" } else {\n" -" %4 = torch.aten.eq.int %1#1, %int6 : !torch.int, !torch.int -> !torch.bool\n" -" %5 = torch.prim.If %4 -> (!torch.int) {\n" -" torch.prim.If.yield %int9 : !torch.int\n" -" } else {\n" -" %6 = torch.aten.eq.int %1#1, %int7 : !torch.int, !torch.int -> !torch.bool\n" -" %7 = torch.prim.If %6 -> (!torch.int) {\n" -" torch.prim.If.yield %int10 : !torch.int\n" -" } else {\n" -" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%1#1) : (!torch.int) -> !torch.bool\n" -" %9 = torch.prim.If %8 -> (!torch.int) {\n" -" torch.prim.If.yield %int9 : !torch.int\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield %0 : !torch.int\n" -" }\n" -" torch.prim.If.yield %9 : !torch.int\n" -" }\n" -" torch.prim.If.yield %7 : !torch.int\n" -" }\n" -" torch.prim.If.yield %5 : !torch.int\n" -" }\n" -" return %3 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.stft\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.optional, %arg7: !torch.optional) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: Unsupported dtype\"\n" " %int7 = torch.constant.int 7\n" @@ -14012,24 +13833,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.l1_loss\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int) -> !torch.int {\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" -" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" -" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" -" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" -" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%4) : (!torch.int) -> !torch.bool\n" -" %6 = torch.aten.__not__ %5 : !torch.bool -> !torch.bool\n" -" torch.prim.If %6 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" return %4 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mul.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" @@ -15818,33 +15621,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" " return %5 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.column_stack\"(%arg0: !torch.list>) -> !torch.int {\n" -" %true = torch.constant.bool true\n" -" %none = torch.constant.none\n" -" %str = torch.constant.str \"AssertionError: \"\n" -" %int0 = torch.constant.int 0\n" -" %0 = torch.prim.ListConstruct : () -> !torch.list>\n" -" %1 = torch.prim.ListConstruct : () -> !torch.list\n" -" %2 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" -" %3 = torch.aten.ne.int %2, %int0 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %3 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %4 = torch.aten.len.t %arg0 : !torch.list> -> !torch.int\n" -" torch.prim.Loop %4, %true, init() {\n" -" ^bb0(%arg1: !torch.int):\n" -" %6 = torch.aten.__getitem__.t %arg0, %arg1 : !torch.list>, !torch.int -> !torch.tuple\n" -" %7:2 = torch.prim.TupleUnpack %6 : !torch.tuple -> !torch.int, !torch.int\n" -" %8 = torch.aten.append.t %0, %7#0 : !torch.list>, !torch.int -> !torch.list>\n" -" %9 = torch.aten.append.t %1, %7#1 : !torch.list, !torch.int -> !torch.list\n" -" torch.prim.Loop.condition %true, iter()\n" -" } : (!torch.int, !torch.bool) -> ()\n" -" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" -" return %5 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.einsum\"(%arg0: !torch.str, %arg1: !torch.list>, %arg2: !torch.optional>) -> !torch.int {\n" " %true = torch.constant.bool true\n" " %none = torch.constant.none\n" @@ -16084,11 +15860,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.deg2rad\"(%arg0: !torch.tuple) -> !torch.int {\n" -" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" -" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" -" return %1 : !torch.int\n" -" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.int_repr\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int3 = torch.constant.int 3\n" " %int1 = torch.constant.int 1\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9c2a80187c93..aa15e3735dae 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1334,44 +1334,6 @@ class DecomposeAtenTrilIndicesOp : public OpRewritePattern { }; } // namespace -namespace { -class DecomposeAtenDeg2radOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenDeg2radOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value self = op.getSelf(); - auto selfTy = dyn_cast(self.getType()); - if (!selfTy || !selfTy.getDtype()) { - return rewriter.notifyMatchFailure(op, "requires tensor types input."); - } - - auto outTy = dyn_cast(op.getType()); - if (!outTy || !outTy.getDtype()) { - return rewriter.notifyMatchFailure( - op, "requires output is a tensor with dtype."); - } - - if (selfTy.getDtype() != outTy.getDtype()) { - self = convertTensorToDtype(rewriter, loc, self, outTy.getDtype()); - } - - Value pi = - rewriter.create(loc, rewriter.getF64FloatAttr(M_PI)); - Value basic = - rewriter.create(loc, rewriter.getF64FloatAttr(180.0)); - Value rad = - rewriter.create(loc, op.getType(), self, basic); - Value result = rewriter.create(loc, op.getType(), rad, pi); - - rewriter.replaceOp(op, result); - - return success(); - } -}; -} // namespace - namespace { class DecomposeAtenSizeOp : public OpRewritePattern { public: @@ -2593,22 +2555,16 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern { // first the input tensor is flattened to 1d tensor and then the reduction // happens on the 0th dimension. if (isa(dim.getType())) { - Value zero = rewriter.create(loc, 0); + BaseTensorType flattenType = + cast(inputType.getWithSizesAndDtype( + {kUnknownSize}, inputType.getOptionalDtype())); + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value end = rewriter.create( + loc, rewriter.getI64IntegerAttr(inputRank - 1)); Value falseValue = rewriter.create(loc, false); - if (inputType.getSizes().size() > 1) { - int64_t flattenSize = Torch::kUnknownSize; - if (inputType.areAllSizesKnown()) { - flattenSize = 1; - for (int64_t sze : inputType.getSizes()) - flattenSize *= sze; - } - auto flattenType = cast(inputType.getWithSizesAndDtype( - {flattenSize}, inputType.getOptionalDtype())); - Value end = rewriter.create( - loc, rewriter.getI64IntegerAttr(inputRank - 1)); - input = rewriter.create(loc, flattenType, input, - zero, end); - } + input = rewriter.create(loc, flattenType, input, + zero, end); Value resultIndices = rewriter .create( @@ -3791,7 +3747,11 @@ class DecomposeAtenRreluOp : public OpRewritePattern { // Create a uniform random op with low and high set to `lower` and // `upper`, respectively. Value none = rewriter.create(loc); - alpha = rewriter.create(loc, resType, self, + Value emptyTensor = rewriter.create( + loc, resType, self, constantZeroFloat, /*dtype=*/none, + /*layout=*/none, + /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none); + alpha = rewriter.create(loc, resType, emptyTensor, /*from=*/lower, /*to=*/upper, /*generator=*/none); } else { @@ -3836,33 +3796,6 @@ class DecomposeAtenRreluWithNoiseOp Value lower = op.getLower(); Value upper = op.getUpper(); auto resType = cast(op.getType()); - Value cstNone = rewriter.create(loc); - Value cstFalse = - rewriter.create(loc, rewriter.getBoolAttr(false)); - Value result = - rewriter - .create( - loc, resType, self, noise, lower, upper, cstFalse, cstNone) - ->getResult(0); - rewriter.replaceOp(op, result); - return success(); - } -}; -} // namespace - -namespace { -class DecomposeAtenRreluWithNoiseFunctionalOp - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenRreluWithNoiseFunctionalOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value self = op.getSelf(); - Value noise = op.getNoise(); - Value lower = op.getLower(); - Value upper = op.getUpper(); - auto resType = cast(op.getResultTypes()[0]); if (!resType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } @@ -3908,7 +3841,7 @@ class DecomposeAtenRreluWithNoiseFunctionalOp rewriter.getI1Type()); Value oneTensor = createRank0Tensor(rewriter, loc, resType, constantOneFloat); - Value not_positive = rewriter.create( + Value not_positive = rewriter.create( loc, boolResType, self, constantZeroFloat); noise = rewriter.create(loc, resType, not_positive, alpha, oneTensor); @@ -3920,7 +3853,7 @@ class DecomposeAtenRreluWithNoiseFunctionalOp rewriter.create(loc, resType, zeroTensor, scaledSelf); Value rreluOutput = rewriter.create( loc, resType, positiveOutput, negativeOutput, constantOneFloat); - rewriter.replaceOp(op, {rreluOutput, noise}); + rewriter.replaceOp(op, rreluOutput); return success(); } }; @@ -4259,68 +4192,6 @@ class DecomposeAtenHstackOp : public OpRewritePattern { }; } // namespace -// Decompose `aten.column_stack` into `aten.reshape` and `aten.cat`. -// https://github.com/pytorch/pytorch/blob/207564bab1c4fe42750931765734ee604032fb69/torch/_refs/__init__.py#L2822 -namespace { -class DecomposeAtenColumnStackOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenColumnStackOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - - SmallVector tensors; - if (!getListConstructElements(op.getTensors(), tensors)) - return rewriter.notifyMatchFailure( - op, "unimplemented: the tensor list is not from list construct"); - - for (auto tensor : tensors) { - auto tTy = dyn_cast(tensor.getType()); - if (!tTy || !tTy.hasSizes()) - return rewriter.notifyMatchFailure( - op, "unimplemented: one tensor does not have known sizes"); - } - - SmallVector tensors2d; - for (auto tensor : tensors) { - auto tTy = dyn_cast(tensor.getType()); - SmallVector tSizes(tTy.getSizes()); - if (tSizes.size() <= 1) { - if (tSizes.size() == 0) { - tSizes.push_back(1); - } - tSizes.push_back(1); - auto newTy = tTy.getWithSizesAndDtype(tSizes, tTy.getDtype()); - SmallVector newShapeList; - for (auto tSize : tSizes) { - newShapeList.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(tSize))); - } - auto newShape = rewriter.create( - loc, Torch::ListType::get(rewriter.getType()), - newShapeList); - Value tensor2d = - rewriter.create(loc, newTy, tensor, newShape); - tensors2d.push_back(tensor2d); - } else { - tensors2d.push_back(tensor); - } - } - - auto elemType = cast(tensors2d[0].getType()) - .getWithSizesAndDtype(std::nullopt, nullptr); - Value newTensors = rewriter.create( - loc, Torch::ListType::get(elemType), tensors2d); - - rewriter.replaceOpWithNewOp( - op, op.getType(), newTensors, - rewriter.create(loc, rewriter.getI64IntegerAttr(1))); - - return success(); - } -}; -} // namespace - // Decompose aten.roll into aten.slice and aten.cat ops. // https://pytorch.org/docs/stable/generated/torch.roll.html namespace { @@ -4716,11 +4587,6 @@ class DecomposeAtenUnflattenIntOp if (!isValidDim(dimInt, inputRank)) return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); - if (inputShape[dimInt] == Torch::kUnknownSize && - llvm::count(sizesInts, -1) > 0) - return rewriter.notifyMatchFailure( - op, "Unimplemented: dynamic unflatten dim with an inferred size."); - SmallVector sizesTorchInt; if (!getListConstructElements(op.getSizes(), sizesTorchInt)) return rewriter.notifyMatchFailure( @@ -5198,82 +5064,6 @@ class DecomposeAtenConv2dOp : public OpRewritePattern { }; } // namespace -// Decompose aten.conv(1/2/3)d.padding to aten.convolution -namespace { -template -class DecomposeAtenConvPaddingOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ConvPaddingOp op, - PatternRewriter &rewriter) const override { - - Location loc = op.getLoc(); - - Value weight = op.getWeight(); - std::optional maybeRank = getTensorRank(weight); - if (!maybeRank) { - return rewriter.notifyMatchFailure(op, "expected weight to have a rank"); - } - unsigned rank = *maybeRank; - // first 2 dimensions of weight are out_channels and in_channels / groups - if (rank < 3) - return rewriter.notifyMatchFailure( - op, "ConvPaddingOp weight must be at least 3 dimensional."); - - std::string padding_str; - if (!matchPattern(op.getPadding(), m_TorchConstantStr(padding_str))) - return rewriter.notifyMatchFailure(op, - "padding must be a constant string"); - - Value zero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - - SmallVector paddingValues; - if (padding_str == "valid") { - // valid means no padding - for (unsigned iRank = 2; iRank < rank; iRank++) { - paddingValues.push_back(zero); - } - } else { - - SmallVector dilation; - getListConstructElements(op.getDilation(), dilation); - - Value one = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - Value two = - rewriter.create(loc, rewriter.getI64IntegerAttr(2)); - for (unsigned iRank = 2; iRank < rank; iRank++) { - Value dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(iRank)); - Value kernelSize = - rewriter.create(loc, weight, dim); - Value kernelSizeMinusOne = - rewriter.create(loc, kernelSize, one); - Value padding = rewriter.create( - loc, dilation[iRank - 2], kernelSizeMinusOne); - padding = rewriter.create(loc, padding, two); - paddingValues.push_back(padding); - } - } - - Value emptyList = rewriter.create( - op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), - SmallVector()); - Value cstFalse = rewriter.create(op.getLoc(), false); - Value padding = rewriter.create( - op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), - paddingValues); - rewriter.replaceOpWithNewOp( - op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), - op.getStride(), padding, op.getDilation(), cstFalse, emptyList, - op.getGroups()); - - return success(); - } -}; -} // namespace - // Decompose aten.conv3d to aten.convolution namespace { class DecomposeAtenConv3dOp : public OpRewritePattern { @@ -5728,240 +5518,6 @@ class DecomposeAtenConvolutionBackwardOp }; } // namespace -/** - * # one dim input - * t = torch.tensor([0, 0, 1, 1, 0, 0] - * # t_flat:[0, 0, 1, 1, 0, 0] - * t_flat = t.flatten(0, 0) - * nonzero_mask = t_flat != 0 - * # nonzero_mask:[0, 0, 1, 1, 0, 0] - * nonzero_mask = nonzero_mask.long() - * # destination_indices:[-1, -1, 0, 1, 1, 1] - * destination_indices = torch.cumsum(nonzero_mask, 0) - 1 - * # destination_indices_clamp:[0, 0, 0, 1, 1, 1] - * destination_indices_clamp = torch.clamp(destination_indices, min=0) - * # iota:[0, 0, 2, 3, 0, 0] - * iota = torch.arange(t_flat.size(0)) * nonzero_mask - * # scatter_self:[0, 0, 0, 0, 0, 0] - * scatter_self = torch.zeros_like(t_flat, dtype=torch.int64) - * # compacted:[2, 3, 0, 0, 0, 0] - * compacted = torch.scatter_add( - * scatter_self, dim=0, index=destination_indices_clamp, src=iota - * ) - * # result_flat:[2, 3] - * result_flat = compacted[: torch.sum(nonzero_mask)] - * - * # multi dim support - * original_shape = t.shape - * # input_shape_tensor:[6] - * input_shape_tensor = torch.tensor(original_shape) - * strides = torch.cumprod(torch.flip(input_shape_tensor, [0]), 0).flip(0) - * - * one = torch.tensor([1]) - * if(t.dim() > 1): - * slicedStrides = strides[1:-1] - * strides = torch.cat([slicedStrides, one]) - * else: - * strides = one - * # a: tensor([[2], [3]]) torch.Size([2, 1]) - * a = result_flat.unsqueeze(1) # tensor([[2], [3]]) torch.Size([2, 1]) - * # b: tensor([[1]]) torch.Size([1, 1]) - * b = strides.unsqueeze(0) - * # c: tensor([[2], [3]]) torch.Size([2, 1]) - * c = a // b - * # result: tensor([[2], [3]]) torch.Size([2, 1]) - * result = c % input_shape_tensor - */ -class DecomposeAtenNonzeroOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenNonzeroOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - auto resultType = cast(op.getType()); - auto intType = resultType.getDtype(); - Value intTypeValue = getDtypeIntValueForType(rewriter, loc, intType); - auto constantZero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); - auto constantOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - std::function makeOneElementList = [&](Value element) { - auto listType = Torch::ListType::get(element.getType()); - return rewriter.create(loc, listType, - ArrayRef{element}); - }; - - Value input = op.getSelf(); - auto inputType = dyn_cast(input.getType()); - int64_t inputRank = inputType.getSizes().size(); - - // t_flat = t.flatten() # torch.flatten(t, 0, 0) - int64_t flattenedSize = 1; - if (inputType.hasSizes()) { - for (auto size : inputType.getSizes()) { - flattenedSize *= size; - } - } else { - flattenedSize = kUnknownSize; - } - - auto flattendInputShape = SmallVector{flattenedSize}; - auto flattenedInputType = rewriter.getType( - flattendInputShape, inputType.getOptionalDtype()); - - // %1 = torch.aten.flatten.using_ints %arg0, %int0, %int0_0 : - auto inputDimsEnd = rewriter.create( - loc, rewriter.getI64IntegerAttr(inputRank - 1)); - Value flattenedInput = rewriter.create( - loc, flattenedInputType, input, constantZero /*inputDimsStart*/, - inputDimsEnd /*inputDimsEnd*/); - - // nonzero_mask = (t_flat != 0) - auto boolMaskType = inputType.getWithSizesAndDtype( - flattenedInputType.getOptionalSizes(), rewriter.getI1Type()); - Value boolMask = rewriter.create( - loc, boolMaskType, flattenedInput, constantZero); - - // nonzero_mask = nonzero_mask.int() - Value falseCst = rewriter.create(loc, false); - Value noneCst = rewriter.create(loc); - auto intMaskType = flattenedInputType.getWithSizesAndDtype( - flattenedInputType.getOptionalSizes(), intType); - Value intMask = rewriter.create( - loc, intMaskType, boolMask, intTypeValue, falseCst, falseCst, noneCst); - - // destination_indices = torch.cumsum(nonzero_mask, 0) - 1 - Value cumulativeSum = rewriter.create( - loc, intMaskType, intMask, constantZero, noneCst); - Value subtracted = rewriter.create( - loc, intMaskType, cumulativeSum, constantOne, /*alpha=*/constantOne); - - // destination_indices = torch.clamp(destination_indices, min=0) - Value indices = rewriter.create(loc, intMaskType, - subtracted, constantZero); - - // iota = torch.arange(len(t_flat)) * nonzero_mask - Value end = rewriter.create(loc, flattenedInput, - /*dim=*/constantZero); - Value rangeTensor = rewriter.create( - loc, intMaskType, /*start*/ constantZero, /*end*/ end, - /*step*/ constantOne, noneCst, noneCst, noneCst, noneCst); - Value multiplied = rewriter.create(loc, intMaskType, - rangeTensor, intMask); - - // scatter_self = torch.zeros_like(t, dtype=torch.int64) - // AtenFullLike doesn't support index type so we have to use int. - Value zerosTensor = rewriter.create( - loc, intMaskType, flattenedInput, intTypeValue, noneCst, noneCst, - noneCst, noneCst); - - // compacted = torch.scatter_add( - // scatter_self, dim=0, index=destination_indices_clamp, src=iota) - Value scatteredTensor = rewriter.create( - loc, intMaskType, /*self*/ zerosTensor, /*dim=*/constantZero, - /*index=*/indices, /*src=*/multiplied); - - // result_flat = compacted[:torch.sum(nonzero_mask)] - auto scalarType = ValueTensorType::get(rewriter.getContext(), - ArrayRef{}, intType); - Value sumMask = - rewriter.create(loc, scalarType, intMask, noneCst); - Value numNonzero = rewriter.create(loc, sumMask); - - auto slicedResultType = Torch::ValueTensorType::get( - rewriter.getContext(), SmallVector{kUnknownSize}, intType); - Value slicedResult = - rewriter.create(loc, slicedResultType, - /*self=*/scatteredTensor, - /*dim=*/constantZero, - /*start=*/noneCst, - /*end=*/numNonzero, - /*step=*/constantOne); - - // TODO fix multidim dynamic support. The following code only work for - // static multidim. Convert flattened indices back to multi-dimensional - // indices original_shape = t.shape input_shape_tensor = - // torch.tensor(original_shape) - auto shapeType = Torch::ValueTensorType::get( - rewriter.getContext(), SmallVector{inputRank}, intType); - SmallVector shapeValues; - for (int i = 0; i < inputRank; i++) { - auto constantI = - rewriter.create(loc, rewriter.getI64IntegerAttr(i)); - Value shape = rewriter.create(loc, input, - /*dim=*/constantI); - shapeValues.push_back(shape); - } - Value shapeTensorList = rewriter.create( - loc, Torch::ListType::get(shapeValues[0].getType()), shapeValues); - Value inputShapeTensor = rewriter.create( - loc, shapeType, shapeTensorList, noneCst, noneCst, falseCst); - - // strides = torch.cumprod(torch.flip(input_shape_tensor,[0]),0).flip(0) - Value flippedShape = rewriter.create( - loc, shapeType, inputShapeTensor, makeOneElementList(constantZero)); - Value cumulativeProduct = rewriter.create( - loc, shapeType, flippedShape, constantZero, noneCst); - Value flippedCumulativeProduct = rewriter.create( - loc, shapeType, cumulativeProduct, makeOneElementList(constantZero)); - - // strides = torch.cat([strides[1:-1], torch.tensor([1])]) - auto oneTensorType = ValueTensorType::get(rewriter.getContext(), - SmallVector{1}, intType); - Value oneTensor = rewriter.create( - loc, oneTensorType, constantOne, intTypeValue, noneCst, noneCst, - noneCst); - - Value strides; - if (inputRank > 1) { - // strides[1:-1] - auto slicedStrideType = Torch::ValueTensorType::get( - rewriter.getContext(), SmallVector{inputRank - 1}, // sizes - intType); - Value strideSliceEnd = rewriter.create( - loc, rewriter.getI64IntegerAttr(inputRank)); - Value slicedStrides = rewriter.create( - loc, slicedStrideType, /*self*/ flippedCumulativeProduct, - /*dim*/ constantZero, - /*start=*/constantOne, /*end=*/strideSliceEnd, /*step=*/constantOne); - // torch.cat - auto tensorListElementType = Torch::ValueTensorType::get( - rewriter.getContext(), SmallVector{kUnknownSize}, intType); - Value tensorList = rewriter.create( - loc, Torch::ListType::get(tensorListElementType), - SmallVector{slicedStrides, oneTensor}); - strides = rewriter.create(loc, shapeType, tensorList, - constantZero); - } else { - // strides[1:-1] is empty - strides = oneTensor; - } - - // multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) % - // input_shape_tensor - auto unsqueezedResultType = ValueTensorType::get( - rewriter.getContext(), SmallVector{kUnknownSize, 1}, intType); - Value unsqueezedResult = rewriter.create( - loc, unsqueezedResultType, slicedResult, constantOne); - - auto unsqueezedStridesType = ValueTensorType::get( - rewriter.getContext(), SmallVector{1, inputRank}, intType); - Value unsqueezedStrides = rewriter.create( - loc, unsqueezedStridesType, strides, constantZero); - - auto dividedBroadcastType = ValueTensorType::get( - rewriter.getContext(), SmallVector{kUnknownSize, inputRank}, - intType); - Value divided = rewriter.create( - loc, dividedBroadcastType, unsqueezedResult, unsqueezedStrides); - - Value modded = rewriter.create( - loc, resultType, divided, inputShapeTensor); - - rewriter.replaceOp(op, modded); - return success(); - } -}; - // Decompose aten.addmm into aten.mm and aten.add.Tensor op. namespace { class DecomposeAtenAddmmOp : public OpRewritePattern { @@ -6968,7 +6524,7 @@ class DecomposeAtenNativeLayerNormOp Location loc = op.getLoc(); auto context = op.getContext(); - auto inputTy = cast(op.getInput().getType()); + auto inputTy = cast(op.getInput().getType()); if (!inputTy.hasSizes()) return rewriter.notifyMatchFailure( op, "input tensor should have known sizes."); @@ -7023,18 +6579,6 @@ class DecomposeAtenNativeLayerNormOp loc, inputTy, inputRsqrtVar, op.getInput()); Value inputNormalized = rewriter.create( loc, inputTy, inputZeroMean, inputRsqrtVarExpanded); - // Convert resultType if dtype is different - auto resultTensorType = - dyn_cast(op.getResult(0).getType()); - if (inputTy.getDtype() != resultTensorType.getDtype()) { - Value dtypeValue = Torch::getDtypeIntValueForType( - rewriter, loc, resultTensorType.getDtype()); - Value cstFalse = rewriter.create(loc, false); - inputNormalized = rewriter.create( - loc, resultTensorType, inputNormalized, - /*dtype=*/dtypeValue, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, - /*memory_format=*/none); - } Value out = rewriter.create( loc, op.getResult(0).getType(), inputNormalized); @@ -9029,71 +8573,6 @@ class DecomposeAtenMseLossOp : public OpRewritePattern { }; } // namespace -namespace { -class DecomposeAtenL1LossOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenL1LossOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value self = op.getSelf(); - auto selfTy = dyn_cast(self.getType()); - if (!selfTy || !selfTy.hasSizes() || !selfTy.hasDtype()) { - return rewriter.notifyMatchFailure( - op, "Expected self to be a tensor with sizes and a dtype"); - } - - Value target = op.getTarget(); - auto targetTy = dyn_cast(target.getType()); - if (!targetTy || !targetTy.hasDtype()) { - return rewriter.notifyMatchFailure( - op, "Expected target to be a tensor with sizes and a dtype"); - } - - auto outTy = dyn_cast(op.getType()); - if (!outTy || !outTy.hasDtype()) { - return rewriter.notifyMatchFailure( - op, "Expected output type to be a tensor with a dtype"); - } - - auto outDtype = outTy.getDtype(); - if (selfTy.getDtype() != outDtype) { - self = convertTensorToDtype(rewriter, loc, self, outDtype); - } - if (targetTy.getDtype() != outDtype) { - target = convertTensorToDtype(rewriter, loc, target, outDtype); - } - - Value reduction = op.getReduction(); - int64_t reductionInt; - if (!matchPattern(reduction, m_TorchConstantInt(&reductionInt))) { - return rewriter.notifyMatchFailure( - op, "Expected reduction to be a constant int"); - } - - auto subTy = outTy.getWithSizesAndDtype(selfTy.getSizes(), outDtype); - Value sub = createTensorSub(rewriter, loc, subTy, self, target); - Value abs = rewriter.create(loc, subTy, sub); - - if (reductionInt == 0) { - rewriter.replaceOp(op, abs); - } else if (reductionInt == 1) { - Value none = rewriter.create(loc); - Value sum = rewriter.create(loc, outTy, abs, none); - Value numel = rewriter.create(loc, abs); - Value mean = rewriter.create(loc, outTy, sum, numel); - rewriter.replaceOp(op, mean); - } else { - Value none = rewriter.create(loc); - Value sum = rewriter.create(loc, outTy, abs, none); - rewriter.replaceOp(op, sum); - } - - return success(); - } -}; -} // namespace - namespace { // Decompose `aten.norm.ScalarOpt_dim` op to `aten.linalg_vector_norm` op class DecomposeAtenNormScalarOptDimOp @@ -10206,156 +9685,6 @@ class DecomposeAtenTopkOp : public OpRewritePattern { }; } // namespace -namespace { - -/// Creates coefficients based on DFT definition, see -/// https://en.wikipedia.org/wiki/Discrete_Fourier_transform. -/// Even indices of the second dimension are for the real components of the -/// output. Odd indices for the imaginary components. -Value getDFTMatmulCoeff(PatternRewriter &rewriter, Location loc, - ValueTensorType matrixType) { - // scale = 2 * pi / N - double scale = 2 * M_PI / matrixType.getSizes()[0]; - - SmallVector values; - assert(matrixType.getSizes().size() == 2 && "expected 2D matrix"); - for (auto i : llvm::seq(0, matrixType.getSizes()[0])) { - for (auto j : llvm::seq(0, matrixType.getSizes()[1])) { - const bool isImagPart = j % 2; - double v = scale * i * (j / 2); - v = isImagPart ? -sin(v) : cos(v); - values.push_back(rewriter.getF32FloatAttr(v)); - } - } - - return rewriter.create( - loc, matrixType, - DenseElementsAttr::get(matrixType.toBuiltinTensor(), - ArrayRef(values))); -} - -class DecomposeAtenFftRfftOp final : public OpRewritePattern { - - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenFftRfftOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value self = op.getSelf(); - - int64_t dim; - auto dimVal = op.getDim(); - if (isa(dimVal.getType())) { - dim = -1; - } else if (!matchPattern(dimVal, torch::Torch::m_TorchConstantInt(&dim))) { - return rewriter.notifyMatchFailure( - op, "unimplemented: requires dim to be constant"); - } - - if (!isa(op.getN().getType())) { - return rewriter.notifyMatchFailure(op, "unimplemented: parameter n"); - } - - if (!isa(op.getNorm().getType())) { - return rewriter.notifyMatchFailure(op, "unimplemented: parameter norm"); - } - - BaseTensorType inputType = cast(self.getType()); - - if (!inputType.hasSizes()) { - return rewriter.notifyMatchFailure( - op, "unsupported: only ranked tensors are supported"); - } - - const ArrayRef inputShape = inputType.getSizes(); - dim += dim < 0 ? inputShape.size() : 0; - - const int64_t fftLength = inputShape[dim]; - if (fftLength == kUnknownSize) { - return rewriter.notifyMatchFailure( - op, "unsupported: input signal length must be known"); - } - const int64_t rank = inputShape.size(); - const int64_t lastDim = rank - 1; - const int64_t outputFftDim = fftLength / 2 + 1; - const bool needTranspose = dim != lastDim; - - auto transposeValue = [](PatternRewriter &rewriter, Location loc, - Value input, int64_t dimA, int64_t dimB, - Value &transposed) { - Type transposedType; - if (failed(getTransposedType(cast(input.getType()), dimA, - dimB, transposedType))) - return failure(); - Value cstDimA = - rewriter.create(loc, rewriter.getI64IntegerAttr(dimA)); - Value cstDimB = - rewriter.create(loc, rewriter.getI64IntegerAttr(dimB)); - transposed = rewriter.create(loc, transposedType, - input, cstDimA, cstDimB); - return success(); - }; - - SmallVector lhsShape(inputShape); - // Transpose if FFT dimension is not the last one - if (needTranspose) { - if (failed(transposeValue(rewriter, loc, self, dim, lastDim, self))) { - return failure(); - } - std::swap(lhsShape[dim], lhsShape[lastDim]); - } - // self : (D_0 x ... x D_m x fftLength) - - Type dtype = inputType.getOptionalDtype(); - - // coeff : (fftLength x outputFftDim*2) - ValueTensorType matrixType = ValueTensorType::get( - op.getContext(), SmallVector{fftLength, outputFftDim * 2}, - dtype); - Value coeffMatrix = getDFTMatmulCoeff(rewriter, loc, matrixType); - - // X = matmul(self, coeff) : (D_0 x ... x D_m x outputFftDim*2) - SmallVector matmulShape(lhsShape.begin(), lhsShape.end() - 1); - matmulShape.push_back(outputFftDim * 2); - ValueTensorType matmulType = - ValueTensorType::get(op.getContext(), matmulShape, dtype); - Value flatRes = - rewriter.create(loc, matmulType, self, coeffMatrix); - - // Y = unflatten(X, -1, [outputFftDim, 2]) - // : (D_0 x ... x D_m x outputFftDim x 2) - // Z = view_as_complex(Y) : complex(D_0 x ... x D_m x outputFftDim) - SmallVector complexResShape(matmulShape); - complexResShape.back() = outputFftDim; - SmallVector unflattenedResShape(complexResShape); - unflattenedResShape.push_back(2); - Type unflattenedResType = - ValueTensorType::get(op.getContext(), unflattenedResShape, dtype); - Value cstMinusOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); - Value unflattenSizes = toIntListConstruct(rewriter, loc, {outputFftDim, 2}); - Value unflattenedRes = rewriter.create( - loc, unflattenedResType, flatRes, /*dim=*/cstMinusOne, unflattenSizes); - Type complexResType = ValueTensorType::get(op.getContext(), complexResShape, - ComplexType::get(dtype)); - Value complexRes = rewriter.create(loc, complexResType, - unflattenedRes); - - // Transpose back - if (needTranspose) { - if (failed(transposeValue(rewriter, loc, complexRes, dim, lastDim, - complexRes))) { - return failure(); - } - } - - rewriter.replaceOp(op, {complexRes}); - - return success(); - } -}; - -} // namespace - namespace { // Decompose `aten.hann_window` into `aten.arange.start`, `aten.mul.Scalar`, // `aten.sin` and `aten.square` or into `aten.ones` in the trivial case @@ -11167,286 +10496,6 @@ class DecomposeAtenFloatPowerTensorTensorOp }; } // namespace -namespace { -class DecomposeTorchvisionNmsOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TorchvisionNmsOp op, - PatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - MLIRContext *context = op->getContext(); - Value boxes = op.getDets(); - Value scores = op.getScores(); - Value iouThreshold = op.getIouThreshold(); - - Value cst0 = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value cst1 = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value cst2 = rewriter.create( - loc, rewriter.getI64IntegerAttr(2)); - Value cst4 = rewriter.create( - loc, rewriter.getI64IntegerAttr(4)); - Value cstNone = rewriter.create(loc); - Value cstTrue = - rewriter.create(loc, rewriter.getBoolAttr(true)); - Value cstFalse = rewriter.create( - loc, rewriter.getBoolAttr(false)); - - // Get number of boxes for the loop count - auto boxesTensorType = dyn_cast(boxes.getType()); - auto dType = boxesTensorType.getDtype(); - int64_t boxesSize = boxesTensorType.getSizes()[0]; - Value len = rewriter.create(loc, boxes, /*dim=*/cst0); - - // Calculate the area of each box: (x2 - x1) * (y2 - y1) - auto sliceTy = rewriter.getType( - SmallVector{boxesSize, 2}, dType); - Value lowSlice = rewriter.create( - loc, sliceTy, boxes, - /*dim=*/cst1, /*start=*/cst0, /*end=*/cst2, /*step=*/cst1); - Value highSlice = rewriter.create( - loc, sliceTy, boxes, - /*dim=*/cst1, /*start=*/cst2, /*end=*/cst4, /*step=*/cst1); - Value distance = rewriter.create( - loc, sliceTy, highSlice, lowSlice, cst1); - auto areaTy = rewriter.getType( - SmallVector{boxesSize}, dType); - Value area = rewriter.create( - loc, areaTy, distance, /*dim=*/cst1, /*keepdim=*/cstFalse, - /*dtype=*/cstNone); - - // Sort scores in descending order - // Use the sorted indices to iterate boxes - auto scoresType = dyn_cast(scores.getType()); - auto intTensorType = scoresType.getWithSizesAndDtype( - scoresType.getOptionalSizes(), - IntegerType::get(context, 64, IntegerType::Signed)); - auto sortResult = rewriter.create( - loc, TypeRange({scores.getType(), intTensorType}), scores, - /*dim=*/cst0, /*descending=*/cstTrue); - - // Create a mask to mark if we keep the boxes - Value lenShapeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), - SmallVector{len}); - Value mask = rewriter.create( - loc, intTensorType, lenShapeList, cstNone, cstNone, cstNone, cstNone); - Value zeroShapeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), - SmallVector{cst1}); - auto zeroTy = rewriter.getType( - SmallVector{1}, rewriter.getIntegerType(64, /*signed=*/true)); - Value falseMask = rewriter.create( - loc, zeroTy, zeroShapeList, cstNone, cstNone, cstNone, cstNone); - - // Create an empty tensor for result - Value result = rewriter.create( - loc, intTensorType, lenShapeList, /*dtype=*/cst4, /*layout=*/cstNone, - /*device=*/cstNone, /*pinMemory=*/cstNone, /*memoryFormat=*/cstNone); - - auto intTy = rewriter.getType(); - auto rowSliceTy = - rewriter.getType(SmallVector{1, 4}, dType); - auto pointTy = - rewriter.getType(SmallVector{1, 2}, dType); - auto extractTy = rewriter.getType( - SmallVector{1}, rewriter.getIntegerType(64, true)); - Value float0 = rewriter.create( - loc, rewriter.getFloatAttr(dType, 0.0)); - auto scalarFloatType = rewriter.getType( - SmallVector{1}, dType); - Value float0Tensor = rewriter.create( - loc, scalarFloatType, float0); - - // 1. Loop through the boxes based on sorted indices - // 2. Add the current box to result if it's not suppressed - // 3. Calculate the IoUs with all boxes - // 4. Loop through the rest boxes in sorted indices - // 5. Suppress the box if the corresponding IoU is larger than threshold - auto loop1 = rewriter.create( - loc, TypeRange({intTensorType, intTensorType, intTy}), len, cstTrue, - ValueRange({mask, result, cst0})); - { - PatternRewriter::InsertionGuard guard(rewriter); - Block *loopBody1 = rewriter.createBlock( - &loop1.getRegion(), loop1.getRegion().begin(), - TypeRange({intTy, intTensorType, intTensorType, intTy}), - {loc, loc, loc, loc}); - Value i = loopBody1->getArgument(0); - Value mask1 = loopBody1->getArgument(1); - Value curResult = loopBody1->getArgument(2); - Value curCnt = loopBody1->getArgument(3); - - // Extract the mask to check if the base box is suppressed - Value extract = rewriter.create( - loc, extractTy, mask1, /*dim=*/cst0, /*index=*/i); - Value scalar = rewriter.create(loc, intTy, extract); - Value iskept = rewriter.create( - loc, rewriter.getType(), scalar); - auto ifFilterOthers = rewriter.create( - loc, TypeRange({intTensorType, intTensorType, intTy}), iskept); - { - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.createBlock(&ifFilterOthers.getThenRegion(), - ifFilterOthers.getThenRegion().begin()); - - // Scatter the selected indices into result - Value extractIdx1 = rewriter.create( - loc, extractTy, sortResult.getResults()[1], /*dim=*/cst0, - /*index=*/i); - Value next = rewriter.create(loc, curCnt, cst1); - Value updatedResult = rewriter.create( - loc, intTensorType, curResult, extractIdx1, /*dim=*/cst0, - /*start=*/curCnt, /*end=*/next, /*step=*/cst1); - - // Get the coordinates of base box - Value idx1 = - rewriter.create(loc, intTy, extractIdx1); - Value idx1End = rewriter.create(loc, idx1, cst1); - Value curBox = rewriter.create( - loc, rowSliceTy, boxes, - /*dim=*/cst0, /*start=*/idx1, /*end=*/idx1End, /*step=*/cst1); - - // Calculate IoUs: intersectionArea / unionArea - // Intersection area = intersectionWidth * intersectionHeight - Value point1 = rewriter.create( - loc, pointTy, curBox, - /*dim=*/cst1, /*start=*/cst0, /*end=*/cst2, /*step=*/cst1); - Value point2 = rewriter.create( - loc, pointTy, curBox, - /*dim=*/cst1, /*start=*/cst2, /*end=*/cst4, /*step=*/cst1); - Value innerLow = rewriter.create( - loc, sliceTy, lowSlice, point1); - Value innerHigh = rewriter.create( - loc, sliceTy, highSlice, point2); - Value innerDistance = rewriter.create( - loc, sliceTy, innerHigh, innerLow, cst1); - innerDistance = rewriter.create( - loc, sliceTy, innerDistance, float0Tensor); - Value intersectionArea = rewriter.create( - loc, areaTy, innerDistance, /*dim=*/cst1, /*keepdim=*/cstFalse, - /*dtype=*/cstNone); - Value iEnd = rewriter.create(loc, i, cst1); - Value curArea = rewriter.create( - loc, scalarFloatType, area, - /*dim=*/cst0, /*start=*/i, /*end=*/iEnd, /*step=*/cst1); - // Union area = area1 + area2 - intersectionArea - Value unionArea = rewriter.create( - loc, areaTy, area, curArea, cst1); - unionArea = rewriter.create( - loc, areaTy, unionArea, intersectionArea, cst1); - Value iou = rewriter.create( - loc, areaTy, intersectionArea, unionArea); - - // Loop through the rest of boxes in sorted indices - auto loop2 = rewriter.create(loc, intTensorType, len, - cstTrue, mask1); - { - PatternRewriter::InsertionGuard guard(rewriter); - Block *loopBody2 = rewriter.createBlock( - &loop2.getRegion(), loop2.getRegion().begin(), - TypeRange({intTy, intTensorType}), {loc, loc}); - Value j = loopBody2->getArgument(0); - Value mask2 = loopBody2->getArgument(1); - - // Check if current index is out of range - j = rewriter.create(loc, j, i); - j = rewriter.create(loc, j, cst1); - Value isInRange = rewriter.create(loc, j, len); - auto ifCalculateIou = rewriter.create( - loc, TypeRange({intTensorType}), isInRange); - { - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.createBlock(&ifCalculateIou.getThenRegion(), - ifCalculateIou.getThenRegion().begin()); - - // Retrieve IoU and check if suppress the box - Value extractIdx2 = rewriter.create( - loc, extractTy, sortResult.getResults()[1], /*dim=*/cst0, - /*index=*/j); - Value idx2 = - rewriter.create(loc, intTy, extractIdx2); - Value idx2End = - rewriter.create(loc, idx2, cst1); - Value curIoU = rewriter.create( - loc, scalarFloatType, iou, - /*dim=*/cst0, /*start=*/idx2, /*end=*/idx2End, /*step=*/cst1); - curIoU = rewriter.create( - loc, rewriter.getType(), curIoU); - Value isSuppressed = rewriter.create( - loc, curIoU, iouThreshold); - - auto ifUnmask = rewriter.create( - loc, TypeRange({intTensorType}), isSuppressed); - { - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.createBlock(&ifUnmask.getThenRegion(), - ifUnmask.getThenRegion().begin()); - - // Update the mask if suppress - Value jEnd = rewriter.create(loc, j, cst1); - Value updatedMask = rewriter.create( - loc, intTensorType, mask2, falseMask, /*dim=*/cst0, - /*start=*/j, /*end=*/jEnd, /*step=*/cst1); - rewriter.create(loc, updatedMask); - } - { - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.createBlock(&ifUnmask.getElseRegion(), - ifUnmask.getElseRegion().begin()); - rewriter.create(loc, mask2); - } - - rewriter.create(loc, ifUnmask.getResult(0)); - } - { - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.createBlock(&ifCalculateIou.getElseRegion(), - ifCalculateIou.getElseRegion().begin()); - rewriter.create(loc, mask2); - } - - rewriter.create( - loc, cstTrue, ifCalculateIou.getResult(0)); - } - - rewriter.create( - loc, ValueRange({loop2.getResult(0), updatedResult, next})); - } - { - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.createBlock(&ifFilterOthers.getElseRegion(), - ifFilterOthers.getElseRegion().begin()); - rewriter.create( - loc, ValueRange({mask1, curResult, curCnt})); - } - - rewriter.create(loc, cstTrue, - ifFilterOthers.getResults()); - } - - rewriter.replaceOpWithNewOp( - op, op.getType(), loop1.getResult(1), /*dim=*/cst0, /*start=*/cst0, - /*end=*/loop1.getResult(2), /*step=*/cst1); - return success(); - } -}; -} // namespace - -namespace { -class DecomposeAtenSpecialExpm1Op - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenSpecialExpm1Op op, - PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf()); - return success(); - } -}; -} // namespace - namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -11500,7 +10549,6 @@ class DecomposeComplexOpsPass DecomposeConstantTensorAllocLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( @@ -11520,7 +10568,6 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal( patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -11591,8 +10638,6 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal( - patterns); addPatternIfTargetOpIsIllegal( patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -11663,7 +10708,6 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -11699,7 +10743,6 @@ class DecomposeComplexOpsPass patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -11710,7 +10753,6 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -11726,25 +10768,15 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal< - DecomposeAtenConvPaddingOp>(patterns); - addPatternIfTargetOpIsIllegal< - DecomposeAtenConvPaddingOp>(patterns); - addPatternIfTargetOpIsIllegal< - DecomposeAtenConvPaddingOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenFMaxMinOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenFMaxMinOp>(patterns); - // Torchvision ops - addPatternIfTargetOpIsIllegal(patterns); - GreedyRewriteConfig config; config.useTopDownTraversal = true; config.maxIterations = GreedyRewriteConfig::kNoLimit; diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index f15911e2b5ba..4bca74470772 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -382,7 +382,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -501,7 +500,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -528,7 +526,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -566,11 +563,9 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); for (auto &opName : backendLegalOpsSet) { target.addLegalOp( diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index 634e910d4c32..3d1a54de29f9 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -37,8 +37,8 @@ LogicalResult materializeFolds(ImplicitLocOpBuilder b, if (auto attr = dyn_cast(f)) { if (auto val = dyn_cast(attr)) { - values.push_back( - b.create(APFloat(val.getValueAsDouble()))); + values.push_back(b.create( + b.getType(), val)); continue; } @@ -714,7 +714,7 @@ class PropagateAtenItemPattern : public OpRewritePattern { ImplicitLocOpBuilder b(op.getLoc(), rewriter); // Rank 0 item op prop - if (selfTy.getSizes().empty()) { + if (selfTy.getSizes().size() == 0) { auto numToTensor = self.getDefiningOp(); auto squeezeDim = self.getDefiningOp(); if (!squeezeDim && !numToTensor) @@ -746,109 +746,6 @@ class PropagateAtenItemPattern : public OpRewritePattern { }; } // namespace -namespace { - -LogicalResult convertOpFoldResults(ImplicitLocOpBuilder &b, - SmallVector &converted, - SmallVector &elements, - Type inputDtype, Type resultDtype) { - auto inputIsInt = dyn_cast(inputDtype); - auto resultIsInt = dyn_cast(resultDtype); - if (!inputIsInt && !isa(inputDtype)) - return failure(); - if (!resultIsInt && !isa(resultDtype)) - return failure(); - - // if dtypes are both int or both float, no conversion needed - if (static_cast(inputIsInt) == static_cast(resultIsInt)) { - converted = elements; - return success(); - } - - if (resultIsInt) { - for (auto &e : elements) { - auto eValue = dyn_cast(e); - if (eValue) { - converted.push_back(b.createOrFold(eValue)); - continue; - } - auto eAttr = dyn_cast(e); - auto eFloatAttr = dyn_cast_or_null(eAttr); - if (!eFloatAttr) - return failure(); - - converted.push_back(IntegerAttr::get( - resultDtype, static_cast(eFloatAttr.getValueAsDouble()))); - } - return success(); - } - - // result is float - for (auto &e : elements) { - auto eValue = dyn_cast(e); - if (eValue) { - converted.push_back(b.createOrFold(eValue)); - continue; - } - auto eAttr = dyn_cast(e); - auto eIntAttr = dyn_cast(eAttr); - if (!eIntAttr) - return failure(); - - auto eInt = (inputIsInt.isSigned()) ? eIntAttr.getValue().getSExtValue() - : eIntAttr.getValue().getZExtValue(); - converted.push_back(FloatAttr::get(resultDtype, static_cast(eInt))); - } - return success(); -} - -class PropagateAtenToDtypePattern : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenToDtypeOp op, - PatternRewriter &rewriter) const override { - bool nonBlocking, copyArg; - // The non_blocking arg must be `False`. - if (!matchPattern(op.getNonBlocking(), m_TorchConstantBool(&nonBlocking)) || - nonBlocking) - return failure(); - // The copy arg must be `False`. - if (!matchPattern(op.getCopy(), m_TorchConstantBool(©Arg)) || copyArg) - return failure(); - // The memory_format arg must be `none`. - if (!isa(op.getMemoryFormat().getType())) - return failure(); - - auto inputType = dyn_cast(op.getSelf().getType()); - auto resultType = dyn_cast(op.getType()); - if (!inputType || !resultType || !inputType.hasDtype() || - !resultType.hasDtype()) - return failure(); - auto inputDtype = inputType.getDtype(); - auto resultDtype = resultType.getDtype(); - - SmallVector elements; - if (failed(getListFromTensor(op.getSelf(), elements))) - return failure(); - - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - SmallVector converted; - if (failed(convertOpFoldResults(b, converted, elements, inputDtype, - resultDtype))) - return rewriter.notifyMatchFailure( - op, "Unhandled attribute type encountered."); - - SmallVector vals; - if (failed(materializeFolds(b, converted, vals))) - return failure(); - - Value result = constructAtenTensorOpFromList(b, op.getType(), vals); - rewriter.replaceOp(op, result); - return success(); - } -}; -} // namespace - namespace { template class PropagateAtenViewLikePattern : public OpRewritePattern { @@ -931,49 +828,6 @@ class PropagateAtenArithmeticPattern : public OpRewritePattern { if (failed(materializeFolds(b, resultFolds, resultVals))) return failure(); - if (resultTy.getSizes().empty()) { - rewriter.replaceOpWithNewOp( - op, resultTy, resultVals.front()); - return success(); - } - - Value result = constructAtenTensorOpFromList(b, resultTy, resultVals); - rewriter.replaceOp(op, result); - return success(); - } -}; -} // namespace - -namespace { -template -class PropagateAtenUnaryPattern : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(OpTy op, - PatternRewriter &rewriter) const override { - // Check type - auto resultTy = cast(op.getType()); - if (resultTy.getSizes().size() > 1) - return rewriter.notifyMatchFailure(op, "unsupported: rank > 1"); - if (!resultTy.hasDtype() || !isa(resultTy.getDtype())) - return rewriter.notifyMatchFailure(op, "not an int type"); - - ImplicitLocOpBuilder b(op.getLoc(), rewriter); - SmallVector selfFold; - if (failed(getListFromTensor(op.getSelf(), selfFold))) - return failure(); - SmallVector selfVals; - if (failed(materializeFolds(b, selfFold, selfVals))) - return failure(); - SmallVector resultFolds; - for (uint64_t i = 0; i < selfVals.size(); i++) { - resultFolds.push_back( - b.createOrFold(selfVals[i].getType(), selfVals[i])); - } - SmallVector resultVals; - if (failed(materializeFolds(b, resultFolds, resultVals))) - return failure(); - if (resultTy.getSizes().size() == 0) { rewriter.replaceOpWithNewOp( op, resultTy, resultVals.front()); @@ -986,6 +840,7 @@ class PropagateAtenUnaryPattern : public OpRewritePattern { } }; } // namespace + /// ------ Fold Patterns ------ /// // These are shape-specific folding patterns @@ -1060,11 +915,6 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern { auto resultTy = cast(op.getType()); if (!resultTy.hasSizes() || !resultTy.areAllSizesKnown()) return rewriter.notifyMatchFailure(op, "dynamic output shape"); - if (resultTy.getSizes().size() == 0) { - rewriter.replaceOpWithNewOp( - op, op.getType(), elements.front()); - return success(); - } auto loc = op.getLoc(); SmallVector sizes; @@ -1072,10 +922,12 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern { sizes.push_back(rewriter.create( loc, rewriter.getI64IntegerAttr(size))); + Value one = rewriter.create( + loc, rewriter.getType(), 1); Value sizeList = rewriter.create( loc, rewriter.getType(rewriter.getType()), - sizes); + one); Value none = rewriter.create(loc); Value cstFalse = rewriter.create(loc, false); @@ -1179,24 +1031,6 @@ class FoldAtenWhereSelf : public OpRewritePattern { }; } // namespace -namespace { -// fold ridiculous patterns like size.int -> float.scalar -> int.scalar -class FoldAtenIntScalarPattern : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenIntScalarOp op, - PatternRewriter &rewriter) const override { - auto floatScalarOp = op.getA().getDefiningOp(); - if (!floatScalarOp) - return failure(); - auto sizeOp = floatScalarOp.getA().getDefiningOp(); - if (!sizeOp) - return failure(); - rewriter.replaceOp(op, floatScalarOp.getA()); - return success(); - } -}; -} // namespace namespace { class FoldAtenUnsqueezePattern : public OpRewritePattern { public: @@ -1348,29 +1182,8 @@ class CanonicalizeAtenViewPattern : public OpRewritePattern { if (inputUnmatched == 1 && outputUnmatched > 1) { Value dimVal = rewriter.create(op.getLoc(), leftMatchEnd); - SmallVector unflattenSizes(viewSizes.begin() + leftMatchEnd, - viewSizes.end() - rightMatchEnd); - // try to convert a single dynamic size input to -1 - int64_t dynCount = 0; - int64_t dynIdx = 0; - for (auto [i, v] : llvm::enumerate(unflattenSizes)) { - int64_t szeInt; - if (!matchPattern(v, m_TorchConstantInt(&szeInt))) { - dynCount++; - dynIdx = i; - continue; - } - // if we have a -1 already, make dynCount invalid and break - if (szeInt == -1) { - dynCount = -1; - break; - } - } - // if only one size is dynamic, make it -1 - if (dynCount == 1) - unflattenSizes[dynIdx] = - rewriter.create(op.getLoc(), -1); - + ArrayRef unflattenSizes(viewSizes.begin() + leftMatchEnd, + viewSizes.end() - rightMatchEnd); Value unflattenList = rewriter.create( op.getLoc(), op.getSize().getType(), unflattenSizes); rewriter.replaceOpWithNewOp( @@ -1414,18 +1227,6 @@ template class RemoveUnusedPattern : public OpRewritePattern { namespace { -bool isItemForSliceOp(Operation *op) { - auto itemOp = dyn_cast_or_null(op); - if (!itemOp) - return false; - for (OpOperand &use : op->getUses()) { - Operation *userOp = use.getOwner(); - if (isa(userOp)) - return true; - } - return false; -} - bool isSourceOpForShapeScalarization(Operation *op) { return llvm::isa(op); @@ -1443,7 +1244,7 @@ bool isPrimListOfInts(Operation *op) { bool isAnchorOp(Operation *op) { return isa(op) || isa(op) || - isPrimListOfInts(op) || isItemForSliceOp(op); + isPrimListOfInts(op); } // The argument to this function, op, is the use of some source op, srcOp. If @@ -1477,9 +1278,9 @@ bool isInvalidValidViewConsumer(Operation *op, void populateScalarizationFoldPatterns(RewritePatternSet &patterns) { patterns.insert, FoldAtenSqueezePattern, - FoldAtenIntScalarPattern, FoldAtenUnsqueezePattern, - FoldAtenWhereSelf, FoldAtenTensorSplatPattern, - FoldAtenEqIntPattern>(patterns.getContext()); + FoldAtenUnsqueezePattern, FoldAtenWhereSelf, + FoldAtenTensorSplatPattern, FoldAtenEqIntPattern>( + patterns.getContext()); } void populateScalarizationCanonicalizePatterns(RewritePatternSet &patterns) { @@ -1502,12 +1303,10 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) { PropagateAtenItemPattern, PropagateAtenShapeToTensorPattern, PropagateAtenSliceTensorPattern, PropagateAtenEqTensorPattern, PropagateAtenWhereSelfPattern, PropagateAtenBroadcastToPattern, - PropagateAtenTransposeIntPattern, PropagateAtenToDtypePattern, - PropagateAtenUnaryPattern, + PropagateAtenTransposeIntPattern, PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern, - PropagateAtenArithmeticPattern, PropagateAtenArithmeticPattern>( patterns.getContext()); } @@ -1515,7 +1314,6 @@ void populateScalarizationPropagationPatterns(RewritePatternSet &patterns) { void populateScalarizationRemovePatterns(RewritePatternSet &patterns) { patterns.insert, RemoveUnusedPattern, - RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, @@ -1523,8 +1321,6 @@ void populateScalarizationRemovePatterns(RewritePatternSet &patterns) { RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, - RemoveUnusedPattern, - RemoveUnusedPattern, RemoveUnusedPattern>( patterns.getContext()); } diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 390a2f2d7862..664bbb2d5d8e 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -36,18 +36,6 @@ Torch::matchLegalConstantIndexIntoListOfSize(Value v, int64_t length) { return dim; } -Value Torch::toIntListConstruct(PatternRewriter &rewriter, Location loc, - ArrayRef cstInput) { - SmallVector cstValues; - for (int64_t i : cstInput) { - cstValues.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(i))); - } - return rewriter.create( - loc, Torch::ListType::get(IntType::get(rewriter.getContext())), - cstValues); -} - bool Torch::getListConstructElements(Value v, SmallVectorImpl &elems) { auto listConstruct = v.getDefiningOp(); if (!listConstruct) diff --git a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp index 229b352094e8..1e6879530ce6 100644 --- a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp +++ b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp @@ -104,8 +104,7 @@ class UnpackQuantizedMatmulWeights char mask = (1 << unpackedBitWidth) - 1; for (int b = 0; b < packRatio; b++) { newData[i * packRatio + b] = - APInt(unpackedBitWidth, (el & mask) >> (unpackedBitWidth * b), - /*isSigned=*/false, /*implicitTrunc=*/true); + APInt(unpackedBitWidth, (el & mask) >> (unpackedBitWidth * b)); mask = mask << unpackedBitWidth; } } diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index d9d7ef1a0cd4..c9638c8353b1 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -16,7 +16,6 @@ #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" @@ -53,8 +52,7 @@ void mlir::torch::registerOptionalInputDialects( mlir::DialectRegistry ®istry) { registry.insert(); + scf::SCFDialect, tensor::TensorDialect, tosa::TosaDialect>(); } void mlir::torch::registerAllPasses() { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 1dce55f06158..8c38d0112f6c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -33,8 +33,6 @@ # if a dimension is specified in all expand lists, and not in sumdim list. # This is a bug in the implementation of _trilinear in PyTorch. "Aten_TrilinearModuleZerodDimBug_basic", - # missing lowering from aten.pow.Tensor_Tensor for integer result - "PowIntIntModule_basic", } if torch_version_for_comparison() < version.parse("2.5.0.dev"): @@ -222,6 +220,7 @@ "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", "IntFloatModule_basic", + "PowIntFloatModule_basic", # END tests failing due to: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.Int # ERROR: torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor int call_function aten.len "LenStrModule_basic", @@ -398,7 +397,7 @@ "AtenIntBoolOpConstTrueModule_basic", "AtenIntBoolOpModule_basic", "AtenIntMM_basic", - "AtenNonzero1DDynamicModule_basic", # no lowering for torch.aten.sym_constrain_range_for_size + "AtenItemFpOpModule_basic", "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleZerodDimBug_basic", "QuantizedReluInt32_basic", @@ -424,6 +423,7 @@ "CumsumModule_basic", "CumprodModule_basic", "DeformConv2D_basic", + "DivFloatModule_basic", "DivIntModule_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", @@ -437,6 +437,7 @@ "IntFloatModule_basic", "IntImplicitModule_basic", "LenStrModule_basic", + "MulFloatModule_basic", "NativeGroupNormBackwardModule_basic", "NeFloatIntModule_basic", "NllLossModuleBackward1DMeanWeight_basic", @@ -447,7 +448,7 @@ "NllLossModuleBackward1D_basic", "NumelModule_basic", "NumelZeroRankModule_basic", - "PowIntIntModule_basic", + "PowIntFloatModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", "PrimMinIntModule_basic", @@ -461,11 +462,17 @@ "QuantizedSingleLayer_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", + "ScalarImplicitFloatModule_basic", + "SortIntListReverse_basic", + "SortIntList_basic", "SplitDimDynamicModule_basic", "SplitDimStaticModule_basic", "SqrtIntModule_basic", + "SubFloatModule_basic", "TensorToBoolZeroRank_basic", "TensorToBool_basic", + "TensorToFloatZeroRank_basic", + "TensorToFloat_basic", "ThresholdBackward2dMixedModule_basic", "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", "UpSampleNearest2dDynamicFactor_basic", @@ -494,18 +501,32 @@ "AdaptiveMaxPool1dStatic_basic", "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", + "ElementwiseExpm1IntModule_basic", + "ElementwiseExpm1Module_basic", + "IndexPutImpl1DFloatAccumulateModule_basic", + "IndexPutImpl1DFloatNonAccumulateModule_basic", + "IndexPutImpl1DIntAccumulateModule_basic", + "IndexPutImpl1DIntNonAccumulateModule_basic", + "IndexPutImpl2DFloatNonAccumulateModule_basic", + "IndexPutImpl2DImplicitModule_basic", + "IndexPutImpl2DIndexModule_basic", + "IndexPutImpl2DNoneIndexStaticModule_basic", + "IndexPutImpl3DFloatNonAccumulateModule_basic", + "IndexPutImplIndexWithNoneModule_basic", "IsInfiniteModule_basic", "InterpolateDynamicModule_sizes_nearest", "IouOfModule_basic", "MeshgridIndexingIJ_basic", "MeshgridIndexingXY_basic", "Meshgrid_basic", + "OneHotModule_basic", + # RuntimeError: cannot mutate tensors with frozen storage + "ElementwiseRreluTrainModule_basic", + "ElementwiseRreluTrainStaticModule_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", "ElementwiseSignbitModule_basic", "ElementwiseCopysignModule_basic", - "BernoulliFloatModule_basic", - "BernoulliTensorModule_basic", - "UniformModule_basic", - "UniformStaticShapeModule_basic", } FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | { @@ -517,10 +538,17 @@ "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", "Aten_TrilinearModuleSumAllDims_basic", "Aten_TrilinearModuleSumdims_basic", + # torch export: RuntimeError: cannot mutate tensors with frozen storage + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", } FX_IMPORTER_STABLEHLO_XFAIL_SET = { "AddFloatIntModule_basic", + "ArgmaxIntModule_basic", + "ArgmaxIntModule_multiple_maxs", + "ArgmaxKeepdimModule_basic", + "ArgmaxModule_basic", "AtenKthvalueDynamicDimsModule_basic", "AtenKthvalueFloat64DynamicDimsModule_basic", "AtenKthvalueFloat64Module_basic", @@ -590,6 +618,9 @@ "AnyBoolFalseModule_basic", "AnyBoolTrueModule_basic", "ArangeStartOutViewModule_basic", + "ArgminIntModule_basic", + "ArgminIntModule_multiple_mins", + "ArgminModule_basic", "AtenComplexImagModule_basic", "AtenComplexRealModule_basic", "AtenComplexViewModule_basic", @@ -600,8 +631,6 @@ "AtenDiagEmbedOffsetDiag_basic", "AtenDiagEmbedRevDimDiag_basic", "AtenEmbeddingBagSumExample_basic", - "AtenFftRfft2DLastDim_basic", - "AtenFftRfft2DMiddleDim_basic", "AtenFloatScalarModule_basic", "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", @@ -616,7 +645,6 @@ "AtenMmQMixedSigni8_basic", "AtenMmQint8_basic", "AtenMmQuint8_basic", - "AtenNonzero1DDynamicModule_basic", "AtenRealView128Module_basic", "AtenRealView64Module_basic", "AtenTopKModule_basic", @@ -673,6 +701,7 @@ "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseErfIntModule_basic", + "ElementwiseLogitModule_basic", "ElementwiseMulTensorComplexModule_basic", "ElementwiseMulTensorComplexDiffModule_basic", "ElementwiseQuantizePerTensorModule_basic", @@ -735,7 +764,6 @@ "LenStrModule_basic", "MaxPool2dCeilModeTrueModule_basic", "MaxPool2dStaticCeilModeTrueModule_basic", - "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", "MaxPool2dWithIndicesBackwardDynamic3DModule_basic", "MaxPool2dWithIndicesBackwardDynamic4DModule_basic", "MaxPool2dWithIndicesBackwardStatic3DModule_basic", @@ -774,6 +802,7 @@ "NormalFunctionalModule_basic", "NumelModule_basic", "NumelZeroRankModule_basic", + "PowIntFloatModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", "PrimMinIntModule_basic", @@ -833,6 +862,8 @@ "ScatterValueFloatModule_basic", "ScatterValueIntModule_basic", "SliceOutOfLowerBoundEndIndexModule_basic", + "SortIntListReverse_basic", + "SortIntList_basic", "SortTensorDescending_basic", "SortTensorInteger_basic", "SortTensorNegativeDimension_basic", @@ -897,6 +928,8 @@ "AtenItemIntOpModule_basic", "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", + "ElementwiseExpm1IntModule_basic", + "ElementwiseExpm1Module_basic", "InterpolateDynamicModule_sizes_nearest", "IouOfModule_basic", "IscloseStaticModuleTrue_basic", @@ -905,6 +938,7 @@ "MeshgridIndexingXY_basic", "Meshgrid_basic", "MulIntModule_basic", + "OneHotModule_basic", "ReduceFrobeniusNormComplexModule_basic", "ScalarImplicitIntModule_basic", "ScaledDotProductAttentionBoolMaskModule_basic", @@ -922,9 +956,11 @@ "UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2dStaticSize_basic", "UpSampleNearest2d_basic", - "BernoulliFloatModule_basic", - "UniformModule_basic", - "UniformStaticShapeModule_basic", + # RuntimeError: cannot mutate tensors with frozen storage + "ElementwiseRreluTrainModule_basic", + "ElementwiseRreluTrainStaticModule_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", } FX_IMPORTER_STABLEHLO_CRASHING_SET = { @@ -946,8 +982,9 @@ "Aten_TrilinearModuleSumdims_basic", "Aten_TrilinearModuleSumAllDims_basic", "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", - "CrossEntropyLossModule_basic", - "CrossEntropyLossNoReductionModule_basic", + # torch export: RuntimeError: cannot mutate tensors with frozen storage + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", } STABLEHLO_PASS_SET = { @@ -1189,8 +1226,6 @@ "ElementwiseRsqrtModule_basic", "ElementwiseSigmoidModule_basic", "ElementwiseSinModule_basic", - "ElementwiseSpecialExpm1IntModule_basic", - "ElementwiseSpecialExpm1Module_basic", "ElementwiseSqrtModule_basic", "ElementwiseTanIntModule_basic", "ElementwiseTanModule_basic", @@ -1680,74 +1715,27 @@ "Aten_TrilinearModuleSumAllDims_basic", "Aten_TrilinearModuleSumdims_basic", "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", - "CrossEntropyLossModule_basic", - "CrossEntropyLossNoReductionModule_basic", + "GridSamplerBasic1_basic", + "GridSamplerBasic2_basic", + "GridSamplerBasic3_basic", + "GridSamplerBasic4_basic", "ScatterSrcModule_basic", "ScatterSrcStaticModule_basic", "HBC_basic", + "InterpolateDynamicModule_scales_recompute_bilinear", + "InterpolateDynamicModule_sizes_bilinear", + "InterpolateDynamicModule_sizes_nearest", + "InterpolateStaticModule_scales_bilinear_align_corners", + "UpSampleNearest2d_basic", + "UpSampleNearest2dStaticSize_basic", + "UpSampleNearest2dDynamicSize_basic", + "UpSampleNearest2dDynamicFactor_basic", + "UpSampleNearest2dStaticFactor_basic", } # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { - "Unfold_Module_Rank_4", - "Unfold_Module_Rank_Zero_basic", - "Unfold_Module_basic", - "ElementwiseErfIntModule_basic", - "ElementwiseIntTensorLtFloatScalarModule_basic", - "ElementwiseSigmoidIntModule_basic", - "ElementwiseTanIntModule_basic", - "ElementwiseTanModule_basic", - "ElementwiseUnaryIntModule_basic", - "PowIntFloatModule_basic", - "Deg2radModule_basic", - "ElementwiseIntTensorLtFloatTensorModule_basic", - "L1LossMeanReductionModule_basic", - "L1LossNoReductionModule_basic", - "L1LossSumReductionModule_basic", - "PixelShuffleModuleStaticRank3Int64_basic", - "PixelShuffleModuleStaticRank4Float32_basic", - "RandIntLowModule_basic", - "RandIntModule_basic", - "RandIntPinMemoryModule_basic", - "RenormModuleFloat16_basic", - "SplitDimStaticModule_basic", - "Deg2radModule_basic", - "ElementwiseExpIntModule_basic", - "ElementwiseLog10IntModule_basic", - "ElementwiseLog10Module_basic", - "ElementwiseLog1pModule_basic", - "ElementwiseLog2IntModule_basic", - "ElementwiseLogIntModule_basic", - "ElementwiseLogitModule_basic", - "ElementwiseMishModule_basic", - "L1LossMeanReductionModule_basic", - "L1LossNoReductionModule_basic", - "L1LossSumReductionModule_basic", - "RandIntLowModule_basic", - "RandIntModule_basic", - "RandIntPinMemoryModule_basic", - "SoftplusModule_basic", - "ReflectionPad1dModule2dInput_Right", - "ReflectionPad1dModule2dInput_basic", - "ReflectionPad1dModule3dInput_Left", - "ReflectionPad1dModule3dInput_basic", - "ReflectionPad2dModule_Bottom", - "ReflectionPad2dModule_Left", - "ReflectionPad2dModule_Right", - "ReflectionPad2dModule_Top", - "ReflectionPad2dModule_basic", - "ReplicationPad2dModule_basic", - "ReplicationPad2dModule_bottom0", - "ReplicationPad2dModule_left0", - "ReplicationPad2dModule_right0", - "ReplicationPad2dModule_top0", - "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", - "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", - "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "ElementwiseAtenLogicalNotOpPromoteModule_basic", "ElementwiseCosIntModule_basic", "ElementwiseReciprocalIntModule_basic", @@ -2048,8 +2036,6 @@ "Conv2dWithPaddingDilationStrideStaticModule_depthwise", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "Conv2dWithPaddingModule_basic", - "Conv2dWithValidPaddingModule_basic", - "Conv2dWithSamePaddingModule_basic", "Convolution2DStaticModule_basic", "CosineSimilarityStaticModule_basic", "DetachModule_basic", @@ -2256,7 +2242,6 @@ "MatmulStaticBroadcast_basic", "MaxPool2dEmptyStrideStaticModule_basic", "MaxPool2dStaticCeilModeTrueModule_basic", - "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", "MaxPool2dStaticModule_basic", "MeanModule_basic", "MmDagModule_basic", @@ -2302,8 +2287,6 @@ "PadWithNoneValModule_basic", "PermuteModule_basic", "PermuteNegativeIndexModule_basic", - "PowFloatFloatModule_basic", - "PowFloatIntModule_basic", "PrimListUnpackNumMismatchModule_basic", "PrimsIotaModule_basic", "PrimsSqueezeEmptyDimensionsModule_basic", @@ -2333,7 +2316,6 @@ "ReshapeExpandModule_basic", "ReturnThreeTensorFloat32_basic", "ReturnTwoTensorF32I64_basic", - "ResNet18StaticModule_basic", "RsubFloatModule_basic", "RsubFloatModule_noalpha_basic", "RsubInt0d_NumToTensor_Module_basic", @@ -2450,7 +2432,6 @@ TOSA_PASS_SET | { ### Tests additionally passing in make_fx_tosa - "AdaptiveAvgPool1dStaticEvenMultiple_basic", "IsInfiniteModule_basic", "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", @@ -2543,8 +2524,6 @@ "Conv2dNoPaddingModule_basic", "Conv2dWithPaddingDilationStrideModule_basic", "Conv2dWithPaddingModule_basic", - "Conv2dWithSamePaddingModule_basic", - "Conv2dWithValidPaddingModule_basic", # failed to legalize operation 'torch.operator' "ElementwisePreluModule_basic", "ElementwisePreluStaticModule_basic", @@ -2721,7 +2700,6 @@ "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", - "Conv1dGroupModule_basic", "Conv2dQInt8Module_basic", "Conv2dQInt8Module_depthwise", "Conv2dQInt8Module_grouped", @@ -2815,8 +2793,6 @@ "AtenDiagEmbedRevDimDiag_basic", "AtenEmbeddingBagStaticModule_basic", "AtenEmbeddingBagSumExample_basic", - "AtenFftRfft2DLastDim_basic", - "AtenFftRfft2DMiddleDim_basic", "AtenFloatScalarModule_basic", "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", @@ -2868,16 +2844,10 @@ "CollapsePartialDynamicModule_basic", "CollapseRank1DynamicModule_basic", "CollapseStaticModule_basic", - "ColumnStackBasicIntModule_basic", - "ColumnStack1dModule_basic", - "ColumnStack0dModule_basic", "ConstantBoolParameterModule_basic", "ContainsIntList_False", "ContainsIntList_True", "Conv1dModule_basic", - "Conv1dWithSamePaddingModule_basic", - "Conv1dWithValidPaddingModule_basic", - "Conv1dGroupModule_basic", "Conv2dBiasNoPaddingModule_basic", "Conv2dModule_basic", "Conv2dNoPaddingModule_basic", @@ -2890,11 +2860,7 @@ "Conv2dQInt8PerChannelModule_grouped", "Conv2dWithPaddingDilationStrideModule_basic", "Conv2dWithPaddingModule_basic", - "Conv2dWithSamePaddingModule_basic", - "Conv2dWithValidPaddingModule_basic", "Conv3dModule_basic", - "Conv3dWithSamePaddingModule_basic", - "Conv3dWithValidPaddingModule_basic", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", "Conv_Transpose2dModule_basic", @@ -2908,7 +2874,6 @@ "ConvolutionModule2DTransposeNonUnitOutputPadding_basic", "ConvolutionModule2DTransposeStrided_basic", "ConvolutionModule2DTranspose_basic", - "Deg2radModule_basic", "DivFloatModule_basic", "DivIntModule_basic", "ElementwiseAcoshIntModule_basic", @@ -2941,8 +2906,6 @@ "ElementwiseEluNonDefaultModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", - "ElementwiseSpecialExpm1IntModule_basic", - "ElementwiseSpecialExpm1Module_basic", "ElementwiseFmodTensor_Int_basic", "ElementwiseCreateComplexModule_basic", "ElementwiseMulTensorComplexModule_basic", @@ -2998,9 +2961,6 @@ "IsFloatingPointInt_False", "IscloseStaticModuleTrue_basic", "IscloseStaticModule_basic", - "L1LossNoReductionModule_basic", - "L1LossMeanReductionModule_basic", - "L1LossSumReductionModule_basic", "LeakyReluBackwardModule_basic", "LeakyReluBackwardStaticModule_basic", "LenStrModule_basic", @@ -3008,6 +2968,7 @@ "LinalgNormKeepDimComplexModule_basic", "LinalgVectorNormComplexModule_basic", "LogSoftmaxBackwardModule_basic", + "MaskedScatterStaticBasic_basic", "MaxPool1dCeilModeTrueModule_basic", "MaxPool1dModule_basic", "MaxPool2dCeilModeTrueModule_basic", @@ -3095,7 +3056,7 @@ "PixelShuffleModuleSpatiallyDynamic_basic", "PixelShuffleModuleSpatiallyStatic_basic", "PixelShuffleModuleStaticRank3Int64_basic", - "PowIntIntModule_basic", + "PowIntFloatModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", "PrimMinIntModule_basic", @@ -3382,13 +3343,6 @@ "ScaledDotProductAttentionBoolMaskModule_basic", } -if torch_version_for_comparison() > version.parse("2.5.1"): - ONNX_XFAIL_SET = ONNX_XFAIL_SET | { - # error: 'memref.cast' op operand type 'memref<2x6x4x3xf32>' and result type 'memref<2x6x5x3xf32>' are cast incompatible - # torch.onnx.export produces onnx.MaxPool op with incorrect output shape of 2x6x5x3 instead of 2x6x4x3 - "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", - } - if torch_version_for_comparison() < version.parse("2.4.0.dev"): STABLEHLO_PASS_SET = STABLEHLO_PASS_SET - { "AtenIntMM_basic", @@ -3436,10 +3390,6 @@ } FX_IMPORTER_TOSA_XFAIL_SET = { - "UniformModule_basic", - "UniformStaticShapeModule_basic", - "AtenFftRfft2DLastDim_basic", - "AtenFftRfft2DMiddleDim_basic", "IsInfiniteModule_basic", "LayerNormFwAndBwModule_basic", "LayerNormManualFwAndBwModule_basic", @@ -3449,13 +3399,19 @@ "ElementwiseSignbitModule_basic", "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleZerodDimBug_basic", + "ElementwiseRreluWithNoiseTrainModule_basic", + "ElementwiseRreluWithNoiseTrainStaticModule_basic", "MaxPool3dEmptyStrideStaticModule_basic", "MaxPool3dLargeDatadModule_basic", "MaxPool3dModuleRandomSimple_basic", "MaxPool3dModule_basic", "MaxPool3dStaticModule_basic", "ViewDtypeStaticModule_basic", + "Unfold_Module_Dynamic_basic", + "Unfold_Module_Rank_4", "Unfold_Module_Rank_Zero_Size_Zero_basic", + "Unfold_Module_Rank_Zero_basic", + "Unfold_Module_basic", "ArangeZeroElementOutputModule_basic", "NumpyTRank0Module_basic", "Permute0RankModule_basic", @@ -3478,6 +3434,8 @@ "Conv_Transpose2dStaticModule_basic", "Conv_Transpose3dModule_basic", "Conv_Transpose3dStaticModule_basic", + "ElementwiseFloatTensorGtIntTensorModule_basic", + "ElementwiseIntTensorLtFloatTensorModule_basic", "IndexPutWithNoneAndBroadcastModule_basic", "MaskedScatterStaticBasic_basic", "MaxUnpool3dModulePad0_basic", @@ -3485,6 +3443,7 @@ "MultinomialModule2D_F32", "MultinomialModule2D_basic", "MultinomialModule_basic", + "RenormModuleFloat16_basic", # REMOVE WHEN ENABLE_GQA IS ADDED "ScatterAddStaticModule_basic", "TensorsConcatComplex128FloatModule_basic", @@ -3585,9 +3544,6 @@ "ContainsIntList_True", "Conv1dModule_basic", "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", - "Conv1dWithSamePaddingModule_basic", - "Conv1dWithValidPaddingModule_basic", - "Conv1dGroupModule_basic", "Conv2dQInt8Module_basic", "Conv2dQInt8Module_depthwise", "Conv2dQInt8Module_grouped", @@ -3598,8 +3554,6 @@ "Conv2dWithPaddingDilationStrideStaticModule_grouped", "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", "Conv3dModule_basic", - "Conv3dWithSamePaddingModule_basic", - "Conv3dWithValidPaddingModule_basic", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", "Conv_Transpose2dModule_basic", @@ -3649,18 +3603,31 @@ "ElementwiseCoshModule_basic", "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", + "ElementwiseErfIntModule_basic", + "ElementwiseExpIntModule_basic", "ElementwiseExpm1IntModule_basic", "ElementwiseExpm1Module_basic", + "ElementwiseGeluApproximateTanhModule_basic", + "ElementwiseIntTensorLtFloatScalarModule_basic", + "ElementwiseLog10IntModule_basic", + "ElementwiseLog10Module_basic", + "ElementwiseLog1pModule_basic", + "ElementwiseLog2IntModule_basic", + "ElementwiseLogIntModule_basic", + "ElementwiseLogitModule_basic", + "ElementwiseMishModule_basic", "ElementwiseMulTensorComplexDiffModule_basic", "ElementwiseMulTensorComplexModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseQuantizePerTensorUIntModule_basic", + "ElementwiseSigmoidIntModule_basic", "ElementwiseSinhIntModule_basic", "ElementwiseSinhModule_basic", - "ElementwiseSpecialExpm1IntModule_basic", - "ElementwiseSpecialExpm1Module_basic", + "ElementwiseTanIntModule_basic", + "ElementwiseTanModule_basic", "ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseToDtypeI64ToUI8Module_basic", + "ElementwiseUnaryIntModule_basic", "ElementwiseWhereScalarOtherStaticModule_basic", "EqIntModule_basic", "FloatImplicitModule_basic", @@ -3696,6 +3663,8 @@ "IndexPutImpl3DFloatAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", "InterpolateDynamicModule_sizes_bilinear", + "InterpolateDynamicModule_sizes_nearest", + "InterpolateStaticModule_scales_bilinear_align_corners", "InterpolateDynamicModule_scales_recompute_bilinear", "IntFloatModule_basic", "IntImplicitModule_basic", @@ -3706,6 +3675,7 @@ "LinalgVectorNormComplexModule_basic", "LinspaceDtypeModule_basic", "LinspaceEmptyModule_basic", + "MaskedFillTensorFloatValueModule_basic", "MaskedScatterStaticBasic_basic", "MaxPool1dCeilModeTrueModule_basic", "MaxPool1dModule_basic", @@ -3766,7 +3736,12 @@ "NumelModule_basic", "NumelZeroRankModule_basic", "OnesLikeModule_falsePinMemory", - "PowIntIntModule_basic", + "PixelShuffleModuleFullDynamic_basic", + "PixelShuffleModuleSpatiallyDynamic_basic", + "PixelShuffleModuleSpatiallyStatic_basic", + "PixelShuffleModuleStaticRank3Int64_basic", + "PixelShuffleModuleStaticRank4Float32_basic", + "PowIntFloatModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", "PrimMinIntModule_basic", @@ -3782,6 +3757,9 @@ "QuantizedReluInt8_basic", "QuantizedReluUint8_basic", "QuantizedSingleLayer_basic", + "RandIntLowModule_basic", + "RandIntModule_basic", + "RandIntPinMemoryModule_basic", "RandnDtypeDeviceModule_basic", "RandnGeneratorF64Module_basic", "RandnGeneratorModule_basic", @@ -3791,11 +3769,26 @@ "ReduceAllDimEmpty_basic", "ReduceFrobeniusNormComplexModule_basic", "ReduceL1NormComplexModule_basic", + "ReduceL1NormWithDTypeModule_basic", "ReduceL2NormComplexModule_basic", "ReduceL3NormKeepDimComplexModule_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", "ReduceSumDimIntListEmptyDimModule_basic", + "ReflectionPad1dModule2dInput_Right", + "ReflectionPad1dModule2dInput_basic", + "ReflectionPad1dModule3dInput_Left", + "ReflectionPad1dModule3dInput_basic", + "ReflectionPad2dModule_Bottom", + "ReflectionPad2dModule_Left", + "ReflectionPad2dModule_Right", + "ReflectionPad2dModule_Top", + "ReflectionPad2dModule_basic", + "ReplicationPad2dModule_basic", + "ReplicationPad2dModule_bottom0", + "ReplicationPad2dModule_left0", + "ReplicationPad2dModule_right0", + "ReplicationPad2dModule_top0", "RollModule_basic", "ScalarConstantTupleModule_basic", "ScalarImplicitFloatModule_basic", @@ -3833,6 +3826,7 @@ "SliceOutOfLowerBoundEndIndexModule_basic", "SliceOutOfLowerBoundStartIndexModule_basic", "SliceSizeTwoStepModule_basic", + "SoftplusModule_basic", "SortIntListReverse_basic", "SortIntList_basic", "SortTensorDescending_basic", @@ -3866,23 +3860,50 @@ "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", "UpSampleNearest2dBackwardScalesNone_basic", "UpSampleNearest2dBackward_basic", + "UpSampleNearest2dDynamicFactor_basic", + "UpSampleNearest2dDynamicSize_basic", + "UpSampleNearest2dStaticFactor_basic", + "UpSampleNearest2dStaticSize_basic", + "UpSampleNearest2d_basic", "ViewCollapseDynamicWithAtenSizeIntModule_basic", "ViewSizeFromOtherTensor_basic", "VisionTransformerModule_basic", "ZerosLikeModule_falsePinMemory", + # count_include_pad and divisor_override check in TOSA AvgPool2d + "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic", + "AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic", + "AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", + "ResNet18Module_basic", + "ResNet18StaticModule_basic", + "MobilenetV3Module_basic", # Unexpected failures due to new PyTorch version update "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", "AdaptiveAvgPool1dGeneralDynamic_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dStaticEvenMultiple_basic", "AdaptiveAvgPool1dStaticLargerOutput_basic", + "AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool2dDynamicNoBatch_basic", "AdaptiveAvgPool2dDynamic_basic", "CrossEntropyLossModule_basic", "CrossEntropyLossNoReductionModule_basic", + "ElementwiseRreluTrainModule_basic", + "ElementwiseRreluTrainStaticModule_basic", + "IndexPutImpl1DFloatNonAccumulateModule_basic", + "IndexPutImpl1DIntNonAccumulateModule_basic", + "IndexPutImpl2DFloatNonAccumulateModule_basic", + "IndexPutImpl3DFloatNonAccumulateModule_basic", "IouOfModule_basic", "MeshgridIndexingIJ_basic", "MeshgridIndexingXY_basic", "Meshgrid_basic", + "OneHotModule_basic", "ReduceFrobeniusNormKeepDimModule_basic", "ReduceFrobeniusNormModule_basic", "ScaledDotProductAttentionBoolMaskModule_basic", @@ -3905,18 +3926,6 @@ } ONNX_TOSA_XFAIL_SET = { - "AtenFftRfft2DLastDim_basic", - "AtenFftRfft2DMiddleDim_basic", - "PowFloatIntModule_basic", - "PowIntFloatModule_basic", - "PowIntIntModule_basic", - "ColumnStack0dModule_basic", - "ColumnStack1dModule_basic", - "ColumnStackBasicIntModule_basic", - "Deg2radModule_basic", - "L1LossMeanReductionModule_basic", - "L1LossNoReductionModule_basic", - "L1LossSumReductionModule_basic", "FloatPowerTensorTensorStaticModule_basic", "IsInfiniteModule_basic", "ElementwiseCopysignModule_basic", @@ -4162,6 +4171,7 @@ "ChunkListUnpackDynamic_Module_basic", "ChunkListUnpackUnevenDynamic_Module_basic", "ChunkListUnpackUneven_Module_basic", + "ChunkListUnpack_Module_basic", "CollapseAllDimensionsModule_basic", "CollapseFullDynamicModule_basic", "CollapsePartialDynamicModule_basic", @@ -4171,10 +4181,7 @@ "ContainsIntList_False", "ContainsIntList_True", "Conv1dModule_basic", - "Conv1dWithSamePaddingModule_basic", - "Conv1dWithValidPaddingModule_basic", "Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic", - "Conv1dGroupModule_basic", "Conv2dBiasNoPaddingModule_basic", "Conv2dModule_basic", "Conv2dNoPaddingModule_basic", @@ -4189,11 +4196,7 @@ "Conv2dWithPaddingDilationStrideStaticModule_grouped", "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", "Conv2dWithPaddingModule_basic", - "Conv2dWithSamePaddingModule_basic", - "Conv2dWithValidPaddingModule_basic", "Conv3dModule_basic", - "Conv3dWithSamePaddingModule_basic", - "Conv3dWithValidPaddingModule_basic", "ConvTbcModule_basic", "ConvTranspose2DQInt8_basic", "Conv_Transpose2dModule_basic", @@ -4326,6 +4329,7 @@ "ElementwiseLog2IntModule_basic", "ElementwiseLogIntModule_basic", "ElementwiseLtDiffWidthScalarModule_basic", + "ElementwiseMishModule_basic", "ElementwiseMulScalarModule_basic", "ElementwiseMulTensorComplexDiffModule_basic", "ElementwiseMulTensorComplexModule_basic", @@ -4346,11 +4350,10 @@ "ElementwiseSinIntModule_basic", "ElementwiseSinhIntModule_basic", "ElementwiseSinhModule_basic", - "ElementwiseSpecialExpm1IntModule_basic", - "ElementwiseSpecialExpm1Module_basic", "ElementwiseSqrtIntModule_basic", "ElementwiseSubScalarIntModule_basic", "ElementwiseTanIntModule_basic", + "ElementwiseTanModule_basic", "ElementwiseTernaryModule_basic", "ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseToDtypeI64ToI8Module_basic", @@ -4543,6 +4546,7 @@ "MeanDimNoneDimModule_basic", "MeanDtypeModule_basic", "MeanDynamicSizesModule_basic", + "MeanModule_basic", "Mlp1LayerModule_basic", "Mlp2LayerModuleNoBias_basic", "Mlp2LayerModule_basic", @@ -4614,6 +4618,7 @@ "PixelShuffleModuleSpatiallyStatic_basic", "PixelShuffleModuleStaticRank3Int64_basic", "PixelShuffleModuleStaticRank4Float32_basic", + "PowIntFloatModule_basic", "PrimMaxIntModule_basic", "PrimMinIntDynamicModule_basic", "PrimMinIntModule_basic", @@ -4632,6 +4637,7 @@ "QuantizedSingleLayer_basic", "RandIntDtypeModule_basic", "RandIntLowDtypeModule_basic", + "RandIntLowModule_basic", "RandIntModule_basic", "RandIntPinMemoryModule_basic", "RandLikeDtypeModule_basic", @@ -4697,9 +4703,27 @@ "ReduceSumDimIntListDtypeFloatModule_basic", "ReduceSumDimIntListDtypeIntModule_basic", "ReduceSumDimIntListElementTypeBoolModule_basic", + "ReduceSumDimIntListEmptyDimModule_basic", "ReduceSumDtypeFloatModule_basic", "ReduceSumDtypeIntModule_basic", "ReduceSumElementTypeBoolModule_basic", + "ReduceSumFloatModule_basic", + "ReduceSumSignedIntModule_basic", + "ReduceSumUnsignedIntModule_basic", + "ReflectionPad1dModule2dInput_Right", + "ReflectionPad1dModule2dInput_basic", + "ReflectionPad1dModule3dInput_Left", + "ReflectionPad1dModule3dInput_basic", + "ReflectionPad2dModule_Bottom", + "ReflectionPad2dModule_Left", + "ReflectionPad2dModule_Right", + "ReflectionPad2dModule_Top", + "ReflectionPad2dModule_basic", + "ReplicationPad2dModule_basic", + "ReplicationPad2dModule_bottom0", + "ReplicationPad2dModule_left0", + "ReplicationPad2dModule_right0", + "ReplicationPad2dModule_top0", "ResNet18Module_basic", "ReshapeAliasCollapseModule_basic", "ReshapeAliasExpandModule_basic", @@ -4770,6 +4794,7 @@ "SoftmaxIntModule_basic", "SoftmaxIntNegDimModule_basic", "SoftmaxIntNonNoneDtypeModule_basic", + "SoftplusModule_basic", "SortIntListReverse_basic", "SortIntList_basic", "SortTensorDescending_basic", @@ -4861,6 +4886,10 @@ "TypePromotionDifferentCategoryModule_basic", "TypePromotionSameCategoryDifferentWidthModule_basic", "TypePromotionZeroRankHigherCategoryModule_basic", + "UnflattenIntNegativeOneDimStaticModule_basic", + "UnflattenIntNegativeOneSizeStaticModule_basic", + "UnflattenIntStaticModule_basic", + "UnflattenStaticModule_basic", "UniformModule_basic", "UniformNoCorrelationModule_basic", "UniformStaticShapeModule_basic", @@ -4941,10 +4970,3 @@ "_LogSoftmaxModule_basic", "_SoftmaxModule_basic", } - -if torch_version_for_comparison() > version.parse("2.5.1"): - ONNX_TOSA_XFAIL_SET = ONNX_TOSA_XFAIL_SET | { - # error: 'memref.cast' op operand type 'memref<2x6x4x3xf32>' and result type 'memref<2x6x5x3xf32>' are cast incompatible - # torch.onnx.export produces onnx.MaxPool op with incorrect output shape of 2x6x5x3 instead of 2x6x4x3 - "MaxPool2dStaticCeilModeTrueReduceOutputModule_basic", - } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index a73d188d7168..12b1f8c76b37 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -222,9 +222,6 @@ def aten〇exp2〡shape(self: List[int]) -> List[int]: def aten〇expm1〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) -def aten〇special_expm1〡shape(self: List[int]) -> List[int]: - return upstream_shape_functions.unary(self) - def aten〇isfinite〡shape(self: List[int]) -> List[int]: return self @@ -649,9 +646,6 @@ def aten〇rrelu〡shape(self: List[int], lower: float = 0.125, upper: float = 0 def aten〇rrelu_with_noise〡shape(self: List[int], noise: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> List[int]: return upstream_shape_functions.unary(self) -def aten〇rrelu_with_noise_functional〡shape(self: List[int], noise: List[int], lower: float = 0.125, upper: float = 0.33333333333333331, training: bool = False, generator: Any = None) -> Tuple[List[int], List[int]]: - return upstream_shape_functions.unary(self), upstream_shape_functions.unary(noise) - def aten〇selu〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -1763,7 +1757,8 @@ def aten〇col2im〡shape(self: List[int], output_size: List[int], kernel_size: # compute the shape of the output num_channels = n_input_plane // (kernel_size[0] * kernel_size[1]) - out: List[int] = ([self[0], num_channels] if batch_dim == 0 else [num_channels]) + [elem for elem in output_size] + out: List[int] = [self[0], num_channels] if batch_dim == 0 else [num_channels] + out += [elem for elem in output_size] return out @@ -1845,32 +1840,9 @@ def torchvision〇deform_conv2d〡dtype(input_rank_dtype: Tuple[int, int], weigh def aten〇conv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), groups: int = 1) -> List[int]: return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups) -def _conv_padding(weight: List[int], dilation: List[int], padding: str): - rank = len(weight) - # first 2 dimensions of weight corresponds to out_channels and in_channels/groups - num_unpadded_dims = 2 - assert rank > num_unpadded_dims, "conv: weight must be at least 3 dimensional." - num_kernel_elems = rank - num_unpadded_dims - padding_int = [0] * num_kernel_elems - if padding == "same": - for d, i in zip( - dilation, range(num_kernel_elems - 1, -1, -1) - ): - padding_val = d * (weight[num_unpadded_dims+i] - 1) - padding_int[i] = padding_val // 2 - return padding_int - -def aten〇conv2d〇padding〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: str = "valid", dilation: List[int] = (1, 1,), groups: int = 1) -> List[int]: - padding_int = _conv_padding(weight, dilation, padding) - return upstream_shape_functions.conv2d(input, weight, bias, stride, padding_int, dilation, groups) - def aten〇conv3d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1, 1,), padding: List[int] = (0, 0, 0,), dilation: List[int] = (1, 1, 1,), groups: int = 1) -> List[int]: return upstream_shape_functions.conv3d(input, weight, bias, stride, padding, dilation, groups) -def aten〇conv3d〇padding〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1, 1,), padding: str = "valid", dilation: List[int] = (1, 1, 1,), groups: int = 1) -> List[int]: - padding_int = _conv_padding(weight, dilation, padding) - return upstream_shape_functions.conv3d(input, weight, bias, stride, padding_int, dilation, groups) - def aten〇conv_transpose2d〇input〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), output_padding: List[int] = (0, 0,), groups: int = 1, dilation: List[int] = (1, 1,)) -> List[int]: return upstream_shape_functions.conv_transpose2d_input(input, weight, bias, stride, padding, output_padding, groups, dilation) @@ -1912,10 +1884,6 @@ def aten〇convolution〡shape(input: List[int], weight: List[int], bias: Option def aten〇conv1d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), dilation: List[int] = (1,), groups: int = 1) -> List[int]: return upstream_shape_functions.conv_forwards(input, weight, bias, stride, padding, dilation, transposed=False, output_padding=[], groups=1) -def aten〇conv1d〇padding〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1,), padding: str = "valid", dilation: List[int] = (1,), groups: int = 1) -> List[int]: - padding_int = _conv_padding(weight, dilation, padding) - return upstream_shape_functions.conv_forwards(input, weight, bias, stride, padding_int, dilation, transposed=False, output_padding=[], groups=1) - def aten〇conv_transpose1d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), output_padding: List[int] = (0,), groups: int = 1, dilation: List[int] = (1,)) -> List[int]: return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, True, output_padding, groups) @@ -2095,9 +2063,6 @@ def aten〇tril_indices〡shape(row: int, col: int, offset: int = 0, dtype: Opti return [2, trapezoid_size + rectangle_size] -def aten〇deg2rad〡shape(self: List[int]) -> List[int]: - return upstream_shape_functions.unary(self) - @check_shape_function([ Invocation(TensorOfShape(2, 3), LongTensorOfShape(2), None, 1, -100), # Basic case. Invocation(TensorOfShape(3), LongTensorOfShape(), None, 1, -100), # No batch dim. @@ -2116,11 +2081,6 @@ def aten〇mse_loss〡shape(self: List[int], target: List[int], reduction: int = return upstream_shape_functions.unary(self) return [] -def aten〇l1_loss〡shape(self: List[int], target: List[int], reduction: int = 1) -> List[int]: - if reduction == 0: - return upstream_shape_functions.unary(self) - return [] - def aten〇cross_entropy_loss〡shape(self: List[int], target: List[int], weight: Optional[List[int]] = None, reduction: int = 1, ignore_index: int = -100, label_smoothing: float = 0.) -> List[int]: return upstream_shape_functions.cross_entropy_loss(self, target, weight, reduction, ignore_index, label_smoothing) @@ -2320,35 +2280,9 @@ def aten〇hstack〡shape(tensors: List[List[int]]) -> List[int]: return upstream_shape_functions.cat(tensors_atleast1d, dim=1) -@check_shape_function([ - Invocation([LongTensorOfShape(2, 4, 3), LongTensorOfShape(2, 5, 3)]), # Basic case. -]) -def aten〇column_stack〡shape(tensors: List[List[int]]) -> List[int]: - tensors2d: List[List[int]] = [] - for tensor in tensors: - if len(tensor) == 0: - tensor = [1, 1] - elif len(tensor) == 1: - tensor.append(1) - tensors2d.append(tensor) - - return upstream_shape_functions.cat(tensors2d, dim=1) - def aten〇fft_fft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]: return self -@check_shape_function([ - Invocation(TensorOfShape(3, 9, 5), None, -2, None) # Second-last dim -]) -def aten〇fft_rfft〡shape(self: List[int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> List[int]: - dim = (dim + len(self)) if dim < 0 else dim - assert dim >= 0 and dim < len(self), "Expected dim in [-rank, rank-1]" - out: List[int] = [] - for s in self: - out.append(s) - out[dim] = self[dim] // 2 + 1 - return out - @check_shape_function([ Invocation(TensorOfShape(1, 128), 16, None, 16, TensorOfShape(16), False, None, True) # With an explicit 1-D window. ]) @@ -2723,11 +2657,6 @@ def aten〇expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) -def aten〇special_expm1〡dtype(self_rank_dtype: Tuple[int, int]) -> int: - self_rank, self_dtype = self_rank_dtype - return _get_dtype_of_floating_point_op(self_dtype) - def aten〇isfinite〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.bool @@ -3475,25 +3404,21 @@ def aten〇celu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, floa self_rank, self_dtype = self_rank_dtype return self_dtype -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, *all_integer_dtypes()})) def aten〇rrelu〡dtype(self_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int: self_rank, self_dtype = self_rank_dtype + assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype) return self_dtype -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2)) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2, error_types={torch.bool, *all_integer_dtypes()})) def aten〇rrelu_with_noise〡dtype(self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> int: self_rank, self_dtype = self_rank_dtype noise_rank, noise_dtype = noise_rank_dtype + assert is_float_dtype(self_dtype) or is_complex_dtype(self_dtype) + assert is_float_dtype(noise_dtype) or is_complex_dtype(noise_dtype) assert self_rank == noise_rank return self_dtype -@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2)) -def aten〇rrelu_with_noise_functional〡dtype(self_rank_dtype: Tuple[int, int], noise_rank_dtype: Tuple[int, int], lower: Union[int, float, complex] = 0.125, upper: Union[int, float, complex] = 0.33333333333333331, training: bool = False, generator: Any = None) -> Tuple[int, int]: - self_rank, self_dtype = self_rank_dtype - noise_rank, noise_dtype = noise_rank_dtype - assert self_rank == noise_rank - return self_dtype, noise_dtype - @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool})) def aten〇relu6〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -3946,23 +3871,6 @@ def aten〇fft_fft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = else: assert False, "Unsupported dtype" -@check_dtype_function( - _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex32, torch.complex64, torch.complex128, torch.bfloat16})) -def aten〇fft_rfft〡dtype(self_rank_dtype: Tuple[int, int], n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> int: - self_rank, self_dtype = self_rank_dtype - if self_dtype == torch.float16: - return torch.complex32 - elif self_dtype == torch.float32: - return torch.complex64 - elif self_dtype == torch.float64: - return torch.complex128 - elif is_integer_dtype(self_dtype): - return torch.complex64 - else: - assert False, "Unsupported dtype" - - - @check_dtype_function([ Invocation(TensorOfShape(1,128, dtype=torch.complex64), n_fft=16, return_complex=False), # output dtype = torch.float32 Invocation(TensorOfShape(1,128, dtype=torch.complex64), n_fft=16, return_complex=True), # output dtype = torch.complex64 @@ -4341,15 +4249,6 @@ def aten〇mse_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: assert not is_integer_dtype(promoted_dtype) return promoted_dtype -def aten〇l1_loss〡dtype(self_rank_dtype: Tuple[int, int], target_rank_dtype: Tuple[int, int], reduction: int = 1) -> int: - self_rank, self_dtype = self_rank_dtype - target_rank, target_dtype = target_rank_dtype - ranks: List[Optional[int]] = [self_rank, target_rank] - dtypes = [self_dtype, target_dtype] - promoted_dtype = promote_dtypes(ranks, dtypes) - assert not is_integer_dtype(promoted_dtype) - return promoted_dtype - @check_dtype_function(_check_two_tensor_op()) def aten〇mul〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: other_rank, other_dtype = other_rank_dtype @@ -5662,23 +5561,6 @@ def aten〇hstack〡dtype(tensors_rank_dtype: List[Tuple[int, int]]) -> int: return promote_dtypes(ranks, dtypes) -@check_dtype_function( - [Invocation([NonZeroDTensorWithDtype(torch.bool), NonZeroDTensorWithDtype(torch.int32), NonZeroDTensorWithDtype(torch.int64)]), - Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32)]), - Invocation([NonZeroDTensorWithDtype(torch.float16), NonZeroDTensorWithDtype(torch.float64)]), - Invocation([NonZeroDTensorWithDtype(torch.float32), NonZeroDTensorWithDtype(torch.int32), - NonZeroDTensorWithDtype(torch.complex64)])]) -def aten〇column_stack〡dtype(tensors_rank_dtype: List[Tuple[int, int]]) -> int: - ranks: List[Optional[int]] = [] - dtypes: List[int] = [] - assert len(tensors_rank_dtype) != 0 - for tensor_rank_dtype in tensors_rank_dtype: - tensor_rank, tensor_dtype = tensor_rank_dtype - ranks.append(tensor_rank) - dtypes.append(tensor_dtype) - - return promote_dtypes(ranks, dtypes) - @check_dtype_function( [Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32), TensorOfShape(1, dtype=torch.int32)]),]) @@ -5822,10 +5704,6 @@ def aten〇triu_indices〡dtype(row: int, col: int, offset: int = 0, dtype: Opti def aten〇tril_indices〡dtype(row: int, col: int, offset: int = 0, dtype: Optional[int] = 4, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: return torch.int64 if dtype is None else dtype -def aten〇deg2rad〡dtype(self_rank_dtype: Tuple[int, int]) -> int: - self_rank, self_dtype = self_rank_dtype - return _get_dtype_of_floating_point_op(self_dtype) - def aten〇int_repr〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype if (self_dtype == torch.quint8): diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 930979b3c939..1a81a4dcd7ea 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -452,7 +452,6 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::trunc : (Tensor) -> (Tensor)", has_folder=True) - emit("aten::special_expm1 : (Tensor) -> (Tensor)") emit_with_mutating_variants( "aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True ) @@ -575,21 +574,12 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::conv3d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" ) - emit( - "aten::conv3d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)" - ) emit( "aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" ) - emit( - "aten::conv2d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)" - ) emit( "aten::conv1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" ) - emit( - "aten::conv1d.padding : (Tensor, Tensor, Tensor?, int[], str, int[], int) -> (Tensor)" - ) emit( "aten::conv_transpose1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)" ) @@ -757,7 +747,6 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)") emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)") emit("aten::mse_loss_backward : (Tensor, Tensor, Tensor, int) -> (Tensor)") - emit("aten::l1_loss : (Tensor, Tensor, int) -> (Tensor)") emit( "aten::upsample_nearest2d_backward : (Tensor, int[], int[], float?, float?) -> (Tensor)" ) @@ -982,7 +971,6 @@ def emit_with_mutating_variants(key, **kwargs): "aten::hann_window.periodic : (int, bool, int?, int?, Device?, bool?) -> (Tensor)" ) emit("aten::fft_fft : (Tensor, int?, int, str?) -> (Tensor)") - emit("aten::fft_rfft : (Tensor, int?, int, str?) -> (Tensor)") emit("aten::fft_ifft : (Tensor, int?, int, str?) -> (Tensor)") emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)") emit( @@ -1065,7 +1053,6 @@ def emit_with_mutating_variants(key, **kwargs): ) emit("aten::stack : (Tensor[], int) -> (Tensor)") emit("aten::hstack : (Tensor[]) -> (Tensor)") - emit("aten::column_stack : (Tensor[]) -> (Tensor)") emit("aten::append.t : (t[], t) -> (t[])") emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True) emit("aten::eq.int_list : (int[], int[]) -> (bool)", has_folder=True) @@ -1129,12 +1116,10 @@ def emit_with_mutating_variants(key, **kwargs): has_folder=True, has_canonicalizer=True, ) - emit("aten::mul.int_float : (int, float) -> (float)", has_folder=True) emit("aten::div.int : (int, int) -> (float)", has_folder=True) emit("aten::neg.int : (int) -> (int)", has_folder=True) emit("aten::log.int : (int) -> (float)") emit("aten::add.float_int : (float, int) -> (float)", has_folder=True) - emit("aten::mul.float_int : (float, int) -> (float)", has_folder=True) emit("aten::sub.float : (float, float) -> (float)", has_folder=True) emit("aten::mul.float : (float, float) -> (float)", has_folder=True) emit("aten::div.float : (float, float) -> (float)", has_folder=True) @@ -1149,14 +1134,12 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::gt.float_int : (float, int) -> (bool)") emit("aten::pow.int_float : (int, float) -> (float)", has_folder=True) emit("aten::__and__.bool : (bool, bool) -> (bool)") - emit("aten::eq.bool : (bool, bool) -> (bool)", has_folder=True) emit("aten::ne.bool : (bool, bool) -> (bool)", has_folder=True) emit("aten::__is__ : (t1, t2) -> (bool)", has_folder=True) emit("aten::__isnot__ : (t1, t2) -> (bool)", has_folder=True) emit("aten::__not__ : (bool) -> (bool)", has_folder=True) emit("aten::__or__.bool : (bool, bool) -> (bool)", has_folder=True) emit("aten::len.t : (t[]) -> (int)", has_folder=True, has_canonicalizer=True) - emit("aten::mul.left_t : (t[], int) -> (t[])", has_canonicalizer=True) emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True) emit("aten::_set_item.t : (t[], int, t) -> (t[])") emit("aten::mul : (Scalar, Scalar) -> (Scalar)", has_folder=True) @@ -1184,8 +1167,6 @@ def emit_with_mutating_variants(key, **kwargs): has_verifier=True, ) - emit("aten::deg2rad : (Tensor) -> (Tensor)") - # backprop ops emit("aten::_softmax_backward_data : (Tensor, Tensor, int, int) -> (Tensor)") emit("aten::tanh_backward : (Tensor, Tensor) -> (Tensor)") @@ -1212,9 +1193,6 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::rrelu_with_noise_backward : (Tensor, Tensor, Tensor, Scalar, Scalar, bool, bool) -> (Tensor)" ) - emit( - "aten::rrelu_with_noise_functional : (Tensor, Tensor, Scalar, Scalar, bool, Generator?) -> (Tensor, Tensor)" - ) # quantized ops emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 927bfe85df8a..5aa22ce3b122 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -87,29 +87,6 @@ def BmmFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4)) -class BmmFloat16Module(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([-1, -1, -1], torch.float16, True), - ([-1, -1, -1], torch.float16, True), - ] - ) - def forward(self, lhs, rhs): - return torch.bmm(lhs, rhs) - - -@register_test_case(module_factory=lambda: BmmFloat16Module()) -def BmmFloat16Module_basic(module, tu: TestUtils): - module.forward( - tu.rand(3, 4, 5).to(torch.float16), tu.rand(3, 5, 4).to(torch.float16) - ) - - class BmmIntModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1432,83 +1409,6 @@ def HstackBasicComplexModule_basic(module, tu: TestUtils): # ============================================================================== -class ColumnStackBasicIntModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([2, 3, 4], torch.bool, True), - ([2, 3, 4], torch.int32, True), - ([2, 3, 4], torch.int64, True), - ] - ) - def forward(self, x, y, z): - return torch.ops.aten.column_stack([x, y, z]) - - -@register_test_case(module_factory=lambda: ColumnStackBasicIntModule()) -def ColumnStackBasicIntModule_basic(module, tu: TestUtils): - module.forward( - tu.randint(2, 3, 4, low=0, high=2).bool(), - tu.randint(2, 3, 4, low=0, high=100).int(), - tu.randint(2, 3, 4, low=0, high=100).long(), - ) - - -# ============================================================================== - - -class ColumnStack1dModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([4], torch.float32, True), - ([4], torch.float32, True), - ] - ) - def forward(self, x, y): - return torch.ops.aten.column_stack([x, y]) - - -@register_test_case(module_factory=lambda: ColumnStack1dModule()) -def ColumnStack1dModule_basic(module, tu: TestUtils): - module.forward(tu.rand(4), tu.rand(4)) - - -# ============================================================================== - - -class ColumnStack0dModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([], torch.float32, True), - ([], torch.float32, True), - ] - ) - def forward(self, x, y): - return torch.ops.aten.column_stack([x, y]) - - -@register_test_case(module_factory=lambda: ColumnStack0dModule()) -def ColumnStack0dModule_basic(module, tu: TestUtils): - module.forward(torch.tensor(4.0), torch.tensor(1.0)) - - -# ============================================================================== - - class GatherModule(torch.nn.Module): def __init__(self): super().__init__() @@ -4449,100 +4349,25 @@ def IntImplicitModule_basic(module, tu: TestUtils): # ============================================================================== -class PowModule(torch.nn.Module): +class PowIntFloat(torch.nn.Module): def __init__(self): super().__init__() + self.value = 2 + self.power_value = 3.0 @export @annotate_args( [ None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), ] ) - def forward(self, x, y): - return torch.ops.aten.pow(x, y) - - -@register_test_case(module_factory=lambda: PowModule()) -def PowFloatFloatModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5), tu.rand(3, 4, 5)) - - -# ============================================================================== - - -class PowIntFloatModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([-1, -1, -1], torch.int32, True), - ([-1, -1, -1], torch.float32, True), - ] - ) - def forward(self, x, y): - return torch.ops.aten.pow(x, y) + def forward(self): + return torch.ops.aten.pow(self.value, self.power_value) -@register_test_case(module_factory=lambda: PowIntFloatModule()) +@register_test_case(module_factory=lambda: IntFloatModule()) def PowIntFloatModule_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, 5, dtype=torch.int32), tu.rand(3, 4, 5)) - - -# ============================================================================== - - -class PowFloatIntModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.int32, True), - ] - ) - def forward(self, x, y): - return torch.ops.aten.pow(x, y) - - -@register_test_case(module_factory=lambda: PowFloatIntModule()) -def PowFloatIntModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5), tu.randint(3, 4, 5, dtype=torch.int32)) - - -# ============================================================================== - - -class PowIntIntModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([-1, -1, -1], torch.int32, True), - ([-1, -1, -1], torch.int32, True), - ] - ) - def forward(self, x, y): - return torch.ops.aten.pow(x, y) - - -@register_test_case(module_factory=lambda: PowIntIntModule()) -def PowIntIntModule_basic(module, tu: TestUtils): - module.forward( - tu.randint(3, 4, 5, high=10, dtype=torch.int32), - tu.randint(3, 4, 5, high=20, dtype=torch.int32), - ) + module.forward() # ============================================================================== @@ -6430,26 +6255,3 @@ def AtenPolarDoubleModule_basic(module, tu: TestUtils): module.forward( tu.rand(2, 5, 3, 4).to(torch.float64), tu.rand(2, 5, 3, 4).to(torch.float64) ) - - -# ============================================================================== - - -class AtenNonzero1DDynamicModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([-1], torch.bool, True), - ] - ) - def forward(self, x): - return torch.ops.aten.nonzero(x) - - -@register_test_case(module_factory=lambda: AtenNonzero1DDynamicModule()) -def AtenNonzero1DDynamicModule_basic(module, tu: TestUtils): - module.forward(torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.bool)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 663c4b6a746b..e6332579d575 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -191,54 +191,6 @@ def Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier( module.forward(tu.rand(5, 4, 10, 20)) -class Conv2dWithSamePaddingModule(torch.nn.Module): - def __init__(self): - super().__init__() - torch.manual_seed(0) - self.conv = torch.nn.Conv2d(2, 10, 3, bias=False, padding="same") - self.train(False) - - @export - @annotate_args( - [ - None, - ([-1, -1, -1, -1], torch.float32, True), - ] - ) - def forward(self, x): - return self.conv(x) - - -@register_test_case(module_factory=lambda: Conv2dWithSamePaddingModule()) -def Conv2dWithSamePaddingModule_basic(module, tu: TestUtils): - t = tu.rand(5, 2, 10, 20) - module.forward(t) - - -class Conv2dWithValidPaddingModule(torch.nn.Module): - def __init__(self): - super().__init__() - torch.manual_seed(0) - self.conv = torch.nn.Conv2d(2, 10, 3, bias=False, padding="valid") - self.train(False) - - @export - @annotate_args( - [ - None, - ([-1, -1, -1, -1], torch.float32, True), - ] - ) - def forward(self, x): - return self.conv(x) - - -@register_test_case(module_factory=lambda: Conv2dWithValidPaddingModule()) -def Conv2dWithValidPaddingModule_basic(module, tu: TestUtils): - t = tu.rand(5, 2, 10, 20) - module.forward(t) - - # ============================================================================== @@ -1142,90 +1094,6 @@ def Conv1dDepthwiseWithPaddingDilationStrideStaticModule_basic(module, tu: TestU module.forward(inputVec, weight) -class Conv1dWithSamePaddingModule(torch.nn.Module): - def __init__(self): - super().__init__() - torch.manual_seed(0) - self.conv = torch.nn.Conv1d(2, 10, 3, bias=False, padding="same") - self.train(False) - - @export - @annotate_args( - [ - None, - ([-1, -1, -1], torch.float32, True), - ] - ) - def forward(self, x): - return self.conv(x) - - -@register_test_case(module_factory=lambda: Conv1dWithSamePaddingModule()) -def Conv1dWithSamePaddingModule_basic(module, tu: TestUtils): - t = tu.rand(5, 2, 10) - module.forward(t) - - -class Conv1dWithValidPaddingModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ([-1], torch.float32, True), - ] - ) - def forward(self, inputVec, weight, bias): - return torch.ops.aten.conv1d( - inputVec, - weight, - bias=bias, - stride=[1], - padding="valid", - dilation=[1], - groups=1, - ) - - -@register_test_case(module_factory=lambda: Conv1dWithValidPaddingModule()) -def Conv1dWithValidPaddingModule_basic(module, tu: TestUtils): - inputVec = tu.rand(2, 2, 6) - weight = torch.randn(8, 2, 3) - bias = torch.randn(8) - module.forward(inputVec, weight, bias) - - -class Conv1dGroupModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([-1, -1, -1], torch.float32, True), - ([-1, -1, -1], torch.float32, True), - ([-1], torch.float32, True), - ] - ) - def forward(self, inputVec, weight, bias): - return torch.ops.aten.conv1d( - inputVec, weight, bias=bias, stride=[1], padding=[0], dilation=[1], groups=2 - ) - - -@register_test_case(module_factory=lambda: Conv1dGroupModule()) -def Conv1dGroupModule_basic(module, tu: TestUtils): - inputVec = tu.rand(2, 4, 6) - weight = torch.randn(8, 2, 3) - bias = torch.randn(8) - module.forward(inputVec, weight, bias) - - class Conv2dModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1292,72 +1160,6 @@ def Conv3dModule_basic(module, tu: TestUtils): module.forward(inputVec, weight, bias) -class Conv3dWithSamePaddingModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([-1, -1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1, -1], torch.float32, True), - ([-1], torch.float32, True), - ] - ) - def forward(self, inputVec, weight, bias): - return torch.ops.aten.conv3d( - inputVec, - weight, - bias=bias, - stride=[1, 1, 1], - padding="same", - dilation=[1, 1, 1], - groups=1, - ) - - -@register_test_case(module_factory=lambda: Conv3dWithSamePaddingModule()) -def Conv3dWithSamePaddingModule_basic(module, tu: TestUtils): - inputVec = tu.rand(2, 2, 6, 6, 6) - weight = torch.randn(8, 2, 3, 3, 3) - bias = torch.randn(8) - module.forward(inputVec, weight, bias) - - -class Conv3dWithValidPaddingModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([-1, -1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1, -1], torch.float32, True), - ([-1], torch.float32, True), - ] - ) - def forward(self, inputVec, weight, bias): - return torch.ops.aten.conv3d( - inputVec, - weight, - bias=bias, - stride=[1, 1, 1], - padding="valid", - dilation=[1, 1, 1], - groups=1, - ) - - -@register_test_case(module_factory=lambda: Conv3dWithValidPaddingModule()) -def Conv3dWithValidPaddingModule_basic(module, tu: TestUtils): - inputVec = tu.rand(2, 2, 6, 6, 6) - weight = torch.randn(8, 2, 3, 3, 3) - bias = torch.randn(8) - module.forward(inputVec, weight, bias) - - class ConvTbcModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 3ee851611ac0..a6679ec4dfc4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1240,20 +1240,13 @@ def __init__(self): [None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)] ) def forward(self, x, noise): - out, out_noise = torch.ops.aten.rrelu_with_noise_functional( - x, noise, 0.2, 0.5, True - ) - return ( - torch.mean(out), - torch.std(out), - torch.mean(out_noise), - torch.std(out_noise), - ) + res = torch.ops.aten.rrelu_with_noise(x, noise, 0.2, 0.5, True) + return torch.mean(res), torch.std(res) @register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainModule()) def ElementwiseRreluWithNoiseTrainModule_basic(module, tu: TestUtils): - module.forward(tu.rand(256, 256, low=-1, high=1), tu.rand(256, 256)) + module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128)) # ============================================================================== @@ -1265,23 +1258,16 @@ def __init__(self): @export @annotate_args( - [None, ([256, 256], torch.float32, True), ([256, 256], torch.float32, True)] + [None, ([128, 128], torch.float32, True), ([128, 128], torch.float32, True)] ) def forward(self, x, noise): - out, out_noise = torch.ops.aten.rrelu_with_noise_functional( - x, noise, 0.4, 0.6, True - ) - return ( - torch.mean(out), - torch.std(out), - torch.mean(out_noise), - torch.std(out_noise), - ) + res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, True) + return torch.mean(res), torch.std(res) @register_test_case(module_factory=lambda: ElementwiseRreluWithNoiseTrainStaticModule()) def ElementwiseRreluWithNoiseTrainStaticModule_basic(module, tu: TestUtils): - module.forward(tu.rand(256, 256, low=-1, high=1), tu.rand(256, 256)) + module.forward(tu.rand(128, 128, low=-1, high=1), tu.rand(128, 128)) # ============================================================================== @@ -1296,7 +1282,7 @@ def __init__(self): [None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)] ) def forward(self, x, noise): - res = torch.ops.aten.rrelu_with_noise_functional(x, noise, 0.4, 0.6, False)[0] + res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, False) return torch.mean(res), torch.std(res) @@ -1315,7 +1301,7 @@ def __init__(self): @export @annotate_args([None, ([5, 3], torch.float32, True), ([5, 3], torch.float32, True)]) def forward(self, x, noise): - res = torch.ops.aten.rrelu_with_noise_functional(x, noise, 0.4, 0.6, False)[0] + res = torch.ops.aten.rrelu_with_noise(x, noise, 0.4, 0.6, False) return torch.mean(res), torch.std(res) @@ -5221,7 +5207,7 @@ def __init__(self): ] ) def forward(self, a): - return torch.expm1(a) + return torch.special.expm1(a) @register_test_case(module_factory=lambda: ElementwiseExpm1Module()) @@ -5244,7 +5230,7 @@ def __init__(self): ] ) def forward(self, a): - return torch.expm1(a) + return torch.special.expm1(a) @register_test_case(module_factory=lambda: ElementwiseExpm1IntModule()) @@ -5255,52 +5241,6 @@ def ElementwiseExpm1IntModule_basic(module, tu: TestUtils): # ============================================================================== -class ElementwiseSpecialExpm1Module(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([-1, -1], torch.float32, True), - ] - ) - def forward(self, a): - return torch.special.expm1(a) - - -@register_test_case(module_factory=lambda: ElementwiseSpecialExpm1Module()) -def ElementwiseSpecialExpm1Module_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4)) - - -# ============================================================================== - - -class ElementwiseSpecialExpm1IntModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([-1, -1], torch.int32, True), - ] - ) - def forward(self, a): - return torch.special.expm1(a) - - -@register_test_case(module_factory=lambda: ElementwiseSpecialExpm1IntModule()) -def ElementwiseSpecialExpm1IntModule_basic(module, tu: TestUtils): - module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) - - -# ============================================================================== - - class ElementwiseRad2DegModule(torch.nn.Module): def __init__(self): super().__init__() @@ -7233,26 +7173,3 @@ def forward(self): @register_test_case(module_factory=lambda: TrilIndicesOfssetGreaterThanRowModule()) def TrilIndicesOfssetGreaterThanRowModule_basic(module, tu: TestUtils): module.forward() - - -# ============================================================================== - - -class Deg2radModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([3, 4], torch.float32, True), - ] - ) - def forward(self, x): - return torch.ops.aten.deg2rad(x) - - -@register_test_case(module_factory=lambda: Deg2radModule()) -def Deg2radModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index e2eaa4cfd0fe..84e0e2eb9cf5 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -420,35 +420,6 @@ def MaxPool2dCeilModeTrueModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 20, 20, low=0.5, high=1.0)) -class MaxPool2dStaticCeilModeTrueReduceOutputModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.mp2d = torch.nn.MaxPool2d( - kernel_size=6, - stride=6, - padding=3, - dilation=1, - ceil_mode=True, - ) - - @export - @annotate_args( - [ - None, - ([2, 6, 20, 10], torch.float32, True), - ] - ) - def forward(self, x): - return self.mp2d(x) - - -@register_test_case( - module_factory=lambda: MaxPool2dStaticCeilModeTrueReduceOutputModule() -) -def MaxPool2dStaticCeilModeTrueReduceOutputModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 6, 20, 10, low=0.5, high=1.0)) - - # ============================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 3e379deacb79..89774c5d13b1 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -2260,78 +2260,6 @@ def MseLossSumReductionWithDifferentElemTypeModule_basic(module, tu: TestUtils): # ============================================================================== -class L1LossNoReductionModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([2, 4], torch.float32, True), - ([2, 4], torch.float32, True), - ] - ) - def forward(self, x, y): - return torch.ops.aten.l1_loss(x, y, reduction=0) - - -@register_test_case(module_factory=lambda: L1LossNoReductionModule()) -def L1LossNoReductionModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 4), tu.rand(2, 4)) - - -# ============================================================================== - - -class L1LossMeanReductionModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([2, 4], torch.float32, True), - ([2, 4], torch.float32, True), - ] - ) - def forward(self, x, y): - return torch.ops.aten.l1_loss(x, y, reduction=1) - - -@register_test_case(module_factory=lambda: L1LossMeanReductionModule()) -def L1LossMeanReductionModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 4), tu.rand(2, 4)) - - -# ============================================================================== - - -class L1LossSumReductionModule(torch.nn.Module): - def __init__(self): - super().__init__() - - @export - @annotate_args( - [ - None, - ([2, 4], torch.float32, True), - ([2, 4], torch.float32, True), - ] - ) - def forward(self, x, y): - return torch.ops.aten.l1_loss(x, y, reduction=2) - - -@register_test_case(module_factory=lambda: L1LossSumReductionModule()) -def L1LossSumReductionModule_basic(module, tu: TestUtils): - module.forward(tu.rand(2, 4), tu.rand(2, 4)) - - -# ============================================================================== - - class CrossEntropyLossModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index d1ddc42b39b1..a8820f59c373 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -1752,7 +1752,7 @@ def forward(self, x): return x.unfold(0, 0, 1) -@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero_Size_Zero()) +@register_test_case(module_factory=lambda: Unfold_Module_Rank_Zero()) def Unfold_Module_Rank_Zero_Size_Zero_basic(module, tu: TestUtils): module.forward(tu.rand()) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/spectral.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/spectral.py index 57a7270f9d09..8e259fbe0c2a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/spectral.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/spectral.py @@ -51,43 +51,3 @@ def forward(self): @register_test_case(module_factory=lambda: AtenHannWindowPeriodicTrueModule()) def AtenHannWindowPeriodicTrueModule_basic(module, tu: TestUtils): module.forward() - - -# ============================================================================== - - -class AtenFftRfft2DLastDim(torch.nn.Module): - @export - @annotate_args( - [ - None, - ([16, 9], torch.float32, True), - ] - ) - def forward(self, input): - return torch.fft.rfft(input, dim=-1) - - -@register_test_case(module_factory=lambda: AtenFftRfft2DLastDim()) -def AtenFftRfft2DLastDim_basic(module, tu: TestUtils): - module.forward(tu.rand(16, 9)) - - -# ============================================================================== - - -class AtenFftRfft2DMiddleDim(torch.nn.Module): - @export - @annotate_args( - [ - None, - ([36, 10], torch.float32, True), - ] - ) - def forward(self, input): - return torch.fft.rfft(input, dim=0) - - -@register_test_case(module_factory=lambda: AtenFftRfft2DMiddleDim()) -def AtenFftRfft2DMiddleDim_basic(module, tu: TestUtils): - module.forward(tu.rand(36, 10)) diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index 9aa2ae8994e4..9fe29212386a 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -739,10 +739,7 @@ def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType: def _sanitize_name(self, name): if not name.isidentifier(): name = "_" + name - - # Remove characters that are invalid in MLIR identifier names. - # https://mlir.llvm.org/docs/LangRef/#identifiers-and-keywords - return re.sub("[:/-]", "_", name) + return re.sub("[:/]", "_", name) def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: tensor_type = self.tensor_proto_to_builtin_type(tp) diff --git a/python/torch_mlir/tools/import_onnx/__main__.py b/python/torch_mlir/tools/import_onnx/__main__.py index 4f852d34bb0a..fa0e2a89dbba 100644 --- a/python/torch_mlir/tools/import_onnx/__main__.py +++ b/python/torch_mlir/tools/import_onnx/__main__.py @@ -137,7 +137,7 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: # Load the temp file and the external data. inferred_model = onnx.load(temp_inferred_file, load_external_data=False) data_dir = Path(input_dir if args.temp_dir is None else args.data_dir) - onnx.load_external_data_for_model(inferred_model, str(data_dir)) + onnx.load_external_data_for_model(inferred_model, data_dir) # Remove the inferred shape file unless asked to keep it if not args.keep_temps: diff --git a/pytorch-hash.txt b/pytorch-hash.txt index 0439f8244a0b..dd4f3a19ad33 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -3f159d635772fa2a8fd352d96b95100d885f8169 +c787213d413e85c66bdad0d8c9cde1c5ced34b1b diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 7ab5a78d074f..960ca904e045 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch/ --pre -torch==2.6.0.dev20241216 +torch==2.6.0.dev20241029 diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 30b85e63ab0f..d567db79fdf8 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1182,9 +1182,12 @@ func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3 func.func @test_pow_i32(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[POW:.+]] = torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32> -> !torch.vtensor<[3,4,5],f64> + // CHECK: %[[DTY:.+]] = torch.constant.int 6 + // CHECK: %[[CAST_LHS:.+]] = torch.aten.to.dtype %arg0, %[[DTY]], %[[FALSE]], %[[FALSE]], %[[NONE]] + // CHECK: %[[CAST_RHS:.+]] = torch.aten.to.dtype %arg1, %[[DTY]], %[[FALSE]], %[[FALSE]], %[[NONE]] + // CHECK: %[[POW:.+]] = torch.aten.pow.Tensor_Tensor %[[CAST_LHS]], %[[CAST_RHS]] // CHECK: %[[DTY:.+]] = torch.constant.int 3 - // CHECK: %[[RES:.+]] = torch.aten.to.dtype %[[POW]], %[[DTY]], %[[FALSE]], %[[FALSE]], %[[NONE]] + // CHECK: %[[RES:.+]] = torch.aten.to.dtype %2, %[[DTY]], %[[FALSE]], %[[FALSE]], %[[NONE]] // CHECK: return %[[RES]] %0 = torch.operator "onnx.Pow"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32> return %0 : !torch.vtensor<[3,4,5],si32> @@ -1580,14 +1583,12 @@ func.func @test_nllloss_iii_reduction_none_ignore_negative(%arg0: !torch.vtensor // ----- -func.func @test_nonzero(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1,?],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[ZERO:.*]] = torch.constant.int 0 - // CHECK: %[[ONE:.*]] = torch.constant.int 1 - // CHECK: %[[NONZERO:.*]] = torch.aten.nonzero %arg0 : !torch.vtensor<[?],f32> -> !torch.vtensor<[?,1],si64> - // CHECK: %[[TRANSPOSE:.*]] = torch.aten.transpose.int %[[NONZERO]], %[[ZERO]], %[[ONE]] : !torch.vtensor<[?,1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,?],si64> - %0 = torch.operator "onnx.NonZero"(%arg0) : (!torch.vtensor<[?],f32>) -> !torch.vtensor<[1,?],si64> - return %0 : !torch.vtensor<[1,?],si64> -} +// CHECK-LABEL: func.func @test_nonzero + func.func @test_nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64> + %0 = torch.operator "onnx.NonZero"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> + return %0 : !torch.vtensor<[3,4,5],si64> + } // ----- @@ -2053,34 +2054,26 @@ func.func @test_nonmaxsuppression_identical_boxes(%arg0: !torch.vtensor<[1,10,4] // CHECK: torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1." // CHECK: %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,10],f32>, !torch.int -> !torch.vtensor<[10],f32> // CHECK: %[[VAL_20:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float - // CHECK: %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[10],f32> -> !torch.vtensor<[],f32> - // CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[10],f32> -> !torch.vtensor<*,f32> + // CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<*,f32> -> !torch.float // CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool // CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)" - // CHECK: %[[VAL_24:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_25:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_26:.*]] = torch.constant.float 0.000000e+00 - // CHECK: %[[VAL_27:.*]] = torch.aten.item %arg3 : !torch.vtensor<[1],f32> -> !torch.float - // CHECK: %[[VAL_28:.*]] = torch.aten.item %arg2 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[VAL_29:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_27]] : !torch.vtensor<[10,4],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[?],si64> - // CHECK: %[[VAL_30:.*]] = torch.aten.size.int %[[VAL_29]], %[[VAL_24]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int - // CHECK: %[[VAL_31:.*]] = torch.aten.gt.int %[[VAL_30]], %[[VAL_28]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[VAL_32:.*]] = torch.prim.If %[[VAL_31]] -> (!torch.vtensor<[1],si64>) - // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %[[VAL_29]], %[[VAL_24]], %[[VAL_24]], %[[VAL_28]], %[[VAL_25]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.prim.If.yield %[[SLICE]] : !torch.vtensor<[1],si64> - // CHECK: } else { - // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[VAL_29]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64> - // CHECK: torch.prim.If.yield %[[CAST]] : !torch.vtensor<[1],si64> - // CHECK: } - // CHECK: %[[VAL_33:.*]] = torch.aten.unsqueeze %[[VAL_32]], %[[VAL_25]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> - // CHECK: %[[VAL_34:.*]] = torch.aten.size.int %[[VAL_33]], %[[VAL_24]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int - // CHECK: %[[VAL_35:.*]] = torch.constant.int 2 - // CHECK: %[[VAL_36:.*]] = torch.prim.ListConstruct %[[VAL_34]], %[[VAL_35]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[VAL_37:.*]] = torch.constant.none - // CHECK: %[[VAL_38:.*]] = torch.aten.zeros %[[VAL_36]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> - // CHECK: %[[VAL_39:.*]] = torch.prim.ListConstruct %[[VAL_38]], %[[VAL_33]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list - // CHECK: %[[VAL_40:.*]] = torch.aten.cat %[[VAL_39]], %[[VAL_25]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> - // CHECK: return %[[VAL_40]] : !torch.vtensor<[1,3],si64> + // CHECK: %[[VAL_24:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[VAL_25:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_24]] : !torch.vtensor<[10,4],f32>, !torch.vtensor<[10],f32>, !torch.float -> !torch.vtensor<[1,3],si64> + // CHECK: %[[VAL_26:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_27:.*]] = torch.aten.unsqueeze %[[VAL_25]], %[[VAL_26]] : !torch.vtensor<[1,3],si64>, !torch.int -> !torch.vtensor<[1,1,3],si64> + // CHECK: %[[VAL_28:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_29:.*]] = torch.aten.size.int %[[VAL_27]], %[[VAL_28]] : !torch.vtensor<[1,1,3],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_30:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_31:.*]] = torch.prim.ListConstruct %[[VAL_29]], %[[VAL_30]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_32:.*]] = torch.constant.none + // CHECK: %[[VAL_33:.*]] = torch.aten.zeros %[[VAL_31]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> + // CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_27]], %[[VAL_33]] : (!torch.vtensor<[1,1,3],si64>, !torch.vtensor<[1,2],si64>) -> !torch.list + // CHECK: %[[VAL_35:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[VAL_36:.*]] = torch.aten.le.int %[[VAL_29]], %[[VAL_35]] : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.runtime.assert %[[VAL_36]], "unimplemented: number of output boxes per class should be <= max_output_boxes_per_class" + // CHECK: %[[VAL_37:.*]] = torch.aten.cat %[[VAL_34]], %[[VAL_26]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> + // CHECK: return %[[VAL_37]] : !torch.vtensor<[1,3],si64> %0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,10,4],f32>, !torch.vtensor<[1,1,10],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> return %0 : !torch.vtensor<[1,3],si64> } @@ -2113,34 +2106,27 @@ func.func @test_nonmaxsuppression_single_box(%arg0: !torch.vtensor<[1,1,4],f32>, // CHECK: torch.runtime.assert %[[VAL_18]], "squeeze operation possible for dim only when input_shape[dim] == 1." // CHECK: %[[VAL_19:.*]] = torch.aten.squeeze.dim %[[VAL_14]], %[[VAL_15]] : !torch.vtensor<[1,1],f32>, !torch.int -> !torch.vtensor<[1],f32> // CHECK: %[[VAL_20:.*]] = torch.aten.item %[[VAL_4]] : !torch.vtensor<[1],f32> -> !torch.float - // CHECK: %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[1],f32> -> !torch.vtensor<[],f32> - // CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[VAL_21:.*]] = torch.aten.min %[[VAL_19]] : !torch.vtensor<[1],f32> -> !torch.vtensor<*,f32> + // CHECK: %[[VAL_22:.*]] = torch.aten.item %[[VAL_21]] : !torch.vtensor<*,f32> -> !torch.float // CHECK: %[[VAL_23:.*]] = torch.aten.ge.float %[[VAL_22]], %[[VAL_20]] : !torch.float, !torch.float -> !torch.bool // CHECK: torch.runtime.assert %[[VAL_23]], "unimplemented: score_threshold should be <= min(scores)" - // CHECK: %[[VAL_24:.*]] = torch.constant.int 0 - // CHECK: %[[VAL_25:.*]] = torch.constant.int 1 - // CHECK: %[[VAL_26:.*]] = torch.constant.float 0.000000e+00 - // CHECK: %[[VAL_27:.*]] = torch.aten.item %arg3 : !torch.vtensor<[1],f32> -> !torch.float - // CHECK: %[[VAL_28:.*]] = torch.aten.item %arg2 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[VAL_29:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_27]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[?],si64> - // CHECK: %[[VAL_30:.*]] = torch.aten.size.int %[[VAL_29]], %[[VAL_24]] : !torch.vtensor<[?],si64>, !torch.int -> !torch.int - // CHECK: %[[VAL_31:.*]] = torch.aten.gt.int %[[VAL_30]], %[[VAL_28]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[VAL_32:.*]] = torch.prim.If %[[VAL_31]] -> (!torch.vtensor<[1],si64>) - // CHECK: %[[SLICE:.*]] = torch.aten.slice.Tensor %[[VAL_29]], %[[VAL_24]], %[[VAL_24]], %[[VAL_28]], %[[VAL_25]] : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.prim.If.yield %[[SLICE]] : !torch.vtensor<[1],si64> - // CHECK: } else { - // CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[VAL_29]] : !torch.vtensor<[?],si64> to !torch.vtensor<[1],si64> - // CHECK: torch.prim.If.yield %[[CAST]] : !torch.vtensor<[1],si64> - // CHECK: } - // CHECK: %[[VAL_33:.*]] = torch.aten.unsqueeze %[[VAL_32]], %[[VAL_25]] : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[1,1],si64> - // CHECK: %[[VAL_34:.*]] = torch.aten.size.int %[[VAL_33]], %[[VAL_24]] : !torch.vtensor<[1,1],si64>, !torch.int -> !torch.int - // CHECK: %[[VAL_35:.*]] = torch.constant.int 2 - // CHECK: %[[VAL_36:.*]] = torch.prim.ListConstruct %[[VAL_34]], %[[VAL_35]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[VAL_37:.*]] = torch.constant.none - // CHECK: %[[VAL_38:.*]] = torch.aten.zeros %[[VAL_36]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]], %[[VAL_37]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> - // CHECK: %[[VAL_39:.*]] = torch.prim.ListConstruct %[[VAL_38]], %[[VAL_33]] : (!torch.vtensor<[1,2],si64>, !torch.vtensor<[1,1],si64>) -> !torch.list - // CHECK: %[[VAL_40:.*]] = torch.aten.cat %[[VAL_39]], %[[VAL_25]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> - // CHECK: return %[[VAL_40]] : !torch.vtensor<[1,3],si64> + // CHECK: %[[VAL_24:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[VAL_25:.*]] = torch.torchvision.nms %[[VAL_9]], %[[VAL_19]], %[[VAL_24]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[1],f32>, !torch.float -> !torch.vtensor<[1,3],si64> + // CHECK: %[[VAL_26:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_27:.*]] = torch.aten.unsqueeze %[[VAL_25]], %[[VAL_26]] : !torch.vtensor<[1,3],si64>, !torch.int -> !torch.vtensor<[1,1,3],si64> + // CHECK: %[[VAL_28:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_29:.*]] = torch.aten.size.int %[[VAL_27]], %[[VAL_28]] : !torch.vtensor<[1,1,3],si64>, !torch.int -> !torch.int + // CHECK: %[[VAL_30:.*]] = torch.constant.int 2 + // CHECK: %[[VAL_31:.*]] = torch.prim.ListConstruct %[[VAL_29]], %[[VAL_30]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[VAL_32:.*]] = torch.constant.none + // CHECK: %[[VAL_33:.*]] = torch.aten.zeros %[[VAL_31]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]], %[[VAL_32]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1,2],si64> + // CHECK: %[[VAL_34:.*]] = torch.prim.ListConstruct %[[VAL_27]], %[[VAL_33]] : (!torch.vtensor<[1,1,3],si64>, !torch.vtensor<[1,2],si64>) -> !torch.list + // CHECK: %[[VAL_35:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[VAL_36:.*]] = torch.aten.le.int %[[VAL_29]], %[[VAL_35]] : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.runtime.assert %[[VAL_36]], "unimplemented: number of output boxes per class should be <= max_output_boxes_per_class" + // CHECK: %[[VAL_37:.*]] = torch.aten.cat %[[VAL_34]], %[[VAL_26]] : !torch.list, !torch.int -> !torch.vtensor<[1,3],si64> + // CHECK: return %[[VAL_37]] : !torch.vtensor<[1,3],si64> + // CHECK: } %0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[1,1,4],f32>, !torch.vtensor<[1,1,1],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,3],si64> return %0 : !torch.vtensor<[1,3],si64> } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 16c86218dbc8..30fd60dbde3a 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -707,8 +707,17 @@ func.func @test_reduce_max_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %a // CHECK-LABEL: func.func @test_reduce_max_bool_inputs func.func @test_reduce_max_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[I1:.*]] = torch.constant.int 1 - // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list + // CHECK: %[[IDX:.+]] = torch.constant.int 0 + // CHECK: %[[SZ:.+]] = torch.constant.int 0 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LST]], %[[TRUE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4,1],i1> // CHECK: return %[[AMAX]] : !torch.vtensor<[4,1],i1> @@ -720,8 +729,17 @@ func.func @test_reduce_max_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: ! // CHECK-LABEL: func.func @test_reduce_max_bool_inputs_nokeepdims func.func @test_reduce_max_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[I1:.*]] = torch.constant.int 1 - // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list + // CHECK: %[[IDX:.+]] = torch.constant.int 0 + // CHECK: %[[SZ:.+]] = torch.constant.int 0 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> // CHECK: return %[[AMAX]] : !torch.vtensor<[4],i1> @@ -733,9 +751,19 @@ func.func @test_reduce_max_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1 // CHECK-LABEL: func.func @test_reduce_max_all_dims_default func.func @test_reduce_max_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[I0:.*]] = torch.constant.int 0 - // CHECK: %[[I1:.*]] = torch.constant.int 1 - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[I0]], %[[I1]] + // CHECK: %[[I0:.+]] = torch.constant.int 0 + // CHECK: %[[I1:.+]] = torch.constant.int 1 + // CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I0]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[A0:.+]] = torch.aten.add.int %[[I0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I1]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[A1:.+]] = torch.aten.add.int %[[I1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[A0]], %[[A1]] // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[MAX:.+]] = torch.aten.amax %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[],i1> // CHECK: return %[[MAX]] : !torch.vtensor<[],i1> @@ -747,7 +775,13 @@ func.func @test_reduce_max_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> func.func @test_reduce_max_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> // CHECK: return %[[AMAX]] @@ -759,12 +793,9 @@ func.func @test_reduce_max_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtens // CHECK-LABEL: func.func @test_reduce_l1_default_axes_keepdims_example func.func @test_reduce_l1_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[ABS:.*]] = torch.aten.abs %arg0 + // CHECK: %[[ABS:.+]] = torch.aten.abs %arg0 : !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 - // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[ABS]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> @@ -814,11 +845,8 @@ func.func @test_reduce_l1_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f // CHECK-LABEL: func.func @test_reduce_l2_default_axes_keepdims_example func.func @test_reduce_l2_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> - // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[TRUE_0:.+]] = torch.constant.bool true // CHECK: %[[NONE_0:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE_0]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> @@ -916,10 +944,7 @@ func.func @test_reduce_l2_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2 // CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> @@ -975,10 +1000,7 @@ func.func @test_reduce_log_sum_exp_default_axes_keepdims_example(%arg0: !torch.v // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE_1:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE_1]] : !torch.vtensor<[3,2,2],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f64> @@ -1070,10 +1092,7 @@ func.func @test_reduce_log_sum_exp_keep_dims_int_input_example(%arg0: !torch.vte // CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> @@ -1158,10 +1177,7 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor< func.func @test_reduce_sum_square_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[MULT:.+]] = torch.aten.mul.Tensor %arg0, %arg0 : !torch.vtensor<[3,2,2],f32>, !torch.vtensor<[3,2,2],f32> -> !torch.vtensor<[3,2,2],f32> // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[MULT]], %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> @@ -1369,8 +1385,17 @@ func.func @test_reduce_min_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %a // CHECK-LABEL: func.func @test_reduce_min_bool_inputs func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[I1:.+]] = torch.constant.int 1 - // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list + // CHECK: %[[IDX:.+]] = torch.constant.int 0 + // CHECK: %[[SZ:.+]] = torch.constant.int 0 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LST]], %[[TRUE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4,1],i1> // CHECK: return %[[AMIN]] : !torch.vtensor<[4,1],i1> @@ -1382,8 +1407,17 @@ func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: ! // CHECK-LABEL: func.func @test_reduce_min_bool_inputs_nokeepdims func.func @test_reduce_min_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[I1:.+]] = torch.constant.int 1 - // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list + // CHECK: %[[IDX:.+]] = torch.constant.int 0 + // CHECK: %[[SZ:.+]] = torch.constant.int 0 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> // CHECK: return %[[AMIN]] : !torch.vtensor<[4],i1> @@ -1397,7 +1431,17 @@ func.func @test_reduce_min_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1 func.func @test_reduce_min_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[I0:.+]] = torch.constant.int 0 // CHECK: %[[I1:.+]] = torch.constant.int 1 - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[I0]], %[[I1]] + // CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I0]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[A0:.+]] = torch.aten.add.int %[[I0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I1]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[A1:.+]] = torch.aten.add.int %[[I1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[A0]], %[[A1]] // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[MIN:.+]] = torch.aten.amin %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[],i1> // CHECK: return %[[MIN]] : !torch.vtensor<[],i1> @@ -1409,7 +1453,13 @@ func.func @test_reduce_min_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> func.func @test_reduce_min_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> // CHECK: return %[[AMIN]] diff --git a/test/Conversion/TorchToArith/basic.mlir b/test/Conversion/TorchToArith/basic.mlir index 8fa13b47e588..86ad4e972f8e 100644 --- a/test/Conversion/TorchToArith/basic.mlir +++ b/test/Conversion/TorchToArith/basic.mlir @@ -236,34 +236,6 @@ func.func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.in return %0 : !torch.int } -// CHECK-LABEL: func.func @torch.aten.mul.int_float( -// CHECK-SAME: %[[LHS:.*]]: !torch.int, -// CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float { -// CHECK-DAG: %[[LHS_I64:.*]] = torch_c.to_i64 %[[LHS]] -// CHECK-DAG: %[[RHS_F64:.*]] = torch_c.to_f64 %[[RHS]] -// CHECK: %[[LHS_F64:.*]] = arith.sitofp %[[LHS_I64]] : i64 to f64 -// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_F64]], %[[RHS_F64]] : f64 -// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[MUL]] -// CHECK: return %[[OUT]] : !torch.float -func.func @torch.aten.mul.int_float(%arg0: !torch.int, %arg1: !torch.float) -> !torch.float { - %0 = torch.aten.mul.int_float %arg0, %arg1 : !torch.int, !torch.float -> !torch.float - return %0 : !torch.float -} - -// CHECK-LABEL: func.func @torch.aten.mul.float_int( -// CHECK-SAME: %[[LHS:.*]]: !torch.float, -// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.float { -// CHECK-DAG: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] -// CHECK-DAG: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] -// CHECK: %[[RHS_F64:.*]] = arith.sitofp %[[RHS_I64]] : i64 to f64 -// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_F64]], %[[RHS_F64]] : f64 -// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[MUL]] -// CHECK: return %[[OUT]] : !torch.float -func.func @torch.aten.mul.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.float { - %0 = torch.aten.mul.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.float - return %0 : !torch.float -} - // CHECK-LABEL: func.func @torch.aten.div.float( // CHECK-SAME: %[[LHS:.*]]: !torch.float, // CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float { diff --git a/test/Conversion/TorchToLinalg/datamovement.mlir b/test/Conversion/TorchToLinalg/datamovement.mlir deleted file mode 100644 index dd5e5c553d31..000000000000 --- a/test/Conversion/TorchToLinalg/datamovement.mlir +++ /dev/null @@ -1,34 +0,0 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -verify-diagnostics | FileCheck %s - -// CHECK-LABEL: func.func @torch.aten.permute( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[64,32,16,8,4],f32>) -> !torch.vtensor<[64,8,4,32,16],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[64,32,16,8,4],f32> -> tensor<64x32x16x8x4xf32> -// CHECK: %[[VAL_2:.*]] = tensor.empty() : tensor<64x8x4x32x16xf32> -// CHECK: %[[VAL_3:.*]] = linalg.transpose ins(%[[VAL_1]] : tensor<64x32x16x8x4xf32>) outs(%[[VAL_2]] : tensor<64x8x4x32x16xf32>) permutation = [0, 3, 4, 1, 2] -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<64x8x4x32x16xf32> -> !torch.vtensor<[64,8,4,32,16],f32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[64,8,4,32,16],f32> -// CHECK: } -func.func @torch.aten.permute(%arg0: !torch.vtensor<[64,32,16,8,4],f32>) -> !torch.vtensor<[64,8,4,32,16],f32> { - %int0 = torch.constant.int 0 - %int3 = torch.constant.int 3 - %int4 = torch.constant.int 4 - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 - %0 = torch.prim.ListConstruct %int0, %int3, %int4, %int1, %int2 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[64,32,16,8,4],f32>, !torch.list -> !torch.vtensor<[64,8,4,32,16],f32> - return %1 : !torch.vtensor<[64,8,4,32,16],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.permute$rank0( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = torch_c.from_builtin_tensor %[[VAL_1]] : tensor -> !torch.vtensor<[],f32> -// CHECK: return %[[VAL_2]] : !torch.vtensor<[],f32> -// CHECK: } -func.func @torch.aten.permute$rank0(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { - %0 = torch.prim.ListConstruct : () -> !torch.list - %1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[],f32>, !torch.list -> !torch.vtensor<[],f32> - return %1 : !torch.vtensor<[],f32> -} diff --git a/test/Conversion/TorchToLinalg/resize.mlir b/test/Conversion/TorchToLinalg/resize.mlir index 1dfe45492312..7976b1ad8b16 100644 --- a/test/Conversion/TorchToLinalg/resize.mlir +++ b/test/Conversion/TorchToLinalg/resize.mlir @@ -21,14 +21,14 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: // CHECK-DAG: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32 // CHECK-DAG: %[[x26:.*]] = arith.subf %[[x25]], %[[cst_4]] : f32 // CHECK-DAG: %[[x27:.*]] = arith.maximumf %[[x26]], %[[cst_5]] : f32 - // CHECK-DAG: %[[x28:.*]] = arith.subf %[[x19]], %cst_4 : f32 + // CHECK-DAG: %[[x28:.*]] = arith.subf %[[x19]], %[[cst]] : f32 // CHECK-DAG: %[[x29:.*]] = arith.minimumf %[[x27]], %[[x28]] : f32 // CHECK-DAG: %[[x30:.*]] = math.floor %[[x29]] : f32 // CHECK-DAG: %[[x31:.*]] = arith.addf %[[cst]], %[[x29]] : f32 // CHECK-DAG: %[[x32:.*]] = math.floor %[[x31]] : f32 // CHECK-DAG: %[[x33:.*]] = arith.fptosi %[[x30]] : f32 to i64 // CHECK-DAG: %[[x34:.*]] = arith.index_cast %[[x33]] : i64 to index - // CHECK-DAG: %[[x35:.*]] = arith.minimumf %44, %42 : f32 + // CHECK-DAG: %[[x35:.*]] = arith.minimumf %[[x31]], %[[x28]] : f32 // CHECK-DAG: %[[x36:.*]] = arith.fptosi %[[x35]] : f32 to i64 // CHECK-DAG: %[[x37:.*]] = arith.index_cast %[[x36]] : i64 to index // CHECK: %[[extracted:.*]] = tensor.extract %[[x0]][%[[x15]], %[[x16]], %[[x34]], %[[low:.*]]] : tensor<1x1x2x4xf32> @@ -304,51 +304,4 @@ func.func @test_resize_nearest_half_pixel_round_prefer_floor(%arg0: !torch.vtens return %5 : !torch.vtensor<[?,?,?],f32> } -// CHECK-LABEL: func.func @test_resize_sizes_cubic -func.func @test_resize_sizes_cubic(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4] -,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 -: si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK-DAG: %[[x1:.*]] = math.ceil %36 : f32 - // CHECK-DAG: %[[x_1:.*]] = arith.subf %[[x1]], %cst_5 : f32 - // CHECK-DAG: %[[x_2:.*]] = arith.subf %[[x_1]], %cst_5 : f32 - // CHECK-DAG: %[[x2:.*]] = arith.addf %[[x1]], %cst_5 : f32 - // CHECK-DAG: %[[y1:.*]] = math.ceil %28 : f32 - // CHECK-DAG: %[[y_1:.*]] = arith.subf %[[y1]], %cst_5 : f32 - // CHECK-DAG: %[[y_2:.*]] = arith.subf %[[y_1]], %cst_5 : f32 - // CHECK-DAG: %[[y2:.*]] = arith.addf %[[y1]], %cst_5 : f32 - // CHECK-DAG: %[[y2D:.*]] = arith.subf %28, %[[y2]] : f32 - // CHECK-DAG: %[[y2Dist:.*]] = math.absf %[[y2D]] : f32 - // CHECK-DAG: %[[y1D:.*]] = arith.subf %28, %[[y1]] : f32 - // CHECK-DAG: %[[y1Dist:.*]] = math.absf %[[y1D]] : f32 - // CHECK-DAG: %[[y_1D:.*]] = arith.subf %28, %[[y_1]] : f32 - // CHECK-DAG: %[[y_1Dist:.*]] = math.absf %[[y_1D]] : f32 - // CHECK-DAG: %[[y_2D:.*]] = arith.subf %28, %[[y_2]] : f32 - // CHECK-DAG: %[[y_2Dist:.*]] = math.absf %[[y_2D]] : f32 - // CHECK-DAG: %[[x2D:.*]] = arith.subf %36, %[[x2]] : f32 - // CHECK-DAG: %[[x2Dist:.*]] = math.absf %[[x2D]] : f32 - // CHECK-DAG: %[[x1D:.*]] = arith.subf %36, %[[x1]] : f32 - // CHECK-DAG: %[[x1Dist:.*]] = math.absf %[[x1D]] : f32 - // CHECK-DAG: %[[x_1D:.*]] = arith.subf %36, %[[x_1]] : f32 - // CHECK-DAG: %[[x_1Dist:.*]] = math.absf %[[x_1D]] : f32 - // CHECK-DAG: %[[x_2D:.*]] = arith.subf %36, %[[x_2]] : f32 - // CHECK-DAG: %[[x_2Dist:.*]] = math.absf %[[x_2D]] : f32 - // CHECK-DAG: %[[distSQ:.*]] = arith.mulf %52, %52 : f32 - // CHECK-DAG: %[[distCubed:.*]] = arith.mulf %[[distSQ]], %52 : f32 - %none = torch.constant.none - %none_0 = torch.constant.none - %int0 = torch.constant.int 0 - %false = torch.constant.bool false - %true = torch.constant.bool true - %str = torch.constant.str "cubic" - %int2 = torch.constant.int 2 - %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - %int3 = torch.constant.int 3 - %2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - %3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int - %4 = torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list - %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> - return %5 : !torch.vtensor<[?,?,?,?],f32> -} - // ----- diff --git a/test/Conversion/TorchToLinalg/spectral.mlir b/test/Conversion/TorchToLinalg/spectral.mlir deleted file mode 100644 index abd45183bd84..000000000000 --- a/test/Conversion/TorchToLinalg/spectral.mlir +++ /dev/null @@ -1,64 +0,0 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -canonicalize -split-input-file -verify-diagnostics | FileCheck %s - -// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1)> -// CHECK: #map1 = affine_map<(d0, d1, d2) -> (d1, d2)> -// CHECK: #map2 = affine_map<(d0, d1, d2) -> (d0, d2)> - -// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_last_dim( -// CHECK-SAME: %arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { -// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<9x5xcomplex> -// CHECK: %[[VAR0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[16,9],f32> -> tensor<16x9xf32> -// CHECK-DAG: %[[VAR1:.*]] = tensor.empty() : tensor<16x5xcomplex> -// CHECK: %[[VAR2:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAR1]] : tensor<16x5xcomplex>) -> tensor<16x5xcomplex> -// CHECK: %[[VAR3:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[VAR0]], %[[CST_0]] : tensor<16x9xf32>, tensor<9x5xcomplex>) outs(%[[VAR2]] : tensor<16x5xcomplex>) { -// CHECK: ^bb0(%in: f32, %in_1: complex, %out: complex): -// CHECK: %[[VAR5:.*]] = complex.re %in_1 : complex -// CHECK: %[[VAR6:.*]] = complex.im %in_1 : complex -// CHECK: %[[VAR7:.*]] = arith.mulf %in, %[[VAR5]] : f32 -// CHECK: %[[VAR8:.*]] = arith.mulf %in, %[[VAR6]] : f32 -// CHECK: %[[VAR9:.*]] = complex.create %[[VAR7]], %[[VAR8]] : complex -// CHECK: %[[VAR10:.*]] = complex.add %[[VAR9]], %out : complex -// CHECK: linalg.yield %[[VAR10]] : complex -// CHECK: } -> tensor<16x5xcomplex> -// CHECK: %[[VAR4:.*]] = torch_c.from_builtin_tensor %[[VAR3]] : tensor<16x5xcomplex> -> !torch.vtensor<[16,5],complex> -// CHECK: return %[[VAR4]] : !torch.vtensor<[16,5],complex> - -func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { - %int-1 = torch.constant.int -1 - %none = torch.constant.none - %out = torch.aten.fft_rfft %arg0, %none, %int-1, %none : !torch.vtensor<[16,9],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[16,5],complex> - return %out : !torch.vtensor<[16,5],complex> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_first_dim( -// CHECK-SAME: %arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { -// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<36x19xcomplex> -// CHECK: %[[VAR0:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[36,23],f32> -> tensor<36x23xf32> -// CHECK-DAG: %[[VAR1:.*]] = tensor.empty() : tensor<23x36xf32> -// CHECK: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[VAR0]] : tensor<36x23xf32>) outs(%[[VAR1]] : tensor<23x36xf32>) permutation = [1, 0] -// CHECK-DAG: %[[VAR2:.*]] = tensor.empty() : tensor<23x19xcomplex> -// CHECK: %[[VAR3:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[VAR2]] : tensor<23x19xcomplex>) -> tensor<23x19xcomplex> -// CHECK: %[[VAR4:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[TRANSPOSED]], %[[CST_0]] : tensor<23x36xf32>, tensor<36x19xcomplex>) outs(%[[VAR3]] : tensor<23x19xcomplex>) { -// CHECK: ^bb0(%in: f32, %in_2: complex, %out: complex): -// CHECK: %[[VAR7:.*]] = complex.re %in_2 : complex -// CHECK: %[[VAR8:.*]] = complex.im %in_2 : complex -// CHECK: %[[VAR9:.*]] = arith.mulf %in, %[[VAR7]] : f32 -// CHECK: %[[VAR10:.*]] = arith.mulf %in, %[[VAR8]] : f32 -// CHECK: %[[VAR11:.*]] = complex.create %[[VAR9]], %[[VAR10]] : complex -// CHECK: %[[VAR12:.*]] = complex.add %[[VAR11]], %out : complex -// CHECK: linalg.yield %[[VAR12]] : complex -// CHECK: } -> tensor<23x19xcomplex> -// CHECK-DAG: %[[VAR5:.*]] = tensor.empty() : tensor<19x23xcomplex> -// CHECK: %[[TRANSPOSED_1:.*]] = linalg.transpose ins(%[[VAR4]] : tensor<23x19xcomplex>) outs(%[[VAR5]] : tensor<19x23xcomplex>) permutation = [1, 0] -// CHECK: %[[VAR6:.*]] = torch_c.from_builtin_tensor %[[TRANSPOSED_1]] : tensor<19x23xcomplex> -> !torch.vtensor<[19,23],complex> -// CHECK: return %[[VAR6]] : !torch.vtensor<[19,23],complex> -func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { - %int0 = torch.constant.int 0 - %none = torch.constant.none - %out = torch.aten.fft_rfft %arg0, %none, %int0, %none : !torch.vtensor<[36,23],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[19,23],complex> - return %out : !torch.vtensor<[19,23],complex> -} diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index a3d52166385a..548c0b4baf06 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1766,11 +1766,10 @@ func.func @torch.aten.erf$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.cast %[[VAL_3]] : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.bitwise_and %[[VAL_1]], %[[VAL_4]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],si32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],si32> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.bitwise_and %[[VAL_1]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],si32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32> // CHECK: } func.func @torch.aten.bitwise_and.Scalar$basic(%arg0: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> { %int2 = torch.constant.int 2 @@ -1800,11 +1799,10 @@ func.func @torch.aten.le.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.cast %[[VAL_3]] : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.greater_equal %[[VAL_4]], %[[VAL_1]] : (tensor, tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor -> !torch.vtensor<[?,?],i1> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],i1> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.greater_equal %[[VAL_3]], %[[VAL_1]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> // CHECK: } func.func @torch.aten.le.Scalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> { %int2 = torch.constant.int 2 @@ -2426,570 +2424,3 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to %0 = torch.prims.collapse %arg0, %int1, %int2 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,12],f32> return %0 : !torch.vtensor<[2,12],f32> } - -// ----- - -func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> { - %int1 = torch.constant.int 1 - %int3 = torch.constant.int 3 - %false = torch.constant.bool false - %count_include_pad = torch.constant.bool true - %0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list - %1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list - %2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list - // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}} - %3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32> - return %3 : !torch.vtensor<[1,512,10],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.reflection_pad1d$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,2,4],f32> -> tensor<1x2x4xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.int 3 -// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_5:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x2x4xf32>) -> tensor<1x2x3xf32> -// CHECK: %[[VAL_6:.*]] = tosa.reverse %[[VAL_5]] {axis = 2 : i32} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> -// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x2x4xf32>) -> tensor<1x2x1xf32> -// CHECK: %[[VAL_8:.*]] = tosa.reverse %[[VAL_7]] {axis = 2 : i32} : (tensor<1x2x1xf32>) -> tensor<1x2x1xf32> -// CHECK: %[[VAL_9:.*]] = tosa.concat %[[VAL_6]], %[[VAL_1]], %[[VAL_8]] {axis = 2 : i32} : (tensor<1x2x3xf32>, tensor<1x2x4xf32>, tensor<1x2x1xf32>) -> tensor<1x2x8xf32> -// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<1x2x8xf32> -> !torch.vtensor<[1,2,8],f32> -// CHECK: return %[[VAL_10]] : !torch.vtensor<[1,2,8],f32> -// CHECK: } -func.func @torch.aten.reflection_pad1d$basic(%arg0: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> { - %int3 = torch.constant.int 3 - %int1 = torch.constant.int 1 - %0 = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list - %1 = torch.aten.reflection_pad1d %arg0, %0 : !torch.vtensor<[1,2,4],f32>, !torch.list -> !torch.vtensor<[1,2,8],f32> - return %1 : !torch.vtensor<[1,2,8],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.reflection_pad2d$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,20,20],f32>) -> !torch.vtensor<[1,40,40],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,20,20],f32> -> tensor<1x20x20xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.int 10 -// CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_2]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x20x20xf32>) -> tensor<1x20x10xf32> -// CHECK: %[[VAL_5:.*]] = tosa.reverse %[[VAL_4]] {axis = 2 : i32} : (tensor<1x20x10xf32>) -> tensor<1x20x10xf32> -// CHECK: %[[VAL_6:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x20x20xf32>) -> tensor<1x20x10xf32> -// CHECK: %[[VAL_7:.*]] = tosa.reverse %[[VAL_6]] {axis = 2 : i32} : (tensor<1x20x10xf32>) -> tensor<1x20x10xf32> -// CHECK: %[[VAL_8:.*]] = tosa.concat %[[VAL_5]], %[[VAL_1]], %[[VAL_7]] {axis = 2 : i32} : (tensor<1x20x10xf32>, tensor<1x20x20xf32>, tensor<1x20x10xf32>) -> tensor<1x20x40xf32> -// CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]] {size = array, start = array} : (tensor<1x20x40xf32>) -> tensor<1x10x40xf32> -// CHECK: %[[VAL_10:.*]] = tosa.reverse %[[VAL_9]] {axis = 1 : i32} : (tensor<1x10x40xf32>) -> tensor<1x10x40xf32> -// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_8]] {size = array, start = array} : (tensor<1x20x40xf32>) -> tensor<1x10x40xf32> -// CHECK: %[[VAL_12:.*]] = tosa.reverse %[[VAL_11]] {axis = 1 : i32} : (tensor<1x10x40xf32>) -> tensor<1x10x40xf32> -// CHECK: %[[VAL_13:.*]] = tosa.concat %[[VAL_10]], %[[VAL_8]], %[[VAL_12]] {axis = 1 : i32} : (tensor<1x10x40xf32>, tensor<1x20x40xf32>, tensor<1x10x40xf32>) -> tensor<1x40x40xf32> -// CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_13]] : tensor<1x40x40xf32> -> !torch.vtensor<[1,40,40],f32> -// CHECK: return %[[VAL_14]] : !torch.vtensor<[1,40,40],f32> -// CHECK: } -func.func @torch.aten.reflection_pad2d$basic(%arg0: !torch.vtensor<[1,20,20],f32>) -> !torch.vtensor<[1,40,40],f32> { - %int10 = torch.constant.int 10 - %0 = torch.prim.ListConstruct %int10, %int10, %int10, %int10 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1 = torch.aten.reflection_pad2d %arg0, %0 : !torch.vtensor<[1,20,20],f32>, !torch.list -> !torch.vtensor<[1,40,40],f32> - return %1 : !torch.vtensor<[1,40,40],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.replication_pad2d$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,10,6],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,3,3],f32> -> tensor<1x1x3x3xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_4:.*]] = torch.constant.int 3 -// CHECK: %[[VAL_5:.*]] = torch.constant.int 4 -// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_2]], %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x1x3x3xf32>) -> tensor<1x1x3x1xf32> -// CHECK: %[[VAL_8:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<1x1x3x3xf32>) -> tensor<1x1x3x1xf32> -// CHECK: %[[VAL_9:.*]] = tosa.concat %[[VAL_7]], %[[VAL_1]], %[[VAL_8]], %[[VAL_8]] {axis = 3 : i32} : (tensor<1x1x3x1xf32>, tensor<1x1x3x3xf32>, tensor<1x1x3x1xf32>, tensor<1x1x3x1xf32>) -> tensor<1x1x3x6xf32> -// CHECK: %[[VAL_10:.*]] = tosa.slice %[[VAL_9]] {size = array, start = array} : (tensor<1x1x3x6xf32>) -> tensor<1x1x1x6xf32> -// CHECK: %[[VAL_11:.*]] = tosa.slice %[[VAL_9]] {size = array, start = array} : (tensor<1x1x3x6xf32>) -> tensor<1x1x1x6xf32> -// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_10]], %[[VAL_10]], %[[VAL_10]], %[[VAL_9]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]], %[[VAL_11]] {axis = 2 : i32} : (tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x3x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>, tensor<1x1x1x6xf32>) -> tensor<1x1x10x6xf32> -// CHECK: %[[VAL_13:.*]] = torch_c.from_builtin_tensor %[[VAL_12]] : tensor<1x1x10x6xf32> -> !torch.vtensor<[1,1,10,6],f32> -// CHECK: return %[[VAL_13]] : !torch.vtensor<[1,1,10,6],f32> -// CHECK: } -func.func @torch.aten.replication_pad2d$basic(%arg0: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,10,6],f32> { - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 - %int3 = torch.constant.int 3 - %int4 = torch.constant.int 4 - %0 = torch.prim.ListConstruct %int1, %int2, %int3, %int4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - %1 = torch.aten.replication_pad2d %arg0, %0 : !torch.vtensor<[1,1,3,3],f32>, !torch.list -> !torch.vtensor<[1,1,10,6],f32> - return %1 : !torch.vtensor<[1,1,10,6],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.outer$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3],f32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4],f32> -> tensor<4xf32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3],f32> -> tensor<3xf32> -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<3xf32>) -> tensor<3x1xf32> -// CHECK: %[[VAL_5:.*]] = tosa.tile %[[VAL_4]] {multiples = array} : (tensor<3x1xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4xf32>) -> tensor<1x4xf32> -// CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_6]] {multiples = array} : (tensor<1x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_5]], %[[VAL_7]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_9]] : !torch.vtensor<[3,4],f32> -// CHECK: } -func.func @torch.aten.outer$basic(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4],f32> { - %0 = torch.aten.outer %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[3,4],f32> - return %0 : !torch.vtensor<[3,4],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.prims.split_dim$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,8,3,3],si64>) -> !torch.vtensor<[1,2,2,2,3,3],si64> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,8,3,3],si64> -> tensor<1x8x3x3xi64> -// CHECK: %[[VAL_2:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<1x8x3x3xi64>) -> tensor<1x2x4x3x3xi64> -// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<1x2x4x3x3xi64>) -> tensor<1x2x2x2x3x3xi64> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x2x2x2x3x3xi64> -> !torch.vtensor<[1,2,2,2,3,3],si64> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,2,2,2,3,3],si64> -// CHECK: } -func.func @torch.prims.split_dim$basic(%arg0: !torch.vtensor<[1,8,3,3],si64>) -> !torch.vtensor<[1,2,2,2,3,3],si64> { - %int1 = torch.constant.int 1 - %int2 = torch.constant.int 2 - %0 = torch.prims.split_dim %arg0, %int1, %int2 : !torch.vtensor<[1,8,3,3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2,4,3,3],si64> - %1 = torch.prims.split_dim %0, %int2, %int2 : !torch.vtensor<[1,2,4,3,3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,3,3],si64> - return %1 : !torch.vtensor<[1,2,2,2,3,3],si64> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.upsample_nearest2d$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,2,3],f64>) -> !torch.vtensor<[1,1,8,9],f64> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,2,3],f64> -> tensor<1x1x2x3xf64> -// CHECK: %[[VAL_2:.*]] = torch.constant.float 4.000000e+00 -// CHECK: %[[VAL_3:.*]] = torch.constant.float 3.000000e+00 -// CHECK: %[[VAL_4:.*]] = torch.constant.int 8 -// CHECK: %[[VAL_5:.*]] = torch.constant.int 9 -// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_5]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<1x1x2x3xf64>) -> tensor<1x1x6xf64> -// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5, 3, 3, 3, 4, 4, 4, 5, 5, 5]]]> : tensor<1x1x72xi32>}> : () -> tensor<1x1x72xi32> -// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor<1x1x72xi32>) -> tensor<1x1x72x1xi32> -// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x72x1xi32>}> : () -> tensor<1x1x72x1xi32> -// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x72x1xi32>}> : () -> tensor<1x1x72x1xi32> -// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_10]], %[[VAL_11]], %[[VAL_9]] {axis = 3 : i32} : (tensor<1x1x72x1xi32>, tensor<1x1x72x1xi32>, tensor<1x1x72x1xi32>) -> tensor<1x1x72x3xi32> -// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<1x1x6xf64>) -> tensor<1x6x1xf64> -// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<1x1x72x3xi32>) -> tensor<72x3xi32> -// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[6, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_14]], %[[VAL_15]] {shift = 0 : i8} : (tensor<72x3xi32>, tensor<3xi32>) -> tensor<72x3xi32> -// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<72x3xi32>) -> tensor<72x1xi32> -// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<72x1xi32>) -> tensor<1x72xi32> -// CHECK: %[[VAL_19:.*]] = tosa.gather %[[VAL_13]], %[[VAL_18]] : (tensor<1x6x1xf64>, tensor<1x72xi32>) -> tensor<1x72x1xf64> -// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<1x72x1xf64>) -> tensor<1x1x72xf64> -// CHECK: %[[VAL_21:.*]] = tosa.reshape %[[VAL_20]] {new_shape = array} : (tensor<1x1x72xf64>) -> tensor<1x1x8x9xf64> -// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x1x8x9xf64> -> !torch.vtensor<[1,1,8,9],f64> -// CHECK: return %[[VAL_22]] : !torch.vtensor<[1,1,8,9],f64> -// CHECK: } -func.func @torch.aten.upsample_nearest2d$basic(%arg0: !torch.vtensor<[1,1,2,3],f64>) -> !torch.vtensor<[1,1,8,9],f64> { - %float4.000000e00 = torch.constant.float 4.000000e+00 - %float3.000000e00 = torch.constant.float 3.000000e+00 - %int8 = torch.constant.int 8 - %int9 = torch.constant.int 9 - %0 = torch.prim.ListConstruct %int8, %int9 : (!torch.int, !torch.int) -> !torch.list - %1 = torch.aten.upsample_nearest2d %arg0, %0, %float4.000000e00, %float3.000000e00 : !torch.vtensor<[1,1,2,3],f64>, !torch.list, !torch.float, !torch.float -> !torch.vtensor<[1,1,8,9],f64> - return %1 : !torch.vtensor<[1,1,8,9],f64> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.upsample_nearest2d.vec$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,4,5],f32>) -> !torch.vtensor<[1,1,2,7],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,4,5],f32> -> tensor<1x1x4x5xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.none -// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_4:.*]] = torch.constant.int 7 -// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<1x1x4x5xf32>) -> tensor<1x1x20xf32> -// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0, 0, 1, 2, 2, 3, 4, 10, 10, 11, 12, 12, 13, 14]]]> : tensor<1x1x14xi32>}> : () -> tensor<1x1x14xi32> -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<1x1x14xi32>) -> tensor<1x1x14x1xi32> -// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x14x1xi32>}> : () -> tensor<1x1x14x1xi32> -// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x14x1xi32>}> : () -> tensor<1x1x14x1xi32> -// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<1x1x14x1xi32>, tensor<1x1x14x1xi32>, tensor<1x1x14x1xi32>) -> tensor<1x1x14x3xi32> -// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_6]] {new_shape = array} : (tensor<1x1x20xf32>) -> tensor<1x20x1xf32> -// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<1x1x14x3xi32>) -> tensor<14x3xi32> -// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[20, 20, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<14x3xi32>, tensor<3xi32>) -> tensor<14x3xi32> -// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<14x3xi32>) -> tensor<14x1xi32> -// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<14x1xi32>) -> tensor<1x14xi32> -// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x20x1xf32>, tensor<1x14xi32>) -> tensor<1x14x1xf32> -// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<1x14x1xf32>) -> tensor<1x1x14xf32> -// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<1x1x14xf32>) -> tensor<1x1x2x7xf32> -// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<1x1x2x7xf32> -> !torch.vtensor<[1,1,2,7],f32> -// CHECK: return %[[VAL_21]] : !torch.vtensor<[1,1,2,7],f32> -// CHECK: } -func.func @torch.aten.upsample_nearest2d.vec$basic(%arg0: !torch.vtensor<[1,1,4,5],f32>) -> !torch.vtensor<[1,1,2,7],f32> { - %none = torch.constant.none - %int2 = torch.constant.int 2 - %int7 = torch.constant.int 7 - %0 = torch.prim.ListConstruct %int2, %int7 : (!torch.int, !torch.int) -> !torch.list - %1 = torch.aten.upsample_nearest2d.vec %arg0, %0, %none : !torch.vtensor<[1,1,4,5],f32>, !torch.list, !torch.none -> !torch.vtensor<[1,1,2,7],f32> - return %1 : !torch.vtensor<[1,1,2,7],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.gelu$tanh( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,3],f32>) -> !torch.vtensor<[5,3],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,3],f32> -> tensor<5x3xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.str "tanh" -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> -// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<3.000000e+00> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> -// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<4.471500e-02> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> -// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0.636619746> : tensor<5x3xf32>}> : () -> tensor<5x3xf32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_3]], %[[VAL_1]] {shift = 0 : i8} : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_9:.*]] = tosa.pow %[[VAL_7]], %[[VAL_3]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_10:.*]] = tosa.pow %[[VAL_1]], %[[VAL_5]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_6]], %[[VAL_10]] {shift = 0 : i8} : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_12:.*]] = tosa.add %[[VAL_1]], %[[VAL_11]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_9]], %[[VAL_12]] {shift = 0 : i8} : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_14:.*]] = tosa.tanh %[[VAL_13]] : (tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_15:.*]] = tosa.add %[[VAL_4]], %[[VAL_14]] : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_8]], %[[VAL_15]] {shift = 0 : i8} : (tensor<5x3xf32>, tensor<5x3xf32>) -> tensor<5x3xf32> -// CHECK: %[[VAL_17:.*]] = torch_c.from_builtin_tensor %[[VAL_16]] : tensor<5x3xf32> -> !torch.vtensor<[5,3],f32> -// CHECK: return %[[VAL_17]] : !torch.vtensor<[5,3],f32> -// CHECK: } -func.func @torch.aten.gelu$tanh(%arg0: !torch.vtensor<[5,3],f32>) -> !torch.vtensor<[5,3],f32> { - %str = torch.constant.str "tanh" - %0 = torch.aten.gelu %arg0, %str : !torch.vtensor<[5,3],f32>, !torch.str -> !torch.vtensor<[5,3],f32> - return %0 : !torch.vtensor<[5,3],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.exp$int( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> -// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_3:.*]] = tosa.exp %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[3,4],f32> -// CHECK: } -func.func @torch.aten.exp$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { - %0 = torch.aten.exp %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> - return %0 : !torch.vtensor<[3,4],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.log10$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> -// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+01> : tensor}> : () -> tensor -// CHECK: %[[VAL_3:.*]] = tosa.log %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_4:.*]] = tosa.log %[[VAL_2]] : (tensor) -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_4]] : (tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor) -> tensor<3x4xf32> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> -// CHECK: } -func.func @torch.aten.log10$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { - %0 = torch.aten.log10 %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> - return %0 : !torch.vtensor<[3,4],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.log10$int( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> -// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e+01> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.log %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.log %[[VAL_3]] : (tensor) -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.reciprocal %[[VAL_5]] : (tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_4]], %[[VAL_6]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor) -> tensor<3x4xf32> -// CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_8]] : !torch.vtensor<[3,4],f32> -// CHECK: } -func.func @torch.aten.log10$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { - %0 = torch.aten.log10 %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> - return %0 : !torch.vtensor<[3,4],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.log1p$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> -// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_3:.*]] = tosa.add %[[VAL_1]], %[[VAL_2]] : (tensor<3x4xf32>, tensor) -> tensor<3x4xf32> -// CHECK: %[[VAL_4:.*]] = tosa.log %[[VAL_3]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32> -// CHECK: } -func.func @torch.aten.log1p$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { - %0 = torch.aten.log1p %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> - return %0 : !torch.vtensor<[3,4],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.log1p$int( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> -// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = tosa.add %[[VAL_2]], %[[VAL_3]] : (tensor<3x4xf32>, tensor) -> tensor<3x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.log %[[VAL_4]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> -// CHECK: } -func.func @torch.aten.log1p$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { - %0 = torch.aten.log1p %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> - return %0 : !torch.vtensor<[3,4],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.logit$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.float 9.9999999999999995E-8 -// CHECK: %[[VAL_3:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0.99999988 : f32, max_int = 0 : i64, min_fp = 1.000000e-07 : f32, min_int = 0 : i64} : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.sub %[[VAL_4]], %[[VAL_3]] : (tensor, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = tosa.reciprocal %[[VAL_5]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_3]], %[[VAL_6]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.log %[[VAL_7]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_9]] : !torch.vtensor<[3,4],f32> -// CHECK: } -func.func @torch.aten.logit$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { - %float9.999990e-08 = torch.constant.float 9.9999999999999995E-8 - %0 = torch.aten.logit %arg0, %float9.999990e-08 : !torch.vtensor<[3,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32> - return %0 : !torch.vtensor<[3,4],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.logit$int( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> -// CHECK: %[[VAL_2:.*]] = torch.constant.float 9.9999999999999995E-8 -// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_3]] {max_fp = 0.99999988 : f32, max_int = 0 : i64, min_fp = 1.000000e-07 : f32, min_int = 0 : i64} : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_6:.*]] = tosa.sub %[[VAL_5]], %[[VAL_4]] : (tensor, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_7:.*]] = tosa.reciprocal %[[VAL_6]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_4]], %[[VAL_7]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_9:.*]] = tosa.log %[[VAL_8]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_10]] : !torch.vtensor<[3,4],f32> -// CHECK: } -func.func @torch.aten.logit$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { - %float9.999990e-08 = torch.constant.float 9.9999999999999995E-8 - %0 = torch.aten.logit %arg0, %float9.999990e-08 : !torch.vtensor<[3,4],si32>, !torch.float -> !torch.vtensor<[3,4],f32> - return %0 : !torch.vtensor<[3,4],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.log$int( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> -// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_3:.*]] = tosa.log %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[3,4],f32> -// CHECK: } -func.func @torch.aten.log$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { - %0 = torch.aten.log %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> - return %0 : !torch.vtensor<[3,4],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.log2$int( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> -// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.693147182> : tensor<1x1xf32>}> : () -> tensor<1x1xf32> -// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor<1x1xf32>) -> tensor<1x1xf32> -// CHECK: %[[VAL_5:.*]] = tosa.log %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_5]], %[[VAL_4]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<1x1xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> -// CHECK: } -func.func @torch.aten.log2$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { - %0 = torch.aten.log2 %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> - return %0 : !torch.vtensor<[3,4],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.erf$int( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> -// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_3:.*]] = tosa.erf %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[3,4],f32> -// CHECK: } -func.func @torch.aten.erf$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { - %0 = torch.aten.erf %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> - return %0 : !torch.vtensor<[3,4],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.lt.Scalar$intfloat( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],i1> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4],si64> -> tensor<4xi64> -// CHECK: %[[VAL_2:.*]] = torch.constant.float 1.100000e+00 -// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.100000e+00> : tensor}> : () -> tensor -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1.1000000238418579> : tensor}> : () -> tensor -// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_1]] : (tensor<4xi64>) -> tensor<4xf64> -// CHECK: %[[VAL_6:.*]] = tosa.greater %[[VAL_4]], %[[VAL_5]] : (tensor, tensor<4xf64>) -> tensor<4xi1> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<4xi1> -> !torch.vtensor<[4],i1> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[4],i1> -// CHECK: } -func.func @torch.aten.lt.Scalar$intfloat(%arg0: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],i1> { - %float1.100000e00 = torch.constant.float 1.100000e+00 - %0 = torch.aten.lt.Scalar %arg0, %float1.100000e00 : !torch.vtensor<[4],si64>, !torch.float -> !torch.vtensor<[4],i1> - return %0 : !torch.vtensor<[4],i1> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.sigmoid$int( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,5],si32>) -> !torch.vtensor<[3,5],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,5],si32> -> tensor<3x5xi32> -// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x5xi32>) -> tensor<3x5xf32> -// CHECK: %[[VAL_3:.*]] = tosa.sigmoid %[[VAL_2]] : (tensor<3x5xf32>) -> tensor<3x5xf32> -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<3x5xf32> -> !torch.vtensor<[3,5],f32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[3,5],f32> -// CHECK: } -func.func @torch.aten.sigmoid$int(%arg0: !torch.vtensor<[3,5],si32>) -> !torch.vtensor<[3,5],f32> { - %0 = torch.aten.sigmoid %arg0 : !torch.vtensor<[3,5],si32> -> !torch.vtensor<[3,5],f32> - return %0 : !torch.vtensor<[3,5],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.tan$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],f32> -> tensor<3x4xf32> -// CHECK: %[[VAL_2:.*]] = tosa.sin %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_3:.*]] = tosa.cos %[[VAL_1]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_4:.*]] = tosa.reciprocal %[[VAL_3]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.mul %[[VAL_2]], %[[VAL_4]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4],f32> -// CHECK: } -func.func @torch.aten.tan$basic(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> { - %0 = torch.aten.tan %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> - return %0 : !torch.vtensor<[3,4],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.tan$int( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> -// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_3:.*]] = tosa.sin %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_4:.*]] = tosa.cos %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_5:.*]] = tosa.reciprocal %[[VAL_4]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_3]], %[[VAL_5]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_7]] : !torch.vtensor<[3,4],f32> -// CHECK: } -func.func @torch.aten.tan$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { - %0 = torch.aten.tan %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> - return %0 : !torch.vtensor<[3,4],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.tanh$int( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4],si32> -> tensor<3x4xi32> -// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<3x4xi32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_3:.*]] = tosa.tanh %[[VAL_2]] : (tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_4]] : !torch.vtensor<[3,4],f32> -// CHECK: } -func.func @torch.aten.tanh$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],f32> { - %0 = torch.aten.tanh %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],f32> - return %0 : !torch.vtensor<[3,4],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.pow.Tensor_Tensor$intfloat( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,5],si32>, -// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { -// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[3,4,5],f32> -> tensor<3x4x5xf32> -// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,5],si32> -> tensor<3x4x5xi32> -// CHECK: %[[VAL_4:.*]] = tosa.cast %[[VAL_3]] : (tensor<3x4x5xi32>) -> tensor<3x4x5xf32> -// CHECK: %[[VAL_5:.*]] = tosa.pow %[[VAL_4]], %[[VAL_2]] : (tensor<3x4x5xf32>, tensor<3x4x5xf32>) -> tensor<3x4x5xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<3x4x5xf32> -> !torch.vtensor<[3,4,5],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[3,4,5],f32> -// CHECK: } -func.func @torch.aten.pow.Tensor_Tensor$intfloat(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> { - %0 = torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> - return %0 : !torch.vtensor<[3,4,5],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.unfold$basic( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[6,4],f32>) -> !torch.vtensor<[3,4,2],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[6,4],f32> -> tensor<6x4xf32> -// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 -// CHECK: %[[VAL_3:.*]] = torch.constant.int 2 -// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5]]> : tensor<6x4xi32>}> : () -> tensor<6x4xi32> -// CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]] {new_shape = array} : (tensor<6x4xi32>) -> tensor<6x4x1xi32> -// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<{{\[\[}}[0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]], {{\[\[}}0], [1], [2], [3]]]> : tensor<6x4x1xi32>}> : () -> tensor<6x4x1xi32> -// CHECK: %[[VAL_7:.*]] = tosa.concat %[[VAL_5]], %[[VAL_6]] {axis = 2 : i32} : (tensor<6x4x1xi32>, tensor<6x4x1xi32>) -> tensor<6x4x2xi32> -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor<6x4xf32>) -> tensor<1x24x1xf32> -// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<6x4x2xi32>) -> tensor<24x2xi32> -// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<[4, 1]> : tensor<2xi32>}> : () -> tensor<2xi32> -// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_9]], %[[VAL_10]] {shift = 0 : i8} : (tensor<24x2xi32>, tensor<2xi32>) -> tensor<24x2xi32> -// CHECK: %[[VAL_12:.*]] = tosa.reduce_sum %[[VAL_11]] {axis = 1 : i32} : (tensor<24x2xi32>) -> tensor<24x1xi32> -// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<24x1xi32>) -> tensor<1x24xi32> -// CHECK: %[[VAL_14:.*]] = tosa.gather %[[VAL_8]], %[[VAL_13]] : (tensor<1x24x1xf32>, tensor<1x24xi32>) -> tensor<1x24x1xf32> -// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array} : (tensor<1x24x1xf32>) -> tensor<6x4xf32> -// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]] {new_shape = array} : (tensor<6x4xf32>) -> tensor<3x2x4xf32> -// CHECK: %[[VAL_17:.*]] = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_16]], %[[VAL_17]] : (tensor<3x2x4xf32>, tensor<3xi32>) -> tensor<3x4x2xf32> -// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<3x4x2xf32> -> !torch.vtensor<[3,4,2],f32> -// CHECK: return %[[VAL_19]] : !torch.vtensor<[3,4,2],f32> -// CHECK: } -func.func @torch.aten.unfold$basic(%arg0: !torch.vtensor<[6,4],f32>) -> !torch.vtensor<[3,4,2],f32> { - %int0 = torch.constant.int 0 - %int2 = torch.constant.int 2 - %0 = torch.aten.unfold %arg0, %int0, %int2, %int2 : !torch.vtensor<[6,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,4,2],f32> - return %0 : !torch.vtensor<[3,4,2],f32> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.unfold$rank_zero( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { -// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 -// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array} : (tensor) -> tensor<1xf32> -// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1xf32> -> !torch.vtensor<[1],f32> -// CHECK: return %[[VAL_5]] : !torch.vtensor<[1],f32> -// CHECK: } -func.func @torch.aten.unfold$rank_zero(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { - %int0 = torch.constant.int 0 - %int1 = torch.constant.int 1 - %0 = torch.aten.unfold %arg0, %int0, %int1, %int1 : !torch.vtensor<[],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1],f32> - return %0 : !torch.vtensor<[1],f32> -} - -// ----- diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index d4afd67d65db..263e69169cf3 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -137,46 +137,6 @@ func.func @torch.aten.__isnot__$none_isnot_none(%arg0: !torch.none, %arg1: !torc return %0 : !torch.bool } -// CHECK-LABEL: func.func @torch.aten.eq.bool$same_value() -> !torch.bool { -// CHECK: %[[TRUE:.*]] = torch.constant.bool true -// CHECK: return %[[TRUE]] : !torch.bool -func.func @torch.aten.eq.bool$same_value() -> !torch.bool { - %a = torch.constant.bool false - %b = torch.constant.bool false - %0 = torch.aten.eq.bool %a, %b: !torch.bool, !torch.bool -> !torch.bool - return %0 : !torch.bool -} - -// CHECK-LABEL: func.func @torch.aten.eq.bool$different_value() -> !torch.bool { -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: return %[[FALSE]] : !torch.bool -func.func @torch.aten.eq.bool$different_value() -> !torch.bool { - %a = torch.constant.bool true - %b = torch.constant.bool false - %0 = torch.aten.eq.bool %a, %b: !torch.bool, !torch.bool -> !torch.bool - return %0 : !torch.bool -} - -// CHECK-LABEL: func.func @torch.aten.eq.bool$same_operand( -// CHECK-SAME: %[[ARG0:.*]]: !torch.bool) -> !torch.bool { -// CHECK: %[[TRUE:.*]] = torch.constant.bool true -// CHECK: return %[[TRUE]] : !torch.bool -func.func @torch.aten.eq.bool$same_operand(%arg0: !torch.bool) -> !torch.bool { - %0 = torch.aten.eq.bool %arg0, %arg0: !torch.bool, !torch.bool -> !torch.bool - return %0 : !torch.bool -} - -// CHECK-LABEL: func.func @torch.aten.eq.bool$different_operand( -// CHECK-SAME: %[[ARG0:.*]]: !torch.bool) -> !torch.bool { -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[RET:.*]] = torch.aten.eq.bool %[[ARG0]], %[[FALSE]] : !torch.bool, !torch.bool -> !torch.bool -// CHECK: return %[[RET]] : !torch.bool -func.func @torch.aten.eq.bool$different_operand(%a: !torch.bool) -> !torch.bool { - %b = torch.constant.bool false - %0 = torch.aten.eq.bool %a, %b: !torch.bool, !torch.bool -> !torch.bool - return %0 : !torch.bool -} - // CHECK-LABEL: func.func @torch.aten.ne.bool() -> !torch.bool { // CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: return %[[TRUE]] : !torch.bool @@ -738,20 +698,6 @@ func.func @torch.aten.len.t$no_fold_list_mutated() -> !torch.int { return %2 : !torch.int } -// CHECK-LABEL: func.func @torch.aten.mul.left_t( -// CHECK: %[[C4:.*]] = torch.constant.int 4 -// CHECK: %[[C5:.*]] = torch.constant.int 5 -// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C4]], %[[C5]], %[[C4]], %[[C5]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: return %[[LIST]] : !torch.list -func.func @torch.aten.mul.left_t() -> !torch.list { - %int4 = torch.constant.int 4 - %int5 = torch.constant.int 5 - %int2 = torch.constant.int 2 - %0 = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list - %1 = torch.aten.mul.left_t %0, %int2 : !torch.list, !torch.int -> !torch.list - return %1 : !torch.list -} - // CHECK-LABEL: func.func @torch.aten.__getitem__.t( // CHECK: %[[C5:.*]] = torch.constant.int 5 // CHECK: return %[[C5]] : !torch.int @@ -1235,16 +1181,6 @@ func.func @torch.aten.mul.int$canonicalize(%arg0: !torch.int) -> !torch.int { return %ret : !torch.int } -// CHECK-LABEL: func.func @torch.aten.mul.int_float() -> !torch.float { -// CHECK: %[[CST6:.*]] = torch.constant.float 6.000000e+00 -// CHECK: return %[[CST6]] : !torch.float -func.func @torch.aten.mul.int_float() -> !torch.float { - %cst2 = torch.constant.int 2 - %cst3 = torch.constant.float 3.0 - %ret = torch.aten.mul.int_float %cst2, %cst3: !torch.int, !torch.float -> !torch.float - return %ret : !torch.float -} - // CHECK-LABEL: func.func @torch.aten.mul.float() -> !torch.float { // CHECK: %[[CST30:.*]] = torch.constant.float 3.000000e+01 // CHECK: return %[[CST30]] : !torch.float @@ -1255,16 +1191,6 @@ func.func @torch.aten.mul.float() -> !torch.float { return %ret : !torch.float } -// CHECK-LABEL: func.func @torch.aten.mul.float_int() -> !torch.float { -// CHECK: %[[CST6:.*]] = torch.constant.float 6.000000e+00 -// CHECK: return %[[CST6]] : !torch.float -func.func @torch.aten.mul.float_int() -> !torch.float { - %cst2 = torch.constant.float 2.0 - %cst3 = torch.constant.int 3 - %ret = torch.aten.mul.float_int %cst2, %cst3: !torch.float, !torch.int -> !torch.float - return %ret : !torch.float -} - // CHECK-LABEL: func.func @torch.aten.neg.float() -> !torch.float { // CHECK: %[[CST_6:.*]] = torch.constant.float -6.000000e+00 // CHECK: return %[[CST_6]] : !torch.float diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 384502ecd2af..f938a2637835 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -25,19 +25,6 @@ func.func @matmul_decompose_3d(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch return %0 : !torch.tensor } -// ----- -// CHECK-LABEL: func.func @argmax_rank_1 -// CHECK: %[[I0:.*]] = torch.constant.int 0 -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[VALUES:.*]], %[[INDICES:.*]] = torch.aten.max.dim %arg0, %[[I0]], %[[FALSE]] : !torch.vtensor<[20],si32>, !torch.int, !torch.bool -> !torch.vtensor<[],si32>, !torch.vtensor<[],si64> -// CHECK: return %[[INDICES]] : !torch.vtensor<[],si64> -func.func @argmax_rank_1(%arg0: !torch.vtensor<[20],si32>) -> !torch.vtensor<[],si64> { - %none = torch.constant.none - %false = torch.constant.bool false - %7 = torch.aten.argmax %arg0, %none, %false : !torch.vtensor<[20],si32>, !torch.none, !torch.bool -> !torch.vtensor<[],si64> - return %7 : !torch.vtensor<[],si64> -} - // ----- // CHECK-LABEL: func.func @torch.aten.type_as$basic( // CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor { @@ -118,9 +105,9 @@ func.func @torch.aten.fake_quantize_per_channel_affine_cachemask(%arg0: !torch.v // CHECK-LABEL: test_einsum_inner_prod func.func @test_einsum_inner_prod(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[5],f64>) -> !torch.vtensor<[],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64} { - // CHECK-DAG: %[[INT5:.+]] = torch.constant.int 5 - // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 - // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT5:.+]] = torch.constant.int 5 + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[LHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]] // CHECK: %[[LHS_PERM:.+]] = torch.aten.permute %arg0, %[[LHS_LIST]] // CHECK: %[[RHS_LIST:.+]] = torch.prim.ListConstruct %[[INT0]] @@ -184,47 +171,3 @@ func.func @torch.aten.fmod_float(%arg0: !torch.vtensor<[?],f16>, %arg1: !torch.v %0 = torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[?],f16>, !torch.vtensor<[1],f16> -> !torch.vtensor<[?],f16> return %0 : !torch.vtensor<[?],f16> } - -// ----- - -// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_last_dim( -// CHECK-SAME: %arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { -// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2 -// CHECK-DAG: %[[INT5:.*]] = torch.constant.int 5 -// CHECK-DAG: %[[INT16:.*]] = torch.constant.int 16 -// CHECK: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<9x10xf32>) : !torch.vtensor<[9,10],f32> -// CHECK: %[[VAR1:.*]] = torch.aten.mm %arg0, %[[VAR0]] : !torch.vtensor<[16,9],f32>, !torch.vtensor<[9,10],f32> -> !torch.vtensor<[16,10],f32> -// CHECK: %[[VAR2:.*]] = torch.prim.ListConstruct %[[INT16]], %[[INT5]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAR3:.*]] = torch.aten.view %[[VAR1]], %[[VAR2]] : !torch.vtensor<[16,10],f32>, !torch.list -> !torch.vtensor<[16,5,2],f32> -// CHECK: %[[VAR4:.*]] = torch.aten.view_as_complex %[[VAR3]] : !torch.vtensor<[16,5,2],f32> -> !torch.vtensor<[16,5],complex> -// CHECK: return %[[VAR4]] : !torch.vtensor<[16,5],complex> -func.func @torch.aten.fft_rfft$2d_last_dim(%arg0: !torch.vtensor<[16,9],f32>) -> !torch.vtensor<[16,5],complex> { - %int-1 = torch.constant.int -1 - %none = torch.constant.none - %out = torch.aten.fft_rfft %arg0, %none, %int-1, %none : !torch.vtensor<[16,9],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[16,5],complex> - return %out : !torch.vtensor<[16,5],complex> -} - -// ----- - -// CHECK-LABEL: func.func @torch.aten.fft_rfft$2d_first_dim( -// CHECK-SAME: %arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { -// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2 -// CHECK-DAG: %[[INT19:.*]] = torch.constant.int 19 -// CHECK-DAG: %[[INT23:.*]] = torch.constant.int 23 -// CHECK-DAG: %[[VAR0:.*]] = torch.vtensor.literal(dense<{{.*}}> : tensor<36x38xf32>) : !torch.vtensor<[36,38],f32> -// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0 -// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[VAR1:.*]] = torch.aten.transpose.int %arg0, %[[INT0]], %[[INT1]] : !torch.vtensor<[36,23],f32>, !torch.int, !torch.int -> !torch.vtensor<[23,36],f32> -// CHECK: %[[VAR2:.*]] = torch.aten.mm %[[VAR1]], %[[VAR0]] : !torch.vtensor<[23,36],f32>, !torch.vtensor<[36,38],f32> -> !torch.vtensor<[23,38],f32> -// CHECK: %[[VAR3:.*]] = torch.prim.ListConstruct %[[INT23]], %[[INT19]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAR4:.*]] = torch.aten.view %[[VAR2]], %[[VAR3]] : !torch.vtensor<[23,38],f32>, !torch.list -> !torch.vtensor<[23,19,2],f32> -// CHECK: %[[VAR5:.*]] = torch.aten.view_as_complex %[[VAR4]] : !torch.vtensor<[23,19,2],f32> -> !torch.vtensor<[23,19],complex> -// CHECK: %[[VAR6:.*]] = torch.aten.transpose.int %[[VAR5]], %[[INT0]], %[[INT1]] : !torch.vtensor<[23,19],complex>, !torch.int, !torch.int -> !torch.vtensor<[19,23],complex> -// CHECK: return %[[VAR6]] : !torch.vtensor<[19,23],complex> -func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) -> !torch.vtensor<[19,23],complex> { - %int0 = torch.constant.int 0 - %none = torch.constant.none - %out = torch.aten.fft_rfft %arg0, %none, %int0, %none : !torch.vtensor<[36,23],f32>, !torch.none, !torch.int, !torch.none -> !torch.vtensor<[19,23],complex> - return %out : !torch.vtensor<[19,23],complex> -} diff --git a/test/Dialect/Torch/scalarize-shapes.mlir b/test/Dialect/Torch/scalarize-shapes.mlir index 00975a2405be..5ea715735c70 100644 --- a/test/Dialect/Torch/scalarize-shapes.mlir +++ b/test/Dialect/Torch/scalarize-shapes.mlir @@ -27,8 +27,12 @@ func.func @shape_as_tensor(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtenso // CHECK-LABEL: @shape_as_tensor_dim func.func @shape_as_tensor_dim(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],si32> { // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK-DAG: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[INT1]] - // CHECK: %[[TENSOR:.+]] = torch.prim.NumToTensor.Scalar %[[SZ]] : !torch.int -> !torch.vtensor<[],si32> + // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[INT1]] + // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT1_0]] + // CHECK: %[[TENSOR:.+]] = torch.aten.full %[[LIST]], %[[SZ]], %[[NONE]], %[[NONE]], %[[NONE]], %[[FALSE]] // CHECK: return %[[TENSOR]] : !torch.vtensor<[],si32> %shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32> %dim = torch.constant.int 0 @@ -39,75 +43,6 @@ func.func @shape_as_tensor_dim(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vt return %select : !torch.vtensor<[],si32> } -// ----- - -// CHECK-LABEL: @cast_int_int -func.func @cast_int_int(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],si64> { - // CHECK: %[[I1:.*]] = torch.constant.int 1 - // CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[5,?,?],f32>, !torch.int -> !torch.int - // CHECK: %[[TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[SZE]] : !torch.int -> !torch.vtensor<[],si64> - // CHECK: return %[[TENSOR]] : !torch.vtensor<[],si64> - %int4 = torch.constant.int 4 - %false = torch.constant.bool false - %none = torch.constant.none - %shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32> - %cast_shape = torch.aten.to.dtype %shape, %int4, %false, %false, %none : !torch.vtensor<[3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3],si64> - %dim = torch.constant.int 0 - %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si32> - %select = torch.aten.index_select %cast_shape, %dim, %idx : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],si64> - %item = torch.aten.item %select : !torch.vtensor<[],si64> -> !torch.int - %list = torch.prim.ListConstruct %item : (!torch.int) -> !torch.list - return %select : !torch.vtensor<[],si64> -} - -// ----- - -// CHECK-LABEL: @cast_int_float -func.func @cast_int_float(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[],f32> { - // CHECK: %[[I1:.*]] = torch.constant.int 1 - // CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[5,?,?],f32>, !torch.int -> !torch.int - // CHECK: %[[FLOAT:.*]] = torch.aten.Float.Scalar %[[SZE]] : !torch.int -> !torch.float - // CHECK: %[[TENSOR:.*]] = torch.prim.NumToTensor.Scalar %[[FLOAT]] : !torch.float -> !torch.vtensor<[],f32> - // CHECK: return %[[TENSOR]] : !torch.vtensor<[],f32> - %int6 = torch.constant.int 6 - %false = torch.constant.bool false - %none = torch.constant.none - %shape = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[5,?,?],f32> -> !torch.vtensor<[3],si32> - %cast_shape = torch.aten.to.dtype %shape, %int6, %false, %false, %none : !torch.vtensor<[3],si32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3],f32> - %dim = torch.constant.int 0 - %idx = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si32> - %select = torch.aten.index_select %cast_shape, %dim, %idx : !torch.vtensor<[3],f32>, !torch.int, !torch.vtensor<[],si32> -> !torch.vtensor<[],f32> - %item = torch.aten.item %select : !torch.vtensor<[],f32> -> !torch.float - %item_int = torch.aten.Int.Scalar %item : !torch.float -> !torch.int - %list = torch.prim.ListConstruct %item_int : (!torch.int) -> !torch.list - return %select : !torch.vtensor<[],f32> -} - -// ----- - -// CHECK-LABEL: @cast_int_float_static -func.func @cast_int_float_static(%arg0 : !torch.vtensor<[5,?,?],f32>) -> !torch.vtensor<[3],f32> { - // CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 - // CHECK: %[[FLOAT2:.*]] = torch.constant.float 2.000000e+00 - // CHECK: %[[FLOAT3:.*]] = torch.constant.float 3.000000e+00 - // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[FLOAT1:.*]], %[[FLOAT2:.*]], %[[FLOAT3:.*]] : (!torch.float, !torch.float, !torch.float) -> !torch.list - // CHECK: %[[NONE:.*]] = torch.constant.none - // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[3],f32> - // CHECK: return %[[TENSOR]] : !torch.vtensor<[3],f32> - %int6 = torch.constant.int 6 - %false = torch.constant.bool false - %none = torch.constant.none - %shape = torch.vtensor.literal(dense<[1,2,3]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> - %cast_shape = torch.aten.to.dtype %shape, %int6, %false, %false, %none : !torch.vtensor<[3],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3],f32> - %dim = torch.constant.int 0 - %idx0 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> - %select0 = torch.aten.index_select %cast_shape, %dim, %idx0 : !torch.vtensor<[3],f32>, !torch.int, !torch.vtensor<[],si64> -> !torch.vtensor<[],f32> - %item0 = torch.aten.item %select0 : !torch.vtensor<[],f32> -> !torch.float - %item_int0 = torch.aten.Int.Scalar %item0 : !torch.float -> !torch.int - %list = torch.prim.ListConstruct %item_int0 : (!torch.int) -> !torch.list - return %cast_shape : !torch.vtensor<[3],f32> -} // ----- @@ -154,12 +89,14 @@ func.func @arith_prop(%arg0 : !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?] // CHECK: %[[x2:.*]] = torch.aten.floordiv.int %[[x0]], %[[int12]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[x3:.*]] = torch.aten.floordiv.int %[[x1]], %[[int1_0]] : !torch.int, !torch.int -> !torch.int // CHECK: %[[int12_1:.*]] = torch.constant.int 12 + // CHECK: %[[int1_2:.*]] = torch.constant.int 1 // CHECK: %[[x4:.*]] = torch.aten.mul.int %[[x2]], %[[int12_1]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[x5:.*]] = torch.aten.sub.int %[[x0]], %[[x4]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[x6:.*]] = torch.aten.sub.int %[[x1]], %[[x3]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[x7:.*]] = torch.prim.ListConstruct %[[x6]], %[[x5]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[x8:.*]] = torch.aten.constant_pad_nd %arg0, %[[x7]], %[[float0]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.float -> !torch.vtensor<[?,?],f32> - // CHECK: return %[[x8]] : !torch.vtensor<[?,?],f32> + // CHECK: %[[x5:.*]] = torch.aten.mul.int %[[x3]], %[[int1_2]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[x6:.*]] = torch.aten.sub.int %[[x0]], %[[x4]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[x7:.*]] = torch.aten.sub.int %[[x1]], %[[x5]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[x8:.*]] = torch.prim.ListConstruct %[[x7]], %[[x6]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[x9:.*]] = torch.aten.constant_pad_nd %arg0, %[[x8]], %[[float0]] : !torch.vtensor<[?,?],f32>, !torch.list, !torch.float -> !torch.vtensor<[?,?],f32> + // CHECK: return %[[x9]] : !torch.vtensor<[?,?],f32> %0 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> %1 = torch.vtensor.literal(dense<0> : tensor) : !torch.vtensor<[],si64> %float0.000000e00 = torch.constant.float 0.000000e+00 diff --git a/test/python/fx_importer/v2.3/auto_functionalized.py b/test/python/fx_importer/v2.3/auto_functionalized.py index 7fb0eeb3b67f..ab7401dcc2fb 100644 --- a/test/python/fx_importer/v2.3/auto_functionalized.py +++ b/test/python/fx_importer/v2.3/auto_functionalized.py @@ -59,9 +59,8 @@ def forward(self, x): # AssertionError: Current active mode not registered decomposition_table=[], ) - # The Torch 2.6 expects the IR to be same as the below one, while the torch versions < 2.6 does not, hence this check is kept as a "COM". - # COM: torch.operator "torch.torch_mlir_test.inplace_modify"({{.*}}) : (!torch.vtensor<[3,4],f32>) -> () - # CHECK: torch.aten.mul.Tensor %{{.*}}, %{{.*}} + # CHECK: %[[TIED:.*]] = torch.operator "torch.torch_mlir_test.inplace_modify"({{.*}}) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> + # CHECK: torch.aten.mul.Tensor %[[TIED]], %[[TIED]] print(m) m.operation.verify() @@ -87,8 +86,7 @@ def forward(self, x): # AssertionError: Current active mode not registered decomposition_table=[], ) - # The Torch 2.6 expects the IR to be same as the below one, while the torch versions < 2.6 does not, hence this check is kept as a "COM". - # COM: %[[TIED:.*]] = torch.operator "torch.torch_mlir_test.inplace_modify_calc"(%arg0) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> - # CHECK: torch.aten.mul.Tensor %{{.*}}, %{{.*}} + # CHECK: %[[TIED:.*]]:2 = torch.operator "torch.torch_mlir_test.inplace_modify_calc"(%0) : (!torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>) + # CHECK: torch.aten.mul.Tensor %[[TIED]]#1, %[[TIED]]#0 print(m) m.operation.verify() diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index be1615525984..901fbd3d9a84 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torchvision/ --pre -torchvision==0.22.0.dev20241216 +torchvision==0.20.0.dev20241029