From cc72042a94b80b7b2f8f7c383ad6e157260b817f Mon Sep 17 00:00:00 2001 From: Jack Frankland Date: Mon, 3 Feb 2025 09:36:50 +0000 Subject: [PATCH] [mlir][tosa] Make Convolution Zero Points Inputs (#122939) The TOSA-v1.0 specification moves the "zero point" parameters of the convolution operators CONV2D, CONV3D, DEPTHWISE_CONV2D, and TRANSPOSE_CONV2D from attributes to inputs. Make the zero points of the convolutions in the MLIR TOSA dialect inputs and update any transformations, materializations and lit tests appropriately. Rename the "filter" argument of `tosa.transpose_conv2d` to weight to align with the TOSA specification. Remove the quantization_info attribute on the convolution operations. Co-authored-by: TatWai Chong --- .../mlir/Dialect/Tosa/IR/TosaOpBase.td | 7 + mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h | 118 +++++++++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 22 +- .../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 5 + .../mlir/Dialect/Tosa/Utils/QuantUtils.h | 3 + .../TosaToLinalg/TosaToLinalgNamed.cpp | 57 ++--- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 233 +++++++++++++++--- .../Tosa/Transforms/TosaDecomposeConv2D.cpp | 29 ++- .../Transforms/TosaDecomposeDepthwise.cpp | 25 +- .../Transforms/TosaDecomposeTransposeConv.cpp | 83 +++---- .../Tosa/Transforms/TosaValidation.cpp | 2 +- mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp | 63 ++++- .../TosaToLinalg/tosa-to-linalg-named.mlir | 19 +- mlir/test/Dialect/Tosa/invalid.mlir | 41 ++- mlir/test/Dialect/Tosa/ops.mlir | 3 +- mlir/test/Dialect/Tosa/quant-test.mlir | 4 +- .../Dialect/Tosa/tosa-decompose-conv2d.mlir | 12 +- .../Tosa/tosa-decompose-depthwise.mlir | 4 +- .../Tosa/tosa-decompose-transpose-conv.mlir | 44 ++-- 19 files changed, 576 insertions(+), 198 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index 4975530a9588c..f492bad78e775 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -264,4 +264,11 @@ class Tosa_InferShapedTypeOp traits = []> "operands attr-dict `:` functional-type(operands, results)"; } +// The "SameVariadicOperandSize" trait allows us to pass optional arguments +// for multiple zero points in convolution ops. +class Tosa_ConvOp traits = []> + : Tosa_InferShapedTypeOp { +} + #endif // TOSA_OP_BASE diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h index 27061002b0295..069073bc2d164 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -16,6 +16,7 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/Dialect/Traits.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeUtilities.h" @@ -29,6 +30,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.h.inc" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { class PatternRewriter; @@ -152,4 +154,120 @@ bool isa_tosa_shape_type(mlir::Type t); #define GET_OP_CLASSES #include "mlir/Dialect/Tosa/IR/TosaOps.h.inc" +namespace mlir { +namespace tosa { + +// Create a rank-1 const tensor for zero point of the source tensor. +std::optional createZeroPointTensor(OpBuilder &builder, Location loc, + Type srcElemType, int64_t zp = 0); + +// Get zero point value from the attribute argument. +LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp); + +// Verify if zero point falls into valid range. +template +LogicalResult verifyZeroPoint(Type zpElemType, int64_t zp) { + if constexpr (!std::is_same_v && !std::is_same_v && + !std::is_same_v && + !std::is_same_v) { + return failure(); + } + + if (!zpElemType.isIntOrFloat()) + return failure(); + + if (!zpElemType.isInteger(8) && zp != 0) + return failure(); + + if (zpElemType.isSignedInteger(8) && (zp < -128 || zp > 127)) + return failure(); + + if (zpElemType.isUnsignedInteger(8) && (zp < 0 || zp > 255)) + return failure(); + + return success(); +} + +// Helper type trait to determine if an operation is a tosa convolution. +template +struct IsTosaConv : std::false_type {}; + +template <> +struct IsTosaConv : std::true_type {}; +template <> +struct IsTosaConv : std::true_type {}; +template <> +struct IsTosaConv : std::true_type {}; +template <> +struct IsTosaConv : std::true_type {}; + +template +constexpr bool is_tosa_conv_v = IsTosaConv::value; + +// Helper struct to hold the zero points of a TOSA convolution operation as +// named 64-bit integer fields. +struct ConvZpPair { + ConvZpPair(std::int64_t inputZp, std::int64_t weightZp) + : inputZp(inputZp), weightZp(weightZp) {} + std::int64_t inputZp; + std::int64_t weightZp; +}; + +// Helper function which attempts to extract the zero points from a TOSA +// convolution by matching them against defining ops which should be tosa.const +// operations. +// +// There are three possible results: +// 1. Failed to extract the zero-points i.e. they should exist and don't or they +// do exist but are invalid. +// 2. Succeeded in extracting zero-points. +// 3. Zero points are "empty" and meaningless for this op i.e. non-quantized +// convolution. +using FailOrMaybeZP = llvm::FailureOr>; +template +std::enable_if_t, FailOrMaybeZP> +extractConvZpPair(TosaConvOp op, PatternRewriter &rewriter) { + // Strictly speaking the base TOSA spec requires that for non int8 types + // zero points must be zero. However, in the dialect these operands are + // optional and only required for int8. They have no semantic meaning for + // non-quantized types and can therefore be safely ignored. This is case 3. + if (auto opElementTY = + cast(op->getOperand(0).getType()).getElementType(); + !opElementTY.isInteger(8)) + return FailOrMaybeZP(std::nullopt); + + // Now we know we should have a zero point check it is valid. + if (!op.getInputZp()) + return rewriter.notifyMatchFailure(op, "missing input zero point"); + + // Helper to extract the zero point by matching its definition against a + // constant. + auto extractZeroPoint = [](Value zpValue) -> std::optional { + ElementsAttr zpAttr; + if (!matchPattern(zpValue, m_Constant(&zpAttr))) + return std::nullopt; + + int64_t zp; + if (tosa::getZeroPoint(zpAttr, zp).failed()) + return std::nullopt; + + return std::make_optional(zp); + }; + + auto maybeInputZp = extractZeroPoint(op.getInputZp()); + if (!maybeInputZp) + return rewriter.notifyMatchFailure(op, "unable to extract input zp"); + + if (!op.getWeightZp()) + return rewriter.notifyMatchFailure(op, "missing weight zero point"); + + auto maybeWeightZp = extractZeroPoint(op.getWeightZp()); + if (!maybeWeightZp) + return rewriter.notifyMatchFailure(op, "unable to extract weight zp"); + + return std::make_optional(*maybeInputZp, *maybeWeightZp); +} +} // namespace tosa +} // namespace mlir + #endif // MLIR_DIALECT_TOSA_IR_TOSAOPS_H diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index c59c582a1f522..819547855d101 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -92,7 +92,7 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> { //===----------------------------------------------------------------------===// // Operator: conv2d //===----------------------------------------------------------------------===// -def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> { +def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> { let summary = "2D Convolution Operator"; let description = [{ @@ -104,11 +104,12 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> { Tosa_Tensor4D:$input, TosaTensorRankOf<[Tosa_Weight], [4]>:$weight, Tosa_Tensor1D:$bias, + Optional:$input_zp, + Optional:$weight_zp, Tosa_IntArrayAttr4:$pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr2:$dilation, TypeAttrOf:$acc_type, - OptionalAttr:$quantization_info, DefaultValuedOptionalAttr:$local_bound ); @@ -123,7 +124,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> { //===----------------------------------------------------------------------===// // Operator: conv3d //===----------------------------------------------------------------------===// -def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> { +def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> { let summary = "3D Convolution operator"; let description = [{ @@ -134,11 +135,12 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> { Tosa_Tensor5D:$input, TosaTensorRankOf<[Tosa_Weight], [5]>:$weight, Tosa_Tensor1D:$bias, + Optional:$input_zp, + Optional:$weight_zp, Tosa_IntArrayAttr6:$pad, Tosa_IntArrayAttr3:$stride, Tosa_IntArrayAttr3:$dilation, TypeAttrOf:$acc_type, - OptionalAttr:$quantization_info, DefaultValuedOptionalAttr:$local_bound ); @@ -153,7 +155,7 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> { //===----------------------------------------------------------------------===// // Operator: depthwise_conv2d //===----------------------------------------------------------------------===// -def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> { +def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> { let summary = "Depthwise 2D Convolution operator"; let description = [{ @@ -165,11 +167,12 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> { Tosa_Tensor4D:$input, TosaTensorRankOf<[Tosa_Weight], [4]>:$weight, Tosa_Tensor1D:$bias, + Optional:$input_zp, + Optional:$weight_zp, Tosa_IntArrayAttr4:$pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr2:$dilation, TypeAttrOf:$acc_type, - OptionalAttr:$quantization_info, DefaultValuedOptionalAttr:$local_bound ); @@ -338,7 +341,7 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> { //===----------------------------------------------------------------------===// // Operator: transpose_conv2d //===----------------------------------------------------------------------===// -def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> { +def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> { let summary = "Transpose 2D Convolution operator."; let description = [{ @@ -348,13 +351,14 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> { let arguments = (ins Tosa_Tensor4D:$input, - TosaTensorRankOf<[Tosa_Weight], [4]>:$filter, + TosaTensorRankOf<[Tosa_Weight], [4]>:$weight, Tosa_Tensor1D:$bias, + Optional:$input_zp, + Optional:$weight_zp, Tosa_IntArrayAttr4:$out_pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr4:$out_shape, TypeAttrOf:$acc_type, - OptionalAttr:$quantization_info, DefaultValuedOptionalAttr:$local_bound ); diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 5693acf3a01db..7aa1f72ec6e17 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -288,4 +288,9 @@ def Rank1TosaShape : TosaShapeOfRank<1>; def Rank2TosaShape : TosaShapeOfRank<2>; def Rank4TosaShape : TosaShapeOfRank<4>; +// NOTE: Tosa_ScalarTensor is currently defined as rank-0. If and when this +// becomes rank-1 it can be used in place of Tosa_ZeroPointTensor and the +// following def can be removed. +def Tosa_ZeroPointTensor : TosaTensorRankOf<[Tosa_AnyNumber], [1]>; + #endif // TOSA_TYPES_BASE diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h index 5e80745777b3b..10dc5dd36cfa9 100644 --- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h @@ -35,6 +35,9 @@ void computeMultiplierAndShift(double scale, int32_t &multiplier, ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight); +std::pair createZPsAsConst(OpBuilder &builder, Value input, + Value weight); + //// Builds MatMulOpQuantizationAttr for MatMul operations from A and B. MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 57a5fe75a007b..cf9852e05cf7c 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -258,7 +258,12 @@ class ConvConverter : public OpConversionPattern { DenseI64ArrayAttr padAttr = op.getPadAttr(); DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr(); DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr(); - bool isQuantized = op.getQuantizationInfo().has_value(); + + auto failureOrMaybeZps = extractConvZpPair(op, rewriter); + if (llvm::failed(failureOrMaybeZps)) + return failure(); + + auto maybeZps = failureOrMaybeZps.value(); if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) return rewriter.notifyMatchFailure( @@ -284,10 +289,7 @@ class ConvConverter : public OpConversionPattern { // Apply padding as necessary. TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy); - if (isQuantized) { - auto quantizationInfo = *op.getQuantizationInfo(); - int64_t iZp = quantizationInfo.getInputZp(); - + if (maybeZps) { int64_t intMin = APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth()) .getSExtValue(); @@ -295,11 +297,11 @@ class ConvConverter : public OpConversionPattern { APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth()) .getSExtValue(); - if (iZp < intMin || iZp > intMax) + if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax) return rewriter.notifyMatchFailure( op, "tosa.conv op quantization has zp outside of input range"); - zeroAttr = rewriter.getIntegerAttr(inputETy, iZp); + zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp); } llvm::SmallVector pad; @@ -312,8 +314,8 @@ class ConvConverter : public OpConversionPattern { // For 2D convolutions, we need to check if the target convolution op // wants a HWCF kernel layout. bool wantHwcf = - isQuantized ? std::is_same_v - : std::is_same_v; + maybeZps ? std::is_same_v + : std::is_same_v; if (wantHwcf) { // Transpose the kernel to match dimension ordering of the linalg // convolution operation. @@ -374,10 +376,9 @@ class ConvConverter : public OpConversionPattern { Value broadcastBias = linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor); - if (isQuantized) { - auto quantizationInfo = *op.getQuantizationInfo(); - auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()); - auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()); + if (maybeZps) { + auto iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp); + auto kZp = rewriter.getI32IntegerAttr(maybeZps->weightZp); auto iZpVal = rewriter.create(loc, iZp); auto kZpVal = rewriter.create(loc, kZp); @@ -440,26 +441,18 @@ class DepthwiseConvConverter /*inputSizeDims=*/{1, 2}, /*kernelSizeDims=*/{0, 1}, rewriter); - bool isQuantized = op->hasAttr("quantization_info"); - IntegerAttr iZp; - IntegerAttr kZp; - if (isQuantized) { - auto quantizationInfo = - cast(op->getAttr("quantization_info")); - iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()); - kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()); - } + auto failureOrMaybeZps = extractConvZpPair(op, rewriter); + if (llvm::failed(failureOrMaybeZps)) + return failure(); + + auto maybeZps = failureOrMaybeZps.value(); auto weightShape = weightTy.getShape(); auto resultShape = resultTy.getShape(); // Apply padding as necessary. TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy); - if (isQuantized) { - auto quantizationInfo = - cast(op->getAttr("quantization_info")); - int64_t iZp = quantizationInfo.getInputZp(); - + if (maybeZps) { int64_t intMin = APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth()) .getSExtValue(); @@ -467,12 +460,12 @@ class DepthwiseConvConverter APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth()) .getSExtValue(); - if (iZp < intMin || iZp > intMax) + if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax) return rewriter.notifyMatchFailure( op, "tosa.depthwise_conv op quantization has zp outside of input " "range"); - zeroAttr = rewriter.getIntegerAttr(inputETy, iZp); + zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp); } llvm::SmallVector pad; @@ -512,7 +505,7 @@ class DepthwiseConvConverter indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); - if (!isQuantized) { + if (!maybeZps) { Value conv = rewriter .create( loc, linalgConvTy, ValueRange{input, weight}, @@ -539,8 +532,10 @@ class DepthwiseConvConverter .getResult(0); rewriter.replaceOp(op, result); } else { + IntegerAttr iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp); + IntegerAttr wZp = rewriter.getI32IntegerAttr(maybeZps->weightZp); auto iZpVal = rewriter.create(loc, iZp); - auto kZpVal = rewriter.create(loc, kZp); + auto kZpVal = rewriter.create(loc, wZp); Value conv = rewriter .create( diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 0a10439db4080..e8b28906135ed 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -217,33 +217,59 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, template static LogicalResult verifyConvOp(T op) { - // All TOSA conv ops have an input() and weight(). + // All TOSA conv ops have an input and weight arguments which must be ranked + // tensors. auto inputType = llvm::dyn_cast(op.getInput().getType()); - - RankedTensorType weightType; - if constexpr (std::is_same_v) - weightType = llvm::dyn_cast(op.getFilter().getType()); - else - weightType = llvm::dyn_cast(op.getWeight().getType()); - - // Must be ranked tensor types if (!inputType) { op.emitOpError("expect a ranked tensor for input, got ") << op.getInput(); return failure(); } + + auto weightType = llvm::dyn_cast(op.getWeight().getType()); if (!weightType) { - if constexpr (std::is_same_v) { - op.emitOpError("expect a ranked tensor for filter, got ") - << op.getFilter(); - } else { - op.emitOpError("expect a ranked tensor for weight, got ") - << op.getWeight(); - } + op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight(); return failure(); } auto inputEType = inputType.getElementType(); auto weightEType = weightType.getElementType(); + auto biasEType = + llvm::cast(op.getBias().getType()).getElementType(); + auto resultEType = + llvm::cast(op.getResult().getType()).getElementType(); + bool biasIsFloat = llvm::isa(biasEType); + bool resultIsFloat = llvm::isa(resultEType); + + if (auto quantType = + llvm::dyn_cast(inputEType)) + inputEType = quantType.getStorageType(); + + if (auto quantType = + llvm::dyn_cast(biasEType)) + biasEType = quantType.getStorageType(); + + if (auto quantType = + llvm::dyn_cast(resultEType)) + resultEType = quantType.getStorageType(); + + if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) { + // for now, only enforce bias element type == result element type for + // float types. + op.emitOpError( + "expect both bias and result to have same element type, got ") + << biasEType << " and " << resultEType; + return failure(); + } + + if (isa(inputEType) || isa(inputEType) || + isa(weightEType) || isa(weightEType)) { + if (inputEType != weightEType) { + op.emitOpError( + "expect both input and weight to have same element type, got ") + << inputEType << " and " << weightEType; + return failure(); + } + } bool inputIsQuant = !llvm::isa(inputEType); bool weightIsQuant = !llvm::isa(weightEType); @@ -256,14 +282,38 @@ static LogicalResult verifyConvOp(T op) { return failure(); } - // Quantized type must have constructed the quantizationattr, and unquantized - // types should not have a quantizationattr. - if ((inputIsQuant && !op.getQuantizationInfo()) || - (!inputIsQuant && op.getQuantizationInfo())) { - op.emitOpError("quantizationattr is required for quantized type, and not " - "allowed for float type"); + // We require an explicit input zero point and weight zero point for i8 + // convolution. + if (!op.getInputZp() && !op.getWeightZp()) + return inputEType.isInteger(8) ? failure() : success(); + + ElementsAttr inputZpAttr; + ElementsAttr weightZpAttr; + if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) || + !matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr))) { + op.emitOpError( + "bail out if the actual value of zero points cannot be determined"); return failure(); } + + // Get and verify explicit zero points. + int64_t inputZpVal; + int64_t weightZpVal; + + if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() || + tosa::verifyZeroPoint(getElementTypeOrSelf(inputZpAttr), inputZpVal) + .failed()) { + op.emitOpError("input zero point must be zero for non-int8 integer types"); + return failure(); + } + + if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() || + tosa::verifyZeroPoint(getElementTypeOrSelf(weightZpAttr), weightZpVal) + .failed()) { + op.emitOpError("weight zero point must be zero for non-int8 integer types"); + return failure(); + } + return success(); } @@ -322,6 +372,39 @@ static LogicalResult verifyConvOpModes(T op) { return success(); } +// verify that inType and outType have same element types +template +static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) { + auto inputType = llvm::dyn_cast(inType); + auto outputType = llvm::dyn_cast(outType); + if (!inputType) { + op.emitOpError("expect shaped tensor for input, got ") << inType; + return failure(); + } + if (!outputType) { + op.emitOpError("expect shaped tensor for output, got ") << outType; + return failure(); + } + auto inputElementType = inputType.getElementType(); + auto outputElementType = outputType.getElementType(); + auto inputQuantType = + llvm::dyn_cast(inputElementType); + auto outputQuantType = + llvm::dyn_cast(outputElementType); + if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) && + (outputElementType.isIntOrIndexOrFloat() || outputQuantType) && + inputElementType != outputElementType) { + // only check if both element types are int/index/float/UniformQuantized + // eg, not sure how to check quant::QuantizedType + // this happens in test_conv2d_q_grouped_convolution in + // tfl-to-tosa-pipeline.mlir + op.emitOpError("expect input and output to have same element type, got ") + << inputElementType << " and " << outputElementType; + return failure(); + } + return success(); +} + LogicalResult tosa::ArgMaxOp::verify() { // Ensure output is of 32-bit integer const auto resultETy = llvm::cast(getType()).getElementType(); @@ -421,21 +504,13 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType) { - - result.addOperands({input, weight, bias}); + auto zps = createZPsAsConst(builder, input, weight); + result.addOperands({input, weight, bias, zps.first, zps.second}); result.addAttribute("pad", pad); result.addAttribute("stride", stride); result.addAttribute("dilation", dilation); result.addAttribute("acc_type", accType); - - auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight); - if (quantAttr) { - result.addAttribute("quantization_info", quantAttr); - result.addTypes( - buildConvOpResultTypeInfo(builder, outputType, input, weight)); - } else { - result.addTypes(outputType); - } + result.addTypes(outputType); } /// Handles tosa.transpose_conv2d which has outpad and output shape @@ -790,7 +865,47 @@ LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents( return success(); } -LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); } +LogicalResult FullyConnectedOp::verify() { + // All TOSA conv ops have an input() and weight(). + auto inputType = llvm::dyn_cast(getInput().getType()); + + RankedTensorType weightType = + llvm::dyn_cast(getWeight().getType()); + + // Must be ranked tensor types + if (!inputType) { + emitOpError("expect a ranked tensor for input, got ") << getInput(); + return failure(); + } + if (!weightType) { + emitOpError("expect a ranked tensor for weight, got ") << getWeight(); + return failure(); + } + + auto inputEType = inputType.getElementType(); + auto weightEType = weightType.getElementType(); + + bool inputIsQuant = !llvm::isa(inputEType); + bool weightIsQuant = !llvm::isa(weightEType); + + // Either both must be quantized or both unquantized. + if (inputIsQuant != weightIsQuant) { + emitOpError( + "expect both input and weight to be float or not together, got ") + << inputEType << " and " << weightEType; + return failure(); + } + + // Quantized type must have constructed the quantizationattr, and unquantized + // types should not have a quantizationattr. + if ((inputIsQuant && !getQuantizationInfo()) || + (!inputIsQuant && getQuantizationInfo())) { + emitOpError("quantizationattr is required for quantized type, and not " + "allowed for float type"); + return failure(); + } + return success(); +} LogicalResult tosa::MatMulOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, @@ -2019,7 +2134,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents( } // Weight shapes describes the filter width/height and the output channels. - ShapeAdaptor weightShape(adaptor.getFilter().getType()); + ShapeAdaptor weightShape(adaptor.getWeight().getType()); if (weightShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) ? weightShape.getDimSize(0) @@ -2315,6 +2430,54 @@ void WhileOp::print(OpAsmPrinter &parser) { parser.printOptionalAttrDictWithKeyword((*this)->getAttrs()); } +LogicalResult mlir::tosa::getZeroPoint(ElementsAttr zpAttr, int64_t &zp) { + Type zpElemType = zpAttr.getElementType(); + if (auto quantType = + llvm::dyn_cast(zpElemType)) { + zp = quantType.getZeroPoint(); + return success(); + } + if (llvm::isa(zpElemType)) { + // non-zero zero point is not allowed for float types. + if (!zpAttr.getValues()[0].isZero()) + return failure(); + zp = 0; + return success(); + } + if (llvm::isa(zpElemType)) { + zp = zpAttr.getValues()[0].getSExtValue(); + return success(); + } + // zero point is not allowed for unsupported types. + return failure(); +} + +// Create a rank-0 const tensor for zero point of the source tensor. +std::optional mlir::tosa::createZeroPointTensor(OpBuilder &builder, + Location loc, + Type srcElemType, + int64_t zp) { + if (auto quantType = + llvm::dyn_cast(srcElemType)) + srcElemType = quantType.getStorageType(); + + auto zpType = mlir::RankedTensorType::get({1}, srcElemType); + if (auto quantType = llvm::dyn_cast(srcElemType)) + srcElemType = quantType.getStorageType(); + if (llvm::isa(srcElemType)) { + auto zpAttr = DenseElementsAttr::get( + zpType, builder.getFloatAttr(srcElemType, static_cast(zp))); + return builder.create(loc, zpType, zpAttr); + } + if (llvm::isa(srcElemType)) { + auto zpAttr = + DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp)); + return builder.create(loc, zpType, zpAttr); + } + llvm::errs() << "zero point is not allowed for unsupported data types\n"; + return std::nullopt; +} + //===----------------------------------------------------------------------===// // TOSA Shape and Shape Operators Helper functions. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp index cb08360f90228..7d3deae3330af 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp @@ -59,19 +59,17 @@ struct Conv2DIsFullyConnected : public OpRewritePattern { for (const auto &it : llvm::enumerate(padAttr)) pad[it.index() + 2] = it.value(); + Type inputETy = inputType.getElementType(); if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) { - Type inputETy = inputType.getElementType(); - Attribute zeroAttr = rewriter.getZeroAttr(inputETy); - if (op.getQuantizationInfo()) { - auto quantizationInfo = op.getQuantizationInfo(); - int64_t iZp = quantizationInfo->getInputZp(); + auto failureOrMaybeZps = extractConvZpPair(op, rewriter); + if (failed(failureOrMaybeZps)) + return failure(); - if (!validIntegerRange(cast(inputETy), iZp)) - return rewriter.notifyMatchFailure( - op, "tosa.conv op quantization has zp outside of input range"); + auto maybeZps = failureOrMaybeZps.value(); - zeroAttr = rewriter.getIntegerAttr(inputETy, iZp); - } + Attribute zeroAttr = + maybeZps ? rewriter.getIntegerAttr(inputETy, maybeZps->inputZp) + : rewriter.getZeroAttr(inputETy); llvm::SmallVector newShape(inputType.getShape()); @@ -125,13 +123,20 @@ struct Conv2DIsFullyConnected : public OpRewritePattern { auto fullyConnectedShapeType = RankedTensorType::get(fullyConnectedShape, resultType.getElementType()); + auto failureOrMaybeZps = extractConvZpPair(op, rewriter); + if (failed(failureOrMaybeZps)) + return failure(); + + auto maybeZps = failureOrMaybeZps.value(); Value fullyConnectedValue; - if (op.getQuantizationInfo()) { + if (maybeZps) { + auto zeroPointAttr = rewriter.getAttr( + maybeZps->inputZp, maybeZps->weightZp); fullyConnectedValue = rewriter .create( op.getLoc(), fullyConnectedShapeType, reshapedInput, - reshapedWeight, op.getBias(), *op.getQuantizationInfo()) + reshapedWeight, op.getBias(), zeroPointAttr) .getResult(); } else { fullyConnectedValue = rewriter diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp index 181aff3a9ce04..ee857f1998a54 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -61,20 +61,26 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern { rewriter.getDenseI64ArrayAttr(revisedInputShape)) .getResult(); - if (inputType.getElementType() != resultType.getElementType()) { - inputType = inputType.clone(resultType.getElementType()); + Type inputETy = inputType.getElementType(); + Type weightETy = weightType.getElementType(); + Type resultETy = resultType.getElementType(); + + if (inputETy != resultETy) { + inputType = inputType.clone(resultETy); input = rewriter.create(op.getLoc(), inputType, input); } - if (weightType.getElementType() != resultType.getElementType()) { - weightType = weightType.clone(resultType.getElementType()); + if (weightETy != resultETy) { + weightType = weightType.clone(resultETy); weight = rewriter.create(op.getLoc(), weightType, weight); } - if (auto quantizationInfo = op.getQuantizationInfo()) { - auto iZp = quantizationInfo->getInputZp(); - auto wZp = quantizationInfo->getWeightZp(); + auto failureOrMaybeZps = extractConvZpPair(op, rewriter); + if (failed(failureOrMaybeZps)) + return failure(); + auto maybeZps = failureOrMaybeZps.value(); + if (maybeZps) { auto applyZp = [&](Value val, int64_t zp) -> Value { if (zp == 0) return val; @@ -89,8 +95,8 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern { zpVal); }; - input = applyZp(input, iZp); - weight = applyZp(weight, wZp); + input = applyZp(input, maybeZps->inputZp); + weight = applyZp(weight, maybeZps->weightZp); } ArrayRef padAttr = op.getPad(); @@ -99,7 +105,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern { pad[it.index() + 2] = it.value(); if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) { - Type inputETy = inputType.getElementType(); Attribute zeroAttr = rewriter.getZeroAttr(inputETy); llvm::SmallVector newShape(inputType.getShape()); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp index 807f9cd683bb8..ae224671e304f 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -69,22 +69,12 @@ class TransposeConvNonStridedConverter auto reverse2 = rewriter.create( loc, weightTy, reverse1, /* axis = */ rewriter.getI32IntegerAttr(2)); - Value conv2d; - if (op.getQuantizationInfo()) { - conv2d = rewriter.create( - loc, resultTy, input, reverse2, bias, - rewriter.getDenseI64ArrayAttr(convPad), - rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr({1, 1}), - /* acc_type = */ op.getAccType(), *op.getQuantizationInfo()); - } else { - conv2d = rewriter.create( - loc, resultTy, input, reverse2, bias, - rewriter.getDenseI64ArrayAttr(convPad), - rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr({1, 1}), - /* acc_type = */ op.getAccTypeAttr()); - } + Value conv2d = rewriter.create( + loc, resultTy, input, reverse2, bias, op.getInputZp(), op.getWeightZp(), + rewriter.getDenseI64ArrayAttr(convPad), + rewriter.getDenseI64ArrayAttr(stride), + rewriter.getDenseI64ArrayAttr({1, 1}), + /* acc_type = */ op.getAccType()); rewriter.replaceOp(op, conv2d); return success(); @@ -144,12 +134,16 @@ class TransposeConvStridedConverter Value weightPaddingVal = getTosaConstShape(rewriter, op->getLoc(), weightPadding); - if (op.getQuantizationInfo().has_value()) { - auto quantInfo = op.getQuantizationInfo().value(); + auto failureOrMaybeZps = extractConvZpPair(op, rewriter); + if (failed(failureOrMaybeZps)) + return failure(); + + auto maybeZps = failureOrMaybeZps.value(); + if (maybeZps) { weight = CreateOpAndInferShape( rewriter, loc, UnrankedTensorType::get(weightETy), weight, weightPaddingVal, nullptr, - rewriter.getAttr(quantInfo.getWeightZp())); + rewriter.getAttr(maybeZps->weightZp)); } else { weight = CreateOpAndInferShape( @@ -205,12 +199,11 @@ class TransposeConvStridedConverter Value inputPaddingVal = getTosaConstShape(rewriter, op->getLoc(), inputPadding); - if (op.getQuantizationInfo().has_value()) { - auto quantInfo = op.getQuantizationInfo().value(); + if (maybeZps) { input = CreateOpAndInferShape( rewriter, loc, UnrankedTensorType::get(inputETy), input, inputPaddingVal, nullptr, - rewriter.getAttr(quantInfo.getInputZp())); + rewriter.getAttr(maybeZps->inputZp)); } else { input = CreateOpAndInferShape( rewriter, loc, UnrankedTensorType::get(inputETy), input, @@ -227,28 +220,34 @@ class TransposeConvStridedConverter biasETy), rewriter.getZeroAttr(biasETy))); - // Perform the convolution using the zero bias. - Value conv2d; - if (op.getQuantizationInfo()) { - conv2d = CreateOpAndInferShape( - rewriter, loc, UnrankedTensorType::get(resultETy), input, - weight, zeroBias, - /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}), - /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}), - /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}), - /* acc_type = */ op.getAccType(), *op.getQuantizationInfo()) - .getResult(); - } else { - conv2d = CreateOpAndInferShape( - rewriter, loc, UnrankedTensorType::get(resultETy), input, - weight, zeroBias, - /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}), - /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}), - /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}), - /* acc_type = */ op.getAccTypeAttr()) - .getResult(); + Value inputZp, weightZp; + if (maybeZps) { + auto maybeInputZp = createZeroPointTensor( + rewriter, loc, getElementTypeOrSelf(input.getType()), + maybeZps->inputZp); + auto maybeWeightZp = createZeroPointTensor( + rewriter, loc, getElementTypeOrSelf(weight.getType()), + maybeZps->weightZp); + + if (!maybeInputZp.has_value() || !maybeWeightZp.has_value()) { + return rewriter.notifyMatchFailure( + op, "fail to create a const zero point tensor"); + } + + inputZp = *maybeInputZp; + weightZp = *maybeWeightZp; } + // Perform the convolution using the zero bias. + Value conv2d = CreateOpAndInferShape( + rewriter, loc, UnrankedTensorType::get(resultETy), input, + weight, zeroBias, inputZp, weightZp, + /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}), + /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}), + /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}), + /* acc_type = */ op.getAccType()) + .getResult(); + // Factor the resulting width / height. ShapedType convTy = cast(conv2d.getType()); Type convETy = convTy.getElementType(); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index a49870687fdc6..678bb47935bd2 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -357,7 +357,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { bool levelCheckTransposeConv2d(Operation *op) { if (auto transpose = dyn_cast(op)) { if (ShapedType filterType = - dyn_cast(transpose.getFilter().getType())) { + dyn_cast(transpose.getWeight().getType())) { auto shape = filterType.getShape(); assert(shape.size() == 4); // level check kernel sizes for kH and KW diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp index 5c546f59cde41..0f7562767001c 100644 --- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp @@ -112,19 +112,14 @@ void mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier, #define GET_QTYPE(inputType) \ (llvm::dyn_cast((inputType).getElementType())) -/// Method to build ConvOpQuantizationAttr, called from -/// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder: -/// input_zp: input zeropoint -/// weight_zp: weight zeropoint. -ConvOpQuantizationAttr -mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input, - Value weight) { +static std::optional> +getConvZeroPoints(Value input, Value weight) { auto inputType = dyn_cast(input.getType()); auto weightType = dyn_cast(weight.getType()); if (!inputType || !weightType) - return nullptr; + return std::nullopt; auto inputQType = GET_UQTYPE(inputType); auto weightPerTensorQType = GET_UQTYPE(weightType); @@ -150,10 +145,58 @@ mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input, weightZp = weightPerAxisQType.getZeroPoints().front(); } - return builder.getAttr(inputZp, weightZp); + return std::make_pair(inputZp, weightZp); } - return nullptr; + return std::nullopt; +} + +std::pair +mlir::tosa::createZPsAsConst(OpBuilder &builder, Value input, Value weight) { + std::int64_t inputZp, weightZp; + + auto inputEType = getElementTypeOrSelf(input.getType()); + auto weightEType = getElementTypeOrSelf(weight.getType()); + + if (mlir::isa(inputEType) && mlir::isa(weightEType)) { + inputZp = 0; + weightZp = 0; + } else { + auto maybeZps = getConvZeroPoints(input, weight); + if (!maybeZps.has_value()) + return {}; + + inputZp = maybeZps->first; + weightZp = maybeZps->second; + } + + auto maybeInputZpValue = + createZeroPointTensor(builder, input.getLoc(), inputEType, inputZp); + if (!maybeInputZpValue.has_value()) + return {}; + + auto maybeWeightZpValue = + createZeroPointTensor(builder, weight.getLoc(), weightEType, weightZp); + if (!maybeWeightZpValue.has_value()) + return {}; + + return std::make_pair(*maybeInputZpValue, *maybeWeightZpValue); +} + +/// Method to build ConvOpQuantizationAttr, called from +/// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder: +/// input_zp: input zeropoint +/// weight_zp: weight zeropoint. +ConvOpQuantizationAttr +mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input, + Value weight) { + + auto maybeZps = getConvZeroPoints(input, weight); + if (!maybeZps.has_value()) + return nullptr; + + return builder.getAttr(maybeZps->first, + maybeZps->second); } /// Builds MatMulOpQuantizationAttr, called from diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index 5eeaebb384e40..116cd045aa0d3 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -544,7 +544,8 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi // CHECK: linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32> // HWCF: linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32> - %0 = tosa.conv2d %input, %weights, %bias {acc_type = i32, dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>) -> tensor<1x45x40x28xi32> + %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.conv2d %input, %weights, %bias, %zp, %zp {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x45x40x28xi32> return } @@ -687,7 +688,9 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] // CHECK: tensor.yield %[[C22]] // CHECK: linalg.conv_2d_nhwc_fhwc_q - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32> + %input_zp = "tosa.const"() {value = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.conv2d %arg0, %arg1, %arg2 , %input_zp, %weight_zp {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x12x12x1024xi32> return } @@ -799,7 +802,9 @@ func.func @depthwise_conv_quant(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3 // CHECK: [[ADD:%.+]] = arith.addi %[[ARG3]], %[[ARG4]] : i32 // CHECK: linalg.yield [[ADD]] : i32 // CHECK: } -> tensor<1x12x12x512xi32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array, quantization_info = #tosa.conv_quant, stride = array, dilation = array } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x12x12x512xi32> + %input_zp = "tosa.const"() {value = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array, stride = array, dilation = array } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x12x12x512xi32> return } @@ -823,7 +828,9 @@ func.func @depthwise_conv_quant_dilations(%arg0 : tensor<1x14x14x4xi8>, %arg1 : // CHECK: [[ADD:%.+]] = arith.addi %[[ARG3]], %[[ARG4]] : i32 // CHECK: linalg.yield [[ADD]] : i32 // CHECK: } -> tensor<1x10x10x512xi32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array, quantization_info = #tosa.conv_quant, stride = array, dilation = array } : (tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x10x10x512xi32> + %input_zp = "tosa.const"() {value = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 , %input_zp, %weight_zp {acc_type = i32, pad = array, stride = array, dilation = array } : (tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x10x10x512xi32> return } @@ -905,7 +912,9 @@ func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5 // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x28xi8>, i32, i32) // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32> - %0 = tosa.conv3d %input, %weights, %bias {acc_type = i32, pad = array, quantization_info = #tosa.conv_quant, stride = array, dilation = array} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>) -> tensor<1x47x45x43x28xi32> + %input_zp = "tosa.const"() {value = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.conv3d %input, %weights, %bias , %input_zp, %weight_zp {acc_type = i32, pad = array, stride = array, dilation = array} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x47x45x43x28xi32> return } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index ac4d466aef94b..006c5bd52a9f6 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -33,45 +33,41 @@ func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>, // ----- func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { + %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> // expected-error@+1 {{expect a ranked tensor for input, got of type 'tensor<*xi8>' at index: 0}} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array, pad = array, stride = array} - : (tensor<*xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8> + %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array, pad = array, stride = array} + : (tensor<*xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8> return %0 : tensor<1x27x27x16xi8> } // ----- func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { + %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> // expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array, pad = array, stride = array} - : (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8> - return %0 : tensor<1x27x27x16xi8> -} - -// ----- - -func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { - // expected-error@+1 {{'tosa.conv2d' op quantizationattr is required for quantized type, and not allowed for float type}} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array, pad = array, stride = array} - : (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8> + %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array, pad = array, stride = array} + : (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8> return %0 : tensor<1x27x27x16xi8> } // ----- func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { + %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> // expected-error@+1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32}} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array, pad = array, stride = array, quantization_info = #tosa.conv_quant} - : (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8> + %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array, pad = array, stride = array} + : (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8> return %0 : tensor<1x27x27x16xi8> } // ----- func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi16>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi16>) -> tensor<1x27x27x16xi16> { + %input_zp = "tosa.const"() {value = dense<0> : tensor<1xi16>} : () -> tensor<1xi16> + %weight_zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> // expected-error@+1 {{'tosa.conv2d' op accumulator type for i16 tensor is not i48}} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array, pad = array, stride = array, quantization_info = #tosa.conv_quant} - : (tensor<1x29x29x4xi16>, tensor<16x3x3x4xi8>, tensor<16xi16>) -> tensor<1x27x27x16xi16> + %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f16, dilation = array, pad = array, stride = array} + : (tensor<1x29x29x4xi16>, tensor<16x3x3x4xi8>, tensor<16xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x27x27x16xi16> return %0 : tensor<1x27x27x16xi16> } @@ -123,25 +119,28 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3 // ----- func.func @test_conv3d_acc_type(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<34xi8>) -> tensor<1x4x8x21x34xi8> { + %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> // expected-error@+1 {{'tosa.conv3d' op accumulator type for i8 tensor is not i32}} - %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array, pad = array, stride = array, quantization_info = #tosa.conv_quant} - : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<34xi8>) -> tensor<1x4x8x21x34xi8> + %0 = tosa.conv3d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array, pad = array, stride = array} + : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<34xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x8x21x34xi8> return %0 : tensor<1x4x8x21x34xi8> } // ----- func.func @test_depthwise_conv2d_acc_type(%arg0: tensor<1x4x4x4xi8>, %arg1: tensor<1x1x4x2xi8>, %arg2: tensor<8xi8>) -> tensor<1x4x4x8xi8> { + %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> // expected-error@+1 {{'tosa.depthwise_conv2d' op accumulator type for i8 tensor is not i32}} - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array, pad = array, stride = array, quantization_info = #tosa.conv_quant} : (tensor<1x4x4x4xi8>, tensor<1x1x4x2xi8>, tensor<8xi8>) -> tensor<1x4x4x8xi8> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array, pad = array, stride = array} : (tensor<1x4x4x4xi8>, tensor<1x1x4x2xi8>, tensor<8xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8xi8> return %0 : tensor<1x4x4x8xi8> } // ----- func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xi8>, %arg1: tensor<16x1x1x8xi8>, %arg2: tensor<16xi8>) -> tensor<1x32x32x16xi8> { + %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> // expected-error@+1 {{'tosa.transpose_conv2d' op accumulator type for i8 tensor is not i32}} - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f16, out_pad = array, out_shape = array, stride = array, quantization_info = #tosa.conv_quant} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>) -> tensor<1x32x32x16xi8> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, out_pad = array, out_shape = array, stride = array} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x32x32x16xi8> return %0 : tensor<1x32x32x16xi8> } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index a4596c8f9d536..d00230d12aab1 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -63,7 +63,8 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, % func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8> { %0 = "tosa.const"() {value = dense<0> : tensor<3x11x11x3xi4>} : () -> tensor<3x11x11x3xi4> %1 = "tosa.const"() {value = dense<[12, 23, 55]> : tensor<3xi32>} : () -> tensor<3xi32> - %2 = "tosa.conv2d"(%arg0, %0, %1) {acc_type = i32, dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>) -> tensor<1x1x1x3xi32> + %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %2 = "tosa.conv2d"(%arg0, %0, %1, %zp, %zp) {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1x1x3xi32> %3 = "tosa.rescale"(%2) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 27 : i32, per_channel = true, scale32 = true, shift = array} : (tensor<1x1x1x3xi32>) -> tensor<1x1x1x3xi8> return %3 : tensor<1x1x1x3xi8> } diff --git a/mlir/test/Dialect/Tosa/quant-test.mlir b/mlir/test/Dialect/Tosa/quant-test.mlir index 6437f12e3ff85..ee6caf285a248 100644 --- a/mlir/test/Dialect/Tosa/quant-test.mlir +++ b/mlir/test/Dialect/Tosa/quant-test.mlir @@ -12,7 +12,9 @@ func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform:f32 // CHECK-LABEL: test_build_mult_and_shift func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform>, %arg1 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform> { // CHECK: tosa.conv2d - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = i32, pad = array, dilation = array, stride = array, quantization_info = #tosa.conv_quant} : (tensor<1x32x32x8x!quant.uniform>, tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489>>, tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform> + %input_zp = "tosa.const"() {value = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2, %input_zp, %weight_zp) {acc_type = i32, pad = array, dilation = array, stride = array} : (tensor<1x32x32x8x!quant.uniform>, tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489>>, tensor<16xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x32x32x16x!quant.uniform> return %0 : tensor<1x32x32x16x!quant.uniform> } diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir index 95d9bb1b98ab7..685f799bd3d2b 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir @@ -33,7 +33,9 @@ func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: t // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array} // CHECK-SAME: -> tensor<4x10x10x3xi32> // CHECK: return %[[VAR3]] - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array, stride = array, dilation = array, quantization_info = #tosa.conv_quant} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32> + %input_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<24> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x3xi32> return %0 : tensor<4x10x10x3xi32> } @@ -50,7 +52,9 @@ func.func @conv_with_dynamic_dim(%arg0: tensor, %arg1: tensor<384 // CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor // CHECK: return %[[VAL_6]] : tensor // CHECK: } - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} : (tensor, tensor<384x1x1x64xi8>, tensor<384xi32>) -> tensor + %input_zp = "tosa.const"() {value = dense<-6> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<11> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor, tensor<384x1x1x64xi8>, tensor<384xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor return %0 : tensor } @@ -65,6 +69,8 @@ func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1: // CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1 {new_shape = array} // CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {quantization_info = #tosa.conv_quant} // CHECK: %[[RESHAPE:.+]] = tosa.reshape %[[FULLY]] {new_shape = array} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array, stride = array, dilation = array, quantization_info = #tosa.conv_quant} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x12x12x3xi32> + %input_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<24> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x12x12x3xi32> return %0 : tensor<4x12x12x3xi32> } diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir index 5f36dd3b3d137..ce29d1a498b4f 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir @@ -38,7 +38,9 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor< // CHECK: %[[reO:.+]] = tosa.reshape %[[mul]] {new_shape = array} // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array} // CHECK: %[[add:.+]] = tosa.add %[[reO]], %[[reArg2]] - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array, stride = array, dilation = array, quantization_info = #tosa.conv_quant} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32> + %input_zp = "tosa.const"() {value = dense<7> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<11> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array, stride = array, dilation = array } : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6xi32> return %0 : tensor<4x10x10x6xi32> } diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir index 12691f2e325a2..bb6de82ee1053 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir @@ -6,7 +6,7 @@ func.func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3x // CHECK: %[[REV2:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32} // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2 // CHECK-SAME: dilation = array, pad = array, stride = array - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2{acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32> return %0 : tensor<2x18x19x5xf32> } @@ -15,10 +15,14 @@ func.func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3x // CHECK-LABEL: @transpose_conv2d_quantized func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x18x19x5xi32>) { + // CHECK-DAG: %[[INPUT_ZP:.+]] = "tosa.const"() <{value = dense<-6> : tensor<1xi8>} + // CHECK-DAG: %[[WEIGHT_ZP:.+]] = "tosa.const"() <{value = dense<11> : tensor<1xi8>} // CHECK: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32} // CHECK: %[[REV2:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32} - // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2 {acc_type = i32, dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = i32, out_pad = array, quantization_info = #tosa.conv_quant, out_shape = array, stride = array} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x18x19x5xi32> + // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array, pad = array, stride = array} + %input_zp = "tosa.const"() {value = dense<-6> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<11> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, out_pad = array, out_shape = array, stride = array} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<2x18x19x5xi32> return %0 : tensor<2x18x19x5xi32> } @@ -26,17 +30,20 @@ func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor // CHECK-LABEL: @transpose_conv2d_quantized_padded func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x21x26x5xi32>) { - // CHECK-DAG: %[[REV0:.+]] = tosa.reverse %0 {axis = 2 : i32} + // CHECK-DAG: %[[INPUT_ZP:.+]] = "tosa.const"() <{value = dense<-22> : tensor<1xi8>} + // CHECK-DAG: %[[WEIGHT_ZP:.+]] = "tosa.const"() <{value = dense<42> : tensor<1xi8>} + // CHECK-DAG: %[[REV0:.+]] = tosa.reverse %2 {axis = 2 : i32} // CHECK-DAG: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32} - // CHECK: tosa.conv2d %arg0, %1, %arg2 + // CHECK: tosa.conv2d %arg0, %3, %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]] // CHECK-SAME: dilation = array, pad = array, - // CHECK-SAME: quantization_info = #tosa.conv_quant, stride = array} - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 { + // CHECK-SAME: stride = array} + %input_zp = "tosa.const"() {value = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp { acc_type = i32, out_pad = array, - quantization_info = #tosa.conv_quant, out_shape = array, - stride = array} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x21x26x5xi32> + stride = array} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<2x21x26x5xi32> return %0 : tensor<2x21x26x5xi32> } @@ -71,7 +78,7 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor< // CHECK-DAG: %[[SLICE:.*]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]] // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 // CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]] - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2{acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32> %1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32> return %1 : tensor<2x?x?x5xf32> } @@ -98,7 +105,9 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1 // Manipulate the final shape. // CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() <{value = dense<0> : tensor<30xi32>} - // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {acc_type = i32, dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} + // CHECK-DAG: %[[INPUT_ZP:.+]] = "tosa.const"() <{value = dense<-22> : tensor<1xi8>} + // CHECK-DAG: %[[WEIGHT_ZP:.+]] = "tosa.const"() <{value = dense<42> : tensor<1xi8>} + // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array, pad = array, stride = array} // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]] {new_shape = array} // CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]] // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]] @@ -107,7 +116,9 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1 // CHECK-DAG: %[[SLICE:.*]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]] // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 // CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]] - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = i32, out_pad = array, quantization_info = #tosa.conv_quant, out_shape = array, stride = array} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32> + %input_zp = "tosa.const"() {value = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, out_pad = array, out_shape = array, stride = array} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<2x35x47x5xi32> return %0 : tensor<2x35x47x5xi32> } @@ -135,12 +146,13 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : // CHECK: %[[PAD_RESULT:.+]] = tosa.pad %[[RESHAPE_RESULT_1]], %[[RESULT_PAD]] // CHECK: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 {new_shape = array} // CHECK: %[[ADD:.+]] = tosa.add %[[PAD_RESULT]], %[[RESHAPE_ARG2]] - %2 = tosa.transpose_conv2d %arg0, %arg1, %arg2 { + %input_zp = "tosa.const"() {value = dense<-103> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<93> : tensor<1xi8>} : () -> tensor<1xi8> + %2 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp { acc_type = i32, out_pad = array, out_shape = array, - stride = array, - quantization_info = #tosa.conv_quant} : - (tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>) -> tensor<1x19x2x1xi32> + stride = array} : + (tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x19x2x1xi32> "func.return" (%2) : (tensor<1x19x2x1xi32>) -> () }