Skip to content

Commit

Permalink
Support batches and classes for nms lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
praveen-g-ctt committed Jan 23, 2025
1 parent 481da8d commit 7fdaec0
Showing 1 changed file with 215 additions and 92 deletions.
307 changes: 215 additions & 92 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h"
#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h"
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

using namespace mlir;
Expand Down Expand Up @@ -3703,30 +3704,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, "unimplemented: expected center_point_box "
"attribute value to be 0");

// TODO: Support multiple batches and classes
// Squeeze the boxes and scores tensor.
// In Onnx, the shape of boxes is [BxNx4] while the
// torchvision expects it to be of shape [Nx4]. Similarly, for
// the scores tensor shape in Onnx is [BxCxN] while the
// torchvision expects it to be of shape [N].
Value boxes = operands[0], scores = operands[1];
FailureOr<Value> squeezedBoxes =
Torch::squeezeTensor(rewriter, binder.op, loc, 0, boxes);
if (failed(squeezedBoxes))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze boxes tensor");
FailureOr<Value> squeezedScores =
Torch::squeezeTensor(rewriter, binder.op, loc, 0, scores);
if (failed(squeezedScores))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze scores tensor");
squeezedScores = Torch::squeezeTensor(rewriter, binder.op, loc, 0,
squeezedScores.value());
if (failed(squeezedScores))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze scores tensor");
boxes = squeezedBoxes.value();
scores = squeezedScores.value();

// TODO: Support score_threshold input
// Filter out the boxes if the score < score_threshold
Expand All @@ -3750,12 +3732,26 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
"unimplemented: score_threshold should be <= min(scores)"));
}

// Get max_output_boxes_per_class and iou_threshold
Value cst0 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value cst1 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value cst2 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(2));
Value cst3 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(3));
Value cst4 = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(4));

Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getBoolAttr(true));
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getBoolAttr(false));

Value maxOutputBoxesPerClass = cst0;

// Get max_output_boxes_per_class and iou_threshold
Value iouThreshold = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(0.0));
if (operands.size() > 3 &&
Expand All @@ -3769,87 +3765,214 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
loc, rewriter.getType<Torch::IntType>(), operands[2]);
}

auto boxesTensorType = cast<Torch::ValueTensorType>(boxes.getType());
auto scoreTensorType = cast<Torch::ValueTensorType>(scores.getType());
auto boxSlicedType = rewriter.getType<Torch::ValueTensorType>(
boxesTensorType.getSizes().slice(1), boxesTensorType.getDtype());
auto scoreSlicedType = rewriter.getType<Torch::ValueTensorType>(
scoreTensorType.getSizes().slice(1), scoreTensorType.getDtype());

auto numBatches =
rewriter.create<Torch::AtenSizeIntOp>(loc, scores, cst0);
auto numClasses =
rewriter.create<Torch::AtenSizeIntOp>(loc, scores, cst1);
// auto numBoxes =
// rewriter.create<Torch::AtenSizeIntOp>(loc, scores, cst2);

std::optional<ArrayRef<int64_t>> resultShape =
cast<Torch::ValueTensorType>(resultType).getOptionalSizes();
if (!resultShape.has_value())
return rewriter.notifyMatchFailure(
binder.op, "Expected result tensor to have shape");

Value numResults = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(resultShape->front()));

auto intTy = rewriter.getType<Torch::IntType>();
auto intListTy = rewriter.getType<Torch::ListType>(intTy);

Value resultShapeList = rewriter.create<Torch::PrimListConstructOp>(
loc, intListTy, SmallVector<Value>{numResults, cst3});

Value finalResult = rewriter.create<Torch::AtenEmptyMemoryFormatOp>(
loc, resultType, resultShapeList, /*dtype=*/cst4,
/*layout=*/cstNone,
/*device=*/cstNone, /*pinMemory=*/cstNone,
/*memoryFormat=*/cstNone);

auto nmsTy = Torch::ValueTensorType::get(
binder.op->getContext(), SmallVector<int64_t>{-1},
rewriter.getIntegerType(64, /*signed=*/true));
Value result = rewriter.create<Torch::TorchvisionNmsOp>(
loc, nmsTy, boxes, scores, iouThreshold);

// Slice the result if numOutputBoxes (N) > max_output_boxes_per_class
Value numOutputBoxes =
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
Value boxesCond = rewriter.create<Torch::AtenGtIntOp>(
loc, numOutputBoxes, maxOutputBoxesPerClass);
auto emptyTensorTy = rewriter.getType<Torch::ValueTensorType>(
SmallVector<int64_t>{}, nmsTy.getDtype());

