Skip to content

Commit e4895f0

Browse files
committed
implement permute_dims
1 parent 42a3edd commit e4895f0

19 files changed

+582
-15
lines changed

.pre-commit-config.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ repos:
1919
- id: trailing-whitespace
2020
exclude: '.*\.patch'
2121
- repo: https://github.com/psf/black
22-
rev: 24.3.0
22+
rev: 24.8.0
2323
hooks:
2424
- id: black
2525
args: ["--line-length", "80"]
2626
language_version: python3
2727
- repo: https://github.com/PyCQA/bandit
28-
rev: '1.7.8'
28+
rev: '1.7.9'
2929
hooks:
3030
- id: bandit
3131
args: ["-c", ".bandit.yml"]
@@ -35,7 +35,7 @@ repos:
3535
- id: isort
3636
name: isort (python)
3737
- repo: https://github.com/pycqa/flake8
38-
rev: 7.0.0
38+
rev: 7.1.1
3939
hooks:
4040
- id: flake8
4141
- repo: https://github.com/pocc/pre-commit-hooks

CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ include_directories(
149149
${PROJECT_SOURCE_DIR}/third_party/bitsery/include
150150
${MPI_INCLUDE_PATH}
151151
${pybind11_INCLUDE_DIRS}
152+
${LLVM_INCLUDE_DIRS}
152153
${MLIR_INCLUDE_DIRS}
153154
${IMEX_INCLUDE_DIRS})
154155

examples/transposed3d.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import numpy as np
2+
3+
import sharpy as sp
4+
5+
6+
def sp_tranposed3d_1():
7+
a = sp.arange(0, 2 * 3 * 4, 1)
8+
a = sp.reshape(a, [2, 3, 4])
9+
10+
# b = a.swapaxes(1,0).swapaxes(1,2)
11+
b = sp.permute_dims(a, (1, 0, 2)) # 2x4x4 -> 4x2x4 || 4x4x4
12+
b = sp.permute_dims(b, (0, 2, 1)) # 4x2x4 -> 4x4x2 || 4x4x4
13+
14+
# c = b.swapaxes(1,2).swapaxes(1,0)
15+
c = sp.permute_dims(b, (0, 2, 1))
16+
c = sp.permute_dims(c, (1, 0, 2))
17+
18+
assert np.allclose(sp.to_numpy(a), sp.to_numpy(c))
19+
return b
20+
21+
22+
def sp_tranposed3d_2():
23+
a = sp.arange(0, 2 * 3 * 4, 1)
24+
a = sp.reshape(a, [2, 3, 4])
25+
26+
# b = a.swapaxes(2,1).swapaxes(2,0)
27+
b = sp.permute_dims(a, (0, 2, 1))
28+
b = sp.permute_dims(b, (2, 1, 0))
29+
30+
# c = b.swapaxes(2,1).swapaxes(0,1)
31+
c = sp.permute_dims(b, (0, 2, 1))
32+
c = sp.permute_dims(c, (1, 0, 2))
33+
34+
return c
35+
36+
37+
def np_tranposed3d_1():
38+
a = np.arange(0, 2 * 3 * 4, 1)
39+
a = np.reshape(a, [2, 3, 4])
40+
b = a.swapaxes(1, 0).swapaxes(1, 2)
41+
return b
42+
43+
44+
def np_tranposed3d_2():
45+
a = np.arange(0, 2 * 3 * 4, 1)
46+
a = np.reshape(a, [2, 3, 4])
47+
b = a.swapaxes(2, 1).swapaxes(2, 0)
48+
c = b.swapaxes(2, 1).swapaxes(0, 1)
49+
return c
50+
51+
52+
sp.init(False)
53+
54+
b1 = sp_tranposed3d_1()
55+
assert np.allclose(sp.to_numpy(b1), np_tranposed3d_1())
56+
57+
b2 = sp_tranposed3d_2()
58+
assert np.allclose(sp.to_numpy(b2), np_tranposed3d_2())
59+
60+
sp.fini()

setup.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import multiprocessing
12
import os
23
import pathlib
34

@@ -44,7 +45,10 @@ def build_cmake(self, ext):
4445
os.chdir(str(build_temp))
4546
self.spawn(["cmake", str(cwd)] + cmake_args)
4647
if not self.dry_run:
47-
self.spawn(["cmake", "--build", ".", "-j5"] + build_args)
48+
self.spawn(
49+
["cmake", "--build", ".", f"-j{multiprocessing.cpu_count()}"]
50+
+ build_args
51+
)
4852
# Troubleshooting: if fail on line above then delete all possible
4953
# temporary CMake files including "CMakeCache.txt" in top level dir.
5054
os.chdir(str(cwd))

sharpy/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ def _validate_device(device):
130130
exec(
131131
f"{func} = lambda this, shape, cp=None: ndarray(_csp.ManipOp.reshape(this._t, shape, cp))"
132132
)
133+
elif func == "permute_dims":
134+
exec(
135+
f"{func} = lambda this, axes: ndarray(_csp.ManipOp.permute_dims(this._t, axes))"
136+
)
133137

