Skip to content

Commit 79c50e1

Browse files
[TorchToLinalg] Lower count_nonzero to Linalg
1 parent c632c86 commit 79c50e1

File tree

9 files changed

+205
-0
lines changed

9 files changed

+205
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9902,6 +9902,31 @@ def Torch_AtenRot90Op : Torch_Op<"aten.rot90", [
99029902
let hasVerifier = 1;
99039903
}
99049904

9905+
def Torch_AtenCountNonzeroOp : Torch_Op<"aten.count_nonzero", [
9906+
AllowsTypeRefinement,
9907+
HasValueSemantics,
9908+
ReadOnly
9909+
]> {
9910+
let summary = "Generated op for `aten::count_nonzero : (Tensor, int?) -> (Tensor)`";
9911+
let arguments = (ins
9912+
AnyTorchTensorType:$self,
9913+
AnyTorchOptionalIntType:$dim
9914+
);
9915+
let results = (outs
9916+
AnyTorchOptionalTensorType:$result
9917+
);
9918+
let hasCustomAssemblyFormat = 1;
9919+
let extraClassDefinition = [{
9920+
ParseResult AtenCountNonzeroOp::parse(OpAsmParser &parser, OperationState &result) {
9921+
return parseDefaultTorchOp(parser, result, 2, 1);
9922+
}
9923+
void AtenCountNonzeroOp::print(OpAsmPrinter &printer) {
9924+
printDefaultTorchOp(printer, *this, 2, 1);
9925+
}
9926+
}];
9927+
let hasVerifier = 1;
9928+
}
9929+
99059930
def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
99069931
AllowsTypeRefinement,
99079932
HasValueSemantics,

lib/Dialect/Torch/IR/TorchOps.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6033,6 +6033,33 @@ LogicalResult AtenRot90Op::verify() {
60336033
return success();
60346034
}
60356035

6036+
//===----------------------------------------------------------------------===//
6037+
// AtenCountNonzero
6038+
//===----------------------------------------------------------------------===//
6039+
6040+
LogicalResult AtenCountNonzeroOp::verify() {
6041+
6042+
auto selfType = cast<BaseTensorType>(getSelf().getType());
6043+
6044+
if (!selfType.hasDtype() || !selfType.hasSizes())
6045+
return success();
6046+
6047+
if (!isa<Torch::IntType>(getDim().getType()) &&
6048+
!isa<Torch::NoneType>(getDim().getType()))
6049+
return emitOpError("parameter dim must be none or int type");
6050+
6051+
int64_t dim;
6052+
if (!matchPattern(getDim(), m_TorchConstantInt(&dim)))
6053+
return success();
6054+
6055+
int selfRank = selfType.getSizes().size();
6056+
if (dim >= selfRank || dim < -selfRank)
6057+
return emitOpError("expected dim to be in [ ")
6058+
<< -selfRank << ", " << selfRank - 1 << " ], but got dim = " << dim;
6059+
6060+
return success();
6061+
}
6062+
60366063
//===----------------------------------------------------------------------===//
60376064
// OnnxVariantRotaryEmbeddingOp
60386065
//===----------------------------------------------------------------------===//

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9191,6 +9191,39 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
91919191
" }\n"
91929192
" return %arg0 : !torch.list<int>\n"
91939193
" }\n"
9194+
" func.func @\"__torch_mlir_shape_fn.aten.count_nonzero\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>) -> !torch.list<int> {\n"
9195+
" %false = torch.constant.bool false\n"
9196+
" %str = torch.constant.str \"AssertionError: \"\n"
9197+
" %true = torch.constant.bool true\n"
9198+
" %none = torch.constant.none\n"
9199+
" %0 = torch.aten.__is__ %arg1, %none : !torch.optional<int>, !torch.none -> !torch.bool\n"
9200+
" %1 = torch.prim.If %0 -> (!torch.list<int>) {\n"
9201+
" %2 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
9202+
" torch.prim.If.yield %2 : !torch.list<int>\n"
9203+
" } else {\n"
9204+
" %2 = torch.prim.unchecked_cast %arg1 : !torch.optional<int> -> !torch.int\n"
9205+
" %3 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
9206+
" %4 = torch.aten.neg.int %3 : !torch.int -> !torch.int\n"
9207+
" %5 = torch.aten.lt.int %2, %4 : !torch.int, !torch.int -> !torch.bool\n"
9208+
" %6 = torch.prim.If %5 -> (!torch.bool) {\n"
9209+
" torch.prim.If.yield %true : !torch.bool\n"
9210+
" } else {\n"
9211+
" %9 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
9212+
" %10 = torch.aten.ge.int %2, %9 : !torch.int, !torch.int -> !torch.bool\n"
9213+
" torch.prim.If.yield %10 : !torch.bool\n"
9214+
" }\n"
9215+
" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n"
9216+
" torch.prim.If %7 -> () {\n"
9217+
" torch.prim.If.yield\n"
9218+
" } else {\n"
9219+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
9220+
" torch.prim.If.yield\n"
9221+
" }\n"
9222+
" %8 = func.call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %false) : (!torch.list<int>, !torch.optional<int>, !torch.bool) -> !torch.list<int>\n"
9223+
" torch.prim.If.yield %8 : !torch.list<int>\n"
9224+
" }\n"
9225+
" return %1 : !torch.list<int>\n"
9226+
" }\n"
91949227
" func.func @\"__torch_mlir_shape_fn.aten._to_copy\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
91959228
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
91969229
" return %0 : !torch.list<int>\n"
@@ -15782,6 +15815,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1578215815
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1578315816
" return %0#1 : !torch.int\n"
1578415817
" }\n"
15818+
" func.func @\"__torch_mlir_dtype_fn.aten.count_nonzero\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>) -> !torch.int {\n"
15819+
" %int4 = torch.constant.int 4\n"
15820+
" return %int4 : !torch.int\n"
15821+
" }\n"
1578515822
" func.func @\"__torch_mlir_dtype_fn.aten.rot90\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.int {\n"
1578615823
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1578715824
" return %0#1 : !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6661,6 +6661,47 @@ class DecomposeAtenRot90Op : public OpRewritePattern<AtenRot90Op> {
66616661
};
66626662
} // namespace
66636663

