Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Canonicalize aten.log #3169

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 46 additions & 45 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -256,51 +256,6 @@ def Torch_AtenLeakyRelu_Op : Torch_Op<"aten.leaky_relu_", [
}];
}

def Torch_AtenLogOp : Torch_Op<"aten.log", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::log : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLogOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenLog_Op : Torch_Op<"aten.log_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::log_ : (Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLog_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLog_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenSeluOp : Torch_Op<"aten.selu", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -4269,6 +4224,52 @@ def Torch_AtenMaskedFill_TensorOp : Torch_Op<"aten.masked_fill_.Tensor", [
}];
}

def Torch_AtenLogOp : Torch_Op<"aten.log", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::log : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLogOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasCanonicalizer = 1;
}

def Torch_AtenLog_Op : Torch_Op<"aten.log_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::log_ : (Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLog_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLog_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
38 changes: 37 additions & 1 deletion lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1914,8 +1914,12 @@ void AtenScalarImplicitOp::getCanonicalizationPatterns(
Value a = op.getA();
auto outType = op.getResult().getType();
Value scalarValue = getScalarIntValue(a, loc, rewriter);
if (!scalarValue)
if (!scalarValue) {
scalarValue = getScalarFloatValue(a, loc, rewriter);
}
if (!scalarValue) {
return failure();
}
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType, scalarValue);
return success();
});
Expand Down Expand Up @@ -3343,6 +3347,38 @@ void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}

//===----------------------------------------------------------------------===//
// AtenLogOp
//===----------------------------------------------------------------------===//

void AtenLogOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenLogOp op, PatternRewriter &rewriter) {
Location loc = op->getLoc();
if (op->getNumOperands() != 1) {
return failure();
}
Value self = getScalarIntValue(op->getOperand(0), loc, rewriter);
if (!self)
self = getScalarFloatValue(op->getOperand(0), loc, rewriter);
if (!self) {
return failure();
}
auto outType = op->getResult(0).getType();
double floatScalar;
int64_t intScalar;
double resultScalar;
if (matchPattern(self, m_TorchConstantInt(&intScalar)))
resultScalar = std::log(intScalar);
if (matchPattern(self, m_TorchConstantFloat(&floatScalar)))
resultScalar = std::log(floatScalar);
Value result = rewriter.create<ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(resultScalar));
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(op, outType, result);
return success();
});
}

//===----------------------------------------------------------------------===//
// AtenBroadcastToOp
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 2 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,7 @@
"ElementwiseGeluModule_basic",
"ElementwiseLeakyReluStaticModule_basic",
"ElementwiseLogModule_basic",
"ElementwizeLogScalarInputModule_basic",
"ElementwiseNanToNumModule_Basic",
"ElementwiseNeFloatTensorStaticModule_basic",
"ElementwiseNeIntTensorStaticModule_basic",
Expand Down Expand Up @@ -1658,6 +1659,7 @@
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
"CosineSimilarityModule_basic",
"ElementwizeLogScalarInputModule_basic",
"NativeGroupNormBackwardModule_basic",
"ReduceFrobeniusNormKeepDimModule_basic",
"ReduceFrobeniusNormModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::relu : (Tensor) -> (Tensor)",
"aten::relu6 : (Tensor) -> (Tensor)",
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
"aten::log : (Tensor) -> (Tensor)",
"aten::selu : (Tensor) -> (Tensor)",
"aten::sigmoid : (Tensor) -> (Tensor)",
"aten::sinh : (Tensor) -> (Tensor)",
Expand Down Expand Up @@ -360,6 +359,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::log : (Tensor) -> (Tensor)", has_canonicalizer=True)

emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
Expand Down
17 changes: 17 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,6 +1752,23 @@ def ElementwiseLogModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))


class ElementwizeLogScalarInputModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
])
def forward(self):
a = torch.tensor(10)
return torch.log(a)

@register_test_case(module_factory=lambda: ElementwizeLogScalarInputModule())
def ElementwizeLogScalarInputModule_basic(module, tu: TestUtils):
module.forward()


# ==============================================================================


Expand Down
Loading