134138
for func in api.api_categories["ReduceOp"]:
135139
FUNC = func.upper()

sharpy/array_api.py

+1
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@
179179
"roll", # (x, /, shift, *, axis=None)
180180
"squeeze", # (x, /, axis)
181181
"stack", # (arrays, /, *, axis=0)
182+
"permute_dims", # (x: array, /, axes: Tuple[int, ...]) → array
182183
],
183184
"LinAlgOp": [
184185
"matmul", # (x1, x2, /)

src/EWBinOp.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ struct DeferredEWBinOp : public Deferred {
120120
auto av = dm.getDependent(builder, Registry::get(_a));
121121
auto bv = dm.getDependent(builder, Registry::get(_b));
122122

123-
auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
123+
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
124124
auto outElemType =
125125
::imex::ndarray::toMLIR(builder, SHARPY::jit::getPTDType(_dtype));
126126
auto outTyp = aTyp.cloneWith(shape(), outElemType);

src/EWUnyOp.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ struct DeferredEWUnyOp : public Deferred {
105105
jit::DepManager &dm) override {
106106
auto av = dm.getDependent(builder, Registry::get(_a));
107107

108-
auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
108+
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
109109
auto outTyp = aTyp.cloneWith(shape(), aTyp.getElementType());
110110

111111
auto ndOpId = sharpy(_op);

src/IEWBinOp.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ struct DeferredIEWBinOp : public Deferred {
7171
auto av = dm.getDependent(builder, Registry::get(_a));
7272
auto bv = dm.getDependent(builder, Registry::get(_b));
7373

74-
auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
74+
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
7575
auto outTyp = aTyp.cloneWith(shape(), aTyp.getElementType());
7676

7777
auto binop = builder.create<::imex::ndarray::EWBinOp>(

src/ManipOp.cpp

+79-3
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ struct DeferredReshape : public Deferred {
4141
? ::mlir::IntegerAttr()
4242
: ::imex::getIntAttr(builder, COPY_ALWAYS ? true : false, 1);
4343

44-
auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
44+
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
4545
auto outTyp = imex::dist::cloneWithShape(aTyp, shape());
4646

4747
auto op =
@@ -106,7 +106,7 @@ struct DeferredAsType : public Deferred {
106106
// construct NDArrayType with same shape and given dtype
107107
::imex::ndarray::DType ndDType = dispatch<convDType>(dtype);
108108
auto mlirElType = ::imex::ndarray::toMLIR(builder, ndDType);
109-
auto arType = av.getType().dyn_cast<::imex::ndarray::NDArrayType>();
109+
auto arType = ::mlir::dyn_cast<::imex::ndarray::NDArrayType>(av.getType());
110110
if (!arType) {
111111
throw std::invalid_argument(
112112
"Encountered unexpected ndarray type in astype.");
@@ -157,7 +157,7 @@ struct DeferredToDevice : public Deferred {
157157
jit::DepManager &dm) override {
158158
auto av = dm.getDependent(builder, Registry::get(_a));
159159

160-
auto srcType = av.getType().dyn_cast<::imex::ndarray::NDArrayType>();
160+
auto srcType = ::mlir::dyn_cast<::imex::ndarray::NDArrayType>(av.getType());
161161
if (!srcType) {
162162
throw std::invalid_argument(
163163
"Encountered unexpected ndarray type in to_device.");
@@ -205,6 +205,57 @@ struct DeferredToDevice : public Deferred {
205205
}
206206
};
207207

208+
struct DeferredPermuteDims : public Deferred {
209+
id_type _array;
210+
shape_type _axes;
211+
212+
DeferredPermuteDims() = default;
213+
DeferredPermuteDims(const array_i::future_type &array,
214+
const shape_type &shape, const shape_type &axes)
215+
: Deferred(array.dtype(), shape, array.device(), array.team()),
216+
_array(array.guid()), _axes(axes) {}
217+
218+
bool generate_mlir(::mlir::OpBuilder &builder, const ::mlir::Location &loc,
219+
jit::DepManager &dm) override {
220+
auto arrayValue = dm.getDependent(builder, Registry::get(_array));
221+
222+
auto axesAttr = builder.getDenseI64ArrayAttr(_axes);
223+
224+
auto aTyp =
225+
::mlir::cast<::imex::ndarray::NDArrayType>(arrayValue.getType());
226+
auto outTyp = imex::dist::cloneWithShape(aTyp, shape());
227+
228+
auto op = builder.create<::imex::ndarray::PermuteDimsOp>(
229+
loc, outTyp, arrayValue, axesAttr);
230+
231+
dm.addVal(
232+
this->guid(), op,
233+
[this](uint64_t rank, void *l_allocated, void *l_aligned,
234+
intptr_t l_offset, const intptr_t *l_sizes,
235+
const intptr_t *l_strides, void *o_allocated, void *o_aligned,
236+
intptr_t o_offset, const intptr_t *o_sizes,
237+
const intptr_t *o_strides, void *r_allocated, void *r_aligned,
238+
intptr_t r_offset, const intptr_t *r_sizes,
239+
const intptr_t *r_strides, std::vector<int64_t> &&loffs) {
240+
auto t = mk_tnsr(this->guid(), _dtype, this->shape(), this->device(),
241+
this->team(), l_allocated, l_aligned, l_offset,
242+
l_sizes, l_strides, o_allocated, o_aligned, o_offset,
243+
o_sizes, o_strides, r_allocated, r_aligned, r_offset,
244+
r_sizes, r_strides, std::move(loffs));
245+
this->set_value(std::move(t));
246+
});
247+
248+
return false;
249+
}
250+
251+
FactoryId factory() const override { return F_PERMUTEDIMS; }
252+
253+
template <typename S> void serialize(S &ser) {
254+
ser.template value<sizeof(_array)>(_array);
255+
// ser.template value<sizeof(_axes)>(_axes);
256+
}
257+
};
258+
208259
FutureArray *ManipOp::reshape(const FutureArray &a, const shape_type &shape,
209260
const py::object &copy) {
210261
auto doCopy = copy.is_none()
@@ -229,7 +280,32 @@ FutureArray *ManipOp::to_device(const FutureArray &a,
229280
return new FutureArray(defer<DeferredToDevice>(a.get(), device));
230281
}
231282

283+
FutureArray *ManipOp::permute_dims(const FutureArray &array,
284+
const shape_type &axes) {
285+
auto shape = array.get().shape();
286+
287+
// verifyPermuteArray
288+
if (shape.size() != axes.size()) {
289+
throw std::invalid_argument("axes must have the same length as the shape");
290+
}
291+
for (auto i = 0ul; i < shape.size(); ++i) {
292+
if (std::find(axes.begin(), axes.end(), i) == axes.end()) {
293+
throw std::invalid_argument("axes must contain all dimensions");
294+
}
295+
}
296+
297+
auto permutedShape = shape_type(shape.size());
298+
for (auto i = 0ul; i < shape.size(); ++i) {
299+
permutedShape[i] = shape[axes[i]];
300+
}
301+
302+
return new FutureArray(
303+
defer<DeferredPermuteDims>(array.get(), permutedShape, axes));
304+
}
305+
232306
FACTORY_INIT(DeferredReshape, F_RESHAPE);
233307
FACTORY_INIT(DeferredAsType, F_ASTYPE);
234308
FACTORY_INIT(DeferredToDevice, F_TODEVICE);
309+
FACTORY_INIT(DeferredPermuteDims, F_PERMUTEDIMS);
310+
235311
} // namespace SHARPY

src/NDArray.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@ void NDArray::NDADeleter::operator()(NDArray *a) const {
113113
std::cerr << "sharpy fini: detected possible memory leak\n";
114114
} else {
115115
auto av = dm.addDependent(builder, a);
116-
builder.create<::imex::ndarray::DeleteOp>(loc, av);
116+
auto deleteOp = builder.create<::imex::ndarray::DeleteOp>(loc, av);
117+
deleteOp->setAttr("bufferization.manual_deallocation",
118+
builder.getUnitAttr());
117119
dm.drop(a->guid());
118120
}
119121
return false;

src/ReduceOp.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ struct DeferredReduceOp : public Deferred {
6161
// FIXME reduction over individual dimensions is not supported
6262
auto av = dm.getDependent(builder, Registry::get(_a));
6363
// return type 0d with same dtype as input
64-
auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
64+
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
6565
auto outTyp = ::imex::dist::cloneWithShape(aTyp, shape());
6666
// reduction op
6767
auto mop = sharpy2mlir(_op);

src/SetGetItem.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ struct DeferredGetItem : public Deferred {
277277
const auto &offs = _slc.offsets();
278278
const auto &sizes = shape();
279279
const auto &strides = _slc.strides();
280-
auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
280+
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
281281
auto outTyp = ::imex::dist::cloneWithShape(aTyp, shape());
282282

283283
// now we can create the NDArray op using the above Values

src/_sharpy.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,9 @@ PYBIND11_MODULE(_sharpy, m) {
196196
py::class_<IEWBinOp>(m, "IEWBinOp").def("op", &IEWBinOp::op);
197197
py::class_<EWBinOp>(m, "EWBinOp").def("op", &EWBinOp::op);
198198
py::class_<ReduceOp>(m, "ReduceOp").def("op", &ReduceOp::op);
199-
py::class_<ManipOp>(m, "ManipOp").def("reshape", &ManipOp::reshape);
199+
py::class_<ManipOp>(m, "ManipOp")
200+
.def("reshape", &ManipOp::reshape)
201+
.def("permute_dims", &ManipOp::permute_dims);
200202
py::class_<LinAlgOp>(m, "LinAlgOp").def("vecdot", &LinAlgOp::vecdot);
201203

202204
py::class_<FutureArray>(m, "SHARPYFuture")

0 commit comments

Comments
 (0)