Skip to content

Commit 683a441

Browse files
bingzheliumeta-codesync[bot]
authored andcommitted
Use user passed parameter to determine if it is in combine
Summary: Instead of calculate if it is 2nd a2a inside a2avD, we use user passing parameter combine to determine if it is 2nd a2a. Reviewed By: cenzhaometa Differential Revision: D86346838 fbshipit-source-id: 8567fd95877b36c179ece264b4e19db8a958ba1c
1 parent 268eb77 commit 683a441

File tree

7 files changed

+36
-25
lines changed

7 files changed

+36
-25
lines changed

comms/ctran/algos/AllToAll/AllToAllvDynamic.cuh

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ __device__ __forceinline__ void recvImplNonContig(
164164
int groupIdx,
165165
int ngroups,
166166
size_t maxRecvcount,
167-
bool nonContigIndices) {
167+
bool combine) {
168168
const auto localRank = statex->localRank();
169169
const auto nLocalRanks = statex->nLocalRanks();
170170

@@ -189,7 +189,7 @@ __device__ __forceinline__ void recvImplNonContig(
189189
// writes it to the recvCountsTmpbufGPU buffer.
190190
devSyncWaitStep(sync, groupIdx, 0);
191191
mySendIndicesBlockLength = recvIndicesPeerAllToAllvDynamicBufsMap[0];
192-
if (threadIdx.x == 0 && groupIdx == 0 && !nonContigIndices) {
192+
if (threadIdx.x == 0 && groupIdx == 0 && !combine) {
193193
for (int i = 0; i < sendcountsLength; i++) {
194194
recvCountsTmpbufGPU[recvPeerGlobal * sendcountsLength + i] =
195195
recvcountsPeerAllToAllvDynamicBufsMap[i];
@@ -198,7 +198,7 @@ __device__ __forceinline__ void recvImplNonContig(
198198
devSyncSetStep(sync, groupIdx, CTRAN_ALGO_STEP_RESET);
199199

200200
size_t recvOffsets = 0, lastRecvIndex = 0;
201-
if (nonContigIndices) {
201+
if (combine) {
202202
lastRecvIndex = sendcountsLength * statex->rank() / statex->nRanks();
203203
}
204204
for (int i = 0; i < mySendIndicesBlockLength; i++) {
@@ -304,7 +304,7 @@ __device__ __forceinline__ void selfCopyNonContig(
304304
int groupIdx,
305305
bool groupType,
306306
size_t maxRecvcount,
307-
bool nonContigIndices) {
307+
bool combine) {
308308
// Now we calculate the startSendIndex on-the-fly,
309309
// which may not be efficient. If the inputChunkCountPerRank can be
310310
// on CPU, we can calculate it on CPU and pass it to GPU.
@@ -317,7 +317,7 @@ __device__ __forceinline__ void selfCopyNonContig(
317317
startSendIndex += inputChunkCountPerRank[i];
318318
}
319319

320-
if (!nonContigIndices && groupIdx == 0 && groupType == GROUP_RECV) {
320+
if (!combine && groupIdx == 0 && groupType == GROUP_RECV) {
321321
ctranKernCopy<size_t>(
322322
sendcounts,
323323
recvCountsTmpbufGPU + rank * sendcountsLength,
@@ -326,7 +326,7 @@ __device__ __forceinline__ void selfCopyNonContig(
326326
1);
327327
}
328328

329-
if (nonContigIndices) {
329+
if (combine) {
330330
curOffsetIndex = sendcountsLength * rank / nRanks;
331331
}
332332

@@ -383,7 +383,7 @@ __device__ __forceinline__ void ncclKernelAllToAllvDynamicCommon(
383383
int* flag,
384384
CtranKernelAllToAllvDynamicArgs args,
385385
ALGOTYPE algoType,
386-
bool nonContigIndices = false) {
386+
bool combine = false) {
387387
const auto gtIdx = blockDim.x * blockIdx.x + threadIdx.x;
388388

389389
const auto rank = statex->rank();
@@ -461,7 +461,7 @@ __device__ __forceinline__ void ncclKernelAllToAllvDynamicCommon(
461461
groupIdx,
462462
groupType,
463463
args.nonContig.maxRecvcount,
464-
nonContigIndices);
464+
combine);
465465
if (groupType == GROUP_RECV) {
466466
recvImplNonContig(
467467
recvbuffs,
@@ -471,7 +471,7 @@ __device__ __forceinline__ void ncclKernelAllToAllvDynamicCommon(
471471
groupIdx,
472472
ngroups,
473473
args.nonContig.maxRecvcount,
474-
nonContigIndices);
474+
combine);
475475
} else {
476476
sendImplNonContig(
477477
sendbuffs,
@@ -510,7 +510,7 @@ __device__ __forceinline__ void ncclKernelAllToAllvDynamicCommon(
510510
// Copy back to recvcounts for DYNAMIC and DYNAMIC_SPLIT
511511
// or if it is first a2a for DYNAMIC_SPLIT_NON_CONTIG
512512
if (groupIdx == 0 && groupType == GROUP_RECV &&
513-
(algoType != DYNAMIC_SPLIT_NON_CONTIG || !nonContigIndices)) {
513+
(algoType != DYNAMIC_SPLIT_NON_CONTIG || !combine)) {
514514
ctranKernCopy<size_t>(
515515
recvCountsTmpbufGPU,
516516
reinterpret_cast<size_t*>(args.actualRecvcounts),
@@ -528,7 +528,7 @@ __device__ __forceinline__ void ncclKernelAllToAllvDynamicCommon(
528528
template <typename T>
529529
__device__ __forceinline__ void generateSendbuffs(
530530
CtranKernelAllToAllvDynamicArgs& args,
531-
bool nonContigIndices = false) {
531+
bool combine = false) {
532532
const auto gtIdx = blockDim.x * blockIdx.x + threadIdx.x;
533533
const size_t* sendSplitLengths = (size_t*)args.sendcounts;
534534
args.split.sendbuffsPtrShmDev =
@@ -548,7 +548,7 @@ __device__ __forceinline__ void generateSendbuffs(
548548
// and hence need to reset the sendbuff offset.
549549
// The length of each rank is equal to maxsendcounts/ranks.
550550
// i / numCountsPerRank is the rank number.
551-
if (nonContigIndices && (i % numCountsPerRank == 0)) {
551+
if (combine && (i % numCountsPerRank == 0)) {
552552
sendbuffsGPU[i] = sendbuffsGPU[0] +
553553
(args.nonContig.maxSendcount / statex->nRanks()) *
554554
(i / numCountsPerRank);
@@ -592,14 +592,14 @@ __global__ void ncclKernelAllToAllvDynamicSplitNonContig(
592592
CtranKernelAllToAllvDynamicArgs args) {
593593
devStateLoadToShm(devState);
594594

595-
bool nonContigIndices = false;
596595
int totalSendIndicesLength = 0;
597596
for (int i = 0; i < statex->nRanks(); i++) {
598597
totalSendIndicesLength += args.nonContig.inputChunkCountPerRank[i];
599598
}
600-
nonContigIndices = (totalSendIndicesLength < args.sendcountsLength);
601599

602-
generateSendbuffs<T>(args, nonContigIndices);
600+
bool combine = args.nonContig.combine;
601+
602+
generateSendbuffs<T>(args, combine);
603603

604604
ctranKernCopy<size_t>(
605605
args.nonContig.inputChunkIndices,
@@ -630,7 +630,7 @@ __global__ void ncclKernelAllToAllvDynamicSplitNonContig(
630630
}
631631

632632
ncclKernelAllToAllvDynamicCommon<T>(
633-
flag, args, DYNAMIC_SPLIT_NON_CONTIG, nonContigIndices);
633+
flag, args, DYNAMIC_SPLIT_NON_CONTIG, combine);
634634
}
635635

636636
#define DECL_CTRAN_ALLTOALLVDYNAMIC_KERN(T) \

comms/ctran/algos/AllToAll/AllToAllvDynamicCommon.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
ibPutReqs, \
3232
ibRecvCtrlReqs, \
3333
maxRecvcount, \
34-
maxSendcount)); \
34+
maxSendcount, \
35+
combine)); \
3536
} else { \
3637
FB_COMMCHECK(peerPutContig( \
3738
comm, \
@@ -218,7 +219,8 @@ commResult_t ctranAllToAllvDynamicIbImpl(
218219
CtranComm* comm,
219220
std::unique_ptr<CtranMapperTimestamp> timestamp,
220221
KernelElem* elem,
221-
void* recvbuff) {
222+
void* recvbuff,
223+
bool combine) {
222224
const auto& statex = comm->statex_;
223225
const int myRank = statex->rank();
224226
const int nRanks = statex->nRanks();
@@ -391,7 +393,8 @@ commResult_t opIbImpl(
391393
comm,
392394
std::move(timestamp),
393395
op->alltoallv_dynamic.kElem,
394-
op->alltoallv_dynamic.recvbuff);
396+
op->alltoallv_dynamic.recvbuff,
397+
op->alltoallv_dynamic.combine);
395398
}
396399

397400
commResult_t setupGpeOp(
@@ -406,7 +409,8 @@ commResult_t setupGpeOp(
406409
uint64_t opCount,
407410
std::vector<std::unique_ptr<struct OpElem>>& opGroup,
408411
KernelElem* elem,
409-
void* recvbuff) {
412+
void* recvbuff,
413+
bool combine) {
410414
std::unique_ptr<struct OpElem> op =
411415
std::unique_ptr<struct OpElem>(new OpElem(opType, comm, opCount));
412416
op->alltoallv_dynamic.sendbuffs = sendbuffs;
@@ -417,6 +421,7 @@ commResult_t setupGpeOp(
417421
op->alltoallv_dynamic.maxRecvcount = maxRecvcount;
418422
op->alltoallv_dynamic.kElem = elem;
419423
op->alltoallv_dynamic.recvbuff = recvbuff;
424+
op->alltoallv_dynamic.combine = combine;
420425

421426
opGroup.push_back(std::move(op));
422427

comms/ctran/algos/AllToAll/AllToAllvDynamicCommon.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ commResult_t setupGpeOp(
4545
uint64_t opCount,
4646
std::vector<std::unique_ptr<struct OpElem>>& opGroup,
4747
KernelElem* elem,
48-
void* recvbuff = nullptr);
48+
void* recvbuff = nullptr,
49+
bool combine = false);
4950

5051
template <typename PerfConfig = DefaultPerfCollConfig>
5152
commResult_t peerPutNonContig(
@@ -64,6 +65,7 @@ commResult_t peerPutNonContig(
6465
std::vector<std::unique_ptr<CtranMapperRequest>>& ibRecvCtrlReqs,
6566
size_t maxRecvcount,
6667
size_t maxSendcount,
68+
bool combine,
6769
bool skipWaitRecvCtrl = false) {
6870
// Prepare basic info for nonContig send
6971
size_t* sendIndices = reinterpret_cast<size_t*>(comm->ctran_->algo->getTmpBuf(
@@ -97,15 +99,14 @@ commResult_t peerPutNonContig(
9799
for (int r = 0; r < comm->statex_->nRanks(); r++) {
98100
totalBlock += sendIndicesBlockLengthsTmpbufCPU[r];
99101
}
100-
bool nonContigIndices = (totalBlock < sendcountsLength);
101102

102103
// Calculate the offset of each recvbuff, considering if it is 1st or 2nd
103104
// all2allv.
104105
std::vector<size_t> remoteRecvBuffsBytesOffset(sendcountsLength);
105106
remoteRecvBuffsBytesOffset[0] = 0;
106107
int numCountsPerRank = sendcountsLength / nRanks;
107108
for (int i = 1; i < sendcountsLength; i++) {
108-
if (nonContigIndices && (i % numCountsPerRank == 0)) {
109+
if (combine && (i % numCountsPerRank == 0)) {
109110
remoteRecvBuffsBytesOffset[i] = 0;
110111
} else {
111112
remoteRecvBuffsBytesOffset[i] += remoteRecvBuffsBytesOffset[i - 1] +
@@ -153,7 +154,7 @@ commResult_t peerPutNonContig(
153154
// Allgather sendcounts
154155
// Skip sending sendcounts if it is second all2allv.
155156
// TODO: using hints instead of nonContigIndices to determine this.
156-
if (!nonContigIndices) {
157+
if (!combine) {
157158
puts.emplace_back(
158159
CtranMapperPutMsg{
159160
.sbuf = reinterpret_cast<size_t*>(sendCountsTmpbufGPU),

comms/ctran/algos/AllToAll/AllToAllvDynamicPImpl.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ using ctran::alltoallvdynamicp::PersistArgs;
3131
completedIbRecvCtrlReqs, \
3232
pArgs->maxRecvCount, \
3333
pArgs->maxSendCount, \
34+
op->alltoallv_dynamic.combine, \
3435
/* skipWaitRecvCtrl */ true)); \
3536
/* Wait for all puts to complete */ \
3637
for (auto& req : ibPutReqs) { \

comms/ctran/algos/AllToAll/AllToallvDynamicSplitNonContig.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ commResult_t ctranAlltoallvDynamicSplitNonContig(
8484
maxRecvcount;
8585
config.args.collective.alltoallv_dynamic.nonContig.maxSendcount =
8686
maxSendcount;
87+
config.args.collective.alltoallv_dynamic.nonContig.combine = combine;
8788

8889
if (recvbuff != nullptr) {
8990
for (int i = 0; i < comm->statex_->nRanks(); i++) {
@@ -106,7 +107,8 @@ commResult_t ctranAlltoallvDynamicSplitNonContig(
106107
opCount,
107108
opGroup,
108109
elem,
109-
recvbuff));
110+
recvbuff,
111+
combine));
110112

111113
XCHECK(alltoallvDynamicSplitNonContigKerns.contains(datatype))
112114
<< "alltoallvDynamicSplitNonContigKerns does not contain datatype "

comms/ctran/gpe/CtranGpe.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ struct OpElem {
163163
KernelElem* kElem;
164164
// Persistent args for persistent alltoallv_dynamic.
165165
void* pArgs;
166+
bool combine;
166167
} alltoallv_dynamic;
167168
struct {
168169
const void* sendbuff;

comms/ctran/gpe/CtranGpeDev.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ struct CtranKernelAllToAllvDynamicArgs {
260260
size_t maxInputChunkCountPerRank{0};
261261
size_t maxRecvcount{0};
262262
size_t maxSendcount{0};
263+
bool combine;
263264
} nonContig;
264265
struct {
265266
} contig;

0 commit comments

Comments
 (0)