Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch] aten.eye should use dynamic dims when no static dims are available #3202

Merged
merged 5 commits into from
Apr 30, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 27 additions & 22 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1059,44 +1059,49 @@ class DecomposeAtenEyeMOp : public OpRewritePattern<AtenEyeMOp> {
LogicalResult matchAndRewrite(AtenEyeMOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
int64_t n;

if (!matchPattern(op.getN(), m_TorchConstantInt(&n)))
return rewriter.notifyMatchFailure(op,
"unimplemented: n must be constant");
int64_t m;
if (!matchPattern(op.getM(), m_TorchConstantInt(&m)))
return rewriter.notifyMatchFailure(op,
"unimplemented: m must be constant");
Value none = rewriter.create<ConstantNoneOp>(loc);
auto outType = op.getType().dyn_cast<BaseTensorType>();
if (!outType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
if (!outType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
if (n < 0) {
return rewriter.notifyMatchFailure(op, "n must be greater or equal to 0");
}
if (m < 0) {
return rewriter.notifyMatchFailure(op, "m must be greater or equal to 0");
}

Value none = rewriter.create<ConstantNoneOp>(loc);
auto context = op.getContext();
auto int64Dtype = getDtypeIntValueForType(
rewriter, loc,
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
auto arangeType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type);

int64_t n = kUnknownSize;
int64_t m = kUnknownSize;

// prioritize getting shape from output shape
if (outType.hasSizes() && outType.getSizes().size() == 2) {
n = outType.getSizes().front();
m = outType.getSizes().back();
}
// if output shape is not available, try to get shape from input
if (n == kUnknownSize)
matchPattern(op.getN(), m_TorchConstantInt(&n));
if (m == kUnknownSize)
matchPattern(op.getM(), m_TorchConstantInt(&m));

// prepare two unsqueezed ranges that are equal on and only on the diagonal
std::optional<llvm::SmallVector<int64_t, 1>> rangeNSize;
if (n != kUnknownSize)
rangeNSize = {n};
Type rangeNType = outType.getWithSizesAndDtype(rangeNSize, si64Type);
Value rangeN = rewriter.create<AtenArangeOp>(
loc, arangeType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
loc, rangeNType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
/*device=*/op.getDevice(), /*pin_memory=*/none);

auto arangeType1 =
outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type);
std::optional<llvm::SmallVector<int64_t, 1>> rangeMSize;
if (m != kUnknownSize)
rangeMSize = {m};
Type rangeMType = outType.getWithSizesAndDtype(rangeMSize, si64Type);
Value rangeM = rewriter.create<AtenArangeOp>(
loc, arangeType1, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
loc, rangeMType, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);

Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
Expand Down
Loading