@@ -132,11 +132,34 @@ def consistent_restore_mesh(
132
132
"""
133
133
# Map how device ids changed across restarts.
134
134
device_id_across_restarts = {}
135
- for i in range (len (previous_distributed_to_device_ids )):
136
- for j in range (len (previous_distributed_to_device_ids [i ])):
137
- previous_id = previous_distributed_to_device_ids [i ][j ]
138
- current_id = current_distributed_to_device_ids [i ][j ]
135
+ assert len (previous_distributed_to_device_ids ) == len (
136
+ current_distributed_to_device_ids
137
+ )
138
+
139
+ logging .debug (
140
+ 'previous_distributed_to_device_ids: %s' ,
141
+ previous_distributed_to_device_ids ,
142
+ )
143
+ logging .debug (
144
+ 'current_distributed_to_device_ids: %s' ,
145
+ current_distributed_to_device_ids ,
146
+ )
147
+ # TODO(b/376748289): remove the following variables after bug is fixed.
148
+ previous_device_to_distributed_id = {}
149
+ current_device_to_distributed_id = {}
150
+ for distributed_id in range (len (previous_distributed_to_device_ids )):
151
+ logging .debug (
152
+ 'distributed_id: %s, previous_device_ids: %s, current_device_ids: %s' ,
153
+ distributed_id ,
154
+ previous_distributed_to_device_ids [distributed_id ],
155
+ current_distributed_to_device_ids [distributed_id ],
156
+ )
157
+ for j in range (len (previous_distributed_to_device_ids [distributed_id ])):
158
+ previous_id = previous_distributed_to_device_ids [distributed_id ][j ]
159
+ current_id = current_distributed_to_device_ids [distributed_id ][j ]
139
160
device_id_across_restarts [previous_id ] = current_id
161
+ previous_device_to_distributed_id [previous_id ] = distributed_id
162
+ current_device_to_distributed_id [current_id ] = distributed_id
140
163
logging .debug (
141
164
'device_id_across_restarts (key: previous_id, value: current_id): %s' ,
142
165
device_id_across_restarts ,
@@ -150,6 +173,28 @@ def consistent_restore_mesh(
150
173
jax_devices_by_id [device_id_across_restarts [id ]]
151
174
for id in previous_flattened_mesh_device_ids
152
175
]
176
+ logging .debug (
177
+ 'previous_flattened_mesh_device_ids: %s' ,
178
+ previous_flattened_mesh_device_ids ,
179
+ )
180
+ new_flattened_mesh_device_ids = [d .id for d in new_flattened_mesh_devices ]
181
+ logging .debug (
182
+ 'new_flattened_mesh_device_ids: %s' ,
183
+ new_flattened_mesh_device_ids ,
184
+ )
185
+
186
+ previous_flattened_distributed_ids = [
187
+ previous_device_to_distributed_id [id ]
188
+ for id in previous_flattened_mesh_device_ids
189
+ ]
190
+ current_flattened_distributed_ids = [
191
+ current_device_to_distributed_id [id ]
192
+ for id in new_flattened_mesh_device_ids
193
+ ]
194
+ # The following is the invariant considering the distributed ids are
195
+ # the same across restarts.
196
+ assert previous_flattened_distributed_ids == current_flattened_distributed_ids
197
+
153
198
new_mesh_devices = np .array (new_flattened_mesh_devices ).reshape (
154
199
user_mesh .devices .shape
155
200
)
0 commit comments