Skip to content

Commit 49175e4

Browse files
author
Orbax Authors
committed
Removed atomic operations from mesh_consistency.save_process_metadata
PiperOrigin-RevId: 759460002
1 parent 4b5f801 commit 49175e4

File tree

1 file changed

+5
-20
lines changed

1 file changed

+5
-20
lines changed

checkpoint/orbax/checkpoint/experimental/emergency/mesh_consistency.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,10 @@
2020
from absl import logging
2121
from etils import epath
2222
import jax
23-
from orbax.checkpoint import options as options_lib
2423
from orbax.checkpoint.experimental.emergency import multihost as emergency_multihost
25-
from orbax.checkpoint.path import atomicity_defaults
2624
from orbax.checkpoint.path import step as step_lib
2725

2826

29-
3027
_PROCESS_METADATA_FOLDER = 'process_metadata'
3128
_PROCESS_METADATA_FILE_NAME = 'process_metadata.json'
3229
_GLOBAL_PROCESS_METADATA_FILE_NAME = 'global_process_metadata.json'
@@ -76,29 +73,17 @@ async def save_process_metadata(
7673
"""Saves process metadata to local storage. Runs on every process."""
7774
metadata_folder = process_metadata_folder(directory)
7875
logging.info('Saving process index metadata at %s', metadata_folder)
79-
if metadata_folder.exists():
80-
logging.warning(
81-
'Process metadata folder already exists at %s. Overwriting.',
82-
metadata_folder,
76+
if not metadata_folder.exists():
77+
raise FileNotFoundError(
78+
f'Process metadata folder does not exist at {metadata_folder}.'
8379
)
84-
metadata_folder.rmtree()
8580

86-
multiprocessing_options = options_lib.MultiprocessingOptions(
87-
primary_host=None
88-
)
89-
tmp_path = atomicity_defaults.get_default_temporary_path_class(
90-
metadata_folder
91-
).from_final(metadata_folder, multiprocessing_options=multiprocessing_options)
92-
await tmp_path.create()
93-
94-
(tmp_path.get() / _GLOBAL_PROCESS_METADATA_FILE_NAME).write_text(
81+
(metadata_folder / _GLOBAL_PROCESS_METADATA_FILE_NAME).write_text(
9582
json.dumps(distributed_to_device_ids)
9683
)
97-
(tmp_path.get() / _MESH_METADATA_FILE_NAME).write_text(
84+
(metadata_folder / _MESH_METADATA_FILE_NAME).write_text(
9885
json.dumps([int(id) for id in global_mesh.device_ids.flatten()])
9986
)
100-
tmp_path.finalize(
101-
)
10287

10388

10489
def consistent_restore_mesh_from_metadata(

0 commit comments

Comments
 (0)