|
| 1 | +class BaseSortingNetwork: |
| 2 | + def __init__(self, n, ops=None, sections=None): |
| 3 | + assert n >= 0 |
| 4 | + self.n = n |
| 5 | + self.nodes = list(range(n)) |
| 6 | + self.ops = [] |
| 7 | + self.sections = None |
| 8 | + if ops is not None: |
| 9 | + for i,j in ops: |
| 10 | + self.P(i, j) |
| 11 | + self.sections = sections |
| 12 | + |
| 13 | + def nops(self, nodes=None): |
| 14 | + if nodes is None: |
| 15 | + return len(self.ops) |
| 16 | + res = len(self.ops) |
| 17 | + while True: |
| 18 | + if self.ops[res-1][0] in nodes or self.ops[res-1][1] in nodes or res == 0: |
| 19 | + break |
| 20 | + else: |
| 21 | + res -= 1 |
| 22 | + for i,j in self.ops[res:]: |
| 23 | + assert i not in nodes and j not in nodes |
| 24 | + return res |
| 25 | + |
| 26 | + def section(self, name): |
| 27 | + if self.sections is None: |
| 28 | + if len(self.ops) > 0: |
| 29 | + self.sections = [("<unnamed>", 0)] |
| 30 | + else: |
| 31 | + self.sections = [] |
| 32 | + self.sections.append( (name, len(self.ops)) ) |
| 33 | + |
| 34 | + def focus(self, sections): |
| 35 | + assert self.sections is not None |
| 36 | + for idx in list(range(len(self.sections)))[::-1]: |
| 37 | + name, opidx = self.sections[idx] |
| 38 | + if not name in sections: |
| 39 | + nextOpIdx = self.sections[idx+1][1] if idx < len(self.sections)-1 else len(self.ops) |
| 40 | + self.ops = self.ops[:opidx] + self.ops[nextOpIdx:] |
| 41 | + self.sections = None |
| 42 | + |
| 43 | + def P(self, i, j): |
| 44 | + """swap elements i and j if comparison result is unsorted""" |
| 45 | + assert i in self.nodes and j in self.nodes and i != j |
| 46 | + self.ops.append( (i,j) ) |
| 47 | + |
| 48 | + def __str__(self): |
| 49 | + """return ascii art of the network""" |
| 50 | + cols = [(("%3d: "%self.nodes[k])+chr(0x2500)) for k in range(self.n)] |
| 51 | + secIdx = 0 |
| 52 | + for opidx, (i,j) in enumerate(self.ops): |
| 53 | + if self.sections is not None and secIdx < len(self.sections): |
| 54 | + if opidx == self.sections[secIdx][1]: |
| 55 | + secIdx += 1 |
| 56 | + for k in range(len(cols)): |
| 57 | + cols[k] += chr(0x250a) |
| 58 | + for k in range(min(i,j)): |
| 59 | + cols[k] += chr(0x2500) |
| 60 | + if i < j: |
| 61 | + cols[i] += chr(0x252c) |
| 62 | + for k in range(i+1, j): |
| 63 | + cols[k] += chr(0x253c) |
| 64 | + cols[j] += chr(0x2534) |
| 65 | + else: |
| 66 | + cols[j] += chr(0x2565) |
| 67 | + for k in range(j+1, i): |
| 68 | + cols[k] += chr(0x256b) |
| 69 | + cols[i] += chr(0x2568) |
| 70 | + for k in range(max(i,j)+1, len(cols)): |
| 71 | + cols[k] = cols[k] + chr(0x2500) |
| 72 | + return repr(self) + "\n" + ("\n".join(cols)) |
| 73 | + |
| 74 | + def __repr__(self): |
| 75 | + return f""" |
| 76 | +BaseSortingNetwork( |
| 77 | + n={self.n}, |
| 78 | + ops={self.ops}, |
| 79 | + sections={self.sections} |
| 80 | +)""" |
| 81 | + |
| 82 | + def execute(self, d_in): |
| 83 | + assert len(d_in) == self.n |
| 84 | + assert self.nodes == list(range(self.n)) |
| 85 | + a = list(d_in[:]) |
| 86 | + for i,j in self.ops: |
| 87 | + if a[i] > a[j]: |
| 88 | + a[i], a[j] = a[j], a[i] |
| 89 | + return a |
| 90 | + |
| 91 | + def execute_stat(self, d_in, stat): |
| 92 | + assert len(stat) == len(self.ops) |
| 93 | + assert len(d_in) == self.n |
| 94 | + assert self.nodes == list(range(self.n)) |
| 95 | + a = list(d_in[:]) |
| 96 | + for opidx, (i,j) in enumerate(self.ops): |
| 97 | + if a[i] > a[j]: |
| 98 | + a[i], a[j] = a[j], a[i] |
| 99 | + stat[opidx] += 1 |
| 100 | + return a |
| 101 | + |
| 102 | + def test(self, inputs, callback_if_failed=None, callback_if_passed=None): |
| 103 | + """tests the network for each input yielded by the generator and gather statistics""" |
| 104 | + passed = 0 |
| 105 | + failed = 0 |
| 106 | + for d_in in inputs: |
| 107 | + d_out = self.execute(d_in) |
| 108 | + if not all(d_out[i] <= d_out[i+1] for i in range(self.n-1)): |
| 109 | + failed += 1 |
| 110 | + if callback_if_failed is not None: |
| 111 | + callback_if_failed(d_in, d_out) |
| 112 | + else: |
| 113 | + passed += 1 |
| 114 | + if callback_if_passed is not None: |
| 115 | + callback_if_passed(d_in, d_out) |
| 116 | + return passed, failed |
| 117 | + |
| 118 | + def opidx2secname(self): |
| 119 | + res = None |
| 120 | + if self.sections is not None: |
| 121 | + res = {} |
| 122 | + for sidx in range(len(self.sections)): |
| 123 | + name, opidx0 = self.sections[sidx] |
| 124 | + opidx1 = self.sections[sidx+1][1] if sidx+1 < len(self.sections) else len(self.ops) |
| 125 | + for opidx in range(opidx0, opidx1): |
| 126 | + res[opidx] = name |
| 127 | + return res |
| 128 | + |
| 129 | + def prune(self, input_generator, sections=None): |
| 130 | + seqidx = 0 |
| 131 | + opidx2secname = self.opidx2secname() |
| 132 | + if self.sections is not None: |
| 133 | + if sections is None: |
| 134 | + sections = [name for name, _ in self.sections] |
| 135 | + for s in sections: |
| 136 | + assert s in [name for name, _ in self.sections] |
| 137 | + else: |
| 138 | + assert sections is None |
| 139 | + stats = [0] * len(self.ops) |
| 140 | + for d_in in input_generator(): |
| 141 | + d_out = self.execute_stat(d_in, stats) |
| 142 | + indices = list(range(len(self.ops)))[::-1] |
| 143 | + numPruned = 0 |
| 144 | + for opidx in indices: |
| 145 | + if stats[opidx] == 0 and (opidx2secname is None or opidx2secname[opidx] in sections): |
| 146 | + # opidx can be removed because it was a noop |
| 147 | + numPruned += 1 |
| 148 | + self.ops = self.ops[:opidx] + self.ops[opidx+1:] |
| 149 | + if self.sections is not None: |
| 150 | + # check if we need to adapt the section indices |
| 151 | + newsections = [] |
| 152 | + for name, idx in self.sections: |
| 153 | + if idx == opidx: |
| 154 | + newsections.append( (name, idx+1) ) |
| 155 | + elif idx > opidx: |
| 156 | + newsections.append( (name, idx-1) ) |
| 157 | + else: |
| 158 | + newsections.append( (name, idx) ) |
| 159 | + return numPruned |
| 160 | + |
| 161 | + def normalize(self, method, sections=None): |
| 162 | + if sections is None and self.sections is not None: |
| 163 | + sections = [name for name, _ in self.sections] |
| 164 | + opidx2secname = self.opidx2secname() |
| 165 | + if method == "lower_indices_last": |
| 166 | + changed = True |
| 167 | + while changed: |
| 168 | + changed = False |
| 169 | + for i in range(len(self.ops)-1): |
| 170 | + if opidx2secname is None or (opidx2secname[i] in sections and opidx2secname[i+1] in sections): |
| 171 | + i0, i1 = self.ops[i] |
| 172 | + j0, j1 = self.ops[i+1] |
| 173 | + if j1 > i1: |
| 174 | + # we'd like to swap the elements if possible |
| 175 | + if i0 != j0 and i0 != j1 and i1 != j0 and i1 != j1: |
| 176 | + self.ops[i], self.ops[i+1] = self.ops[i+1], self.ops[i] |
| 177 | + changed = True |
| 178 | + return |
| 179 | + raise RuntimeError("Unknown method %s" % method) |
| 180 | + |
| 181 | + def copy(self): |
| 182 | + res = BaseSortingNetwork(self.n) |
| 183 | + res.nodes = self.nodes[:] |
| 184 | + res.ops = self.ops[:] |
| 185 | + if self.sections is not None: |
| 186 | + res.sections = self.sections[:] |
| 187 | + return res |
| 188 | + |
| 189 | + @staticmethod |
| 190 | + def relabel(sn, node_map): |
| 191 | + res = BaseSortingNetwork(sn.n) |
| 192 | + res.nodes = [node_map[i] for i in sn.nodes] |
| 193 | + if sn.sections is not None: |
| 194 | + res.sections = sn.sections[:] |
| 195 | + for i,j in sn.ops: |
| 196 | + res.P(node_map[i], node_map[j]) |
| 197 | + return res |
| 198 | + |
| 199 | + @staticmethod |
| 200 | + def append(sn1, sn2): |
| 201 | + if len(set(sn1.nodes) & set(sn2.nodes)) == 0: |
| 202 | + # intersection of nodes is empty |
| 203 | + res = BaseSortingNetwork(sn1.n + sn2.n) |
| 204 | + res.nodes = sn1.nodes + sn2.nodes |
| 205 | + res.nodes.sort() |
| 206 | + for i,j in sn1.ops + sn2.ops: |
| 207 | + res.P(i, j) |
| 208 | + if sn1.sections is not None: |
| 209 | + res.sections = sn1.sections[:] |
| 210 | + if sn2.sections is not None: |
| 211 | + if res.sections is None: res.sections = [("<unnamed>", 0)] |
| 212 | + for name, idx in sn2.sections: |
| 213 | + res.sections.append( (name, sn1.n + idx) ) |
| 214 | + return res |
| 215 | + if set(sn1.nodes) & set(sn2.nodes) == set(sn1.nodes): |
| 216 | + res = BaseSortingNetwork(sn2.n) |
| 217 | + res.nodes = sn2.nodes[:] |
| 218 | + for i,j in sn1.ops: |
| 219 | + res.P(i,j) |
| 220 | + for i,j in sn2.ops: |
| 221 | + res.P(i,j) |
| 222 | + if sn1.sections is not None: |
| 223 | + res.sections = sn1.sections[:] |
| 224 | + if sn2.sections is not None: |
| 225 | + if res.sections is None: |
| 226 | + res.sections = [("<unnamed>", 0)] |
| 227 | + for name, idx in sn2.sections: |
| 228 | + res.sections.append( (name, len(sn1.ops) + idx) ) |
| 229 | + return res |
| 230 | + if set(sn1.nodes) & set(sn2.nodes) == set(sn2.nodes): |
| 231 | + res = BaseSortingNetwork(sn1.n) |
| 232 | + res.nodes = sn1.nodes[:] |
| 233 | + for i,j in sn1.ops: |
| 234 | + res.P(i,j) |
| 235 | + for i,j in sn2.ops: |
| 236 | + res.P(i,j) |
| 237 | + if sn1.sections is not None: |
| 238 | + res.sections = sn1.sections[:] |
| 239 | + if sn2.sections is not None: |
| 240 | + if res.sections is None: |
| 241 | + res.sections = [("<unnamed>", 0)] |
| 242 | + for name, idx in sn2.sections: |
| 243 | + res.sections.append( (name, len(sn1.ops) + idx) ) |
| 244 | + return res |
| 245 | + raise RuntimeError("Don't know how to append both networks n1=%s n2=%s" % (sn1.nodes, sn2.nodes)) |
0 commit comments