Skip to content
This repository was archived by the owner on Feb 20, 2020. It is now read-only.

Commit cde2527

Browse files
committed
Initial version. Very dirty.
0 parents  commit cde2527

File tree

5 files changed

+256
-0
lines changed

5 files changed

+256
-0
lines changed

.gitignore

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
5+
# C extensions
6+
*.so
7+
8+
# Distribution / packaging
9+
.Python
10+
env/
11+
bin/
12+
build/
13+
develop-eggs/
14+
dist/
15+
eggs/
16+
lib/
17+
lib64/
18+
parts/
19+
sdist/
20+
var/
21+
*.egg-info/
22+
.installed.cfg
23+
*.egg
24+
25+
# Installer logs
26+
pip-log.txt
27+
pip-delete-this-directory.txt
28+
29+
# Unit test / coverage reports
30+
.tox/
31+
.coverage
32+
.cache
33+
nosetests.xml
34+
coverage.xml
35+
36+
# Translations
37+
*.mo
38+
39+
# Mr Developer
40+
.mr.developer.cfg
41+
.project
42+
.pydevproject
43+
44+
# Rope
45+
.ropeproject
46+
47+
# Django stuff:
48+
*.log
49+
*.pot
50+
51+
# Sphinx documentation
52+
docs/_build/
53+
54+
# ipython notebook
55+
.ipynb_checkpoints

README.md

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Tired of numpy.dot?
2+
3+
Note: this is pre alpha!
4+
5+
TODO change that!
6+
7+
TODO Ideally this thould be integrated in numpy.
8+
9+
`mdot` chains multiplication calls and allows you to write
10+
```python
11+
mdot(A, B, C, D)
12+
```
13+
instead of
14+
```python
15+
np.dot(np.dot(np.dot(A, B), C), D)
16+
A.dot(B).dot(C).dot(D)
17+
```
18+
19+
Did I mention that it automatically speeds up the multiplication by setting the
20+
parens in an optimal fashion:
21+
```python
22+
>>> %timeit np.dot(np.dot(np.dot(A, B), C), D)
23+
1 loops, best of 3: 694 ms per loop
24+
>>> %timeit mdot(A, B, C, D)
25+
100 loops, best of 3: 5.18 ms per loop
26+
```
27+
28+
Still, not satisfied? Get red rid of the overhead for calculating the optimal
29+
parens once and then use the expression:
30+
```python
31+
>>> print_optimal(D, A, B, C, names=list("DABC"))
32+
"np.dot(np.dot(D, np.dot(A, B)), C)"
33+
```

