Skip to content

Commit

Permalink
parallelize writing of layer checkpoint files across data parallel in…
Browse files Browse the repository at this point in the history
…stances (#1419)

* parallelize layer checkpoints across data parallel groups

* use partition_uniform to determine start/end index values

* formatting fix

* config: add option for parallel write of layer checkpoints in pipeline stage

* yapf fixes

* enable parallel layer write according to config param

* avoid extraneous makedir when rank 0 writes all layers

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
adammoody and tjruwase authored Oct 21, 2022
1 parent 99fde3b commit b8fb9c3
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 12 deletions.
15 changes: 15 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,19 @@ def get_checkpoint_tag_validation_mode(checkpoint_params):
)


def get_checkpoint_parallel_write_pipeline(checkpoint_params):
par_write_params = checkpoint_params.get(CHECKPOINT_PARALLEL_WRITE, {})
par_write_pipeline = par_write_params.get(
CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE,
CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE_DEFAULT)
if par_write_pipeline in [True, False]:
return par_write_pipeline
else:
raise DeepSpeedConfigError(
"checkpoint::parallel_write::pipeline_stage "
f"value of '{par_write_pipeline}' is invalid, expecting: true or false")


def get_dataloader_drop_last(param_dict):
return get_scalar_param(param_dict,
DATALOADER_DROP_LAST,
Expand Down Expand Up @@ -887,6 +900,8 @@ def _initialize_params(self, param_dict):
self.load_universal_checkpoint = checkpoint_params.get(
LOAD_UNIVERSAL_CHECKPOINT,
LOAD_UNIVERSAL_CHECKPOINT_DEFAULT)
par_write_pipe = get_checkpoint_parallel_write_pipeline(checkpoint_params)
self.checkpoint_parallel_write_pipeline = par_write_pipe

self.aio_config = get_aio_config(param_dict)

Expand Down
7 changes: 7 additions & 0 deletions deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,9 @@ class ValidationMode:
# "checkpoint": {
# tag_validation=["Ignore"|"Warn"|"Fail"]
# load_universal=false
# parallel_write: {
# pipeline_stage: [True|False]
# }
# }
CHECKPOINT = "checkpoint"
CHECKPOINT_TAG_VALIDATION = "tag_validation"
Expand All @@ -380,6 +383,10 @@ class ValidationMode:
LOAD_UNIVERSAL_CHECKPOINT = "load_universal"
LOAD_UNIVERSAL_CHECKPOINT_DEFAULT = False

CHECKPOINT_PARALLEL_WRITE = "parallel_write"
CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE = "pipeline_stage"
CHECKPOINT_PARALLEL_WRITE_PIPELINE_STAGE_DEFAULT = False

#########################################
# Drop the last incomplete Batch
# #########################################
Expand Down
24 changes: 17 additions & 7 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2924,7 +2924,11 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True)
self._create_checkpoint_file(save_dir, tag, False)
self._save_moe_checkpoint(save_dir, tag, client_state=client_state)

if self.save_non_zero_checkpoint:
# We distribute the task of saving layer checkpoint files among
# data parallel instances, so all procs should call _save_checkpoint.
# All procs then call module_state_dict(), but only procs of data
# parallel rank 0 save the general model params.
if not self.has_moe_layers:
self._create_checkpoint_file(save_dir, tag, False)
self._save_checkpoint(save_dir, tag, client_state=client_state)

Expand Down Expand Up @@ -3091,12 +3095,18 @@ def _create_zero_checkpoint_files(self, save_dir, tag):
def _save_checkpoint(self, save_dir, tag, client_state={}):

save_path = self._get_ckpt_name(save_dir, tag)

zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()

# A hack to save the checkpointing directory. Pipeline parallelism overrides
# module_state_dict() and uses this path to save the model. module_state_dict()
# then instead just returns None.
# then instead just returns None. The module_state_dict() implementation in
# PipelineEngine expects the save path to be set in self._curr_ckpt_path.
self._curr_ckpt_path = os.path.join(save_dir, tag)
zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()
state = dict(module=self.module_state_dict(),
module = self.module_state_dict()
self._curr_ckpt_path = None

state = dict(module=module,
buffer_names=self._get_buffer_names(),
optimizer=self.optimizer.state_dict()
if self.optimizer and not zero_optimizer_state else None,
Expand All @@ -3114,9 +3124,9 @@ def _save_checkpoint(self, save_dir, tag, client_state={}):
ds_version=version)
state.update(client_state)

log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1])
self.checkpoint_engine.save(state, save_path)
self._curr_save_path = None
if self.save_non_zero_checkpoint:
log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0, 1])
self.checkpoint_engine.save(state, save_path)

def _get_buffer_names(self):
buffer_names = []
Expand Down
2 changes: 2 additions & 0 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
self.module.activation_checkpoint_interval = self._config.pipeline[
'activation_checkpoint_interval']

self.module.checkpoint_parallel_write_pipeline = self._config.checkpoint_parallel_write_pipeline

if self.is_last_stage():
self.loss_model = self.module.loss_fn

Expand Down
25 changes: 20 additions & 5 deletions deepspeed/runtime/pipe/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,13 +562,28 @@ def ckpt_layer_path_list(self, ckpt_dir, local_layer_idx):
return ckpt_files

def save_state_dict(self, save_dir, checkpoint_engine):
if self._grid.data_parallel_id != 0:
return
# Processes having the same model parallel rank on different data parallel instances
# have identical layer weights. We can distribute the task of saving the layer weights
# among the data parallel ranks. For example, if a pipeline stage has 9 layers and
# if there are 2 data parallel instances, rank 0 will save the first 5 layers and
# rank 1 will save the last 4.
dp_rank = self._grid.data_parallel_id
dp_size = self._grid.data_parallel_size
num_layers = len(self.forward_funcs)
if self.checkpoint_parallel_write_pipeline:
# spread layers evenly across data parallel ranks
offsets = ds_utils.partition_uniform(num_layers, dp_size)
start, end = offsets[dp_rank], offsets[dp_rank + 1]
else:
# data parallel rank 0 writes all layers
if dp_rank != 0:
return
start, end = 0, num_layers
layer_list = self.forward_funcs[start:end]

os.makedirs(save_dir, exist_ok=True)
layer_offset = self._local_start
for idx, layer in enumerate(self.forward_funcs):
model_ckpt_path = self.ckpt_layer_path(save_dir, idx)
for idx, layer in enumerate(layer_list):
model_ckpt_path = self.ckpt_layer_path(save_dir, start + idx)
if not hasattr(layer, 'state_dict'):
continue
# We pass cloned tensors to torch.save() to avoid checkpoint bloat which occurs because torch.save()
Expand Down

0 comments on commit b8fb9c3

Please sign in to comment.