Skip to content

Commit c47bc9e

Browse files
authored
Merge pull request #7 from ryanstocks00/recv
Use recv in naive distributer
2 parents a02209b + b3ad32a commit c47bc9e

4 files changed

Lines changed: 228 additions & 150 deletions

File tree

include/dynampi/impl/hierarchical_distributor.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ template <typename TaskT, typename ResultT, typename... Options>
2626
class HierarchicalMPIWorkDistributor : public BaseMPIWorkDistributor<TaskT, ResultT, Options...> {
2727
using Base = BaseMPIWorkDistributor<TaskT, ResultT, Options...>;
2828

29+
public:
2930
struct Config {
3031
MPI_Comm comm = MPI_COMM_WORLD;
3132
int manager_rank = 0;

include/dynampi/impl/naive_distributor.hpp

Lines changed: 68 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class NaiveMPIWorkDistributor {
3030
MPI_Comm comm = MPI_COMM_WORLD;
3131
int manager_rank = 0;
3232
bool auto_run_workers = true;
33+
bool use_immediate_recv = false;
34+
int max_result_size = 1024; // Maximum expected size for RESULT messages when using immediate
35+
// recv. Must be large enough to hold the largest expected RESULT
36+
// message. If a message exceeds this size, behavior is undefined.
3337
};
3438

3539
private:
@@ -100,10 +104,11 @@ class NaiveMPIWorkDistributor {
100104
void run_worker() {
101105
assert(_communicator.rank() != _config.manager_rank && "Worker cannot run on the manager rank");
102106
using task_type = MPI_Type<TaskT>;
103-
_communicator.send(nullptr, _config.manager_rank, Tag::REQUEST);
107+
// Send REQUEST as 0 elements of ResultT so manager can recv_any(ResultT&) for both REQUEST and
108+
// RESULT
109+
_communicator.template send_empty<ResultT>(_config.manager_rank, Tag::REQUEST);
104110
while (true) {
105-
MPI_Status status;
106-
DYNAMPI_MPI_CHECK(MPI_Probe, (MPI_ANY_SOURCE, MPI_ANY_TAG, _communicator.get(), &status));
111+
MPI_Status status = _communicator.probe();
107112
if (status.MPI_TAG == Tag::DONE) {
108113
_communicator.recv_empty_message(_config.manager_rank, Tag::DONE);
109114
break;
@@ -241,31 +246,75 @@ class NaiveMPIWorkDistributor {
241246

242247
int worker_for_idx(int idx) const { return (idx < _config.manager_rank) ? idx : (idx + 1); }
243248

249+
void process_result_message(const MPI_Status& status, ResultT&& result, int count) {
250+
using result_type = MPI_Type<ResultT>;
251+
int worker_idx = status.MPI_SOURCE - (status.MPI_SOURCE > _config.manager_rank);
252+
int64_t task_idx = _worker_current_task_indices[worker_idx];
253+
_worker_current_task_indices[worker_idx] = -1;
254+
assert(task_idx >= 0 && "Task index should be valid");
255+
if (static_cast<uint64_t>(task_idx) >= _results.size()) {
256+
_results.resize(task_idx + 1);
257+
}
258+
if constexpr (result_type::resize_required) {
259+
result_type::resize(_results[task_idx], count);
260+
}
261+
_results[task_idx] = std::move(result);
262+
_results_received++;
263+
}
264+
244265
void receive_from_any_worker() {
245266
assert(_communicator.rank() == _config.manager_rank &&
246267
"Only the manager can receive results and send tasks");
247268
assert(_communicator.size() > 1 &&
248269
"There should be at least one worker to receive results from");
249270
using result_type = MPI_Type<ResultT>;
250271
MPI_Status status;
251-
DYNAMPI_MPI_CHECK(MPI_Probe, (MPI_ANY_SOURCE, MPI_ANY_TAG, _communicator.get(), &status));
252-
if (status.MPI_TAG == Tag::RESULT) {
253-
int64_t task_idx = _worker_current_task_indices[status.MPI_SOURCE -
254-
(status.MPI_SOURCE > _config.manager_rank)];
255-
_worker_current_task_indices[status.MPI_SOURCE - (status.MPI_SOURCE > _config.manager_rank)] =
256-
-1;
257-
assert(task_idx >= 0 && "Task index should be valid");
258-
if (static_cast<uint64_t>(task_idx) >= _results.size()) {
259-
_results.resize(task_idx + 1);
272+
273+
if (_config.use_immediate_recv) {
274+
// Immediate receive mode: REQUEST and RESULT both use type ResultT (REQUEST = 0 elements).
275+
// recv_any(buffer) receives into the same buffer type for both.
276+
if constexpr (result_type::resize_required) {
277+
ResultT buffer;
278+
result_type::resize(buffer, _config.max_result_size);
279+
status = _communicator.recv_any(buffer);
280+
281+
if (status.MPI_TAG == Tag::RESULT) {
282+
int count;
283+
DYNAMPI_MPI_CHECK(MPI_Get_count, (&status, result_type::value, &count));
284+
// Resize buffer to actual received count (may be less than max_result_size)
285+
result_type::resize(buffer, count);
286+
process_result_message(status, std::move(buffer), count);
287+
} else {
288+
assert(status.MPI_TAG == Tag::REQUEST && "Unexpected tag received");
289+
}
290+
} else {
291+
ResultT buffer;
292+
status = _communicator.recv_any(buffer);
293+
294+
if (status.MPI_TAG == Tag::RESULT) {
295+
int count;
296+
DYNAMPI_MPI_CHECK(MPI_Get_count, (&status, result_type::value, &count));
297+
process_result_message(status, std::move(buffer), count);
298+
} else {
299+
assert(status.MPI_TAG == Tag::REQUEST && "Unexpected tag received");
300+
}
260301
}
261-
int count;
262-
DYNAMPI_MPI_CHECK(MPI_Get_count, (&status, result_type::value, &count));
263-
result_type::resize(_results[task_idx], count);
264-
_communicator.recv(_results[task_idx], status.MPI_SOURCE, Tag::RESULT);
265-
_results_received++;
266302
} else {
267-
assert(status.MPI_TAG == Tag::REQUEST && "Unexpected tag received in worker");
268-
_communicator.recv_empty_message(status.MPI_SOURCE, Tag::REQUEST);
303+
// Probe mode: use probe to check message size before receiving
304+
status = _communicator.probe();
305+
if (status.MPI_TAG == Tag::RESULT) {
306+
int count;
307+
DYNAMPI_MPI_CHECK(MPI_Get_count, (&status, result_type::value, &count));
308+
ResultT buffer;
309+
if constexpr (result_type::resize_required) {
310+
result_type::resize(buffer, count);
311+
}
312+
_communicator.recv(buffer, status.MPI_SOURCE, Tag::RESULT);
313+
process_result_message(status, std::move(buffer), count);
314+
} else {
315+
assert(status.MPI_TAG == Tag::REQUEST && "Unexpected tag received in worker");
316+
_communicator.recv_empty<ResultT>(status.MPI_SOURCE, Tag::REQUEST);
317+
}
269318
}
270319
_free_worker_indices.push(status.MPI_SOURCE);
271320
}

include/dynampi/mpi/mpi_communicator.hpp

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,16 +156,44 @@ class MPICommunicator {
156156
template <typename T>
157157
inline void recv(T& data, int source, int tag = 0) {
158158
using mpi_type = MPI_Type<T>;
159+
MPI_Status status;
159160
DYNAMPI_MPI_CHECK(MPI_Recv, (mpi_type::ptr(data), mpi_type::count(data), mpi_type::value,
160-
source, tag, _comm, MPI_STATUS_IGNORE));
161+
source, tag, _comm, &status));
161162
if constexpr (statistics_mode != StatisticsMode::None) {
162163
_statistics.recv_count++;
164+
int actual_count;
165+
DYNAMPI_MPI_CHECK(MPI_Get_count, (&status, mpi_type::value, &actual_count));
163166
int size;
164167
MPI_Type_size(mpi_type::value, &size);
165-
_statistics.bytes_received += mpi_type::count(data) * size;
168+
_statistics.bytes_received += actual_count * size;
166169
}
167170
}
168171

172+
// Probe for a message, returns status
173+
inline MPI_Status probe(int source = MPI_ANY_SOURCE, int tag = MPI_ANY_TAG) {
174+
MPI_Status status;
175+
DYNAMPI_MPI_CHECK(MPI_Probe, (source, tag, _comm, &status));
176+
return status;
177+
}
178+
179+
// Receive with MPI_ANY_SOURCE/MPI_ANY_TAG and return status
180+
template <typename T>
181+
inline MPI_Status recv_any(T& data, int source = MPI_ANY_SOURCE, int tag = MPI_ANY_TAG) {
182+
using mpi_type = MPI_Type<T>;
183+
MPI_Status status;
184+
DYNAMPI_MPI_CHECK(MPI_Recv, (mpi_type::ptr(data), mpi_type::count(data), mpi_type::value,
185+
source, tag, _comm, &status));
186+
if constexpr (statistics_mode != StatisticsMode::None) {
187+
_statistics.recv_count++;
188+
int actual_count;
189+
DYNAMPI_MPI_CHECK(MPI_Get_count, (&status, mpi_type::value, &actual_count));
190+
int size;
191+
MPI_Type_size(mpi_type::value, &size);
192+
_statistics.bytes_received += actual_count * size;
193+
}
194+
return status;
195+
}
196+
169197
template <typename T>
170198
inline void broadcast(T& data, int root = 0) {
171199
using mpi_type = MPI_Type<T>;
@@ -192,6 +220,28 @@ class MPICommunicator {
192220
}
193221
}
194222

223+
/// Sends 0 elements of type T (same type as recv buffer) so that recv_any(T&) can receive any
224+
/// worker message (REQUEST or RESULT) into a single buffer type.
225+
template <typename T>
226+
inline void send_empty(int dest, int tag = 0) {
227+
using mpi_type = MPI_Type<T>;
228+
DYNAMPI_MPI_CHECK(MPI_Send, (nullptr, 0, mpi_type::value, dest, tag, _comm));
229+
if constexpr (statistics_mode != StatisticsMode::None) {
230+
_statistics.send_count++;
231+
}
232+
}
233+
234+
/// Receives 0 elements of type T. Use when the sender used send_empty<T>.
235+
template <typename T>
236+
inline void recv_empty(int source, int tag = 0) {
237+
using mpi_type = MPI_Type<T>;
238+
DYNAMPI_MPI_CHECK(MPI_Recv,
239+
(nullptr, 0, mpi_type::value, source, tag, _comm, MPI_STATUS_IGNORE));
240+
if constexpr (statistics_mode != StatisticsMode::None) {
241+
_statistics.recv_count++;
242+
}
243+
}
244+
195245
[[nodiscard]] MPI_Comm get() const { return _comm; }
196246
};
197247

0 commit comments

Comments
 (0)