From f42e9ba237597cf3939d90675c6d4b04d692f512 Mon Sep 17 00:00:00 2001 From: Praveen G Date: Wed, 8 Jan 2025 12:48:37 +0000 Subject: [PATCH] Lower aten::_assert_scalar to torch.runtime.assert --- .../TorchToLinalg/Uncategorized.cpp | 39 ++------------ .../Torch/Transforms/DecomposeComplexOps.cpp | 44 ++++++++++++---- projects/pt1/e2e_testing/xfail_sets.py | 8 ++- .../torch_mlir_e2e_test/test_suite/basic.py | 14 ++--- .../Conversion/TorchToLinalg/constraints.mlir | 52 +------------------ test/Dialect/Torch/decompose-complex-ops.mlir | 35 +++++++++++-- 6 files changed, 85 insertions(+), 107 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 8aaed117edd74..68cc77dc97c6a 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -21,6 +21,7 @@ #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/APSInt.h" @@ -3580,23 +3581,17 @@ class ConvertSymConstrainRangeOp auto min = op.getMin(); auto max = op.getMax(); - auto minOp = min.getDefiningOp(); - auto maxOp = max.getDefiningOp(); - - if (!minOp || !maxOp) - return op.emitError("Unimplemented: Non constant min/max values"); - int64_t minValue = std::numeric_limits::min(); int64_t maxValue = std::numeric_limits::max(); Type operandType = getTypeConverter()->convertType(op.getSize().getType()); - if (!isa(minOp)) + if (!isa(min.getType())) if (!matchPattern(min, m_TorchConstantInt(&minValue))) return rewriter.notifyMatchFailure( op, "Expected min value to be constant integer"); - if (!isa(maxOp)) + if (!isa(max.getType())) if (!matchPattern(max, m_TorchConstantInt(&maxValue))) return rewriter.notifyMatchFailure( op, "Expected max value to be constant integer"); @@ -3621,7 +3616,7 @@ class ConvertSymConstrainRangeOp loc, arith::CmpIPredicate::sle, adaptor.getSize(), max); auto compareVal = rewriter.create(loc, checkMin, checkMax); - std::string assertMessage = "Invalid value range for size between [" + + std::string assertMessage = "Size constraint failed. Expected range: [" + std::to_string(minValue) + ", " + std::to_string(maxValue) + "]"; rewriter.create(loc, compareVal, @@ -3633,30 +3628,6 @@ class ConvertSymConstrainRangeOp }; } // namespace -namespace { -class ConvertAssertScalarOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(Aten_AssertScalarOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (failed(verifyLinalgCompatibleTypes(op, rewriter))) - return failure(); - - auto assertCond = convertScalarToDtype( - rewriter, op.getLoc(), adaptor.getSelf(), rewriter.getI1Type()); - - std::string assertMessage; - if (!matchPattern(op.getAssertMsg(), m_TorchConstantStr(assertMessage))) - return rewriter.notifyMatchFailure( - op, "Assert message must be a constant string"); - - rewriter.replaceOpWithNewOp(op, assertCond, assertMessage); - return success(); - } -}; -} // namespace - void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -3721,6 +3692,4 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 910f31013bf9d..f67e2ccaa25ef 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -11456,7 +11456,7 @@ class DecomposeAtenSpecialExpm1Op } // namespace namespace { -class DecomposeConstrainRangeForSizeOp +class DecomposeAtenConstrainRangeForSizeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -11466,15 +11466,10 @@ class DecomposeConstrainRangeForSizeOp auto loc = op.getLoc(); auto min = op.getMin(); auto max = op.getMax(); - auto minOp = min.getDefiningOp(); - auto maxOp = max.getDefiningOp(); - - if (!minOp || !maxOp) - return op.emitError("Unimplemented: Non constant min/max values"); int64_t minValue, maxValue; - if (isa(minOp)) { + if (isa(min.getType())) { // Set min value to 0 min = rewriter.create(loc, 0); } else { @@ -11484,7 +11479,7 @@ class DecomposeConstrainRangeForSizeOp op, "Expected min value to be constant integer"); } - if (!isa(maxOp)) { + if (!isa(max.getType())) { // Verify that max value is greater than 2 if (!matchPattern(max, m_TorchConstantInt(&maxValue))) return rewriter.notifyMatchFailure( @@ -11505,6 +11500,35 @@ class DecomposeConstrainRangeForSizeOp }; } // namespace +namespace { +class DecomposeAten_AssertScalarOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Aten_AssertScalarOp op, + PatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + auto assertCond = op.getSelf(); + + if (isa(assertCond.getType())) + assertCond = rewriter.create(loc, assertCond); + else if (isa(assertCond.getType())) + assertCond = rewriter.create(loc, assertCond); + assert(isa(assertCond.getType()) && + "Unhandled type encountered in aten._assert_scalar op"); + + std::string assertMessage; + if (!matchPattern(op.getAssertMsg(), m_TorchConstantStr(assertMessage))) + return rewriter.notifyMatchFailure( + op, "Assert message must be a constant string"); + + rewriter.replaceOpWithNewOp(op, assertCond, assertMessage); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -11803,7 +11827,9 @@ class DecomposeComplexOpsPass // Torchvision ops addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal( + patterns); + addPatternIfTargetOpIsIllegal(patterns); GreedyRewriteConfig config; config.useTopDownTraversal = true; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 356c881fb5791..989985fde4fe1 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -38,7 +38,7 @@ # Unknown builtin op: aten::_check_is_size in TorchScript "AtenSymConstrainRange_basic", "AtenSymConstrainRangeForSize_basic", - "AtenAssertScalar", + "Aten_AssertScalar_basic", } if torch_version_for_comparison() < version.parse("2.5.0.dev"): @@ -943,6 +943,9 @@ "BernoulliFloatModule_basic", "UniformModule_basic", "UniformStaticShapeModule_basic", + "AtenSymConstrainRange_basic", + "AtenSymConstrainRangeForSize_basic", + "Aten_AssertScalar_basic", } FX_IMPORTER_STABLEHLO_CRASHING_SET = { @@ -3373,6 +3376,9 @@ "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic", "Aten_TrilinearModuleZerodDimBug_basic", + "AtenSymConstrainRange_basic", + "AtenSymConstrainRangeForSize_basic", + "Aten_AssertScalar_basic", } if torch_version_for_comparison() < version.parse("2.3.0.dev"): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 369de448c55cf..cf29a96b54a03 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -6466,7 +6466,6 @@ def __init__(self): @annotate_args([None, ([-1], torch.int, True)]) def forward(self, x): a = x.item() - torch._check_is_size(a) torch.ops.aten.sym_constrain_range(a, max=5) return a @@ -6487,8 +6486,6 @@ def __init__(self): @annotate_args([None, ([-1], torch.int, True)]) def forward(self, x): a = x.item() - torch._check_is_size(a) - # max should be > 2 torch.ops.aten.sym_constrain_range_for_size(a, min=0, max=10) return a @@ -6499,7 +6496,7 @@ def AtenSymConstrainRangeForSize_basic(module, tu: TestUtils): # ============================================================================== -class AtenAssertScalar(torch.nn.Module): +class Aten_AssertScalar(torch.nn.Module): def __init__(self): super().__init__() @@ -6507,12 +6504,11 @@ def __init__(self): @annotate_args([None, ([-1], torch.int, True)]) def forward(self, x): a = x.item() - # The below checks introduces aten._assert_scalar op - torch._check_is_size(a) - torch._check(a <= 5) + assert_msg = "Assertion failed for condition x.item() > 3" + torch.ops.aten._assert_scalar(a > 3, assert_msg) return a -@register_test_case(module_factory=lambda: AtenAssertScalar()) -def AtenAssertScalar_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: Aten_AssertScalar()) +def Aten_AssertScalar_basic(module, tu: TestUtils): module.forward(torch.tensor(4)) diff --git a/test/Conversion/TorchToLinalg/constraints.mlir b/test/Conversion/TorchToLinalg/constraints.mlir index bc48da402fb8d..11bafaa973d1c 100644 --- a/test/Conversion/TorchToLinalg/constraints.mlir +++ b/test/Conversion/TorchToLinalg/constraints.mlir @@ -1,5 +1,4 @@ // RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s -// ----- // CHECK-LABEL: func.func @torch.aten.sym_constrain_range( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>) -> !torch.int { @@ -13,13 +12,13 @@ // CHECK: %[[VAL_8:.*]] = arith.cmpi sle, %[[VAL_6]], %[[VAL_5]] : i64 // CHECK: %[[VAL_9:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_7]] : i64 // CHECK: %[[VAL_10:.*]] = arith.andi %[[VAL_8]], %[[VAL_9]] : i1 -// CHECK: cf.assert %[[VAL_10]], "Invalid value range for size between [0, 9223372036854775807]" +// CHECK: cf.assert %[[VAL_10]], "Size constraint failed. Expected range: [0, 9223372036854775807]" // CHECK: %[[VAL_11:.*]] = arith.constant 0 : i64 // CHECK: %[[VAL_12:.*]] = arith.constant 7 : i64 // CHECK: %[[VAL_13:.*]] = arith.cmpi sle, %[[VAL_11]], %[[VAL_5]] : i64 // CHECK: %[[VAL_14:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_12]] : i64 // CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_13]], %[[VAL_14]] : i1 -// CHECK: cf.assert %[[VAL_15]], "Invalid value range for size between [0, 7]" +// CHECK: cf.assert %[[VAL_15]], "Size constraint failed. Expected range: [0, 7]" // CHECK: return %[[VAL_4]] : !torch.int func.func @torch.aten.sym_constrain_range(%arg0: !torch.vtensor<[],si64>) -> !torch.int { @@ -31,50 +30,3 @@ func.func @torch.aten.sym_constrain_range(%arg0: !torch.vtensor<[],si64>) -> !to torch.aten.sym_constrain_range %0, %int0, %int7 : !torch.int, !torch.int, !torch.int return %0 : !torch.int } - -// ----- - -// CHECK-LABEL: func.func @torch.aten._assert_scalar( -// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>) -> !torch.int { -// CHECK: %[[VAL_1:.*]] = torch.constant.str "Runtime assertion failed for expression u0 <= 7 on node 'le_1'" -// CHECK: %[[VAL_2:.*]] = torch.constant.int 7 -// CHECK: %[[VAL_3:.*]] = torch.constant.str "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'" -// CHECK: %[[VAL_4:.*]] = torch.constant.int 0 -// CHECK: %[[VAL_5:.*]] = torch.constant.none -// CHECK: %[[VAL_6:.*]] = torch.aten.item %[[VAL_0]] : !torch.vtensor<[],si64> -> !torch.int -// CHECK: %[[VAL_7:.*]] = torch_c.to_i64 %[[VAL_6]] -// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_9:.*]] = arith.constant 9223372036854775807 : i64 -// CHECK: %[[VAL_10:.*]] = arith.cmpi sle, %[[VAL_8]], %[[VAL_7]] : i64 -// CHECK: %[[VAL_11:.*]] = arith.cmpi sle, %[[VAL_7]], %[[VAL_9]] : i64 -// CHECK: %[[VAL_12:.*]] = arith.andi %[[VAL_10]], %[[VAL_11]] : i1 -// CHECK: cf.assert %[[VAL_12]], "Invalid value range for size between [0, 9223372036854775807]" -// CHECK: %[[VAL_13:.*]] = torch.aten.ge.int %[[VAL_6]], %[[VAL_4]] : !torch.int, !torch.int -> !torch.bool -// CHECK: %[[VAL_14:.*]] = torch.aten.Int.bool %[[VAL_13]] : !torch.bool -> !torch.int -// CHECK: %[[VAL_15:.*]] = torch_c.to_i64 %[[VAL_14]] -// CHECK: %[[VAL_16:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_17:.*]] = arith.cmpi ne, %[[VAL_15]], %[[VAL_16]] : i64 -// CHECK: cf.assert %[[VAL_17]], "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'" -// CHECK: %[[VAL_18:.*]] = torch.aten.le.int %[[VAL_6]], %[[VAL_2]] : !torch.int, !torch.int -> !torch.bool -// CHECK: %[[VAL_19:.*]] = torch.aten.Int.bool %[[VAL_18]] : !torch.bool -> !torch.int -// CHECK: %[[VAL_20:.*]] = torch_c.to_i64 %[[VAL_19]] -// CHECK: %[[VAL_21:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_22:.*]] = arith.cmpi ne, %[[VAL_20]], %[[VAL_21]] : i64 -// CHECK: cf.assert %[[VAL_22]], "Runtime assertion failed for expression u0 <= 7 on node 'le_1'" -// CHECK: return %[[VAL_6]] : !torch.int -func.func @torch.aten._assert_scalar(%arg0: !torch.vtensor<[],si64>) -> !torch.int { - %str = torch.constant.str "Runtime assertion failed for expression u0 <= 7 on node 'le_1'" - %int7 = torch.constant.int 7 - %str_0 = torch.constant.str "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'" - %int0 = torch.constant.int 0 - %none = torch.constant.none - %0 = torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int - torch.aten.sym_constrain_range %0, %int0, %none : !torch.int, !torch.int, !torch.none - %1 = torch.aten.ge.int %0, %int0 : !torch.int, !torch.int -> !torch.bool - %2 = torch.aten.Int.bool %1 : !torch.bool -> !torch.int - torch.aten._assert_scalar %2, %str_0 : !torch.int, !torch.str - %3 = torch.aten.le.int %0, %int7 : !torch.int, !torch.int -> !torch.bool - %4 = torch.aten.Int.bool %3 : !torch.bool -> !torch.int - torch.aten._assert_scalar %4, %str : !torch.int, !torch.str - return %0 : !torch.int -} diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 0adb10edac801..be3f6548fc98f 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -240,8 +240,7 @@ func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>) // CHECK: torch.aten.sym_constrain_range %[[VAL_4]], %[[VAL_2]], %[[VAL_3]] : !torch.int, !torch.int, !torch.none // CHECK: torch.aten.sym_constrain_range %[[VAL_4]], %[[VAL_2]], %[[VAL_1]] : !torch.int, !torch.int, !torch.int // CHECK: return %[[VAL_4]] : !torch.int -module { - func.func @torch.aten.sym_constrain_range_for_size(%arg0: !torch.vtensor<[],si64>) -> !torch.int { +func.func @torch.aten.sym_constrain_range_for_size(%arg0: !torch.vtensor<[],si64>) -> !torch.int { %0 = torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int %none = torch.constant.none %none_0 = torch.constant.none @@ -250,5 +249,35 @@ module { %int7_7 = torch.constant.int 7 torch.aten.sym_constrain_range_for_size %0, %int0_6, %int7_7 : !torch.int, !torch.int, !torch.int return %0 : !torch.int - } +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten._assert_scalar( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>) -> !torch.int { +// CHECK: %[[VAL_1:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_2:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_3:.*]] = torch.aten.item %[[VAL_0]] : !torch.vtensor<[],si64> -> !torch.int +// CHECK: %[[VAL_4:.*]] = torch.aten.ge.int %[[VAL_3]], %[[VAL_2]] : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[VAL_5:.*]] = torch.aten.Int.bool %[[VAL_4]] : !torch.bool -> !torch.int +// CHECK: %[[VAL_6:.*]] = torch.aten.Bool.int %[[VAL_5]] : !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[VAL_6]], "Runtime assertion failed for expression u0 >= 3 on node 'ge_1'" +// CHECK: %[[VAL_7:.*]] = torch.aten.gt.int %[[VAL_3]], %[[VAL_1]] : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[VAL_8:.*]] = torch.aten.Int.bool %[[VAL_7]] : !torch.bool -> !torch.int +// CHECK: %[[VAL_9:.*]] = torch.aten.Bool.int %[[VAL_8]] : !torch.int -> !torch.bool +// CHECK: torch.runtime.assert %[[VAL_9]], "Runtime assertion failed for expression 2 < u0 on node 'gt_1'" +// CHECK: return %[[VAL_3]] : !torch.int +func.func @torch.aten._assert_scalar(%arg0: !torch.vtensor<[],si64>) -> !torch.int { + %0 = torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int + %int3 = torch.constant.int 3 + %1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool + %2 = torch.aten.Int.bool %1 : !torch.bool -> !torch.int + %str = torch.constant.str "Runtime assertion failed for expression u0 >= 3 on node 'ge_1'" + torch.aten._assert_scalar %2, %str : !torch.int, !torch.str + %int2 = torch.constant.int 2 + %3 = torch.aten.gt.int %0, %int2 : !torch.int, !torch.int -> !torch.bool + %4 = torch.aten.Int.bool %3 : !torch.bool -> !torch.int + %str_0 = torch.constant.str "Runtime assertion failed for expression 2 < u0 on node 'gt_1'" + torch.aten._assert_scalar %4, %str_0 : !torch.int, !torch.str + return %0 : !torch.int }