6664+
// Decompose aten.count_nonzero to aten.ne.Scalar and
6665+
// aten.sum/aten.sum.dim_IntList
6666+
namespace {
6667+
class DecomposeAtenCountNonzeroOp
6668+
: public OpRewritePattern<AtenCountNonzeroOp> {
6669+
using OpRewritePattern::OpRewritePattern;
6670+
LogicalResult matchAndRewrite(AtenCountNonzeroOp op,
6671+
PatternRewriter &rewriter) const override {
6672+
auto dim = op.getDim();
6673+
if (!isa<Torch::NoneType>(dim.getType()) &&
6674+
!isa<Torch::IntType>(dim.getType())) {
6675+
return rewriter.notifyMatchFailure(
6676+
op, "expected `dim` to be `None` or `int`");
6677+
}
6678+
Location loc = op.getLoc();
6679+
auto self = op.getSelf();
6680+
auto inputType = cast<BaseTensorType>(self.getType());
6681+
auto inpBoolTy = inputType.getWithSizesAndDtype(inputType.getSizes(),
6682+
rewriter.getI1Type());
6683+
auto cstZero =
6684+
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
6685+
auto nonZeroMask =
6686+
rewriter.create<AtenNeScalarOp>(loc, inpBoolTy, self, cstZero);
6687+
auto none = rewriter.create<ConstantNoneOp>(loc);
6688+
if (isa<Torch::NoneType>(dim.getType())) {
6689+
rewriter.replaceOpWithNewOp<AtenSumOp>(op, op.getResult().getType(),
6690+
nonZeroMask, none);
6691+
} else {
6692+
auto cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
6693+
auto dimIntList = rewriter.create<PrimListConstructOp>(
6694+
loc, ListType::get(IntType::get(op.getContext())),
6695+
SmallVector<Value>{dim});
6696+
rewriter.replaceOpWithNewOp<AtenSumDimIntListOp>(
6697+
op, op.getResult().getType(), nonZeroMask, dimIntList, cstFalse,
6698+
none);
6699+
}
6700+
return success();
6701+
}
6702+
};
6703+
} // namespace
6704+
66646705
// Decompose aten.std.correction to sqrt(var.correction(x))
66656706
namespace {
66666707
class DecomposeAtenStdCorrectionOp
@@ -12018,6 +12059,7 @@ class DecomposeComplexOpsPass
1201812059
addPatternIfTargetOpIsIllegal<DecomposeAtenVarCorrectionOp>(patterns);
1201912060
addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns);
1202012061
addPatternIfTargetOpIsIllegal<DecomposeAtenRot90Op>(patterns);
12062+
addPatternIfTargetOpIsIllegal<DecomposeAtenCountNonzeroOp>(patterns);
1202112063
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);
1202212064
addPatternIfTargetOpIsIllegal<DecomposeAtenSplitWithSizesOp>(patterns);
1202312065
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
413413
target.addIllegalOp<AtenMvOp>();
414414
target.addIllegalOp<AtenRenormOp>();
415415
target.addIllegalOp<AtenRot90Op>();
416+
target.addIllegalOp<AtenCountNonzeroOp>();
416417
target.addIllegalOp<AtenLinalgCrossOp>();
417418
target.addIllegalOp<Aten_LinalgDetOp>();
418419
target.addIllegalOp<AtenLinalgSlogdetOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2846,6 +2846,10 @@
28462846
"ConvolutionModule2DTransposeNonUnitOutputPadding_basic",
28472847
"ConvolutionModule2DTransposeStrided_basic",
28482848
"ConvolutionModule2DTranspose_basic",
2849+
# Error: onnx lowering,
2850+
"CountNonzeroModuleBool_Basic",
2851+
"CountNonzeroModuleF32_basic",
2852+
"CountNonzeroModuleI64_basic",
28492853
"Deg2radModule_basic",
28502854
"DivFloatModule_basic",
28512855
"DivIntModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,6 +1436,21 @@ def aten〇rot90〡shape(self: List[int], k: int = 1, dims: List[int] = (0, 1,))
14361436

