@@ -132,11 +132,34 @@ def consistent_restore_mesh(
132132 """
133133 # Map how device ids changed across restarts.
134134 device_id_across_restarts = {}
135- for i in range (len (previous_distributed_to_device_ids )):
136- for j in range (len (previous_distributed_to_device_ids [i ])):
137- previous_id = previous_distributed_to_device_ids [i ][j ]
138- current_id = current_distributed_to_device_ids [i ][j ]
135+ assert len (previous_distributed_to_device_ids ) == len (
136+ current_distributed_to_device_ids
137+ )
138+
139+ logging .debug (
140+ 'previous_distributed_to_device_ids: %s' ,
141+ previous_distributed_to_device_ids ,
142+ )
143+ logging .debug (
144+ 'current_distributed_to_device_ids: %s' ,
145+ current_distributed_to_device_ids ,
146+ )
147+ # TODO(b/376748289): remove the following variables after bug is fixed.
148+ previous_device_to_distributed_id = {}
149+ current_device_to_distributed_id = {}
150+ for distributed_id in range (len (previous_distributed_to_device_ids )):
151+ logging .debug (
152+ 'distributed_id: %s, previous_device_ids: %s, current_device_ids: %s' ,
153+ distributed_id ,
154+ previous_distributed_to_device_ids [distributed_id ],
155+ current_distributed_to_device_ids [distributed_id ],
156+ )
157+ for j in range (len (previous_distributed_to_device_ids [distributed_id ])):
158+ previous_id = previous_distributed_to_device_ids [distributed_id ][j ]
159+ current_id = current_distributed_to_device_ids [distributed_id ][j ]
139160 device_id_across_restarts [previous_id ] = current_id
161+ previous_device_to_distributed_id [previous_id ] = distributed_id
162+ current_device_to_distributed_id [current_id ] = distributed_id
140163 logging .debug (
141164 'device_id_across_restarts (key: previous_id, value: current_id): %s' ,
142165 device_id_across_restarts ,
@@ -150,6 +173,28 @@ def consistent_restore_mesh(
150173 jax_devices_by_id [device_id_across_restarts [id ]]
151174 for id in previous_flattened_mesh_device_ids
152175 ]
176+ logging .debug (
177+ 'previous_flattened_mesh_device_ids: %s' ,
178+ previous_flattened_mesh_device_ids ,
179+ )
180+ new_flattened_mesh_device_ids = [d .id for d in new_flattened_mesh_devices ]
181+ logging .debug (
182+ 'new_flattened_mesh_device_ids: %s' ,
183+ new_flattened_mesh_device_ids ,
184+ )
185+
186+ previous_flattened_distributed_ids = [
187+ previous_device_to_distributed_id [id ]
188+ for id in previous_flattened_mesh_device_ids
189+ ]
190+ current_flattened_distributed_ids = [
191+ current_device_to_distributed_id [id ]
192+ for id in new_flattened_mesh_device_ids
193+ ]
194+ # The following is the invariant considering the distributed ids are
195+ # the same across restarts.
196+ assert previous_flattened_distributed_ids == current_flattened_distributed_ids
197+
153198 new_mesh_devices = np .array (new_flattened_mesh_devices ).reshape (
154199 user_mesh .devices .shape
155200 )
0 commit comments