Skip to content

Commit 4520d97

Browse files
authored
Implement permute_dims (#12)
* implement permute_dims
1 parent 38362e7 commit 4520d97

20 files changed

+755
-16
lines changed

.pre-commit-config.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ repos:
1919
- id: trailing-whitespace
2020
exclude: '.*\.patch'
2121
- repo: https://github.com/psf/black
22-
rev: 24.3.0
22+
rev: 24.8.0
2323
hooks:
2424
- id: black
2525
args: ["--line-length", "80"]
2626
language_version: python3
2727
- repo: https://github.com/PyCQA/bandit
28-
rev: '1.7.8'
28+
rev: '1.7.9'
2929
hooks:
3030
- id: bandit
3131
args: ["-c", ".bandit.yml"]
@@ -35,7 +35,7 @@ repos:
3535
- id: isort
3636
name: isort (python)
3737
- repo: https://github.com/pycqa/flake8
38-
rev: 7.0.0
38+
rev: 7.1.1
3939
hooks:
4040
- id: flake8
4141
- repo: https://github.com/pocc/pre-commit-hooks

CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ include_directories(
149149
${PROJECT_SOURCE_DIR}/third_party/bitsery/include
150150
${MPI_INCLUDE_PATH}
151151
${pybind11_INCLUDE_DIRS}
152+
${LLVM_INCLUDE_DIRS}
152153
${MLIR_INCLUDE_DIRS}
153154
${IMEX_INCLUDE_DIRS})
154155

examples/transpose.py

+184
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
"""
2+
Transpose benchmark
3+
4+
Matrix transpose benchmark for sharpy and numpy backends.
5+
6+
Examples:
7+
8+
# Run 1000 iterations of 1000*1000 matrix on sharpy backend
9+
python transpose.py -r 10 -c 1000 -b sharpy -i 1000
10+
11+
# MPI parallel run
12+
mpiexec -n 3 python transpose.py -r 1000 -c 1000 -b sharpy -i 1000
13+
14+
"""
15+
16+
import argparse
17+
import time as time_mod
18+
19+
import numpy
20+
21+
import sharpy
22+
23+
try:
24+
import mpi4py
25+
26+
mpi4py.rc.finalize = False
27+
from mpi4py import MPI
28+
29+
comm_rank = MPI.COMM_WORLD.Get_rank()
30+
comm = MPI.COMM_WORLD
31+
except ImportError:
32+
comm_rank = 0
33+
comm = None
34+
35+
36+
def info(s):
37+
if comm_rank == 0:
38+
print(s)
39+
40+
41+
def sp_transpose(arr):
42+
brr = sharpy.permute_dims(arr, [1, 0])
43+
return brr
44+
45+
46+
def np_transpose(arr):
47+
brr = arr.transpose()
48+
return brr.copy()
49+
50+
51+
def initialize(np, row, col, dtype):
52+
arr = np.arange(0, row * col, 1, dtype=dtype)
53+
return np.reshape(arr, (row, col))
54+
55+
56+
def run(row, col, backend, iterations, datatype):
57+
if backend == "sharpy":
58+
import sharpy as np
59+
from sharpy import fini, init, sync
60+
61+
transpose = sp_transpose
62+
63+
init(False)
64+
elif backend == "numpy":
65+
import numpy as np
66+
67+
if comm is not None:
68+
assert (
69+
comm.Get_size() == 1
70+
), "Numpy backend only supports serial execution."
71+
72+
fini = sync = lambda x=None: None
73+
transpose = np_transpose
74+
else:
75+
raise ValueError(f'Unknown backend: "{backend}"')
76+
77+
dtype = {
78+
"f32": np.float32,
79+
"f64": np.float64,
80+
}[datatype]
81+
82+
info(f"Using backend: {backend}")
83+
info(f"Number of row: {row}")
84+
info(f"Number of column: {col}")
85+
info(f"Datatype: {datatype}")
86+
87+
arr = initialize(np, row, col, dtype)
88+
sync()
89+
90+
# verify
91+
if backend == "sharpy":
92+
brr = sp_transpose(arr)
93+
crr = np_transpose(sharpy.to_numpy(arr))
94+
assert numpy.allclose(sharpy.to_numpy(brr), crr)
95+
96+
def eval():
97+
tic = time_mod.perf_counter()
98+
transpose(arr)
99+
sync()
100+
toc = time_mod.perf_counter()
101+
return toc - tic
102+
103+
# warm-up run
104+
t_warm = eval()
105+
106+
# evaluate
107+
info(f"Running {iterations} iterations")
108+
time_list = []
109+
for i in range(iterations):
110+
time_list.append(eval())
111+
112+
# get max time over mpi ranks
113+
if comm is not None:
114+
t_warm = comm.allreduce(t_warm, MPI.MAX)
115+
time_list = comm.allreduce(time_list, MPI.MAX)
116+
117+
t_min = numpy.min(time_list)
118+
t_max = numpy.max(time_list)
119+
t_med = numpy.median(time_list)
120+
init_overhead = t_warm - t_med
121+
if backend == "sharpy":
122+
info(f"Estimated initialization overhead: {init_overhead:.5f} s")
123+
info(f"Min. duration: {t_min:.5f} s")
124+
info(f"Max. duration: {t_max:.5f} s")
125+
info(f"Median duration: {t_med:.5f} s")
126+
127+
fini()
128+
129+
130+
if __name__ == "__main__":
131+
parser = argparse.ArgumentParser(
132+
description="Run transpose benchmark",
133+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
134+
)
135+
136+
parser.add_argument(
137+
"-r",
138+
"--row",
139+
type=int,
140+
default=10000,
141+
help="Number of row.",
142+
)
143+
parser.add_argument(
144+
"-c",
145+
"--column",
146+
type=int,
147+
default=10000,
148+
help="Number of column.",
149+
)
150+
151+
parser.add_argument(
152+
"-b",
153+
"--backend",
154+
type=str,
155+
default="sharpy",
156+
choices=["sharpy", "numpy"],
157+
help="Backend to use.",
158+
)
159+
160+
parser.add_argument(
161+
"-i",
162+
"--iterations",
163+
type=int,
164+
default=10,
165+
help="Number of iterations to run.",
166+
)
167+
168+
parser.add_argument(
169+
"-d",
170+
"--datatype",
171+
type=str,
172+
default="f64",
173+
choices=["f32", "f64"],
174+
help="Datatype for model state variables",
175+
)
176+
177+
args = parser.parse_args()
178+
run(
179+
args.row,
180+
args.column,
181+
args.backend,
182+
args.iterations,
183+
args.datatype,
184+
)

imex_version.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
5a7bb80ede5fe4fa8d56ee0dd77c4e5c1327fe09
1+
8ae485bbfb1303a414b375e25130fcaa4c02127a

setup.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import multiprocessing
12
import os
23
import pathlib
34

@@ -44,7 +45,10 @@ def build_cmake(self, ext):
4445
os.chdir(str(build_temp))
4546
self.spawn(["cmake", str(cwd)] + cmake_args)
4647
if not self.dry_run:
47-
self.spawn(["cmake", "--build", ".", "-j5"] + build_args)
48+
self.spawn(
49+
["cmake", "--build", ".", f"-j{multiprocessing.cpu_count()}"]
50+
+ build_args
51+
)
4852
# Troubleshooting: if fail on line above then delete all possible
4953
# temporary CMake files including "CMakeCache.txt" in top level dir.
5054
os.chdir(str(cwd))

sharpy/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ def _validate_device(device):
130130
exec(
131131
f"{func} = lambda this, shape, cp=None: ndarray(_csp.ManipOp.reshape(this._t, shape, cp))"
132132
)
133+
elif func == "permute_dims":
134+
exec(
135+
f"{func} = lambda this, axes: ndarray(_csp.ManipOp.permute_dims(this._t, axes))"
136+
)
133137

134138
for func in api.api_categories["ReduceOp"]:
135139
FUNC = func.upper()

sharpy/array_api.py

+1
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@
179179
"roll", # (x, /, shift, *, axis=None)
180180
"squeeze", # (x, /, axis)
181181
"stack", # (arrays, /, *, axis=0)
182+
"permute_dims", # (x: array, /, axes: Tuple[int, ...]) → array
182183
],
183184
"LinAlgOp": [
184185
"matmul", # (x1, x2, /)

src/EWBinOp.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ struct DeferredEWBinOp : public Deferred {
120120
auto av = dm.getDependent(builder, Registry::get(_a));
121121
auto bv = dm.getDependent(builder, Registry::get(_b));
122122

123-
auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
123+
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
124124
auto outElemType =
125125
::imex::ndarray::toMLIR(builder, SHARPY::jit::getPTDType(_dtype));
126126
auto outTyp = aTyp.cloneWith(shape(), outElemType);

src/EWUnyOp.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ struct DeferredEWUnyOp : public Deferred {
105105
jit::DepManager &dm) override {
106106
auto av = dm.getDependent(builder, Registry::get(_a));
107107

108-
auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
108+
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
109109
auto outTyp = aTyp.cloneWith(shape(), aTyp.getElementType());
110110

111111
auto ndOpId = sharpy(_op);

src/IEWBinOp.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ struct DeferredIEWBinOp : public Deferred {
7171
auto av = dm.getDependent(builder, Registry::get(_a));
7272
auto bv = dm.getDependent(builder, Registry::get(_b));
7373

74-
auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>();
74+
auto aTyp = ::mlir::cast<::imex::ndarray::NDArrayType>(av.getType());
7575
auto outTyp = aTyp.cloneWith(shape(), aTyp.getElementType());
7676

7777
auto binop = builder.create<::imex::ndarray::EWBinOp>(

0 commit comments

Comments
 (0)