14371437
return self
14381438

1439+
@check_shape_function([
1440+
Invocation(TensorOfShape(2, 3, 4)), # Basic case.
1441+
Invocation(TensorOfShape(2, 3, 4), dim = 1), # Test explicit dim.
1442+
Invocation(TensorOfShape(2, 3, 4), dim = -1), # Test explicit dim(negative).
1443+
Invocation(TensorOfShape(2, 3, 4), dim = -3), # Test explicit dim(negative).
1444+
Invocation(TensorOfShape(2, 3, 4), dim = 0), # Test explicit dim.
1445+
Invocation(TensorOfShape(2, 3, 4), dim = 2), # Test explicit maximum valid dim.
1446+
ErrorInvocation(TensorOfShape(2, 3, 4), dim = -4), # Test dim out of bound.
1447+
ErrorInvocation(TensorOfShape(2, 3, 4), dim = 3), # Test dim out of bound.
1448+
])
1449+
def aten〇count_nonzero〡shape(self: List[int], dim: Optional[int] = None) -> List[int]:
1450+
if dim is None: return []
1451+
assert not (dim < -len(self) or dim >= len(self))
1452+
return upstream_shape_functions.argmax(self, dim)
1453+
14391454
def aten〇_to_copy〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> List[int]:
14401455
return upstream_shape_functions.unary(self)
14411456

@@ -5514,6 +5529,9 @@ def aten〇diag_embed〡dtype(self_rank_dtype: Tuple[int, int], offset: int = 0,
55145529
self_rank, self_dtype = self_rank_dtype
55155530
return self_dtype
55165531

5532+
def aten〇count_nonzero〡dtype(self_rank_dtype: Tuple[int, int], dim: Optional[int] = None) -> int:
5533+
return torch.int64
5534+
55175535
def aten〇rot90〡dtype(self_rank_dtype: Tuple[int, int], k: int = 1, dims: List[int] = (0, 1,)) -> int:
55185536
self_rank, self_dtype = self_rank_dtype
55195537
return self_dtype

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,7 @@ def emit_with_mutating_variants(key, **kwargs):
787787
emit("aten::diag_embed : (Tensor, int, int, int) -> (Tensor)")
788788
emit("aten::_weight_norm_interface : (Tensor, Tensor, int) -> (Tensor, Tensor)")
789789
emit("aten::rot90 : (Tensor, int, int[]) -> (Tensor)", has_verifier=True)
790+
emit("aten::count_nonzero : (Tensor, int?) -> (Tensor)", has_verifier=True)
790791

791792
# Misc tensor ops.
792793
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")

projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2496,3 +2496,53 @@ def TraceUnsignedIntModule_basic(module, tu: TestUtils):
24962496
@register_test_case(module_factory=lambda: TraceIntModule())
24972497
def TraceUnsignedIntModule_empty(module, tu: TestUtils):
24982498
module.forward(tu.randint(0, 0, low=0, high=10))
2499+
2500+
2501+
# ==============================================================================
2502+
2503+
2504+
class CountNonzeroModuleI64(torch.nn.Module):
2505+
def __init__(self):
2506+
super().__init__()
2507+
self.dim = 2
2508+
2509+
@export
2510+
@annotate_args([None, ([2, 3, 4], torch.int64, True)])
2511+
def forward(self, x):
2512+
return torch.ops.aten.count_nonzero(x, self.dim)
2513+
2514+
2515+
@register_test_case(module_factory=lambda: CountNonzeroModuleI64())
2516+
def CountNonzeroModuleI64_basic(module, tu: TestUtils):
2517+
module.forward(tu.randint(2, 3, 4, low=-2, high=2))
2518+
2519+
2520+
class CountNonzeroModuleF32(torch.nn.Module):
2521+
def __init__(self):
2522+
super().__init__()
2523+
self.dim = -3
2524+
2525+
@export
2526+
@annotate_args([None, ([2, 3, 4], torch.float32, True)])
2527+
def forward(self, x):
2528+
return torch.ops.aten.count_nonzero(x, self.dim)
2529+
2530+
2531+
@register_test_case(module_factory=lambda: CountNonzeroModuleF32())
2532+
def CountNonzeroModuleF32_basic(module, tu: TestUtils):
2533+
module.forward(tu.rand(2, 3, 4))
2534+
2535+
2536+
class CountNonzeroModuleBool(torch.nn.Module):
2537+
def __init__(self):
2538+
super().__init__()
2539+
2540+
@export
2541+
@annotate_args([None, ([2, 3, 4], torch.bool, True)])
2542+
def forward(self, x):
2543+
return torch.ops.aten.count_nonzero(x)
2544+
2545+
2546+
@register_test_case(module_factory=lambda: CountNonzeroModuleBool())
2547+
def CountNonzeroModuleBool_Basic(module, tu: TestUtils):
2548+
module.forward(tu.randint(2, 3, 4, low=0, high=2).to(torch.bool))

0 commit comments

Comments
 (0)