Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
praveen-g-ctt committed Jan 7, 2025
1 parent 34c2599 commit 1dd573c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 15 deletions.
10 changes: 5 additions & 5 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3589,7 +3589,7 @@ class ConvertSymConstrainRangeOp
int64_t minValue = std::numeric_limits<int64_t>::min();
int64_t maxValue = std::numeric_limits<int64_t>::max();

Type operandType = rewriter.getI64Type();
Type operandType = getTypeConverter()->convertType(op.getSize().getType());

if (!isa<Torch::ConstantNoneOp>(minOp))
if (!matchPattern(min, m_TorchConstantInt(&minValue)))
Expand All @@ -3615,10 +3615,10 @@ class ConvertSymConstrainRangeOp

// FIXME:: Skip the below checks if constraint ops are already inserted as
// part of symbol expr evaluation
auto checkMin = createLessThanOrEqual(rewriter, loc, operandType, min,
adaptor.getSize());
auto checkMax = createLessThanOrEqual(rewriter, loc, operandType,
adaptor.getSize(), max);
auto checkMin = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sle, min, adaptor.getSize());
auto checkMax = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sle, adaptor.getSize(), max);
auto compareVal = rewriter.create<arith::AndIOp>(loc, checkMin, checkMax);

std::string assertMessage = "Invalid value range for size between [" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6488,7 +6488,7 @@ def __init__(self):
def forward(self, x):
a = x.item()
torch._check_is_size(a)
# max should be >= 2
# max should be > 2
torch.ops.aten.sym_constrain_range_for_size(a, min=0, max=10)
return a

Expand Down
16 changes: 7 additions & 9 deletions test/Conversion/TorchToLinalg/constraints.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,26 @@
// CHECK: %[[VAL_5:.*]] = torch_c.to_i64 %[[VAL_4]]
// CHECK: %[[VAL_6:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_7:.*]] = arith.constant 9223372036854775807 : i64
// CHECK: %[[VAL_8:.*]] = arith.cmpi ule, %[[VAL_6]], %[[VAL_5]] : i64
// CHECK: %[[VAL_9:.*]] = arith.cmpi ule, %[[VAL_5]], %[[VAL_7]] : i64
// CHECK: %[[VAL_8:.*]] = arith.cmpi sle, %[[VAL_6]], %[[VAL_5]] : i64
// CHECK: %[[VAL_9:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_7]] : i64
// CHECK: %[[VAL_10:.*]] = arith.andi %[[VAL_8]], %[[VAL_9]] : i1
// CHECK: cf.assert %[[VAL_10]], "Invalid value range for size between [0, 9223372036854775807]"
// CHECK: %[[VAL_11:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_12:.*]] = arith.constant 7 : i64
// CHECK: %[[VAL_13:.*]] = arith.cmpi ule, %[[VAL_11]], %[[VAL_5]] : i64
// CHECK: %[[VAL_14:.*]] = arith.cmpi ule, %[[VAL_5]], %[[VAL_12]] : i64
// CHECK: %[[VAL_13:.*]] = arith.cmpi sle, %[[VAL_11]], %[[VAL_5]] : i64
// CHECK: %[[VAL_14:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_12]] : i64
// CHECK: %[[VAL_15:.*]] = arith.andi %[[VAL_13]], %[[VAL_14]] : i1
// CHECK: cf.assert %[[VAL_15]], "Invalid value range for size between [0, 7]"
// CHECK: return %[[VAL_4]] : !torch.int

module {
func.func @torch.aten.sym_constrain_range(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
func.func @torch.aten.sym_constrain_range(%arg0: !torch.vtensor<[],si64>) -> !torch.int {
%int7 = torch.constant.int 7
%int0 = torch.constant.int 0
%none = torch.constant.none
%0 = torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int
torch.aten.sym_constrain_range %0, %int0, %none : !torch.int, !torch.int, !torch.none
torch.aten.sym_constrain_range %0, %int0, %int7 : !torch.int, !torch.int, !torch.int
return %0 : !torch.int
}
}

// -----
Expand All @@ -47,8 +45,8 @@ module {
// CHECK: %[[VAL_7:.*]] = torch_c.to_i64 %[[VAL_6]]
// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_9:.*]] = arith.constant 9223372036854775807 : i64
// CHECK: %[[VAL_10:.*]] = arith.cmpi ule, %[[VAL_8]], %[[VAL_7]] : i64
// CHECK: %[[VAL_11:.*]] = arith.cmpi ule, %[[VAL_7]], %[[VAL_9]] : i64
// CHECK: %[[VAL_10:.*]] = arith.cmpi sle, %[[VAL_8]], %[[VAL_7]] : i64
// CHECK: %[[VAL_11:.*]] = arith.cmpi sle, %[[VAL_7]], %[[VAL_9]] : i64
// CHECK: %[[VAL_12:.*]] = arith.andi %[[VAL_10]], %[[VAL_11]] : i1
// CHECK: cf.assert %[[VAL_12]], "Invalid value range for size between [0, 9223372036854775807]"
// CHECK: %[[VAL_13:.*]] = torch.aten.ge.int %[[VAL_6]], %[[VAL_4]] : !torch.int, !torch.int -> !torch.bool
Expand Down

0 comments on commit 1dd573c

Please sign in to comment.