Skip to content

Commit

Permalink
Lower aten::_assert_scalar to torch.runtime.assert
Browse files Browse the repository at this point in the history
  • Loading branch information
praveen-g-ctt committed Jan 8, 2025
1 parent 1dd573c commit f42e9ba
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 107 deletions.
39 changes: 4 additions & 35 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "torch-mlir/Conversion/TorchToLinalg/Utils.h"
#include "torch-mlir/Conversion/Utils/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "llvm/ADT/APSInt.h"
Expand Down Expand Up @@ -3580,23 +3581,17 @@ class ConvertSymConstrainRangeOp
auto min = op.getMin();
auto max = op.getMax();

auto minOp = min.getDefiningOp();
auto maxOp = max.getDefiningOp();

if (!minOp || !maxOp)
return op.emitError("Unimplemented: Non constant min/max values");

int64_t minValue = std::numeric_limits<int64_t>::min();
int64_t maxValue = std::numeric_limits<int64_t>::max();

Type operandType = getTypeConverter()->convertType(op.getSize().getType());

if (!isa<Torch::ConstantNoneOp>(minOp))
if (!isa<Torch::NoneType>(min.getType()))
if (!matchPattern(min, m_TorchConstantInt(&minValue)))
return rewriter.notifyMatchFailure(
op, "Expected min value to be constant integer");

if (!isa<Torch::ConstantNoneOp>(maxOp))
if (!isa<Torch::NoneType>(max.getType()))
if (!matchPattern(max, m_TorchConstantInt(&maxValue)))
return rewriter.notifyMatchFailure(
op, "Expected max value to be constant integer");
Expand All @@ -3621,7 +3616,7 @@ class ConvertSymConstrainRangeOp
loc, arith::CmpIPredicate::sle, adaptor.getSize(), max);
auto compareVal = rewriter.create<arith::AndIOp>(loc, checkMin, checkMax);

std::string assertMessage = "Invalid value range for size between [" +
std::string assertMessage = "Size constraint failed. Expected range: [" +
std::to_string(minValue) + ", " +
std::to_string(maxValue) + "]";
rewriter.create<cf::AssertOp>(loc, compareVal,
Expand All @@ -3633,30 +3628,6 @@ class ConvertSymConstrainRangeOp
};
} // namespace

namespace {
class ConvertAssertScalarOp : public OpConversionPattern<Aten_AssertScalarOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(Aten_AssertScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

auto assertCond = convertScalarToDtype(
rewriter, op.getLoc(), adaptor.getSelf(), rewriter.getI1Type());

std::string assertMessage;
if (!matchPattern(op.getAssertMsg(), m_TorchConstantStr(assertMessage)))
return rewriter.notifyMatchFailure(
op, "Assert message must be a constant string");

rewriter.replaceOpWithNewOp<cf::AssertOp>(op, assertCond, assertMessage);
return success();
}
};
} // namespace

void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
Expand Down Expand Up @@ -3721,6 +3692,4 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
patterns.add<ConvertAtenPolarOp>(typeConverter, context);
target.addIllegalOp<AtenSymConstrainRangeOp>();
patterns.add<ConvertSymConstrainRangeOp>(typeConverter, context);
target.addIllegalOp<Aten_AssertScalarOp>();
patterns.add<ConvertAssertScalarOp>(typeConverter, context);
}
44 changes: 35 additions & 9 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11456,7 +11456,7 @@ class DecomposeAtenSpecialExpm1Op
} // namespace

