@@ -641,6 +641,7 @@ template <typename T> class ndarray {
641
641
size_t size = lSize ();
642
642
id idx = firstLocalIndex ();
643
643
while (size--) {
644
+ std::cout << " idx: " << idx[0 ] << " , " << idx[1 ] << std::endl;
644
645
callback (idx);
645
646
idx.next (_gShape);
646
647
}
@@ -708,6 +709,52 @@ size_t getOutputRank(const std::vector<Parts> &parts, int64_t dim0) {
708
709
return 0 ;
709
710
}
710
711
712
+ template <typename T> class WaitPermute {
713
+ public:
714
+ WaitPermute (SHARPY::Transceiver *tc, SHARPY::Transceiver::WaitHandle hdl,
715
+ SHARPY::rank_type nRanks, std::vector<Parts> &&parts,
716
+ std::vector<int64_t > &&axes, ndarray<T> &&output,
717
+ std::vector<T> &&receiveBuffer, std::vector<int > &&receiveOffsets,
718
+ std::vector<int > &&receiveSizes)
719
+ : tc(tc), hdl(hdl), nRanks(nRanks), parts(std::move(parts)),
720
+ axes (std::move(axes)), output(std::move(output)),
721
+ receiveBuffer(std::move(receiveBuffer)),
722
+ receiveOffsets(std::move(receiveOffsets)),
723
+ receiveSizes(std::move(receiveSizes)) {}
724
+
725
+ void operator ()() {
726
+ tc->wait (hdl);
727
+ std::vector<std::vector<T>> receiveRankBuffer (nRanks);
728
+ for (size_t rank = 0 ; rank < nRanks; ++rank) {
729
+ auto &rankBuffer = receiveRankBuffer[rank];
730
+ rankBuffer.insert (
731
+ rankBuffer.end (), receiveBuffer.begin () + receiveOffsets[rank],
732
+ receiveBuffer.begin () + receiveOffsets[rank] + receiveSizes[rank]);
733
+ }
734
+
735
+ std::vector<size_t > receiveRankBufferCount (nRanks, 0 );
736
+ output.localIndices ([&](const id &outputIndex) {
737
+ id inputIndex = outputIndex.permute (axes);
738
+ std::cout << " inputIndex: " << inputIndex[0 ] << " , " << inputIndex[1 ]
739
+ << std::endl;
740
+ auto rank = getInputRank (parts, inputIndex[0 ]);
741
+ auto &count = receiveRankBufferCount[rank];
742
+ output[outputIndex] = receiveRankBuffer[rank][count++];
743
+ });
744
+ }
745
+
746
+ private:
747
+ SHARPY::Transceiver *tc;
748
+ SHARPY::Transceiver::WaitHandle hdl;
749
+ SHARPY::rank_type nRanks;
750
+ std::vector<Parts> parts;
751
+ std::vector<int64_t > axes;
752
+ ndarray<T> output;
753
+ std::vector<T> receiveBuffer;
754
+ std::vector<int > receiveOffsets;
755
+ std::vector<int > receiveSizes;
756
+ };
757
+
711
758
} // namespace
712
759
713
760
// / @brief permute array
@@ -844,27 +891,20 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
844
891
auto hdl = tc->alltoall (sendBuffer.data (), sendSizes.data (),
845
892
sendOffsets.data (), sharpytype, receiveBuffer.data (),
846
893
receiveSizes.data (), receiveOffsets.data ());
847
- tc->wait (hdl);
848
894
849
- {
850
- std::vector<std::vector<T>> receiveRankBuffer (nRanks);
851
- for (size_t rank = 0 ; rank < nRanks; ++rank) {
852
- auto &rankBuffer = receiveRankBuffer[rank];
853
- rankBuffer.insert (
854
- rankBuffer.end (), receiveBuffer.begin () + receiveOffsets[rank],
855
- receiveBuffer.begin () + receiveOffsets[rank] + receiveSizes[rank]);
856
- }
895
+ auto wait = WaitPermute (tc, hdl, nRanks, std::move (parts), std::move (axes),
896
+ std::move (output), std::move (receiveBuffer),
897
+ std::move (receiveOffsets), std::move (receiveSizes));
857
898
858
- std::vector<size_t > receiveRankBufferCount (nRanks);
859
- output.localIndices ([&](const id &outputIndex) {
860
- id inputIndex = outputIndex.permute (axes);
861
- auto rank = getInputRank (parts, inputIndex[0 ]);
862
- auto &count = receiveRankBufferCount[rank];
863
- output[outputIndex] = receiveRankBuffer[rank][count++];
864
- });
899
+ assert (parts.empty () && axes.empty () && receiveBuffer.empty () &&
900
+ receiveOffsets.empty () && receiveSizes.empty ());
901
+
902
+ if (no_async) {
903
+ wait ();
904
+ return nullptr ;
865
905
}
866
906
867
- return nullptr ;
907
+ return mkWaitHandle ( std::move ( wait )) ;
868
908
}
869
909
870
910
// / @brief permute array
0 commit comments