Skip to content

Commit

Permalink
[mlir][tosa] Make Convolution Zero Points Inputs (llvm#122939)
Browse files Browse the repository at this point in the history
The TOSA-v1.0 specification moves the "zero point" parameters of the
convolution operators CONV2D, CONV3D, DEPTHWISE_CONV2D, and
TRANSPOSE_CONV2D from attributes to inputs.

Make the zero points of the convolutions in the MLIR TOSA dialect inputs
and update any transformations, materializations and lit tests
appropriately.

Rename the "filter" argument of `tosa.transpose_conv2d` to weight to
align with the TOSA specification.

Remove the quantization_info attribute on the convolution operations.

Co-authored-by: TatWai Chong <[email protected]>
  • Loading branch information
2 people authored and Icohedron committed Feb 11, 2025
1 parent 355f83e commit fb81f74
Show file tree
Hide file tree
Showing 19 changed files with 576 additions and 198 deletions.
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -264,4 +264,11 @@ class Tosa_InferShapedTypeOp<string mnemonic, list<Trait> traits = []>
"operands attr-dict `:` functional-type(operands, results)";
}

// The "SameVariadicOperandSize" trait allows us to pass optional arguments
// for multiple zero points in convolution ops.
class Tosa_ConvOp<string mnemonic, list<Trait> traits = []>
: Tosa_InferShapedTypeOp<mnemonic, !listconcat(traits,
[SameVariadicOperandSize])> {
}

#endif // TOSA_OP_BASE
118 changes: 118 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
Expand All @@ -29,6 +30,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.h.inc"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
class PatternRewriter;
Expand Down Expand Up @@ -152,4 +154,120 @@ bool isa_tosa_shape_type(mlir::Type t);
#define GET_OP_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"

namespace mlir {
namespace tosa {

// Create a rank-1 const tensor for zero point of the source tensor.
std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
Type srcElemType, int64_t zp = 0);

// Get zero point value from the attribute argument.
LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp);

// Verify if zero point falls into valid range.
template <typename T>
LogicalResult verifyZeroPoint(Type zpElemType, int64_t zp) {
if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
!std::is_same_v<T, DepthwiseConv2DOp> &&
!std::is_same_v<T, TransposeConv2DOp>) {
return failure();
}

if (!zpElemType.isIntOrFloat())
return failure();

if (!zpElemType.isInteger(8) && zp != 0)
return failure();

if (zpElemType.isSignedInteger(8) && (zp < -128 || zp > 127))
return failure();

if (zpElemType.isUnsignedInteger(8) && (zp < 0 || zp > 255))
return failure();

return success();
}

// Helper type trait to determine if an operation is a tosa convolution.
template <typename Op>
struct IsTosaConv : std::false_type {};

template <>
struct IsTosaConv<tosa::Conv2DOp> : std::true_type {};
template <>
struct IsTosaConv<tosa::DepthwiseConv2DOp> : std::true_type {};
template <>
struct IsTosaConv<tosa::TransposeConv2DOp> : std::true_type {};
template <>
struct IsTosaConv<tosa::Conv3DOp> : std::true_type {};

template <typename Op>
constexpr bool is_tosa_conv_v = IsTosaConv<Op>::value;

// Helper struct to hold the zero points of a TOSA convolution operation as
// named 64-bit integer fields.
struct ConvZpPair {
ConvZpPair(std::int64_t inputZp, std::int64_t weightZp)
: inputZp(inputZp), weightZp(weightZp) {}
std::int64_t inputZp;
std::int64_t weightZp;
};

