1
- from typing import NewType , Sequence , Mapping , Optional , Dict , List , Tuple
1
+ from typing import NewType , Sequence , Mapping , Optional , Dict , List , Tuple , Set , Union
2
+ from numbers import Real
2
3
from mypy_extensions import TypedDict
3
4
import logging
4
5
from fractions import Fraction
6
+ from abc import ABC , abstractmethod
5
7
6
8
SourceObject = NewType ('SourceObject' , str )
7
9
TargetObject = NewType ('TargetObject' , str )
18
20
19
21
logger = logging .getLogger (__name__ )
20
22
23
+ class WeightedMap (ABC ):
24
+ '''
25
+ '''
21
26
22
- class WeightedMap (list ):
27
+ @abstractmethod
28
+ def __getitem__ (self , key ):
29
+ '''Returns the weight corresponding to (source, target)
30
+ if key is a tuple (source, target).
31
+ Else, if key is a str, returns the list of the WeightedNodes
32
+ such that its source is key
33
+ '''
34
+
35
+ @abstractmethod
36
+ def get_sources (self ) -> Set [SourceObject ]:
37
+ '''Returns the set of sources'''
38
+
39
+ @abstractmethod
40
+ def total_weight (self ) -> Real :
41
+ '''Returns the sum of all the weights'''
42
+
43
+ @abstractmethod
44
+ def apply (self , fun ) -> None :
45
+ '''applies function `fun` to all weights'''
46
+
47
+ @abstractmethod
48
+ def add_weight (self , w : WeightedNode ) -> None :
49
+ '''adds the node to the map'''
50
+
51
+
52
+ class ListWeightedMap (list , WeightedMap ):
23
53
24
54
def __init__ (self , nodes : Sequence [WeightedNode ]):
25
55
super ().__init__ (nodes )
@@ -38,11 +68,64 @@ def __getitem__(self, key):
38
68
39
69
raise KeyError (f"Can't get item, argument must be str, None or tuple. Got { key } " )
40
70
71
+ def add_weight (self , w ):
72
+ self .append (w )
73
+
74
+ def total_weight (self ):
75
+ return sum (w ['weight' ] for w in self )
76
+
77
+ def get_sources (self ):
78
+ return {a ['from' ] for a in self }
79
+
80
+ def apply (self , fun ):
81
+ for ftw in self :
82
+ ftw ['weight' ] = fun (ftw ['weight' ])
83
+
84
+
85
+ class DictWeightedMap (WeightedMap ):
86
+
87
+ def __init__ (self , nodes : Sequence [WeightedNode ]):
88
+ self .sources = set (w ['from' ] for w in nodes )
89
+ self .targets = set (w ['to' ] for w in nodes )
90
+ self .weights = {(w ['from' ], w ['to' ]): w ['weight' ]
91
+ for w in nodes
92
+ }
93
+
94
+ def __getitem__ (self , key ):
95
+
96
+ if isinstance (key , str ) or key is None :
97
+ return [
98
+ {'from' : ft [0 ], 'to' : ft [1 ], 'weight' : w }
99
+ for (ft , w ) in self .weights .items ()
100
+ if ft [0 ] == key
101
+ ]
102
+
103
+ elif isinstance (key , tuple ):
104
+ return self .weights .get (key , None )
105
+
106
+ raise KeyError (f"Can't get item, argument must be str, None or tuple. Got { key } " )
107
+
108
+ def add_weight (self , w ):
109
+ self .sources .add (w ['from' ])
110
+ self .targets .add (w ['to' ])
111
+ self .weights [(w ['from' ], w ['to' ])] = w ['weight' ]
112
+
113
+ def total_weight (self ):
114
+ return sum (self .weights .values ())
115
+
116
+ def get_sources (self ):
117
+ return self .sources
118
+
119
+ def apply (self , fun ):
120
+ for ft , w in self .weights .items ():
121
+ self .weights [ft ] = fun (w )
122
+
123
+
41
124
42
125
class Source :
43
126
44
127
def __init__ (self , wmap : WeightedMap , instances : Mapping [Optional [SourceObject ], int ]):
45
- self .collection = { a [ 'from' ] for a in wmap }
128
+ self .collection = wmap . get_sources ()
46
129
self .wmap = wmap
47
130
self .instances = instances
48
131
@@ -104,17 +187,16 @@ def __init__(self,
104
187
'''
105
188
106
189
if limit_denominator :
107
- for weight in wmap :
108
- weight ['weight' ] = Fraction (weight ['weight' ]).limit_denominator (limit_denominator )
190
+ wmap .apply (lambda w : Fraction (w ).limit_denominator (limit_denominator ))
109
191
sources_total_qty = sum (sources .values ())
110
192
targets_total_qty = sum (targets .values ())
111
193
112
- max_val = sum ( abs ( m [ 'weight' ]) for m in wmap )
194
+ max_val = wmap . total_weight ( )
113
195
for t in targets .keys ():
114
- wmap .append ({'from' : None , 'to' : t , 'weight' : max_val + 1 })
196
+ wmap .add_weight ({'from' : None , 'to' : t , 'weight' : max_val + 1 })
115
197
for s in sources .keys ():
116
- wmap .append ({'from' : s , 'to' : None , 'weight' : max_val + 1 })
117
- wmap .append ({'from' : None , 'to' : None , 'weight' : - 1 })
198
+ wmap .add_weight ({'from' : s , 'to' : None , 'weight' : max_val + 1 })
199
+ wmap .add_weight ({'from' : None , 'to' : None , 'weight' : - 1 })
118
200
119
201
sources [None ] = targets_total_qty
120
202
targets [None ] = sources_total_qty
0 commit comments