From 4e5577ad88fc99b93eec0ede85e61ad5c7a87e99 Mon Sep 17 00:00:00 2001 From: "wujiawei.aml" Date: Wed, 8 May 2024 22:36:24 +0800 Subject: [PATCH] [torch-dialect] emit aten.index_add and decompose it to scatter.add op --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 53 ++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 16 +++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 70 +++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + .../build_tools/abstract_interp_lib_gen.py | 19 ++++- .../build_tools/torch_ods_gen.py | 2 + 6 files changed, 158 insertions(+), 3 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4de41e13b80a..8734896c34cf 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5703,6 +5703,59 @@ def Torch_AtenTril_Op : Torch_Op<"aten.tril_", [ }]; } +def Torch_AtenIndexAddOp : Torch_Op<"aten.index_add", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::index_add : (Tensor, int, Tensor, Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$dim, + AnyTorchTensorType:$index, + AnyTorchTensorType:$source, + AnyTorchScalarType:$alpha + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIndexAddOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenIndexAddOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + +def Torch_AtenIndexAdd_Op : Torch_Op<"aten.index_add_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::index_add_ : (Tensor, int, Tensor, Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + Torch_IntType:$dim, + Torch_NonValueTensorType:$index, + Torch_NonValueTensorType:$source, + AnyTorchScalarType:$alpha + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIndexAdd_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenIndexAdd_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_AtenIndexPutOp : Torch_Op<"aten.index_put", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 43bcc3acc0eb..a414cca75697 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9185,6 +9185,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.scatter_add\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.index_add\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.index_put\"(%arg0: !torch.list, %arg1: !torch.list>>, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -10399,6 +10407,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.index_add\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.tuple, %arg4: !torch.number) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.scatter_add\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple, %arg3: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.index_select\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 5ec22233bbf5..e5a6b2fe7728 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5621,6 +5621,75 @@ class DecomposeAtenNewFullOp : public OpRewritePattern { }; } // namespace +namespace { +// Decompose `aten.index_add` op into `aten.index_put` +class DecomposeAtenIndexAddOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenIndexAddOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value src = op.getSource(); + Value input = op.getSelf(); + Value index = op.getIndex(); + Value alpha = op.getAlpha(); + + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { + return rewriter.notifyMatchFailure(op, + "dim of index_add must be a constant"); + } + std::optional maybeInputRank = getTensorRank(input); + if (!maybeInputRank) { + return rewriter.notifyMatchFailure(op, "expected input to have a rank"); + } + int64_t inputRank = static_cast(*maybeInputRank); + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) { + return rewriter.notifyMatchFailure(op, "index dim is not a valid dim"); + } + + auto resType = op.getType().cast(); + auto srcType = src.getType().cast(); + auto indexType = index.getType().cast(); + if (!indexType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "index should have dtype"); + } + auto indexDtype = indexType.getDtype(); + + // calculate src * alpha first. + Value newSrc = + rewriter.create(loc, srcType, src, alpha); + + // broadcast index to have the same shape as src. + Value constMinusOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(-1)); + for (int64_t i = dim + 1; i < inputRank; ++i) { + index = *unsqueezeTensor(rewriter, op, index, /*dim=*/constMinusOne); + } + + SmallVector bcastShape; + SmallVector bcastShapeValue; + computeBroadcastShape(rewriter, loc, index, src, bcastShape, + bcastShapeValue); + + Type bcastType = ValueTensorType::get( + op.getContext(), llvm::ArrayRef(bcastShape), indexDtype); + + Value indexBcastShapeTorchList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + bcastShapeValue); + + index = rewriter.create(loc, bcastType, index, + indexBcastShapeTorchList); + + rewriter.replaceOpWithNewOp(op, resType, input, + op.getDim(), index, newSrc); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenExpandAsOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -8021,6 +8090,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 0ca7ea9c4f0e..39bb88e583a3 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -471,6 +471,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 1cf0c2c7696a..1d7fd7b30eaf 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1607,15 +1607,18 @@ def aten〇scatter〇value〡shape(self: List[int], dim: int, index: List[int], def aten〇index_select〡shape(self: List[int], dim: int, index: List[int]) -> List[int]: return upstream_shape_functions.index_select(self, dim, index) +def aten〇scatter_add〡shape(self: List[int], dim: int, index: List[int], src: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇index_add〡shape(self: List[int], dim: int, index: List[int], source: List[int], alpha: float = 1) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇index_put〡shape(self: List[int], indices: List[Optional[List[int]]], values: List[int], accumulate: bool = False) -> List[int]: return upstream_shape_functions.unary(self) def aten〇index_put〇hacked_twin〡shape(self: List[int], indices: List[List[int]], values: List[int], accumulate: bool = False) -> List[int]: return upstream_shape_functions.unary(self) -def aten〇embedding〡shape(weight: List[int], indices: List[int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> List[int]: - return upstream_shape_functions.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse) - def aten〇embedding_bag〇padding_idx〡shape(weight: List[int], indices: List[int], offsets: List[int], scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights: Optional[List[int]], include_last_offset: bool, padding_idx: Optional[int]) -> Tuple[List[int], List[int], List[int], List[int]]: return _embedding_bag_helper(weight, indices, offsets, include_last_offset, mode, per_sample_weights, padding_idx) @@ -2534,6 +2537,16 @@ def aten〇index_put〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtyp self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function([Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES]) +def aten〇index_add〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], source_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + +@check_dtype_function([Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES]) +def aten〇scatter_add〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, None, 0, TensorOfShape(1, dtype=torch.int64))) def aten〇index_select〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index c847e42d844a..e7fd7cf7c409 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -516,6 +516,8 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)") emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)") + emit_with_mutating_variants("aten::index_add : (Tensor, int, Tensor, Tensor, Scalar) -> (Tensor)") + emit_with_mutating_variants( "aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)" )