Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit c901fa7

Browse files
committedSep 9, 2024·
fix test
1 parent a3908ab commit c901fa7

File tree

5 files changed

+94
-44
lines changed

5 files changed

+94
-44
lines changed
 

‎examples/transposed3d.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import sharpy as sp
2+
import numpy as np
3+
4+
5+
def sp_tranposed3d_1():
6+
a = sp.arange(0,2*3*4,1)
7+
a = sp.reshape(a,[2,3,4])
8+
9+
# b = a.swapaxes(1,0).swapaxes(1,2)
10+
b = sp.permute_dims(a, (1,0,2)) # 2x4x4 -> 4x2x4 || 4x4x4
11+
b = sp.permute_dims(b, (0,2,1)) # 4x2x4 -> 4x4x2 || 4x4x4
12+
13+
# c = b.swapaxes(1,2).swapaxes(1,0)
14+
c = sp.permute_dims(b, (0,2,1))
15+
c = sp.permute_dims(c, (1,0,2))
16+
17+
assert(np.allclose(sp.to_numpy(a), sp.to_numpy(c)))
18+
return b
19+
20+
def sp_tranposed3d_2():
21+
a = sp.arange(0,2*3*4,1)
22+
a = sp.reshape(a,[2,3,4])
23+
24+
# b = a.swapaxes(2,1).swapaxes(2,0)
25+
b = sp.permute_dims(a, (0,2,1))
26+
b = sp.permute_dims(b, (2,1,0))
27+
28+
# c = b.swapaxes(2,1).swapaxes(0,1)
29+
c = sp.permute_dims(b, (0,2,1))
30+
c = sp.permute_dims(c, (1,0,2))
31+
32+
return c
33+
34+
def np_tranposed3d_1():
35+
a = np.arange(0,2*3*4,1)
36+
a = np.reshape(a,[2,3,4])
37+
b = a.swapaxes(1,0).swapaxes(1,2)
38+
return b
39+
40+
def np_tranposed3d_2():
41+
a = np.arange(0,2*3*4,1)
42+
a = np.reshape(a,[2,3,4])
43+
b = a.swapaxes(2,1).swapaxes(2,0)
44+
c = b.swapaxes(2,1).swapaxes(0,1)
45+
return c
46+
47+
sp.init(False)
48+
49+
b1 = sp_tranposed3d_1()
50+
assert(np.allclose(sp.to_numpy(b1), np_tranposed3d_1()))
51+
52+
b2 = sp_tranposed3d_2()
53+
assert(np.allclose(sp.to_numpy(b2), np_tranposed3d_2()))
54+
55+
sp.fini()

‎setup.py

-2
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,8 @@ def build_cmake(self, ext):
4343
build_args = ["--config", config]
4444

4545
os.chdir(str(build_temp))
46-
print('!!!!!!!!!!', ["cmake", str(cwd)] + cmake_args)
4746
self.spawn(["cmake", str(cwd)] + cmake_args)
4847
if not self.dry_run:
49-
print('!!!!!!!!!!', ["cmake", "--build", ".", f"-j{multiprocessing.cpu_count()}"] + build_args)
5048
self.spawn(
5149
["cmake", "--build", ".", f"-j{multiprocessing.cpu_count()}"]
5250
+ build_args

‎src/idtr.cpp

+34-18
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,15 @@ template <typename T> class ndarray {
646646
}
647647
}
648648

