|
20 | 20 | from absl import logging
|
21 | 21 | from etils import epath
|
22 | 22 | import jax
|
23 |
| -from orbax.checkpoint import options as options_lib |
24 | 23 | from orbax.checkpoint.experimental.emergency import multihost as emergency_multihost
|
25 |
| -from orbax.checkpoint.path import atomicity_defaults |
26 | 24 | from orbax.checkpoint.path import step as step_lib
|
27 | 25 |
|
28 | 26 |
|
29 |
| - |
30 | 27 | _PROCESS_METADATA_FOLDER = 'process_metadata'
|
31 | 28 | _PROCESS_METADATA_FILE_NAME = 'process_metadata.json'
|
32 | 29 | _GLOBAL_PROCESS_METADATA_FILE_NAME = 'global_process_metadata.json'
|
@@ -76,29 +73,17 @@ async def save_process_metadata(
|
76 | 73 | """Saves process metadata to local storage. Runs on every process."""
|
77 | 74 | metadata_folder = process_metadata_folder(directory)
|
78 | 75 | logging.info('Saving process index metadata at %s', metadata_folder)
|
79 |
| - if metadata_folder.exists(): |
80 |
| - logging.warning( |
81 |
| - 'Process metadata folder already exists at %s. Overwriting.', |
82 |
| - metadata_folder, |
| 76 | + if not metadata_folder.exists(): |
| 77 | + raise FileNotFoundError( |
| 78 | + f'Process metadata folder does not exist at {metadata_folder}.' |
83 | 79 | )
|
84 |
| - metadata_folder.rmtree() |
85 | 80 |
|
86 |
| - multiprocessing_options = options_lib.MultiprocessingOptions( |
87 |
| - primary_host=None |
88 |
| - ) |
89 |
| - tmp_path = atomicity_defaults.get_default_temporary_path_class( |
90 |
| - metadata_folder |
91 |
| - ).from_final(metadata_folder, multiprocessing_options=multiprocessing_options) |
92 |
| - await tmp_path.create() |
93 |
| - |
94 |
| - (tmp_path.get() / _GLOBAL_PROCESS_METADATA_FILE_NAME).write_text( |
| 81 | + (metadata_folder / _GLOBAL_PROCESS_METADATA_FILE_NAME).write_text( |
95 | 82 | json.dumps(distributed_to_device_ids)
|
96 | 83 | )
|
97 |
| - (tmp_path.get() / _MESH_METADATA_FILE_NAME).write_text( |
| 84 | + (metadata_folder / _MESH_METADATA_FILE_NAME).write_text( |
98 | 85 | json.dumps([int(id) for id in global_mesh.device_ids.flatten()])
|
99 | 86 | )
|
100 |
| - tmp_path.finalize( |
101 |
| - ) |
102 | 87 |
|
103 | 88 |
|
104 | 89 | def consistent_restore_mesh_from_metadata(
|
|
0 commit comments