Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use op.dtype to create aten.empty.memory_format during decomposition. #3941

Merged
merged 1 commit into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ Type getTypeForTorchType(
MLIRContext *context, Type type,
mlir::IntegerType::SignednessSemantics signedness = IntegerType::Signed);

template <typename OpTy>
FailureOr<Value> getDtypeFromOp(PatternRewriter &rewriter, OpTy op);

FailureOr<Type> getTorchTypeForScalarType(MLIRContext *context,
torch_upstream::ScalarType dtypeInt);

Expand Down
52 changes: 32 additions & 20 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7087,9 +7087,16 @@ class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
Torch::ListType::get(Torch::IntType::get(op.getContext()));
Value sizeList =
rewriter.create<AtenSizeOp>(op.getLoc(), sizeListType, op.getSelf());

FailureOr<Value> dtype = getDtypeFromOp(rewriter, op);
if (failed(dtype)) {
return rewriter.notifyMatchFailure(
op, "could not determine dtype from the op.");
}

rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
op, op.getType(), sizeList, op.getDtype(), op.getLayout(),
op.getDevice(), op.getPinMemory(), op.getMemoryFormat());
op, op.getType(), sizeList, *dtype, op.getLayout(), op.getDevice(),
op.getPinMemory(), op.getMemoryFormat());
return success();
}
};
Expand Down Expand Up @@ -7838,18 +7845,13 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
LogicalResult matchAndRewrite(AtenNewEmptyOp op,
PatternRewriter &rewriter) const override {
Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());
Value dtype = op.getDtype();
if (isa<Torch::NoneType>(dtype.getType())) {
BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf().getType());
if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected input tensor to have a dtype");
}
dtype =
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
FailureOr<Value> dtype = getDtypeFromOp(rewriter, op);
if (failed(dtype)) {
return rewriter.notifyMatchFailure(
op, "could not determine dtype from the op.");
}
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
op, op.getType(), op.getSize(), dtype, op.getLayout(), op.getDevice(),
op, op.getType(), op.getSize(), *dtype, op.getLayout(), op.getDevice(),
op.getPinMemory(), /*memoryFormat=*/noneVal);
return success();
}
Expand Down Expand Up @@ -9257,12 +9259,12 @@ class DecomposeAtenRandnGeneratorOp
Location loc = op.getLoc();
auto resultType = cast<BaseTensorType>(op.getType());

if (!resultType.hasDtype()) {
FailureOr<Value> dtype = getDtypeFromOp(rewriter, op);
if (failed(dtype)) {
return rewriter.notifyMatchFailure(
op, "expected result type to have a dtype");
op, "could not determine dtype from the op.");
}

Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype());
Value none = rewriter.create<ConstantNoneOp>(loc);
Value low = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr((double)0.0));
Expand All @@ -9274,12 +9276,12 @@ class DecomposeAtenRandnGeneratorOp
loc, rewriter.getF64FloatAttr((double)(2.0 * 3.14159)));

Value emptyTensorA = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, resultType, op.getSize(), /*dtype=*/dtype,
loc, resultType, op.getSize(), /*dtype=*/*dtype,
/*layout=*/op.getLayout(),
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
/*memory_format=*/none);
Value emptyTensorB = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, resultType, op.getSize(), /*dtype=*/dtype,
loc, resultType, op.getSize(), /*dtype=*/*dtype,
/*layout=*/op.getLayout(),
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
/*memory_format=*/none);
Expand Down Expand Up @@ -9377,8 +9379,13 @@ class DecomposeAtenRandOp : public OpRewritePattern<AtenRandOp> {
loc, rewriter.getF64FloatAttr((double)0.0));
Value high = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr((double)1.0));
FailureOr<Value> dtype = getDtypeFromOp(rewriter, op);
if (failed(dtype)) {
return rewriter.notifyMatchFailure(
op, "could not determine dtype from the op.");
}
Value emptyTensor = rewriter.create<AtenEmptyMemoryFormatOp>(
loc, resultType, op.getSize(), /*dtype=*/op.getDtype(),
loc, resultType, op.getSize(), /*dtype=*/*dtype,
/*layout=*/op.getLayout(),
/*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(),
/*memory_format=*/noneVal);
Expand Down Expand Up @@ -9536,9 +9543,14 @@ class DecomposeAtenEmptyStridedOp

Value noneVal = rewriter.create<ConstantNoneOp>(op.getLoc());

FailureOr<Value> dtype = getDtypeFromOp(rewriter, op);
if (failed(dtype)) {
return rewriter.notifyMatchFailure(
op, "could not determine dtype from the op.");
}
rewriter.replaceOpWithNewOp<AtenEmptyMemoryFormatOp>(
op, op.getType(), op.getSize(), op.getDtype(), op.getLayout(),
op.getDevice(), op.getPinMemory(), /*memoryFormat=*/noneVal);
op, op.getType(), op.getSize(), *dtype, op.getLayout(), op.getDevice(),
op.getPinMemory(), /*memoryFormat=*/noneVal);
return success();
}
};
Expand Down
36 changes: 36 additions & 0 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,42 @@ Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc,
rewriter.getI64IntegerAttr(intType));
}

