@@ -525,7 +525,7 @@ WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype,
525525 if (isStrided) {
526526 unpack (rBuff, sharpytype, oDataShapePtr, oDataStridesPtr, oNDims,
527527 oDataPtr);
528- delete[] (char *) rBuff;
528+ delete[] (char *)rBuff;
529529 }
530530 };
531531 assert (sendbuff.empty () && sszs.empty () && soffs.empty () && rszs.empty () &&
@@ -735,18 +735,28 @@ template <typename T> class WaitPermute {
735735 SHARPY::rank_type cRank, SHARPY::rank_type nRanks,
736736 std::vector<Parts> &&parts, std::vector<int64_t > &&axes,
737737 std::vector<int64_t > oGShape, ndarray<T> &&input,
738- ndarray<T> &&output, std::vector<T> &&receiveBuffer,
739- std::vector<int > &&receiveOffsets,
738+ ndarray<T> &&output, std::vector<T> &&sendBuffer,
739+ std::vector<int > &&sendOffsets, std::vector<int > &&sendSizes,
740+ std::vector<T> &&receiveBuffer, std::vector<int > &&receiveOffsets,
740741 std::vector<int > &&receiveSizes)
741742 : tc(tc), hdl(hdl), cRank(cRank), nRanks(nRanks), parts(std::move(parts)),
742743 axes (std::move(axes)), oGShape(std::move(oGShape)),
743744 input(std::move(input)), output(std::move(output)),
745+ sendBuffer(std::move(sendBuffer)), sendOffsets(std::move(sendOffsets)),
746+ sendSizes(std::move(sendSizes)),
744747 receiveBuffer(std::move(receiveBuffer)),
745748 receiveOffsets(std::move(receiveOffsets)),
746749 receiveSizes(std::move(receiveSizes)) {}
747750
751+ // Only allow move
752+ WaitPermute (const WaitPermute &) = delete;
753+ WaitPermute &operator =(const WaitPermute &) = delete ;
754+ WaitPermute (WaitPermute &&) = default;
755+ WaitPermute &operator =(WaitPermute &&) = default ;
756+
748757 void operator ()() {
749758 tc->wait (hdl);
759+
750760 std::vector<std::vector<T>> receiveRankBuffer (nRanks);
751761 for (size_t rank = 0 ; rank < nRanks; ++rank) {
752762 auto &rankBuffer = receiveRankBuffer[rank];
@@ -755,6 +765,7 @@ template <typename T> class WaitPermute {
755765 receiveBuffer.begin () + receiveOffsets[rank] + receiveSizes[rank]);
756766 }
757767
768+ // FIXME: very low efficiency, need to improve
758769 std::vector<size_t > receiveRankBufferCount (nRanks, 0 );
759770 input.globalIndices ([&](const id &inputIndex) {
760771 id outputIndex = inputIndex.permute (axes);
@@ -777,6 +788,9 @@ template <typename T> class WaitPermute {
777788 std::vector<int64_t > oGShape;
778789 ndarray<T> input;
779790 ndarray<T> output;
791+ std::vector<T> sendBuffer;
792+ std::vector<int > sendOffsets;
793+ std::vector<int > sendSizes;
780794 std::vector<T> receiveBuffer;
781795 std::vector<int > receiveOffsets;
782796 std::vector<int > receiveSizes;
@@ -870,6 +884,7 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
870884 for (auto i = 0ul ; i < nRanks; ++i) {
871885 dspl[i] = 4 * i;
872886 }
887+
873888 tc->gather (parts.data (), counts.data (), dspl.data (), SHARPY::INT64,
874889 SHARPY::REPLICATED);
875890
@@ -919,10 +934,12 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
919934 sendOffsets.data (), sharpytype, receiveBuffer.data (),
920935 receiveSizes.data (), receiveOffsets.data ());
921936
922- auto wait = WaitPermute (tc, hdl, cRank, nRanks, std::move (parts),
923- std::move (axes), std::move (oGShape), std::move (input),
924- std::move (output), std::move (receiveBuffer),
925- std::move (receiveOffsets), std::move (receiveSizes));
937+ auto wait =
938+ WaitPermute (tc, hdl, cRank, nRanks, std::move (parts), std::move (axes),
939+ std::move (oGShape), std::move (input), std::move (output),
940+ std::move (sendBuffer), std::move (sendOffsets),
941+ std::move (sendSizes), std::move (receiveBuffer),
942+ std::move (receiveOffsets), std::move (receiveSizes));
926943
927944 assert (parts.empty () && axes.empty () && receiveBuffer.empty () &&
928945 receiveOffsets.empty () && receiveSizes.empty ());
0 commit comments