Skip to content

Commit 84013c5

Browse files
committed
initial commit
0 parents  commit 84013c5

11 files changed

+1039
-0
lines changed

PySortingNetworks/AutoCompleter.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import random
2+
3+
def _base(sn, geninputs, select, doprint=False):
4+
sn = sn.copy()
5+
6+
def callback(d_in, d_out):
7+
#print("failed", d_in, d_out)
8+
nonlocal suggest
9+
zero = True
10+
for idx in range(len(d_out)):
11+
if zero and d_out[idx]:
12+
firstNonZero = idx
13+
zero = False
14+
elif not zero and not d_out[idx]:
15+
lastZero = idx
16+
test = d_out[:]
17+
assert test[firstNonZero] > test[lastZero] and firstNonZero < lastZero
18+
test[firstNonZero], test[lastZero] = test[lastZero], test[firstNonZero]
19+
if all(test[i] <= test[i+1] for i in range(len(test)-1)):
20+
if not (firstNonZero, lastZero) in suggest:
21+
suggest[(firstNonZero, lastZero)] = 0
22+
suggest[(firstNonZero, lastZero)] += 1
23+
24+
suggest = None
25+
while True:
26+
if suggest is None:
27+
suggest = dict()
28+
passed, failed = sn.test( geninputs(), callback)
29+
if doprint:
30+
print("nops", sn.nops(), "passed", passed, "failed", failed, "suggest", len(suggest))
31+
if failed == 0:
32+
break
33+
def applySuggestion(sgst):
34+
nonlocal suggest
35+
test = sn.copy()
36+
test.P(*sgst)
37+
suggest = dict()
38+
passed, failed = test.test(geninputs(), callback)
39+
return passed, failed, suggest
40+
passed, failed, suggest, (i,j) = select(suggest, applySuggestion)
41+
sn.P(i,j)
42+
if doprint:
43+
print("nops", sn.nops(), "passed", passed, "failed", failed, "suggest", len(suggest))
44+
#print(sn)
45+
return sn
46+
47+
def greedy(sn, geninputs, doprint=False, maxNumSuggestsToTry=None):
48+
def select(suggest, applySuggestion):
49+
best = None
50+
if maxNumSuggestsToTry is not None:
51+
last_suggest = sorted(suggest.keys(), key=lambda x: -suggest[x])[:maxNumSuggestsToTry]
52+
else:
53+
last_suggest = list(suggest.keys())
54+
random.shuffle(last_suggest)
55+
for i,j in last_suggest:
56+
passed, failed, suggest = applySuggestion( (i,j) )
57+
if best is None or (passed > best[0]) or (passed == best[0] and len(suggest) > len(best[2])):
58+
best = passed, failed, suggest, (i,j)
59+
return best
60+
return _base(sn, geninputs, select, doprint)
61+
62+
def randbest5(sn, geninputs, doprint=False):
63+
def select(suggest, applySuggestion):
64+
best = []
65+
last_suggest = list(suggest)
66+
random.shuffle(last_suggest)
67+
for i,j in last_suggest:
68+
passed, failed, suggest = applySuggestion( (i,j) )
69+
best.append( (passed, failed, suggest, (i,j)) )
70+
best.sort(key = lambda x: (x[1], -len(x[2])))
71+
best = best[:5]
72+
random.shuffle(best)
73+
return best[0]
74+
return _base(sn, geninputs, select, doprint)
75+

PySortingNetworks/Base.py

