Skip to content

Commit d9eb123

Browse files
committed
WaitPermute only allow move
1 parent 4520d97 commit d9eb123

File tree

1 file changed

+24
-7
lines changed

1 file changed

+24
-7
lines changed

src/idtr.cpp

+24-7
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype,
525525
if (isStrided) {
526526
unpack(rBuff, sharpytype, oDataShapePtr, oDataStridesPtr, oNDims,
527527
oDataPtr);
528-
delete[](char *) rBuff;
528+
delete[] (char *)rBuff;
529529
}
530530
};
531531
assert(sendbuff.empty() && sszs.empty() && soffs.empty() && rszs.empty() &&
@@ -735,18 +735,28 @@ template <typename T> class WaitPermute {
735735
SHARPY::rank_type cRank, SHARPY::rank_type nRanks,
736736
std::vector<Parts> &&parts, std::vector<int64_t> &&axes,
737737
std::vector<int64_t> oGShape, ndarray<T> &&input,
738-
ndarray<T> &&output, std::vector<T> &&receiveBuffer,
739-
std::vector<int> &&receiveOffsets,
738+
ndarray<T> &&output, std::vector<T> &&sendBuffer,
739+
std::vector<int> &&sendOffsets, std::vector<int> &&sendSizes,
740+
std::vector<T> &&receiveBuffer, std::vector<int> &&receiveOffsets,
740741
std::vector<int> &&receiveSizes)
741742
: tc(tc), hdl(hdl), cRank(cRank), nRanks(nRanks), parts(std::move(parts)),
742743
axes(std::move(axes)), oGShape(std::move(oGShape)),
743744
input(std::move(input)), output(std::move(output)),
745+
sendBuffer(std::move(sendBuffer)), sendOffsets(std::move(sendOffsets)),
746+
sendSizes(std::move(sendSizes)),
744747
receiveBuffer(std::move(receiveBuffer)),
745748
receiveOffsets(std::move(receiveOffsets)),
746749
receiveSizes(std::move(receiveSizes)) {}
747750

751+
// Only allow move
752+
WaitPermute(const WaitPermute &) = delete;
753+
WaitPermute &operator=(const WaitPermute &) = delete;
754+
WaitPermute(WaitPermute &&) = default;
755+
WaitPermute &operator=(WaitPermute &&) = default;
756+
748757
void operator()() {
749758
tc->wait(hdl);
759+
750760
std::vector<std::vector<T>> receiveRankBuffer(nRanks);
751761
for (size_t rank = 0; rank < nRanks; ++rank) {
752762
auto &rankBuffer = receiveRankBuffer[rank];
@@ -755,6 +765,7 @@ template <typename T> class WaitPermute {
755765
receiveBuffer.begin() + receiveOffsets[rank] + receiveSizes[rank]);
756766
}
757767

768+
// FIXME: very low efficiency, need to improve
758769
std::vector<size_t> receiveRankBufferCount(nRanks, 0);
759770
input.globalIndices([&](const id &inputIndex) {
760771
id outputIndex = inputIndex.permute(axes);
@@ -777,6 +788,9 @@ template <typename T> class WaitPermute {
777788
std::vector<int64_t> oGShape;
778789
ndarray<T> input;
779790
ndarray<T> output;
791+
std::vector<T> sendBuffer;
792+
std::vector<int> sendOffsets;
793+
std::vector<int> sendSizes;
780794
std::vector<T> receiveBuffer;
781795
std::vector<int> receiveOffsets;
782796
std::vector<int> receiveSizes;
@@ -870,6 +884,7 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
870884
for (auto i = 0ul; i < nRanks; ++i) {
871885
dspl[i] = 4 * i;
872886
}
887+
873888
tc->gather(parts.data(), counts.data(), dspl.data(), SHARPY::INT64,
874889
SHARPY::REPLICATED);
875890

@@ -919,10 +934,12 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
919934
sendOffsets.data(), sharpytype, receiveBuffer.data(),
920935
receiveSizes.data(), receiveOffsets.data());
921936

922-
auto wait = WaitPermute(tc, hdl, cRank, nRanks, std::move(parts),
923-
std::move(axes), std::move(oGShape), std::move(input),
924-
std::move(output), std::move(receiveBuffer),
925-
std::move(receiveOffsets), std::move(receiveSizes));
937+
auto wait =
938+
WaitPermute(tc, hdl, cRank, nRanks, std::move(parts), std::move(axes),
939+
std::move(oGShape), std::move(input), std::move(output),
940+
std::move(sendBuffer), std::move(sendOffsets),
941+
std::move(sendSizes), std::move(receiveBuffer),
942+
std::move(receiveOffsets), std::move(receiveSizes));
926943

927944
assert(parts.empty() && axes.empty() && receiveBuffer.empty() &&
928945
receiveOffsets.empty() && receiveSizes.empty());

0 commit comments

Comments
 (0)