Skip to content

Commit

Permalink
Add logs in consistent_restore_mesh to examine the invariant.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 703218403
  • Loading branch information
Orbax Authors committed Dec 5, 2024
1 parent ac2d276 commit a32e71f
Showing 1 changed file with 49 additions and 4 deletions.
53 changes: 49 additions & 4 deletions checkpoint/orbax/checkpoint/experimental/emergency/multihost.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,34 @@ def consistent_restore_mesh(
"""
# Map how device ids changed across restarts.
device_id_across_restarts = {}
for i in range(len(previous_distributed_to_device_ids)):
for j in range(len(previous_distributed_to_device_ids[i])):
previous_id = previous_distributed_to_device_ids[i][j]
current_id = current_distributed_to_device_ids[i][j]
assert len(previous_distributed_to_device_ids) == len(
current_distributed_to_device_ids
)

logging.debug(
'previous_distributed_to_device_ids: %s',
previous_distributed_to_device_ids,
)
logging.debug(
'current_distributed_to_device_ids: %s',
current_distributed_to_device_ids,
)
# TODO(b/376748289): remove the following variables after bug is fixed.
previous_device_to_distributed_id = {}
current_device_to_distributed_id = {}
for distributed_id in range(len(previous_distributed_to_device_ids)):
logging.debug(
'distributed_id: %s, previous_device_ids: %s, current_device_ids: %s',
distributed_id,
previous_distributed_to_device_ids[distributed_id],
current_distributed_to_device_ids[distributed_id],
)
for j in range(len(previous_distributed_to_device_ids[distributed_id])):
previous_id = previous_distributed_to_device_ids[distributed_id][j]
current_id = current_distributed_to_device_ids[distributed_id][j]
device_id_across_restarts[previous_id] = current_id
previous_device_to_distributed_id[previous_id] = distributed_id
current_device_to_distributed_id[current_id] = distributed_id
logging.debug(
'device_id_across_restarts (key: previous_id, value: current_id): %s',
device_id_across_restarts,
Expand All @@ -150,6 +173,28 @@ def consistent_restore_mesh(
jax_devices_by_id[device_id_across_restarts[id]]
for id in previous_flattened_mesh_device_ids
]
logging.debug(
'previous_flattened_mesh_device_ids: %s',
previous_flattened_mesh_device_ids,
)
new_flattened_mesh_device_ids = [d.id for d in new_flattened_mesh_devices]
logging.debug(
'new_flattened_mesh_device_ids: %s',
new_flattened_mesh_device_ids,
)

previous_flattened_distributed_ids = [
previous_device_to_distributed_id[id]
for id in previous_flattened_mesh_device_ids
]
current_flattened_distributed_ids = [
current_device_to_distributed_id[id]
for id in new_flattened_mesh_device_ids
]
# The following is the invariant considering the distributed ids are
# the same across restarts.
assert previous_flattened_distributed_ids == current_flattened_distributed_ids

new_mesh_devices = np.array(new_flattened_mesh_devices).reshape(
user_mesh.devices.shape
)
Expand Down

0 comments on commit a32e71f

Please sign in to comment.