diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 5fb17c79a65b..944c258a8d12 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -180,53 +180,67 @@ LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter, return success(); } -Value getValueList(OpBinder binder, ConversionPatternRewriter &rewriter, - Value operand) { - SmallVector itemList; - auto sizes = dyn_cast(operand.getType()).getSizes(); - Torch::BaseTensorType operandType = - cast(operand.getType()); - - SmallVector selectSizes; - selectSizes.push_back(1); - Type selectResultType = operandType.getWithSizesAndDtype( - llvm::ArrayRef(selectSizes), operandType.getOptionalDtype()); - - auto extract = [&rewriter, &binder](Value x, Value v) { - auto xTy = cast(x.getType()); - Type extractTy = rewriter.getType(); - if (isa(xTy.getDtype())) - extractTy = rewriter.getType(); - - return rewriter.create(binder.getLoc(), extractTy, v); - }; - - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - - MLIRContext *context = binder.op->getContext(); - for (int i = 2; i < sizes[0]; i++) { - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value ext = rewriter.create( - binder.getLoc(), selectResultType, operand, zero, selectIndex); - Value item = extract(operand, ext); - itemList.push_back(item); - } - auto xTy = cast(operand.getType()); - Value ValueList; - if (isa(xTy.getDtype())) { - ValueList = rewriter.create( - binder.getLoc(), Torch::ListType::get(Torch::IntType::get(context)), - itemList); - } else { - ValueList = rewriter.create( - binder.getLoc(), Torch::ListType::get(Torch::FloatType::get(context)), - itemList); +Type getTorchScalarType( + /* forElementIn */ Torch::BaseTensorType givenTensorType, + /* using */ ConversionPatternRewriter &rewriter) { + auto elementTypeForGivenTensor = givenTensorType.getDtype(); + + if (isa(elementTypeForGivenTensor)) + return rewriter.getType(); + if (isa(elementTypeForGivenTensor)) + return rewriter.getType(); + + assert(false && "dtype for given tensor expected to be either int or float"); +} + +Value extractTorchScalar( + /* at */ Location givenLoc, + /* from */ int64_t givenIndex, + /* in */ Value given1DTensor, + /* using */ ConversionPatternRewriter &rewriter) { + auto some1DTensorType = cast(given1DTensor.getType()); + + Type selectionTypeForSome1DTensor = some1DTensorType.getWithSizesAndDtype( + ArrayRef{1}, some1DTensorType.getOptionalDtype()); + + Value frontDim = rewriter.create(givenLoc, 0); + + Value selectionIndex = + rewriter.create(givenLoc, givenIndex); + + auto someTorchScalarType = getTorchScalarType(some1DTensorType, rewriter); + + Value selectionFromGiven1DTensor = rewriter.create( + givenLoc, selectionTypeForSome1DTensor, given1DTensor, frontDim, + selectionIndex); + + return rewriter.create(givenLoc, someTorchScalarType, + selectionFromGiven1DTensor); +} + +Value createScalarSublist( + /* at */ Location givenLoc, + /* movingForwardsThrough */ Value given1DTensor, + /* startingAt */ int64_t givenIndex, + /* using */ ConversionPatternRewriter &rewriter) { + auto some1DTensorType = cast(given1DTensor.getType()); + auto sizesOfSome1DTensor = some1DTensorType.getSizes(); + auto lengthOfFullList = sizesOfSome1DTensor[0]; + + SmallVector runningScalarSublist; + + for (int indexOfEachScalar = givenIndex; indexOfEachScalar < lengthOfFullList; + indexOfEachScalar++) { + Value eachScalar = extractTorchScalar(givenLoc, indexOfEachScalar, + given1DTensor, rewriter); + runningScalarSublist.push_back(eachScalar); } - return ValueList; + + auto someTorchScalarType = runningScalarSublist.front().getType(); + Type someTorchScalarListType = Torch::ListType::get(someTorchScalarType); + + return rewriter.create( + givenLoc, someTorchScalarListType, runningScalarSublist); } } // namespace @@ -2809,14 +2823,21 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( modeStrValue = rewriter.create(binder.getLoc(), modeStr); } + + int64_t assumedForemostSpatialDim = 2; + if (operands.size() < 4) { Value scaleOperand = operands[2]; - scalesValueList = getValueList(binder, rewriter, scaleOperand); + scalesValueList = + createScalarSublist(binder.getLoc(), scaleOperand, + assumedForemostSpatialDim, rewriter); sizesValueList = noneVal; } else { Value sizeOperand = operands[3]; scalesValueList = noneVal; - sizesValueList = getValueList(binder, rewriter, sizeOperand); + sizesValueList = + createScalarSublist(binder.getLoc(), sizeOperand, + assumedForemostSpatialDim, rewriter); } if (isa(scalesValueList.getType()) && isa(sizesValueList.getType())) { @@ -3339,7 +3360,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return rewriter.notifyMatchFailure( binder.op, "supports upto 3d upsampling only"); - Value scalesValueList = getValueList(binder, rewriter, scales); + int64_t assumedForemostSpatialDim = 2; + Value scalesValueList = createScalarSublist( + binder.getLoc(), scales, assumedForemostSpatialDim, rewriter); if (mode == "linear") { if (resultRank == 4) mode = "bilinear"; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 16c86218dbc8..5dd6ee037b75 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2803,8 +2803,9 @@ func.func @test_upsample_nearest(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !t // CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[SELECT:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32> // CHECK: %[[SCALE:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 // CHECK: %[[INT3:.*]] = torch.constant.int 3 - // CHECK: %[[SELECT_0:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32> + // CHECK: %[[SELECT_0:.*]] = torch.aten.select.int %arg1, %[[INT0_0]], %[[INT3]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32> // CHECK: %[[SCALE_0:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],f32> -> !torch.float // CHECK: %[[SCALE_LIST:.*]] = torch.prim.ListConstruct %[[SCALE]], %[[SCALE_0]] : (!torch.float, !torch.float) -> !torch.list // CHECK: %[[MODE:.*]] = torch.constant.str "nearest" @@ -2824,8 +2825,9 @@ func.func @test_upsample_bilinear(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: ! // CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[SELECT:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32> // CHECK: %[[SCALE:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 // CHECK: %[[INT3:.*]] = torch.constant.int 3 - // CHECK: %[[SELECT_0:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32> + // CHECK: %[[SELECT_0:.*]] = torch.aten.select.int %arg1, %[[INT0_0]], %[[INT3]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32> // CHECK: %[[SCALE_0:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],f32> -> !torch.float // CHECK: %[[SCALE_LIST:.*]] = torch.prim.ListConstruct %[[SCALE]], %[[SCALE_0]] : (!torch.float, !torch.float) -> !torch.list // CHECK: %[[MODE:.*]] = torch.constant.str "bilinear"