Skip to content

Commit

Permalink
Bug fix for erasing the op correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
arpitj1 committed Jan 31, 2025
1 parent e20708c commit 4a7efe7
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 180 deletions.
162 changes: 46 additions & 116 deletions lib/polygeist/Passes/LinalgDebufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,76 +51,6 @@ bool comesBefore(Operation *a, Operation *b) {
if (isAncestor(a, b)) return true;
if (isAncestor(b, a)) return false;

//Block *aBlock = a->getBlock();
//Block *bBlock = b->getBlock();

//// Same block: compare operation order
//if (aBlock == bBlock) {
// for (Operation &op : aBlock->getOperations()) {
// if (&op == a) return true;
// if (&op == b) return false;
// }
// llvm_unreachable("Operations not found in their parent block");
//}

//// Different blocks: compare region hierarchy
//Region *aRegion = aBlock->getParent();
//Region *bRegion = bBlock->getParent();

//// Same region: compare block order
//if (aRegion == bRegion) {
// //auto aBlockIt = std::find(aRegion->begin(), aRegion->end(), aBlock);
// //auto bBlockIt = std::find(aRegion->begin(), aRegion->end(), bBlock);
// //return aBlockIt < bBlockIt;
// //const int aIndex = std::distance(aRegion->begin(), aRegion->find(aBlock));
// //const int bIndex = std::distance(aRegion->begin(), aRegion->find(bBlock));
// //return aIndex < bIndex;
// auto get_block_pos = [](Region *region, Block *block) {
// auto &blocks = region->getBlocks();
// auto it = llvm::find_if(blocks, [block](Block &b) {
// return &b == block; // Address comparison
// });
// assert(it != blocks.end() && "Block not found in region");
// return std::distance(blocks.begin(), it);
// //return std::distance(region->getBlocks().begin(),
// // llvm::find(region->getBlocks(), block));
// };
// return get_block_pos(aRegion, aBlock) <
// get_block_pos(aRegion, bBlock);
//}

//// Different regions: compare parent operations
//Operation *aParent = aRegion->getParentOp();
//Operation *bParent = bRegion->getParentOp();

//// Same parent op: compare region order
//if (aParent == bParent) {
// //auto aRegionIt = std::find(aParent->getRegions().begin(),
// // aParent->getRegions().end(), aRegion);
// //auto bRegionIt = std::find(bParent->getRegions().begin(),
// // bParent->getRegions().end(), bRegion);
// //return aRegionIt < bRegionIt;
// //auto get_region_position = [](Operation *parent, Region *target) {
// //return std::distance(
// // parent->getRegions.begin(),
// // llvm::find_if(parent->getRegions(), [&](Region &r) {
// // return &r == target; // Compare region addresses
// // })
// // );
// //};

// auto get_region_position = [](Operation *parent, Region *target) {
// auto regions = parent->getRegions(); // Get reference to region list
// auto begin = regions.begin();
// auto it = llvm::find_if(regions, [&](Region &r) {
// return &r == target;
// });
// return std::distance(begin, it);
// };
// return get_region_position(aParent, aRegion) <
// get_region_position(aParent, bRegion);
//}

Operation *aParent = a->getParentOp();
Operation *bParent = b->getParentOp();
// Walk up b's hierarchy until we reach a's level
Expand Down Expand Up @@ -224,50 +154,42 @@ bool comesBefore(Operation *a, Operation *b) {
std::vector<Operation *> getSortedUsers(Value val) {
std::vector<Operation*> users;
for (Operation *user : val.getUsers()) {
users.push_back(user);
auto it = std::find_if(users.begin(), users.end(),
[user](const Operation* op) {
return op == user;
});
if(it == users.end())
users.push_back(user);
}

//TODO: problem is this only works for 1 level
// Sort the users based on their topological order
std::sort(users.begin(), users.end(), [](Operation *a, Operation *b) {
return comesBefore(a,b);
//if (a->getBlock() == b->getBlock()) {
// return a->isBeforeInBlock(b);
//}
//if (a->getParentRegion() == b->getParentRegion()) {
// Block *blockA = a->getBlock();
// Block *blockB = b->getBlock();
// return std::distance(blockA->getParent()->begin(), blockA->getIterator()) <
// std::distance(blockB->getParent()->begin(), blockB->getIterator());
//}

//return a->getParentRegion()->isAncestor(b->getParentRegion());
});

return users;
}

std::vector<Operation *> getSortedUsers(Operation *op) {
// Find the parent function
auto funcOp = op->getParentOfType<func::FuncOp>();
if (!funcOp)
return {};
// std::vector<Operation *> getSortedUsers(Operation *op) {
// // Find the parent function
// auto funcOp = op->getParentOfType<func::FuncOp>();
// if (!funcOp)
// return {};

// Map to store order of operations
llvm::DenseMap<Operation *, size_t> opOrder;
size_t order = 0;
// // Map to store order of operations
// llvm::DenseMap<Operation *, size_t> opOrder;
// size_t order = 0;

funcOp.walk([&](Operation *curOp) { opOrder[curOp] = order++; });
// funcOp.walk([&](Operation *curOp) { opOrder[curOp] = order++; });

std::vector<Operation *> sortedUsers(op->getUsers().begin(),
op->getUsers().end());
// std::vector<Operation *> sortedUsers(op->getUsers().begin(),
// op->getUsers().end());

std::sort(
sortedUsers.begin(), sortedUsers.end(),
[&](Operation *a, Operation *b) { return opOrder[a] < opOrder[b]; });
// std::sort(
// sortedUsers.begin(), sortedUsers.end(),
// [&](Operation *a, Operation *b) { return opOrder[a] < opOrder[b]; });

return sortedUsers;
}
// return sortedUsers;
// }

Region* findCommonAncestorRegion(Operation* a, Operation* b) {
DenseMap<Region*, size_t> regionCounts;
Expand Down Expand Up @@ -351,15 +273,15 @@ struct LinalgDebufferization : public OpRewritePattern<func::FuncOp> {

auto module = funcOp->getParentOfType<ModuleOp>();

SmallVector<Operation *> opsToDelete;
llvm::SmallPtrSet<Operation *, 16> opsToDeleteSet;
//SmallVector<Operation *> opsToDelete;
//llvm::SmallPtrSet<Operation *, 16> opsToDeleteSet;
// Tracks both old linalg.generics and linalg.generics with repeated values
// in ins and outs
llvm::SmallPtrSet<Operation *, 16> processedGenericOps;

LogicalResult passResult = failure();

auto handleMemref = [&](Value memVal) -> LogicalResult {
llvm::SmallPtrSet<Operation *, 16> processedGenericOps;
auto module = memVal.getParentRegion()->getParentOfType<ModuleOp>();

if (!memVal.getType().isa<MemRefType>()) {
Expand Down Expand Up @@ -428,8 +350,8 @@ struct LinalgDebufferization : public OpRewritePattern<func::FuncOp> {
if (auto genericOp = dyn_cast<linalg::GenericOp>(user)) {

// auto genericOp = cast<linalg::GenericOp>(user);
if (processedGenericOps.count(genericOp) > 0)
continue;
//if (processedGenericOps.count(genericOp) > 0)
// continue;
rewriter.setInsertionPointAfter(genericOp);

SmallVector<Value, 4> newInputs;
Expand Down Expand Up @@ -556,17 +478,22 @@ struct LinalgDebufferization : public OpRewritePattern<func::FuncOp> {
currentTensor = newGenericOp.getResult(newCurrentTensorIndex);
}

processedGenericOps.insert(genericOp.getOperation());
//processedGenericOps.insert(genericOp.getOperation());
// Delete the original genericOp
genericOp.erase();
//unsigned numUsers = std::distance(genericOp.getResults().getUsers().begin(), genericOp.getResults().getUsers().end());
//llvm::outs() << "Number of generic op uses: " << numUsers << "\n";
//genericOp.erase();
rewriter.eraseOp(genericOp);
//WalkResult::interrupt();
//opsToDelete.push_back(genericOp.getOperation());
}
}

auto toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
memVal.getLoc(), memrefType, currentTensor);
rewriter.create<memref::CopyOp>(memVal.getLoc(), toMemrefOp, memVal);

//if(currentTensor != prevTensor) {
auto toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
memVal.getLoc(), memrefType, currentTensor);
rewriter.create<memref::CopyOp>(memVal.getLoc(), toMemrefOp, memVal);
//}
// opsToDelete.push_back(allocaOp.getOperation());
return success();
};
Expand All @@ -584,13 +511,15 @@ struct LinalgDebufferization : public OpRewritePattern<func::FuncOp> {
handleMemref(alloca);
}

if (llvm::any_of(llvm::map_range(funcOp.getArguments(), handleMemref), [](LogicalResult res) {return res.succeeded();}))
for(auto arg: funcOp.getArguments()){
handleMemref(arg);
}

passResult = success();
for (Operation *op : opsToDelete) {
op->erase();
}
opsToDelete.clear();
//for (Operation *op : opsToDelete) {
// op->erase();
//}
//opsToDelete.clear();

return passResult;
}
Expand All @@ -603,6 +532,7 @@ struct LinalgDebufferize : public LinalgDebufferizeBase<LinalgDebufferize> {
} // namespace

void LinalgDebufferize::runOnOperation() {
auto module = getOperation()->getParentOfType<ModuleOp>();
RewritePatternSet patterns(&getContext());
patterns.insert<LinalgDebufferization>(&getContext());
patterns.insert<debufferizationAllocaRemoval>(&getContext());
Expand Down
127 changes: 63 additions & 64 deletions test/polygeist-opt/debufferize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -132,68 +132,67 @@
}
}

// module @conv_2 {
// func.func @main(%0: memref<515x67xi32> {llvm.noalias}, %1: memref<4x4xi32> {llvm.noalias}, %2: memref<512x64xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage<external>} {
// %c0_i32 = arith.constant 0 : i32
// linalg.generic {indexing_maps = [#map17, #map18, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0, %1 : memref<515x67xi32>, memref<4x4xi32>) outs(%2 : memref<512x64xi32>) {
// ^bb0(%in: i32, %in_0: i32, %out: i32):
// %3 = arith.muli %in, %in_0 : i32
// %4 = arith.addi %out, %3 : i32
// linalg.yield %4 : i32
// }
// return %c0_i32 : i32
// }
// }
module @conv_2 {
func.func @main(%0: memref<515x67xi32> {llvm.noalias}, %1: memref<4x4xi32> {llvm.noalias}, %2: memref<512x64xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage<external>} {
%c0_i32 = arith.constant 0 : i32
linalg.generic {indexing_maps = [#map17, #map18, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0, %1 : memref<515x67xi32>, memref<4x4xi32>) outs(%2 : memref<512x64xi32>) {
^bb0(%in: i32, %in_0: i32, %out: i32):
%3 = arith.muli %in, %in_0 : i32
%4 = arith.addi %out, %3 : i32
linalg.yield %4 : i32
}
return %c0_i32 : i32
}
}

// module @harris_score_with_gradient_extra_kernel {
// //memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1>
// //memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]>
// //memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]>
// func.func @main(%0: memref<3x3xi32> {llvm.noalias}, %1: memref<3x3xi32> {llvm.noalias}, %2: memref<5x5xi32> {llvm.noalias}, %score: memref<512x512xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage<external>} {
// %c4_i32 = arith.constant 4 : i32
// %c0_i32 = arith.constant 0 : i32
// %alloca = memref.alloca() : memref<512x512xi32>
// %alloca_0 = memref.alloca() : memref<512x512xi32>
// %alloca_1 = memref.alloca() : memref<512x512xi32>
// %alloca_2 = memref.alloca() : memref<516x516xi32>
// %alloca_3 = memref.alloca() : memref<516x516xi32>
// %alloca_4 = memref.alloca() : memref<518x518xi32>
// //%score = memref.alloca() : memref<512x512xi32>
// //%0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32>
// //%1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32>
// //%2 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32>
// linalg.generic {indexing_maps = [#map17, #map18, #map18, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_4, %0, %1 : memref<518x518xi32>, memref<3x3xi32>, memref<3x3xi32>) outs(%alloca_2, %alloca_3 : memref<516x516xi32>, memref<516x516xi32>) {
// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32):
// %4 = arith.muli %in, %in_5 : i32
// %5 = arith.addi %out_7, %4 : i32
// %6 = arith.muli %in, %in_6 : i32
// %7 = arith.addi %out, %6 : i32
// linalg.yield %7, %5 : i32, i32
// }
// linalg.generic {indexing_maps = [#map17, #map17, #map18, #map19, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_3, %alloca_2, %2 : memref<516x516xi32>, memref<516x516xi32>, memref<5x5xi32>) outs(%alloca, %alloca_0, %alloca_1 : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) {
// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32, %out_8: i32):
// %4 = arith.muli %in, %in : i32
// %5 = arith.muli %4, %in_6 : i32
// %6 = arith.addi %out_8, %5 : i32
// %7 = arith.muli %in_5, %in_5 : i32
// %8 = arith.muli %7, %in_6 : i32
// %9 = arith.addi %out_7, %8 : i32
// %10 = arith.muli %in, %in_5 : i32
// %11 = arith.muli %10, %in_6 : i32
// %12 = arith.addi %out, %11 : i32
// linalg.yield %12, %9, %6 : i32, i32, i32
// }
// linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel"]} ins(%alloca_1, %alloca_0, %alloca : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) outs(%score : memref<512x512xi32>) {
// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32):
// %4 = arith.muli %in, %in_5 : i32
// %5 = arith.muli %in_6, %in_6 : i32
// %6 = arith.subi %4, %5 : i32
// %7 = arith.addi %in, %in_5 : i32
// %8 = arith.muli %7, %c4_i32 : i32
// %9 = arith.muli %8, %7 : i32
// %10 = arith.subi %6, %9 : i32
// linalg.yield %10 : i32
// }
// return %c0_i32 : i32
// }
// }
module @harris_score_with_gradient_extra_kernel {
//memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1>
//memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]>
//memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]>
func.func @main(%input: memref<518x518xi32>, %0: memref<3x3xi32> {llvm.noalias}, %1: memref<3x3xi32> {llvm.noalias}, %2: memref<5x5xi32> {llvm.noalias}, %score: memref<512x512xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage<external>} {
%c4_i32 = arith.constant 4 : i32
%c0_i32 = arith.constant 0 : i32
%alloca = memref.alloca() : memref<512x512xi32>
%alloca_0 = memref.alloca() : memref<512x512xi32>
%alloca_1 = memref.alloca() : memref<512x512xi32>
%alloca_2 = memref.alloca() : memref<516x516xi32>
%alloca_3 = memref.alloca() : memref<516x516xi32>
//%score = memref.alloca() : memref<512x512xi32>
//%0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32>
//%1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32>
//%2 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32>
linalg.generic {indexing_maps = [#map17, #map18, #map18, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%input, %0, %1 : memref<518x518xi32>, memref<3x3xi32>, memref<3x3xi32>) outs(%alloca_2, %alloca_3 : memref<516x516xi32>, memref<516x516xi32>) {
^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32):
%4 = arith.muli %in, %in_5 : i32
%5 = arith.addi %out_7, %4 : i32
%6 = arith.muli %in, %in_6 : i32
%7 = arith.addi %out, %6 : i32
linalg.yield %7, %5 : i32, i32
}
linalg.generic {indexing_maps = [#map17, #map17, #map18, #map19, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_3, %alloca_2, %2 : memref<516x516xi32>, memref<516x516xi32>, memref<5x5xi32>) outs(%alloca, %alloca_0, %alloca_1 : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) {
^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32, %out_8: i32):
%4 = arith.muli %in, %in : i32
%5 = arith.muli %4, %in_6 : i32
%6 = arith.addi %out_8, %5 : i32
%7 = arith.muli %in_5, %in_5 : i32
%8 = arith.muli %7, %in_6 : i32
%9 = arith.addi %out_7, %8 : i32
%10 = arith.muli %in, %in_5 : i32
%11 = arith.muli %10, %in_6 : i32
%12 = arith.addi %out, %11 : i32
linalg.yield %12, %9, %6 : i32, i32, i32
}
linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel"]} ins(%alloca_1, %alloca_0, %alloca : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) outs(%score : memref<512x512xi32>) {
^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32):
%4 = arith.muli %in, %in_5 : i32
%5 = arith.muli %in_6, %in_6 : i32
%6 = arith.subi %4, %5 : i32
%7 = arith.addi %in, %in_5 : i32
%8 = arith.muli %7, %c4_i32 : i32
%9 = arith.muli %8, %7 : i32
%10 = arith.subi %6, %9 : i32
linalg.yield %10 : i32
}
return %c0_i32 : i32
}
}

0 comments on commit 4a7efe7

Please sign in to comment.