diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td index 4975530a9588c..f492bad78e775 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -264,4 +264,11 @@ class Tosa_InferShapedTypeOp traits = []> "operands attr-dict `:` functional-type(operands, results)"; } +// The "SameVariadicOperandSize" trait allows us to pass optional arguments +// for multiple zero points in convolution ops. +class Tosa_ConvOp traits = []> + : Tosa_InferShapedTypeOp { +} + #endif // TOSA_OP_BASE diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h index 27061002b0295..069073bc2d164 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -16,6 +16,7 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/Dialect/Traits.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeUtilities.h" @@ -29,6 +30,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.h.inc" +#include "mlir/Transforms/DialectConversion.h" namespace mlir { class PatternRewriter; @@ -152,4 +154,120 @@ bool isa_tosa_shape_type(mlir::Type t); #define GET_OP_CLASSES #include "mlir/Dialect/Tosa/IR/TosaOps.h.inc" +namespace mlir { +namespace tosa { + +// Create a rank-1 const tensor for zero point of the source tensor. +std::optional createZeroPointTensor(OpBuilder &builder, Location loc, + Type srcElemType, int64_t zp = 0); + +// Get zero point value from the attribute argument. +LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp); + +// Verify if zero point falls into valid range. +template +LogicalResult verifyZeroPoint(Type zpElemType, int64_t zp) { + if constexpr (!std::is_same_v && !std::is_same_v && + !std::is_same_v && + !std::is_same_v) { + return failure(); + } + + if (!zpElemType.isIntOrFloat()) + return failure(); + + if (!zpElemType.isInteger(8) && zp != 0) + return failure(); + + if (zpElemType.isSignedInteger(8) && (zp < -128 || zp > 127)) + return failure(); + + if (zpElemType.isUnsignedInteger(8) && (zp < 0 || zp > 255)) + return failure(); + + return success(); +} + +// Helper type trait to determine if an operation is a tosa convolution. +template +struct IsTosaConv : std::false_type {}; + +template <> +struct IsTosaConv : std::true_type {}; +template <> +struct IsTosaConv : std::true_type {}; +template <> +struct IsTosaConv : std::true_type {}; +template <> +struct IsTosaConv : std::true_type {}; + +template +constexpr bool is_tosa_conv_v = IsTosaConv::value; + +// Helper struct to hold the zero points of a TOSA convolution operation as +// named 64-bit integer fields. +struct ConvZpPair { + ConvZpPair(std::int64_t inputZp, std::int64_t weightZp) + : inputZp(inputZp), weightZp(weightZp) {} + std::int64_t inputZp; + std::int64_t weightZp; +}; + +// Helper function which attempts to extract the zero points from a TOSA +// convolution by matching them against defining ops which should be tosa.const +// operations. +// +// There are three possible results: +// 1. Failed to extract the zero-points i.e. they should exist and don't or they +// do exist but are invalid. +// 2. Succeeded in extracting zero-points. +// 3. Zero points are "empty" and meaningless for this op i.e. non-quantized +// convolution. +using FailOrMaybeZP = llvm::FailureOr>; +template +std::enable_if_t, FailOrMaybeZP> +extractConvZpPair(TosaConvOp op, PatternRewriter &rewriter) { + // Strictly speaking the base TOSA spec requires that for non int8 types + // zero points must be zero. However, in the dialect these operands are + // optional and only required for int8. They have no semantic meaning for + // non-quantized types and can therefore be safely ignored. This is case 3. + if (auto opElementTY = + cast(op->getOperand(0).getType()).getElementType(); + !opElementTY.isInteger(8)) + return FailOrMaybeZP(std::nullopt); + + // Now we know we should have a zero point check it is valid. + if (!op.getInputZp()) + return rewriter.notifyMatchFailure(op, "missing input zero point"); + + // Helper to extract the zero point by matching its definition against a + // constant. + auto extractZeroPoint = [](Value zpValue) -> std::optional { + ElementsAttr zpAttr; + if (!matchPattern(zpValue, m_Constant(&zpAttr))) + return std::nullopt; + + int64_t zp; + if (tosa::getZeroPoint(zpAttr, zp).failed()) + return std::nullopt; + + return std::make_optional(zp); + }; + + auto maybeInputZp = extractZeroPoint(op.getInputZp()); + if (!maybeInputZp) + return rewriter.notifyMatchFailure(op, "unable to extract input zp"); + + if (!op.getWeightZp()) + return rewriter.notifyMatchFailure(op, "missing weight zero point"); + + auto maybeWeightZp = extractZeroPoint(op.getWeightZp()); + if (!maybeWeightZp) + return rewriter.notifyMatchFailure(op, "unable to extract weight zp"); + + return std::make_optional(*maybeInputZp, *maybeWeightZp); +} +} // namespace tosa +} // namespace mlir + #endif // MLIR_DIALECT_TOSA_IR_TOSAOPS_H diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index c59c582a1f522..819547855d101 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -92,7 +92,7 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> { //===----------------------------------------------------------------------===// // Operator: conv2d //===----------------------------------------------------------------------===// -def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> { +def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> { let summary = "2D Convolution Operator"; let description = [{ @@ -104,11 +104,12 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> { Tosa_Tensor4D:$input, TosaTensorRankOf<[Tosa_Weight], [4]>:$weight, Tosa_Tensor1D:$bias, + Optional:$input_zp, + Optional:$weight_zp, Tosa_IntArrayAttr4:$pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr2:$dilation, TypeAttrOf:$acc_type, - OptionalAttr:$quantization_info, DefaultValuedOptionalAttr:$local_bound ); @@ -123,7 +124,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> { //===----------------------------------------------------------------------===// // Operator: conv3d //===----------------------------------------------------------------------===// -def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> { +def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> { let summary = "3D Convolution operator"; let description = [{ @@ -134,11 +135,12 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> { Tosa_Tensor5D:$input, TosaTensorRankOf<[Tosa_Weight], [5]>:$weight, Tosa_Tensor1D:$bias, + Optional:$input_zp, + Optional:$weight_zp, Tosa_IntArrayAttr6:$pad, Tosa_IntArrayAttr3:$stride, Tosa_IntArrayAttr3:$dilation, TypeAttrOf:$acc_type, - OptionalAttr:$quantization_info, DefaultValuedOptionalAttr:$local_bound ); @@ -153,7 +155,7 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> { //===----------------------------------------------------------------------===// // Operator: depthwise_conv2d //===----------------------------------------------------------------------===// -def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> { +def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> { let summary = "Depthwise 2D Convolution operator"; let description = [{ @@ -165,11 +167,12 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> { Tosa_Tensor4D:$input, TosaTensorRankOf<[Tosa_Weight], [4]>:$weight, Tosa_Tensor1D:$bias, + Optional:$input_zp, + Optional:$weight_zp, Tosa_IntArrayAttr4:$pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr2:$dilation, TypeAttrOf:$acc_type, - OptionalAttr:$quantization_info, DefaultValuedOptionalAttr:$local_bound ); @@ -338,7 +341,7 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> { //===----------------------------------------------------------------------===// // Operator: transpose_conv2d //===----------------------------------------------------------------------===// -def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> { +def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> { let summary = "Transpose 2D Convolution operator."; let description = [{ @@ -348,13 +351,14 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> { let arguments = (ins Tosa_Tensor4D:$input, - TosaTensorRankOf<[Tosa_Weight], [4]>:$filter, + TosaTensorRankOf<[Tosa_Weight], [4]>:$weight, Tosa_Tensor1D:$bias, + Optional:$input_zp, + Optional:$weight_zp, Tosa_IntArrayAttr4:$out_pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr4:$out_shape, TypeAttrOf:$acc_type, - OptionalAttr:$quantization_info, DefaultValuedOptionalAttr:$local_bound ); diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 5693acf3a01db..7aa1f72ec6e17 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -288,4 +288,9 @@ def Rank1TosaShape : TosaShapeOfRank<1>; def Rank2TosaShape : TosaShapeOfRank<2>; def Rank4TosaShape : TosaShapeOfRank<4>; +// NOTE: Tosa_ScalarTensor is currently defined as rank-0. If and when this +// becomes rank-1 it can be used in place of Tosa_ZeroPointTensor and the +// following def can be removed. +def Tosa_ZeroPointTensor : TosaTensorRankOf<[Tosa_AnyNumber], [1]>; + #endif // TOSA_TYPES_BASE diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h index 5e80745777b3b..10dc5dd36cfa9 100644 --- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h +++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h @@ -35,6 +35,9 @@ void computeMultiplierAndShift(double scale, int32_t &multiplier, ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight); +std::pair createZPsAsConst(OpBuilder &builder, Value input, + Value weight); + //// Builds MatMulOpQuantizationAttr for MatMul operations from A and B. MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 57a5fe75a007b..cf9852e05cf7c 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -258,7 +258,12 @@ class ConvConverter : public OpConversionPattern { DenseI64ArrayAttr padAttr = op.getPadAttr(); DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr(); DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr(); - bool isQuantized = op.getQuantizationInfo().has_value(); + + auto failureOrMaybeZps = extractConvZpPair(op, rewriter); + if (llvm::failed(failureOrMaybeZps)) + return failure(); + + auto maybeZps = failureOrMaybeZps.value(); if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) return rewriter.notifyMatchFailure( @@ -284,10 +289,7 @@ class ConvConverter : public OpConversionPattern { // Apply padding as necessary. TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy); - if (isQuantized) { - auto quantizationInfo = *op.getQuantizationInfo(); - int64_t iZp = quantizationInfo.getInputZp(); - + if (maybeZps) { int64_t intMin = APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth()) .getSExtValue(); @@ -295,11 +297,11 @@ class ConvConverter : public OpConversionPattern { APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth()) .getSExtValue(); - if (iZp < intMin || iZp > intMax) + if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax) return rewriter.notifyMatchFailure( op, "tosa.conv op quantization has zp outside of input range"); - zeroAttr = rewriter.getIntegerAttr(inputETy, iZp); + zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp); } llvm::SmallVector pad; @@ -312,8 +314,8 @@ class ConvConverter : public OpConversionPattern { // For 2D convolutions, we need to check if the target convolution op // wants a HWCF kernel layout. bool wantHwcf = - isQuantized ? std::is_same_v - : std::is_same_v; + maybeZps ? std::is_same_v + : std::is_same_v; if (wantHwcf) { // Transpose the kernel to match dimension ordering of the linalg // convolution operation. @@ -374,10 +376,9 @@ class ConvConverter : public OpConversionPattern { Value broadcastBias = linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor); - if (isQuantized) { - auto quantizationInfo = *op.getQuantizationInfo(); - auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()); - auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()); + if (maybeZps) { + auto iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp); + auto kZp = rewriter.getI32IntegerAttr(maybeZps->weightZp); auto iZpVal = rewriter.create(loc, iZp); auto kZpVal = rewriter.create(loc, kZp); @@ -440,26 +441,18 @@ class DepthwiseConvConverter /*inputSizeDims=*/{1, 2}, /*kernelSizeDims=*/{0, 1}, rewriter); - bool isQuantized = op->hasAttr("quantization_info"); - IntegerAttr iZp; - IntegerAttr kZp; - if (isQuantized) { - auto quantizationInfo = - cast(op->getAttr("quantization_info")); - iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()); - kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()); - } + auto failureOrMaybeZps = extractConvZpPair(op, rewriter); + if (llvm::failed(failureOrMaybeZps)) + return failure(); + + auto maybeZps = failureOrMaybeZps.value(); auto weightShape = weightTy.getShape(); auto resultShape = resultTy.getShape(); // Apply padding as necessary. TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy); - if (isQuantized) { - auto quantizationInfo = - cast(op->getAttr("quantization_info")); - int64_t iZp = quantizationInfo.getInputZp(); - + if (maybeZps) { int64_t intMin = APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth()) .getSExtValue(); @@ -467,12 +460,12 @@ class DepthwiseConvConverter APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth()) .getSExtValue(); - if (iZp < intMin || iZp > intMax) + if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax) return rewriter.notifyMatchFailure( op, "tosa.depthwise_conv op quantization has zp outside of input " "range"); - zeroAttr = rewriter.getIntegerAttr(inputETy, iZp); + zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp); } llvm::SmallVector pad; @@ -512,7 +505,7 @@ class DepthwiseConvConverter indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); - if (!isQuantized) { + if (!maybeZps) { Value conv = rewriter .create( loc, linalgConvTy, ValueRange{input, weight}, @@ -539,8 +532,10 @@ class DepthwiseConvConverter .getResult(0); rewriter.replaceOp(op, result); } else { + IntegerAttr iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp); + IntegerAttr wZp = rewriter.getI32IntegerAttr(maybeZps->weightZp); auto iZpVal = rewriter.create(loc, iZp); - auto kZpVal = rewriter.create(loc, kZp); + auto kZpVal = rewriter.create(loc, wZp); Value conv = rewriter .create( diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 0a10439db4080..e8b28906135ed 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -217,33 +217,59 @@ void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, template static LogicalResult verifyConvOp(T op) { - // All TOSA conv ops have an input() and weight(). + // All TOSA conv ops have an input and weight arguments which must be ranked + // tensors. auto inputType = llvm::dyn_cast(op.getInput().getType()); - - RankedTensorType weightType; - if constexpr (std::is_same_v) - weightType = llvm::dyn_cast(op.getFilter().getType()); - else - weightType = llvm::dyn_cast(op.getWeight().getType()); - - // Must be ranked tensor types if (!inputType) { op.emitOpError("expect a ranked tensor for input, got ") << op.getInput(); return failure(); } + + auto weightType = llvm::dyn_cast(op.getWeight().getType()); if (!weightType) { - if constexpr (std::is_same_v) { - op.emitOpError("expect a ranked tensor for filter, got ") - << op.getFilter(); - } else { - op.emitOpError("expect a ranked tensor for weight, got ") - << op.getWeight(); - } + op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight(); return failure(); } auto inputEType = inputType.getElementType(); auto weightEType = weightType.getElementType(); + auto biasEType = + llvm::cast(op.getBias().getType()).getElementType(); + auto resultEType = + llvm::cast(op.getResult().getType()).getElementType(); + bool biasIsFloat = llvm::isa(biasEType); + bool resultIsFloat = llvm::isa(resultEType); + + if (auto quantType = + llvm::dyn_cast(inputEType)) + inputEType = quantType.getStorageType(); + + if (auto quantType = + llvm::dyn_cast(biasEType)) + biasEType = quantType.getStorageType(); + + if (auto quantType = + llvm::dyn_cast(resultEType)) + resultEType = quantType.getStorageType(); + + if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) { + // for now, only enforce bias element type == result element type for + // float types. + op.emitOpError( + "expect both bias and result to have same element type, got ") + << biasEType << " and " << resultEType; + return failure(); + } + + if (isa(inputEType) || isa(inputEType) || + isa(weightEType) || isa(weightEType)) { + if (inputEType != weightEType) { + op.emitOpError( + "expect both input and weight to have same element type, got ") + << inputEType << " and " << weightEType; + return failure(); + } + } bool inputIsQuant = !llvm::isa(inputEType); bool weightIsQuant = !llvm::isa(weightEType); @@ -256,14 +282,38 @@ static LogicalResult verifyConvOp(T op) { return failure(); } - // Quantized type must have constructed the quantizationattr, and unquantized - // types should not have a quantizationattr. - if ((inputIsQuant && !op.getQuantizationInfo()) || - (!inputIsQuant && op.getQuantizationInfo())) { - op.emitOpError("quantizationattr is required for quantized type, and not " - "allowed for float type"); + // We require an explicit input zero point and weight zero point for i8 + // convolution. + if (!op.getInputZp() && !op.getWeightZp()) + return inputEType.isInteger(8) ? failure() : success(); + + ElementsAttr inputZpAttr; + ElementsAttr weightZpAttr; + if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) || + !matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr))) { + op.emitOpError( + "bail out if the actual value of zero points cannot be determined"); return failure(); } + + // Get and verify explicit zero points. + int64_t inputZpVal; + int64_t weightZpVal; + + if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() || + tosa::verifyZeroPoint(getElementTypeOrSelf(inputZpAttr), inputZpVal) + .failed()) { + op.emitOpError("input zero point must be zero for non-int8 integer types"); + return failure(); + } + + if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() || + tosa::verifyZeroPoint(getElementTypeOrSelf(weightZpAttr), weightZpVal) + .failed()) { + op.emitOpError("weight zero point must be zero for non-int8 integer types"); + return failure(); + } + return success(); } @@ -322,6 +372,39 @@ static LogicalResult verifyConvOpModes(T op) { return success(); } +// verify that inType and outType have same element types +template +static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) { + auto inputType = llvm::dyn_cast(inType); + auto outputType = llvm::dyn_cast(outType); + if (!inputType) { + op.emitOpError("expect shaped tensor for input, got ") << inType; + return failure(); + } + if (!outputType) { + op.emitOpError("expect shaped tensor for output, got ") << outType; + return failure(); + } + auto inputElementType = inputType.getElementType(); + auto outputElementType = outputType.getElementType(); + auto inputQuantType = + llvm::dyn_cast(inputElementType); + auto outputQuantType = + llvm::dyn_cast(outputElementType); + if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) && + (outputElementType.isIntOrIndexOrFloat() || outputQuantType) && + inputElementType != outputElementType) { + // only check if both element types are int/index/float/UniformQuantized + // eg, not sure how to check quant::QuantizedType + // this happens in test_conv2d_q_grouped_convolution in + // tfl-to-tosa-pipeline.mlir + op.emitOpError("expect input and output to have same element type, got ") + << inputElementType << " and " << outputElementType; + return failure(); + } + return success(); +} + LogicalResult tosa::ArgMaxOp::verify() { // Ensure output is of 32-bit integer const auto resultETy = llvm::cast(getType()).getElementType(); @@ -421,21 +504,13 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType) { - - result.addOperands({input, weight, bias}); + auto zps = createZPsAsConst(builder, input, weight); + result.addOperands({input, weight, bias, zps.first, zps.second}); result.addAttribute("pad", pad); result.addAttribute("stride", stride); result.addAttribute("dilation", dilation); result.addAttribute("acc_type", accType); - - auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight); - if (quantAttr) { - result.addAttribute("quantization_info", quantAttr); - result.addTypes( - buildConvOpResultTypeInfo(builder, outputType, input, weight)); - } else { - result.addTypes(outputType); - } + result.addTypes(outputType); } /// Handles tosa.transpose_conv2d which has outpad and output shape @@ -790,7 +865,47 @@ LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents( return success(); } -LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); } +LogicalResult FullyConnectedOp::verify() { + // All TOSA conv ops have an input() and weight(). + auto inputType = llvm::dyn_cast(getInput().getType()); + + RankedTensorType weightType = + llvm::dyn_cast(getWeight().getType()); + + // Must be ranked tensor types + if (!inputType) { + emitOpError("expect a ranked tensor for input, got ") << getInput(); + return failure(); + } + if (!weightType) { + emitOpError("expect a ranked tensor for weight, got ") << getWeight(); + return failure(); + } + + auto inputEType = inputType.getElementType(); + auto weightEType = weightType.getElementType(); + + bool inputIsQuant = !llvm::isa(inputEType); + bool weightIsQuant = !llvm::isa(weightEType); + + // Either both must be quantized or both unquantized. + if (inputIsQuant != weightIsQuant) { + emitOpError( + "expect both input and weight to be float or not together, got ") + << inputEType << " and " << weightEType; + return failure(); + } + + // Quantized type must have constructed the quantizationattr, and unquantized + // types should not have a quantizationattr. + if ((inputIsQuant && !getQuantizationInfo()) || + (!inputIsQuant && getQuantizationInfo())) { + emitOpError("quantizationattr is required for quantized type, and not " + "allowed for float type"); + return failure(); + } + return success(); +} LogicalResult tosa::MatMulOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, @@ -2019,7 +2134,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents( } // Weight shapes describes the filter width/height and the output channels. - ShapeAdaptor weightShape(adaptor.getFilter().getType()); + ShapeAdaptor weightShape(adaptor.getWeight().getType()); if (weightShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) ? weightShape.getDimSize(0) @@ -2315,6 +2430,54 @@ void WhileOp::print(OpAsmPrinter &parser) { parser.printOptionalAttrDictWithKeyword((*this)->getAttrs()); } +LogicalResult mlir::tosa::getZeroPoint(ElementsAttr zpAttr, int64_t &zp) { + Type zpElemType = zpAttr.getElementType(); + if (auto quantType = + llvm::dyn_cast(zpElemType)) { + zp = quantType.getZeroPoint(); + return success(); + } + if (llvm::isa(zpElemType)) { + // non-zero zero point is not allowed for float types. + if (!zpAttr.getValues()[0].isZero()) + return failure(); + zp = 0; + return success(); + } + if (llvm::isa(zpElemType)) { + zp = zpAttr.getValues()[0].getSExtValue(); + return success(); + } + // zero point is not allowed for unsupported types. + return failure(); +} + +// Create a rank-0 const tensor for zero point of the source tensor. +std::optional mlir::tosa::createZeroPointTensor(OpBuilder &builder, + Location loc, + Type srcElemType, + int64_t zp) { + if (auto quantType = + llvm::dyn_cast(srcElemType)) + srcElemType = quantType.getStorageType(); + + auto zpType = mlir::RankedTensorType::get({1}, srcElemType); + if (auto quantType = llvm::dyn_cast(srcElemType)) + srcElemType = quantType.getStorageType(); + if (llvm::isa(srcElemType)) { + auto zpAttr = DenseElementsAttr::get( + zpType, builder.getFloatAttr(srcElemType, static_cast(zp))); + return builder.create(loc, zpType, zpAttr); + } + if (llvm::isa(srcElemType)) { + auto zpAttr = + DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp)); + return builder.create(loc, zpType, zpAttr); + } + llvm::errs() << "zero point is not allowed for unsupported data types\n"; + return std::nullopt; +} + //===----------------------------------------------------------------------===// // TOSA Shape and Shape Operators Helper functions. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp index cb08360f90228..7d3deae3330af 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp @@ -59,19 +59,17 @@ struct Conv2DIsFullyConnected : public OpRewritePattern { for (const auto &it : llvm::enumerate(padAttr)) pad[it.index() + 2] = it.value(); + Type inputETy = inputType.getElementType(); if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) { - Type inputETy = inputType.getElementType(); - Attribute zeroAttr = rewriter.getZeroAttr(inputETy); - if (op.getQuantizationInfo()) { - auto quantizationInfo = op.getQuantizationInfo(); - int64_t iZp = quantizationInfo->getInputZp(); + auto failureOrMaybeZps = extractConvZpPair(op, rewriter); + if (failed(failureOrMaybeZps)) + return failure(); - if (!validIntegerRange(cast(inputETy), iZp)) - return rewriter.notifyMatchFailure( - op, "tosa.conv op quantization has zp outside of input range"); + auto maybeZps = failureOrMaybeZps.value(); - zeroAttr = rewriter.getIntegerAttr(inputETy, iZp); - } + Attribute zeroAttr = + maybeZps ? rewriter.getIntegerAttr(inputETy, maybeZps->inputZp) + : rewriter.getZeroAttr(inputETy); llvm::SmallVector newShape(inputType.getShape()); @@ -125,13 +123,20 @@ struct Conv2DIsFullyConnected : public OpRewritePattern { auto fullyConnectedShapeType = RankedTensorType::get(fullyConnectedShape, resultType.getElementType()); + auto failureOrMaybeZps = extractConvZpPair(op, rewriter); + if (failed(failureOrMaybeZps)) + return failure(); + + auto maybeZps = failureOrMaybeZps.value(); Value fullyConnectedValue; - if (op.getQuantizationInfo()) { + if (maybeZps) { + auto zeroPointAttr = rewriter.getAttr( + maybeZps->inputZp, maybeZps->weightZp); fullyConnectedValue = rewriter .create( op.getLoc(), fullyConnectedShapeType, reshapedInput, - reshapedWeight, op.getBias(), *op.getQuantizationInfo()) + reshapedWeight, op.getBias(), zeroPointAttr) .getResult(); } else { fullyConnectedValue = rewriter diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp index 181aff3a9ce04..ee857f1998a54 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -61,20 +61,26 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern { rewriter.getDenseI64ArrayAttr(revisedInputShape)) .getResult(); - if (inputType.getElementType() != resultType.getElementType()) { - inputType = inputType.clone(resultType.getElementType()); + Type inputETy = inputType.getElementType(); + Type weightETy = weightType.getElementType(); + Type resultETy = resultType.getElementType(); + + if (inputETy != resultETy) { + inputType = inputType.clone(resultETy); input = rewriter.create(op.getLoc(), inputType, input); } - if (weightType.getElementType() != resultType.getElementType()) { - weightType = weightType.clone(resultType.getElementType()); + if (weightETy != resultETy) { + weightType = weightType.clone(resultETy); weight = rewriter.create(op.getLoc(), weightType, weight); } - if (auto quantizationInfo = op.getQuantizationInfo()) { - auto iZp = quantizationInfo->getInputZp(); - auto wZp = quantizationInfo->getWeightZp(); + auto failureOrMaybeZps = extractConvZpPair(op, rewriter); + if (failed(failureOrMaybeZps)) + return failure(); + auto maybeZps = failureOrMaybeZps.value(); + if (maybeZps) { auto applyZp = [&](Value val, int64_t zp) -> Value { if (zp == 0) return val; @@ -89,8 +95,8 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern { zpVal); }; - input = applyZp(input, iZp); - weight = applyZp(weight, wZp); + input = applyZp(input, maybeZps->inputZp); + weight = applyZp(weight, maybeZps->weightZp); } ArrayRef padAttr = op.getPad(); @@ -99,7 +105,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern { pad[it.index() + 2] = it.value(); if (llvm::any_of(pad, [](int64_t p) { return p != 0; })) { - Type inputETy = inputType.getElementType(); Attribute zeroAttr = rewriter.getZeroAttr(inputETy); llvm::SmallVector newShape(inputType.getShape()); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp index 807f9cd683bb8..ae224671e304f 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -69,22 +69,12 @@ class TransposeConvNonStridedConverter auto reverse2 = rewriter.create( loc, weightTy, reverse1, /* axis = */ rewriter.getI32IntegerAttr(2)); - Value conv2d; - if (op.getQuantizationInfo()) { - conv2d = rewriter.create( - loc, resultTy, input, reverse2, bias, - rewriter.getDenseI64ArrayAttr(convPad), - rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr({1, 1}), - /* acc_type = */ op.getAccType(), *op.getQuantizationInfo()); - } else { - conv2d = rewriter.create( - loc, resultTy, input, reverse2, bias, - rewriter.getDenseI64ArrayAttr(convPad), - rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr({1, 1}), - /* acc_type = */ op.getAccTypeAttr()); - } + Value conv2d = rewriter.create( + loc, resultTy, input, reverse2, bias, op.getInputZp(), op.getWeightZp(), + rewriter.getDenseI64ArrayAttr(convPad), + rewriter.getDenseI64ArrayAttr(stride), + rewriter.getDenseI64ArrayAttr({1, 1}), + /* acc_type = */ op.getAccType()); rewriter.replaceOp(op, conv2d); return success(); @@ -144,12 +134,16 @@ class TransposeConvStridedConverter Value weightPaddingVal = getTosaConstShape(rewriter, op->getLoc(), weightPadding); - if (op.getQuantizationInfo().has_value()) { - auto quantInfo = op.getQuantizationInfo().value(); + auto failureOrMaybeZps = extractConvZpPair(op, rewriter); + if (failed(failureOrMaybeZps)) + return failure(); + + auto maybeZps = failureOrMaybeZps.value(); + if (maybeZps) { weight = CreateOpAndInferShape( rewriter, loc, UnrankedTensorType::get(weightETy), weight, weightPaddingVal, nullptr, - rewriter.getAttr(quantInfo.getWeightZp())); + rewriter.getAttr(maybeZps->weightZp)); } else { weight = CreateOpAndInferShape( @@ -205,12 +199,11 @@ class TransposeConvStridedConverter Value inputPaddingVal = getTosaConstShape(rewriter, op->getLoc(), inputPadding); - if (op.getQuantizationInfo().has_value()) { - auto quantInfo = op.getQuantizationInfo().value(); + if (maybeZps) { input = CreateOpAndInferShape( rewriter, loc, UnrankedTensorType::get(inputETy), input, inputPaddingVal, nullptr, - rewriter.getAttr(quantInfo.getInputZp())); + rewriter.getAttr(maybeZps->inputZp)); } else { input = CreateOpAndInferShape( rewriter, loc, UnrankedTensorType::get(inputETy), input, @@ -227,28 +220,34 @@ class TransposeConvStridedConverter biasETy), rewriter.getZeroAttr(biasETy))); - // Perform the convolution using the zero bias. - Value conv2d; - if (op.getQuantizationInfo()) { - conv2d = CreateOpAndInferShape( - rewriter, loc, UnrankedTensorType::get(resultETy), input, - weight, zeroBias, - /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}), - /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}), - /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}), - /* acc_type = */ op.getAccType(), *op.getQuantizationInfo()) - .getResult(); - } else { - conv2d = CreateOpAndInferShape( - rewriter, loc, UnrankedTensorType::get(resultETy), input, - weight, zeroBias, - /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}), - /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}), - /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}), - /* acc_type = */ op.getAccTypeAttr()) - .getResult(); + Value inputZp, weightZp; + if (maybeZps) { + auto maybeInputZp = createZeroPointTensor( + rewriter, loc, getElementTypeOrSelf(input.getType()), + maybeZps->inputZp); + auto maybeWeightZp = createZeroPointTensor( + rewriter, loc, getElementTypeOrSelf(weight.getType()), + maybeZps->weightZp); + + if (!maybeInputZp.has_value() || !maybeWeightZp.has_value()) { + return rewriter.notifyMatchFailure( + op, "fail to create a const zero point tensor"); + } + + inputZp = *maybeInputZp; + weightZp = *maybeWeightZp; } + // Perform the convolution using the zero bias. + Value conv2d = CreateOpAndInferShape( + rewriter, loc, UnrankedTensorType::get(resultETy), input, + weight, zeroBias, inputZp, weightZp, + /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}), + /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}), + /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}), + /* acc_type = */ op.getAccType()) + .getResult(); + // Factor the resulting width / height. ShapedType convTy = cast(conv2d.getType()); Type convETy = convTy.getElementType(); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index a49870687fdc6..678bb47935bd2 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -357,7 +357,7 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { bool levelCheckTransposeConv2d(Operation *op) { if (auto transpose = dyn_cast(op)) { if (ShapedType filterType = - dyn_cast(transpose.getFilter().getType())) { + dyn_cast(transpose.getWeight().getType())) { auto shape = filterType.getShape(); assert(shape.size() == 4); // level check kernel sizes for kH and KW diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp index 5c546f59cde41..0f7562767001c 100644 --- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp @@ -112,19 +112,14 @@ void mlir::tosa::computeMultiplierAndShift(double scale, int32_t &multiplier, #define GET_QTYPE(inputType) \ (llvm::dyn_cast((inputType).getElementType())) -/// Method to build ConvOpQuantizationAttr, called from -/// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder: -/// input_zp: input zeropoint -/// weight_zp: weight zeropoint. -ConvOpQuantizationAttr -mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input, - Value weight) { +static std::optional> +getConvZeroPoints(Value input, Value weight) { auto inputType = dyn_cast(input.getType()); auto weightType = dyn_cast(weight.getType()); if (!inputType || !weightType) - return nullptr; + return std::nullopt; auto inputQType = GET_UQTYPE(inputType); auto weightPerTensorQType = GET_UQTYPE(weightType); @@ -150,10 +145,58 @@ mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input, weightZp = weightPerAxisQType.getZeroPoints().front(); } - return builder.getAttr(inputZp, weightZp); + return std::make_pair(inputZp, weightZp); } - return nullptr; + return std::nullopt; +} + +std::pair +mlir::tosa::createZPsAsConst(OpBuilder &builder, Value input, Value weight) { + std::int64_t inputZp, weightZp; + + auto inputEType = getElementTypeOrSelf(input.getType()); + auto weightEType = getElementTypeOrSelf(weight.getType()); + + if (mlir::isa(inputEType) && mlir::isa(weightEType)) { + inputZp = 0; + weightZp = 0; + } else { + auto maybeZps = getConvZeroPoints(input, weight); + if (!maybeZps.has_value()) + return {}; + + inputZp = maybeZps->first; + weightZp = maybeZps->second; + } + + auto maybeInputZpValue = + createZeroPointTensor(builder, input.getLoc(), inputEType, inputZp); + if (!maybeInputZpValue.has_value()) + return {}; + + auto maybeWeightZpValue = + createZeroPointTensor(builder, weight.getLoc(), weightEType, weightZp); + if (!maybeWeightZpValue.has_value()) + return {}; + + return std::make_pair(*maybeInputZpValue, *maybeWeightZpValue); +} + +/// Method to build ConvOpQuantizationAttr, called from +/// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder: +/// input_zp: input zeropoint +/// weight_zp: weight zeropoint. +ConvOpQuantizationAttr +mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input, + Value weight) { + + auto maybeZps = getConvZeroPoints(input, weight); + if (!maybeZps.has_value()) + return nullptr; + + return builder.getAttr(maybeZps->first, + maybeZps->second); } /// Builds MatMulOpQuantizationAttr, called from diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index 5eeaebb384e40..116cd045aa0d3 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -544,7 +544,8 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi // CHECK: linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32> // HWCF: linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32> - %0 = tosa.conv2d %input, %weights, %bias {acc_type = i32, dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>) -> tensor<1x45x40x28xi32> + %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.conv2d %input, %weights, %bias, %zp, %zp {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x45x40x28xi32> return } @@ -687,7 +688,9 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] // CHECK: tensor.yield %[[C22]] // CHECK: linalg.conv_2d_nhwc_fhwc_q - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32> + %input_zp = "tosa.const"() {value = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.conv2d %arg0, %arg1, %arg2 , %input_zp, %weight_zp {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x12x12x1024xi32> return } @@ -799,7 +802,9 @@ func.func @depthwise_conv_quant(%arg0 : tensor<1x12x12x4xi8>, %arg1 : tensor<3x3 // CHECK: [[ADD:%.+]] = arith.addi %[[ARG3]], %[[ARG4]] : i32 // CHECK: linalg.yield [[ADD]] : i32 // CHECK: } -> tensor<1x12x12x512xi32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array, quantization_info = #tosa.conv_quant, stride = array, dilation = array } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x12x12x512xi32> + %input_zp = "tosa.const"() {value = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array, stride = array, dilation = array } : (tensor<1x12x12x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x12x12x512xi32> return } @@ -823,7 +828,9 @@ func.func @depthwise_conv_quant_dilations(%arg0 : tensor<1x14x14x4xi8>, %arg1 : // CHECK: [[ADD:%.+]] = arith.addi %[[ARG3]], %[[ARG4]] : i32 // CHECK: linalg.yield [[ADD]] : i32 // CHECK: } -> tensor<1x10x10x512xi32> - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array, quantization_info = #tosa.conv_quant, stride = array, dilation = array } : (tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x10x10x512xi32> + %input_zp = "tosa.const"() {value = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 , %input_zp, %weight_zp {acc_type = i32, pad = array, stride = array, dilation = array } : (tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x10x10x512xi32> return } @@ -905,7 +912,9 @@ func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5 // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x28xi8>, i32, i32) // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32> - %0 = tosa.conv3d %input, %weights, %bias {acc_type = i32, pad = array, quantization_info = #tosa.conv_quant, stride = array, dilation = array} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>) -> tensor<1x47x45x43x28xi32> + %input_zp = "tosa.const"() {value = dense<-128> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.conv3d %input, %weights, %bias , %input_zp, %weight_zp {acc_type = i32, pad = array, stride = array, dilation = array} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x47x45x43x28xi32> return } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index ac4d466aef94b..006c5bd52a9f6 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -33,45 +33,41 @@ func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>, // ----- func.func @test_conv2d(%arg0: tensor<*xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { + %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> // expected-error@+1 {{expect a ranked tensor for input, got of type 'tensor<*xi8>' at index: 0}} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array, pad = array, stride = array} - : (tensor<*xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8> + %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array, pad = array, stride = array} + : (tensor<*xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8> return %0 : tensor<1x27x27x16xi8> } // ----- func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { + %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> // expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or floating-point values, but got 'tensor<*xi8>'}} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array, pad = array, stride = array} - : (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8> - return %0 : tensor<1x27x27x16xi8> -} - -// ----- - -func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { - // expected-error@+1 {{'tosa.conv2d' op quantizationattr is required for quantized type, and not allowed for float type}} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array, pad = array, stride = array} - : (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8> + %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = i32, dilation = array, pad = array, stride = array} + : (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8> return %0 : tensor<1x27x27x16xi8> } // ----- func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { + %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> // expected-error@+1 {{'tosa.conv2d' op accumulator type for i8 tensor is not i32}} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array, pad = array, stride = array, quantization_info = #tosa.conv_quant} - : (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8> + %0 = tosa.conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array, pad = array, stride = array} + : (tensor<1x29x29x4xi8>, tensor<16x3x3x4xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x27x27x16xi8> return %0 : tensor<1x27x27x16xi8> } // ----- func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xi16>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi16>) -> tensor<1x27x27x16xi16> { + %input_zp = "tosa.const"() {value = dense<0> : tensor<1xi16>} : () -> tensor<1xi16> + %weight_zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> // expected-error@+1 {{'tosa.conv2d' op accumulator type for i16 tensor is not i48}} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array, pad = array, stride = array, quantization_info = #tosa.conv_quant} - : (tensor<1x29x29x4xi16>, tensor<16x3x3x4xi8>, tensor<16xi16>) -> tensor<1x27x27x16xi16> + %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f16, dilation = array, pad = array, stride = array} + : (tensor<1x29x29x4xi16>, tensor<16x3x3x4xi8>, tensor<16xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x27x27x16xi16> return %0 : tensor<1x27x27x16xi16> } @@ -123,25 +119,28 @@ func.func @test_conv2d_acc_type(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3 // ----- func.func @test_conv3d_acc_type(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<34xi8>) -> tensor<1x4x8x21x34xi8> { + %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> // expected-error@+1 {{'tosa.conv3d' op accumulator type for i8 tensor is not i32}} - %0 = tosa.conv3d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array, pad = array, stride = array, quantization_info = #tosa.conv_quant} - : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<34xi8>) -> tensor<1x4x8x21x34xi8> + %0 = tosa.conv3d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array, pad = array, stride = array} + : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<34xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x8x21x34xi8> return %0 : tensor<1x4x8x21x34xi8> } // ----- func.func @test_depthwise_conv2d_acc_type(%arg0: tensor<1x4x4x4xi8>, %arg1: tensor<1x1x4x2xi8>, %arg2: tensor<8xi8>) -> tensor<1x4x4x8xi8> { + %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> // expected-error@+1 {{'tosa.depthwise_conv2d' op accumulator type for i8 tensor is not i32}} - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f16, dilation = array, pad = array, stride = array, quantization_info = #tosa.conv_quant} : (tensor<1x4x4x4xi8>, tensor<1x1x4x2xi8>, tensor<8xi8>) -> tensor<1x4x4x8xi8> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, dilation = array, pad = array, stride = array} : (tensor<1x4x4x4xi8>, tensor<1x1x4x2xi8>, tensor<8xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x4x8xi8> return %0 : tensor<1x4x4x8xi8> } // ----- func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xi8>, %arg1: tensor<16x1x1x8xi8>, %arg2: tensor<16xi8>) -> tensor<1x32x32x16xi8> { + %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> // expected-error@+1 {{'tosa.transpose_conv2d' op accumulator type for i8 tensor is not i32}} - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f16, out_pad = array, out_shape = array, stride = array, quantization_info = #tosa.conv_quant} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>) -> tensor<1x32x32x16xi8> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f16, out_pad = array, out_shape = array, stride = array} : (tensor<1x32x32x8xi8>, tensor<16x1x1x8xi8>, tensor<16xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x32x32x16xi8> return %0 : tensor<1x32x32x16xi8> } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index a4596c8f9d536..d00230d12aab1 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -63,7 +63,8 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, % func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8> { %0 = "tosa.const"() {value = dense<0> : tensor<3x11x11x3xi4>} : () -> tensor<3x11x11x3xi4> %1 = "tosa.const"() {value = dense<[12, 23, 55]> : tensor<3xi32>} : () -> tensor<3xi32> - %2 = "tosa.conv2d"(%arg0, %0, %1) {acc_type = i32, dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>) -> tensor<1x1x1x3xi32> + %zp = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> + %2 = "tosa.conv2d"(%arg0, %0, %1, %zp, %zp) {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x1x1x3xi32> %3 = "tosa.rescale"(%2) {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 27 : i32, per_channel = true, scale32 = true, shift = array} : (tensor<1x1x1x3xi32>) -> tensor<1x1x1x3xi8> return %3 : tensor<1x1x1x3xi8> } diff --git a/mlir/test/Dialect/Tosa/quant-test.mlir b/mlir/test/Dialect/Tosa/quant-test.mlir index 6437f12e3ff85..ee6caf285a248 100644 --- a/mlir/test/Dialect/Tosa/quant-test.mlir +++ b/mlir/test/Dialect/Tosa/quant-test.mlir @@ -12,7 +12,9 @@ func.func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform:f32 // CHECK-LABEL: test_build_mult_and_shift func.func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform>, %arg1 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform> { // CHECK: tosa.conv2d - %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {acc_type = i32, pad = array, dilation = array, stride = array, quantization_info = #tosa.conv_quant} : (tensor<1x32x32x8x!quant.uniform>, tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489>>, tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform> + %input_zp = "tosa.const"() {value = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2, %input_zp, %weight_zp) {acc_type = i32, pad = array, dilation = array, stride = array} : (tensor<1x32x32x8x!quant.uniform>, tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489>>, tensor<16xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x32x32x16x!quant.uniform> return %0 : tensor<1x32x32x16x!quant.uniform> } diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir index 95d9bb1b98ab7..685f799bd3d2b 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir @@ -33,7 +33,9 @@ func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: t // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array} // CHECK-SAME: -> tensor<4x10x10x3xi32> // CHECK: return %[[VAR3]] - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array, stride = array, dilation = array, quantization_info = #tosa.conv_quant} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x10x10x3xi32> + %input_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<24> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x3xi32> return %0 : tensor<4x10x10x3xi32> } @@ -50,7 +52,9 @@ func.func @conv_with_dynamic_dim(%arg0: tensor, %arg1: tensor<384 // CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor) -> tensor // CHECK: return %[[VAL_6]] : tensor // CHECK: } - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} : (tensor, tensor<384x1x1x64xi8>, tensor<384xi32>) -> tensor + %input_zp = "tosa.const"() {value = dense<-6> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<11> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor, tensor<384x1x1x64xi8>, tensor<384xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor return %0 : tensor } @@ -65,6 +69,8 @@ func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1: // CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1 {new_shape = array} // CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {quantization_info = #tosa.conv_quant} // CHECK: %[[RESHAPE:.+]] = tosa.reshape %[[FULLY]] {new_shape = array} - %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array, stride = array, dilation = array, quantization_info = #tosa.conv_quant} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>) -> tensor<4x12x12x3xi32> + %input_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<24> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xi8>, tensor<3x1x1x2xi8>, tensor<3xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x12x12x3xi32> return %0 : tensor<4x12x12x3xi32> } diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir index 5f36dd3b3d137..ce29d1a498b4f 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir @@ -38,7 +38,9 @@ func.func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor< // CHECK: %[[reO:.+]] = tosa.reshape %[[mul]] {new_shape = array} // CHECK: %[[reArg2:.+]] = tosa.reshape %arg2 {new_shape = array} // CHECK: %[[add:.+]] = tosa.add %[[reO]], %[[reArg2]] - %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = i32, pad = array, stride = array, dilation = array, quantization_info = #tosa.conv_quant} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32> + %input_zp = "tosa.const"() {value = dense<7> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<11> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, pad = array, stride = array, dilation = array } : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6xi32> return %0 : tensor<4x10x10x6xi32> } diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir index 12691f2e325a2..bb6de82ee1053 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir @@ -6,7 +6,7 @@ func.func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3x // CHECK: %[[REV2:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32} // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2 // CHECK-SAME: dilation = array, pad = array, stride = array - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2{acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<2x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x18x19x5xf32> return %0 : tensor<2x18x19x5xf32> } @@ -15,10 +15,14 @@ func.func @transpose_conv2d(%arg0: tensor<2x16x14x3xf32>, %arg1: tensor<5x3x6x3x // CHECK-LABEL: @transpose_conv2d_quantized func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x18x19x5xi32>) { + // CHECK-DAG: %[[INPUT_ZP:.+]] = "tosa.const"() <{value = dense<-6> : tensor<1xi8>} + // CHECK-DAG: %[[WEIGHT_ZP:.+]] = "tosa.const"() <{value = dense<11> : tensor<1xi8>} // CHECK: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32} // CHECK: %[[REV2:.+]] = tosa.reverse %[[REV1]] {axis = 2 : i32} - // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2 {acc_type = i32, dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = i32, out_pad = array, quantization_info = #tosa.conv_quant, out_shape = array, stride = array} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x18x19x5xi32> + // CHECK: tosa.conv2d %arg0, %[[REV2]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array, pad = array, stride = array} + %input_zp = "tosa.const"() {value = dense<-6> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<11> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, out_pad = array, out_shape = array, stride = array} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<2x18x19x5xi32> return %0 : tensor<2x18x19x5xi32> } @@ -26,17 +30,20 @@ func.func @transpose_conv2d_quantized(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor // CHECK-LABEL: @transpose_conv2d_quantized_padded func.func @transpose_conv2d_quantized_padded(%arg0: tensor<2x16x14x3xi8>, %arg1: tensor<5x3x6x3xi8>, %arg2: tensor<5xi32>) -> (tensor<2x21x26x5xi32>) { - // CHECK-DAG: %[[REV0:.+]] = tosa.reverse %0 {axis = 2 : i32} + // CHECK-DAG: %[[INPUT_ZP:.+]] = "tosa.const"() <{value = dense<-22> : tensor<1xi8>} + // CHECK-DAG: %[[WEIGHT_ZP:.+]] = "tosa.const"() <{value = dense<42> : tensor<1xi8>} + // CHECK-DAG: %[[REV0:.+]] = tosa.reverse %2 {axis = 2 : i32} // CHECK-DAG: %[[REV1:.+]] = tosa.reverse %arg1 {axis = 1 : i32} - // CHECK: tosa.conv2d %arg0, %1, %arg2 + // CHECK: tosa.conv2d %arg0, %3, %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]] // CHECK-SAME: dilation = array, pad = array, - // CHECK-SAME: quantization_info = #tosa.conv_quant, stride = array} - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 { + // CHECK-SAME: stride = array} + %input_zp = "tosa.const"() {value = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp { acc_type = i32, out_pad = array, - quantization_info = #tosa.conv_quant, out_shape = array, - stride = array} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>) -> tensor<2x21x26x5xi32> + stride = array} : (tensor<2x16x14x3xi8>, tensor<5x3x6x3xi8>, tensor<5xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<2x21x26x5xi32> return %0 : tensor<2x21x26x5xi32> } @@ -71,7 +78,7 @@ func.func @transpose_conv2d_strided(%arg0: tensor<2x17x15x3xf32>, %arg1: tensor< // CHECK-DAG: %[[SLICE:.*]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]] // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 // CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]] - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2{acc_type = f32, out_pad = array, out_shape = array, stride = array} : (tensor<2x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>) -> tensor<2x35x47x5xf32> %1 = tensor.cast %0 : tensor<2x35x47x5xf32> to tensor<2x?x?x5xf32> return %1 : tensor<2x?x?x5xf32> } @@ -98,7 +105,9 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1 // Manipulate the final shape. // CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() <{value = dense<0> : tensor<30xi32>} - // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]] {acc_type = i32, dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} + // CHECK-DAG: %[[INPUT_ZP:.+]] = "tosa.const"() <{value = dense<-22> : tensor<1xi8>} + // CHECK-DAG: %[[WEIGHT_ZP:.+]] = "tosa.const"() <{value = dense<42> : tensor<1xi8>} + // CHECK-DAG: %[[CONV:.+]] = tosa.conv2d %[[NEWINPUT]], %[[NEWWEIGHT]], %[[BIAS]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array, pad = array, stride = array} // CHECK-DAG: %[[RESHAPE_OUT_1:.+]] = tosa.reshape %[[CONV]] {new_shape = array} // CHECK-DAG: %[[TRANS_OUT:.+]] = tosa.transpose %[[RESHAPE_OUT_1]], %[[TRANS2]] // CHECK-DAG: %[[RESHAPE_OUT_2:.+]] = tosa.reshape %[[TRANS_OUT]] @@ -107,7 +116,9 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1 // CHECK-DAG: %[[SLICE:.*]] = tosa.slice %[[RESHAPE_OUT_2]], %[[START]], %[[SIZE]] // CHECK-DAG: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 // CHECK: %[[ADD:.+]] = tosa.add %[[SLICE]], %[[RESHAPE_ARG2]] - %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {acc_type = i32, out_pad = array, quantization_info = #tosa.conv_quant, out_shape = array, stride = array} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>) -> tensor<2x35x47x5xi32> + %input_zp = "tosa.const"() {value = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = i32, out_pad = array, out_shape = array, stride = array} : (tensor<2x17x15x3xi8>, tensor<5x3x5x3xi8>, tensor<5xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<2x35x47x5xi32> return %0 : tensor<2x35x47x5xi32> } @@ -135,12 +146,13 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : // CHECK: %[[PAD_RESULT:.+]] = tosa.pad %[[RESHAPE_RESULT_1]], %[[RESULT_PAD]] // CHECK: %[[RESHAPE_ARG2:.+]] = tosa.reshape %arg2 {new_shape = array} // CHECK: %[[ADD:.+]] = tosa.add %[[PAD_RESULT]], %[[RESHAPE_ARG2]] - %2 = tosa.transpose_conv2d %arg0, %arg1, %arg2 { + %input_zp = "tosa.const"() {value = dense<-103> : tensor<1xi8>} : () -> tensor<1xi8> + %weight_zp = "tosa.const"() {value = dense<93> : tensor<1xi8>} : () -> tensor<1xi8> + %2 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp { acc_type = i32, out_pad = array, out_shape = array, - stride = array, - quantization_info = #tosa.conv_quant} : - (tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>) -> tensor<1x19x2x1xi32> + stride = array} : + (tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x19x2x1xi32> "func.return" (%2) : (tensor<1x19x2x1xi32>) -> () }