Skip to content

Commit 7d8b4f8

Browse files
siyengarmeta-codesync[bot]
authored andcommitted
add pipelining send/recv
Summary: Instead of having 1 databuffer (or staging buffer), make it `pipelineDepth` buffers. ``` * Pipeline Flow: * Time │ Sender │ Receiver * ─────┼─────────────────────┼────────────────── * 0 │ Step 0 → Slot[0] │ - * 1 │ Step 1 → Slot[1] │ Step 0 ← Slot[0] * 2 │ Step 2 → Slot[2] │ Step 1 ← Slot[1] * 3 │ Step 3 → Slot[3] │ Step 2 ← Slot[2] * 4 │ Step 4 → Slot[0] │ Step 3 ← Slot[3] * 5 │ Step 5 → Slot[1] │ Step 4 ← Slot[0] * * The sender can be ahead by up to 3 steps, hiding NVLink latency and * improving throughput ``` Reviewed By: cenzhaometa Differential Revision: D88463702 fbshipit-source-id: d1414c50e605be6c8a6585b3b0019150d3ce266b
1 parent 4f34173 commit 7d8b4f8

File tree

4 files changed

+900
-225
lines changed

4 files changed

+900
-225
lines changed

comms/pipes/P2pNvlTransport.cc

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,30 @@ P2pNvlTransport::P2pNvlTransport(
1515
nRanks_(nRanks),
1616
mpiBootstrap_(mpiBootstrap),
1717
config_(p2pNvlTransportConfig) {
18+
// Calculate total buffer sizes with pipelining
19+
const std::size_t totalDataBufferSize =
20+
config_.pipelineDepth * config_.dataBufferSize;
21+
22+
// Calculate state buffer size based on chunk size and pipeline depth
23+
const std::size_t numChunksPerStep =
24+
(config_.dataBufferSize + config_.chunkSize - 1) / config_.chunkSize;
25+
const std::size_t totalNumChunks = config_.pipelineDepth * numChunksPerStep;
26+
const std::size_t stateBufferSize = totalNumChunks * sizeof(ChunkState<int>);
27+
1828
dataBuffer_d_ =
19-
std::make_unique<meta::comms::DeviceBuffer>(config_.dataBufferSize);
29+
std::make_unique<meta::comms::DeviceBuffer>(totalDataBufferSize);
2030
dataBufferHandler_ = std::make_unique<meta::comms::IpcMemHandler>(
2131
mpiBootstrap_, myRank, nRanks_);
2232
dataBufferHandler_->addSelfDeviceMemPtr(dataBuffer_d_->get());
2333

24-
// Calculate state buffer size based on chunk size
25-
// stateBufferSize = dataBufferSize / chunkSize * sizeof(ChunkState<int>)
26-
const std::size_t numChunks =
27-
(config_.dataBufferSize + config_.chunkSize - 1) / config_.chunkSize;
28-
const std::size_t stateBufferSize = numChunks * sizeof(ChunkState<int>);
29-
3034
stateBuffer_d_ = std::make_unique<meta::comms::DeviceBuffer>(stateBufferSize);
3135
stateBufferHandler_ = std::make_unique<meta::comms::IpcMemHandler>(
3236
mpiBootstrap_, myRank, nRanks_);
3337
stateBufferHandler_->addSelfDeviceMemPtr(stateBuffer_d_->get());
3438

35-
// Initialize state buffer to -1
39+
// Initialize state buffer to -1 for all pipeline slots
3640
auto statePtr = static_cast<ChunkState<int>*>(stateBuffer_d_->get());
37-
std::vector<ChunkState<int>> initStates(numChunks, ChunkState<int>(-1));
41+
std::vector<ChunkState<int>> initStates(totalNumChunks, ChunkState<int>(-1));
3842
auto cudaErr = cudaMemcpy(
3943
statePtr, initStates.data(), stateBufferSize, cudaMemcpyDefault);
4044
if (cudaErr != cudaSuccess) {
@@ -50,7 +54,9 @@ void P2pNvlTransport::exchange() {
5054

5155
P2pNvlTransportDevice P2pNvlTransport::getTransportDevice(int peerRank) {
5256
P2pNvlTransportOptions options{
53-
.dataBufferSize = config_.dataBufferSize, .chunkSize = config_.chunkSize};
57+
.dataBufferSize = config_.dataBufferSize,
58+
.chunkSize = config_.chunkSize,
59+
.pipelineDepth = config_.pipelineDepth};
5460

5561
LocalState localState{
5662
.dataBuffer = static_cast<char*>(dataBuffer_d_->get()),

comms/pipes/P2pNvlTransport.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ namespace comms::pipes {
1313
struct P2pNvlTransportConfig {
1414
std::size_t dataBufferSize{0};
1515
std::size_t chunkSize{0};
16+
std::size_t pipelineDepth{0};
1617
};
1718

1819
// Host-side P2P NVL transport that exchanges IPC buffer handles between ranks

0 commit comments

Comments
 (0)