// Helper function which attempts to extract the zero points from a TOSA
// convolution by matching them against defining ops which should be tosa.const
// operations.
//
// There are three possible results:
// 1. Failed to extract the zero-points i.e. they should exist and don't or they
// do exist but are invalid.
// 2. Succeeded in extracting zero-points.
// 3. Zero points are "empty" and meaningless for this op i.e. non-quantized
// convolution.
using FailOrMaybeZP = llvm::FailureOr<std::optional<ConvZpPair>>;
template <typename TosaConvOp>
std::enable_if_t<is_tosa_conv_v<TosaConvOp>, FailOrMaybeZP>
extractConvZpPair(TosaConvOp op, PatternRewriter &rewriter) {
// Strictly speaking the base TOSA spec requires that for non int8 types
// zero points must be zero. However, in the dialect these operands are
// optional and only required for int8. They have no semantic meaning for
// non-quantized types and can therefore be safely ignored. This is case 3.
if (auto opElementTY =
cast<ShapedType>(op->getOperand(0).getType()).getElementType();
!opElementTY.isInteger(8))
return FailOrMaybeZP(std::nullopt);

// Now we know we should have a zero point check it is valid.
if (!op.getInputZp())
return rewriter.notifyMatchFailure(op, "missing input zero point");

// Helper to extract the zero point by matching its definition against a
// constant.
auto extractZeroPoint = [](Value zpValue) -> std::optional<int64_t> {
ElementsAttr zpAttr;
if (!matchPattern(zpValue, m_Constant(&zpAttr)))
return std::nullopt;

int64_t zp;
if (tosa::getZeroPoint(zpAttr, zp).failed())
return std::nullopt;

return std::make_optional(zp);
};

auto maybeInputZp = extractZeroPoint(op.getInputZp());
if (!maybeInputZp)
return rewriter.notifyMatchFailure(op, "unable to extract input zp");

if (!op.getWeightZp())
return rewriter.notifyMatchFailure(op, "missing weight zero point");

auto maybeWeightZp = extractZeroPoint(op.getWeightZp());
if (!maybeWeightZp)
return rewriter.notifyMatchFailure(op, "unable to extract weight zp");

return std::make_optional<ConvZpPair>(*maybeInputZp, *maybeWeightZp);
}
} // namespace tosa
} // namespace mlir

#endif // MLIR_DIALECT_TOSA_IR_TOSAOPS_H
22 changes: 13 additions & 9 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
//===----------------------------------------------------------------------===//
// Operator: conv2d
//===----------------------------------------------------------------------===//
def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
let summary = "2D Convolution Operator";

let description = [{
Expand All @@ -104,11 +104,12 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Optional<Tosa_ZeroPointTensor>:$input_zp,
Optional<Tosa_ZeroPointTensor>:$weight_zp,
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
TypeAttrOf<Tosa_AccType>:$acc_type,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);

Expand All @@ -123,7 +124,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
//===----------------------------------------------------------------------===//
// Operator: conv3d
//===----------------------------------------------------------------------===//
def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
let summary = "3D Convolution operator";

let description = [{
Expand All @@ -134,11 +135,12 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
Tosa_Tensor5D:$input,
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
Tosa_Tensor1D:$bias,
Optional<Tosa_ZeroPointTensor>:$input_zp,
Optional<Tosa_ZeroPointTensor>:$weight_zp,
Tosa_IntArrayAttr6:$pad,
Tosa_IntArrayAttr3:$stride,
Tosa_IntArrayAttr3:$dilation,
TypeAttrOf<Tosa_AccType>:$acc_type,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);

Expand All @@ -153,7 +155,7 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
//===----------------------------------------------------------------------===//
// Operator: depthwise_conv2d
//===----------------------------------------------------------------------===//
def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
let summary = "Depthwise 2D Convolution operator";

let description = [{
Expand All @@ -165,11 +167,12 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Optional<Tosa_ZeroPointTensor>:$input_zp,
Optional<Tosa_ZeroPointTensor>:$weight_zp,
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
TypeAttrOf<Tosa_AccType>:$acc_type,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);

Expand Down Expand Up @@ -338,7 +341,7 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
//===----------------------------------------------------------------------===//
// Operator: transpose_conv2d
//===----------------------------------------------------------------------===//
def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
let summary = "Transpose 2D Convolution operator.";

let description = [{
Expand All @@ -348,13 +351,14 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {

let arguments = (ins
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$filter,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Optional<Tosa_ZeroPointTensor>:$input_zp,
Optional<Tosa_ZeroPointTensor>:$weight_zp,
Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$out_shape,
TypeAttrOf<Tosa_AccType>:$acc_type,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);

Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -288,4 +288,9 @@ def Rank1TosaShape : TosaShapeOfRank<1>;
def Rank2TosaShape : TosaShapeOfRank<2>;
def Rank4TosaShape : TosaShapeOfRank<4>;

// NOTE: Tosa_ScalarTensor is currently defined as rank-0. If and when this
// becomes rank-1 it can be used in place of Tosa_ZeroPointTensor and the
// following def can be removed.
def Tosa_ZeroPointTensor : TosaTensorRankOf<[Tosa_AnyNumber], [1]>;

#endif // TOSA_TYPES_BASE
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ void computeMultiplierAndShift(double scale, int32_t &multiplier,
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
Value input, Value weight);

std::pair<Value, Value> createZPsAsConst(OpBuilder &builder, Value input,
Value weight);

//// Builds MatMulOpQuantizationAttr for MatMul operations from A and B.
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder,
Value a, Value b);
Expand Down
57 changes: 26 additions & 31 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,12 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
DenseI64ArrayAttr padAttr = op.getPadAttr();
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
bool isQuantized = op.getQuantizationInfo().has_value();

auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
if (llvm::failed(failureOrMaybeZps))
return failure();

auto maybeZps = failureOrMaybeZps.value();

if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return rewriter.notifyMatchFailure(
Expand All @@ -284,22 +289,19 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {

// Apply padding as necessary.
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
if (isQuantized) {
auto quantizationInfo = *op.getQuantizationInfo();
int64_t iZp = quantizationInfo.getInputZp();

if (maybeZps) {
int64_t intMin =
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();
int64_t intMax =
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();

if (iZp < intMin || iZp > intMax)
if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
return rewriter.notifyMatchFailure(
op, "tosa.conv op quantization has zp outside of input range");

zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
}

llvm::SmallVector<int64_t> pad;
Expand All @@ -312,8 +314,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
// For 2D convolutions, we need to check if the target convolution op
// wants a HWCF kernel layout.
bool wantHwcf =
isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
maybeZps ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
if (wantHwcf) {
// Transpose the kernel to match dimension ordering of the linalg
// convolution operation.
Expand Down Expand Up @@ -374,10 +376,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
Value broadcastBias =
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);

if (isQuantized) {
auto quantizationInfo = *op.getQuantizationInfo();
auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
if (maybeZps) {
auto iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
auto kZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);

auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
Expand Down Expand Up @@ -440,39 +441,31 @@ class DepthwiseConvConverter
/*inputSizeDims=*/{1, 2},
/*kernelSizeDims=*/{0, 1}, rewriter);

bool isQuantized = op->hasAttr("quantization_info");
IntegerAttr iZp;
IntegerAttr kZp;
if (isQuantized) {
auto quantizationInfo =
cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
}
auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
if (llvm::failed(failureOrMaybeZps))
return failure();

auto maybeZps = failureOrMaybeZps.value();

auto weightShape = weightTy.getShape();
auto resultShape = resultTy.getShape();

// Apply padding as necessary.
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
if (isQuantized) {
auto quantizationInfo =
cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
int64_t iZp = quantizationInfo.getInputZp();

if (maybeZps) {
int64_t intMin =
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();
int64_t intMax =
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();

if (iZp < intMin || iZp > intMax)
if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
return rewriter.notifyMatchFailure(
op, "tosa.depthwise_conv op quantization has zp outside of input "
"range");

zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
}

llvm::SmallVector<int64_t> pad;
Expand Down Expand Up @@ -512,7 +505,7 @@ class DepthwiseConvConverter
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));

if (!isQuantized) {
if (!maybeZps) {
Value conv = rewriter
.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
loc, linalgConvTy, ValueRange{input, weight},
Expand All @@ -539,8 +532,10 @@ class DepthwiseConvConverter
.getResult(0);
rewriter.replaceOp(op, result);
} else {
IntegerAttr iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
IntegerAttr wZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, wZp);
Value conv =
rewriter
.create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
Expand Down
Loading

0 comments on commit fb81f74

Please sign in to comment.