Skip to content

Commit fca0576

Browse files
committed
fix: temporarily disable dealloc
1 parent c2b5dde commit fca0576

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

src/enzyme_ad/jax/Passes/LowerTritonExtensionOps.cpp

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -93,22 +93,27 @@ struct JITCallScratchMemoryLowering
9393
allocOp.getResult());
9494
rewriter.replaceAllUsesWith(fnBody.getArgument(idx), ptrOp.getResult());
9595

96-
SmallVector<Value> deps;
97-
Operation *lastUser = ptrOp;
98-
for (auto u : ptrOp->getUsers()) {
99-
if (auto gpuLaunchOp = dyn_cast<gpu::LaunchFuncOp>(u)) {
100-
deps.push_back(gpuLaunchOp.getAsyncToken());
101-
}
102-
103-
if (lastUser->isBeforeInBlock(u)) {
104-
lastUser = u;
105-
}
106-
}
107-
108-
rewriter.setInsertionPointAfter(lastUser);
109-
gpu::DeallocOp::create(rewriter, op.getLoc(),
110-
gpu::AsyncTokenType::get(rewriter.getContext()),
111-
ValueRange(deps), allocOp.getResult());
96+
// clang-format off
97+
// FIXME: This is producing
98+
// error: 'llvm.call' op operand type mismatch for operand 0: '!llvm.ptr<1>' != '!llvm.ptr'
99+
// see current operation: "llvm.call"(%61, %60) <{CConv = #llvm.cconv<ccc>, TailCallKind = #llvm.tailcallkind<none>, callee = @mgpuMemFree, fastmathFlags = #llvm.fastmath<none>, op_bundle_sizes = array<i32>, operandSegmentSizes = array<i32: 2, 0>}> : (!llvm.ptr<1>, !llvm.ptr) -> ()
100+
// SmallVector<Value> deps;
101+
// Operation *lastUser = ptrOp;
102+
// for (auto u : ptrOp->getUsers()) {
103+
// if (auto gpuLaunchOp = dyn_cast<gpu::LaunchFuncOp>(u)) {
104+
// deps.push_back(gpuLaunchOp.getAsyncToken());
105+
// }
106+
107+
// if (lastUser->isBeforeInBlock(u)) {
108+
// lastUser = u;
109+
// }
110+
// }
111+
112+
// rewriter.setInsertionPointAfter(lastUser);
113+
// gpu::DeallocOp::create(rewriter, op.getLoc(),
114+
// gpu::AsyncTokenType::get(rewriter.getContext()),
115+
// ValueRange(deps), allocOp.getResult());
116+
// clang-format on
112117
}
113118

114119
funcOpInterface.eraseArguments(rewriteScratchMemoryIdxs);

0 commit comments

Comments
 (0)