mdot.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from __future__ import print_function
2+
3+
import numpy as np
4+
import mytimer
5+
6+
7+
#@mytimer.timeit
8+
def chain_order_rec(args):
9+
"""
10+
cost[i, k ] = min([cost[prefix] + cost[suffix] + cost_mult(prefix, suffix)
11+
for k in range(i, j)])
12+
m[i, k ] = min([m[i, k] + m[k+1, j] + p[i] * p[k+1] * p[j+1]
13+
for k in range(i, j)])
14+
15+
"""
16+
# p is the list of the row length of all matrices plus the column of the
17+
# last matrix
18+
# example
19+
# A_{10x100}, B_{100x5}, C_{5x50} --> p = [10, 100, 5, 50]
20+
# The cost for multipying AB is then: 10 * 100 * 5
21+
p = [arg.shape[0] for arg in args]
22+
p.append(args[-1].shape[1])
23+
24+
# determine the order of the multiplication using DP
25+
n = len(p) - 1
26+
# costs for subproblems
27+
m = np.zeros((n, n))
28+
# helper to actually multiply optimal solution
29+
s = np.zeros((n, n))
30+
for i in range(n):
31+
for j in range(i+1, n):
32+
cost, k = min((m[i, k] + m[k+1, j] + p[i] * p[k+1] * p[j+1], k)
33+
for k in range(i, j))
34+
m[i, j] = cost
35+
s[i, j] = k
36+
37+
return m, s
38+
39+
40+
#@mytimer.timeit
41+
def multiply_r(args, s, i, j):
42+
if i == j:
43+
return args[int(i)]
44+
else:
45+
return np.dot(multiply_r(args, s, i, s[i, j]),
46+
multiply_r(args, s, s[i, j] + 1, j))
47+
48+
49+
def _print_parens(args, s, i, j, names=None):
50+
if i == j:
51+
if names:
52+
print(names[int(i)], end="")
53+
else:
54+
str_ = "M_{}".format(int(i))
55+
print(str_, end="")
56+
else:
57+
print("np.dot(", end="")
58+
_print_parens(args, s, i, s[i, j], names)
59+
print(", ", end="")
60+
_print_parens(args, s, s[i, j] + 1, j, names)
61+
print(")", end="")
62+
63+
64+
def print_optimal(*args, **kwargs):
65+
"""Print the optimal chain of multiplications that minimizes the total
66+
number of multiplications.
67+
68+
"""
69+
names = kwargs.get("names", None)
70+
m, s = chain_order_rec(args)
71+
_print_parens(args, s, 0, len(args) - 1, names=names)
72+
73+
74+
def mdot(*args, **kwargs):
75+
"""Multiply the given arrays.
76+
77+
`optimize` = True
78+
79+
TODO extend and document.
80+
81+
Minimize the number of required scalar multiplications for the given
82+
matrices.
83+
84+
Example for the costs:
85+
A_{10x100}, B_{100x5}, C_{5x50}
86+
87+
cost((AB)C) = 5000 + 2500 = 7500
88+
cost(A(BC)) = 50000 + 25000 = 75000
89+
90+
"""
91+
if len(args) == 1:
92+
return args[0]
93+
94+
optimize = kwargs.get("optimize", True)
95+
96+
if optimize:
97+
m, s = chain_order_rec(args)
98+
return multiply_r(args, s, 0, len(args) - 1)
99+
else:
100+
return reduce(np.dot, args)

mytimer.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import time
2+
3+
4+
def timeit(func=None, loops=1, verbose=False):
5+
if func is not None:
6+
def inner(*args, **kwargs):
7+
8+
sums = 0.0
9+
mins = 1.7976931348623157e+308
10+
maxs = 0.0
11+
print '====%s Timing====' % func.__name__
12+
for i in range(0, loops):
13+
t0 = time.time()
14+
result = func(*args, **kwargs)
15+
dt = time.time() - t0
16+
mins = dt if dt < mins else mins
17+
maxs = dt if dt > maxs else maxs
18+
sums += dt
19+
if verbose is True:
20+
print "\t%r ran in %2.9f sec on run %s" % (
21+
func.__name__, dt, i)
22+
print "%r min run time was %2.9f sec" % (func.__name__, mins)
23+
print "%r max run time was %2.9f sec" % (func.__name__, maxs)
24+
print "%r avg run time was %2.9f sec in %s runs" % (
25+
func.__name__, sums/loops, loops)
26+
print "==== end ===="
27+
return result
28+
29+
return inner
30+
else:
31+
def partial_inner(func):
32+
return timeit(func, loops, verbose)
33+
return partial_inner

test_mdot.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from mdot import mdot
2+
import numpy as np
3+
4+
5+
##############################################################################
6+
def test_unoptimized_one_parameter():
7+
I = np.eye(3, 3)
8+
assert (mdot(I, optimize=False) == I).all()
9+
10+
11+
def test_unoptimized_multiple_parameters():
12+
I = np.eye(3, 3)
13+
assert (mdot(I, I, optimize=False) == I).all()
14+
assert (mdot(I, I, I, optimize=False) == I).all()
15+
16+
17+
def test_unoptimized_fancy():
18+
A = np.random.random((3, 3))
19+
B = np.linalg.inv(A)
20+
I = np.eye(3)
21+
assert np.allclose(mdot(A, B, optimize=False), I)
22+
23+
24+
##############################################################################
25+
def test_optimized_general():
26+
I = np.eye(3, 3)
27+
assert np.allclose(mdot(I, I, I, I, I, I, I, I, optimize=True), I)
28+
29+
30+
def test_optimized_fancy():
31+
A = np.random.random((3, 3))
32+
B = np.linalg.inv(A)
33+
I = np.eye(3)
34+
assert np.allclose(mdot(A, B, I), I)
35+
print mdot(A, B, I, optimize=True).shape

0 commit comments

Comments
 (0)