diff --git a/include/dr/mhp/algorithms/sort.hpp b/include/dr/mhp/algorithms/sort.hpp index e3eba82728..ffa65d7ea1 100644 --- a/include/dr/mhp/algorithms/sort.hpp +++ b/include/dr/mhp/algorithms/sort.hpp @@ -47,30 +47,25 @@ void local_sort(R &r, Compare &&comp) { } } -// TODO: quite a long function, refactor to make the code more clear -template -void dist_sort(R &r, Compare &&comp) { - using valT = typename R::value_type; +/* elements of dist_sort */ +template +void splitters(Seg &lsegment, Compare &&comp, + std::vector &vec_split_i, + std::vector &vec_split_s) { - const std::size_t _comm_rank = default_comm().rank(); const std::size_t _comm_size = default_comm().size(); // dr-style ignore - auto &&lsegment = local_segment(r); - /* sort local segment */ - - __detail::local_sort(lsegment, comp); - + std::vector vec_split_v(_comm_size - 1); std::vector vec_lmedians(_comm_size + 1); std::vector vec_gmedians((_comm_size + 1) * _comm_size); const double _step_m = static_cast(rng::size(lsegment)) / static_cast(_comm_size); - /* calculate splitting values and indices - find n-1 dividers splitting each - * segment into equal parts */ + /* calculate splitting values and indices - find n-1 dividers splitting + * each segment into equal parts */ for (std::size_t _i = 0; _i < rng::size(vec_lmedians); _i++) { - // vec_lmedians[_i] = lsegment[(_i + 1) * _step_m]; vec_lmedians[_i] = lsegment[_i * _step_m]; } vec_lmedians.back() = lsegment.back(); @@ -79,20 +74,10 @@ void dist_sort(R &r, Compare &&comp) { rng::sort(rng::begin(vec_gmedians), rng::end(vec_gmedians), comp); - /* find splitting values - medians of dividers */ - - std::vector vec_split_v(_comm_size - 1); - for (std::size_t _i = 0; _i < _comm_size - 1; _i++) { vec_split_v[_i] = vec_gmedians[(_i + 1) * (_comm_size + 1) - 1]; } - /* calculate splitting indices (start of buffers) and sizes of buffers to send - */ - - std::vector vec_split_i(_comm_size, 0); - std::vector vec_split_s(_comm_size, 0); - std::size_t segidx = 0, vidx = 1; while (vidx < _comm_size && segidx < rng::size(lsegment)) { @@ -107,77 +92,18 @@ void dist_sort(R &r, Compare &&comp) { } assert(rng::size(lsegment) > vec_split_i[vidx - 1]); vec_split_s[vidx - 1] = rng::size(lsegment) - vec_split_i[vidx - 1]; +} - /* send data size to each node */ - std::vector vec_rsizes(_comm_size, 0); - std::vector vec_rindices(_comm_size, 0); // recv buffers - - default_comm().alltoall(vec_split_s, vec_rsizes, 1); - - std::exclusive_scan(vec_rsizes.begin(), vec_rsizes.end(), - vec_rindices.begin(), 0); - - // const std::size_t _recv_elems = - // std::reduce(vec_rsizes.begin(), vec_rsizes.end()); - - const std::size_t _recv_elems = vec_rindices.back() + vec_rsizes.back(); - - /* send and receive data belonging to each node, then redistribute - * data to achieve size of data equal to size of local segment */ - - std::vector vec_recv_elems(_comm_size); - MPI_Request req_recvelems; - MPI_Status stat_recvelemes; - - default_comm().i_all_gather(_recv_elems, vec_recv_elems, &req_recvelems); - -#ifdef SYCL_LANGUAGE_VERSION - auto policy = dpl_policy(); - sycl::usm_allocator alloc(policy.queue()); - std::vector vec_recvdata(_recv_elems, alloc); -#else - std::vector vec_recvdata(_recv_elems); -#endif - - default_comm().alltoallv(lsegment, vec_split_s, vec_split_i, vec_recvdata, - vec_rsizes, vec_rindices); - - /* vec recvdata is partially sorted, implementation of merge on GPU is - * desirable */ - __detail::local_sort(vec_recvdata, comp); - - MPI_Wait(&req_recvelems, &stat_recvelemes); - - const std::size_t _total_elems = - std::reduce(vec_recv_elems.begin(), vec_recv_elems.end()); - - assert(_total_elems == rng::size(r)); - - std::vector vec_shift(_comm_size - 1); - - const auto desired_elems_num = (_total_elems + _comm_size - 1) / _comm_size; +template +void shift_data(const int shift_left, const int shift_right, auto &vec_recvdata, + auto &vec_left, auto &vec_right) { - vec_shift[0] = desired_elems_num - vec_recv_elems[0]; - for (std::size_t _i = 1; _i < _comm_size - 1; _i++) { - vec_shift[_i] = vec_shift[_i - 1] + desired_elems_num - vec_recv_elems[_i]; - } - - const int shift_left = _comm_rank == 0 ? 0 : -vec_shift[_comm_rank - 1]; - const int shift_right = - _comm_rank == _comm_size - 1 ? 0 : vec_shift[_comm_rank]; + const std::size_t _comm_rank = default_comm().rank(); MPI_Request req_l, req_r; MPI_Status stat_l, stat_r; const communicator::tag t = communicator::tag::halo_index; -#ifdef SYCL_LANGUAGE_VERSION - std::vector vec_left(std::max(shift_left, 0), alloc); - std::vector vec_right(std::max(shift_right, 0), alloc); -#else - std::vector vec_left(std::max(shift_left, 0)); - std::vector vec_right(std::max(shift_right, 0)); -#endif - if (static_cast(rng::size(vec_recvdata)) < -shift_left) { // Too little data in recv buffer to shift left - first get from right, then // send left @@ -223,7 +149,11 @@ void dist_sort(R &r, Compare &&comp) { if (shift_right != 0) MPI_Wait(&req_r, &stat_r); } +} +template +void copy_results(auto &lsegment, const int shift_left, const int shift_right, + auto &vec_recvdata, auto &vec_left, auto &vec_right) { const std::size_t invalidate_left = std::max(-shift_left, 0); const std::size_t invalidate_right = std::max(-shift_right, 0); @@ -243,7 +173,6 @@ void dist_sort(R &r, Compare &&comp) { lsegment.data() + size_l + size_d, size_r); e_d = sycl_queue().copy(vec_recvdata.data() + invalidate_left, lsegment.data() + size_l, size_d); - if (size_l > 0) e_l.wait(); if (size_r > 0) @@ -263,7 +192,100 @@ void dist_sort(R &r, Compare &&comp) { std::memcpy(lsegment.data() + size_l, vec_recvdata.data() + invalidate_left, size_d * sizeof(valT)); } +} + +template +void dist_sort(R &r, Compare &&comp) { + + using valT = typename R::value_type; + + const std::size_t _comm_rank = default_comm().rank(); + const std::size_t _comm_size = default_comm().size(); // dr-style ignore + +#ifdef SYCL_LANGUAGE_VERSION + auto policy = dpl_policy(); + sycl::usm_allocator alloc(policy.queue()); +#endif + + auto &&lsegment = local_segment(r); + + std::vector vec_split_i(_comm_size, 0); + std::vector vec_split_s(_comm_size, 0); + std::vector vec_rsizes(_comm_size, 0); + std::vector vec_rindices(_comm_size, 0); + std::vector vec_recv_elems(_comm_size, 0); + std::size_t _total_elems = 0; + + __detail::local_sort(lsegment, comp); + + /* find splitting values - limits of areas to send to other processes */ + __detail::splitters(lsegment, comp, vec_split_i, vec_split_s); + + default_comm().alltoall(vec_split_s, vec_rsizes, 1); + + /* prepare data to send and receive */ + std::exclusive_scan(vec_rsizes.begin(), vec_rsizes.end(), + vec_rindices.begin(), 0); + + const std::size_t _recv_elems = vec_rindices.back() + vec_rsizes.back(); + + /* send and receive data belonging to each node, then redistribute + * data to achieve size of data equal to size of local segment */ + + MPI_Request req_recvelems; + + default_comm().i_all_gather(_recv_elems, vec_recv_elems, &req_recvelems); + + /* buffer for received data */ +#ifdef SYCL_LANGUAGE_VERSION + std::vector vec_recvdata(_recv_elems, alloc); +#else + std::vector vec_recvdata(_recv_elems); +#endif + + /* send data not belonging and receive data belonging to local processes + */ + default_comm().alltoallv(lsegment, vec_split_s, vec_split_i, vec_recvdata, + vec_rsizes, vec_rindices); + + /* TODO: vec recvdata is partially sorted, implementation of merge on GPU is + * desirable */ + __detail::local_sort(vec_recvdata, comp); + + MPI_Wait(&req_recvelems, MPI_STATUS_IGNORE); + + _total_elems = std::reduce(vec_recv_elems.begin(), vec_recv_elems.end()); + + /* prepare data for shift to neighboring processes */ + std::vector vec_shift(_comm_size - 1); + + const auto desired_elems_num = (_total_elems + _comm_size - 1) / _comm_size; + + vec_shift[0] = desired_elems_num - vec_recv_elems[0]; + for (std::size_t _i = 1; _i < _comm_size - 1; _i++) { + vec_shift[_i] = vec_shift[_i - 1] + desired_elems_num - vec_recv_elems[_i]; + } + + const int shift_left = _comm_rank == 0 ? 0 : -vec_shift[_comm_rank - 1]; + const int shift_right = + _comm_rank == _comm_size - 1 ? 0 : vec_shift[_comm_rank]; + +#ifdef SYCL_LANGUAGE_VERSION + std::vector vec_left(std::max(shift_left, 0), alloc); + std::vector vec_right(std::max(shift_right, 0), alloc); +#else + std::vector vec_left(std::max(shift_left, 0)); + std::vector vec_right(std::max(shift_right, 0)); +#endif + + /* shift data if necessary, to have exactly the number of elements equal to + * lsegment size */ + __detail::shift_data(shift_left, shift_right, vec_recvdata, vec_left, + vec_right); + /* copy results to distributed vector's local segment */ + __detail::copy_results(lsegment, shift_left, shift_right, vec_recvdata, + vec_left, vec_right); } // __detail::dist_sort } // namespace __detail