+245
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
class BaseSortingNetwork:
2+
def __init__(self, n, ops=None, sections=None):
3+
assert n >= 0
4+
self.n = n
5+
self.nodes = list(range(n))
6+
self.ops = []
7+
self.sections = None
8+
if ops is not None:
9+
for i,j in ops:
10+
self.P(i, j)
11+
self.sections = sections
12+
13+
def nops(self, nodes=None):
14+
if nodes is None:
15+
return len(self.ops)
16+
res = len(self.ops)
17+
while True:
18+
if self.ops[res-1][0] in nodes or self.ops[res-1][1] in nodes or res == 0:
19+
break
20+
else:
21+
res -= 1
22+
for i,j in self.ops[res:]:
23+
assert i not in nodes and j not in nodes
24+
return res
25+
26+
def section(self, name):
27+
if self.sections is None:
28+
if len(self.ops) > 0:
29+
self.sections = [("<unnamed>", 0)]
30+
else:
31+
self.sections = []
32+
self.sections.append( (name, len(self.ops)) )
33+
34+
def focus(self, sections):
35+
assert self.sections is not None
36+
for idx in list(range(len(self.sections)))[::-1]:
37+
name, opidx = self.sections[idx]
38+
if not name in sections:
39+
nextOpIdx = self.sections[idx+1][1] if idx < len(self.sections)-1 else len(self.ops)
40+
self.ops = self.ops[:opidx] + self.ops[nextOpIdx:]
41+
self.sections = None
42+
43+
def P(self, i, j):
44+
"""swap elements i and j if comparison result is unsorted"""
45+
assert i in self.nodes and j in self.nodes and i != j
46+
self.ops.append( (i,j) )
47+
48+
def __str__(self):
49+
"""return ascii art of the network"""
50+
cols = [(("%3d: "%self.nodes[k])+chr(0x2500)) for k in range(self.n)]
51+
secIdx = 0
52+
for opidx, (i,j) in enumerate(self.ops):
53+
if self.sections is not None and secIdx < len(self.sections):
54+
if opidx == self.sections[secIdx][1]:
55+
secIdx += 1
56+
for k in range(len(cols)):
57+
cols[k] += chr(0x250a)
58+
for k in range(min(i,j)):
59+
cols[k] += chr(0x2500)
60+
if i < j:
61+
cols[i] += chr(0x252c)
62+
for k in range(i+1, j):
63+
cols[k] += chr(0x253c)
64+
cols[j] += chr(0x2534)
65+
else:
66+
cols[j] += chr(0x2565)
67+
for k in range(j+1, i):
68+
cols[k] += chr(0x256b)
69+
cols[i] += chr(0x2568)
70+
for k in range(max(i,j)+1, len(cols)):
71+
cols[k] = cols[k] + chr(0x2500)
72+
return repr(self) + "\n" + ("\n".join(cols))
73+
74+
def __repr__(self):
75+
return f"""
76+
BaseSortingNetwork(
77+
n={self.n},
78+
ops={self.ops},
79+
sections={self.sections}
80+
)"""
81+
82+
def execute(self, d_in):
83+
assert len(d_in) == self.n
84+
assert self.nodes == list(range(self.n))
85+
a = list(d_in[:])
86+
for i,j in self.ops:
87+
if a[i] > a[j]:
88+
a[i], a[j] = a[j], a[i]
89+
return a
90+
91+
def execute_stat(self, d_in, stat):
92+
assert len(stat) == len(self.ops)
93+
assert len(d_in) == self.n
94+
assert self.nodes == list(range(self.n))
95+
a = list(d_in[:])
96+
for opidx, (i,j) in enumerate(self.ops):
97+
if a[i] > a[j]:
98+
a[i], a[j] = a[j], a[i]
99+
stat[opidx] += 1
100+
return a
101+
102+
def test(self, inputs, callback_if_failed=None, callback_if_passed=None):
103+
"""tests the network for each input yielded by the generator and gather statistics"""
104+
passed = 0
105+
failed = 0
106+
for d_in in inputs:
107+
d_out = self.execute(d_in)
108+
if not all(d_out[i] <= d_out[i+1] for i in range(self.n-1)):
109+
failed += 1
110+
if callback_if_failed is not None:
111+
callback_if_failed(d_in, d_out)
112+
else:
113+
passed += 1
114+
if callback_if_passed is not None:
115+
callback_if_passed(d_in, d_out)
116+
return passed, failed
117+
118+
def opidx2secname(self):
119+
res = None
120+
if self.sections is not None:
121+
res = {}
122+
for sidx in range(len(self.sections)):
123+
name, opidx0 = self.sections[sidx]
124+
opidx1 = self.sections[sidx+1][1] if sidx+1 < len(self.sections) else len(self.ops)
125+
for opidx in range(opidx0, opidx1):
126+
res[opidx] = name
127+
return res
128+
129+
def prune(self, input_generator, sections=None):
130+
seqidx = 0
131+
opidx2secname = self.opidx2secname()
132+
if self.sections is not None:
133+
if sections is None:
134+
sections = [name for name, _ in self.sections]
135+
for s in sections:
136+
assert s in [name for name, _ in self.sections]
137+
else:
138+
assert sections is None
139+
stats = [0] * len(self.ops)
140+
for d_in in input_generator():
141+
d_out = self.execute_stat(d_in, stats)
142+
indices = list(range(len(self.ops)))[::-1]
143+
numPruned = 0
144+
for opidx in indices:
145+
if stats[opidx] == 0 and (opidx2secname is None or opidx2secname[opidx] in sections):
146+
# opidx can be removed because it was a noop
147+
numPruned += 1
148+
self.ops = self.ops[:opidx] + self.ops[opidx+1:]
149+
if self.sections is not None:
150+
# check if we need to adapt the section indices
151+
newsections = []
152+
for name, idx in self.sections:
153+
if idx == opidx:
154+
newsections.append( (name, idx+1) )
155+
elif idx > opidx:
156+
newsections.append( (name, idx-1) )
157+
else:
158+
newsections.append( (name, idx) )
159+
return numPruned
160+
161+
def normalize(self, method, sections=None):
162+
if sections is None and self.sections is not None:
163+
sections = [name for name, _ in self.sections]
164+
opidx2secname = self.opidx2secname()
165+
if method == "lower_indices_last":
166+
changed = True
167+
while changed:
168+
changed = False
169+
for i in range(len(self.ops)-1):
170+
if opidx2secname is None or (opidx2secname[i] in sections and opidx2secname[i+1] in sections):
171+
i0, i1 = self.ops[i]
172+
j0, j1 = self.ops[i+1]
173+
if j1 > i1:
174+
# we'd like to swap the elements if possible
175+
if i0 != j0 and i0 != j1 and i1 != j0 and i1 != j1:
176+
self.ops[i], self.ops[i+1] = self.ops[i+1], self.ops[i]
177+
changed = True
178+
return
179+
raise RuntimeError("Unknown method %s" % method)
180+
181+
def copy(self):
182+
res = BaseSortingNetwork(self.n)
183+
res.nodes = self.nodes[:]
184+
res.ops = self.ops[:]
185+
if self.sections is not None:
186+
res.sections = self.sections[:]
187+
return res
188+
189+
@staticmethod
190+
def relabel(sn, node_map):
191+
res = BaseSortingNetwork(sn.n)
192+
res.nodes = [node_map[i] for i in sn.nodes]
193+
if sn.sections is not None:
194+
res.sections = sn.sections[:]
195+
for i,j in sn.ops:
196+
res.P(node_map[i], node_map[j])
197+
return res
198+
199+
@staticmethod
200+
def append(sn1, sn2):
201+
if len(set(sn1.nodes) & set(sn2.nodes)) == 0:
202+
# intersection of nodes is empty
203+
res = BaseSortingNetwork(sn1.n + sn2.n)
204+
res.nodes = sn1.nodes + sn2.nodes
205+
res.nodes.sort()
206+
for i,j in sn1.ops + sn2.ops:
207+
res.P(i, j)
208+
if sn1.sections is not None:
209+
res.sections = sn1.sections[:]
210+
if sn2.sections is not None:
211+
if res.sections is None: res.sections = [("<unnamed>", 0)]
212+
for name, idx in sn2.sections:
213+
res.sections.append( (name, sn1.n + idx) )
214+
return res
215+
if set(sn1.nodes) & set(sn2.nodes) == set(sn1.nodes):
216+
res = BaseSortingNetwork(sn2.n)
217+
res.nodes = sn2.nodes[:]
218+
for i,j in sn1.ops:
219+
res.P(i,j)
220+
for i,j in sn2.ops:
221+
res.P(i,j)
222+
if sn1.sections is not None:
223+
res.sections = sn1.sections[:]
224+
if sn2.sections is not None:
225+
if res.sections is None:
226+
res.sections = [("<unnamed>", 0)]
227+
for name, idx in sn2.sections:
228+
res.sections.append( (name, len(sn1.ops) + idx) )
229+
return res
230+
if set(sn1.nodes) & set(sn2.nodes) == set(sn2.nodes):
231+
res = BaseSortingNetwork(sn1.n)
232+
res.nodes = sn1.nodes[:]
233+
for i,j in sn1.ops:
234+
res.P(i,j)
235+
for i,j in sn2.ops:
236+
res.P(i,j)
237+
if sn1.sections is not None:
238+
res.sections = sn1.sections[:]
239+
if sn2.sections is not None:
240+
if res.sections is None:
241+
res.sections = [("<unnamed>", 0)]
242+
for name, idx in sn2.sections:
243+
res.sections.append( (name, len(sn1.ops) + idx) )
244+
return res
245+
raise RuntimeError("Don't know how to append both networks n1=%s n2=%s" % (sn1.nodes, sn2.nodes))

0 commit comments

Comments
 (0)