@@ -642,137 +642,6 @@ def make_inputs_require_grad(module, input, output): # pylint: disable=unused-a
642642 return model
643643
644644
645- def prepare_universal_checkpoint_from_latest (output_dir ):
646- """Populate the universal checkpoint in output_dir/step_folder
647- - 1. read output_dir/latest to get step_folder
648- - 2. populate tmp dir in output_dir/step_folder/tmp
649- - 3. populate zero checkpoints in output_dir/step_folder/zero
650- - 4. create output_dir/latest_universal
651-
652- Items 1, 2, 3, 4 are idempotent. There is atomicity in the sense that
653- only after 4 is completed, then the output_dir/latest_universal
654- checkpoint is created in which then the universal checkpoint
655- can be loaded.
656-
657- Be aware that this creates an extra dir `zero/` in the checkpoint dir,
658- which doubles the DS checkpoint storage requirement.
659- - DS checkpoints store 3X model parameters in 32bit.
660- - e.g., will be 6X a model parameter-only checkpoint in 16bit.
661-
662- Note that this requires a latest version of deepspeed. It kind of works if
663- the model is not saving universal checkpoint info, but only in the
664- the case where advanced features like tensor parallel (TP) and
665- pipeline parallel (PP) are turned off.
666- """
667-
668- log_rank_0 (
669- f"\033 [93mPreparing universal checkpoint in { output_dir } \033 [0m" , to_print = True
670- )
671- # Third Party
672- from transformers .utils .import_utils import _is_package_available
673-
674- _ , ds_version = _is_package_available ("deepspeed" , return_version = True )
675- if ds_version < "0.14.3" :
676- raise ValueError ("universal checkpoint only supported on deepspeed >= 0.14.3" )
677-
678- start = time .time ()
679- if torch .distributed .get_rank () == 0 :
680- try :
681- # Third Party
682- from deepspeed .checkpoint import DeepSpeedCheckpoint
683- from deepspeed .checkpoint .ds_to_universal import (
684- PARAM_SHAPES ,
685- UNIVERSAL_CHECKPOINT_INFO ,
686- _check_for_required_state ,
687- _extract_zero_shard_files ,
688- _merge_tp_slice_files ,
689- _save_optimizer_state ,
690- )
691- except ImportError as exc :
692- raise ImportError (
693- "DeepSpeed-specific checkpoints cannot be saved without DeepSpeed>=0.14.3 installed"
694- ) from exc
695-
696- # read the latest file to get the step folder
697- latest_file = output_dir / "latest"
698- with open (latest_file ) as f :
699- step_folder = f .read ()
700-
701- # will process the checkpoint in the latest step folder
702- input_folder = os .path .join (output_dir , step_folder )
703-
704- # create args for the scripts below
705- class UniversalCheckpointArgs :
706- num_extract_workers : int = 1
707- num_merge_workers : int = 1
708- output_folder : str = input_folder # just put in same place
709- strict : bool = True # strict checkpoint
710-
711- args = UniversalCheckpointArgs ()
712-
713- # get the checkpoint
714- ds_checkpoint = DeepSpeedCheckpoint (input_folder )
715-
716- # hack, force this to null if we did not properly save
717- # any universal checkpoint information
718- # - this will not support any pipeline replication and other
719- # replication such as TP, row parallelism, vocab, sub_params
720- if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint .global_state :
721- warnings .warn (
722- "Universal checkpoint information not found, setting it to "
723- "an empty dictionary."
724- )
725- ds_checkpoint .global_state [UNIVERSAL_CHECKPOINT_INFO ] = {}
726- assert ds_checkpoint .tp_degree == 1 , (
727- "if universal checkpointing info is missing, TP must be absent"
728- )
729- assert ds_checkpoint .pp_degree == 1 , (
730- "if universal checkpointing info is missing, PP must be absent"
731- )
732- _check_for_required_state (ds_checkpoint )
733-
734- slice_shapes = []
735- for mp_rank_file in ds_checkpoint .mp_rank_files :
736- mp_sd = torch .load (mp_rank_file , map_location = torch .device ("cpu" ))
737- slice_shapes += mp_sd [PARAM_SHAPES ]
738-
739- # fix back to normal flat dict, merge duplicates for tp>1
740- slice_shapes = dict ((k , v ) for d in slice_shapes for k , v in d .items ())
741- temp_dir = os .path .join (args .output_folder , "tmp" )
742-
743- log_rank_0 (
744- f"\033 [93m1. Extracting ZeRO fragments into { temp_dir } \033 [0m" ,
745- to_print = True ,
746- )
747- _extract_zero_shard_files (args , ds_checkpoint , temp_dir )
748-
749- zero_output_folder = os .path .join (args .output_folder , "zero" )
750-
751- log_rank_0 (
752- f"\033 [93m2. Merging slices into { zero_output_folder } \033 [0m" , to_print = True
753- )
754- _merge_tp_slice_files (args , ds_checkpoint , slice_shapes , temp_dir )
755-
756- log_rank_0 (
757- f"\033 [93m3. Saving common optimizer states into { zero_output_folder } \033 [0m" ,
758- to_print = True ,
759- )
760- _save_optimizer_state (args , ds_checkpoint )
761-
762- log_rank_0 (
763- f"\033 [93m4. Removing temp directory { temp_dir } \033 [0m" , to_print = True
764- )
765- shutil .rmtree (temp_dir , ignore_errors = True )
766-
767- latest_file = os .path .join (output_dir , "latest_universal" )
768- log_rank_0 (f"\033 [93m5. Creating { latest_file } \033 [0m" , to_print = True )
769- with open (latest_file , "w" ) as f :
770- f .write (step_folder )
771-
772- dist .barrier ()
773- log_rank_0 (f"Preparing universal checkpoint took { time .time () - start } seconds" )
774-
775-
776645@contextmanager
777646def ensure_loadable_dolomite_checkpoint (
778647 model_name_or_path : str ,
@@ -1050,44 +919,6 @@ def _get_state_dict_patched(model, unwrap=False):
1050919 accelerator .get_state_dict = get_state_dict_unpatched
1051920
1052921
1053- # this is native deepspeed saving with optimizer, scheduler
1054- def save_model_ds_native (
1055- args ,
1056- model ,
1057- tokenizer , # pylint: disable=unused-argument
1058- samples_seen ,
1059- ):
1060- # to get a statedict from a zero checkpoint, all you need to do is
1061- # - from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
1062- # - sd = get_fp32_state_dict_from_zero_checkpoint('ckpt')
1063- # - sum([math.prod(x.shape) for x in sd.values()]) # check the size (should be correct)
1064-
1065- log_rank_0 (
1066- f"\033 [93mSaving model+optimizer+scheduler in format at samples_seen: { samples_seen } \033 [0m" ,
1067- to_print = True ,
1068- )
1069- start = time .time ()
1070- # used to save huggingface format, so we can use it for hf.from_pretrained
1071- output_dir = Path (args .output_dir ) / "ds_native"
1072- tag = f"samples_{ samples_seen } "
1073- use_lora = args .lora_r > 0
1074-
1075- # NOTE: this is a distributed save
1076- # if its lora, we only save the adapters
1077- # - so we exclude frozen if use_lora==True
1078- model .save_checkpoint (
1079- output_dir ,
1080- exclude_frozen_parameters = use_lora ,
1081- tag = tag , # this will create the subdirectory with the correct name
1082- )
1083-
1084- # for now we are not saving tokenizer, config, eg..
1085- # so it is not totally "HF compatible"
1086-
1087- log_rank_0 (f"\033 [93mModel saved in { output_dir } \033 [0m" , to_print = True )
1088- log_rank_0 (f"saving took { time .time () - start } seconds" )
1089-
1090-
1091922def set_random_seed (seed ):
1092923 if seed is not None :
1093924 random .seed (seed )
0 commit comments