diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index 0edef878f217..15f29fbc3cab 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -121,6 +121,17 @@ void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op, LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input, TypeAttr &accType); +// Get accumulator type for TOSA convolution ops +LogicalResult getConvOpsAccType(PatternRewriter &rewriter, + RankedTensorType inputTy, + RankedTensorType weightTy, + RankedTensorType outputTy, TypeAttr &accType); + +// Temporary function to get TOSA const shape +// TODO: Remove this function when getTosaConstShape is available in +// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h +Value getTosaConstShape(PatternRewriter &rewriter, Location loc, + llvm::ArrayRef shape); } // namespace tosa } // namespace mlir diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 066126fb0906..4ec703d892ad 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" @@ -2252,6 +2253,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "non-const dilation list unsupported"); + TypeAttr accType; + if (failed(tosa::getConvOpsAccType(rewriter, inputTy, weightTy, outputTy, + accType))) + return rewriter.notifyMatchFailure( + op, "failed to get accumulator type for convolution ops"); + // TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights. // Perform the necessary transformations. std::optional nchwToNhwcTransposeConst = @@ -2365,12 +2372,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // full convolution convOpResult = rewriter - .create(op->getLoc(), - getTypeConverter()->convertType(convOpTy), - transposedInput, transformedWeight, bias, - rewriter.getDenseI64ArrayAttr(padding), - rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr(dilation)) + .create( + op->getLoc(), getTypeConverter()->convertType(convOpTy), + transposedInput, transformedWeight, bias, + rewriter.getDenseI64ArrayAttr(padding), + rewriter.getDenseI64ArrayAttr(stride), + rewriter.getDenseI64ArrayAttr(dilation), accType) .getResult(); } else if (weightShape[1] == 1) { // depthwise convolution @@ -2381,7 +2388,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( transposedInput, transformedWeight, bias, rewriter.getDenseI64ArrayAttr(padding), rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr(dilation)) + rewriter.getDenseI64ArrayAttr(dilation), accType) .getResult(); } else { llvm_unreachable("Unhandled convolution type"); @@ -3909,9 +3916,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } } - auto result = rewriter.create( - op->getLoc(), resultType, reshapedInput, - rewriter.getDenseI64ArrayAttr(tileOpShape)); + auto tileOpMultiples = + tosa::getTosaConstShape(rewriter, op->getLoc(), tileOpShape); + + auto result = rewriter.create(op->getLoc(), resultType, + reshapedInput, tileOpMultiples); rewriter.replaceOp(op, {result.getResult()}); } @@ -4104,9 +4113,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( RankedTensorType::get(makeShapeLLVMCompatible(expandedIndicesShape), rewriter.getIntegerType(32)); + auto tileOpMultiples = + tosa::getTosaConstShape(rewriter, op->getLoc(), tileShape); + auto expandedIndices = rewriter.create( - op->getLoc(), tileType, reshapedIndices.getResult(), - rewriter.getDenseI64ArrayAttr(tileShape)); + op->getLoc(), tileType, reshapedIndices.getResult(), tileOpMultiples); // convert torch style index and dim into tf style indices // tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64> @@ -4445,17 +4456,23 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (needsTiling) { auto idxType = dyn_cast(indicesTfConcatTensors[i].getType()); + // indicesTfConcatTensors has a trailing [1] dim for the final concat. auto maxRankMaxDimShapeTf(maxRankMaxDimShape); maxRankMaxDimShapeTf.push_back(1); + auto tileOpShapeTf(tileOpShape); tileOpShapeTf.push_back(1); + auto tileOutputTy = RankedTensorType::get(maxRankMaxDimShapeTf, idxType.getElementType()); auto reshapedIdxTensor = indicesTfConcatTensors[i]; + + auto tileOpMultiples = + tosa::getTosaConstShape(rewriter, op->getLoc(), tileOpShapeTf); + indicesTfConcatTensors[i] = rewriter.create( - op->getLoc(), tileOutputTy, reshapedIdxTensor, - rewriter.getDenseI64ArrayAttr(tileOpShapeTf)); + op->getLoc(), tileOutputTy, reshapedIdxTensor, tileOpMultiples); } // Every index tensor now has the same rank and shape @@ -6023,12 +6040,14 @@ class ConvertAtenFillOp : public OpConversionPattern { op->getLoc(), fillValueMatchedInputRankType, fillValue, rewriter.getDenseI64ArrayAttr(fillValueMatchedInputRankShape)); + auto tileOpMultiples = + tosa::getTosaConstShape(rewriter, op->getLoc(), outType.getShape()); + fillValueTargetTensor = rewriter.create( op->getLoc(), RankedTensorType::get(makeShapeTorchCompatible(outType.getShape()), fillValueElemTy), - fillValueMatchedInputRankTensor.getResult(), - makeShapeTorchCompatible(outType.getShape())); + fillValueMatchedInputRankTensor.getResult(), tileOpMultiples); } else { if (failed(torchScalarToTosaTensor( rewriter, op, op.getValue(), fillValueTargetTensor, outElemTy, @@ -6179,7 +6198,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } DenseElementsAttr paddingAttr = DenseIntElementsAttr::get( - RankedTensorType::get({rank, 2}, rewriter.getI64Type()), + RankedTensorType::get({2 * rank}, rewriter.getI64Type()), translatePadsList); Value padsList1 = rewriter.create( @@ -7836,9 +7855,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( resultType.getElementType()), self, rewriter.getDenseI64ArrayAttr(resultShapeIndex1Replaced)); + auto selfTileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(), + resultShapeIndex0Replaced); + auto selfTiled = rewriter.create( - op->getLoc(), resultType, selfReshaped.getResult(), - rewriter.getDenseI64ArrayAttr(resultShapeIndex0Replaced)); + op->getLoc(), resultType, selfReshaped.getResult(), selfTileOpMultiples); // Reshape and tile vec2 to shape {resultShape[0], vec2Shape[0]} auto vec2Reshaped = rewriter.create( @@ -7847,9 +7868,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( resultType.getElementType()), vec2, rewriter.getDenseI64ArrayAttr(resultShapeIndex0Replaced)); + auto vec2TileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(), + resultShapeIndex1Replaced); + auto vec2Tiled = rewriter.create( - op->getLoc(), resultType, vec2Reshaped.getResult(), - rewriter.getDenseI64ArrayAttr(resultShapeIndex1Replaced)); + op->getLoc(), resultType, vec2Reshaped.getResult(), vec2TileOpMultiples); auto result = tosa::createMulOpAndCast(rewriter, op, resultType, selfTiled.getResult(), diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index ee7f61becf4f..9dedf457096a 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -566,11 +567,12 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, // [0] -> [0,0,0] SmallVector tileShape({W}); // {3} + auto tileOpMultiples = + tosa::getTosaConstShape(rewriter, op->getLoc(), tileShape); auto tosaFillValuesTileOp = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(tileShape, fillValuesType.getElementType()), - tosaFillValuesOneReshapeOp.getResult(), - rewriter.getDenseI64ArrayAttr(tileShape)); + tosaFillValuesOneReshapeOp.getResult(), tileOpMultiples); // [0,0,0] -> [[0,0,0]] SmallVector newTosaFillValuesShape({N, W}); // {1,3} diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 3d97b695f1ab..a27fa9736aaa 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -436,5 +436,63 @@ LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input, return success(); } +// Get accumulator type for TOSA convolution ops +LogicalResult getConvOpsAccType(PatternRewriter &rewriter, + RankedTensorType inputTy, + RankedTensorType weightTy, + RankedTensorType outputTy, TypeAttr &accType) { + auto inputElemTy = inputTy.getElementType(); + auto weightElemTy = weightTy.getElementType(); + auto outputElemTy = outputTy.getElementType(); + + auto quantTy = dyn_cast(inputElemTy); + if (quantTy) + inputElemTy = quantTy.getStorageType(); + + // Get TOSA conv ops acc type based on input, weight, and output types + // according to the spec: + // https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d + // https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d + // https://www.mlplatform.org/tosa/tosa_spec.html#_conv3d + // + // For undefined dtypes in TOSA like I64 and F64, acc_type will be set to the + // output type but does not offer any guarantee on the numerical precision + // since such cases will fail TOSA validation. + if ((inputElemTy.isF32() && weightElemTy.isF32() && outputElemTy.isF32()) || + (inputElemTy.isF16() && weightElemTy.isF16() && outputElemTy.isF16()) || + (inputElemTy.isBF16() && weightElemTy.isBF16() && + outputElemTy.isBF16())) { + accType = mlir::TypeAttr::get(rewriter.getF32Type()); + } else if (inputElemTy.isInteger(8) && + (weightElemTy.isInteger(8) || weightElemTy.isInteger(4)) && + outputElemTy.isInteger(32)) { + accType = mlir::TypeAttr::get(rewriter.getIntegerType(32)); + } else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) && + outputElemTy.isInteger(48)) { + accType = mlir::TypeAttr::get(rewriter.getIntegerType(48)); + } else if ((inputElemTy.isFloat8E4M3() && weightElemTy.isFloat8E4M3() && + outputElemTy.isF16()) || + (inputElemTy.isFloat8E5M2() && weightElemTy.isFloat8E5M2() && + outputElemTy.isF16())) { + accType = mlir::TypeAttr::get(rewriter.getF16Type()); + } else { + accType = mlir::TypeAttr::get(outputElemTy); + } + + return success(); +} + +// Temporary function to get TOSA const shape +// TODO: Remove this function when getTosaConstShape is available in +// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h +Value getTosaConstShape(PatternRewriter &rewriter, Location loc, + llvm::ArrayRef shape) { + auto attr = rewriter.getIndexTensorAttr(shape); + auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size()); + mlir::Operation *mlir_op = + rewriter.create(loc, type, attr); + return mlir_op->getResult(0); +} + } // namespace tosa } // namespace mlir diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 2d9d95082a89..c83fb669ac3a 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1896,21 +1896,22 @@ func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) -> // CHECK: %[[VAL_4:.*]] = torch.constant.int 2 // CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<2xi64>) -> tensor<2xi32> // CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<2xi32>) -> tensor<1x1x2xi32> -// CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_6]] {multiples = array} : (tensor<1x1x2xi32>) -> tensor<4x5x2xi32> -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32> -// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> -// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> -// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32> -// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32> -// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32> -// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32> -// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32> -// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<40x1xi32>) -> tensor<1x40xi32> -// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32> -// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32> -// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32> -// CHECK: return %[[VAL_20]] : !torch.vtensor<[4,5,2],f32> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {value = dense<[4, 5, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_8:.*]] = tosa.tile %[[VAL_6]], %[[VAL_7]] : (tensor<1x1x2xi32>, !tosa.shape<3>) -> tensor<4x5x2xi32> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_10]], %[[VAL_11]], %[[VAL_9]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_14]], %[[VAL_15]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32> +// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<40x1xi32>) -> tensor<1x40xi32> +// CHECK: %[[VAL_19:.*]] = tosa.gather %[[VAL_13]], %[[VAL_18]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32> +// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32> +// CHECK: return %[[VAL_21]] : !torch.vtensor<[4,5,2],f32> // CHECK: } func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> { %int2 = torch.constant.int 2 @@ -1941,10 +1942,11 @@ func.func @torch.aten.fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>) -> // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1],si32> -> tensor<1xi32> // CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1x1x1xi32> -// CHECK: %[[VAL_4:.*]] = tosa.tile %[[VAL_3]] {multiples = array} : (tensor<1x1x1x1xi32>) -> tensor<1x12x128x128xi32> -// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<1x12x128x128xi32>) -> tensor<1x12x128x128xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[1, 12, 128, 128]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_5:.*]] = tosa.tile %[[VAL_3]], %[[VAL_4]] : (tensor<1x1x1x1xi32>, !tosa.shape<4>) -> tensor<1x12x128x128xi32> +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor<1x12x128x128xi32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,12,128,128],f32> // CHECK: } func.func @torch.aten.fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> { %0 = torch.aten.fill.Tensor %arg0, %arg1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[1,12,128,128],f32> @@ -2561,12 +2563,14 @@ func.func @torch.aten.replication_pad2d$basic(%arg0: !torch.vtensor<[1,1,3,3],f3 // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4],f32> -> tensor<4xf32> // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3],f32> -> tensor<3xf32> // CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<3xf32>) -> tensor<3x1xf32> -// CHECK: %[[VAL_5:.*]] = tosa.tile %[[VAL_4]] {multiples = array} : (tensor<3x1xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4xf32>) -> tensor<1x4xf32> -// CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_6]] {multiples = array} : (tensor<1x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_5]], %[[VAL_7]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_9]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[1, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_6:.*]] = tosa.tile %[[VAL_4]], %[[VAL_5]] : (tensor<3x1xf32>, !tosa.shape<2>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4xf32>) -> tensor<1x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[3, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_9:.*]] = tosa.tile %[[VAL_7]], %[[VAL_8]] : (tensor<1x4xf32>, !tosa.shape<2>) -> tensor<3x4xf32> +// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_6]], %[[VAL_9]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.outer$basic(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4],f32> { %0 = torch.aten.outer %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[3,4],f32> @@ -3057,3 +3061,109 @@ func.func @torch.aten.expm1$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vte } // ----- + +// CHECK-LABEL: func.func @torch.aten.constant_pad_nd$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,20,20,4,4],f32>) -> !torch.vtensor<[1,1,20,20,4,5],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,20,20,4,4],f32> -> tensor<1x1x20x20x4x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.float 0xFFF0000000000000 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]> : tensor<12xi64>}> : () -> tensor<12xi64> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0xFF800000> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.pad %[[VAL_1]], %[[VAL_6]], %[[VAL_7]] : (tensor<1x1x20x20x4x4xf32>, tensor<12xi64>, tensor) -> tensor<1x1x20x20x4x5xf32> +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<1x1x20x20x4x5xf32> -> !torch.vtensor<[1,1,20,20,4,5],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[1,1,20,20,4,5],f32> +// CHECK: } +func.func @torch.aten.constant_pad_nd$basic(%arg0: !torch.vtensor<[1,1,20,20,4,4],f32>) -> !torch.vtensor<[1,1,20,20,4,5],f32> { + %float-Inf = torch.constant.float 0xFFF0000000000000 + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.constant_pad_nd %arg0, %0, %float-Inf : !torch.vtensor<[1,1,20,20,4,4],f32>, !torch.list, !torch.float -> !torch.vtensor<[1,1,20,20,4,5],f32> + return %1 : !torch.vtensor<[1,1,20,20,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,10,20],f32>) -> !torch.vtensor<[5,10,14,24],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,2,10,20],f32> -> tensor<5x2x10x20xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.bool false +// CHECK: %[[VAL_3:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense_resource : tensor<10x2x3x3xf32>}> : () -> tensor<10x2x3x3xf32> +// CHECK: %[[VAL_5:.*]] = torch.constant.none +// CHECK: %[[VAL_6:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<10xf32>}> : () -> tensor<10xf32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_12]] : (tensor<5x2x10x20xf32>, tensor<4xi32>) -> tensor<5x10x20x2xf32> +// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_4]], %[[VAL_12]] : (tensor<10x2x3x3xf32>, tensor<4xi32>) -> tensor<10x3x3x2xf32> +// CHECK: %[[VAL_15:.*]] = tosa.conv2d %[[VAL_13]], %[[VAL_14]], %[[VAL_11]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x2xf32>, tensor<10x3x3x2xf32>, tensor<10xf32>) -> tensor<5x14x24x10xf32> +// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_15]], %[[VAL_16]] : (tensor<5x14x24x10xf32>, tensor<4xi32>) -> tensor<5x10x14x24xf32> +// CHECK: %[[VAL_18:.*]] = tensor.cast %[[VAL_17]] : tensor<5x10x14x24xf32> to tensor<5x10x14x24xf32> +// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<5x10x14x24xf32> -> !torch.vtensor<[5,10,14,24],f32> +// CHECK: return %[[VAL_19]] : !torch.vtensor<[5,10,14,24],f32> +// CHECK: } +func.func @torch.aten.convolution$basic(%arg0: !torch.vtensor<[5,2,10,20],f32>) -> !torch.vtensor<[5,10,14,24],f32> { + %false = torch.constant.bool false + %int3 = torch.constant.int 3 + %0 = torch.vtensor.literal(dense_resource : tensor<10x2x3x3xf32>) : !torch.vtensor<[10,2,3,3],f32> + %none = torch.constant.none + %int1 = torch.constant.int 1 + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.prim.ListConstruct : () -> !torch.list + %5 = torch.aten.convolution %arg0, %0, %none, %1, %2, %3, %false, %4, %int1 : !torch.vtensor<[5,2,10,20],f32>, !torch.vtensor<[10,2,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[5,10,14,24],f32> + return %5 : !torch.vtensor<[5,10,14,24],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$depthwise( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,4,10,20],f32>) -> !torch.vtensor<[5,4,5,10],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,4,10,20],f32> -> tensor<5x4x10x20xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.bool false +// CHECK: %[[VAL_3:.*]] = torch.constant.int 4 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense_resource : tensor<4x1x3x3xf32>}> : () -> tensor<4x1x3x3xf32> +// CHECK: %[[VAL_6:.*]] = torch.constant.none +// CHECK: %[[VAL_7:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_7]], %[[VAL_7]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_11:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<4xf32>}> : () -> tensor<4xf32> +// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_13]] : (tensor<5x4x10x20xf32>, tensor<4xi32>) -> tensor<5x10x20x4xf32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[2, 3, 0, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_16:.*]] = tosa.transpose %[[VAL_5]], %[[VAL_15]] : (tensor<4x1x3x3xf32>, tensor<4xi32>) -> tensor<3x3x4x1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<3x3x4x1xf32>) -> tensor<3x3x4x1xf32> +// CHECK: %[[VAL_18:.*]] = tosa.depthwise_conv2d %[[VAL_14]], %[[VAL_17]], %[[VAL_12]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x4xf32>, tensor<3x3x4x1xf32>, tensor<4xf32>) -> tensor<5x5x10x4xf32> +// CHECK: %[[VAL_19:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_20:.*]] = tosa.transpose %[[VAL_18]], %[[VAL_19]] : (tensor<5x5x10x4xf32>, tensor<4xi32>) -> tensor<5x4x5x10xf32> +// CHECK: %[[VAL_21:.*]] = tensor.cast %[[VAL_20]] : tensor<5x4x5x10xf32> to tensor<5x4x5x10xf32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<5x4x5x10xf32> -> !torch.vtensor<[5,4,5,10],f32> +// CHECK: return %[[VAL_22]] : !torch.vtensor<[5,4,5,10],f32> +// CHECK: } +func.func @torch.aten.convolution$depthwise(%arg0: !torch.vtensor<[5,4,10,20],f32>) -> !torch.vtensor<[5,4,5,10],f32> { + %false = torch.constant.bool false + %int4 = torch.constant.int 4 + %int3 = torch.constant.int 3 + %0 = torch.vtensor.literal(dense_resource : tensor<4x1x3x3xf32>) : !torch.vtensor<[4,1,3,3],f32> + %none = torch.constant.none + %int2 = torch.constant.int 2 + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.prim.ListConstruct : () -> !torch.list + %5 = torch.aten.convolution %arg0, %0, %none, %1, %2, %3, %false, %4, %int4 : !torch.vtensor<[5,4,10,20],f32>, !torch.vtensor<[4,1,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[5,4,5,10],f32> + return %5 : !torch.vtensor<[5,4,5,10],f32> +} + +// -----