diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp index 3fae773c3e7..a1bc038c9a1 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp @@ -214,8 +214,13 @@ class MemRefTypeInterface } Type getShadowType(Type self, unsigned width) const { - assert(width == 1 && "unsupported width != 1"); - return self; + if (width == 1) + return self; + + auto MRT = llvm::cast(self); + SmallVector out_shape = {width}; + out_shape.append(MRT.getShape().begin(), MRT.getShape().end()); + return MRT.clone(out_shape); } Value createConjOp(Type self, OpBuilder &builder, Location loc,