@@ -41,7 +41,7 @@ struct DeferredReshape : public Deferred {
41
41
? ::mlir::IntegerAttr ()
42
42
: ::imex::getIntAttr (builder, COPY_ALWAYS ? true : false , 1 );
43
43
44
- auto aTyp = av. getType (). cast <::imex::ndarray::NDArrayType>();
44
+ auto aTyp = ::mlir:: cast<::imex::ndarray::NDArrayType>(av. getType () );
45
45
auto outTyp = imex::dist::cloneWithShape (aTyp, shape ());
46
46
47
47
auto op =
@@ -106,7 +106,7 @@ struct DeferredAsType : public Deferred {
106
106
// construct NDArrayType with same shape and given dtype
107
107
::imex::ndarray::DType ndDType = dispatch<convDType>(dtype);
108
108
auto mlirElType = ::imex::ndarray::toMLIR (builder, ndDType);
109
- auto arType = av. getType (). dyn_cast <::imex::ndarray::NDArrayType>();
109
+ auto arType = ::mlir:: dyn_cast<::imex::ndarray::NDArrayType>(av. getType () );
110
110
if (!arType) {
111
111
throw std::invalid_argument (
112
112
" Encountered unexpected ndarray type in astype." );
@@ -157,7 +157,7 @@ struct DeferredToDevice : public Deferred {
157
157
jit::DepManager &dm) override {
158
158
auto av = dm.getDependent (builder, Registry::get (_a));
159
159
160
- auto srcType = av. getType (). dyn_cast <::imex::ndarray::NDArrayType>();
160
+ auto srcType = ::mlir:: dyn_cast<::imex::ndarray::NDArrayType>(av. getType () );
161
161
if (!srcType) {
162
162
throw std::invalid_argument (
163
163
" Encountered unexpected ndarray type in to_device." );
@@ -205,6 +205,57 @@ struct DeferredToDevice : public Deferred {
205
205
}
206
206
};
207
207
208
+ struct DeferredPermuteDims : public Deferred {
209
+ id_type _array;
210
+ shape_type _axes;
211
+
212
+ DeferredPermuteDims () = default ;
213
+ DeferredPermuteDims (const array_i::future_type &array,
214
+ const shape_type &shape, const shape_type &axes)
215
+ : Deferred(array.dtype(), shape, array.device(), array.team()),
216
+ _array (array.guid()), _axes(axes) {}
217
+
218
+ bool generate_mlir (::mlir::OpBuilder &builder, const ::mlir::Location &loc,
219
+ jit::DepManager &dm) override {
220
+ auto arrayValue = dm.getDependent (builder, Registry::get (_array));
221
+
222
+ auto axesAttr = builder.getDenseI64ArrayAttr (_axes);
223
+
224
+ auto aTyp =
225
+ ::mlir::cast<::imex::ndarray::NDArrayType>(arrayValue.getType ());
226
+ auto outTyp = imex::dist::cloneWithShape (aTyp, shape ());
227
+
228
+ auto op = builder.create <::imex::ndarray::PermuteDimsOp>(
229
+ loc, outTyp, arrayValue, axesAttr);
230
+
231
+ dm.addVal (
232
+ this ->guid (), op,
233
+ [this ](uint64_t rank, void *l_allocated, void *l_aligned,
234
+ intptr_t l_offset, const intptr_t *l_sizes,
235
+ const intptr_t *l_strides, void *o_allocated, void *o_aligned,
236
+ intptr_t o_offset, const intptr_t *o_sizes,
237
+ const intptr_t *o_strides, void *r_allocated, void *r_aligned,
238
+ intptr_t r_offset, const intptr_t *r_sizes,
239
+ const intptr_t *r_strides, std::vector<int64_t > &&loffs) {
240
+ auto t = mk_tnsr (this ->guid (), _dtype, this ->shape (), this ->device (),
241
+ this ->team (), l_allocated, l_aligned, l_offset,
242
+ l_sizes, l_strides, o_allocated, o_aligned, o_offset,
243
+ o_sizes, o_strides, r_allocated, r_aligned, r_offset,
244
+ r_sizes, r_strides, std::move (loffs));
245
+ this ->set_value (std::move (t));
246
+ });
247
+
248
+ return false ;
249
+ }
250
+
251
+ FactoryId factory () const override { return F_PERMUTEDIMS; }
252
+
253
+ template <typename S> void serialize (S &ser) {
254
+ ser.template value <sizeof (_array)>(_array);
255
+ // ser.template value<sizeof(_axes)>(_axes);
256
+ }
257
+ };
258
+
208
259
FutureArray *ManipOp::reshape (const FutureArray &a, const shape_type &shape,
209
260
const py::object ©) {
210
261
auto doCopy = copy.is_none ()
@@ -229,7 +280,32 @@ FutureArray *ManipOp::to_device(const FutureArray &a,
229
280
return new FutureArray (defer<DeferredToDevice>(a.get (), device));
230
281
}
231
282
283
+ FutureArray *ManipOp::permute_dims (const FutureArray &array,
284
+ const shape_type &axes) {
285
+ auto shape = array.get ().shape ();
286
+
287
+ // verifyPermuteArray
288
+ if (shape.size () != axes.size ()) {
289
+ throw std::invalid_argument (" axes must have the same length as the shape" );
290
+ }
291
+ for (auto i = 0ul ; i < shape.size (); ++i) {
292
+ if (std::find (axes.begin (), axes.end (), i) == axes.end ()) {
293
+ throw std::invalid_argument (" axes must contain all dimensions" );
294
+ }
295
+ }
296
+
297
+ auto permutedShape = shape_type (shape.size ());
298
+ for (auto i = 0ul ; i < shape.size (); ++i) {
299
+ permutedShape[i] = shape[axes[i]];
300
+ }
301
+
302
+ return new FutureArray (
303
+ defer<DeferredPermuteDims>(array.get (), permutedShape, axes));
304
+ }
305
+
232
306
FACTORY_INIT (DeferredReshape, F_RESHAPE);
233
307
FACTORY_INIT (DeferredAsType, F_ASTYPE);
234
308
FACTORY_INIT (DeferredToDevice, F_TODEVICE);
309
+ FACTORY_INIT (DeferredPermuteDims, F_PERMUTEDIMS);
310
+
235
311
} // namespace SHARPY
0 commit comments