diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4b996a9..edaea21 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,13 +19,13 @@ repos: - id: trailing-whitespace exclude: '.*\.patch' - repo: https://github.com/psf/black - rev: 24.3.0 + rev: 24.8.0 hooks: - id: black args: ["--line-length", "80"] language_version: python3 - repo: https://github.com/PyCQA/bandit - rev: '1.7.8' + rev: '1.7.9' hooks: - id: bandit args: ["-c", ".bandit.yml"] @@ -35,7 +35,7 @@ repos: - id: isort name: isort (python) - repo: https://github.com/pycqa/flake8 - rev: 7.0.0 + rev: 7.1.1 hooks: - id: flake8 - repo: https://github.com/pocc/pre-commit-hooks diff --git a/CMakeLists.txt b/CMakeLists.txt index e50b714..4e58301 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -149,6 +149,7 @@ include_directories( ${PROJECT_SOURCE_DIR}/third_party/bitsery/include ${MPI_INCLUDE_PATH} ${pybind11_INCLUDE_DIRS} + ${LLVM_INCLUDE_DIRS} ${MLIR_INCLUDE_DIRS} ${IMEX_INCLUDE_DIRS}) diff --git a/examples/transpose.py b/examples/transpose.py new file mode 100644 index 0000000..5b0e09d --- /dev/null +++ b/examples/transpose.py @@ -0,0 +1,184 @@ +""" +Transpose benchmark + + Matrix transpose benchmark for sharpy and numpy backends. + +Examples: + + # Run 1000 iterations of 1000*1000 matrix on sharpy backend + python transpose.py -r 10 -c 1000 -b sharpy -i 1000 + + # MPI parallel run + mpiexec -n 3 python transpose.py -r 1000 -c 1000 -b sharpy -i 1000 + +""" + +import argparse +import time as time_mod + +import numpy + +import sharpy + +try: + import mpi4py + + mpi4py.rc.finalize = False + from mpi4py import MPI + + comm_rank = MPI.COMM_WORLD.Get_rank() + comm = MPI.COMM_WORLD +except ImportError: + comm_rank = 0 + comm = None + + +def info(s): + if comm_rank == 0: + print(s) + + +def sp_transpose(arr): + brr = sharpy.permute_dims(arr, [1, 0]) + return brr + + +def np_transpose(arr): + brr = arr.transpose() + return brr.copy() + + +def initialize(np, row, col, dtype): + arr = np.arange(0, row * col, 1, dtype=dtype) + return np.reshape(arr, (row, col)) + + +def run(row, col, backend, iterations, datatype): + if backend == "sharpy": + import sharpy as np + from sharpy import fini, init, sync + + transpose = sp_transpose + + init(False) + elif backend == "numpy": + import numpy as np + + if comm is not None: + assert ( + comm.Get_size() == 1 + ), "Numpy backend only supports serial execution." + + fini = sync = lambda x=None: None + transpose = np_transpose + else: + raise ValueError(f'Unknown backend: "{backend}"') + + dtype = { + "f32": np.float32, + "f64": np.float64, + }[datatype] + + info(f"Using backend: {backend}") + info(f"Number of row: {row}") + info(f"Number of column: {col}") + info(f"Datatype: {datatype}") + + arr = initialize(np, row, col, dtype) + sync() + + # verify + if backend == "sharpy": + brr = sp_transpose(arr) + crr = np_transpose(sharpy.to_numpy(arr)) + assert numpy.allclose(sharpy.to_numpy(brr), crr) + + def eval(): + tic = time_mod.perf_counter() + transpose(arr) + sync() + toc = time_mod.perf_counter() + return toc - tic + + # warm-up run + t_warm = eval() + + # evaluate + info(f"Running {iterations} iterations") + time_list = [] + for i in range(iterations): + time_list.append(eval()) + + # get max time over mpi ranks + if comm is not None: + t_warm = comm.allreduce(t_warm, MPI.MAX) + time_list = comm.allreduce(time_list, MPI.MAX) + + t_min = numpy.min(time_list) + t_max = numpy.max(time_list) + t_med = numpy.median(time_list) + init_overhead = t_warm - t_med + if backend == "sharpy": + info(f"Estimated initialization overhead: {init_overhead:.5f} s") + info(f"Min. duration: {t_min:.5f} s") + info(f"Max. duration: {t_max:.5f} s") + info(f"Median duration: {t_med:.5f} s") + + fini() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Run transpose benchmark", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "-r", + "--row", + type=int, + default=10000, + help="Number of row.", + ) + parser.add_argument( + "-c", + "--column", + type=int, + default=10000, + help="Number of column.", + ) + + parser.add_argument( + "-b", + "--backend", + type=str, + default="sharpy", + choices=["sharpy", "numpy"], + help="Backend to use.", + ) + + parser.add_argument( + "-i", + "--iterations", + type=int, + default=10, + help="Number of iterations to run.", + ) + + parser.add_argument( + "-d", + "--datatype", + type=str, + default="f64", + choices=["f32", "f64"], + help="Datatype for model state variables", + ) + + args = parser.parse_args() + run( + args.row, + args.column, + args.backend, + args.iterations, + args.datatype, + ) diff --git a/imex_version.txt b/imex_version.txt index 35d0562..536e988 100644 --- a/imex_version.txt +++ b/imex_version.txt @@ -1 +1 @@ -5a7bb80ede5fe4fa8d56ee0dd77c4e5c1327fe09 +8ae485bbfb1303a414b375e25130fcaa4c02127a diff --git a/setup.py b/setup.py index 3b8c3e6..d30feff 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +import multiprocessing import os import pathlib @@ -44,7 +45,10 @@ def build_cmake(self, ext): os.chdir(str(build_temp)) self.spawn(["cmake", str(cwd)] + cmake_args) if not self.dry_run: - self.spawn(["cmake", "--build", ".", "-j5"] + build_args) + self.spawn( + ["cmake", "--build", ".", f"-j{multiprocessing.cpu_count()}"] + + build_args + ) # Troubleshooting: if fail on line above then delete all possible # temporary CMake files including "CMakeCache.txt" in top level dir. os.chdir(str(cwd)) diff --git a/sharpy/__init__.py b/sharpy/__init__.py index b7784c0..cb1cc47 100644 --- a/sharpy/__init__.py +++ b/sharpy/__init__.py @@ -130,6 +130,10 @@ def _validate_device(device): exec( f"{func} = lambda this, shape, cp=None: ndarray(_csp.ManipOp.reshape(this._t, shape, cp))" ) + elif func == "permute_dims": + exec( + f"{func} = lambda this, axes: ndarray(_csp.ManipOp.permute_dims(this._t, axes))" + ) for func in api.api_categories["ReduceOp"]: FUNC = func.upper() diff --git a/sharpy/array_api.py b/sharpy/array_api.py index 421d43b..add8a86 100644 --- a/sharpy/array_api.py +++ b/sharpy/array_api.py @@ -179,6 +179,7 @@ "roll", # (x, /, shift, *, axis=None) "squeeze", # (x, /, axis) "stack", # (arrays, /, *, axis=0) + "permute_dims", # (x: array, /, axes: Tuple[int, ...]) → array ], "LinAlgOp": [ "matmul", # (x1, x2, /) diff --git a/src/EWBinOp.cpp b/src/EWBinOp.cpp index 73fe968..a040fb7 100644 --- a/src/EWBinOp.cpp +++ b/src/EWBinOp.cpp @@ -120,7 +120,7 @@ struct DeferredEWBinOp : public Deferred { auto av = dm.getDependent(builder, Registry::get(_a)); auto bv = dm.getDependent(builder, Registry::get(_b)); - auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>(); + auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType()); auto outElemType = ::imex::ndarray::toMLIR(builder, SHARPY::jit::getPTDType(_dtype)); auto outTyp = aTyp.cloneWith(shape(), outElemType); diff --git a/src/EWUnyOp.cpp b/src/EWUnyOp.cpp index 9c50553..025ac18 100644 --- a/src/EWUnyOp.cpp +++ b/src/EWUnyOp.cpp @@ -105,7 +105,7 @@ struct DeferredEWUnyOp : public Deferred { jit::DepManager &dm) override { auto av = dm.getDependent(builder, Registry::get(_a)); - auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>(); + auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType()); auto outTyp = aTyp.cloneWith(shape(), aTyp.getElementType()); auto ndOpId = sharpy(_op); diff --git a/src/IEWBinOp.cpp b/src/IEWBinOp.cpp index 396b8d5..72aef3a 100644 --- a/src/IEWBinOp.cpp +++ b/src/IEWBinOp.cpp @@ -71,7 +71,7 @@ struct DeferredIEWBinOp : public Deferred { auto av = dm.getDependent(builder, Registry::get(_a)); auto bv = dm.getDependent(builder, Registry::get(_b)); - auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>(); + auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType()); auto outTyp = aTyp.cloneWith(shape(), aTyp.getElementType()); auto binop = builder.create<::imex::ndarray::EWBinOp>( diff --git a/src/ManipOp.cpp b/src/ManipOp.cpp index 4a69110..4179540 100644 --- a/src/ManipOp.cpp +++ b/src/ManipOp.cpp @@ -41,7 +41,7 @@ struct DeferredReshape : public Deferred { ? ::mlir::IntegerAttr() : ::imex::getIntAttr(builder, COPY_ALWAYS ? true : false, 1); - auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>(); + auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType()); auto outTyp = imex::dist::cloneWithShape(aTyp, shape()); auto op = @@ -106,7 +106,7 @@ struct DeferredAsType : public Deferred { // construct NDArrayType with same shape and given dtype ::imex::ndarray::DType ndDType = dispatch(dtype); auto mlirElType = ::imex::ndarray::toMLIR(builder, ndDType); - auto arType = av.getType().dyn_cast<::imex::ndarray::NDArrayType>(); + auto arType = ::mlir::dyn_cast<::imex::ndarray::NDArrayType>(av.getType()); if (!arType) { throw std::invalid_argument( "Encountered unexpected ndarray type in astype."); @@ -157,7 +157,7 @@ struct DeferredToDevice : public Deferred { jit::DepManager &dm) override { auto av = dm.getDependent(builder, Registry::get(_a)); - auto srcType = av.getType().dyn_cast<::imex::ndarray::NDArrayType>(); + auto srcType = ::mlir::dyn_cast<::imex::ndarray::NDArrayType>(av.getType()); if (!srcType) { throw std::invalid_argument( "Encountered unexpected ndarray type in to_device."); @@ -205,6 +205,56 @@ struct DeferredToDevice : public Deferred { } }; +struct DeferredPermuteDims : public Deferred { + id_type _array; + shape_type _axes; + + DeferredPermuteDims() = default; + DeferredPermuteDims(const array_i::future_type &array, + const shape_type &shape, const shape_type &axes) + : Deferred(array.dtype(), shape, array.device(), array.team()), + _array(array.guid()), _axes(axes) {} + + bool generate_mlir(::mlir::OpBuilder &builder, const ::mlir::Location &loc, + jit::DepManager &dm) override { + auto arrayValue = dm.getDependent(builder, Registry::get(_array)); + + auto axesAttr = builder.getDenseI64ArrayAttr(_axes); + + auto aTyp = + ::mlir::cast<::imex::ndarray::NDArrayType>(arrayValue.getType()); + auto outTyp = imex::dist::cloneWithShape(aTyp, shape()); + + auto op = builder.create<::imex::ndarray::PermuteDimsOp>( + loc, outTyp, arrayValue, axesAttr); + + dm.addVal( + this->guid(), op, + [this](uint64_t rank, void *l_allocated, void *l_aligned, + intptr_t l_offset, const intptr_t *l_sizes, + const intptr_t *l_strides, void *o_allocated, void *o_aligned, + intptr_t o_offset, const intptr_t *o_sizes, + const intptr_t *o_strides, void *r_allocated, void *r_aligned, + intptr_t r_offset, const intptr_t *r_sizes, + const intptr_t *r_strides, std::vector &&loffs) { + auto t = mk_tnsr(this->guid(), _dtype, this->shape(), this->device(), + this->team(), l_allocated, l_aligned, l_offset, + l_sizes, l_strides, o_allocated, o_aligned, o_offset, + o_sizes, o_strides, r_allocated, r_aligned, r_offset, + r_sizes, r_strides, std::move(loffs)); + this->set_value(std::move(t)); + }); + + return false; + } + + FactoryId factory() const override { return F_PERMUTEDIMS; } + + template void serialize(S &ser) { + ser.template value(_array); + } +}; + FutureArray *ManipOp::reshape(const FutureArray &a, const shape_type &shape, const py::object ©) { auto doCopy = copy.is_none() @@ -229,7 +279,32 @@ FutureArray *ManipOp::to_device(const FutureArray &a, return new FutureArray(defer(a.get(), device)); } +FutureArray *ManipOp::permute_dims(const FutureArray &array, + const shape_type &axes) { + auto shape = array.get().shape(); + + // verifyPermuteArray + if (shape.size() != axes.size()) { + throw std::invalid_argument("axes must have the same length as the shape"); + } + for (auto i = 0ul; i < shape.size(); ++i) { + if (std::find(axes.begin(), axes.end(), i) == axes.end()) { + throw std::invalid_argument("axes must contain all dimensions"); + } + } + + auto permutedShape = shape_type(shape.size()); + for (auto i = 0ul; i < shape.size(); ++i) { + permutedShape[i] = shape[axes[i]]; + } + + return new FutureArray( + defer(array.get(), permutedShape, axes)); +} + FACTORY_INIT(DeferredReshape, F_RESHAPE); FACTORY_INIT(DeferredAsType, F_ASTYPE); FACTORY_INIT(DeferredToDevice, F_TODEVICE); +FACTORY_INIT(DeferredPermuteDims, F_PERMUTEDIMS); + } // namespace SHARPY diff --git a/src/NDArray.cpp b/src/NDArray.cpp index 709c99e..5f7962d 100644 --- a/src/NDArray.cpp +++ b/src/NDArray.cpp @@ -113,7 +113,9 @@ void NDArray::NDADeleter::operator()(NDArray *a) const { std::cerr << "sharpy fini: detected possible memory leak\n"; } else { auto av = dm.addDependent(builder, a); - builder.create<::imex::ndarray::DeleteOp>(loc, av); + auto deleteOp = builder.create<::imex::ndarray::DeleteOp>(loc, av); + deleteOp->setAttr("bufferization.manual_deallocation", + builder.getUnitAttr()); dm.drop(a->guid()); } return false; diff --git a/src/ReduceOp.cpp b/src/ReduceOp.cpp index 315f17b..3bdfa5a 100644 --- a/src/ReduceOp.cpp +++ b/src/ReduceOp.cpp @@ -61,7 +61,7 @@ struct DeferredReduceOp : public Deferred { // FIXME reduction over individual dimensions is not supported auto av = dm.getDependent(builder, Registry::get(_a)); // return type 0d with same dtype as input - auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>(); + auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType()); auto outTyp = ::imex::dist::cloneWithShape(aTyp, shape()); // reduction op auto mop = sharpy2mlir(_op); diff --git a/src/SetGetItem.cpp b/src/SetGetItem.cpp index 1ba06d8..38af72e 100644 --- a/src/SetGetItem.cpp +++ b/src/SetGetItem.cpp @@ -277,7 +277,7 @@ struct DeferredGetItem : public Deferred { const auto &offs = _slc.offsets(); const auto &sizes = shape(); const auto &strides = _slc.strides(); - auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>(); + auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType()); auto outTyp = ::imex::dist::cloneWithShape(aTyp, shape()); // now we can create the NDArray op using the above Values diff --git a/src/_sharpy.cpp b/src/_sharpy.cpp index 2e9e6f4..81f40b3 100644 --- a/src/_sharpy.cpp +++ b/src/_sharpy.cpp @@ -196,7 +196,9 @@ PYBIND11_MODULE(_sharpy, m) { py::class_(m, "IEWBinOp").def("op", &IEWBinOp::op); py::class_(m, "EWBinOp").def("op", &EWBinOp::op); py::class_(m, "ReduceOp").def("op", &ReduceOp::op); - py::class_(m, "ManipOp").def("reshape", &ManipOp::reshape); + py::class_(m, "ManipOp") + .def("reshape", &ManipOp::reshape) + .def("permute_dims", &ManipOp::permute_dims); py::class_(m, "LinAlgOp").def("vecdot", &LinAlgOp::vecdot); py::class_(m, "SHARPYFuture") diff --git a/src/idtr.cpp b/src/idtr.cpp index 3a42b26..11869ea 100644 --- a/src/idtr.cpp +++ b/src/idtr.cpp @@ -570,6 +570,402 @@ _idtr_copy_reshape(SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr, oData.data(), oData.sizes(), oData.strides()); } +namespace { + +/// +/// An util class of multi-dimensional index +/// +class id { +public: + id(size_t dims) : _values(dims) {} + id(size_t dims, int64_t *value) : _values(value, value + dims) {} + id(const std::vector &values) : _values(values) {} + id(const std::vector &&values) : _values(std::move(values)) {} + + /// Permute this id by axes and return a new id + id permute(std::vector axes) const { + std::vector new_values(_values.size()); + for (size_t i = 0; i < _values.size(); i++) { + new_values[i] = _values[axes[i]]; + } + return id(std::move(new_values)); + } + + int64_t operator[](size_t i) const { return _values[i]; } + int64_t &operator[](size_t i) { return _values[i]; } + + /// Subtract another id from this id and return a new id + id operator-(const id &rhs) const { + std::vector new_values(_values.size()); + for (size_t i = 0; i < _values.size(); i++) { + new_values[i] = _values[i] - rhs._values[i]; + } + return id(std::move(new_values)); + } + + /// Subtract another id from this id and return a new id + id operator-(const int64_t *rhs) const { + std::vector new_values(_values.size()); + for (size_t i = 0; i < _values.size(); i++) { + new_values[i] = _values[i] - rhs[i]; + } + return id(std::move(new_values)); + } + + /// Increase the last dimension value of this id which bounds by shape + /// + /// Example: + /// In shape (2,2) : (0,0)->(0,1)->(1,0)->(1,1)->(0,0) + void next(const int64_t *shape) { + size_t i = _values.size(); + while (i--) { + ++_values[i]; + if (_values[i] < shape[i]) { + return; + } + _values[i] = 0; + } + } + + size_t size() { return _values.size(); } + +private: + std::vector _values; +}; + +/// +/// An wrapper template class for distribute multi-dimensional array +/// +template class ndarray { +public: + ndarray(int64_t nDims, int64_t *gShape, int64_t *gOffsets, void *lData, + int64_t *lShape, int64_t *lStrides) + : _nDims(nDims), _gShape(gShape), _gOffsets(gOffsets), _lData((T *)lData), + _lShape(lShape), _lStrides(lStrides) {} + + /// Return the first global index of local data + id firstLocalIndex() const { return id(_nDims, _gOffsets); } + + /// Interate all global indices in local data + void localIndices(const std::function &callback) const { + size_t size = lSize(); + id idx = firstLocalIndex(); + while (size--) { + callback(idx); + idx.next(_gShape); + } + } + + /// Interate all global indices of the array + void globalIndices(const std::function &callback) const { + size_t size = gSize(); + id idx(_nDims); + while (size--) { + callback(idx); + idx.next(_gShape); + } + } + + int64_t getLocalDataOffset(const id &idx) const { + auto localIdx = idx - _gOffsets; + int64_t offset = 0; + for (int64_t i = 0; i < _nDims - 1; ++i) { + offset = (offset + localIdx[i]) * _lShape[i + 1]; + } + offset += localIdx[_nDims - 1]; + return offset; + } + + /// Using global index to access its data + T &operator[](const id &idx) { return _lData[getLocalDataOffset(idx)]; } + T operator[](const id &idx) const { return _lData[getLocalDataOffset(idx)]; } + + id gShape() { return id(_nDims, _gShape); } + id lShape() { return id(_nDims, _lShape); } + + size_t gSize() const { + return std::accumulate(_gShape, _gShape + _nDims, 1, + std::multiplies()); + } + + size_t lSize() const { + return std::accumulate(_lShape, _lShape + _nDims, 1, + std::multiplies()); + } + +private: + int64_t _nDims; + int64_t *_gShape; + int64_t *_gOffsets; + T *_lData; + int64_t *_lShape; + int64_t *_lStrides; +}; + +struct Parts { + int64_t iStart; + int64_t iEnd; + int64_t oStart; + int64_t oEnd; +}; + +size_t getInputRank(const std::vector &parts, int64_t dim0) { + for (size_t i = 0; i < parts.size(); i++) { + if (dim0 >= parts[i].iStart && dim0 < parts[i].iEnd) { + return i; + } + } + assert(false && "unreachable"); + return 0; +} + +size_t getOutputRank(const std::vector &parts, int64_t dim0) { + for (size_t i = 0; i < parts.size(); i++) { + if (dim0 >= parts[i].oStart && dim0 < parts[i].oEnd) { + return i; + } + } + assert(false && "unreachable"); + return 0; +} + +template class WaitPermute { +public: + WaitPermute(SHARPY::Transceiver *tc, SHARPY::Transceiver::WaitHandle hdl, + SHARPY::rank_type cRank, SHARPY::rank_type nRanks, + std::vector &&parts, std::vector &&axes, + std::vector oGShape, ndarray &&input, + ndarray &&output, std::vector &&receiveBuffer, + std::vector &&receiveOffsets, + std::vector &&receiveSizes) + : tc(tc), hdl(hdl), cRank(cRank), nRanks(nRanks), parts(std::move(parts)), + axes(std::move(axes)), oGShape(std::move(oGShape)), + input(std::move(input)), output(std::move(output)), + receiveBuffer(std::move(receiveBuffer)), + receiveOffsets(std::move(receiveOffsets)), + receiveSizes(std::move(receiveSizes)) {} + + void operator()() { + tc->wait(hdl); + std::vector> receiveRankBuffer(nRanks); + for (size_t rank = 0; rank < nRanks; ++rank) { + auto &rankBuffer = receiveRankBuffer[rank]; + rankBuffer.insert( + rankBuffer.end(), receiveBuffer.begin() + receiveOffsets[rank], + receiveBuffer.begin() + receiveOffsets[rank] + receiveSizes[rank]); + } + + std::vector receiveRankBufferCount(nRanks, 0); + input.globalIndices([&](const id &inputIndex) { + id outputIndex = inputIndex.permute(axes); + auto rank = getOutputRank(parts, outputIndex[0]); + if (rank != cRank) + return; + rank = getInputRank(parts, inputIndex[0]); + auto &count = receiveRankBufferCount[rank]; + output[outputIndex] = receiveRankBuffer[rank][count++]; + }); + } + +private: + SHARPY::Transceiver *tc; + SHARPY::Transceiver::WaitHandle hdl; + SHARPY::rank_type cRank; + SHARPY::rank_type nRanks; + std::vector parts; + std::vector axes; + std::vector oGShape; + ndarray input; + ndarray output; + std::vector receiveBuffer; + std::vector receiveOffsets; + std::vector receiveSizes; +}; + +} // namespace + +/// @brief permute array +/// We assume array is partitioned along the first dimension (only) and +/// partitions are ordered by ranks +template +WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype, + SHARPY::Transceiver *tc, int64_t iNDims, + int64_t *iGShapePtr, int64_t *iOffsPtr, + void *iDataPtr, int64_t *iDataShapePtr, + int64_t *iDataStridesPtr, int64_t *oOffsPtr, + void *oDataPtr, int64_t *oDataShapePtr, + int64_t *oDataStridesPtr, int64_t *axesPtr) { +#ifdef NO_TRANSCEIVER + initMPIRuntime(); + tc = SHARPY::getTransceiver(); +#endif + if (!iGShapePtr || !iOffsPtr || !iDataPtr || !iDataShapePtr || + !iDataStridesPtr || !oOffsPtr || !oDataPtr || !oDataShapePtr || + !oDataStridesPtr || !tc) { + throw std::invalid_argument("Fatal: received nullptr in reshape"); + } + + std::vector oGShape(iNDims); + for (int64_t i = 0; i < iNDims; ++i) { + oGShape[i] = iGShapePtr[axesPtr[i]]; + } + auto *oGShapePtr = oGShape.data(); + const auto oNDims = iNDims; + + assert(std::accumulate(&iGShapePtr[0], &iGShapePtr[iNDims], 1, + std::multiplies()) == + std::accumulate(&oGShapePtr[0], &oGShapePtr[oNDims], 1, + std::multiplies())); + assert(std::accumulate(&oOffsPtr[1], &oOffsPtr[oNDims], 0, + std::plus()) == 0); + + const auto nRanks = tc->nranks(); + const auto cRank = tc->rank(); + if (nRanks <= cRank) { + throw std::out_of_range("Fatal: rank must be < number of ranks"); + } + + int64_t icSz = std::accumulate(&iGShapePtr[1], &iGShapePtr[iNDims], 1, + std::multiplies()); + assert(icSz == std::accumulate(&iDataShapePtr[1], &iDataShapePtr[iNDims], 1, + std::multiplies())); + int64_t mySz = icSz * iDataShapePtr[0]; + if (mySz / icSz != iDataShapePtr[0]) { + throw std::overflow_error("Fatal: Integer overflow in reshape"); + } + int64_t myOff = iOffsPtr[0] * icSz; + if (myOff / icSz != iOffsPtr[0]) { + throw std::overflow_error("Fatal: Integer overflow in reshape"); + } + int64_t myEnd = myOff + mySz; + if (myEnd < myOff) { + throw std::overflow_error("Fatal: Integer overflow in reshape"); + } + + int64_t oCSz = std::accumulate(&oGShapePtr[1], &oGShapePtr[oNDims], 1, + std::multiplies()); + assert(oCSz == std::accumulate(&oDataShapePtr[1], &oDataShapePtr[oNDims], 1, + std::multiplies())); + int64_t myOSz = oCSz * oDataShapePtr[0]; + if (myOSz / oCSz != oDataShapePtr[0]) { + throw std::overflow_error("Fatal: Integer overflow in reshape"); + } + int64_t myOOff = oOffsPtr[0] * oCSz; + if (myOOff / oCSz != oOffsPtr[0]) { + throw std::overflow_error("Fatal: Integer overflow in reshape"); + } + int64_t myOEnd = myOOff + myOSz; + if (myOEnd < myOOff) { + throw std::overflow_error("Fatal: Integer overflow in reshape"); + } + + // First we allgather the current and target partitioning + std::vector parts(nRanks); + parts[cRank].iStart = iOffsPtr[0]; + parts[cRank].iEnd = iOffsPtr[0] + iDataShapePtr[0]; + parts[cRank].oStart = oOffsPtr[0]; + parts[cRank].oEnd = oOffsPtr[0] + oDataShapePtr[0]; + std::vector counts(nRanks, 4); + std::vector dspl(nRanks); + for (auto i = 0ul; i < nRanks; ++i) { + dspl[i] = 4 * i; + } + tc->gather(parts.data(), counts.data(), dspl.data(), SHARPY::INT64, + SHARPY::REPLICATED); + + // Transpose + ndarray input(iNDims, iGShapePtr, iOffsPtr, iDataPtr, iDataShapePtr, + iDataStridesPtr); + ndarray output(oNDims, oGShapePtr, oOffsPtr, oDataPtr, oDataShapePtr, + oDataStridesPtr); + std::vector axes(axesPtr, axesPtr + iNDims); + + std::vector sendBuffer; + std::vector receiveBuffer(output.lSize()); + std::vector sendSizes(nRanks); + std::vector sendOffsets(nRanks); + std::vector receiveSizes(nRanks); + std::vector receiveOffsets(nRanks); + + { + std::vector> sendRankBuffer(nRanks); + + input.localIndices([&](const id &inputIndex) { + id outputIndex = inputIndex.permute(axes); + auto rank = getOutputRank(parts, outputIndex[0]); + sendRankBuffer[rank].push_back(input[inputIndex]); + }); + + int lastOffset = 0; + for (size_t rank = 0; rank < nRanks; rank++) { + sendSizes[rank] = sendRankBuffer[rank].size(); + sendOffsets[rank] = lastOffset; + sendBuffer.insert(sendBuffer.end(), sendRankBuffer[rank].begin(), + sendRankBuffer[rank].end()); + lastOffset += sendSizes[rank]; + } + + output.localIndices([&](const id &outputIndex) { + id inputIndex = outputIndex.permute(axes); + auto rank = getInputRank(parts, inputIndex[0]); + ++receiveSizes[rank]; + }); + for (size_t rank = 1; rank < nRanks; rank++) { + receiveOffsets[rank] = receiveOffsets[rank - 1] + receiveSizes[rank - 1]; + } + } + + auto hdl = tc->alltoall(sendBuffer.data(), sendSizes.data(), + sendOffsets.data(), sharpytype, receiveBuffer.data(), + receiveSizes.data(), receiveOffsets.data()); + + auto wait = WaitPermute(tc, hdl, cRank, nRanks, std::move(parts), + std::move(axes), std::move(oGShape), std::move(input), + std::move(output), std::move(receiveBuffer), + std::move(receiveOffsets), std::move(receiveSizes)); + + assert(parts.empty() && axes.empty() && receiveBuffer.empty() && + receiveOffsets.empty() && receiveSizes.empty()); + + if (no_async) { + wait(); + return nullptr; + } + + return mkWaitHandle(std::move(wait)); +} + +/// @brief permute array +template +WaitHandleBase * +_idtr_copy_permute(SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr, + int64_t iNOffs, void *iLOffsDescr, int64_t iNDims, + void *iDataDescr, int64_t oNOffs, void *oLOffsDescr, + int64_t oNDims, void *oDataDescr, int64_t axesSzs, + void *axesDescr) { + + if (!iGShapeDescr || !iLOffsDescr || !iDataDescr || !oLOffsDescr || + !oDataDescr || !axesDescr) { + throw std::invalid_argument( + "Fatal error: received nullptr in update_halo."); + } + + auto sharpyType = SHARPY::DTYPE::value; + + // Construct unranked memrefs for metadata and data + MRIdx1d iGShape(iNSzs, iGShapeDescr); + MRIdx1d iOffs(iNOffs, iLOffsDescr); + SHARPY::UnrankedMemRefType iData(iNDims, iDataDescr); + MRIdx1d oOffs(oNOffs, oLOffsDescr); + SHARPY::UnrankedMemRefType oData(oNDims, oDataDescr); + MRIdx1d axes(axesSzs, axesDescr); + + return _idtr_copy_permute(sharpyType, tc, iNDims, iGShape.data(), + iOffs.data(), iData.data(), iData.sizes(), + iData.strides(), oOffs.data(), oData.data(), + oData.sizes(), oData.strides(), axes.data()); +} + extern "C" { #define TYPED_COPY_RESHAPE(_sfx, _typ) \ void *_idtr_copy_reshape_##_sfx( \ @@ -592,6 +988,28 @@ TYPED_COPY_RESHAPE(i16, int16_t); TYPED_COPY_RESHAPE(i8, int8_t); TYPED_COPY_RESHAPE(i1, bool); +#define TYPED_COPY_PERMUTE(_sfx, _typ) \ + void *_idtr_copy_permute_##_sfx( \ + SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr, \ + int64_t iNOffs, void *iLOffsDescr, int64_t iNDims, void *iLDescr, \ + int64_t oNOffs, void *oLOffsDescr, int64_t oNDims, void *oLDescr, \ + int64_t axesSzs, void *axesDescr) { \ + return _idtr_copy_permute<_typ>( \ + tc, iNSzs, iGShapeDescr, iNOffs, iLOffsDescr, iNDims, iLDescr, oNOffs, \ + oLOffsDescr, oNDims, oLDescr, axesSzs, axesDescr); \ + } \ + _Pragma(STRINGIFY(weak _mlir_ciface__idtr_copy_permute_##_sfx = \ + _idtr_copy_permute_##_sfx)) + +TYPED_COPY_PERMUTE(f64, double); +TYPED_COPY_PERMUTE(f32, float); +TYPED_COPY_PERMUTE(i64, int64_t); +TYPED_COPY_PERMUTE(i32, int32_t); +TYPED_COPY_PERMUTE(i16, int16_t); +TYPED_COPY_PERMUTE(i8, int8_t); +// FIXME: bool is not supported yet due to std::vector +// TYPED_COPY_PERMUTE(i1, bool); + } // extern "C" // struct for caching meta data for update_halo diff --git a/src/include/sharpy/CppTypes.hpp b/src/include/sharpy/CppTypes.hpp index 62eda48..a823d20 100644 --- a/src/include/sharpy/CppTypes.hpp +++ b/src/include/sharpy/CppTypes.hpp @@ -339,6 +339,7 @@ enum FactoryId : int { F_REDUCEOP, F_REPLICATE, F_RESHAPE, + F_PERMUTEDIMS, F_SERVICE, F_SETITEM, F_ASTYPE, diff --git a/src/include/sharpy/ManipOp.hpp b/src/include/sharpy/ManipOp.hpp index e78e3dd..5b19d1e 100644 --- a/src/include/sharpy/ManipOp.hpp +++ b/src/include/sharpy/ManipOp.hpp @@ -20,5 +20,8 @@ struct ManipOp { static FutureArray *to_device(const FutureArray &a, const std::string &device); + + static FutureArray *permute_dims(const FutureArray &array, + const shape_type &axes); }; } // namespace SHARPY diff --git a/src/jit/mlir.cpp b/src/jit/mlir.cpp index 0810dd5..d84186e 100644 --- a/src/jit/mlir.cpp +++ b/src/jit/mlir.cpp @@ -691,7 +691,8 @@ static const std::string cpu_pipeline = "one-shot-bufferize," "canonicalize," "imex-remove-temporaries," - "func.func(buffer-deallocation)," + "buffer-deallocation-pipeline," + "convert-bufferization-to-memref," "func.func(convert-linalg-to-parallel-loops)," "func.func(scf-parallel-loop-fusion)," "drop-regions," diff --git a/test/test_manip.py b/test/test_manip.py index 3ce2407..f2e6bcd 100644 --- a/test/test_manip.py +++ b/test/test_manip.py @@ -93,3 +93,46 @@ def test_todevice_host2gpu(self): a = sp.arange(0, 8, 1, sp.int32) b = a.to_device(device="GPU") assert numpy.allclose(sp.to_numpy(b), [0, 1, 2, 3, 4, 5, 6, 7]) + + def test_permute_dims1(self): + a = sp.arange(0, 10, 1, sp.int64) + b = sp.reshape(a, (2, 5)) + c1 = sp.to_numpy(sp.permute_dims(b, [1, 0])) + c2 = sp.to_numpy(b).transpose(1, 0) + assert numpy.allclose(c1, c2) + + def test_permute_dims2(self): + # === sharpy + sp_a = sp.arange(0, 2 * 3 * 4, 1) + sp_a = sp.reshape(sp_a, [2, 3, 4]) + + # b = a.swapaxes(1,0).swapaxes(1,2) + sp_b = sp.permute_dims(sp_a, (1, 0, 2)) # 2x4x4 -> 4x2x4 || 4x4x4 + sp_b = sp.permute_dims(sp_b, (0, 2, 1)) # 4x2x4 -> 4x4x2 || 4x4x4 + + # c = b.swapaxes(1,2).swapaxes(1,0) + sp_c = sp.permute_dims(sp_b, (0, 2, 1)) + sp_c = sp.permute_dims(sp_c, (1, 0, 2)) + + assert numpy.allclose(sp.to_numpy(sp_a), sp.to_numpy(sp_c)) + + # d = a.swapaxes(2,1).swapaxes(2,0) + sp_d = sp.permute_dims(sp_a, (0, 2, 1)) + sp_d = sp.permute_dims(sp_d, (2, 1, 0)) + + # c = d.swapaxes(2,1).swapaxes(0,1) + sp_e = sp.permute_dims(sp_d, (0, 2, 1)) + sp_e = sp.permute_dims(sp_e, (1, 0, 2)) + + # === numpy + np_a = numpy.arange(0, 2 * 3 * 4, 1) + np_a = numpy.reshape(np_a, [2, 3, 4]) + + np_b = np_a.swapaxes(1, 0).swapaxes(1, 2) + assert numpy.allclose(sp.to_numpy(sp_b), np_b) + + np_d = np_a.swapaxes(2, 1).swapaxes(2, 0) + assert numpy.allclose(sp.to_numpy(sp_d), np_d) + + np_e = np_d.swapaxes(2, 1).swapaxes(0, 1) + assert numpy.allclose(sp.to_numpy(sp_e), np_e)