649+
void globalIndices(const std::function<void(const id &)> &callback) const {
650+
size_t size = gSize();
651+
id idx(_nDims);
652+
while (size--) {
653+
callback(idx);
654+
idx.next(_gShape);
655+
}
656+
}
657+
649658
int64_t getLocalDataOffset(const id &idx) const {
650659
auto localIdx = idx - _gOffsets;
651660
int64_t offset = 0;
@@ -711,14 +720,16 @@ size_t getOutputRank(const std::vector<Parts> &parts, int64_t dim0) {
711720
template <typename T> class WaitPermute {
712721
public:
713722
WaitPermute(SHARPY::Transceiver *tc, SHARPY::Transceiver::WaitHandle hdl,
714-
SHARPY::rank_type nRanks, std::vector<Parts> &&parts,
715-
std::vector<int64_t> &&axes, std::vector<int64_t> oGShape,
723+
SHARPY::rank_type cRank, SHARPY::rank_type nRanks,
724+
std::vector<Parts> &&parts, std::vector<int64_t> &&axes,
725+
std::vector<int64_t> oGShape, ndarray<T> &&input,
716726
ndarray<T> &&output, std::vector<T> &&receiveBuffer,
717727
std::vector<int> &&receiveOffsets,
718728
std::vector<int> &&receiveSizes)
719-
: tc(tc), hdl(hdl), nRanks(nRanks), parts(std::move(parts)),
729+
: tc(tc), hdl(hdl), cRank(cRank), nRanks(nRanks), parts(std::move(parts)),
720730
axes(std::move(axes)), oGShape(std::move(oGShape)),
721-
output(std::move(output)), receiveBuffer(std::move(receiveBuffer)),
731+
input(std::move(input)), output(std::move(output)),
732+
receiveBuffer(std::move(receiveBuffer)),
722733
receiveOffsets(std::move(receiveOffsets)),
723734
receiveSizes(std::move(receiveSizes)) {}
724735

@@ -733,9 +744,12 @@ template <typename T> class WaitPermute {
733744
}
734745

735746
std::vector<size_t> receiveRankBufferCount(nRanks, 0);
736-
output.localIndices([&](const id &outputIndex) {
737-
id inputIndex = outputIndex.permute(axes);
738-
auto rank = getInputRank(parts, inputIndex[0]);
747+
input.globalIndices([&](const id &inputIndex) {
748+
id outputIndex = inputIndex.permute(axes);
749+
auto rank = getOutputRank(parts, outputIndex[0]);
750+
if (rank != cRank)
751+
return;
752+
rank = getInputRank(parts, inputIndex[0]);
739753
auto &count = receiveRankBufferCount[rank];
740754
output[outputIndex] = receiveRankBuffer[rank][count++];
741755
});
@@ -744,10 +758,12 @@ template <typename T> class WaitPermute {
744758
private:
745759
SHARPY::Transceiver *tc;
746760
SHARPY::Transceiver::WaitHandle hdl;
761+
SHARPY::rank_type cRank;
747762
SHARPY::rank_type nRanks;
748763
std::vector<Parts> parts;
749764
std::vector<int64_t> axes;
750765
std::vector<int64_t> oGShape;
766+
ndarray<T> input;
751767
ndarray<T> output;
752768
std::vector<T> receiveBuffer;
753769
std::vector<int> receiveOffsets;
@@ -791,9 +807,9 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
791807
assert(std::accumulate(&oOffsPtr[1], &oOffsPtr[oNDims], 0,
792808
std::plus<int64_t>()) == 0);
793809

794-
auto nRanks = tc->nranks();
795-
auto rank = tc->rank();
796-
if (nRanks <= rank) {
810+
const auto nRanks = tc->nranks();
811+
const auto cRank = tc->rank();
812+
if (nRanks <= cRank) {
797813
throw std::out_of_range("Fatal: rank must be < number of ranks");
798814
}
799815

@@ -833,10 +849,10 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
833849

834850
// First we allgather the current and target partitioning
835851
std::vector<Parts> parts(nRanks);
836-
parts[rank].iStart = iOffsPtr[0];
837-
parts[rank].iEnd = iOffsPtr[0] + iDataShapePtr[0];
838-
parts[rank].oStart = oOffsPtr[0];
839-
parts[rank].oEnd = oOffsPtr[0] + oDataShapePtr[0];
852+
parts[cRank].iStart = iOffsPtr[0];
853+
parts[cRank].iEnd = iOffsPtr[0] + iDataShapePtr[0];
854+
parts[cRank].oStart = oOffsPtr[0];
855+
parts[cRank].oEnd = oOffsPtr[0] + oDataShapePtr[0];
840856
std::vector<int> counts(nRanks, 4);
841857
std::vector<int> dspl(nRanks);
842858
for (auto i = 0ul; i < nRanks; ++i) {
@@ -891,10 +907,10 @@ WaitHandleBase *_idtr_copy_permute(SHARPY::DTypeId sharpytype,
891907
sendOffsets.data(), sharpytype, receiveBuffer.data(),
892908
receiveSizes.data(), receiveOffsets.data());
893909

894-
auto wait = WaitPermute(tc, hdl, nRanks, std::move(parts), std::move(axes),
895-
std::move(oGShape), std::move(output),
896-
std::move(receiveBuffer), std::move(receiveOffsets),
897-
std::move(receiveSizes));
910+
auto wait = WaitPermute(tc, hdl, cRank, nRanks, std::move(parts),
911+
std::move(axes), std::move(oGShape), std::move(input),
912+
std::move(output), std::move(receiveBuffer),
913+
std::move(receiveOffsets), std::move(receiveSizes));
898914

899915
assert(parts.empty() && axes.empty() && receiveBuffer.empty() &&
900916
receiveOffsets.empty() && receiveSizes.empty());

‎test/test_manip.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ def test_todevice_host2gpu(self):
9595
assert numpy.allclose(sp.to_numpy(b), [0, 1, 2, 3, 4, 5, 6, 7])
9696

9797
def test_permute_dims(self):
98-
def doit(aapi, **kwargs):
99-
a = aapi.arange(0, 12 * 11, 1, aapi.int32, **kwargs)
100-
return aapi.permite_dims(a, [1, 0])
101-
102-
assert runAndCompare(doit)
98+
a = sp.arange(0, 10, 1, sp.int64)
99+
b = sp.reshape(a, (2, 5))
100+
c1 = sp.to_numpy(sp.permute_dims(b, [1, 0]))
101+
c2 = sp.to_numpy(b).transpose(1, 0)
102+
assert numpy.allclose(c1, c2)

‎test/test_permute.py

-19
This file was deleted.

0 commit comments

Comments
 (0)
Please sign in to comment.