Skip to content

Commit 425f5ec

Browse files
authored
Merge pull request #567 from JamesKunstle/remove-dead-code
remove old Deepspeed-native code
2 parents ccac4fd + fdc02fd commit 425f5ec

File tree

2 files changed

+0
-241
lines changed

2 files changed

+0
-241
lines changed

src/instructlab/training/main_ds.py

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@
22

33
# Standard
44
from copy import deepcopy
5-
from pathlib import Path
65
import argparse
76
import datetime
87
import functools
98
import logging
109
import math
1110
import os
12-
import re
1311
import subprocess
1412
import time
1513
import warnings
@@ -32,10 +30,8 @@
3230
try:
3331
# Third Party
3432
from deepspeed.ops.adam import FusedAdam
35-
from deepspeed.runtime.zero.utils import ZeRORuntimeException
3633
except ImportError:
3734
FusedAdam = None
38-
ZeRORuntimeException = None
3935
local_rank = int(os.getenv("LOCAL_RANK", "0"))
4036
if __name__ == "__main__" and (not local_rank or local_rank == 0):
4137
warnings.warn(
@@ -83,7 +79,6 @@
8379
ensure_loadable_dolomite_checkpoint,
8480
load_latest_full_state,
8581
prepare_peft_model,
86-
prepare_universal_checkpoint_from_latest,
8782
save_checkpoint,
8883
save_hf_format_accelerate,
8984
set_random_seed,
@@ -298,63 +293,6 @@ def make_inputs_require_grad(module, input, output): # pylint: disable=unused-a
298293
return model, lr_scheduler, optimizer, accelerator
299294

300295

301-
# this function is to check if the checkpoint provided can be resumed
302-
def maybe_resume_training(args, model):
303-
local_rank = int(os.environ["LOCAL_RANK"])
304-
305-
# DS's loading function will not raise if fails to reload a checkpoint
306-
# - if lora is used, then the checkpoints will only be for the adapters
307-
# so we need to disable load_module_strict
308-
# - load checkpoint will find the latest checkpoint
309-
# - it will also load the optimizer and scheduler states by default
310-
load_module_strict = args.lora_r == 0 # can only be true if lora is not used
311-
output_dir = Path(args.output_dir) / "ds_native"
312-
313-
try:
314-
# attempt to load a regular checkpoint first
315-
model.load_checkpoint(output_dir, load_module_strict=load_module_strict)
316-
except ZeRORuntimeException as e:
317-
if str(e).startswith("The checkpoint being loaded used a DP world size of"):
318-
# if it fails with the above exception, then a universal
319-
# checkpoint is required
320-
321-
# prepare the universal checkpoint
322-
# - by reading 'latest' to get the resumable checkpoint
323-
prepare_universal_checkpoint_from_latest(output_dir)
324-
325-
# need to do this to trigger the universal checkpoint
326-
# loading
327-
model._config.load_universal_checkpoint = True
328-
329-
# then attempt to load again
330-
model.load_checkpoint(output_dir, load_module_strict=load_module_strict)
331-
332-
# reset to regular checkpoint loading
333-
model._config.load_universal_checkpoint = False
334-
else:
335-
raise e # reraise
336-
337-
# do this to figure out the last_step
338-
latest_file = output_dir / "latest"
339-
try:
340-
with open(latest_file) as f:
341-
# there is some assumption here that the ds_native
342-
# checkpoints are tagged as <something>_(samples_seen)
343-
step_folder = f.read()
344-
(samples_seen,) = re.match("\w+_(\d+)", step_folder).groups()
345-
samples_seen = int(samples_seen)
346-
347-
last_step = samples_seen // args.effective_batch_size
348-
args.__dict__["last_step"] = last_step
349-
if local_rank == 0:
350-
logger.info("Found checkpoint at %d, resuming training", last_step)
351-
except FileNotFoundError:
352-
pass
353-
354-
# we will update the start step here
355-
return model
356-
357-
358296
def train(
359297
args,
360298
model,
@@ -512,16 +450,6 @@ def train(
512450
base_logger.debug("RANK (%d) waiting at post-save barrier.", local_rank)
513451
torch.distributed.barrier()
514452

515-
# if (
516-
# args.save_samples_ds is not None
517-
# and global_step * batch_size % args.save_samples_ds == 0
518-
# ):
519-
# save_model_ds_native(
520-
# args,
521-
# model,
522-
# tokenizer,
523-
# global_step * args.samples_per_gpu * world_size,
524-
# )
525453
global_step += 1
526454
if local_rank == 0:
527455
inner_pb.update(1)

src/instructlab/training/utils.py

Lines changed: 0 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -642,137 +642,6 @@ def make_inputs_require_grad(module, input, output): # pylint: disable=unused-a
642642
return model
643643

644644

645-
def prepare_universal_checkpoint_from_latest(output_dir):
646-
"""Populate the universal checkpoint in output_dir/step_folder
647-
- 1. read output_dir/latest to get step_folder
648-
- 2. populate tmp dir in output_dir/step_folder/tmp
649-
- 3. populate zero checkpoints in output_dir/step_folder/zero
650-
- 4. create output_dir/latest_universal
651-
652-
Items 1, 2, 3, 4 are idempotent. There is atomicity in the sense that
653-
only after 4 is completed, then the output_dir/latest_universal
654-
checkpoint is created in which then the universal checkpoint
655-
can be loaded.
656-
657-
Be aware that this creates an extra dir `zero/` in the checkpoint dir,
658-
which doubles the DS checkpoint storage requirement.
659-
- DS checkpoints store 3X model parameters in 32bit.
660-
- e.g., will be 6X a model parameter-only checkpoint in 16bit.
661-
662-
Note that this requires a latest version of deepspeed. It kind of works if
663-
the model is not saving universal checkpoint info, but only in the
664-
the case where advanced features like tensor parallel (TP) and
665-
pipeline parallel (PP) are turned off.
666-
"""
667-
668-
log_rank_0(
669-
f"\033[93mPreparing universal checkpoint in {output_dir}\033[0m", to_print=True
670-
)
671-
# Third Party
672-
from transformers.utils.import_utils import _is_package_available
673-
674-
_, ds_version = _is_package_available("deepspeed", return_version=True)
675-
if ds_version < "0.14.3":
676-
raise ValueError("universal checkpoint only supported on deepspeed >= 0.14.3")
677-
678-
start = time.time()
679-
if torch.distributed.get_rank() == 0:
680-
try:
681-
# Third Party
682-
from deepspeed.checkpoint import DeepSpeedCheckpoint
683-
from deepspeed.checkpoint.ds_to_universal import (
684-
PARAM_SHAPES,
685-
UNIVERSAL_CHECKPOINT_INFO,
686-
_check_for_required_state,
687-
_extract_zero_shard_files,
688-
_merge_tp_slice_files,
689-
_save_optimizer_state,
690-
)
691-
except ImportError as exc:
692-
raise ImportError(
693-
"DeepSpeed-specific checkpoints cannot be saved without DeepSpeed>=0.14.3 installed"
694-
) from exc
695-
696-
# read the latest file to get the step folder
697-
latest_file = output_dir / "latest"
698-
with open(latest_file) as f:
699-
step_folder = f.read()
700-
701-
# will process the checkpoint in the latest step folder
702-
input_folder = os.path.join(output_dir, step_folder)
703-
704-
# create args for the scripts below
705-
class UniversalCheckpointArgs:
706-
num_extract_workers: int = 1
707-
num_merge_workers: int = 1
708-
output_folder: str = input_folder # just put in same place
709-
strict: bool = True # strict checkpoint
710-
711-
args = UniversalCheckpointArgs()
712-
713-
# get the checkpoint
714-
ds_checkpoint = DeepSpeedCheckpoint(input_folder)
715-
716-
# hack, force this to null if we did not properly save
717-
# any universal checkpoint information
718-
# - this will not support any pipeline replication and other
719-
# replication such as TP, row parallelism, vocab, sub_params
720-
if UNIVERSAL_CHECKPOINT_INFO not in ds_checkpoint.global_state:
721-
warnings.warn(
722-
"Universal checkpoint information not found, setting it to "
723-
"an empty dictionary."
724-
)
725-
ds_checkpoint.global_state[UNIVERSAL_CHECKPOINT_INFO] = {}
726-
assert ds_checkpoint.tp_degree == 1, (
727-
"if universal checkpointing info is missing, TP must be absent"
728-
)
729-
assert ds_checkpoint.pp_degree == 1, (
730-
"if universal checkpointing info is missing, PP must be absent"
731-
)
732-
_check_for_required_state(ds_checkpoint)
733-
734-
slice_shapes = []
735-
for mp_rank_file in ds_checkpoint.mp_rank_files:
736-
mp_sd = torch.load(mp_rank_file, map_location=torch.device("cpu"))
737-
slice_shapes += mp_sd[PARAM_SHAPES]
738-
739-
# fix back to normal flat dict, merge duplicates for tp>1
740-
slice_shapes = dict((k, v) for d in slice_shapes for k, v in d.items())
741-
temp_dir = os.path.join(args.output_folder, "tmp")
742-
743-
log_rank_0(
744-
f"\033[93m1. Extracting ZeRO fragments into {temp_dir}\033[0m",
745-
to_print=True,
746-
)
747-
_extract_zero_shard_files(args, ds_checkpoint, temp_dir)
748-
749-
zero_output_folder = os.path.join(args.output_folder, "zero")
750-
751-
log_rank_0(
752-
f"\033[93m2. Merging slices into {zero_output_folder}\033[0m", to_print=True
753-
)
754-
_merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir)
755-
756-
log_rank_0(
757-
f"\033[93m3. Saving common optimizer states into {zero_output_folder}\033[0m",
758-
to_print=True,
759-
)
760-
_save_optimizer_state(args, ds_checkpoint)
761-
762-
log_rank_0(
763-
f"\033[93m4. Removing temp directory {temp_dir}\033[0m", to_print=True
764-
)
765-
shutil.rmtree(temp_dir, ignore_errors=True)
766-
767-
latest_file = os.path.join(output_dir, "latest_universal")
768-
log_rank_0(f"\033[93m5. Creating {latest_file}\033[0m", to_print=True)
769-
with open(latest_file, "w") as f:
770-
f.write(step_folder)
771-
772-
dist.barrier()
773-
log_rank_0(f"Preparing universal checkpoint took {time.time() - start} seconds")
774-
775-
776645
@contextmanager
777646
def ensure_loadable_dolomite_checkpoint(
778647
model_name_or_path: str,
@@ -1050,44 +919,6 @@ def _get_state_dict_patched(model, unwrap=False):
1050919
accelerator.get_state_dict = get_state_dict_unpatched
1051920

1052921

1053-
# this is native deepspeed saving with optimizer, scheduler
1054-
def save_model_ds_native(
1055-
args,
1056-
model,
1057-
tokenizer, # pylint: disable=unused-argument
1058-
samples_seen,
1059-
):
1060-
# to get a statedict from a zero checkpoint, all you need to do is
1061-
# - from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
1062-
# - sd = get_fp32_state_dict_from_zero_checkpoint('ckpt')
1063-
# - sum([math.prod(x.shape) for x in sd.values()]) # check the size (should be correct)
1064-
1065-
log_rank_0(
1066-
f"\033[93mSaving model+optimizer+scheduler in format at samples_seen: {samples_seen}\033[0m",
1067-
to_print=True,
1068-
)
1069-
start = time.time()
1070-
# used to save huggingface format, so we can use it for hf.from_pretrained
1071-
output_dir = Path(args.output_dir) / "ds_native"
1072-
tag = f"samples_{samples_seen}"
1073-
use_lora = args.lora_r > 0
1074-
1075-
# NOTE: this is a distributed save
1076-
# if its lora, we only save the adapters
1077-
# - so we exclude frozen if use_lora==True
1078-
model.save_checkpoint(
1079-
output_dir,
1080-
exclude_frozen_parameters=use_lora,
1081-
tag=tag, # this will create the subdirectory with the correct name
1082-
)
1083-
1084-
# for now we are not saving tokenizer, config, eg..
1085-
# so it is not totally "HF compatible"
1086-
1087-
log_rank_0(f"\033[93mModel saved in {output_dir}\033[0m", to_print=True)
1088-
log_rank_0(f"saving took {time.time() - start} seconds")
1089-
1090-
1091922
def set_random_seed(seed):
1092923
if seed is not None:
1093924
random.seed(seed)

0 commit comments

Comments
 (0)