Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Canonicalize aten.log #3169

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 46 additions & 45 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -256,51 +256,6 @@ def Torch_AtenLeakyRelu_Op : Torch_Op<"aten.leaky_relu_", [
}];
}

def Torch_AtenLogOp : Torch_Op<"aten.log", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::log : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLogOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenLog_Op : Torch_Op<"aten.log_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::log_ : (Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLog_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLog_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenSeluOp : Torch_Op<"aten.selu", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down Expand Up @@ -4319,6 +4274,52 @@ def Torch_AtenMaskedFill_TensorOp : Torch_Op<"aten.masked_fill_.Tensor", [
}];
}

def Torch_AtenLogOp : Torch_Op<"aten.log", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::log : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLogOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLogOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
let hasFolder = 1;
}

def Torch_AtenLog_Op : Torch_Op<"aten.log_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::log_ : (Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self
);
let results = (outs
AnyTorchOptionalNonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenLog_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenLog_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenAddcmulOp : Torch_Op<"aten.addcmul", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
98 changes: 81 additions & 17 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1179,9 +1179,8 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op,
// NAry folder helpers
//===----------------------------------------------------------------------===//

static bool checkSameDTypes(llvm::ArrayRef<Attribute> attrs) {
bool allFp = true;
bool allInt = true;
static bool checkValidDTypes(llvm::ArrayRef<Attribute> attrs) {
bool allFpOrInt = true;

for (auto attr : attrs) {
if (!attr)
Expand All @@ -1196,11 +1195,12 @@ static bool checkSameDTypes(llvm::ArrayRef<Attribute> attrs) {
attrty = integer.getType();
if (auto shaped = dyn_cast_or_null<ShapedType>(attrty))
attrty = shaped.getElementType();
allFp &= isa<mlir::FloatType>(attrty);
allInt &= isa<mlir::IntegerType>(attrty);
bool isFloat = isa<mlir::FloatType>(attrty);
bool isInt = isa<mlir::IntegerType>(attrty);
allFpOrInt &= isFloat || isInt;
}

return allFp || allInt;
return allFpOrInt;
}

static bool checkAllSplats(llvm::ArrayRef<Attribute> attrs) {
Expand All @@ -1214,24 +1214,61 @@ static bool checkAllSplats(llvm::ArrayRef<Attribute> attrs) {
return true;
}

std::optional<double> convertIntegerAttributeToDouble(Attribute attr,
int64_t idx = 0) {
auto convertAPInt = [](const APInt &apint, bool isUnsigned) {
return isUnsigned ? static_cast<double>(apint.getZExtValue())
: static_cast<double>(apint.getSExtValue());
};

if (auto dense = attr.dyn_cast<ElementsAttr>()) {
if (!dense.tryGetValues<APInt>())
return std::nullopt;
if (auto intType = dense.getElementType().dyn_cast<mlir::IntegerType>()) {
bool isUnsigned = intType.isUnsigned();
if (dense.isSplat()) {
return convertAPInt(dense.getSplatValue<APInt>(), isUnsigned);
}
return convertAPInt(dense.getValues<APInt>()[idx], isUnsigned);
}
} else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
bool isUnsigned = intAttr.getType().cast<mlir::IntegerType>().isUnsigned();
return convertAPInt(intAttr.getValue(), isUnsigned);
}
return std::nullopt;
}

std::optional<double> convertFloatAttributeToDouble(Attribute attr,
int64_t idx = 0) {
if (auto dense = attr.dyn_cast<ElementsAttr>()) {
if (!dense.tryGetValues<APFloat>())
return std::nullopt;
if (auto floatType = dense.getElementType().dyn_cast<mlir::FloatType>()) {
if (dense.isSplat()) {
return dense.getSplatValue<APFloat>().convertToDouble();
}
return dense.getValues<APFloat>()[idx].convertToDouble();
}
} else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
return floatAttr.getValueAsDouble();
}
return std::nullopt;
}

llvm::SmallVector<double> getFoldValueAtIndexFp(llvm::ArrayRef<Attribute> attrs,
int64_t idx = 0) {
llvm::SmallVector<double> splattrs;

for (auto attr : attrs) {
if (auto dense = dyn_cast<ElementsAttr>(attr)) {
if (dense.isSplat()) {
splattrs.push_back(dense.getSplatValue<APFloat>().convertToDouble());
} else {
splattrs.push_back(dense.getValues<APFloat>()[idx].convertToDouble());
}
} else if (auto intattr = dyn_cast<FloatAttr>(attr)) {
splattrs.push_back(intattr.getValueAsDouble());
// the attr can be integer or float
if (auto floatVal = convertFloatAttributeToDouble(attr, idx)) {
splattrs.push_back(*floatVal);
} else if (auto intVal = convertIntegerAttributeToDouble(attr, idx)) {
splattrs.push_back(*intVal);
} else {
return {};
}
}

return splattrs;
}

Expand All @@ -1243,6 +1280,9 @@ llvm::SmallVector<APInt> getFoldValueAtIndexInt(llvm::ArrayRef<Attribute> attrs,
for (auto attr : attrs) {
bool isunsigned = false;
if (auto dense = dyn_cast<ElementsAttr>(attr)) {
if (!dense.tryGetValues<APInt>()) {
return {};
}
isunsigned = dyn_cast<IntegerType>(dense.getElementType()).isUnsigned();
if (dense.isSplat()) {
splattrs.push_back(dense.getSplatValue<APInt>());
Expand Down Expand Up @@ -1276,7 +1316,7 @@ static OpFoldResult naryFolderHelper(ArrayRef<Attribute> operands, Type ty,
NAryFoldFpOperator fpFolder,
NAryFoldIntOperator intFolder) {
constexpr int64_t maxFold = 16;
if (!checkSameDTypes(operands))
if (!checkValidDTypes(operands))
return nullptr;

auto resultTy = dyn_cast<ValueTensorType>(ty);
Expand Down Expand Up @@ -1322,6 +1362,8 @@ static OpFoldResult naryFolderHelper(ArrayRef<Attribute> operands, Type ty,
llvm::SmallVector<APFloat> folded;
for (int i = 0, s = numValues; i < s; ++i) {
auto inputs = getFoldValueAtIndexFp(operands, i);
if (inputs.empty())
return nullptr;
double fold = fpFolder(inputs);

APFloat val(fold);
Expand All @@ -1338,6 +1380,8 @@ static OpFoldResult naryFolderHelper(ArrayRef<Attribute> operands, Type ty,
for (int i = 0, s = numValues; i < s; ++i) {
auto inputs =
getFoldValueAtIndexInt(operands, dty.getIntOrFloatBitWidth(), i);
if (inputs.empty())
return nullptr;
folded.push_back(intFolder(inputs));
}
return DenseElementsAttr::get(resultBTy, folded);
Expand Down Expand Up @@ -1945,8 +1989,12 @@ void AtenScalarImplicitOp::getCanonicalizationPatterns(
Value a = op.getA();
auto outType = op.getResult().getType();
Value scalarValue = getScalarIntValue(a, loc, rewriter);
if (!scalarValue)
if (!scalarValue) {
scalarValue = getScalarFloatValue(a, loc, rewriter);
}
if (!scalarValue) {
return failure();
}
rewriter.replaceOpWithNewOp<Torch::DerefineOp>(op, outType, scalarValue);
return success();
});
Expand Down Expand Up @@ -3374,6 +3422,22 @@ void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
});
}

//===----------------------------------------------------------------------===//
// AtenLogOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenLogOp::fold(FoldAdaptor adaptor) {
auto fpFold = [](llvm::ArrayRef<double> inputs) {
assert(inputs.size() == 1);
return std::log(inputs[0]);
};
auto intFold = [](llvm::ArrayRef<APInt> inputs) {
assert(inputs.size() == 1);
return inputs[0];
};
return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold);
}

//===----------------------------------------------------------------------===//
// AtenBroadcastToOp
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 4 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,7 @@
"ElementwiseGeluModule_basic",
"ElementwiseLeakyReluStaticModule_basic",
"ElementwiseLogModule_basic",
"ElementwizeLogScalarInputModule_basic",
"ElementwiseNanToNumModule_Basic",
"ElementwiseNeFloatTensorStaticModule_basic",
"ElementwiseNeIntTensorStaticModule_basic",
Expand Down Expand Up @@ -1575,6 +1576,7 @@
"ElementwiseAddScalarInt64Module_basic",
"ElementwiseAddScalarInt8Module_basic",
"ElementwiseAddScalarIntModule_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
"ElementwiseAtenDivIntScalarModule_basic",
"ElementwiseAtenIsinfOpModule_basic",
Expand Down Expand Up @@ -1648,6 +1650,7 @@
"ElementwiseLerpScalarFloatModule_basic",
"ElementwiseLog2Module_basic",
"ElementwiseLogModule_basic",
"ElementwizeLogScalarInputModule_basic",
"ElementwiseLtDiffWidthScalarModule_basic",
"ElementwiseLtFloatScalarModule_basic",
"ElementwiseLtFloatTensorModule_basic",
Expand Down Expand Up @@ -1921,6 +1924,7 @@
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
"CosineSimilarityModule_basic",
"ElementwizeLogScalarInputModule_basic",
"NativeGroupNormBackwardModule_basic",
"ReduceFrobeniusNormKeepDimModule_basic",
"ReduceFrobeniusNormModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::relu : (Tensor) -> (Tensor)",
"aten::relu6 : (Tensor) -> (Tensor)",
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
"aten::log : (Tensor) -> (Tensor)",
"aten::selu : (Tensor) -> (Tensor)",
"aten::sigmoid : (Tensor) -> (Tensor)",
"aten::sinh : (Tensor) -> (Tensor)",
Expand Down Expand Up @@ -361,6 +360,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True)
emit_with_mutating_variants("aten::sign : (Tensor) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True)
emit_with_mutating_variants("aten::log : (Tensor) -> (Tensor)", has_folder=True)

emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
emit_with_mutating_variants("aten::addcdiv : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)")
Expand Down
17 changes: 17 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1752,6 +1752,23 @@ def ElementwiseLogModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))


class ElementwizeLogScalarInputModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
])
def forward(self):
a = torch.tensor(10)
return torch.log(a)

@register_test_case(module_factory=lambda: ElementwizeLogScalarInputModule())
def ElementwizeLogScalarInputModule_basic(module, tu: TestUtils):
module.forward()


# ==============================================================================


Expand Down
Loading