Skip to content

Commit 6991fc0

Browse files
committed
Use dyn_cast for safety in casting RankedTensorType
1 parent 4fa273a commit 6991fc0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/enzyme_ad/jax/Utils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -605,15 +605,15 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
605605
PatternRewriter &rewriter) {
606606
assert(op);
607607

608-
auto outTy = cast<RankedTensorType>(op->getResult(0).getType());
609-
if (outTy.getRank() != 2)
610608
if (auto boolAttr = op->getAttrOfType<BoolAttr>(getAttrName())) {
611609
if (boolAttr.getValue())
612610
return State::GUARANTEED;
613611
else
614612
return State::NOTGUARANTEED;
615613
}
616614

615+
auto outTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
616+
if (!outTy || outTy.getRank() != 2)
617617
return State::NOTGUARANTEED; // this pass only checks for symmetric matrices
618618
if (outTy.getDimSize(0) != outTy.getDimSize(1))
619619
return State::NOTGUARANTEED; // quick check and exit

0 commit comments

Comments
 (0)