From cb371ea4d3732a3bc05a5f7166fcda5778f1aaa6 Mon Sep 17 00:00:00 2001 From: Jacob Gordon Date: Fri, 10 Jan 2025 18:07:27 +0000 Subject: [PATCH] fix(ONNX): avoids resizing non scalable dimensions --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 62 +++++++++++++++++-- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 6 +- 2 files changed, 61 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index ed4730444f9ed..e48725e41ab20 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2753,8 +2753,66 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "unimplemented: cubic coeff must be -0.75"); } + Value inputTensor = operands[0]; + Torch::ValueTensorType typeOfInputTensor = + cast(inputTensor.getType()); + + ArrayRef sizesOfInputTensor = typeOfInputTensor.getSizes(); + ArrayRef sizesOfOutputTensor = typeOfOutputTensor.getSizes(); + + int64_t const dimensionAssumedToBeBatch = 0; + int64_t const dimensionAssumedToBeChannel = 1; + int64_t nonScalableDimensions[] = { + dimensionAssumedToBeBatch, + dimensionAssumedToBeChannel, + }; + + auto unknownSize = Torch::kUnknownSize; + + // Compile-time check for dimensions of static size + for (auto eachDimension : nonScalableDimensions) { + auto eachSizeOfInputTensor = sizesOfInputTensor[eachDimension]; + auto eachSizeOfOutputTensor = sizesOfOutputTensor[eachDimension]; + + if (eachSizeOfInputTensor == unknownSize || + eachSizeOfOutputTensor == unknownSize) { + continue; + } else if (eachSizeOfInputTensor == eachSizeOfOutputTensor) { + continue; + } + + auto scalingIntentErrorMessage = + "unsupported: non-trivial intent to scale dimension: " + + std::to_string(eachDimension); + + return rewriter.notifyMatchFailure(binder.op, + scalingIntentErrorMessage); + }; + auto opLocation = binder.getLoc(); + // Run-time check for dimensions of dynamic size + for (auto eachDimension : nonScalableDimensions) { + auto eachDimensionAsValue = rewriter.create( + opLocation, rewriter.getI64IntegerAttr(eachDimension)); + + Value eachSizeOfInputAsValue = rewriter.create( + opLocation, inputTensor, eachDimensionAsValue); + + int64_t eachSizeOfOutput = sizesOfOutputTensor[eachDimension]; + Value eachSizeOfOutputAsValue = rewriter.create( + opLocation, rewriter.getI64IntegerAttr(eachSizeOfOutput)); + + Value eachSizeComparison = rewriter.create( + opLocation, eachSizeOfInputAsValue, eachSizeOfOutputAsValue); + + rewriter.create( + opLocation, eachSizeComparison, + rewriter.getStringAttr( + "unsupported: non-trivial scaling of dimension " + + std::to_string(eachDimension))); + }; + Value cstFalse = rewriter.create(opLocation, false); Value cstTrue = @@ -2774,10 +2832,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.create(opLocation, modeStr); } - Value inputTensor = operands[0]; - Torch::ValueTensorType typeOfInputTensor = - cast(inputTensor.getType()); - ArrayRef sizesOfInputTensor = typeOfInputTensor.getSizes(); unsigned rankOfInputTensor = sizesOfInputTensor.size(); // supported modes: diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 16c86218dbc8b..b4803be6ed3b9 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2256,7 +2256,7 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: // CHECK-LABEL: func.func @test_resize_sizes_nearest func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %8, %none_1, %str, %false, %none_1, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> } @@ -2267,7 +2267,7 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none // CHECK: %[[STR:.+]] = torch.constant.str "nearest_half_pixel,round_prefer_floor" - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %[[STR]], %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %8, %none_1, %[[STR]], %false, %none_1, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) { torch.onnx.coordinate_transformation_mode = "half_pixel", torch.onnx.mode = "nearest"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> @@ -2280,7 +2280,7 @@ func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1 func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?], f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %none = torch.constant.none - // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + // CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %8, %none_1, %str, %false, %none_1, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> %0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32> }