Skip to content

Commit

Permalink
refactor(ONNX): renames getValueList helper to createScalarSublist
Browse files Browse the repository at this point in the history
- Before:
  - "get": implies retrieval of some private property
  - "Value": restatement of the return type `Value`
  - "List": assumed result of casting the returned instance
- After:
  - "create": contextualizes the need to pass in `rewriter`
  - "Scalar": contextualizes the opaque return type
  - "Sublist": the relationship between the first parameter and the returned result
  • Loading branch information
bjacobgordon committed Jan 28, 2025
1 parent 445ad8f commit 2e2300b
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ Value createScalarAs(
}

template <typename SomeTorchScalarType>
Value getValueList(
Value createScalarSublist(
/* at */ Location givenLoc,
/* from */ Value given1DTensor,
/* throughBackStartingWith */ int64_t givenIndex,
Expand Down Expand Up @@ -2825,13 +2825,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(

if (operands.size() < 4) {
Value scaleOperand = operands[2];
scalesValueList = getValueList<Torch::FloatType>(
scalesValueList = createScalarSublist<Torch::FloatType>(
binder.getLoc(), scaleOperand, foremostSupportedDim, rewriter);
sizesValueList = noneVal;
} else {
Value sizeOperand = operands[3];
scalesValueList = noneVal;
sizesValueList = getValueList<Torch::IntType>(
sizesValueList = createScalarSublist<Torch::IntType>(
binder.getLoc(), sizeOperand, foremostSupportedDim, rewriter);
}
if (isa<Torch::NoneType>(scalesValueList.getType()) &&
Expand Down Expand Up @@ -3356,7 +3356,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.op, "supports upto 3d upsampling only");

auto foremostSupportedDim = TorchImageTensor::heightDim;
Value scalesValueList = getValueList<Torch::FloatType>(
Value scalesValueList = createScalarSublist<Torch::FloatType>(
binder.getLoc(), scales, foremostSupportedDim, rewriter);
if (mode == "linear") {
if (resultRank == 4)
Expand Down

0 comments on commit 2e2300b

Please sign in to comment.