diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 586b8d4ff053d..051944e6521f7 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1424,18 +1424,31 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); }); - patterns.onOp("Sinh", 9, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); + patterns.onOp( + "Sinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); + // 1/2 * (exp(x) – exp(-x)) + Value x = rewriter.create( + binder.getLoc(), resultType, operand); + Value neg = rewriter.create( + binder.getLoc(), resultType, operand); + Value y = rewriter.create( + binder.getLoc(), resultType, neg); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value z = rewriter.create( + binder.getLoc(), resultType, x, y, cstOne); + Value cstTwo = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, z, cstTwo); + return success(); + }); // split with fixed-size parts // Arguments: