Skip to content

Commit 6767196

Browse files
committed
wip
1 parent b7e7062 commit 6767196

File tree

2 files changed

+58
-18
lines changed

2 files changed

+58
-18
lines changed

Diff for: src/idtr.cpp

+57-17
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,7 @@ template <typename T> class ndarray {
641641
size_t size = lSize();
642642
id idx = firstLocalIndex();
643643
while (size--) {
644+
std::cout << "idx: " << idx[0] << ", " << idx[1] << std::endl;
644645
callback(idx);
645646
idx.next(_gShape);
646647
}
@@ -708,6 +709,52 @@ size_t getOutputRank(const std::vector<Parts> &parts, int64_t dim0) {
708709
return 0;
709710
}
710711

712+
template <typename T> class WaitPermute {
713+
public:
714+
WaitPermute(SHARPY::Transceiver *tc, SHARPY::Transceiver::WaitHandle hdl,
715+
SHARPY::rank_type nRanks, std::vector<Parts> &&parts,
716+
std::vector<int64_t> &&axes, ndarray<T> &&output,
717+
std::vector<T> &&receiveBuffer, std::vector<int> &&receiveOffsets,
718+
std::vector<int> &&receiveSizes)
719+
: tc(tc), hdl(hdl), nRanks(nRanks), parts(std::move(parts)),
720+
axes(std::move(axes)), output(std::move(output)),
721+
receiveBuffer(std::move(receiveBuffer)),
722+
receiveOffsets(std::move(receiveOffsets)),
723+
receiveSizes(std::move(receiveSizes)) {}
724+
725+
void operator()() {
726+
tc->wait(hdl);
727+
std::vector<std::vector<T>> receiveRankBuffer(nRanks);
728+
for (size_t rank = 0; rank < nRanks; ++rank) {
729+
auto &rankBuffer = receiveRankBuffer[rank];
730+
rankBuffer.insert(
731+
rankBuffer.end(), receiveBuffer.begin() + receiveOffsets[rank],
732+
receiveBuffer.begin() + receiveOffsets[rank] + receiveSizes[rank]);
733+
}
734+
735+
std::vector<size_t> receiveRankBufferCount(nRanks, 0);
736+
output.localIndices([&](const id &outputIndex) {
737+
id inputIndex = outputIndex.permute(axes);
738+
std::cout << "inputIndex: " << inputIndex[0] << ", " << inputIndex[1]
739+
<< std::endl;
740+
auto rank = getInputRank(parts, inputIndex[0]);
741+
auto &count = receiveRankBufferCount[rank];
742+
output[outputIndex] = receiveRankBuffer[rank][count++];
743+
});
744+
}
745+
746+
private:
747+
SHARPY::Transceiver *tc;
748+
SHARPY::Transceiver::WaitHandle hdl;
749+
SHARPY::rank_type nRanks;
750+
std::vector<Parts> parts;
751+
std::vector<int64_t> axes;
752+
ndarray<T> output;
753+
std::vector<T> receiveBuffer;
754+
std::vector<int> receiveOffsets;
755+
std::vector<int> receiveSizes;
756+
};
757+
711758
} // namespace
712759

713760
/// @brief permute array
@@ -844,27 +891,20 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
844891
auto hdl = tc->alltoall(sendBuffer.data(), sendSizes.data(),
845892
sendOffsets.data(), sharpytype, receiveBuffer.data(),
846893
receiveSizes.data(), receiveOffsets.data());
847-
tc->wait(hdl);
848894

849-
{
850-
std::vector<std::vector<T>> receiveRankBuffer(nRanks);
851-
for (size_t rank = 0; rank < nRanks; ++rank) {
852-
auto &rankBuffer = receiveRankBuffer[rank];
853-
rankBuffer.insert(
854-
rankBuffer.end(), receiveBuffer.begin() + receiveOffsets[rank],
855-
receiveBuffer.begin() + receiveOffsets[rank] + receiveSizes[rank]);
856-
}
895+
auto wait = WaitPermute(tc, hdl, nRanks, std::move(parts), std::move(axes),
896+
std::move(output), std::move(receiveBuffer),
897+
std::move(receiveOffsets), std::move(receiveSizes));
857898

858-
std::vector<size_t> receiveRankBufferCount(nRanks);
859-
output.localIndices([&](const id &outputIndex) {
860-
id inputIndex = outputIndex.permute(axes);
861-
auto rank = getInputRank(parts, inputIndex[0]);
862-
auto &count = receiveRankBufferCount[rank];
863-
output[outputIndex] = receiveRankBuffer[rank][count++];
864-
});
899+
assert(parts.empty() && axes.empty() && receiveBuffer.empty() &&
900+
receiveOffsets.empty() && receiveSizes.empty());
901+
902+
if (no_async) {
903+
wait();
904+
return nullptr;
865905
}
866906

867-
return nullptr;
907+
return mkWaitHandle(std::move(wait));
868908
}
869909

870910
/// @brief permute array

Diff for: src/jit/mlir.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ static const std::string cpu_pipeline =
691691
"one-shot-bufferize,"
692692
"canonicalize,"
693693
"imex-remove-temporaries,"
694-
"func.func(buffer-deallocation),"
694+
"buffer-deallocation-pipeline,"
695695
"func.func(convert-linalg-to-parallel-loops),"
696696
"func.func(scf-parallel-loop-fusion),"
697697
"drop-regions,"

0 commit comments

Comments
 (0)