From 68af84183c349db15b749021b2f600c1d642e2d5 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 3 Jul 2025 19:54:19 +0000 Subject: [PATCH 1/2] [mlir][IR][WIP] Set insertion point when erasing an operation --- mlir/lib/IR/PatternMatch.cpp | 38 ++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 5c98417c874d3..5a08c6534b6b6 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -150,12 +150,45 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) { eraseOp(op); } +/// Returns the given block iterator if it lies within the block `b`. +/// Otherwise, otherwise finds the ancestor of the given block iterator that +/// lies within `b`. Returns and "empty" iterator if the latter fails. +/// +/// Note: This is a variant of Block::findAncestorOpInBlock that operates on +/// block iterators instead of ops. +static std::pair +findAncestorIteratorInBlock(Block *b, Block *itBlock, Block::iterator it) { + // Case 1: The iterator lies within the block. + if (itBlock == b) + return std::make_pair(itBlock, it); + + // Otherwise: Find ancestor iterator. Bail if we run out of parent ops. + Operation *parentOp = itBlock->getParentOp(); + if (!parentOp) + return std::make_pair(static_cast(nullptr), Block::iterator()); + Operation *op = b->findAncestorOpInBlock(*parentOp); + if (!op) + return std::make_pair(static_cast(nullptr), Block::iterator()); + return std::make_pair(op->getBlock(), op->getIterator()); +} + /// This method erases an operation that is known to have no uses. The uses of /// the given operation *must* be known to be dead. void RewriterBase::eraseOp(Operation *op) { assert(op->use_empty() && "expected 'op' to have no uses"); auto *rewriteListener = dyn_cast_if_present(listener); + // If the current insertion point is before/within the erased operation, we + // need to adjust the insertion point to be after the operation. + if (getInsertionBlock()) { + Block *insertionBlock; + Block::iterator insertionPoint; + std::tie(insertionBlock, insertionPoint) = findAncestorIteratorInBlock( + op->getBlock(), getInsertionBlock(), getInsertionPoint()); + if (insertionBlock && insertionPoint == op->getIterator()) + setInsertionPointAfter(op); + } + // Fast path: If no listener is attached, the op can be dropped in one go. if (!rewriteListener) { op->erase(); @@ -320,6 +353,11 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest, moveOpBefore(&source->front(), dest, before); } + // If the current insertion point is within the source block, adjust the + // insertion point to the destination block. + if (getInsertionBlock() == source) + setInsertionPoint(dest, getInsertionPoint()); + // Erase the source block. assert(source->empty() && "expected 'source' to be empty"); eraseBlock(source); From 1d70ab4e41faf4a3dd9c6a8d5a4a8ca98794f7cc Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 4 Jul 2025 09:48:51 +0000 Subject: [PATCH 2/2] address comments --- mlir/include/mlir/IR/PatternMatch.h | 14 ++++++++ mlir/lib/IR/PatternMatch.cpp | 36 +++---------------- .../Transforms/Utils/DialectConversion.cpp | 23 ++++++++++++ 3 files changed, 41 insertions(+), 32 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index b3608b4394f45..b5a93a0c5a898 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -525,6 +525,11 @@ class RewriterBase : public OpBuilder { } /// This method erases an operation that is known to have no uses. + /// + /// If the current insertion point is before the erased operation, it is + /// adjusted to the following operation (or the end of the block). If the + /// current insertion point is within the erased operation, the insertion + /// point is left in an invalid state. virtual void eraseOp(Operation *op); /// This method erases all operations in a block. @@ -539,6 +544,9 @@ class RewriterBase : public OpBuilder { /// somewhere in the middle (or beginning) of the dest block, the source block /// must have no successors. Otherwise, the resulting IR would have /// unreachable operations. + /// + /// If the insertion point is within the source block, it is adjusted to the + /// destination block. virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues = {}); @@ -549,6 +557,9 @@ class RewriterBase : public OpBuilder { /// /// The source block must have no successors. Otherwise, the resulting IR /// would have unreachable operations. + /// + /// If the insertion point is within the source block, it is adjusted to the + /// destination block. void inlineBlockBefore(Block *source, Operation *op, ValueRange argValues = {}); @@ -558,6 +569,9 @@ class RewriterBase : public OpBuilder { /// /// The dest block must have no successors. Otherwise, the resulting IR would /// have unreachable operation. + /// + /// If the insertion point is within the source block, it is adjusted to the + /// destination block. void mergeBlocks(Block *source, Block *dest, ValueRange argValues = {}); /// Split the operations starting at "before" (inclusive) out of the given diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 5a08c6534b6b6..9332f55bd9393 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -150,44 +150,16 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) { eraseOp(op); } -/// Returns the given block iterator if it lies within the block `b`. -/// Otherwise, otherwise finds the ancestor of the given block iterator that -/// lies within `b`. Returns and "empty" iterator if the latter fails. -/// -/// Note: This is a variant of Block::findAncestorOpInBlock that operates on -/// block iterators instead of ops. -static std::pair -findAncestorIteratorInBlock(Block *b, Block *itBlock, Block::iterator it) { - // Case 1: The iterator lies within the block. - if (itBlock == b) - return std::make_pair(itBlock, it); - - // Otherwise: Find ancestor iterator. Bail if we run out of parent ops. - Operation *parentOp = itBlock->getParentOp(); - if (!parentOp) - return std::make_pair(static_cast(nullptr), Block::iterator()); - Operation *op = b->findAncestorOpInBlock(*parentOp); - if (!op) - return std::make_pair(static_cast(nullptr), Block::iterator()); - return std::make_pair(op->getBlock(), op->getIterator()); -} - /// This method erases an operation that is known to have no uses. The uses of /// the given operation *must* be known to be dead. void RewriterBase::eraseOp(Operation *op) { assert(op->use_empty() && "expected 'op' to have no uses"); auto *rewriteListener = dyn_cast_if_present(listener); - // If the current insertion point is before/within the erased operation, we - // need to adjust the insertion point to be after the operation. - if (getInsertionBlock()) { - Block *insertionBlock; - Block::iterator insertionPoint; - std::tie(insertionBlock, insertionPoint) = findAncestorIteratorInBlock( - op->getBlock(), getInsertionBlock(), getInsertionPoint()); - if (insertionBlock && insertionPoint == op->getIterator()) - setInsertionPointAfter(op); - } + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); // Fast path: If no listener is attached, the op can be dropped in one go. if (!rewriteListener) { diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index df255cfcf3ec1..b8c40e34c91a7 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1758,6 +1758,12 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + SmallVector> newVals = llvm::map_to_vector(newValues, [](Value v) -> SmallVector { return v ? SmallVector{v} : SmallVector(); @@ -1773,6 +1779,12 @@ void ConversionPatternRewriter::replaceOpWithMultiple( impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + impl->replaceOp(op, std::move(newValues)); } @@ -1781,6 +1793,12 @@ void ConversionPatternRewriter::eraseOp(Operation *op) { impl->logger.startLine() << "** Erase : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + SmallVector> nullRepls(op->getNumResults(), {}); impl->replaceOp(op, std::move(nullRepls)); } @@ -1887,6 +1905,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, moveOpBefore(&source->front(), dest, before); } + // If the current insertion point is within the source block, adjust the + // insertion point to the destination block. + if (getInsertionBlock() == source) + setInsertionPoint(dest, getInsertionPoint()); + // Erase the source block. eraseBlock(source); }