From 4a7efe78d132f0b8ed49b8a30201b86728ea174e Mon Sep 17 00:00:00 2001 From: Arpit Jaiswal Date: Fri, 31 Jan 2025 13:47:19 -0800 Subject: [PATCH] Bug fix for erasing the op correctly --- lib/polygeist/Passes/LinalgDebufferize.cpp | 162 ++++++--------------- test/polygeist-opt/debufferize.mlir | 127 ++++++++-------- 2 files changed, 109 insertions(+), 180 deletions(-) diff --git a/lib/polygeist/Passes/LinalgDebufferize.cpp b/lib/polygeist/Passes/LinalgDebufferize.cpp index 7c2a57405d8e..0590f710db3d 100644 --- a/lib/polygeist/Passes/LinalgDebufferize.cpp +++ b/lib/polygeist/Passes/LinalgDebufferize.cpp @@ -51,76 +51,6 @@ bool comesBefore(Operation *a, Operation *b) { if (isAncestor(a, b)) return true; if (isAncestor(b, a)) return false; - //Block *aBlock = a->getBlock(); - //Block *bBlock = b->getBlock(); - - //// Same block: compare operation order - //if (aBlock == bBlock) { - // for (Operation &op : aBlock->getOperations()) { - // if (&op == a) return true; - // if (&op == b) return false; - // } - // llvm_unreachable("Operations not found in their parent block"); - //} - - //// Different blocks: compare region hierarchy - //Region *aRegion = aBlock->getParent(); - //Region *bRegion = bBlock->getParent(); - - //// Same region: compare block order - //if (aRegion == bRegion) { - // //auto aBlockIt = std::find(aRegion->begin(), aRegion->end(), aBlock); - // //auto bBlockIt = std::find(aRegion->begin(), aRegion->end(), bBlock); - // //return aBlockIt < bBlockIt; - // //const int aIndex = std::distance(aRegion->begin(), aRegion->find(aBlock)); - // //const int bIndex = std::distance(aRegion->begin(), aRegion->find(bBlock)); - // //return aIndex < bIndex; - // auto get_block_pos = [](Region *region, Block *block) { - // auto &blocks = region->getBlocks(); - // auto it = llvm::find_if(blocks, [block](Block &b) { - // return &b == block; // Address comparison - // }); - // assert(it != blocks.end() && "Block not found in region"); - // return std::distance(blocks.begin(), it); - // //return std::distance(region->getBlocks().begin(), - // // llvm::find(region->getBlocks(), block)); - // }; - // return get_block_pos(aRegion, aBlock) < - // get_block_pos(aRegion, bBlock); - //} - - //// Different regions: compare parent operations - //Operation *aParent = aRegion->getParentOp(); - //Operation *bParent = bRegion->getParentOp(); - - //// Same parent op: compare region order - //if (aParent == bParent) { - // //auto aRegionIt = std::find(aParent->getRegions().begin(), - // // aParent->getRegions().end(), aRegion); - // //auto bRegionIt = std::find(bParent->getRegions().begin(), - // // bParent->getRegions().end(), bRegion); - // //return aRegionIt < bRegionIt; - // //auto get_region_position = [](Operation *parent, Region *target) { - // //return std::distance( - // // parent->getRegions.begin(), - // // llvm::find_if(parent->getRegions(), [&](Region &r) { - // // return &r == target; // Compare region addresses - // // }) - // // ); - // //}; - - // auto get_region_position = [](Operation *parent, Region *target) { - // auto regions = parent->getRegions(); // Get reference to region list - // auto begin = regions.begin(); - // auto it = llvm::find_if(regions, [&](Region &r) { - // return &r == target; - // }); - // return std::distance(begin, it); - // }; - // return get_region_position(aParent, aRegion) < - // get_region_position(aParent, bRegion); - //} - Operation *aParent = a->getParentOp(); Operation *bParent = b->getParentOp(); // Walk up b's hierarchy until we reach a's level @@ -224,50 +154,42 @@ bool comesBefore(Operation *a, Operation *b) { std::vector getSortedUsers(Value val) { std::vector users; for (Operation *user : val.getUsers()) { - users.push_back(user); + auto it = std::find_if(users.begin(), users.end(), + [user](const Operation* op) { + return op == user; + }); + if(it == users.end()) + users.push_back(user); } - //TODO: problem is this only works for 1 level - // Sort the users based on their topological order std::sort(users.begin(), users.end(), [](Operation *a, Operation *b) { return comesBefore(a,b); - //if (a->getBlock() == b->getBlock()) { - // return a->isBeforeInBlock(b); - //} - //if (a->getParentRegion() == b->getParentRegion()) { - // Block *blockA = a->getBlock(); - // Block *blockB = b->getBlock(); - // return std::distance(blockA->getParent()->begin(), blockA->getIterator()) < - // std::distance(blockB->getParent()->begin(), blockB->getIterator()); - //} - - //return a->getParentRegion()->isAncestor(b->getParentRegion()); }); return users; } -std::vector getSortedUsers(Operation *op) { - // Find the parent function - auto funcOp = op->getParentOfType(); - if (!funcOp) - return {}; +// std::vector getSortedUsers(Operation *op) { +// // Find the parent function +// auto funcOp = op->getParentOfType(); +// if (!funcOp) +// return {}; - // Map to store order of operations - llvm::DenseMap opOrder; - size_t order = 0; +// // Map to store order of operations +// llvm::DenseMap opOrder; +// size_t order = 0; - funcOp.walk([&](Operation *curOp) { opOrder[curOp] = order++; }); +// funcOp.walk([&](Operation *curOp) { opOrder[curOp] = order++; }); - std::vector sortedUsers(op->getUsers().begin(), - op->getUsers().end()); +// std::vector sortedUsers(op->getUsers().begin(), +// op->getUsers().end()); - std::sort( - sortedUsers.begin(), sortedUsers.end(), - [&](Operation *a, Operation *b) { return opOrder[a] < opOrder[b]; }); +// std::sort( +// sortedUsers.begin(), sortedUsers.end(), +// [&](Operation *a, Operation *b) { return opOrder[a] < opOrder[b]; }); - return sortedUsers; -} +// return sortedUsers; +// } Region* findCommonAncestorRegion(Operation* a, Operation* b) { DenseMap regionCounts; @@ -351,15 +273,15 @@ struct LinalgDebufferization : public OpRewritePattern { auto module = funcOp->getParentOfType(); - SmallVector opsToDelete; - llvm::SmallPtrSet opsToDeleteSet; + //SmallVector opsToDelete; + //llvm::SmallPtrSet opsToDeleteSet; // Tracks both old linalg.generics and linalg.generics with repeated values // in ins and outs - llvm::SmallPtrSet processedGenericOps; LogicalResult passResult = failure(); auto handleMemref = [&](Value memVal) -> LogicalResult { + llvm::SmallPtrSet processedGenericOps; auto module = memVal.getParentRegion()->getParentOfType(); if (!memVal.getType().isa()) { @@ -428,8 +350,8 @@ struct LinalgDebufferization : public OpRewritePattern { if (auto genericOp = dyn_cast(user)) { // auto genericOp = cast(user); - if (processedGenericOps.count(genericOp) > 0) - continue; + //if (processedGenericOps.count(genericOp) > 0) + // continue; rewriter.setInsertionPointAfter(genericOp); SmallVector newInputs; @@ -556,17 +478,22 @@ struct LinalgDebufferization : public OpRewritePattern { currentTensor = newGenericOp.getResult(newCurrentTensorIndex); } - processedGenericOps.insert(genericOp.getOperation()); + //processedGenericOps.insert(genericOp.getOperation()); // Delete the original genericOp - genericOp.erase(); + //unsigned numUsers = std::distance(genericOp.getResults().getUsers().begin(), genericOp.getResults().getUsers().end()); + //llvm::outs() << "Number of generic op uses: " << numUsers << "\n"; + //genericOp.erase(); + rewriter.eraseOp(genericOp); //WalkResult::interrupt(); //opsToDelete.push_back(genericOp.getOperation()); } } - - auto toMemrefOp = rewriter.create( - memVal.getLoc(), memrefType, currentTensor); - rewriter.create(memVal.getLoc(), toMemrefOp, memVal); + + //if(currentTensor != prevTensor) { + auto toMemrefOp = rewriter.create( + memVal.getLoc(), memrefType, currentTensor); + rewriter.create(memVal.getLoc(), toMemrefOp, memVal); + //} // opsToDelete.push_back(allocaOp.getOperation()); return success(); }; @@ -584,13 +511,15 @@ struct LinalgDebufferization : public OpRewritePattern { handleMemref(alloca); } - if (llvm::any_of(llvm::map_range(funcOp.getArguments(), handleMemref), [](LogicalResult res) {return res.succeeded();})) + for(auto arg: funcOp.getArguments()){ + handleMemref(arg); + } passResult = success(); - for (Operation *op : opsToDelete) { - op->erase(); - } - opsToDelete.clear(); + //for (Operation *op : opsToDelete) { + // op->erase(); + //} + //opsToDelete.clear(); return passResult; } @@ -603,6 +532,7 @@ struct LinalgDebufferize : public LinalgDebufferizeBase { } // namespace void LinalgDebufferize::runOnOperation() { + auto module = getOperation()->getParentOfType(); RewritePatternSet patterns(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); diff --git a/test/polygeist-opt/debufferize.mlir b/test/polygeist-opt/debufferize.mlir index 34e203b9dbb6..bd28c13d7c51 100644 --- a/test/polygeist-opt/debufferize.mlir +++ b/test/polygeist-opt/debufferize.mlir @@ -132,68 +132,67 @@ } } -// module @conv_2 { -// func.func @main(%0: memref<515x67xi32> {llvm.noalias}, %1: memref<4x4xi32> {llvm.noalias}, %2: memref<512x64xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage} { -// %c0_i32 = arith.constant 0 : i32 -// linalg.generic {indexing_maps = [#map17, #map18, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0, %1 : memref<515x67xi32>, memref<4x4xi32>) outs(%2 : memref<512x64xi32>) { -// ^bb0(%in: i32, %in_0: i32, %out: i32): -// %3 = arith.muli %in, %in_0 : i32 -// %4 = arith.addi %out, %3 : i32 -// linalg.yield %4 : i32 -// } -// return %c0_i32 : i32 -// } -// } + module @conv_2 { + func.func @main(%0: memref<515x67xi32> {llvm.noalias}, %1: memref<4x4xi32> {llvm.noalias}, %2: memref<512x64xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c0_i32 = arith.constant 0 : i32 + linalg.generic {indexing_maps = [#map17, #map18, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%0, %1 : memref<515x67xi32>, memref<4x4xi32>) outs(%2 : memref<512x64xi32>) { + ^bb0(%in: i32, %in_0: i32, %out: i32): + %3 = arith.muli %in, %in_0 : i32 + %4 = arith.addi %out, %3 : i32 + linalg.yield %4 : i32 + } + return %c0_i32 : i32 + } + } -// module @harris_score_with_gradient_extra_kernel { -// //memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1> -// //memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> -// //memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> -// func.func @main(%0: memref<3x3xi32> {llvm.noalias}, %1: memref<3x3xi32> {llvm.noalias}, %2: memref<5x5xi32> {llvm.noalias}, %score: memref<512x512xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage} { -// %c4_i32 = arith.constant 4 : i32 -// %c0_i32 = arith.constant 0 : i32 -// %alloca = memref.alloca() : memref<512x512xi32> -// %alloca_0 = memref.alloca() : memref<512x512xi32> -// %alloca_1 = memref.alloca() : memref<512x512xi32> -// %alloca_2 = memref.alloca() : memref<516x516xi32> -// %alloca_3 = memref.alloca() : memref<516x516xi32> -// %alloca_4 = memref.alloca() : memref<518x518xi32> -// //%score = memref.alloca() : memref<512x512xi32> -// //%0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> -// //%1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> -// //%2 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32> -// linalg.generic {indexing_maps = [#map17, #map18, #map18, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_4, %0, %1 : memref<518x518xi32>, memref<3x3xi32>, memref<3x3xi32>) outs(%alloca_2, %alloca_3 : memref<516x516xi32>, memref<516x516xi32>) { -// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): -// %4 = arith.muli %in, %in_5 : i32 -// %5 = arith.addi %out_7, %4 : i32 -// %6 = arith.muli %in, %in_6 : i32 -// %7 = arith.addi %out, %6 : i32 -// linalg.yield %7, %5 : i32, i32 -// } -// linalg.generic {indexing_maps = [#map17, #map17, #map18, #map19, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_3, %alloca_2, %2 : memref<516x516xi32>, memref<516x516xi32>, memref<5x5xi32>) outs(%alloca, %alloca_0, %alloca_1 : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) { -// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32, %out_8: i32): -// %4 = arith.muli %in, %in : i32 -// %5 = arith.muli %4, %in_6 : i32 -// %6 = arith.addi %out_8, %5 : i32 -// %7 = arith.muli %in_5, %in_5 : i32 -// %8 = arith.muli %7, %in_6 : i32 -// %9 = arith.addi %out_7, %8 : i32 -// %10 = arith.muli %in, %in_5 : i32 -// %11 = arith.muli %10, %in_6 : i32 -// %12 = arith.addi %out, %11 : i32 -// linalg.yield %12, %9, %6 : i32, i32, i32 -// } -// linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel"]} ins(%alloca_1, %alloca_0, %alloca : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) outs(%score : memref<512x512xi32>) { -// ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32): -// %4 = arith.muli %in, %in_5 : i32 -// %5 = arith.muli %in_6, %in_6 : i32 -// %6 = arith.subi %4, %5 : i32 -// %7 = arith.addi %in, %in_5 : i32 -// %8 = arith.muli %7, %c4_i32 : i32 -// %9 = arith.muli %8, %7 : i32 -// %10 = arith.subi %6, %9 : i32 -// linalg.yield %10 : i32 -// } -// return %c0_i32 : i32 -// } -// } \ No newline at end of file + module @harris_score_with_gradient_extra_kernel { + //memref.global "private" @_ZL8coeffs_1 : memref<5x5xi32> = dense<1> + //memref.global "private" @_ZL8coeffs_y : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + //memref.global "private" @_ZL8coeffs_x : memref<3x3xi32> = dense<[[-3, -10, -3], [0, 0, 0], [3, 10, 3]]> + func.func @main(%input: memref<518x518xi32>, %0: memref<3x3xi32> {llvm.noalias}, %1: memref<3x3xi32> {llvm.noalias}, %2: memref<5x5xi32> {llvm.noalias}, %score: memref<512x512xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage} { + %c4_i32 = arith.constant 4 : i32 + %c0_i32 = arith.constant 0 : i32 + %alloca = memref.alloca() : memref<512x512xi32> + %alloca_0 = memref.alloca() : memref<512x512xi32> + %alloca_1 = memref.alloca() : memref<512x512xi32> + %alloca_2 = memref.alloca() : memref<516x516xi32> + %alloca_3 = memref.alloca() : memref<516x516xi32> + //%score = memref.alloca() : memref<512x512xi32> + //%0 = memref.get_global @_ZL8coeffs_x : memref<3x3xi32> + //%1 = memref.get_global @_ZL8coeffs_y : memref<3x3xi32> + //%2 = memref.get_global @_ZL8coeffs_1 : memref<5x5xi32> + linalg.generic {indexing_maps = [#map17, #map18, #map18, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%input, %0, %1 : memref<518x518xi32>, memref<3x3xi32>, memref<3x3xi32>) outs(%alloca_2, %alloca_3 : memref<516x516xi32>, memref<516x516xi32>) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32): + %4 = arith.muli %in, %in_5 : i32 + %5 = arith.addi %out_7, %4 : i32 + %6 = arith.muli %in, %in_6 : i32 + %7 = arith.addi %out, %6 : i32 + linalg.yield %7, %5 : i32, i32 + } + linalg.generic {indexing_maps = [#map17, #map17, #map18, #map19, #map19, #map19], iterator_types = ["parallel", "parallel", "reduction", "reduction"]} ins(%alloca_3, %alloca_2, %2 : memref<516x516xi32>, memref<516x516xi32>, memref<5x5xi32>) outs(%alloca, %alloca_0, %alloca_1 : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32, %out_7: i32, %out_8: i32): + %4 = arith.muli %in, %in : i32 + %5 = arith.muli %4, %in_6 : i32 + %6 = arith.addi %out_8, %5 : i32 + %7 = arith.muli %in_5, %in_5 : i32 + %8 = arith.muli %7, %in_6 : i32 + %9 = arith.addi %out_7, %8 : i32 + %10 = arith.muli %in, %in_5 : i32 + %11 = arith.muli %10, %in_6 : i32 + %12 = arith.addi %out, %11 : i32 + linalg.yield %12, %9, %6 : i32, i32, i32 + } + linalg.generic {indexing_maps = [#map22, #map22, #map22, #map22], iterator_types = ["parallel", "parallel"]} ins(%alloca_1, %alloca_0, %alloca : memref<512x512xi32>, memref<512x512xi32>, memref<512x512xi32>) outs(%score : memref<512x512xi32>) { + ^bb0(%in: i32, %in_5: i32, %in_6: i32, %out: i32): + %4 = arith.muli %in, %in_5 : i32 + %5 = arith.muli %in_6, %in_6 : i32 + %6 = arith.subi %4, %5 : i32 + %7 = arith.addi %in, %in_5 : i32 + %8 = arith.muli %7, %c4_i32 : i32 + %9 = arith.muli %8, %7 : i32 + %10 = arith.subi %6, %9 : i32 + linalg.yield %10 : i32 + } + return %c0_i32 : i32 + } + } \ No newline at end of file