From 145a0a29bf9fd5623164304ac7f2cd057521d109 Mon Sep 17 00:00:00 2001 From: jinchen62 Date: Sun, 28 Apr 2024 17:47:32 -0700 Subject: [PATCH] Fix onnx cosh lowering --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 36 +++++++++++++------ 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 96f4e55fb12dd..8c9875ca40bf8 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1348,17 +1348,31 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp("Cosh", 9, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); + patterns.onOp( + "Cosh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + // 1/2 * (exp(x) + exp(-x)) + Value x = rewriter.create( + binder.getLoc(), resultType, operand); + Value neg = rewriter.create( + binder.getLoc(), resultType, operand); + Value y = rewriter.create( + binder.getLoc(), resultType, neg); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value z = rewriter.create( + binder.getLoc(), resultType, x, y, cstOne); + Value cstTwo = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, z, cstTwo); + return success(); + }); patterns.onOp( "CumSum", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Location loc = binder.getLoc();