namespace {
class DecomposeConstrainRangeForSizeOp
class DecomposeAtenConstrainRangeForSizeOp
: public OpRewritePattern<AtenSymConstrainRangeForSizeOp> {
public:
using OpRewritePattern<AtenSymConstrainRangeForSizeOp>::OpRewritePattern;
Expand All @@ -11466,15 +11466,10 @@ class DecomposeConstrainRangeForSizeOp
auto loc = op.getLoc();
auto min = op.getMin();
auto max = op.getMax();
auto minOp = min.getDefiningOp();
auto maxOp = max.getDefiningOp();

if (!minOp || !maxOp)
return op.emitError("Unimplemented: Non constant min/max values");

int64_t minValue, maxValue;

if (isa<Torch::ConstantNoneOp>(minOp)) {
if (isa<Torch::NoneType>(min.getType())) {
// Set min value to 0
min = rewriter.create<Torch::ConstantIntOp>(loc, 0);
} else {
Expand All @@ -11484,7 +11479,7 @@ class DecomposeConstrainRangeForSizeOp
op, "Expected min value to be constant integer");
}

if (!isa<Torch::ConstantNoneOp>(maxOp)) {
if (!isa<Torch::NoneType>(max.getType())) {
// Verify that max value is greater than 2
if (!matchPattern(max, m_TorchConstantInt(&maxValue)))
return rewriter.notifyMatchFailure(
Expand All @@ -11505,6 +11500,35 @@ class DecomposeConstrainRangeForSizeOp
};
} // namespace

namespace {
class DecomposeAten_AssertScalarOp
: public OpRewritePattern<Aten_AssertScalarOp> {
public:
using OpRewritePattern<Aten_AssertScalarOp>::OpRewritePattern;
LogicalResult matchAndRewrite(Aten_AssertScalarOp op,
PatternRewriter &rewriter) const override {

auto loc = op.getLoc();
auto assertCond = op.getSelf();

if (isa<Torch::IntType>(assertCond.getType()))
assertCond = rewriter.create<AtenBoolIntOp>(loc, assertCond);
else if (isa<Torch::FloatType>(assertCond.getType()))
assertCond = rewriter.create<AtenBoolFloatOp>(loc, assertCond);
assert(isa<Torch::BoolType>(assertCond.getType()) &&
"Unhandled type encountered in aten._assert_scalar op");

std::string assertMessage;
if (!matchPattern(op.getAssertMsg(), m_TorchConstantStr(assertMessage)))
return rewriter.notifyMatchFailure(
op, "Assert message must be a constant string");

rewriter.replaceOpWithNewOp<RuntimeAssertOp>(op, assertCond, assertMessage);
return success();
}
};
} // namespace

namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
Expand Down Expand Up @@ -11803,7 +11827,9 @@ class DecomposeComplexOpsPass
// Torchvision ops
addPatternIfTargetOpIsIllegal<DecomposeTorchvisionNmsOp>(patterns);

addPatternIfTargetOpIsIllegal<DecomposeConstrainRangeForSizeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenConstrainRangeForSizeOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_AssertScalarOp>(patterns);

GreedyRewriteConfig config;
config.useTopDownTraversal = true;
Expand Down
8 changes: 7 additions & 1 deletion projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
# Unknown builtin op: aten::_check_is_size in TorchScript
"AtenSymConstrainRange_basic",
"AtenSymConstrainRangeForSize_basic",
"AtenAssertScalar",
"Aten_AssertScalar_basic",
}

if torch_version_for_comparison() < version.parse("2.5.0.dev"):
Expand Down Expand Up @@ -943,6 +943,9 @@
"BernoulliFloatModule_basic",
"UniformModule_basic",
"UniformStaticShapeModule_basic",
"AtenSymConstrainRange_basic",
"AtenSymConstrainRangeForSize_basic",
"Aten_AssertScalar_basic",
}

FX_IMPORTER_STABLEHLO_CRASHING_SET = {
Expand Down Expand Up @@ -3373,6 +3376,9 @@
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"AtenSymConstrainRange_basic",
"AtenSymConstrainRangeForSize_basic",
"Aten_AssertScalar_basic",
}

if torch_version_for_comparison() < version.parse("2.3.0.dev"):
Expand Down
14 changes: 5 additions & 9 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6466,7 +6466,6 @@ def __init__(self):
@annotate_args([None, ([-1], torch.int, True)])
def forward(self, x):
a = x.item()
torch._check_is_size(a)
torch.ops.aten.sym_constrain_range(a, max=5)
return a

Expand All @@ -6487,8 +6486,6 @@ def __init__(self):
@annotate_args([None, ([-1], torch.int, True)])
def forward(self, x):
a = x.item()
torch._check_is_size(a)
# max should be > 2
torch.ops.aten.sym_constrain_range_for_size(a, min=0, max=10)
return a

Expand All @@ -6499,20 +6496,19 @@ def AtenSymConstrainRangeForSize_basic(module, tu: TestUtils):


# ==============================================================================
class AtenAssertScalar(torch.nn.Module):
class Aten_AssertScalar(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([None, ([-1], torch.int, True)])
def forward(self, x):
a = x.item()
# The below checks introduces aten._assert_scalar op
torch._check_is_size(a)
torch._check(a <= 5)
assert_msg = "Assertion failed for condition x.item() > 3"
torch.ops.aten._assert_scalar(a > 3, assert_msg)
return a


@register_test_case(module_factory=lambda: AtenAssertScalar())
def AtenAssertScalar_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: Aten_AssertScalar())
def Aten_AssertScalar_basic(module, tu: TestUtils):
module.forward(torch.tensor(4))
52 changes: 2 additions & 50 deletions test/Conversion/TorchToLinalg/constraints.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
// -----

// CHECK-LABEL: func.func @torch.aten.sym_constrain_range(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>) -> !torch.int {
Expand All @@ -13,13 +12,13 @@
// CHECK: %[[VAL_8:.*]] = arith.cmpi sle, %[[VAL_6]], %[[VAL_5]] : i64
// CHECK: %[[VAL_9:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_7]] : i64
// CHECK: %[[VAL_10:.*]] = arith.andi %[[VAL_8]], %[[VAL_9]] : i1
// CHECK: cf.assert %[[VAL_10]], "Invalid value range for size between [0, 9223372036854775807]"
// CHECK: cf.assert %[[VAL_10]], "Size constraint failed. Expected range: [0, 9223372036854775807]"
// CHECK: %[[VAL_11:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_12:.*]] = arith.constant 7 : i64
// CHECK: %[[VAL_13:.*]] = arith.cmpi sle, %[[VAL_11]], %[[VAL_5]] : i64
// CHECK: %[[VAL_14:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_12]] : i64
// CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_13]], %[[VAL_14]] : i1
// CHECK: cf.assert %[[VAL_15]], "Invalid value range for size between [0, 7]"
// CHECK: cf.assert %[[VAL_15]], "Size constraint failed. Expected range: [0, 7]"
// CHECK: return %[[VAL_4]] : !torch.int

func.func @torch.aten.sym_constrain_range(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
Expand All @@ -31,50 +30,3 @@ func.func @torch.aten.sym_constrain_range(%arg0: !torch.vtensor<[],si64>) -> !to
torch.aten.sym_constrain_range %0, %int0, %int7 : !torch.int, !torch.int, !torch.int
return %0 : !torch.int
}

// -----

// CHECK-LABEL: func.func @torch.aten._assert_scalar(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>) -> !torch.int {
// CHECK: %[[VAL_1:.*]] = torch.constant.str "Runtime assertion failed for expression u0 <= 7 on node 'le_1'"
// CHECK: %[[VAL_2:.*]] = torch.constant.int 7
// CHECK: %[[VAL_3:.*]] = torch.constant.str "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"
// CHECK: %[[VAL_4:.*]] = torch.constant.int 0
// CHECK: %[[VAL_5:.*]] = torch.constant.none
// CHECK: %[[VAL_6:.*]] = torch.aten.item %[[VAL_0]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[VAL_7:.*]] = torch_c.to_i64 %[[VAL_6]]
// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_9:.*]] = arith.constant 9223372036854775807 : i64
// CHECK: %[[VAL_10:.*]] = arith.cmpi sle, %[[VAL_8]], %[[VAL_7]] : i64
// CHECK: %[[VAL_11:.*]] = arith.cmpi sle, %[[VAL_7]], %[[VAL_9]] : i64
// CHECK: %[[VAL_12:.*]] = arith.andi %[[VAL_10]], %[[VAL_11]] : i1
// CHECK: cf.assert %[[VAL_12]], "Invalid value range for size between [0, 9223372036854775807]"
// CHECK: %[[VAL_13:.*]] = torch.aten.ge.int %[[VAL_6]], %[[VAL_4]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[VAL_14:.*]] = torch.aten.Int.bool %[[VAL_13]] : !torch.bool -> !torch.int
// CHECK: %[[VAL_15:.*]] = torch_c.to_i64 %[[VAL_14]]
// CHECK: %[[VAL_16:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_17:.*]] = arith.cmpi ne, %[[VAL_15]], %[[VAL_16]] : i64
// CHECK: cf.assert %[[VAL_17]], "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"
// CHECK: %[[VAL_18:.*]] = torch.aten.le.int %[[VAL_6]], %[[VAL_2]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[VAL_19:.*]] = torch.aten.Int.bool %[[VAL_18]] : !torch.bool -> !torch.int
// CHECK: %[[VAL_20:.*]] = torch_c.to_i64 %[[VAL_19]]
// CHECK: %[[VAL_21:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_22:.*]] = arith.cmpi ne, %[[VAL_20]], %[[VAL_21]] : i64
// CHECK: cf.assert %[[VAL_22]], "Runtime assertion failed for expression u0 <= 7 on node 'le_1'"
// CHECK: return %[[VAL_6]] : !torch.int
func.func @torch.aten._assert_scalar(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
%str = torch.constant.str "Runtime assertion failed for expression u0 <= 7 on node 'le_1'"
%int7 = torch.constant.int 7
%str_0 = torch.constant.str "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"
%int0 = torch.constant.int 0
%none = torch.constant.none
%0 = torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int
torch.aten.sym_constrain_range %0, %int0, %none : !torch.int, !torch.int, !torch.none
%1 = torch.aten.ge.int %0, %int0 : !torch.int, !torch.int -> !torch.bool
%2 = torch.aten.Int.bool %1 : !torch.bool -> !torch.int
torch.aten._assert_scalar %2, %str_0 : !torch.int, !torch.str
%3 = torch.aten.le.int %0, %int7 : !torch.int, !torch.int -> !torch.bool
%4 = torch.aten.Int.bool %3 : !torch.bool -> !torch.int
torch.aten._assert_scalar %4, %str : !torch.int, !torch.str
return %0 : !torch.int
}
35 changes: 32 additions & 3 deletions test/Dialect/Torch/decompose-complex-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,7 @@ func.func @torch.aten.fft_rfft$2d_first_dim(%arg0: !torch.vtensor<[36,23],f32>)
// CHECK: torch.aten.sym_constrain_range %[[VAL_4]], %[[VAL_2]], %[[VAL_3]] : !torch.int, !torch.int, !torch.none
// CHECK: torch.aten.sym_constrain_range %[[VAL_4]], %[[VAL_2]], %[[VAL_1]] : !torch.int, !torch.int, !torch.int
// CHECK: return %[[VAL_4]] : !torch.int
module {
func.func @torch.aten.sym_constrain_range_for_size(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
func.func @torch.aten.sym_constrain_range_for_size(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
%0 = torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int
%none = torch.constant.none
%none_0 = torch.constant.none
Expand All @@ -250,5 +249,35 @@ module {
%int7_7 = torch.constant.int 7
torch.aten.sym_constrain_range_for_size %0, %int0_6, %int7_7 : !torch.int, !torch.int, !torch.int
return %0 : !torch.int
}
}

// -----

// CHECK-LABEL: func.func @torch.aten._assert_scalar(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[],si64>) -> !torch.int {
// CHECK: %[[VAL_1:.*]] = torch.constant.int 2
// CHECK: %[[VAL_2:.*]] = torch.constant.int 3
// CHECK: %[[VAL_3:.*]] = torch.aten.item %[[VAL_0]] : !torch.vtensor<[],si64> -> !torch.int
// CHECK: %[[VAL_4:.*]] = torch.aten.ge.int %[[VAL_3]], %[[VAL_2]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[VAL_5:.*]] = torch.aten.Int.bool %[[VAL_4]] : !torch.bool -> !torch.int
// CHECK: %[[VAL_6:.*]] = torch.aten.Bool.int %[[VAL_5]] : !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_6]], "Runtime assertion failed for expression u0 >= 3 on node 'ge_1'"
// CHECK: %[[VAL_7:.*]] = torch.aten.gt.int %[[VAL_3]], %[[VAL_1]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[VAL_8:.*]] = torch.aten.Int.bool %[[VAL_7]] : !torch.bool -> !torch.int
// CHECK: %[[VAL_9:.*]] = torch.aten.Bool.int %[[VAL_8]] : !torch.int -> !torch.bool
// CHECK: torch.runtime.assert %[[VAL_9]], "Runtime assertion failed for expression 2 < u0 on node 'gt_1'"
// CHECK: return %[[VAL_3]] : !torch.int
func.func @torch.aten._assert_scalar(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
%0 = torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int
%int3 = torch.constant.int 3
%1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool
%2 = torch.aten.Int.bool %1 : !torch.bool -> !torch.int
%str = torch.constant.str "Runtime assertion failed for expression u0 >= 3 on node 'ge_1'"
torch.aten._assert_scalar %2, %str : !torch.int, !torch.str
%int2 = torch.constant.int 2
%3 = torch.aten.gt.int %0, %int2 : !torch.int, !torch.int -> !torch.bool
%4 = torch.aten.Int.bool %3 : !torch.bool -> !torch.int
%str_0 = torch.constant.str "Runtime assertion failed for expression 2 < u0 on node 'gt_1'"
torch.aten._assert_scalar %4, %str_0 : !torch.int, !torch.str
return %0 : !torch.int
}

0 comments on commit f42e9ba

Please sign in to comment.