Skip to content

Commit

Permalink
[torch-dialect] emit aten.index_add and decompose it to scatter.add op
Browse files Browse the repository at this point in the history
  • Loading branch information
Vremold committed May 16, 2024
1 parent 28193fd commit 4e5577a
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 3 deletions.
53 changes: 53 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5703,6 +5703,59 @@ def Torch_AtenTril_Op : Torch_Op<"aten.tril_", [
}];
}

def Torch_AtenIndexAddOp : Torch_Op<"aten.index_add", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::index_add : (Tensor, int, Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$dim,
AnyTorchTensorType:$index,
AnyTorchTensorType:$source,
AnyTorchScalarType:$alpha
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIndexAddOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenIndexAddOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}

def Torch_AtenIndexAdd_Op : Torch_Op<"aten.index_add_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::index_add_ : (Tensor, int, Tensor, Tensor, Scalar) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self,
Torch_IntType:$dim,
Torch_NonValueTensorType:$index,
Torch_NonValueTensorType:$source,
AnyTorchScalarType:$alpha
);
let results = (outs
Torch_NonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIndexAdd_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenIndexAdd_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}

def Torch_AtenIndexPutOp : Torch_Op<"aten.index_put", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
16 changes: 16 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9185,6 +9185,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.index_select(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.scatter_add\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.index_add\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.float) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.index_put\"(%arg0: !torch.list<int>, %arg1: !torch.list<optional<list<int>>>, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -10399,6 +10407,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.index_add\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>, %arg4: !torch.number) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.scatter_add\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>, %arg3: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.index_select\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
70 changes: 70 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5621,6 +5621,75 @@ class DecomposeAtenNewFullOp : public OpRewritePattern<AtenNewFullOp> {
};
} // namespace

namespace {
// Decompose `aten.index_add` op into `aten.index_put`
class DecomposeAtenIndexAddOp : public OpRewritePattern<AtenIndexAddOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenIndexAddOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value src = op.getSource();
Value input = op.getSelf();
Value index = op.getIndex();
Value alpha = op.getAlpha();

int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) {
return rewriter.notifyMatchFailure(op,
"dim of index_add must be a constant");
}
std::optional<unsigned> maybeInputRank = getTensorRank(input);
if (!maybeInputRank) {
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
}
int64_t inputRank = static_cast<int64_t>(*maybeInputRank);
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank)) {
return rewriter.notifyMatchFailure(op, "index dim is not a valid dim");
}

auto resType = op.getType().cast<BaseTensorType>();
auto srcType = src.getType().cast<BaseTensorType>();
auto indexType = index.getType().cast<BaseTensorType>();
if (!indexType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "index should have dtype");
}
auto indexDtype = indexType.getDtype();

// calculate src * alpha first.
Value newSrc =
rewriter.create<Torch::AtenMulScalarOp>(loc, srcType, src, alpha);

// broadcast index to have the same shape as src.
Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(-1));
for (int64_t i = dim + 1; i < inputRank; ++i) {
index = *unsqueezeTensor(rewriter, op, index, /*dim=*/constMinusOne);
}

SmallVector<int64_t> bcastShape;
SmallVector<Value> bcastShapeValue;
computeBroadcastShape(rewriter, loc, index, src, bcastShape,
bcastShapeValue);

Type bcastType = ValueTensorType::get(
op.getContext(), llvm::ArrayRef(bcastShape), indexDtype);

Value indexBcastShapeTorchList = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
bcastShapeValue);

index = rewriter.create<Torch::AtenBroadcastToOp>(loc, bcastType, index,
indexBcastShapeTorchList);

rewriter.replaceOpWithNewOp<Torch::AtenScatterAddOp>(op, resType, input,
op.getDim(), index, newSrc);
return success();
}
};
} // namespace

namespace {
class DecomposeAtenExpandAsOp : public OpRewritePattern<AtenExpandAsOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -8021,6 +8090,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewFullOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenIndexAddOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandAsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_ToCopyOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCopyOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenMishOp>();
target.addIllegalOp<AtenFullLikeOp>();
target.addIllegalOp<AtenNewFullOp>();
target.addIllegalOp<AtenIndexAddOp>();
target.addIllegalOp<AtenExpandAsOp>();
target.addIllegalOp<Aten_ToCopyOp>();
target.addIllegalOp<AtenDropoutOp>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1607,15 +1607,18 @@ def aten〇scatter〇value〡shape(self: List[int], dim: int, index: List[int],
def aten〇index_select〡shape(self: List[int], dim: int, index: List[int]) -> List[int]:
return upstream_shape_functions.index_select(self, dim, index)

def aten〇scatter_add〡shape(self: List[int], dim: int, index: List[int], src: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇index_add〡shape(self: List[int], dim: int, index: List[int], source: List[int], alpha: float = 1) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇index_put〡shape(self: List[int], indices: List[Optional[List[int]]], values: List[int], accumulate: bool = False) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇index_put〇hacked_twin〡shape(self: List[int], indices: List[List[int]], values: List[int], accumulate: bool = False) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇embedding〡shape(weight: List[int], indices: List[int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> List[int]:
return upstream_shape_functions.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse)

def aten〇embedding_bag〇padding_idx〡shape(weight: List[int], indices: List[int], offsets: List[int], scale_grad_by_freq: bool, mode: int, sparse: bool, per_sample_weights: Optional[List[int]], include_last_offset: bool, padding_idx: Optional[int]) -> Tuple[List[int], List[int], List[int], List[int]]:
return _embedding_bag_helper(weight, indices, offsets, include_last_offset,
mode, per_sample_weights, padding_idx)
Expand Down Expand Up @@ -2534,6 +2537,16 @@ def aten〇index_put〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtyp
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function([Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES])
def aten〇index_add〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], source_rank_dtype: Tuple[int, int], alpha: Union[int, float, complex] = 1) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function([Invocation(TensorOfShape(3, dtype=dtype), 0, TensorOfShape(3, dtype=torch.int64), TensorOfShape(3, dtype=dtype)) for dtype in _SORTED_TORCH_TYPES])
def aten〇scatter_add〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int], src_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(None, [(5,)], None, None, 0, TensorOfShape(1, dtype=torch.int64)))
def aten〇index_select〡dtype(self_rank_dtype: Tuple[int, int], dim: int, index_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,8 @@ def emit_with_mutating_variants(key, **kwargs):

emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)")
emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)")
emit_with_mutating_variants("aten::index_add : (Tensor, int, Tensor, Tensor, Scalar) -> (Tensor)")

emit_with_mutating_variants(
"aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)"
)
Expand Down

0 comments on commit 4e5577a

Please sign in to comment.