@@ -47,7 +47,10 @@ def __hash__(self):
47
47
48
48
def __eq__ (self , other_memit ) -> bool :
49
49
"""In action and out actions are the same."""
50
- return self .in_act == other_memit .in_act and self .out_act == other_memit .out_act
50
+ return (
51
+ self .in_act == other_memit .in_act
52
+ and self .out_act == other_memit .out_act
53
+ )
51
54
52
55
def __lt__ (self , other_memit ) -> bool :
53
56
return repr (self ) < repr (other_memit )
@@ -107,10 +110,9 @@ def get_accessible_transitions(
107
110
accessible_transitions = dict ()
108
111
for trans in transition_iterator (transitions ):
109
112
if trans .state in accessible_states :
110
- accessible_transitions [(trans .state , trans .last_opponent_action )] = (
111
- trans .next_state ,
112
- trans .next_action ,
113
- )
113
+ accessible_transitions [
114
+ (trans .state , trans .last_opponent_action )
115
+ ] = (trans .next_state , trans .next_action )
114
116
115
117
return accessible_transitions
116
118
@@ -177,7 +179,9 @@ def get_memory_from_transitions(
177
179
transitions = get_accessible_transitions (transitions , initial_state )
178
180
179
181
# Get the incoming actions for each state.
180
- incoming_action_by_state = defaultdict (set ) # type: DefaultDict[int, Set[Action]]
182
+ incoming_action_by_state = defaultdict (
183
+ set
184
+ ) # type: DefaultDict[int, Set[Action]]
181
185
for trans in transition_iterator (transitions ):
182
186
incoming_action_by_state [trans .next_state ].add (trans .next_action )
183
187
@@ -189,17 +193,23 @@ def get_memory_from_transitions(
189
193
# That is to say that the opponent could do anything
190
194
for out_action in all_actions :
191
195
# More recent in action history
192
- starting_node = Memit (trans .next_action , trans .next_state , out_action )
196
+ starting_node = Memit (
197
+ trans .next_action , trans .next_state , out_action
198
+ )
193
199
# All incoming paths to current state
194
200
for in_action in incoming_action_by_state [trans .state ]:
195
201
# Less recent in action history
196
- ending_node = Memit (in_action , trans .state , trans .last_opponent_action )
202
+ ending_node = Memit (
203
+ in_action , trans .state , trans .last_opponent_action
204
+ )
197
205
memit_edges [starting_node ].add (ending_node )
198
206
199
207
all_memits = list (memit_edges .keys ())
200
208
201
209
pair_nodes = set ()
202
- pair_edges = defaultdict (set ) # type: DefaultDict[MemitPair, Set[MemitPair]]
210
+ pair_edges = defaultdict (
211
+ set
212
+ ) # type: DefaultDict[MemitPair, Set[MemitPair]]
203
213
# Loop through all pairs of memits.
204
214
for x , y in [(x , y ) for x in all_memits for y in all_memits ]:
205
215
if x == y and x .state == y .state :
@@ -226,7 +236,9 @@ def get_memory_from_transitions(
226
236
next_action_by_memit = dict ()
227
237
for trans in transition_iterator (transitions ):
228
238
for in_action in incoming_action_by_state [trans .state ]:
229
- memit_key = Memit (in_action , trans .state , trans .last_opponent_action )
239
+ memit_key = Memit (
240
+ in_action , trans .state , trans .last_opponent_action
241
+ )
230
242
next_action_by_memit [memit_key ] = trans .next_action
231
243
232
244
# Calculate the longest path.
@@ -251,3 +263,4 @@ def get_memory_from_transitions(
251
263
if len (next_action_set ) == 1 :
252
264
return 0
253
265
return 1
266
+
0 commit comments