Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement permute_dims #12

Merged
merged 15 commits into from
Nov 4, 2024
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ repos:
- id: trailing-whitespace
exclude: '.*\.patch'
- repo: https://github.com/psf/black
rev: 24.3.0
rev: 24.8.0
hooks:
- id: black
args: ["--line-length", "80"]
language_version: python3
- repo: https://github.com/PyCQA/bandit
rev: '1.7.8'
rev: '1.7.9'
hooks:
- id: bandit
args: ["-c", ".bandit.yml"]
Expand All @@ -35,7 +35,7 @@ repos:
- id: isort
name: isort (python)
- repo: https://github.com/pycqa/flake8
rev: 7.0.0
rev: 7.1.1
hooks:
- id: flake8
- repo: https://github.com/pocc/pre-commit-hooks
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ include_directories(
${PROJECT_SOURCE_DIR}/third_party/bitsery/include
${MPI_INCLUDE_PATH}
${pybind11_INCLUDE_DIRS}
${LLVM_INCLUDE_DIRS}
${MLIR_INCLUDE_DIRS}
${IMEX_INCLUDE_DIRS})

Expand Down
60 changes: 60 additions & 0 deletions examples/transposed3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import numpy as np

import sharpy as sp


def sp_tranposed3d_1():
a = sp.arange(0, 2 * 3 * 4, 1)
a = sp.reshape(a, [2, 3, 4])

# b = a.swapaxes(1,0).swapaxes(1,2)
b = sp.permute_dims(a, (1, 0, 2)) # 2x4x4 -> 4x2x4 || 4x4x4
b = sp.permute_dims(b, (0, 2, 1)) # 4x2x4 -> 4x4x2 || 4x4x4

# c = b.swapaxes(1,2).swapaxes(1,0)
c = sp.permute_dims(b, (0, 2, 1))
c = sp.permute_dims(c, (1, 0, 2))

assert np.allclose(sp.to_numpy(a), sp.to_numpy(c))
return b


def sp_tranposed3d_2():
a = sp.arange(0, 2 * 3 * 4, 1)
a = sp.reshape(a, [2, 3, 4])

# b = a.swapaxes(2,1).swapaxes(2,0)
b = sp.permute_dims(a, (0, 2, 1))
b = sp.permute_dims(b, (2, 1, 0))

# c = b.swapaxes(2,1).swapaxes(0,1)
c = sp.permute_dims(b, (0, 2, 1))
c = sp.permute_dims(c, (1, 0, 2))

return c


def np_tranposed3d_1():
a = np.arange(0, 2 * 3 * 4, 1)
a = np.reshape(a, [2, 3, 4])
b = a.swapaxes(1, 0).swapaxes(1, 2)
return b


def np_tranposed3d_2():
a = np.arange(0, 2 * 3 * 4, 1)
a = np.reshape(a, [2, 3, 4])
b = a.swapaxes(2, 1).swapaxes(2, 0)
c = b.swapaxes(2, 1).swapaxes(0, 1)
return c


sp.init(False)

b1 = sp_tranposed3d_1()
assert np.allclose(sp.to_numpy(b1), np_tranposed3d_1())

b2 = sp_tranposed3d_2()
assert np.allclose(sp.to_numpy(b2), np_tranposed3d_2())

sp.fini()
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import multiprocessing
import os
import pathlib

Expand Down Expand Up @@ -44,7 +45,10 @@ def build_cmake(self, ext):
os.chdir(str(build_temp))
self.spawn(["cmake", str(cwd)] + cmake_args)
if not self.dry_run:
self.spawn(["cmake", "--build", ".", "-j5"] + build_args)
self.spawn(
["cmake", "--build", ".", f"-j{multiprocessing.cpu_count()}"]
+ build_args
)
# Troubleshooting: if fail on line above then delete all possible
# temporary CMake files including "CMakeCache.txt" in top level dir.
os.chdir(str(cwd))
Expand Down
4 changes: 4 additions & 0 deletions sharpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ def _validate_device(device):
exec(
f"{func} = lambda this, shape, cp=None: ndarray(_csp.ManipOp.reshape(this._t, shape, cp))"
)
elif func == "permute_dims":
exec(
f"{func} = lambda this, axes: ndarray(_csp.ManipOp.permute_dims(this._t, axes))"
)

