Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 35 additions & 77 deletions src/Dialect/ONNX/Transforms/Decompose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,10 @@ bool ShouldDecomposeConvTransposeOpToPhasedConvs(Value convTransposeResult,
bool fourPhaseDecomposition = (stridesShape[0] == 2);
bool ninePhaseDecomposition = (stridesShape[0] == 3);
if (fourPhaseDecomposition) {
if (outputShape[0] != 1) {
// Currently support batch=1
return false;
}
if (kernelShape[0] == 6 && padsShape[0] == 2 &&
llvm::all_equal(padsShape)) {
// Currently support only with pads [2, 2, 2, 2]
Expand Down Expand Up @@ -1711,102 +1715,56 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
startOnnxConstant, endOnnxConstant, axisOnnxConstant,
stepOnnxConstant);
}

// The four convOutputs are adjusted to add an extra dimension at the
// innermost level.
SmallVector<int64_t> outputShapePlusOneDim(convOutputShape);
outputShapePlusOneDim.push_back(1);
auto onnxConstForReshapeAddOneDim =
getONNXConstOpFromVector(rewriter, loc, outputShapePlusOneDim);

auto reshapeOutputType =
RankedTensorType::get(outputShapePlusOneDim, elementType);

auto reshapeOutputAddOneDimConv1 = rewriter.create<ONNXReshapeOp>(
loc, reshapeOutputType, conv1, onnxConstForReshapeAddOneDim);
auto reshapeOutputAddOneDimConv2 = rewriter.create<ONNXReshapeOp>(
loc, reshapeOutputType, conv2, onnxConstForReshapeAddOneDim);
auto reshapeOutputAddOneDimConv3 = rewriter.create<ONNXReshapeOp>(
loc, reshapeOutputType, conv3, onnxConstForReshapeAddOneDim);
auto reshapeOutputAddOneDimConv4 = rewriter.create<ONNXReshapeOp>(
loc, reshapeOutputType, conv4, onnxConstForReshapeAddOneDim);

SmallVector<int64_t> outputShapeLevel1Concat(outputShapePlusOneDim);
outputShapeLevel1Concat[outputShapeLevel1Concat.size() - 1] = 2;
auto level1ConcatOutputType =
RankedTensorType::get(outputShapeLevel1Concat, elementType);
// Four conv outputs are merged in channel dim
SmallVector<int64_t> outputShapeOfConcat = {
1, convOutputShape[1] * 4, convOutputShape[2], convOutputShape[3]};
auto concatOutputType =
RankedTensorType::get(outputShapeOfConcat, elementType);
// for the case where convtranspose kernel is [4, 4] and with pads [1, 1, 1,
// 1] The phased convs output are to be concatenated in the reverse order.
// This is observed by looking at the phased conv outputs with respect to
// convtranspose output.
bool reverseConcatOrder = (needWeightsPadding || (kernelShape[0] == 4));
// Below concats result will have the innermost dim as 2.
// The concat output will have 4 times the channels of a single conv.
auto firstConcat =
(reverseConcatOrder)
? rewriter.create<ONNXConcatOp>(loc, level1ConcatOutputType,
ValueRange{
reshapeOutputAddOneDimConv3, reshapeOutputAddOneDimConv1},
-1)
: rewriter.create<ONNXConcatOp>(loc, level1ConcatOutputType,
ValueRange{
reshapeOutputAddOneDimConv1, reshapeOutputAddOneDimConv3},
-1);
auto secondConcat =
(reverseConcatOrder)
? rewriter.create<ONNXConcatOp>(loc, level1ConcatOutputType,
ValueRange{
reshapeOutputAddOneDimConv2, reshapeOutputAddOneDimConv4},
-1)
: rewriter.create<ONNXConcatOp>(loc, level1ConcatOutputType,
ValueRange{
reshapeOutputAddOneDimConv4, reshapeOutputAddOneDimConv2},
-1);
? rewriter.create<ONNXConcatOp>(loc, concatOutputType,
ValueRange{conv2, conv4, conv3, conv1}, 1)
: rewriter.create<ONNXConcatOp>(loc, concatOutputType,
ValueRange{conv1, conv3, conv4, conv2}, 1);

// Reshaping to modify the two innermost levels,ensuring the second
// innermost level is set to 1
SmallVector<int64_t> outputShapeForDimAdjust(convOutputShape);
auto dimValueAtLastIndex = convOutputShape[convOutputShape.size() - 1] * 2;
outputShapeForDimAdjust[outputShapeForDimAdjust.size() - 1] = 1;
outputShapeForDimAdjust.push_back(dimValueAtLastIndex);
// Here we are reshaping the concatenated conv channels of 4*Conv_channels
// into groups of 2x2 channels. This can be visualized as
// H_chan(2) * W_Chan(2) * C_real, then doing the transpose into
// Conv_channels H H_chan W W_chan. The adjecent H and H_chan will be merged
// into H, same way W and W_chan will be merged into W. This leads to
// doubling of the H and W. Keeping the channels same.

SmallVector<int64_t> outputShapeForDimAdjust = {
2, 2, convOutputShape[1], convOutputShape[2], convOutputShape[3]};

auto onnxConstForReshapeDimAdjust =
getONNXConstOpFromVector(rewriter, loc, outputShapeForDimAdjust);

auto reshapeOutputForDimAdjustType =
RankedTensorType::get(outputShapeForDimAdjust, elementType);
auto reshapeOutputDimAdjustOfFirstConcat =
auto reshapeOutputDimAdjust =
rewriter.create<ONNXReshapeOp>(loc, reshapeOutputForDimAdjustType,
firstConcat, onnxConstForReshapeDimAdjust);
auto reshapeOutputDimAdjustOfSecondConcat =
rewriter.create<ONNXReshapeOp>(loc, reshapeOutputForDimAdjustType,
secondConcat, onnxConstForReshapeDimAdjust);

SmallVector<int64_t> outputShapeForFinalConcat(outputShapeForDimAdjust);
outputShapeForFinalConcat[outputShapeForFinalConcat.size() - 2] = 2;
SmallVector<int64_t> transposeOuputShape = {
convOutputShape[1], convOutputShape[2], 2, convOutputShape[3], 2};

auto finalConcatOutputType =
RankedTensorType::get(outputShapeForFinalConcat, elementType);
auto transposeOutputType =
RankedTensorType::get(transposeOuputShape, elementType);

// Final Concat is performed on the two reshaped outputs at the
// second innermost level
auto finalConcat =
(reverseConcatOrder)
? rewriter.create<ONNXConcatOp>(loc, finalConcatOutputType,
ValueRange{reshapeOutputDimAdjustOfSecondConcat,
reshapeOutputDimAdjustOfFirstConcat},
-2)
: rewriter.create<ONNXConcatOp>(loc, finalConcatOutputType,
ValueRange{reshapeOutputDimAdjustOfFirstConcat,
reshapeOutputDimAdjustOfSecondConcat},
-2);
auto permArrayAttr = rewriter.getI64ArrayAttr({2, 3, 0, 4, 1});

SmallVector<int64_t> outputShapeForResult(convOutputShape);
dimValueAtLastIndex = convOutputShape[convOutputShape.size() - 1] * 2;
auto dimValueAtSecondLastIndex =
convOutputShape[convOutputShape.size() - 2] * 2;
outputShapeForResult[outputShapeForResult.size() - 2] =
dimValueAtSecondLastIndex;
outputShapeForResult[outputShapeForResult.size() - 1] = dimValueAtLastIndex;
auto transpose = rewriter.create<ONNXTransposeOp>(
loc, transposeOutputType, reshapeOutputDimAdjust, permArrayAttr);

SmallVector<int64_t> outputShapeForResult = {
1, convOutputShape[1], convOutputShape[2] * 2, convOutputShape[3] * 2};

auto onnxConstForLastReshape =
getONNXConstOpFromVector(rewriter, loc, outputShapeForResult);
Expand All @@ -1816,7 +1774,7 @@ Value decomposeIntoPhasedConvs(PatternRewriter &rewriter, Location loc,
// Result is reshaped back to match the original convtranspose output
// dimensions
auto finalOutput = rewriter.create<ONNXReshapeOp>(
loc, finalOutputType, finalConcat, onnxConstForLastReshape);
loc, finalOutputType, transpose, onnxConstForLastReshape);
return finalOutput;
}
if (numPhases == 9) {
Expand Down
Loading