diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index b3608b4394f45..b5a93a0c5a898 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -525,6 +525,11 @@ class RewriterBase : public OpBuilder { } /// This method erases an operation that is known to have no uses. + /// + /// If the current insertion point is before the erased operation, it is + /// adjusted to the following operation (or the end of the block). If the + /// current insertion point is within the erased operation, the insertion + /// point is left in an invalid state. virtual void eraseOp(Operation *op); /// This method erases all operations in a block. @@ -539,6 +544,9 @@ class RewriterBase : public OpBuilder { /// somewhere in the middle (or beginning) of the dest block, the source block /// must have no successors. Otherwise, the resulting IR would have /// unreachable operations. + /// + /// If the insertion point is within the source block, it is adjusted to the + /// destination block. virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues = {}); @@ -549,6 +557,9 @@ class RewriterBase : public OpBuilder { /// /// The source block must have no successors. Otherwise, the resulting IR /// would have unreachable operations. + /// + /// If the insertion point is within the source block, it is adjusted to the + /// destination block. void inlineBlockBefore(Block *source, Operation *op, ValueRange argValues = {}); @@ -558,6 +569,9 @@ class RewriterBase : public OpBuilder { /// /// The dest block must have no successors. Otherwise, the resulting IR would /// have unreachable operation. + /// + /// If the insertion point is within the source block, it is adjusted to the + /// destination block. void mergeBlocks(Block *source, Block *dest, ValueRange argValues = {}); /// Split the operations starting at "before" (inclusive) out of the given diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 5c98417c874d3..9332f55bd9393 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -156,6 +156,11 @@ void RewriterBase::eraseOp(Operation *op) { assert(op->use_empty() && "expected 'op' to have no uses"); auto *rewriteListener = dyn_cast_if_present(listener); + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + // Fast path: If no listener is attached, the op can be dropped in one go. if (!rewriteListener) { op->erase(); @@ -320,6 +325,11 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest, moveOpBefore(&source->front(), dest, before); } + // If the current insertion point is within the source block, adjust the + // insertion point to the destination block. + if (getInsertionBlock() == source) + setInsertionPoint(dest, getInsertionPoint()); + // Erase the source block. assert(source->empty() && "expected 'source' to be empty"); eraseBlock(source); diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index df255cfcf3ec1..b8c40e34c91a7 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1758,6 +1758,12 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + SmallVector> newVals = llvm::map_to_vector(newValues, [](Value v) -> SmallVector { return v ? SmallVector{v} : SmallVector(); @@ -1773,6 +1779,12 @@ void ConversionPatternRewriter::replaceOpWithMultiple( impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + impl->replaceOp(op, std::move(newValues)); } @@ -1781,6 +1793,12 @@ void ConversionPatternRewriter::eraseOp(Operation *op) { impl->logger.startLine() << "** Erase : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + SmallVector> nullRepls(op->getNumResults(), {}); impl->replaceOp(op, std::move(nullRepls)); } @@ -1887,6 +1905,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, moveOpBefore(&source->front(), dest, before); } + // If the current insertion point is within the source block, adjust the + // insertion point to the destination block. + if (getInsertionBlock() == source) + setInsertionPoint(dest, getInsertionPoint()); + // Erase the source block. eraseBlock(source); }