@@ -41,7 +41,7 @@ struct DeferredReshape : public Deferred {
4141 ? ::mlir::IntegerAttr ()
4242 : ::imex::getIntAttr (builder, COPY_ALWAYS ? true : false , 1 );
4343
44- auto aTyp = av. getType (). cast <::imex::ndarray::NDArrayType>();
44+ auto aTyp = ::mlir:: cast<::imex::ndarray::NDArrayType>(av. getType () );
4545 auto outTyp = imex::dist::cloneWithShape (aTyp, shape ());
4646
4747 auto op =
@@ -106,7 +106,7 @@ struct DeferredAsType : public Deferred {
106106 // construct NDArrayType with same shape and given dtype
107107 ::imex::ndarray::DType ndDType = dispatch<convDType>(dtype);
108108 auto mlirElType = ::imex::ndarray::toMLIR (builder, ndDType);
109- auto arType = av. getType (). dyn_cast <::imex::ndarray::NDArrayType>();
109+ auto arType = ::mlir:: dyn_cast<::imex::ndarray::NDArrayType>(av. getType () );
110110 if (!arType) {
111111 throw std::invalid_argument (
112112 " Encountered unexpected ndarray type in astype." );
@@ -157,7 +157,7 @@ struct DeferredToDevice : public Deferred {
157157 jit::DepManager &dm) override {
158158 auto av = dm.getDependent (builder, Registry::get (_a));
159159
160- auto srcType = av. getType (). dyn_cast <::imex::ndarray::NDArrayType>();
160+ auto srcType = ::mlir:: dyn_cast<::imex::ndarray::NDArrayType>(av. getType () );
161161 if (!srcType) {
162162 throw std::invalid_argument (
163163 " Encountered unexpected ndarray type in to_device." );
@@ -205,6 +205,57 @@ struct DeferredToDevice : public Deferred {
205205 }
206206};
207207
208+ struct DeferredPermuteDims : public Deferred {
209+ id_type _array;
210+ shape_type _axes;
211+
212+ DeferredPermuteDims () = default ;
213+ DeferredPermuteDims (const array_i::future_type &array,
214+ const shape_type &shape, const shape_type &axes)
215+ : Deferred(array.dtype(), shape, array.device(), array.team()),
216+ _array (array.guid()), _axes(axes) {}
217+
218+ bool generate_mlir (::mlir::OpBuilder &builder, const ::mlir::Location &loc,
219+ jit::DepManager &dm) override {
220+ auto arrayValue = dm.getDependent (builder, Registry::get (_array));
221+
222+ auto axesAttr = builder.getDenseI64ArrayAttr (_axes);
223+
224+ auto aTyp =
225+ ::mlir::cast<::imex::ndarray::NDArrayType>(arrayValue.getType ());
226+ auto outTyp = imex::dist::cloneWithShape (aTyp, shape ());
227+
228+ auto op = builder.create <::imex::ndarray::PermuteDimsOp>(
229+ loc, outTyp, arrayValue, axesAttr);
230+
231+ dm.addVal (
232+ this ->guid (), op,
233+ [this ](uint64_t rank, void *l_allocated, void *l_aligned,
234+ intptr_t l_offset, const intptr_t *l_sizes,
235+ const intptr_t *l_strides, void *o_allocated, void *o_aligned,
236+ intptr_t o_offset, const intptr_t *o_sizes,
237+ const intptr_t *o_strides, void *r_allocated, void *r_aligned,
238+ intptr_t r_offset, const intptr_t *r_sizes,
239+ const intptr_t *r_strides, std::vector<int64_t > &&loffs) {
240+ auto t = mk_tnsr (this ->guid (), _dtype, this ->shape (), this ->device (),
241+ this ->team (), l_allocated, l_aligned, l_offset,
242+ l_sizes, l_strides, o_allocated, o_aligned, o_offset,
243+ o_sizes, o_strides, r_allocated, r_aligned, r_offset,
244+ r_sizes, r_strides, std::move (loffs));
245+ this ->set_value (std::move (t));
246+ });
247+
248+ return false ;
249+ }
250+
251+ FactoryId factory () const override { return F_PERMUTEDIMS; }
252+
253+ template <typename S> void serialize (S &ser) {
254+ ser.template value <sizeof (_array)>(_array);
255+ // ser.template value<sizeof(_axes)>(_axes);
256+ }
257+ };
258+
208259FutureArray *ManipOp::reshape (const FutureArray &a, const shape_type &shape,
209260 const py::object ©) {
210261 auto doCopy = copy.is_none ()
@@ -229,7 +280,32 @@ FutureArray *ManipOp::to_device(const FutureArray &a,
229280 return new FutureArray (defer<DeferredToDevice>(a.get (), device));
230281}
231282
283+ FutureArray *ManipOp::permute_dims (const FutureArray &array,
284+ const shape_type &axes) {
285+ auto shape = array.get ().shape ();
286+
287+ // verifyPermuteArray
288+ if (shape.size () != axes.size ()) {
289+ throw std::invalid_argument (" axes must have the same length as the shape" );
290+ }
291+ for (auto i = 0ul ; i < shape.size (); ++i) {
292+ if (std::find (axes.begin (), axes.end (), i) == axes.end ()) {
293+ throw std::invalid_argument (" axes must contain all dimensions" );
294+ }
295+ }
296+
297+ auto permutedShape = shape_type (shape.size ());
298+ for (auto i = 0ul ; i < shape.size (); ++i) {
299+ permutedShape[i] = shape[axes[i]];
300+ }
301+
302+ return new FutureArray (
303+ defer<DeferredPermuteDims>(array.get (), permutedShape, axes));
304+ }
305+
232306FACTORY_INIT (DeferredReshape, F_RESHAPE);
233307FACTORY_INIT (DeferredAsType, F_ASTYPE);
234308FACTORY_INIT (DeferredToDevice, F_TODEVICE);
309+ FACTORY_INIT (DeferredPermuteDims, F_PERMUTEDIMS);
310+
235311} // namespace SHARPY
0 commit comments