Skip to content

Commit c75d76b

Browse files
committed
add new DictWeightedMap class.
Also: * make new ABC class WeightedMap and * rename previous WeightedMap to ListWeightedMap * used operations on WeightedMap are abstracted. Alas, DictWeightedMap is not faster than ListWeightedMap.
1 parent e1b6b8b commit c75d76b

File tree

1 file changed

+91
-9
lines changed

1 file changed

+91
-9
lines changed

allocation/allocating.py

+91-9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from typing import NewType, Sequence, Mapping, Optional, Dict, List, Tuple
1+
from typing import NewType, Sequence, Mapping, Optional, Dict, List, Tuple, Set, Union
2+
from numbers import Real
23
from mypy_extensions import TypedDict
34
import logging
45
from fractions import Fraction
6+
from abc import ABC, abstractmethod
57

68
SourceObject = NewType('SourceObject', str)
79
TargetObject = NewType('TargetObject', str)
@@ -18,8 +20,36 @@
1820

1921
logger = logging.getLogger(__name__)
2022

23+
class WeightedMap(ABC):
24+
'''
25+
'''
2126

22-
class WeightedMap(list):
27+
@abstractmethod
28+
def __getitem__(self, key):
29+
'''Returns the weight corresponding to (source, target)
30+
if key is a tuple (source, target).
31+
Else, if key is a str, returns the list of the WeightedNodes
32+
such that its source is key
33+
'''
34+
35+
@abstractmethod
36+
def get_sources(self) -> Set[SourceObject]:
37+
'''Returns the set of sources'''
38+
39+
@abstractmethod
40+
def total_weight(self) -> Real:
41+
'''Returns the sum of all the weights'''
42+
43+
@abstractmethod
44+
def apply(self, fun) -> None:
45+
'''applies function `fun` to all weights'''
46+
47+
@abstractmethod
48+
def add_weight(self, w: WeightedNode) -> None:
49+
'''adds the node to the map'''
50+
51+
52+
class ListWeightedMap(list, WeightedMap):
2353

2454
def __init__(self, nodes: Sequence[WeightedNode]):
2555
super().__init__(nodes)
@@ -38,11 +68,64 @@ def __getitem__(self, key):
3868

3969
raise KeyError(f"Can't get item, argument must be str, None or tuple. Got {key}")
4070

71+
def add_weight(self, w):
72+
self.append(w)
73+
74+
def total_weight(self):
75+
return sum(w['weight'] for w in self)
76+
77+
def get_sources(self):
78+
return {a['from'] for a in self}
79+
80+
def apply(self, fun):
81+
for ftw in self:
82+
ftw['weight'] = fun(ftw['weight'])
83+
84+
85+
class DictWeightedMap(WeightedMap):
86+
87+
def __init__(self, nodes: Sequence[WeightedNode]):
88+
self.sources = set(w['from'] for w in nodes)
89+
self.targets = set(w['to'] for w in nodes)
90+
self.weights = {(w['from'], w['to']): w['weight']
91+
for w in nodes
92+
}
93+
94+
def __getitem__(self, key):
95+
96+
if isinstance(key, str) or key is None:
97+
return [
98+
{'from': ft[0], 'to': ft[1], 'weight': w}
99+
for (ft, w) in self.weights.items()
100+
if ft[0] == key
101+
]
102+
103+
elif isinstance(key, tuple):
104+
return self.weights.get(key, None)
105+
106+
raise KeyError(f"Can't get item, argument must be str, None or tuple. Got {key}")
107+
108+
def add_weight(self, w):
109+
self.sources.add(w['from'])
110+
self.targets.add(w['to'])
111+
self.weights[(w['from'], w['to'])] = w['weight']
112+
113+
def total_weight(self):
114+
return sum(self.weights.values())
115+
116+
def get_sources(self):
117+
return self.sources
118+
119+
def apply(self, fun):
120+
for ft, w in self.weights.items():
121+
self.weights[ft] = fun(w)
122+
123+
41124

42125
class Source:
43126

44127
def __init__(self, wmap: WeightedMap, instances: Mapping[Optional[SourceObject], int]):
45-
self.collection = {a['from'] for a in wmap}
128+
self.collection = wmap.get_sources()
46129
self.wmap = wmap
47130
self.instances = instances
48131

@@ -104,17 +187,16 @@ def __init__(self,
104187
'''
105188

106189
if limit_denominator:
107-
for weight in wmap:
108-
weight['weight'] = Fraction(weight['weight']).limit_denominator(limit_denominator)
190+
wmap.apply(lambda w: Fraction(w).limit_denominator(limit_denominator))
109191
sources_total_qty = sum(sources.values())
110192
targets_total_qty = sum(targets.values())
111193

112-
max_val = sum(abs(m['weight']) for m in wmap)
194+
max_val = wmap.total_weight()
113195
for t in targets.keys():
114-
wmap.append({'from': None, 'to': t, 'weight': max_val + 1})
196+
wmap.add_weight({'from': None, 'to': t, 'weight': max_val + 1})
115197
for s in sources.keys():
116-
wmap.append({'from': s, 'to': None, 'weight': max_val + 1})
117-
wmap.append({'from': None, 'to': None, 'weight': -1})
198+
wmap.add_weight({'from': s, 'to': None, 'weight': max_val + 1})
199+
wmap.add_weight({'from': None, 'to': None, 'weight': -1})
118200

119201
sources[None] = targets_total_qty
120202
targets[None] = sources_total_qty

0 commit comments

Comments
 (0)