@@ -646,6 +646,15 @@ template <typename T> class ndarray {
646
646
}
647
647
}
648
648
649
+ void globalIndices (const std::function<void (const id &)> &callback) const {
650
+ size_t size = gSize ();
651
+ id idx (_nDims);
652
+ while (size--) {
653
+ callback (idx);
654
+ idx.next (_gShape);
655
+ }
656
+ }
657
+
649
658
int64_t getLocalDataOffset (const id &idx) const {
650
659
auto localIdx = idx - _gOffsets;
651
660
int64_t offset = 0 ;
@@ -711,14 +720,16 @@ size_t getOutputRank(const std::vector<Parts> &parts, int64_t dim0) {
711
720
template <typename T> class WaitPermute {
712
721
public:
713
722
WaitPermute (SHARPY::Transceiver *tc, SHARPY::Transceiver::WaitHandle hdl,
714
- SHARPY::rank_type nRanks, std::vector<Parts> &&parts,
715
- std::vector<int64_t > &&axes, std::vector<int64_t > oGShape,
723
+ SHARPY::rank_type cRank, SHARPY::rank_type nRanks,
724
+ std::vector<Parts> &&parts, std::vector<int64_t > &&axes,
725
+ std::vector<int64_t > oGShape, ndarray<T> &&input,
716
726
ndarray<T> &&output, std::vector<T> &&receiveBuffer,
717
727
std::vector<int > &&receiveOffsets,
718
728
std::vector<int > &&receiveSizes)
719
- : tc(tc), hdl(hdl), nRanks(nRanks), parts(std::move(parts)),
729
+ : tc(tc), hdl(hdl), cRank(cRank), nRanks(nRanks), parts(std::move(parts)),
720
730
axes (std::move(axes)), oGShape(std::move(oGShape)),
721
- output(std::move(output)), receiveBuffer(std::move(receiveBuffer)),
731
+ input(std::move(input)), output(std::move(output)),
732
+ receiveBuffer(std::move(receiveBuffer)),
722
733
receiveOffsets(std::move(receiveOffsets)),
723
734
receiveSizes(std::move(receiveSizes)) {}
724
735
@@ -733,9 +744,12 @@ template <typename T> class WaitPermute {
733
744
}
734
745
735
746
std::vector<size_t > receiveRankBufferCount (nRanks, 0 );
736
- output.localIndices ([&](const id &outputIndex) {
737
- id inputIndex = outputIndex.permute (axes);
738
- auto rank = getInputRank (parts, inputIndex[0 ]);
747
+ input.globalIndices ([&](const id &inputIndex) {
748
+ id outputIndex = inputIndex.permute (axes);
749
+ auto rank = getOutputRank (parts, outputIndex[0 ]);
750
+ if (rank != cRank)
751
+ return ;
752
+ rank = getInputRank (parts, inputIndex[0 ]);
739
753
auto &count = receiveRankBufferCount[rank];
740
754
output[outputIndex] = receiveRankBuffer[rank][count++];
741
755
});
@@ -744,10 +758,12 @@ template <typename T> class WaitPermute {
744
758
private:
745
759
SHARPY::Transceiver *tc;
746
760
SHARPY::Transceiver::WaitHandle hdl;
761
+ SHARPY::rank_type cRank;
747
762
SHARPY::rank_type nRanks;
748
763
std::vector<Parts> parts;
749
764
std::vector<int64_t > axes;
750
765
std::vector<int64_t > oGShape;
766
+ ndarray<T> input;
751
767
ndarray<T> output;
752
768
std::vector<T> receiveBuffer;
753
769
std::vector<int > receiveOffsets;
@@ -791,9 +807,9 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
791
807
assert (std::accumulate (&oOffsPtr[1 ], &oOffsPtr[oNDims], 0 ,
792
808
std::plus<int64_t >()) == 0 );
793
809
794
- auto nRanks = tc->nranks ();
795
- auto rank = tc->rank ();
796
- if (nRanks <= rank ) {
810
+ const auto nRanks = tc->nranks ();
811
+ const auto cRank = tc->rank ();
812
+ if (nRanks <= cRank ) {
797
813
throw std::out_of_range (" Fatal: rank must be < number of ranks" );
798
814
}
799
815
@@ -833,10 +849,10 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
833
849
834
850
// First we allgather the current and target partitioning
835
851
std::vector<Parts> parts (nRanks);
836
- parts[rank ].iStart = iOffsPtr[0 ];
837
- parts[rank ].iEnd = iOffsPtr[0 ] + iDataShapePtr[0 ];
838
- parts[rank ].oStart = oOffsPtr[0 ];
839
- parts[rank ].oEnd = oOffsPtr[0 ] + oDataShapePtr[0 ];
852
+ parts[cRank ].iStart = iOffsPtr[0 ];
853
+ parts[cRank ].iEnd = iOffsPtr[0 ] + iDataShapePtr[0 ];
854
+ parts[cRank ].oStart = oOffsPtr[0 ];
855
+ parts[cRank ].oEnd = oOffsPtr[0 ] + oDataShapePtr[0 ];
840
856
std::vector<int > counts (nRanks, 4 );
841
857
std::vector<int > dspl (nRanks);
842
858
for (auto i = 0ul ; i < nRanks; ++i) {
@@ -891,10 +907,10 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
891
907
sendOffsets.data (), sharpytype, receiveBuffer.data (),
892
908
receiveSizes.data (), receiveOffsets.data ());
893
909
894
- auto wait = WaitPermute (tc, hdl, nRanks, std::move (parts) , std::move (axes ),
895
- std::move (oGShape), std::move (output ),
896
- std::move (receiveBuffer ), std::move (receiveOffsets ),
897
- std::move (receiveSizes));
910
+ auto wait = WaitPermute (tc, hdl, cRank, nRanks , std::move (parts ),
911
+ std::move (axes), std::move ( oGShape), std::move (input ),
912
+ std::move (output ), std::move (receiveBuffer ),
913
+ std::move (receiveOffsets), std::move ( receiveSizes));
898
914
899
915
assert (parts.empty () && axes.empty () && receiveBuffer.empty () &&
900
916
receiveOffsets.empty () && receiveSizes.empty ());
0 commit comments