Skip to content

Commit a32e71f

Browse files
author
Orbax Authors
committed
Add logs in consistent_restore_mesh to examine the invariant.
PiperOrigin-RevId: 703218403
1 parent ac2d276 commit a32e71f

File tree

1 file changed

+49
-4
lines changed
  • checkpoint/orbax/checkpoint/experimental/emergency

1 file changed

+49
-4
lines changed

checkpoint/orbax/checkpoint/experimental/emergency/multihost.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,34 @@ def consistent_restore_mesh(
132132
"""
133133
# Map how device ids changed across restarts.
134134
device_id_across_restarts = {}
135-
for i in range(len(previous_distributed_to_device_ids)):
136-
for j in range(len(previous_distributed_to_device_ids[i])):
137-
previous_id = previous_distributed_to_device_ids[i][j]
138-
current_id = current_distributed_to_device_ids[i][j]
135+
assert len(previous_distributed_to_device_ids) == len(
136+
current_distributed_to_device_ids
137+
)
138+
139+
logging.debug(
140+
'previous_distributed_to_device_ids: %s',
141+
previous_distributed_to_device_ids,
142+
)
143+
logging.debug(
144+
'current_distributed_to_device_ids: %s',
145+
current_distributed_to_device_ids,
146+
)
147+
# TODO(b/376748289): remove the following variables after bug is fixed.
148+
previous_device_to_distributed_id = {}
149+
current_device_to_distributed_id = {}
150+
for distributed_id in range(len(previous_distributed_to_device_ids)):
151+
logging.debug(
152+
'distributed_id: %s, previous_device_ids: %s, current_device_ids: %s',
153+
distributed_id,
154+
previous_distributed_to_device_ids[distributed_id],
155+
current_distributed_to_device_ids[distributed_id],
156+
)
157+
for j in range(len(previous_distributed_to_device_ids[distributed_id])):
158+
previous_id = previous_distributed_to_device_ids[distributed_id][j]
159+
current_id = current_distributed_to_device_ids[distributed_id][j]
139160
device_id_across_restarts[previous_id] = current_id
161+
previous_device_to_distributed_id[previous_id] = distributed_id
162+
current_device_to_distributed_id[current_id] = distributed_id
140163
logging.debug(
141164
'device_id_across_restarts (key: previous_id, value: current_id): %s',
142165
device_id_across_restarts,
@@ -150,6 +173,28 @@ def consistent_restore_mesh(
150173
jax_devices_by_id[device_id_across_restarts[id]]
151174
for id in previous_flattened_mesh_device_ids
152175
]
176+
logging.debug(
177+
'previous_flattened_mesh_device_ids: %s',
178+
previous_flattened_mesh_device_ids,
179+
)
180+
new_flattened_mesh_device_ids = [d.id for d in new_flattened_mesh_devices]
181+
logging.debug(
182+
'new_flattened_mesh_device_ids: %s',
183+
new_flattened_mesh_device_ids,
184+
)
185+
186+
previous_flattened_distributed_ids = [
187+
previous_device_to_distributed_id[id]
188+
for id in previous_flattened_mesh_device_ids
189+
]
190+
current_flattened_distributed_ids = [
191+
current_device_to_distributed_id[id]
192+
for id in new_flattened_mesh_device_ids
193+
]
194+
# The following is the invariant considering the distributed ids are
195+
# the same across restarts.
196+
assert previous_flattened_distributed_ids == current_flattened_distributed_ids
197+
153198
new_mesh_devices = np.array(new_flattened_mesh_devices).reshape(
154199
user_mesh.devices.shape
155200
)

0 commit comments

Comments
 (0)