template <typename OpTy>
FailureOr<Value> Torch::getDtypeFromOp(PatternRewriter &rewriter, OpTy op) {
// For ops like AtenEmptyLikeOp, if dtype specified in the op is none, then it
// defaults to dtype of input. Since dtype specifies the dtype of output, in
// this scenario we can look at dtype of output instead of input itself.
// For ops like AtenRandOp, if dtype specified in the op is none, then it
// defaults to a global value. In this case as well we can look at dtype of
// output as it will already be set according to the default global value.
Value dtype = op.getDtype();
if (isa<Torch::NoneType>(dtype.getType())) {
BaseTensorType tensorType = cast<BaseTensorType>(op.getType());
if (!tensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "expected input tensor to have a dtype");
}
dtype =
getDtypeIntValueForType(rewriter, op.getLoc(), tensorType.getDtype());
}
return dtype;
}
// Template instantiation template std::optional<Value>
template FailureOr<Value>
Torch::getDtypeFromOp<AtenEmptyLikeOp>(PatternRewriter &rewriter,
AtenEmptyLikeOp op);
template FailureOr<Value>
Torch::getDtypeFromOp<AtenNewEmptyOp>(PatternRewriter &rewriter,
AtenNewEmptyOp op);
template FailureOr<Value>
Torch::getDtypeFromOp<AtenRandOp>(PatternRewriter &rewriter, AtenRandOp op);
template FailureOr<Value>
Torch::getDtypeFromOp<AtenEmptyStridedOp>(PatternRewriter &rewriter,
AtenEmptyStridedOp op);
template FailureOr<Value>
Torch::getDtypeFromOp<AtenRandnGeneratorOp>(PatternRewriter &rewriter,
AtenRandnGeneratorOp op);

// Helper to convert a tensor to a specific scalar type.
Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc,
Value input, Type dtype) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,26 @@ def EmptyLikeModule_falsePinMemory(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))


class EmptyLikeDefaultDtypeFloat64InputModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float64, True),
]
)
def forward(self, x):
return torch.empty_like(x).fill_(0)


@register_test_case(module_factory=lambda: EmptyLikeDefaultDtypeFloat64InputModule())
def EmptyLikeDefaultDtypeFloat64InputModule_basic(module, tu: TestUtils):
module.forward(torch.ones((200, 200, 26), dtype=torch.float64))


# ==============================================================================


Expand Down
48 changes: 48 additions & 0 deletions test/Dialect/Torch/decompose-complex-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,51 @@ func.func @convolution_backward_none_result(%arg0: !torch.vtensor<[1,1,3,3],f32>
%result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %2, %1, %false, %2, %int1, %3 : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int, !torch.list<bool> -> !torch.none, !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>
return %result1, %result2 : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1],f32>
}

// -----
// CHECK-LABEL: func.func @emptyLikeNoneDtype(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {
// CHECK: %[[DTYPE:.*]] = torch.constant.int 7
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[C200:.*]] = torch.constant.int 200
// CHECK: %[[C26:.*]] = torch.constant.int 26
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C200]], %[[C200]], %[[C26]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[MEM_FMT:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[DTYPE]], %[[NONE]], %[[NONE]], %[[FALSE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64>
func.func @emptyLikeNoneDtype(%arg0: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {
%none = torch.constant.none
%none_0 = torch.constant.none
%none_1 = torch.constant.none
%false = torch.constant.bool false
%none_2 = torch.constant.none
%0 = torch.aten.empty_like %arg0, %none, %none_0, %none_1, %false, %none_2 : !torch.vtensor<[200,200,26],f64>, !torch.none, !torch.none, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64>
return %0 : !torch.vtensor<[200,200,26],f64>
}

// -----
// CHECK-LABEL: func.func @randNoneDtype(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {
// CHECK: %[[DTYPE:.*]] = torch.constant.int 7
// CHECK: %[[C1:.*]] = torch.constant.float 1.000000e+00
// CHECK: %[[C0:.*]] = torch.constant.float 0.000000e+00
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[C200:.*]] = torch.constant.int 200
// CHECK: %[[C26:.*]] = torch.constant.int 26
// CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[C200]], %[[C200]], %[[C26]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[CPU:.*]] = torch.constant.device "cpu"
// CHECK: %[[MEM_FMT:.*]] = torch.aten.empty.memory_format %[[LIST]], %[[DTYPE]], %[[NONE]], %[[CPU]], %[[FALSE]], %[[NONE]] : !torch.list<int>, !torch.int, !torch.none, !torch.Device, !torch.bool, !torch.none -> !torch.vtensor<[200,200,26],f64>
// CHECK: %[[UNIFORM:.*]] = torch.aten.uniform %[[MEM_FMT]], %[[C0]], %[[C1]], %[[NONE]] : !torch.vtensor<[200,200,26],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[200,200,26],f64>
// CHECK: return %[[UNIFORM]] : !torch.vtensor<[200,200,26],f64>
func.func @randNoneDtype(%arg0: !torch.vtensor<[200,200,26],f64>) -> !torch.vtensor<[200,200,26],f64> {
%int200 = torch.constant.int 200
%int200_0 = torch.constant.int 200
%int26 = torch.constant.int 26
%0 = torch.prim.ListConstruct %int200, %int200_0, %int26 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%none = torch.constant.none
%none_1 = torch.constant.none
%cpu = torch.constant.device "cpu"
%false = torch.constant.bool false
%1 = torch.aten.rand %0, %none, %none_1, %cpu, %false : !torch.list<int>, !torch.none, !torch.none, !torch.Device, !torch.bool -> !torch.vtensor<[200,200,26],f64>
return %1 : !torch.vtensor<[200,200,26],f64>
}
Loading