for func in api.api_categories["ReduceOp"]:
FUNC = func.upper()
Expand Down
1 change: 1 addition & 0 deletions sharpy/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@
"roll", # (x, /, shift, *, axis=None)
"squeeze", # (x, /, axis)
"stack", # (arrays, /, *, axis=0)
"permute_dims", # (x: array, /, axes: Tuple[int, ...]) → array
],
"LinAlgOp": [
"matmul", # (x1, x2, /)
Expand Down
2 changes: 1 addition & 1 deletion src/EWBinOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ struct DeferredEWBinOp : public Deferred {
auto av = dm.getDependent(builder, Registry::get(_a));
auto bv = dm.getDependent(builder, Registry::get(_b));

auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
auto outElemType =
::imex::ndarray::toMLIR(builder, SHARPY::jit::getPTDType(_dtype));
auto outTyp = aTyp.cloneWith(shape(), outElemType);
Expand Down
2 changes: 1 addition & 1 deletion src/EWUnyOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ struct DeferredEWUnyOp : public Deferred {
jit::DepManager &dm) override {
auto av = dm.getDependent(builder, Registry::get(_a));

auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
auto outTyp = aTyp.cloneWith(shape(), aTyp.getElementType());

auto ndOpId = sharpy(_op);
Expand Down
2 changes: 1 addition & 1 deletion src/IEWBinOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ struct DeferredIEWBinOp : public Deferred {
auto av = dm.getDependent(builder, Registry::get(_a));
auto bv = dm.getDependent(builder, Registry::get(_b));

auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
auto outTyp = aTyp.cloneWith(shape(), aTyp.getElementType());

auto binop = builder.create<::imex::ndarray::EWBinOp>(
Expand Down
82 changes: 79 additions & 3 deletions src/ManipOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct DeferredReshape : public Deferred {
? ::mlir::IntegerAttr()
: ::imex::getIntAttr(builder, COPY_ALWAYS ? true : false, 1);

auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
auto outTyp = imex::dist::cloneWithShape(aTyp, shape());

auto op =
Expand Down Expand Up @@ -106,7 +106,7 @@ struct DeferredAsType : public Deferred {
// construct NDArrayType with same shape and given dtype
::imex::ndarray::DType ndDType = dispatch<convDType>(dtype);
auto mlirElType = ::imex::ndarray::toMLIR(builder, ndDType);
auto arType = av.getType().dyn_cast<::imex::ndarray::NDArrayType>();
auto arType = ::mlir::dyn_cast<::imex::ndarray::NDArrayType>(av.getType());
if (!arType) {
throw std::invalid_argument(
"Encountered unexpected ndarray type in astype.");
Expand Down Expand Up @@ -157,7 +157,7 @@ struct DeferredToDevice : public Deferred {
jit::DepManager &dm) override {
auto av = dm.getDependent(builder, Registry::get(_a));

auto srcType = av.getType().dyn_cast<::imex::ndarray::NDArrayType>();
auto srcType = ::mlir::dyn_cast<::imex::ndarray::NDArrayType>(av.getType());
if (!srcType) {
throw std::invalid_argument(
"Encountered unexpected ndarray type in to_device.");
Expand Down Expand Up @@ -205,6 +205,57 @@ struct DeferredToDevice : public Deferred {
}
};

struct DeferredPermuteDims : public Deferred {
id_type _array;
shape_type _axes;

DeferredPermuteDims() = default;
DeferredPermuteDims(const array_i::future_type &array,
const shape_type &shape, const shape_type &axes)
: Deferred(array.dtype(), shape, array.device(), array.team()),
_array(array.guid()), _axes(axes) {}

bool generate_mlir(::mlir::OpBuilder &builder, const ::mlir::Location &loc,
jit::DepManager &dm) override {
auto arrayValue = dm.getDependent(builder, Registry::get(_array));

auto axesAttr = builder.getDenseI64ArrayAttr(_axes);

auto aTyp =
::mlir::cast<::imex::ndarray::NDArrayType>(arrayValue.getType());
auto outTyp = imex::dist::cloneWithShape(aTyp, shape());

auto op = builder.create<::imex::ndarray::PermuteDimsOp>(
loc, outTyp, arrayValue, axesAttr);

dm.addVal(
this->guid(), op,
[this](uint64_t rank, void *l_allocated, void *l_aligned,
intptr_t l_offset, const intptr_t *l_sizes,
const intptr_t *l_strides, void *o_allocated, void *o_aligned,
intptr_t o_offset, const intptr_t *o_sizes,
const intptr_t *o_strides, void *r_allocated, void *r_aligned,
intptr_t r_offset, const intptr_t *r_sizes,
const intptr_t *r_strides, std::vector<int64_t> &&loffs) {
auto t = mk_tnsr(this->guid(), _dtype, this->shape(), this->device(),
this->team(), l_allocated, l_aligned, l_offset,
l_sizes, l_strides, o_allocated, o_aligned, o_offset,
o_sizes, o_strides, r_allocated, r_aligned, r_offset,
r_sizes, r_strides, std::move(loffs));
this->set_value(std::move(t));
});

return false;
}

FactoryId factory() const override { return F_PERMUTEDIMS; }

template <typename S> void serialize(S &ser) {
ser.template value<sizeof(_array)>(_array);
// ser.template value<sizeof(_axes)>(_axes);
}
};

FutureArray *ManipOp::reshape(const FutureArray &a, const shape_type &shape,
const py::object &copy) {
auto doCopy = copy.is_none()
Expand All @@ -229,7 +280,32 @@ FutureArray *ManipOp::to_device(const FutureArray &a,
return new FutureArray(defer<DeferredToDevice>(a.get(), device));
}

FutureArray *ManipOp::permute_dims(const FutureArray &array,
const shape_type &axes) {
auto shape = array.get().shape();

// verifyPermuteArray
if (shape.size() != axes.size()) {
throw std::invalid_argument("axes must have the same length as the shape");
}
for (auto i = 0ul; i < shape.size(); ++i) {
if (std::find(axes.begin(), axes.end(), i) == axes.end()) {
throw std::invalid_argument("axes must contain all dimensions");
}
}

auto permutedShape = shape_type(shape.size());
for (auto i = 0ul; i < shape.size(); ++i) {
permutedShape[i] = shape[axes[i]];
}

return new FutureArray(
defer<DeferredPermuteDims>(array.get(), permutedShape, axes));
}

FACTORY_INIT(DeferredReshape, F_RESHAPE);
FACTORY_INIT(DeferredAsType, F_ASTYPE);
FACTORY_INIT(DeferredToDevice, F_TODEVICE);
FACTORY_INIT(DeferredPermuteDims, F_PERMUTEDIMS);

} // namespace SHARPY
4 changes: 3 additions & 1 deletion src/NDArray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ void NDArray::NDADeleter::operator()(NDArray *a) const {
std::cerr << "sharpy fini: detected possible memory leak\n";
} else {
auto av = dm.addDependent(builder, a);
builder.create<::imex::ndarray::DeleteOp>(loc, av);
auto deleteOp = builder.create<::imex::ndarray::DeleteOp>(loc, av);
deleteOp->setAttr("bufferization.manual_deallocation",
builder.getUnitAttr());
dm.drop(a->guid());
}
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/ReduceOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ struct DeferredReduceOp : public Deferred {
// FIXME reduction over individual dimensions is not supported
auto av = dm.getDependent(builder, Registry::get(_a));
// return type 0d with same dtype as input
auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
auto outTyp = ::imex::dist::cloneWithShape(aTyp, shape());
// reduction op
auto mop = sharpy2mlir(_op);
Expand Down
2 changes: 1 addition & 1 deletion src/SetGetItem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ struct DeferredGetItem : public Deferred {
const auto &offs = _slc.offsets();
const auto &sizes = shape();
const auto &strides = _slc.strides();
auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
auto outTyp = ::imex::dist::cloneWithShape(aTyp, shape());

// now we can create the NDArray op using the above Values
Expand Down
4 changes: 3 additions & 1 deletion src/_sharpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,9 @@ PYBIND11_MODULE(_sharpy, m) {
py::class_<IEWBinOp>(m, "IEWBinOp").def("op", &IEWBinOp::op);
py::class_<EWBinOp>(m, "EWBinOp").def("op", &EWBinOp::op);
py::class_<ReduceOp>(m, "ReduceOp").def("op", &ReduceOp::op);
py::class_<ManipOp>(m, "ManipOp").def("reshape", &ManipOp::reshape);
py::class_<ManipOp>(m, "ManipOp")
.def("reshape", &ManipOp::reshape)
.def("permute_dims", &ManipOp::permute_dims);
py::class_<LinAlgOp>(m, "LinAlgOp").def("vecdot", &LinAlgOp::vecdot);

py::class_<FutureArray>(m, "SHARPYFuture")
Expand Down
Loading
Loading