diff --git a/dash/include/dash/algorithm/Copy.h b/dash/include/dash/algorithm/Copy.h index b0c42304a..812c2da22 100644 --- a/dash/include/dash/algorithm/Copy.h +++ b/dash/include/dash/algorithm/Copy.h @@ -95,14 +95,19 @@ namespace internal { template struct local_copy_chunk { - const InputValueType *src; - OutputValueType *dest; - const size_t size; + using nonconst_input_value_type = typename std::remove_const::type; + using nonconst_output_value_type = typename std::remove_const::type; + + const nonconst_input_value_type *src; + nonconst_output_value_type *dest; + const size_t size; }; template -void do_local_copies( - std::vector>& chunks) +using local_copy_chunks = std::vector>; + +template +void do_local_copies(local_copy_chunks& chunks) { for (auto& chunk : chunks) { std::copy(chunk.src, chunk.src + chunk.size, chunk.dest); @@ -136,10 +141,13 @@ template < typename ValueType, typename GlobInputIt > ValueType * copy_impl( - GlobInputIt begin, - GlobInputIt end, - ValueType * out_first, - std::vector * handles) + GlobInputIt begin, + GlobInputIt end, + ValueType * out_first, + std::vector * handles, + local_copy_chunks< + typename GlobInputIt::value_type, + ValueType> & local_chunks) { DASH_LOG_TRACE("dash::internal::copy_impl() global -> local", "in_first:", begin.pos(), @@ -166,8 +174,6 @@ ValueType * copy_impl( ContiguousRangeSet range_set{begin, end}; - std::vector> local_chunks; - // // Copy elements from every unit: // @@ -211,8 +217,6 @@ ValueType * copy_impl( num_elem_copied += num_copy_elem; } - do_local_copies(local_chunks); - DASH_ASSERT_EQ(num_elem_copied, num_elem_total, "Failed to find all contiguous subranges in range"); @@ -233,10 +237,13 @@ template < typename ValueType, typename GlobOutputIt > GlobOutputIt copy_impl( - ValueType * begin, - ValueType * end, - GlobOutputIt out_first, - std::vector * handles) + ValueType * begin, + ValueType * end, + GlobOutputIt out_first, + std::vector * handles, + local_copy_chunks< + ValueType, + typename GlobOutputIt::value_type> & local_chunks) { DASH_LOG_TRACE("dash::copy_impl() local -> global", "in_first:", begin, @@ -266,8 +273,6 @@ GlobOutputIt copy_impl( ContiguousRangeSet range_set{out_first, out_last}; - std::vector> local_chunks; - auto in_first = begin; // @@ -312,8 +317,6 @@ GlobOutputIt copy_impl( num_elem_copied += num_copy_elem; } - do_local_copies(local_chunks); - DASH_ASSERT_EQ(num_elem_copied, num_elem_total, "Failed to find all contiguous subranges in range"); @@ -350,9 +353,10 @@ dash::Future copy_async( } auto handles = std::make_shared>(); - - auto out_last = dash::internal::copy_impl(in_first, in_last, - out_first, handles.get()); + dash::internal::local_copy_chunks local_chunks; + auto out_last = dash::internal::copy_impl(in_first, in_last, out_first, + handles.get(), local_chunks); + dash::internal::do_local_copies(local_chunks); if (handles->empty()) { DASH_LOG_TRACE("dash::copy_async", "all transfers completed"); @@ -434,24 +438,29 @@ ValueType * copy( } ValueType *out_last; + dash::internal::local_copy_chunks local_chunks; if (UseHandles) { std::vector handles; out_last = dash::internal::copy_impl(in_first, in_last, out_first, - &handles); + &handles, + local_chunks); if (!handles.empty()) { DASH_LOG_TRACE("dash::copy", "Waiting for remote transfers to complete,", "num_handles: ", handles.size()); dart_waitall_local(handles.data(), handles.size()); } + dash::internal::do_local_copies(local_chunks); } else { out_last = dash::internal::copy_impl(in_first, in_last, out_first, - nullptr); + nullptr, + local_chunks); DASH_LOG_TRACE("dash::copy", "Waiting for remote transfers to complete"); + dash::internal::do_local_copies(local_chunks); dart_flush_local_all(in_first.dart_gptr()); } @@ -484,10 +493,13 @@ dash::Future copy_async( } auto handles = std::make_shared>(); + dash::internal::local_copy_chunks local_chunks; auto out_last = dash::internal::copy_impl(in_first, in_last, out_first, - handles.get()); + handles.get(), + local_chunks); + dash::internal::do_local_copies(local_chunks); if (handles->empty()) { return dash::Future(out_last); @@ -556,12 +568,15 @@ GlobOutputIt copy( DASH_LOG_TRACE("dash::copy()", "blocking, local to global"); // handles to wait on at the end GlobOutputIt out_last; + dash::internal::local_copy_chunks local_chunks; if (UseHandles) { std::vector handles; out_last = dash::internal::copy_impl(in_first, in_last, out_first, - &handles); + &handles, + local_chunks); + dash::internal::do_local_copies(local_chunks); if (!handles.empty()) { DASH_LOG_TRACE("dash::copy", "Waiting for remote transfers to complete,", @@ -572,8 +587,10 @@ GlobOutputIt copy( out_last = dash::internal::copy_impl(in_first, in_last, out_first, - nullptr); + nullptr, + local_chunks); DASH_LOG_TRACE("dash::copy", "Waiting for remote transfers to complete"); + dash::internal::do_local_copies(local_chunks); dart_flush_all(out_first.dart_gptr()); } return out_last; @@ -629,25 +646,166 @@ copy_async( } #endif +struct ActiveDestination{}; +struct ActiveSource{}; + /** * Specialization of \c dash::copy as global-to-global blocking copy * operation. * * \ingroup DashAlgorithms */ -template +template < + class GlobInputIt, + class GlobOutputIt, + bool UseHandles = false> GlobOutputIt copy( - GlobInputIt /*in_first*/, - GlobInputIt /*in_last*/, - GlobOutputIt /*out_first*/) + GlobInputIt in_first, + GlobInputIt in_last, + GlobOutputIt out_first, + ActiveDestination /*unused*/) +{ + DASH_LOG_TRACE("dash::copy()", "blocking, global to global, active destination"); + + using size_type = typename GlobInputIt::size_type; + using input_value_type = typename GlobInputIt::value_type; + using output_value_type = typename GlobOutputIt::value_type; + + size_type num_elem_total = dash::distance(in_first, in_last); + if (num_elem_total <= 0) { + DASH_LOG_TRACE("dash::copy", "input range empty"); + return out_first; + } + + auto g_out_first = out_first; + auto g_out_last = g_out_first + num_elem_total; + + internal::ContiguousRangeSet range_set{g_out_first, g_out_last}; + + const auto & out_team = out_first.team(); + out_team.barrier(); + + std::vector handles; + std::vector* handles_arg = UseHandles ? &handles : nullptr; + + dash::internal::local_copy_chunks local_chunks; + + size_type num_elem_processed = 0; + + for (auto range : range_set) { + + auto cur_out_first = range.first; + auto num_copy_elem = range.second; + + DASH_ASSERT_GT(num_copy_elem, 0, + "Number of elements to copy is 0"); + + // handle local data only + if (cur_out_first.is_local()) { + auto dest_ptr = cur_out_first.local(); + auto src_ptr = in_first + num_elem_processed; + internal::copy_impl(src_ptr, + src_ptr + num_copy_elem, + dest_ptr, + handles_arg, + local_chunks); + } + num_elem_processed += num_copy_elem; + } + + dash::internal::do_local_copies(local_chunks); + + if (!handles.empty()) { + DASH_LOG_TRACE("dash::copy", "Waiting for remote transfers to complete,", + "num_handles: ", handles.size()); + dart_waitall_local(handles.data(), handles.size()); + } else if (!UseHandles) { + dart_flush_local_all(in_first.dart_gptr()); + } + out_team.barrier(); + + DASH_ASSERT_EQ(num_elem_processed, num_elem_total, + "Failed to find all contiguous subranges in range"); + + return g_out_last; +} + +/** + * Specialization of \c dash::copy as global-to-global blocking copy + * operation. + * + * \ingroup DashAlgorithms + */ +template < + class GlobInputIt, + class GlobOutputIt, + bool UseHandles = false> +GlobOutputIt copy( + GlobInputIt in_first, + GlobInputIt in_last, + GlobOutputIt out_first, + ActiveSource /*unused*/) { DASH_LOG_TRACE("dash::copy()", "blocking, global to global"); - // TODO: - // - Implement adapter for local-to-global dash::copy here - // - Return if global input range has no local sub-range + using size_type = typename GlobInputIt::size_type; + using input_value_type = typename GlobInputIt::value_type; + using output_value_type = typename GlobOutputIt::value_type; + + size_type num_elem_total = dash::distance(in_first, in_last); + if (num_elem_total <= 0) { + DASH_LOG_TRACE("dash::copy", "input range empty"); + return out_first; + } + + internal::ContiguousRangeSet range_set{in_first, in_last}; + + const auto & in_team = in_first.team(); + in_team.barrier(); + + std::vector handles; + std::vector* handles_arg = UseHandles ? &handles : nullptr; + + dash::internal::local_copy_chunks local_chunks; + + size_type num_elem_processed = 0; + + for (auto range : range_set) { + + auto cur_in_first = range.first; + auto num_copy_elem = range.second; + + DASH_ASSERT_GT(num_copy_elem, 0, + "Number of elements to copy is 0"); + + // handle local data only + if (cur_in_first.is_local()) { + auto src_ptr = cur_in_first.local(); + auto dest_ptr = out_first + num_elem_processed; + internal::copy_impl(src_ptr, + src_ptr + num_copy_elem, + dest_ptr, + handles_arg, + local_chunks); + } + num_elem_processed += num_copy_elem; + } + + internal::do_local_copies(local_chunks); + + if (!handles.empty()) { + DASH_LOG_TRACE("dash::copy", "Waiting for remote transfers to complete,", + "num_handles: ", handles.size()); + dart_waitall(handles.data(), handles.size()); + } else if (!UseHandles) { + dart_flush_all(out_first.dart_gptr()); + } + in_team.barrier(); + + DASH_ASSERT_EQ(num_elem_processed, num_elem_total, + "Failed to find all contiguous subranges in range"); - return GlobOutputIt(); + return out_first + num_elem_total; } #endif // DOXYGEN diff --git a/dash/test/algorithm/CopyTest.cc b/dash/test/algorithm/CopyTest.cc index 8fdae7be1..333de13a2 100644 --- a/dash/test/algorithm/CopyTest.cc +++ b/dash/test/algorithm/CopyTest.cc @@ -1040,3 +1040,71 @@ TEST_F(CopyTest, InputOutputTypeTest) ASSERT_TRUE_U((dash::internal::is_dash_copyable::value)); } + +TEST_F(CopyTest, MatrixTransfersGlobalToGlobal) +{ + if (_dash_size < 2) { + SKIP_TEST_MSG("At least 2 units required for this test."); + } + + using TeamSpecT = dash::TeamSpec<2>; + using MatrixT = dash::NArray; + using PatternT = typename MatrixT::pattern_type; + using SizeSpecT = dash::SizeSpec<2>; + using DistSpecT = dash::DistributionSpec<2>; + + auto& team_all = dash::Team::All(); + TeamSpecT team_all_spec(team_all.size(), 1); + team_all_spec.balance_extents(); + + auto size_spec = SizeSpecT(4*team_all_spec.extent(1), + 4*team_all_spec.extent(1)); + auto dist_spec = DistSpecT(dash::BLOCKED, dash::BLOCKED); + + MatrixT grid_more(size_spec, dist_spec, team_all, team_all_spec); + dash::fill(grid_more.begin(), grid_more.end(), (double)team_all.myid()); + team_all.barrier(); + + // create a smaller team + dash::Team& team_fewer= team_all.split(2); + team_all.barrier(); + if (!team_fewer.is_null() && 0 == team_fewer.position()) { + TeamSpecT team_fewer_spec(team_fewer.size(), 1); + team_fewer_spec.balance_extents(); + + MatrixT grid_fewer(size_spec, dist_spec, team_fewer, team_fewer_spec); + dash::fill(grid_fewer.begin(), grid_fewer.end(), -1.0); + + auto lextents= grid_fewer.pattern().local_extents(); + + dash::copy(grid_more.begin(), grid_more.end(), + grid_fewer.begin(), dash::ActiveDestination()); + + if (team_fewer.myid() == 0) { + auto gextents = grid_fewer.extents(); + for (uint32_t y = 0; y < gextents[0]; ++y) { + for (uint32_t x = 0; x < gextents[1]; ++x) { + ASSERT_EQ_U(grid_more(y, x), grid_fewer(y, x)); + } + } + } + + team_fewer.barrier(); + + dash::fill(grid_fewer.begin(), grid_fewer.end(), (double)team_fewer.myid()); + + dash::copy(grid_fewer.begin(), grid_fewer.end(), + grid_more.begin(), dash::ActiveSource()); + + if (team_fewer.myid() == 0) { + auto gextents = grid_fewer.extents(); + for (uint32_t y = 0; y < gextents[0]; ++y) { + for (uint32_t x = 0; x < gextents[1]; ++x) { + ASSERT_EQ_U(grid_more(y, x), grid_fewer(y, x)); + } + } + } + + team_fewer.barrier(); + } +} diff --git a/dash/test/container/MatrixTest.cc b/dash/test/container/MatrixTest.cc index db65cef75..6bb322fea 100644 --- a/dash/test/container/MatrixTest.cc +++ b/dash/test/container/MatrixTest.cc @@ -745,9 +745,10 @@ TEST_F(MatrixTest, BlockCopy) LOG_MESSAGE("Team barrier passed"); // Copy block 1 of matrix_a to block 0 of matrix_b: - dash::copy(matrix_a.block(1).begin(), - matrix_a.block(1).end(), - matrix_b.block(0).begin()); + dash::copy(matrix_a.block(1).begin(), + matrix_a.block(1).end(), + matrix_b.block(0).begin(), + dash::ActiveSource()); LOG_MESSAGE("Wait for team barrier ..."); dash::barrier();