Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch Dialect] emit aten.index_add op and decompose it to scatter.add op #3085

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5703,6 +5703,59 @@ def Torch_AtenTril_Op : Torch_Op<"aten.tril_", [
}];
}

def Torch_AtenIndexAddOp : Torch_Op<"aten.index_add", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::index_add : (Tensor, int, Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$index,
AnyTorchTensorType:$source,
AnyTorchScalarType:$alpha
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIndexAddOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenIndexAddOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}

def Torch_AtenIndexAdd_Op : Torch_Op<"aten.index_add_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::index_add_ : (Tensor, int, Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
Torch_IntType:$dim,
Torch_NonValueTensorType:$index,
Torch_NonValueTensorType:$source,
AnyTorchScalarType:$alpha
);
let results = (outs
Torch_NonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIndexAdd_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenIndexAdd_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}

def Torch_AtenIndexPutOp : Torch_Op<"aten.index_put", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9185,6 +9185,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.scatter_add\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.index_add\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.index_put\"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -10399,6 +10407,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.index_add\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>, %arg4: !torch.number) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.scatter_add\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.index_select\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
70 changes: 70 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5621,6 +5621,75 @@ class DecomposeAtenNewFullOp : public OpRewritePattern<AtenNewFullOp> {
};
} // namespace

namespace {
// Decompose `aten.index_add` op into `aten.index_put`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the comment.

class DecomposeAtenIndexAddOp : public OpRewritePattern<AtenIndexAddOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenIndexAddOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value src = op.getSource();
Value input = op.getSelf();
Value index = op.getIndex();
Value alpha = op.getAlpha();

int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
return rewriter.notifyMatchFailure(op,
"dim of index_add must be a constant");
}
std::optional<unsigned> maybeInputRank = getTensorRank(input);
if (!maybeInputRank) {
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
}
int64_t inputRank = static_cast<int64_t>(*maybeInputRank);
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank)) {
return rewriter.notifyMatchFailure(op, "index dim is not a valid dim");
}

auto resType = op.getType().cast<BaseTensorType>();
auto srcType = src.getType().cast<BaseTensorType>();
auto indexType = index.getType().cast<BaseTensorType>();
if (!indexType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "index should have dtype");
}
auto indexDtype = indexType.getDtype();

// calculate src * alpha first.
Value newSrc =
rewriter.create<Torch::AtenMulScalarOp>(loc, srcType, src, alpha);

// broadcast index to have the same shape as src.
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-1));
for (int64_t i = dim + 1; i < inputRank; ++i) {
index = *unsqueezeTensor(rewriter, op, index, /*dim=*/constMinusOne);
}

SmallVector<int64_t> bcastShape;
SmallVector<Value> bcastShapeValue;
computeBroadcastShape(rewriter, loc, index, src, bcastShape,
bcastShapeValue);

Type bcastType = ValueTensorType::get(
op.getContext(), llvm::ArrayRef(bcastShape), indexDtype);

Value indexBcastShapeTorchList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
bcastShapeValue);

index = rewriter.create<Torch::AtenBroadcastToOp>(loc, bcastType, index,
indexBcastShapeTorchList);

rewriter.replaceOpWithNewOp<Torch::AtenScatterAddOp>(op, resType, input,
op.getDim(), index, newSrc);
return success();
}
};
} // namespace

namespace {
class DecomposeAtenExpandAsOp : public OpRewritePattern<AtenExpandAsOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -8021,6 +8090,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewFullOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexAddOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandAsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_ToCopyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCopyOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenMishOp>();
target.addIllegalOp<AtenFullLikeOp>();
target.addIllegalOp<AtenNewFullOp>();
target.addIllegalOp<AtenIndexAddOp>();
target.addIllegalOp<AtenExpandAsOp>();
target.addIllegalOp<Aten_ToCopyOp>();
target.addIllegalOp<AtenDropoutOp>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1607,15 +1607,18 @@ def aten〇scatter〇value〡shape(self: List[int], dim: int, index: List[int],
def aten〇index_select〡shape(self: List[int], dim: int, index: List[int]) -> List[int]:
return upstream_shape_functions.index_select(self, dim, index)

def aten〇scatter_add〡shape(self: List[int], dim: int, index: List[int], src: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇index_add〡shape(self: List[int], dim: int, index: List[int], source: List[int], alpha: float = 1) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇index_put〡shape(self: List[int], indices: List[Optional[List[int]]], values: List[int], accumulate: bool = False) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇index_put〇hacked_twin〡shape(self: List[int], indices: List[List[int]], values: List[int], accumulate: bool = False) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇embedding〡shape(weight: List[int], indices: List[int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> List[int]:
return upstream_shape_functions.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse)

def aten〇embedding_bag〇padding_idx〡shape(weight: List[int], indices: List[int], offsets: List[int], scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights: Optional[List[int]], include_last_offset: bool, padding_idx: Optional[int]) -> Tuple[List[int], List[int], List[int], List[int]]:
return _embedding_bag_helper(weight, indices, offsets, include_last_offset,
mode, per_sample_weights, padding_idx)
Expand Down Expand Up @@ -2534,6 +2537,16 @@ def aten〇index_put〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtyp
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function([Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES])
def aten〇index_add〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], source_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function([Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES])
def aten〇scatter_add〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, None, 0, TensorOfShape(1, dtype=torch.int64)))
def aten〇index_select〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,8 @@ def emit_with_mutating_variants(key, **kwargs):

emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)")
emit_with_mutating_variants("aten::index_add : (Tensor, int, Tensor, Tensor, Scalar) -> (Tensor)")

emit_with_mutating_variants(
"aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)"
)
Expand Down
Loading