diff --git a/lib/polygeist/Passes/LinalgDebufferize.cpp b/lib/polygeist/Passes/LinalgDebufferize.cpp index 0590f710db3d..64358256d7df 100644 --- a/lib/polygeist/Passes/LinalgDebufferize.cpp +++ b/lib/polygeist/Passes/LinalgDebufferize.cpp @@ -146,9 +146,9 @@ bool comesBefore(Operation *a, Operation *b) { aAncestor = parent; } - llvm_unreachable("Operations do not share a common ancestor"); + //llvm_unreachable("Operations do not share a common ancestor"); //// Recursive case: compare parent operations - //return comesBefore(aParent, bParent); + return comesBefore(aParent, bParent); } std::vector getSortedUsers(Value val) { @@ -375,7 +375,7 @@ struct LinalgDebufferization : public OpRewritePattern { // Propagate value through each region Value currentValue = currentTensor; - for (Region* region : llvm::reverse(regions)) { + for (Region* region : regions) { Block& block = region->front(); Operation* terminator = block.getTerminator(); Operation *parentOp = region->getParentOp(); diff --git a/test/polygeist-opt/debufferize.mlir b/test/polygeist-opt/debufferize.mlir index bd28c13d7c51..183e81d98489 100644 --- a/test/polygeist-opt/debufferize.mlir +++ b/test/polygeist-opt/debufferize.mlir @@ -132,6 +132,81 @@ } } + module @in_place_cond_add_followed_by_add2{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1, %cond2: i1) { + %c0 = arith.constant 0 : index + //%buffer = memref.alloca() : memref<128xf32> + scf.if %cond2 { + scf.if %cond { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + %sum2 = arith.addf %sum, %value : f32 + linalg.yield %sum2 : f32 + } + } + } + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + return + } + } + + module @in_place_cond_add_followed_by_add3{ + func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1, %cond2: i1) { + %c0 = arith.constant 0 : index + //%buffer = memref.alloca() : memref<128xf32> + scf.if %cond2 { + scf.if %cond { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + %sum2 = arith.addf %sum, %value : f32 + linalg.yield %sum2 : f32 + } + } + } + scf.if %cond2 { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + %sum2 = arith.addf %sum, %value : f32 + %sum3 = arith.addf %sum2, %value : f32 + linalg.yield %sum3 : f32 + } + } + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"] + } ins(%buffer : memref<128xf32>) + outs(%buffer : memref<128xf32>) { + ^bb0(%in: f32, %out: f32): + %sum = arith.addf %in, %value : f32 + linalg.yield %sum : f32 + } + return + } + } + module @conv_2 { func.func @main(%0: memref<515x67xi32> {llvm.noalias}, %1: memref<4x4xi32> {llvm.noalias}, %2: memref<512x64xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage} { %c0_i32 = arith.constant 0 : i32