@@ -30,6 +30,10 @@ class NaiveMPIWorkDistributor {
3030 MPI_Comm comm = MPI_COMM_WORLD;
3131 int manager_rank = 0 ;
3232 bool auto_run_workers = true ;
33+ bool use_immediate_recv = false ;
34+ int max_result_size = 1024 ; // Maximum expected size for RESULT messages when using immediate
35+ // recv. Must be large enough to hold the largest expected RESULT
36+ // message. If a message exceeds this size, behavior is undefined.
3337 };
3438
3539 private:
@@ -100,10 +104,11 @@ class NaiveMPIWorkDistributor {
100104 void run_worker () {
101105 assert (_communicator.rank () != _config.manager_rank && " Worker cannot run on the manager rank" );
102106 using task_type = MPI_Type<TaskT>;
103- _communicator.send (nullptr , _config.manager_rank , Tag::REQUEST);
107+ // Send REQUEST as 0 elements of ResultT so manager can recv_any(ResultT&) for both REQUEST and
108+ // RESULT
109+ _communicator.template send_empty <ResultT>(_config.manager_rank , Tag::REQUEST);
104110 while (true ) {
105- MPI_Status status;
106- DYNAMPI_MPI_CHECK (MPI_Probe, (MPI_ANY_SOURCE, MPI_ANY_TAG, _communicator.get (), &status));
111+ MPI_Status status = _communicator.probe ();
107112 if (status.MPI_TAG == Tag::DONE) {
108113 _communicator.recv_empty_message (_config.manager_rank , Tag::DONE);
109114 break ;
@@ -241,31 +246,75 @@ class NaiveMPIWorkDistributor {
241246
242247 int worker_for_idx (int idx) const { return (idx < _config.manager_rank ) ? idx : (idx + 1 ); }
243248
249+ void process_result_message (const MPI_Status& status, ResultT&& result, int count) {
250+ using result_type = MPI_Type<ResultT>;
251+ int worker_idx = status.MPI_SOURCE - (status.MPI_SOURCE > _config.manager_rank );
252+ int64_t task_idx = _worker_current_task_indices[worker_idx];
253+ _worker_current_task_indices[worker_idx] = -1 ;
254+ assert (task_idx >= 0 && " Task index should be valid" );
255+ if (static_cast <uint64_t >(task_idx) >= _results.size ()) {
256+ _results.resize (task_idx + 1 );
257+ }
258+ if constexpr (result_type::resize_required) {
259+ result_type::resize (_results[task_idx], count);
260+ }
261+ _results[task_idx] = std::move (result);
262+ _results_received++;
263+ }
264+
244265 void receive_from_any_worker () {
245266 assert (_communicator.rank () == _config.manager_rank &&
246267 " Only the manager can receive results and send tasks" );
247268 assert (_communicator.size () > 1 &&
248269 " There should be at least one worker to receive results from" );
249270 using result_type = MPI_Type<ResultT>;
250271 MPI_Status status;
251- DYNAMPI_MPI_CHECK (MPI_Probe, (MPI_ANY_SOURCE, MPI_ANY_TAG, _communicator.get (), &status));
252- if (status.MPI_TAG == Tag::RESULT) {
253- int64_t task_idx = _worker_current_task_indices[status.MPI_SOURCE -
254- (status.MPI_SOURCE > _config.manager_rank )];
255- _worker_current_task_indices[status.MPI_SOURCE - (status.MPI_SOURCE > _config.manager_rank )] =
256- -1 ;
257- assert (task_idx >= 0 && " Task index should be valid" );
258- if (static_cast <uint64_t >(task_idx) >= _results.size ()) {
259- _results.resize (task_idx + 1 );
272+
273+ if (_config.use_immediate_recv ) {
274+ // Immediate receive mode: REQUEST and RESULT both use type ResultT (REQUEST = 0 elements).
275+ // recv_any(buffer) receives into the same buffer type for both.
276+ if constexpr (result_type::resize_required) {
277+ ResultT buffer;
278+ result_type::resize (buffer, _config.max_result_size );
279+ status = _communicator.recv_any (buffer);
280+
281+ if (status.MPI_TAG == Tag::RESULT) {
282+ int count;
283+ DYNAMPI_MPI_CHECK (MPI_Get_count, (&status, result_type::value, &count));
284+ // Resize buffer to actual received count (may be less than max_result_size)
285+ result_type::resize (buffer, count);
286+ process_result_message (status, std::move (buffer), count);
287+ } else {
288+ assert (status.MPI_TAG == Tag::REQUEST && " Unexpected tag received" );
289+ }
290+ } else {
291+ ResultT buffer;
292+ status = _communicator.recv_any (buffer);
293+
294+ if (status.MPI_TAG == Tag::RESULT) {
295+ int count;
296+ DYNAMPI_MPI_CHECK (MPI_Get_count, (&status, result_type::value, &count));
297+ process_result_message (status, std::move (buffer), count);
298+ } else {
299+ assert (status.MPI_TAG == Tag::REQUEST && " Unexpected tag received" );
300+ }
260301 }
261- int count;
262- DYNAMPI_MPI_CHECK (MPI_Get_count, (&status, result_type::value, &count));
263- result_type::resize (_results[task_idx], count);
264- _communicator.recv (_results[task_idx], status.MPI_SOURCE , Tag::RESULT);
265- _results_received++;
266302 } else {
267- assert (status.MPI_TAG == Tag::REQUEST && " Unexpected tag received in worker" );
268- _communicator.recv_empty_message (status.MPI_SOURCE , Tag::REQUEST);
303+ // Probe mode: use probe to check message size before receiving
304+ status = _communicator.probe ();
305+ if (status.MPI_TAG == Tag::RESULT) {
306+ int count;
307+ DYNAMPI_MPI_CHECK (MPI_Get_count, (&status, result_type::value, &count));
308+ ResultT buffer;
309+ if constexpr (result_type::resize_required) {
310+ result_type::resize (buffer, count);
311+ }
312+ _communicator.recv (buffer, status.MPI_SOURCE , Tag::RESULT);
313+ process_result_message (status, std::move (buffer), count);
314+ } else {
315+ assert (status.MPI_TAG == Tag::REQUEST && " Unexpected tag received in worker" );
316+ _communicator.recv_empty <ResultT>(status.MPI_SOURCE , Tag::REQUEST);
317+ }
269318 }
270319 _free_worker_indices.push (status.MPI_SOURCE );
271320 }
0 commit comments