Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(ONNX): replaces getValueList helper with createScalarSublist #3987

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
b5753c3
refactor(ONNX): removes redundant `auto` annotation in `getValueList`…
bjacobgordon Jan 15, 2025
16d8954
refactor(ONNX): leverages `operandType` when declaring `sizes` in `ge…
bjacobgordon Jan 14, 2025
9c3b4b5
refactor(ONNX): extracts `someTorchScalarType` within `getValueList` …
bjacobgordon Jan 27, 2025
456cd22
refactor(ONNX): consolidates duplicate conditional fragments in `getV…
bjacobgordon Jan 14, 2025
cde3ff5
refactor(ONNX): performs direct return in `getValueList` helper
bjacobgordon Jan 27, 2025
b45fb98
refactor(ONNX): extracts `someTorchScalarListType` in `getValueList` …
bjacobgordon Jan 14, 2025
4158652
refactor(ONNX): leverages type for element in existing list in `getVa…
bjacobgordon Feb 6, 2025
810e952
refactor(ONNX): extracts `getTorchScalarType` from `getValueList` helper
bjacobgordon Jan 29, 2025
8129353
refactor(ONNX): enforces assignment-usage adjacency for `itemList` in…
bjacobgordon Jan 14, 2025
557d25f
refactor(ONNX): extracts `extractTorchScalar` helper method from `get…
bjacobgordon Jan 23, 2025
b261571
refactor(ONNX): extracts `loc` within `getValueList` helper
bjacobgordon Jan 14, 2025
4393025
refactor(ONNX): adds anchoring labels to parameters in `getValueList`…
bjacobgordon Jan 24, 2025
327eeaf
refactor(ONNX): narrows first parameter from `Binder` to `Location` i…
bjacobgordon Jan 27, 2025
12c461c
refactor(ONNX): reorders parameters in `getValueList` helper
bjacobgordon Jan 23, 2025
0de9edf
refactor(ONNX): exposes start index as parameter in `getValueList` he…
bjacobgordon Jan 14, 2025
1eb77d7
refactor(ONNX): extracts `lengthOfFullList` within `getValueList` helper
bjacobgordon Jan 14, 2025
0fd39f1
refactor(ONNX): renames `itemList` to `runningScalarSublist` in `getV…
bjacobgordon Jan 14, 2025
da42cd8
refactor(ONNX): renames `item` to `eachScalar` in `getValueList` helper
bjacobgordon Jan 14, 2025
6bdf1b0
refactor(ONNX): renames `i` to `indexOfEachScalar` in `getValueList` …
bjacobgordon Jan 21, 2025
4907243
refactor(ONNX): renames `operand` to `given1DTensor` in `getValueList…
bjacobgordon Jan 14, 2025
d7c1af0
refactor(ONNX): renames `operandType` to `some1DTensorType` in `getVa…
bjacobgordon Jan 14, 2025
7bb8ab0
refactor(ONNX): renames `sizes` to `sizesOfSome1DTensor` in `getValue…
bjacobgordon Jan 14, 2025
dd39a6a
refactor(ONNX): renames `getValueList` helper to `createScalarSublist`
bjacobgordon Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 72 additions & 49 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,53 +180,67 @@ LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter,
return success();
}

Value getValueList(OpBinder binder, ConversionPatternRewriter &rewriter,
Value operand) {
SmallVector<Value> itemList;
auto sizes = dyn_cast<Torch::ValueTensorType>(operand.getType()).getSizes();
Torch::BaseTensorType operandType =
cast<Torch::BaseTensorType>(operand.getType());

SmallVector<int64_t> selectSizes;
selectSizes.push_back(1);
Type selectResultType = operandType.getWithSizesAndDtype(
llvm::ArrayRef(selectSizes), operandType.getOptionalDtype());

auto extract = [&rewriter, &binder](Value x, Value v) {
auto xTy = cast<Torch::ValueTensorType>(x.getType());
Type extractTy = rewriter.getType<Torch::FloatType>();
if (isa<IntegerType>(xTy.getDtype()))
extractTy = rewriter.getType<Torch::IntType>();

return rewriter.create<Torch::AtenItemOp>(binder.getLoc(), extractTy, v);
};

Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));

MLIRContext *context = binder.op->getContext();
for (int i = 2; i < sizes[0]; i++) {
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value ext = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(), selectResultType, operand, zero, selectIndex);
Value item = extract(operand, ext);
itemList.push_back(item);
}
auto xTy = cast<Torch::ValueTensorType>(operand.getType());
Value ValueList;
if (isa<IntegerType>(xTy.getDtype())) {
ValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), Torch::ListType::get(Torch::IntType::get(context)),
itemList);
} else {
ValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(), Torch::ListType::get(Torch::FloatType::get(context)),
itemList);
Type getTorchScalarType(
/* forElementIn */ Torch::BaseTensorType givenTensorType,
/* using */ ConversionPatternRewriter &rewriter) {
auto elementTypeForGivenTensor = givenTensorType.getDtype();

if (isa<IntegerType>(elementTypeForGivenTensor))
return rewriter.getType<Torch::IntType>();
if (isa<FloatType>(elementTypeForGivenTensor))
return rewriter.getType<Torch::FloatType>();

assert(false && "dtype for given tensor expected to be either int or float");
}

