Skip to content

Commit

Permalink
Bug fixes for 1. recursive parent search in sorting users 2. traversi…
Browse files Browse the repository at this point in the history
…ng regions to propagate values in correct order
  • Loading branch information
arpitj1 committed Jan 31, 2025
1 parent 4a7efe7 commit 6d8832f
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 3 deletions.
6 changes: 3 additions & 3 deletions lib/polygeist/Passes/LinalgDebufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ bool comesBefore(Operation *a, Operation *b) {
aAncestor = parent;
}

llvm_unreachable("Operations do not share a common ancestor");
//llvm_unreachable("Operations do not share a common ancestor");
//// Recursive case: compare parent operations
//return comesBefore(aParent, bParent);
return comesBefore(aParent, bParent);
}

std::vector<Operation *> getSortedUsers(Value val) {
Expand Down Expand Up @@ -375,7 +375,7 @@ struct LinalgDebufferization : public OpRewritePattern<func::FuncOp> {

// Propagate value through each region
Value currentValue = currentTensor;
for (Region* region : llvm::reverse(regions)) {
for (Region* region : regions) {
Block& block = region->front();
Operation* terminator = block.getTerminator();
Operation *parentOp = region->getParentOp();
Expand Down
75 changes: 75 additions & 0 deletions test/polygeist-opt/debufferize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,81 @@
}
}

module @in_place_cond_add_followed_by_add2{
func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1, %cond2: i1) {
%c0 = arith.constant 0 : index
//%buffer = memref.alloca() : memref<128xf32>
scf.if %cond2 {
scf.if %cond {
linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%buffer : memref<128xf32>)
outs(%buffer : memref<128xf32>) {
^bb0(%in: f32, %out: f32):
%sum = arith.addf %in, %value : f32
%sum2 = arith.addf %sum, %value : f32
linalg.yield %sum2 : f32
}
}
}
linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%buffer : memref<128xf32>)
outs(%buffer : memref<128xf32>) {
^bb0(%in: f32, %out: f32):
%sum = arith.addf %in, %value : f32
linalg.yield %sum : f32
}
return
}
}

module @in_place_cond_add_followed_by_add3{
func.func @in_place_add(%buffer: memref<128xf32> {llvm.noalias}, %value: f32, %cond: i1, %cond2: i1) {
%c0 = arith.constant 0 : index
//%buffer = memref.alloca() : memref<128xf32>
scf.if %cond2 {
scf.if %cond {
linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%buffer : memref<128xf32>)
outs(%buffer : memref<128xf32>) {
^bb0(%in: f32, %out: f32):
%sum = arith.addf %in, %value : f32
%sum2 = arith.addf %sum, %value : f32
linalg.yield %sum2 : f32
}
}
}
scf.if %cond2 {
linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%buffer : memref<128xf32>)
outs(%buffer : memref<128xf32>) {
^bb0(%in: f32, %out: f32):
%sum = arith.addf %in, %value : f32
%sum2 = arith.addf %sum, %value : f32
%sum3 = arith.addf %sum2, %value : f32
linalg.yield %sum3 : f32
}
}
linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]
} ins(%buffer : memref<128xf32>)
outs(%buffer : memref<128xf32>) {
^bb0(%in: f32, %out: f32):
%sum = arith.addf %in, %value : f32
linalg.yield %sum : f32
}
return
}
}

module @conv_2 {
func.func @main(%0: memref<515x67xi32> {llvm.noalias}, %1: memref<4x4xi32> {llvm.noalias}, %2: memref<512x64xi32> {llvm.noalias}) -> i32 attributes {llvm.linkage = #llvm.linkage<external>} {
%c0_i32 = arith.constant 0 : i32
Expand Down

0 comments on commit 6d8832f

Please sign in to comment.