@@ -3086,7 +3086,7 @@ struct SliceElementwise final
30863086 auto elem = op.getOperand().getDefiningOp();
30873087 if (!elem)
30883088 return failure();
3089- if (!elem->hasTrait<mlir::OpTrait::Elementwise>( ))
3089+ if (!stablehlo::hasTraitElementwise(elem ))
30903090 return failure();
30913091 if (llvm::hasSingleElement(elem->getUsers())) {
30923092 SmallVector<Value> ops;
@@ -3772,7 +3772,7 @@ struct FullReduceReshapeOrTranspose final
37723772 reshapeOrTransposes.push_back(rs);
37733773 continue;
37743774 }
3775- if (!curOp->hasTrait<mlir::OpTrait::Elementwise>( ))
3775+ if (!hasTraitElementwise(curOp ))
37763776 return failure();
37773777 if (!isMemoryEffectFree(curOp))
37783778 return failure();
@@ -10469,7 +10469,7 @@ struct SliceReshapeElementwise final
1046910469 auto elem = reshape.getOperand().getDefiningOp();
1047010470 if (!elem)
1047110471 return failure();
10472- if (!elem->hasTrait<mlir::OpTrait::Elementwise>( ))
10472+ if (!stablehlo::hasTraitElementwise(elem ))
1047310473 return failure();
1047410474 if (!llvm::hasSingleElement(elem->getUsers()))
1047510475 return failure();
@@ -10515,7 +10515,7 @@ struct TransposeElementwise final
1051510515 if (!elem)
1051610516 return failure();
1051710517
10518- if (!elem->hasTrait<mlir::OpTrait::Elementwise>( ))
10518+ if (!stablehlo::hasTraitElementwise(elem ))
1051910519 return failure();
1052010520
1052110521 bool singleUser = llvm::hasSingleElement(elem->getUsers());
@@ -10913,7 +10913,7 @@ struct ReshapeElementwise final
1091310913 if (onlySingleUser && !llvm::hasSingleElement(elem->getUsers()))
1091410914 return failure();
1091510915
10916- if (!elem->hasTrait<mlir::OpTrait::Elementwise>( ))
10916+ if (!stablehlo::hasTraitElementwise(elem ))
1091710917 return failure();
1091810918
1091910919 SmallVector<Value> ops;
@@ -18958,8 +18958,8 @@ LogicalResult commUnaryOpElementwise(bool onlySingleUser, EnzymeOp op,
1895818958
1895918959 bool anyModified = false;
1896018960 for (auto elem : llvm::make_early_inc_range(op->getUsers())) {
18961- if (!elem->template hasTrait<mlir::OpTrait::Elementwise>() ||
18962- elem->getNumResults() != 1 || elem-> getNumOperands() != 1)
18961+ if (!hasTraitElementwise( elem) || elem->getNumResults() != 1 ||
18962+ elem->getNumOperands() != 1)
1896318963 continue;
1896418964
1896518965 auto newOp = rewriter.create(
@@ -20270,7 +20270,7 @@ struct ConcatElementwise final
2027020270 if (isa<stablehlo::ConvertOp>(vdefOp)) // Conflicts with ConvertConcat
2027120271 return failure();
2027220272
20273- if (vdefOp->hasTrait<mlir::OpTrait::Elementwise>( )) {
20273+ if (hasTraitElementwise(vdefOp )) {
2027420274 if (concatOpOperands.size() != 0) {
2027520275 if (!OperationEquivalence::isEquivalentTo(
2027620276 concatOpOperands[0], vdefOp,
@@ -20671,7 +20671,7 @@ struct ConcatReshapeElementwise final
2067120671 if (!vdefOp)
2067220672 return failure();
2067320673
20674- if (vdefOp->hasTrait<mlir::OpTrait::Elementwise>( )) {
20674+ if (hasTraitElementwise(vdefOp )) {
2067520675 if (concatOpOperands.size() != 0) {
2067620676 if (!OperationEquivalence::isEquivalentTo(
2067720677 concatOpOperands[0], vdefOp,
@@ -21259,7 +21259,7 @@ struct GatherElementwise
2125921259 PatternRewriter &rewriter) const {
2126021260 auto gatherInput = op.getOperand();
2126121261 auto defOp = gatherInput.getDefiningOp();
21262- if (!defOp || !defOp->hasTrait<mlir::OpTrait::Elementwise>( ))
21262+ if (!defOp || !hasTraitElementwise(defOp ))
2126321263 return rewriter.notifyMatchFailure(op,
2126421264 "GatherOp with non-elementwise input");
2126521265
0 commit comments