Value extractTorchScalar(
/* at */ Location givenLoc,
/* from */ int64_t givenIndex,
/* in */ Value given1DTensor,
/* using */ ConversionPatternRewriter &rewriter) {
auto some1DTensorType = cast<Torch::BaseTensorType>(given1DTensor.getType());

Type selectionTypeForSome1DTensor = some1DTensorType.getWithSizesAndDtype(
ArrayRef<int64_t>{1}, some1DTensorType.getOptionalDtype());

Value frontDim = rewriter.create<Torch::ConstantIntOp>(givenLoc, 0);

Value selectionIndex =
rewriter.create<Torch::ConstantIntOp>(givenLoc, givenIndex);

auto someTorchScalarType = getTorchScalarType(some1DTensorType, rewriter);

Value selectionFromGiven1DTensor = rewriter.create<Torch::AtenSelectIntOp>(
givenLoc, selectionTypeForSome1DTensor, given1DTensor, frontDim,
selectionIndex);

return rewriter.create<Torch::AtenItemOp>(givenLoc, someTorchScalarType,
selectionFromGiven1DTensor);
}

Value createScalarSublist(
/* at */ Location givenLoc,
/* movingForwardsThrough */ Value given1DTensor,
/* startingAt */ int64_t givenIndex,
/* using */ ConversionPatternRewriter &rewriter) {
auto some1DTensorType = cast<Torch::BaseTensorType>(given1DTensor.getType());
auto sizesOfSome1DTensor = some1DTensorType.getSizes();
auto lengthOfFullList = sizesOfSome1DTensor[0];

SmallVector<Value> runningScalarSublist;

for (int indexOfEachScalar = givenIndex; indexOfEachScalar < lengthOfFullList;
indexOfEachScalar++) {
Value eachScalar = extractTorchScalar(givenLoc, indexOfEachScalar,
given1DTensor, rewriter);
runningScalarSublist.push_back(eachScalar);
}
return ValueList;

auto someTorchScalarType = runningScalarSublist.front().getType();
Type someTorchScalarListType = Torch::ListType::get(someTorchScalarType);

return rewriter.create<Torch::PrimListConstructOp>(
givenLoc, someTorchScalarListType, runningScalarSublist);
}
} // namespace

Expand Down Expand Up @@ -2809,14 +2823,21 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
modeStrValue =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
}

int64_t assumedForemostSpatialDim = 2;

if (operands.size() < 4) {
Value scaleOperand = operands[2];
scalesValueList = getValueList(binder, rewriter, scaleOperand);
scalesValueList =
createScalarSublist(binder.getLoc(), scaleOperand,
assumedForemostSpatialDim, rewriter);
sizesValueList = noneVal;
} else {
Value sizeOperand = operands[3];
scalesValueList = noneVal;
sizesValueList = getValueList(binder, rewriter, sizeOperand);
sizesValueList =
createScalarSublist(binder.getLoc(), sizeOperand,
assumedForemostSpatialDim, rewriter);
}
if (isa<Torch::NoneType>(scalesValueList.getType()) &&
isa<Torch::NoneType>(sizesValueList.getType())) {
Expand Down Expand Up @@ -3339,7 +3360,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
return rewriter.notifyMatchFailure(
binder.op, "supports upto 3d upsampling only");

Value scalesValueList = getValueList(binder, rewriter, scales);
int64_t assumedForemostSpatialDim = 2;
Value scalesValueList = createScalarSublist(
binder.getLoc(), scales, assumedForemostSpatialDim, rewriter);
if (mode == "linear") {
if (resultRank == 4)
mode = "bilinear";
Expand Down
6 changes: 4 additions & 2 deletions test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2803,8 +2803,9 @@ func.func @test_upsample_nearest(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !t
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[SELECT:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: %[[SCALE:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[SELECT_0:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: %[[SELECT_0:.*]] = torch.aten.select.int %arg1, %[[INT0_0]], %[[INT3]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: %[[SCALE_0:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[SCALE_LIST:.*]] = torch.prim.ListConstruct %[[SCALE]], %[[SCALE_0]] : (!torch.float, !torch.float) -> !torch.list<float>
// CHECK: %[[MODE:.*]] = torch.constant.str "nearest"
Expand All @@ -2824,8 +2825,9 @@ func.func @test_upsample_bilinear(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[SELECT:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: %[[SCALE:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: %[[INT3:.*]] = torch.constant.int 3
// CHECK: %[[SELECT_0:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: %[[SELECT_0:.*]] = torch.aten.select.int %arg1, %[[INT0_0]], %[[INT3]] : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
// CHECK: %[[SCALE_0:.*]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[1],f32> -> !torch.float
// CHECK: %[[SCALE_LIST:.*]] = torch.prim.ListConstruct %[[SCALE]], %[[SCALE_0]] : (!torch.float, !torch.float) -> !torch.list<float>
// CHECK: %[[MODE:.*]] = torch.constant.str "bilinear"
Expand Down
Loading