auto nmsResultTy = Torch::ValueTensorType::get(
binder.op->getContext(),
SmallVector<int64_t>{resultType.getSizes()[0]},
rewriter.getIntegerType(64, /*signed=*/true));
auto ifSlice = rewriter.create<Torch::PrimIfOp>(
loc, TypeRange({nmsResultTy}), boxesCond);
auto nmsBatchLoop = rewriter.create<Torch::PrimLoopOp>(
loc, TypeRange({resultType, intTy}), numBatches, cstTrue,
ValueRange({finalResult, /*Index to finalResult*/ cst0}));
{
// Batch loop body
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifSlice.getThenRegion(),
ifSlice.getThenRegion().begin());
Block *batchLoopBody = rewriter.createBlock(
&nmsBatchLoop.getRegion(), nmsBatchLoop.getRegion().begin(),
TypeRange({intTy, resultType, intTy}), {loc, loc, loc});
auto batchIV = batchLoopBody->getArgument(0);
auto currRes = batchLoopBody->getArgument(1);
auto finalResIdx = batchLoopBody->getArgument(2);

auto boxValue = rewriter.create<Torch::AtenSelectIntOp>(
loc, boxSlicedType, boxes, cst0, batchIV);

Value curResult = rewriter.create<Torch::AtenSliceTensorOp>(
loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0,
/*end=*/maxOutputBoxesPerClass, /*step=*/cst1);
rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
auto batchValue = rewriter.create<Torch::PrimNumToTensorScalarOp>(
loc, emptyTensorTy, batchIV);

auto nmsClassLoop = rewriter.create<Torch::PrimLoopOp>(
loc, TypeRange({resultType, intTy}), numClasses, cstTrue,
ValueRange({currRes, finalResIdx}));

{
// Class loop body
PatternRewriter::InsertionGuard guard(rewriter);
Block *classLoopBody = rewriter.createBlock(
&nmsClassLoop.getRegion(), nmsClassLoop.getRegion().begin(),
TypeRange({intTy, resultType, intTy}), {loc, loc, loc});
auto classIV = classLoopBody->getArgument(0);
auto currRes = classLoopBody->getArgument(1);
auto finalResIdx = classLoopBody->getArgument(2);

auto scoreSelect = rewriter.create<Torch::AtenSelectIntOp>(
loc, scoreSlicedType, scores, cst0, cst0);
auto scoreSelectType =
dyn_cast<Torch::ValueTensorType>(scoreSelect.getType());
auto scoreValueType = rewriter.getType<Torch::ValueTensorType>(
scoreSelectType.getSizes().slice(1),
scoreSelectType.getDtype());

auto scoreValue = rewriter.create<Torch::AtenSelectIntOp>(
loc, scoreValueType, scoreSelect, cst0, classIV);

auto classValue = rewriter.create<Torch::PrimNumToTensorScalarOp>(
loc, emptyTensorTy, classIV);

Value result = rewriter.create<Torch::TorchvisionNmsOp>(
loc, nmsTy, boxValue, scoreValue, iouThreshold);

Value numOutputBoxes =
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);

numOutputBoxes = rewriter.create<Torch::PrimNumToTensorScalarOp>(
loc, emptyTensorTy, numOutputBoxes);
Value maxBoxesPerClass =
rewriter.create<Torch::PrimNumToTensorScalarOp>(
loc, emptyTensorTy, maxOutputBoxesPerClass);
auto minVal = rewriter.create<Torch::AtenMinimumOp>(
loc, numOutputBoxes.getType(), numOutputBoxes,
maxBoxesPerClass);
numOutputBoxes =
rewriter.create<Torch::AtenItemOp>(loc, intTy, minVal);

auto nmsResultTy = Torch::ValueTensorType::get(
binder.op->getContext(),
SmallVector<int64_t>{resultType.getSizes()[0]},
rewriter.getIntegerType(64, /*signed=*/true));

result = rewriter.create<Torch::TensorStaticInfoCastOp>(
loc, nmsResultTy, result);

