-
Notifications
You must be signed in to change notification settings - Fork 12.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Fix consumer fusion for producer with multiple results #125915
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-scf Author: Prashant Kumar (pashu123) ChangesIn the case of consumer fusion where the producer is producing multiple results all used by a single consumer for e.g., %results:3 = scf.forall ... -> (tensor<...>, tensor<...>, tensor<...>) { // Produces 3 results all other operands of the tiled consumer needs to updated. Full diff: https://github.com/llvm/llvm-project/pull/125915.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index b548f8ce8b560b1..bca727de3ddb3f6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1949,6 +1949,60 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
}
}
+// If the producer of the operand is a loopLikeOp, then finds the last
+// insertSlice/parallelInsertSlice in the producer op that uses the block
+// argument corresponding to the operand.
+static FailureOr<Operation *>
+getSliceOpFromConsumerOperand(OpOperand &operand) {
+
+ OpResult producerResult = dyn_cast<OpResult>(operand.get());
+ if (!producerResult)
+ return failure();
+
+ LoopLikeOpInterface loopLikeOp =
+ dyn_cast<LoopLikeOpInterface>(producerResult.getOwner());
+ if (!loopLikeOp)
+ return failure();
+
+ // Obtain the BlockArgument correponding to the result.
+ BlockArgument bbArg =
+ loopLikeOp.getRegionIterArgs()[producerResult.getResultNumber()];
+
+ // Finally return the operation corresponding to the yielded value.
+ // Also check whether it's an InsertSliceOp.
+ if (dyn_cast<scf::ForOp>(producerResult.getOwner())) {
+ OpOperand *yieldVal = loopLikeOp.getTiedLoopYieldedValue(bbArg);
+ Operation *lastOp = dyn_cast<OpResult>(yieldVal->get()).getOwner();
+ auto isInsertSliceOp = isa<tensor::InsertSliceOp>(lastOp);
+ if (!isInsertSliceOp) {
+ return failure();
+ }
+ return lastOp;
+ }
+
+ auto forallOp = dyn_cast<scf::ForallOp>(producerResult.getOwner());
+ if (!forallOp)
+ return failure();
+
+ // Iterate over the terminator operation of the forallOp to find the last
+ // parallelInsertSliceOp that uses the blockArgument.
+ Operation *lastOp = nullptr;
+ forallOp.getTerminator()->walk([&](tensor::ParallelInsertSliceOp op) {
+ for (mlir::Value operand : op->getOperands()) {
+ if (auto maybeBlockArg = dyn_cast<BlockArgument>(operand)) {
+ if (maybeBlockArg == bbArg) {
+ lastOp = op;
+ }
+ }
+ }
+ });
+
+ if (!lastOp)
+ return failure();
+
+ return lastOp;
+}
+
/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1979,6 +2033,26 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
}
+ SmallVector<OpOperand *> potentialOperands{*maybeConsumerOpOperand};
+ SmallVector<unsigned> potentialOperandResultNos{
+ consumerOpOperand->getOperandNumber()};
+ SmallVector<Operation *> potentialSliceOps{candidateSliceOp};
+
+ // 1b. Get all the other operands of the consumer op and their corresponding
+ // slice ops. In the case of the consumer consuming using multiple results
+ // from the producer, we need to update every operand.
+ for (OpOperand &otherOperand : consumerOp->getOpOperands()) {
+ if (&otherOperand == *maybeConsumerOpOperand)
+ continue;
+ auto maybePotentialSlice = getSliceOpFromConsumerOperand(otherOperand);
+ if (failed(maybePotentialSlice)) {
+ continue;
+ }
+ potentialSliceOps.push_back(*maybePotentialSlice);
+ potentialOperands.push_back(&otherOperand);
+ potentialOperandResultNos.push_back(otherOperand.getOperandNumber());
+ }
+
// There are two possible cases regarding `oldLoopOp` here:
// 1. single `scf.forall` or `scf.for`.
// 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
@@ -2037,18 +2111,29 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
// tensor.insert_slice. In the scf.for case this is a clone of the
// candidateSliceOp whereas in the scf.forall case this is created from the
// operands of tensor.parallel_insert_slice.
- tensor::InsertSliceOp clonedInsertSliceOp;
+
+ SmallVector<tensor::InsertSliceOp> allClonedInsertSliceOps;
+
+ scf::ForallOp newForallOp;
if (auto sliceOp =
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
rewriter.setInsertionPoint(newForallOp.getTerminator());
- clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
- loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
- sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
} else {
- rewriter.setInsertionPoint(candidateSliceOp);
- clonedInsertSliceOp =
- cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
+ rewriter.setInsertionPoint(potentialSliceOps.back());
+ }
+
+ for (auto *candidateSliceOp : potentialSliceOps) {
+ if (auto sliceOp =
+ dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+ allClonedInsertSliceOps.push_back(rewriter.create<tensor::InsertSliceOp>(
+ loc, sliceOp.getSource(), sliceOp.getDest(),
+ sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
+ sliceOp.getMixedStrides()));
+ } else {
+ allClonedInsertSliceOps.push_back(
+ cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp)));
+ }
}
// 5.a. Clone consumer op.
@@ -2056,24 +2141,34 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
// 5.b. Replace all uses of the loop result with the result of the cloned
// tensor.insert_slice.
- OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
- rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
- operandToReplace.set(clonedInsertSliceOp.getResult());
- });
+ for (const auto &it : llvm::enumerate(allClonedInsertSliceOps)) {
+ OpOperand &operandToReplace =
+ clonedConsumerOp->getOpOperand(potentialOperandResultNos[it.index()]);
+ rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
+ operandToReplace.set(it.value().getResult());
+ });
+ }
// 6. Perform tiling of the cloned consumer and replace the operand at
// `operandNumber` with the source of the cloned tensor.insert_slice op.
- auto ossSliceOp =
- cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
+ auto ossSliceOp = cast<OffsetSizeAndStrideOpInterface>(
+ allClonedInsertSliceOps.front().getOperation());
FailureOr<TilingResult> tileAndFuseResult =
tensor::replaceInsertSliceWithTiledConsumer(
rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
+
if (failed(tileAndFuseResult)) {
return failure();
}
+
auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
- rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
- clonedInsertSliceOp.getSource());
+
+ // 6b. Update the tiled consumer op with the new operands.
+ for (const auto &it : llvm::enumerate(allClonedInsertSliceOps)) {
+ rewriter.replaceAllUsesWith(
+ tiledConsumerOp->getOperand(potentialOperandResultNos[it.index()]),
+ it.value().getSource());
+ }
// 7. Reconstruct [nested] loop with new inits.
YieldTiledValuesFn newYieldValuesFn =
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index a2871b30698c527..14b9ec504c1585e 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -282,7 +282,7 @@ module {
return %unpack : tensor<2048xf32>
}
}
-
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
@@ -343,7 +343,7 @@ module {
return %unpack : tensor<2047xf32>
}
}
-
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
@@ -404,7 +404,7 @@ module {
return %pack : tensor<4x32x16xf32>
}
}
-
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
@@ -610,7 +610,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
// CHECK: %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
-// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
+// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
// CHECK-SAME: {
// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
// CHECK: %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1]
@@ -676,3 +676,127 @@ module attributes {transform.with_named_sequence} {
// CHECK: }
// CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice
// CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]
+
+// -----
+
+module {
+ func.func @forall_producer_multiple_result_single_consumer(%arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) {
+ %outs = tensor.empty() : tensor<32x32xf32>
+ %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+ %3 = linalg.matmul ins(%extracted_slice, %extracted_slice : tensor<32x32xf32>, tensor<32x32xf32>) outs(%outs : tensor<32x32xf32>) -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+ tensor.parallel_insert_slice %extracted_slice into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+ }
+ }
+ %final_out = tensor.empty() : tensor<64x64xf32>
+ %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#0, %1#1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%final_out : tensor<64x64xf32>) -> tensor<64x64xf32>
+ return %2 : tensor<64x64xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %consumer, %fused_consumer = transform.test.fuse_consumer %1#0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @forall_producer_multiple_result_single_consumer(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<64x64xf32>
+
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<64x64xf32>
+// CHECK: %[[LOOP_RESULT:.+]]:3 = scf.forall (%[[I:.+]], %[[J:.+]]) in (2, 2) shared_outs(%[[SHARED0:.+]] = %[[ARG0]], %[[SHARED1:.+]] = %[[ARG0]], %[[SHARED2:.+]] = %[[INIT]])
+
+// CHECK: %[[TILE_INIT:.+]] = tensor.empty() : tensor<32x32xf32>
+// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE]] : tensor<32x32xf32>, tensor<32x32xf32>) outs(%[[TILE_INIT]] : tensor<32x32xf32>)
+// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[MATMUL]] into %[[SHARED1]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK: %[[INSERTED_SLICE0:.+]] = tensor.insert_slice %[[EXTRACTED_SLICE]] into %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK: %[[EXTRACTED_SLICE1:.+]] = tensor.extract_slice %[[SHARED2]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK: %[[ADD:.+]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%[[EXTRACTED_SLICE]], %[[MATMUL]] : tensor<32x32xf32>, tensor<32x32xf32>) outs(%[[EXTRACTED_SLICE1]] : tensor<32x32xf32>)
+
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[MATMUL]] into %[[SHARED1]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[EXTRACTED_SLICE]] into %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[ADD]] into %[[SHARED2]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK: }
+
+// CHECK: return %[[LOOP_RESULT]]#2 : tensor<64x64xf32>
+
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+module {
+ func.func @for_producer_producing_multiple_result_single_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
+ %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.addf %out, %13 : f32
+ linalg.yield %14 : f32
+ } -> tensor<32xf32>
+ %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+ %5 = tensor.insert_slice %3 into %arg5[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+ scf.yield %5, %4 : tensor<64xf32>, tensor<64xf32>
+ }
+ %out_operand = tensor.empty() : tensor<64xf32>
+ %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#1, %1#0 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand : tensor<64xf32>) -> tensor<64xf32>
+ return %2 : tensor<64xf32>
+ }
+ }
+
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %consumer, %fused_consumer = transform.test.fuse_consumer %1#0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ }
+
+// CHECK-LABEL: func.func @for_producer_producing_multiple_result_single_consumer(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<32xf32>,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<32xf32>,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<64xf32>
+
+// CHECK: %[[C4:.+]] = arith.constant 4 : index
+// CHECK: %[[C64:.+]] = arith.constant 64 : index
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<64xf32>
+
+// CHECK: %[[LOOP_RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C64]] step %[[C4]]
+// CHECK-SAME: iter_args(%[[ITER0:.+]] = %[[ARG2]], %[[ITER1:.+]] = %[[ARG2]], %[[ITER2:.+]] = %[[INIT]])
+// CHECK-SAME: -> (tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
+
+// CHECK: %[[EXTRACT_SLICE:.+]] = tensor.extract_slice %[[ITER0]][%[[IV]]] [32] [1]
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<32xf32>, tensor<32xf32>)
+// CHECK-SAME: outs(%[[EXTRACT_SLICE]] : tensor<32xf32>)
+// CHECK: ^{{.*}}(%[[IN0:.+]]: f32, %[[IN1:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK: %[[MUL:.+]] = arith.mulf %[[IN0]], %[[IN1]] : f32
+// CHECK: %[[ADD:.+]] = arith.addf %[[OUT]], %[[MUL]] : f32
+// CHECK: linalg.yield %[[ADD]] : f32
+
+// CHECK: %[[INSERT_SLICE0:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER0]][%[[IV]]] [32] [1]
+// CHECK: %[[INSERT_SLICE1:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER1]][%[[IV]]] [32] [1]
+// CHECK: %[[EXTRACT_SLICE2:.+]] = tensor.extract_slice %[[ITER2]][%[[IV]]] [32] [1]
+// CHECK: %[[BINARY:.+]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
+// CHECK-SAME: ins(%[[GENERIC]], %[[GENERIC]] : tensor<32xf32>, tensor<32xf32>)
+// CHECK-SAME: outs(%[[EXTRACT_SLICE2]] : tensor<32xf32>)
+// CHECK: %[[INSERT_SLICE2:.+]] = tensor.insert_slice %[[BINARY]] into %[[ITER2]][%[[IV]]] [32] [1]
+
+// CHECK: scf.yield %[[INSERT_SLICE1]], %[[INSERT_SLICE0]], %[[INSERT_SLICE2]]
+
+// CHECK: return %[[LOOP_RESULT]]#2 : tensor<64xf32>
|
In the case of consumer fusion where the producer is producing multiple results all used by a single consumer for e.g., %results:3 = scf.forall ... -> (tensor<...>, tensor<...>, tensor<...>) { // Produces 3 results scf.yield %a, %b, %c : tensor<...>, tensor<...>, tensor<...>} // Consumer uses all 3 results %final = consumer %results#0, %results#1, %results#2 all other operands of the tiled consumer needs to updated.
@@ -1979,6 +2033,26 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter, | |||
consumerOp, "consumer op's operand doesn't seem to be an OpResult"); | |||
} | |||
|
|||
SmallVector<OpOperand *> potentialOperands = {*maybeConsumerOpOperand}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please leave some comments as to what this is for.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment below actually is fine. Move that up above these.
Also instead of a SmallVector<OpOperand *>
and a SmallVector<Operation *>
for the slices, just make a SmallVector<std::tuple<OpOperand *, Operation *>>
if (failed(maybePotentialSlice)) { | ||
continue; | ||
} | ||
potentialSliceOps.push_back(*maybePotentialSlice); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to check right here that the producer of otherOperand
is the same as *maybeConsumerOpOperand
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was trying to reason through this and I dont think this actually works. This is pretty fragile and can produce incorrect code. Lets chat offline about this and get back. We probably have to drop this change, and try to do something a bit more constrained in the downstream use of consumer fusion.
In the case of consumer fusion where the producer is producing multiple results all used by a single consumer for e.g.,
all other operands of the tiled consumer needs to updated.