diff --git a/Src/Base/AMReX_ParallelDescriptor.H b/Src/Base/AMReX_ParallelDescriptor.H index a41d393833..39119bdbb5 100644 --- a/Src/Base/AMReX_ParallelDescriptor.H +++ b/Src/Base/AMReX_ParallelDescriptor.H @@ -612,9 +612,9 @@ while ( false ) */ inline int SeqNum () noexcept { return ParallelContext::get_inc_mpi_tag(); } - template Message Asend(const T*, size_t n, int pid, int tag); - template Message Asend(const T*, size_t n, int pid, int tag, MPI_Comm comm); - template Message Asend(const std::vector& buf, int pid, int tag); + template [[nodiscard]] Message Asend(const T*, size_t n, int pid, int tag); + template [[nodiscard]] Message Asend(const T*, size_t n, int pid, int tag, MPI_Comm comm); + template [[nodiscard]] Message Asend(const std::vector& buf, int pid, int tag); template Message Arecv(T*, size_t n, int pid, int tag); template Message Arecv(T*, size_t n, int pid, int tag, MPI_Comm comm); diff --git a/Src/Particle/AMReX_ParticleCommunication.H b/Src/Particle/AMReX_ParticleCommunication.H index 00bf423478..147b74a68f 100644 --- a/Src/Particle/AMReX_ParticleCommunication.H +++ b/Src/Particle/AMReX_ParticleCommunication.H @@ -96,9 +96,12 @@ struct ParticleCopyPlan mutable Vector m_build_stats; mutable Vector m_build_rreqs; - mutable Vector m_particle_stats; + mutable Vector m_particle_rstats; mutable Vector m_particle_rreqs; + mutable Vector m_particle_sstats; + mutable Vector m_particle_sreqs; + Vector m_snd_num_particles; Vector m_rcv_num_particles; @@ -533,12 +536,15 @@ void communicateParticlesStart (const PC& pc, ParticleCopyPlan& plan, const SndB plan.m_nrcvs = int(RcvProc.size()); - plan.m_particle_stats.resize(0); - plan.m_particle_stats.resize(plan.m_nrcvs); + plan.m_particle_rstats.resize(0); + plan.m_particle_rstats.resize(plan.m_nrcvs); plan.m_particle_rreqs.resize(0); plan.m_particle_rreqs.resize(plan.m_nrcvs); + plan.m_particle_sstats.resize(0); + plan.m_particle_sreqs.resize(0); + const int SeqNum = ParallelDescriptor::SeqNum(); // Post receives. @@ -571,10 +577,12 @@ void communicateParticlesStart (const PC& pc, ParticleCopyPlan& plan, const SndB AMREX_ASSERT(plan.m_snd_counts[i] % ParallelDescriptor::sizeof_selected_comm_data_type(plan.m_snd_num_particles[i]*psize) == 0); AMREX_ASSERT(Who >= 0 && Who < NProcs); - ParallelDescriptor::Send((char const*)(snd_buffer.dataPtr()+snd_offset), Cnt, Who, SeqNum, - ParallelContext::CommunicatorSub()); + plan.m_particle_sreqs.push_back(ParallelDescriptor::Asend((char const*)(snd_buffer.dataPtr()+snd_offset), Cnt, Who, SeqNum, + ParallelContext::CommunicatorSub()).req()); } + plan.m_particle_sstats.resize(plan.m_particle_sreqs.size()); + amrex::ignore_unused(pc); #else amrex::ignore_unused(pc,plan,snd_buffer,rcv_buffer); diff --git a/Src/Particle/AMReX_ParticleCommunication.cpp b/Src/Particle/AMReX_ParticleCommunication.cpp index 51cb2866f6..0c995930e1 100644 --- a/Src/Particle/AMReX_ParticleCommunication.cpp +++ b/Src/Particle/AMReX_ParticleCommunication.cpp @@ -158,6 +158,8 @@ void ParticleCopyPlan::buildMPIStart (const ParticleBufferMap& map, Long psize) m_build_rreqs[i] = ParallelDescriptor::Arecv((char*) (m_rcv_data.dataPtr() + offset), Cnt, Who, SeqNum, ParallelContext::CommunicatorSub()).req(); } + Vector snd_reqs; + Vector snd_stats; for (auto i : m_neighbor_procs) { if (i == MyProc) { continue; } @@ -169,8 +171,8 @@ void ParticleCopyPlan::buildMPIStart (const ParticleBufferMap& map, Long psize) AMREX_ASSERT(Who >= 0 && Who < NProcs); AMREX_ASSERT(Cnt < std::numeric_limits::max()); - ParallelDescriptor::Send((char*) snd_data[i].data(), Cnt, Who, SeqNum, - ParallelContext::CommunicatorSub()); + snd_reqs.push_back(ParallelDescriptor::Asend((char*) snd_data[i].data(), Cnt, Who, SeqNum, + ParallelContext::CommunicatorSub()).req()); } m_snd_counts.resize(0); @@ -199,6 +201,10 @@ void ParticleCopyPlan::buildMPIStart (const ParticleBufferMap& map, Long psize) m_snd_pad_correction_d.resize(m_snd_pad_correction_h.size()); Gpu::copy(Gpu::hostToDevice, m_snd_pad_correction_h.begin(), m_snd_pad_correction_h.end(), m_snd_pad_correction_d.begin()); + + snd_stats.resize(0); + snd_stats.resize(snd_reqs.size()); + ParallelDescriptor::Waitall(snd_reqs, snd_stats); #else amrex::ignore_unused(map,psize); #endif @@ -265,8 +271,10 @@ void ParticleCopyPlan::doHandShakeLocal (const Vector& Snds, Vector& #ifdef AMREX_USE_MPI const int SeqNum = ParallelDescriptor::SeqNum(); const auto num_rcvs = static_cast(m_neighbor_procs.size()); - Vector stats(num_rcvs); + Vector rstats(num_rcvs); Vector rreqs(num_rcvs); + Vector sstats(num_rcvs); + Vector sreqs(num_rcvs); // Post receives for (int i = 0; i < num_rcvs; ++i) @@ -288,13 +296,14 @@ void ParticleCopyPlan::doHandShakeLocal (const Vector& Snds, Vector& AMREX_ASSERT(Who >= 0 && Who < ParallelContext::NProcsSub()); - ParallelDescriptor::Send(&Snds[Who], Cnt, Who, SeqNum, - ParallelContext::CommunicatorSub()); + sreqs[i] = ParallelDescriptor::Asend(&Snds[Who], Cnt, Who, SeqNum, + ParallelContext::CommunicatorSub()).req(); } if (num_rcvs > 0) { - ParallelDescriptor::Waitall(rreqs, stats); + ParallelDescriptor::Waitall(sreqs, sstats); + ParallelDescriptor::Waitall(rreqs, rstats); } #else amrex::ignore_unused(Snds,Rcvs); @@ -339,8 +348,10 @@ void ParticleCopyPlan::doHandShakeGlobal (const Vector& Snds, Vector ParallelDescriptor::Mpi_typemap::type(), MPI_SUM, ParallelContext::CommunicatorSub()); - Vector stats(num_rcvs); + Vector rstats(num_rcvs); Vector rreqs(num_rcvs); + Vector sstats; + Vector sreqs; Vector num_bytes_rcv(num_rcvs); for (int i = 0; i < static_cast(num_rcvs); ++i) @@ -352,15 +363,17 @@ void ParticleCopyPlan::doHandShakeGlobal (const Vector& Snds, Vector { if (Snds[i] == 0) { continue; } const Long Cnt = 1; - MPI_Send( &Snds[i], Cnt, ParallelDescriptor::Mpi_typemap::type(), i, SeqNum, - ParallelContext::CommunicatorSub()); + sreqs.push_back(ParallelDescriptor::Asend( &Snds[i], Cnt, i, SeqNum, ParallelContext::CommunicatorSub()).req()); } - MPI_Waitall(static_cast(num_rcvs), rreqs.data(), stats.data()); + sstats.resize(0); + sstats.resize(sreqs.size()); + ParallelDescriptor::Waitall(sreqs, sstats); + ParallelDescriptor::Waitall(rreqs, rstats); for (int i = 0; i < num_rcvs; ++i) { - const auto Who = stats[i].MPI_SOURCE; + const auto Who = rstats[i].MPI_SOURCE; Rcvs[Who] = num_bytes_rcv[i]; } #else @@ -372,9 +385,13 @@ void amrex::communicateParticlesFinish (const ParticleCopyPlan& plan) { BL_PROFILE("amrex::communicateParticlesFinish"); #ifdef AMREX_USE_MPI + if (plan.m_NumSnds > 0) + { + ParallelDescriptor::Waitall(plan.m_particle_sreqs, plan.m_particle_sstats); + } if (plan.m_nrcvs > 0) { - ParallelDescriptor::Waitall(plan.m_particle_rreqs, plan.m_particle_stats); + ParallelDescriptor::Waitall(plan.m_particle_rreqs, plan.m_particle_rstats); } #else amrex::ignore_unused(plan);