@@ -525,7 +525,7 @@ WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype,
525
525
if (isStrided) {
526
526
unpack (rBuff, sharpytype, oDataShapePtr, oDataStridesPtr, oNDims,
527
527
oDataPtr);
528
- delete[] (char *) rBuff;
528
+ delete[] (char *)rBuff;
529
529
}
530
530
};
531
531
assert (sendbuff.empty () && sszs.empty () && soffs.empty () && rszs.empty () &&
@@ -735,18 +735,28 @@ template <typename T> class WaitPermute {
735
735
SHARPY::rank_type cRank, SHARPY::rank_type nRanks,
736
736
std::vector<Parts> &&parts, std::vector<int64_t > &&axes,
737
737
std::vector<int64_t > oGShape, ndarray<T> &&input,
738
- ndarray<T> &&output, std::vector<T> &&receiveBuffer,
739
- std::vector<int > &&receiveOffsets,
738
+ ndarray<T> &&output, std::vector<T> &&sendBuffer,
739
+ std::vector<int > &&sendOffsets, std::vector<int > &&sendSizes,
740
+ std::vector<T> &&receiveBuffer, std::vector<int > &&receiveOffsets,
740
741
std::vector<int > &&receiveSizes)
741
742
: tc(tc), hdl(hdl), cRank(cRank), nRanks(nRanks), parts(std::move(parts)),
742
743
axes (std::move(axes)), oGShape(std::move(oGShape)),
743
744
input(std::move(input)), output(std::move(output)),
745
+ sendBuffer(std::move(sendBuffer)), sendOffsets(std::move(sendOffsets)),
746
+ sendSizes(std::move(sendSizes)),
744
747
receiveBuffer(std::move(receiveBuffer)),
745
748
receiveOffsets(std::move(receiveOffsets)),
746
749
receiveSizes(std::move(receiveSizes)) {}
747
750
751
+ // Only allow move
752
+ WaitPermute (const WaitPermute &) = delete;
753
+ WaitPermute &operator =(const WaitPermute &) = delete ;
754
+ WaitPermute (WaitPermute &&) = default;
755
+ WaitPermute &operator =(WaitPermute &&) = default ;
756
+
748
757
void operator ()() {
749
758
tc->wait (hdl);
759
+
750
760
std::vector<std::vector<T>> receiveRankBuffer (nRanks);
751
761
for (size_t rank = 0 ; rank < nRanks; ++rank) {
752
762
auto &rankBuffer = receiveRankBuffer[rank];
@@ -755,6 +765,7 @@ template <typename T> class WaitPermute {
755
765
receiveBuffer.begin () + receiveOffsets[rank] + receiveSizes[rank]);
756
766
}
757
767
768
+ // FIXME: very low efficiency, need to improve
758
769
std::vector<size_t > receiveRankBufferCount (nRanks, 0 );
759
770
input.globalIndices ([&](const id &inputIndex) {
760
771
id outputIndex = inputIndex.permute (axes);
@@ -777,6 +788,9 @@ template <typename T> class WaitPermute {
777
788
std::vector<int64_t > oGShape;
778
789
ndarray<T> input;
779
790
ndarray<T> output;
791
+ std::vector<T> sendBuffer;
792
+ std::vector<int > sendOffsets;
793
+ std::vector<int > sendSizes;
780
794
std::vector<T> receiveBuffer;
781
795
std::vector<int > receiveOffsets;
782
796
std::vector<int > receiveSizes;
@@ -870,6 +884,7 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
870
884
for (auto i = 0ul ; i < nRanks; ++i) {
871
885
dspl[i] = 4 * i;
872
886
}
887
+
873
888
tc->gather (parts.data (), counts.data (), dspl.data (), SHARPY::INT64,
874
889
SHARPY::REPLICATED);
875
890
@@ -919,10 +934,12 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
919
934
sendOffsets.data (), sharpytype, receiveBuffer.data (),
920
935
receiveSizes.data (), receiveOffsets.data ());
921
936
922
- auto wait = WaitPermute (tc, hdl, cRank, nRanks, std::move (parts),
923
- std::move (axes), std::move (oGShape), std::move (input),
924
- std::move (output), std::move (receiveBuffer),
925
- std::move (receiveOffsets), std::move (receiveSizes));
937
+ auto wait =
938
+ WaitPermute (tc, hdl, cRank, nRanks, std::move (parts), std::move (axes),
939
+ std::move (oGShape), std::move (input), std::move (output),
940
+ std::move (sendBuffer), std::move (sendOffsets),
941
+ std::move (sendSizes), std::move (receiveBuffer),
942
+ std::move (receiveOffsets), std::move (receiveSizes));
926
943
927
944
assert (parts.empty () && axes.empty () && receiveBuffer.empty () &&
928
945
receiveOffsets.empty () && receiveSizes.empty ());
0 commit comments