// Loop through the nms result
auto nmsLoop = rewriter.create<Torch::PrimLoopOp>(
loc, TypeRange({resultType, intTy}), numOutputBoxes, cstTrue,
ValueRange({currRes, finalResIdx}));
{
PatternRewriter::InsertionGuard guard(rewriter);
Block *loopBody = rewriter.createBlock(
&nmsLoop.getRegion(), nmsLoop.getRegion().begin(),
TypeRange({intTy, resultType, intTy}), {loc, loc, loc});
auto iter = loopBody->getArgument(0);
auto currRes = loopBody->getArgument(1);
auto idxCst = loopBody->getArgument(2);

// Update batch

auto outputTensorSliceType =
rewriter.getType<Torch::ValueTensorType>(
SmallVector<int64_t>{3}, nmsTy.getDtype());

auto batchDim3D = rewriter.create<Torch::AtenSelectIntOp>(
loc, outputTensorSliceType, currRes, cst0, idxCst);

auto batchSelect = rewriter.create<Torch::AtenSelectIntOp>(
loc, emptyTensorTy, batchDim3D, cst0, cst0);

auto bCopy = rewriter.create<Torch::AtenCopyOp>(
loc, batchSelect.getType(), batchSelect, batchValue,
cstFalse);

batchDim3D = rewriter.create<Torch::AtenSelectIntOp>(
loc, outputTensorSliceType, currRes, cst0, idxCst);

auto scatterBatch = rewriter.create<Torch::AtenSelectScatterOp>(
loc, outputTensorSliceType, batchDim3D, bCopy, cst0, cst0);
// Yield result
auto batchResult = rewriter.create<Torch::AtenSelectScatterOp>(
loc, resultType, currRes, scatterBatch, cst0, idxCst);

// Class values
auto classDim3D = rewriter.create<Torch::AtenSelectIntOp>(
loc, outputTensorSliceType, batchResult, cst0, idxCst);
// Class indices
auto classSelect = rewriter.create<Torch::AtenSelectIntOp>(
loc, emptyTensorTy, classDim3D, cst0, cst1);

auto cCopy = rewriter.create<Torch::AtenCopyOp>(
loc, classSelect.getType(), classSelect, classValue,
cstFalse);

classDim3D = rewriter.create<Torch::AtenSelectIntOp>(
loc, outputTensorSliceType, batchResult, cst0, idxCst);

auto scatterClass = rewriter.create<Torch::AtenSelectScatterOp>(
loc, outputTensorSliceType, classDim3D, cCopy, cst0, cst1);
auto classRes = rewriter.create<Torch::AtenSelectScatterOp>(
loc, resultType, batchResult, scatterClass, cst0, idxCst);

auto resDim3D = rewriter.create<Torch::AtenSelectIntOp>(
loc, outputTensorSliceType, classRes, cst0, idxCst);
// Class indices
auto resSelect = rewriter.create<Torch::AtenSelectIntOp>(
loc, emptyTensorTy, resDim3D, cst0, cst2);

auto nmsResultValue = rewriter.create<Torch::AtenSelectIntOp>(
loc, emptyTensorTy, result, cst0, iter);

auto rCopy = rewriter.create<Torch::AtenCopyOp>(
loc, resSelect.getType(), resSelect, nmsResultValue,
cstFalse);

resDim3D = rewriter.create<Torch::AtenSelectIntOp>(
loc, outputTensorSliceType, classRes, cst0, idxCst);

auto scatterRes = rewriter.create<Torch::AtenSelectScatterOp>(
loc, outputTensorSliceType, resDim3D, rCopy, cst0, cst2);
Value nmsResult = rewriter.create<Torch::AtenSelectScatterOp>(
loc, resultType, classRes, scatterRes, cst0, idxCst);

Value next =
rewriter.create<Torch::AtenAddIntOp>(loc, idxCst, cst1);

rewriter.create<Torch::PrimLoopConditionOp>(
loc, cstTrue, ValueRange({nmsResult, next}));
}
rewriter.create<Torch::PrimLoopConditionOp>(
loc, cstTrue,
ValueRange({nmsLoop.getResult(0), nmsLoop.getResult(1)}));
}
rewriter.create<Torch::PrimLoopConditionOp>(
loc, cstTrue,
ValueRange(
{nmsClassLoop.getResult(0), nmsClassLoop.getResult(1)}));
}
{
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.createBlock(&ifSlice.getElseRegion(),
ifSlice.getElseRegion().begin());

Value curResult = rewriter.create<Torch::TensorStaticInfoCastOp>(
loc, nmsResultTy, result);
rewriter.create<Torch::PrimIfYieldOp>(loc, curResult);
}
result = ifSlice.getResult(0);

// The result generated by torchvision.nms op is of shape [n], while the
// onnx expects it to be of shape [n, 3]. Hence, we unsqueeze the tensor
// and make it of shape [n, 1] and then concatenate it with a zero
// tensor of shape [n, 2] to make it of shape [n, 3].
FailureOr<Value> unsqueezedResult =
Torch::unsqueezeTensor(rewriter, binder.op, result, cst1);
if (failed(unsqueezedResult))
return rewriter.notifyMatchFailure(
binder.op, "failed to unsqueeze result tensor");
result = unsqueezedResult.value();

numOutputBoxes =
rewriter.create<Torch::AtenSizeIntOp>(loc, result, cst0);
SmallVector<Value> zerosShapeValues{numOutputBoxes};
zerosShapeValues.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(2)));
Value zerosShapeList = rewriter.create<Torch::PrimListConstructOp>(
loc,
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
zerosShapeValues);
std::optional<ArrayRef<int64_t>> resultShape =
cast<Torch::ValueTensorType>(result.getType()).getOptionalSizes();
if (!resultShape.has_value())
return rewriter.notifyMatchFailure(
binder.op, "expected result tensor to have shape");
llvm::SmallVector<int64_t> zerosShape = {resultShape->front(), 2};
auto zerosTy = Torch::ValueTensorType::get(
resultType.getContext(), zerosShape, resultType.getOptionalDtype());
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
Value zeros = rewriter.create<Torch::AtenZerosOp>(
loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone);

Type listElemType =
cast<Torch::BaseTensorType>(resultType)
.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
loc, listType, SmallVector<Value>{zeros, result});
rewriter.replaceOpWithNewOp<Torch::AtenCatOp>(binder.op, resultType,
tensorList, cst1);
rewriter.replaceOp(binder.op, nmsBatchLoop.getResult(0));
return success();
});
}

0 comments on commit 7fdaec0

Please sign in to comment.