Skip to content

Commit 80d8fcb

Browse files
tjruwasejeffra
andauthored
Improve overflow handling (#2944)
Co-authored-by: Jeff Rasley <[email protected]>
1 parent 87eaf8f commit 80d8fcb

File tree

3 files changed

+36
-39
lines changed

3 files changed

+36
-39
lines changed

deepspeed/runtime/fp16/loss_scaler.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
# https://github.com/NVIDIA/Megatron-LM/blob/master/fp16/loss_scaler.py
1717
#Commit: 93ab4bea59dc5cbf97c079d313741866af4deac9
1818

19+
import torch
20+
1921
INITIAL_LOSS_SCALE = 'init_scale'
2022
SCALE_WINDOW = 'scale_window'
2123
DELAYED_SHIFT = 'delayed_shift'
@@ -35,6 +37,7 @@ class LossScalerBase:
3537
"""
3638
def __init__(self, cur_scale):
3739
self.cur_scale = cur_scale
40+
self.dynamic = False
3841

3942
@property
4043
def loss_scale(self):
@@ -117,6 +120,7 @@ def __init__(self,
117120
self.cur_hysteresis = delayed_shift
118121
self.consecutive_hysteresis = consecutive_hysteresis
119122
self.raise_error_at_min_scale = raise_error_at_min_scale
123+
self.dynamic = True
120124

121125
# `params` is a list / generator of torch.Variable
122126
def has_overflow_serial(self, params):
@@ -170,6 +174,18 @@ def update_scale(self, overflow):
170174
self.cur_iter += 1
171175

172176

177+
# Although loss scaling is only defined for fp16, yet for backwards compatibility
178+
# we still create a scaler for other dtypes (fp32, bf16) which does not perform any scaling.
179+
def CreateLossScaler(dtype, static_loss_scale, dynamic_scaling, dynamic_loss_args):
180+
if dtype == torch.half and dynamic_scaling:
181+
if dynamic_loss_args is None:
182+
return DynamicLossScaler()
183+
return DynamicLossScaler(**dynamic_loss_args)
184+
185+
loss_scale_value = static_loss_scale if dtype == torch.half else 1.0
186+
return LossScaler(scale=loss_scale_value)
187+
188+
173189
##############################################################
174190
# Example usage below here -- assuming it's in a separate file
175191
##############################################################

deepspeed/runtime/zero/stage3.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from deepspeed.runtime import ZeROOptimizer
1212
from deepspeed.utils import logger
13-
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
13+
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
1414
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced
1515
from deepspeed.runtime.utils import inf, get_global_norm, is_model_parallel_parameter
1616
from deepspeed.runtime.zero.partition_parameters import *
@@ -332,18 +332,11 @@ def __init__(self,
332332
#exit(0)
333333

334334
# we may have a way of fusing dynamic scale. Do not support for now
335-
if self.dtype == torch.float or not dynamic_loss_scale:
336-
loss_scale_value = 1.0 if self.dtype == torch.float else static_loss_scale
337-
338-
self.dynamic_loss_scale = False
339-
self.loss_scaler = LossScaler(scale=loss_scale_value)
340-
else:
341-
if dynamic_loss_args is None:
342-
self.loss_scaler = DynamicLossScaler()
343-
else:
344-
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
345-
346-
self.dynamic_loss_scale = True
335+
self.loss_scaler = CreateLossScaler(dtype=self.dtype,
336+
static_loss_scale=static_loss_scale,
337+
dynamic_scaling=dynamic_loss_scale,
338+
dynamic_loss_args=dynamic_loss_args)
339+
self.dynamic_loss_scale = self.loss_scaler.dynamic
347340

348341
self.debug_fp16_grads = [{} for _ in self.fp16_groups]
349342

@@ -1844,11 +1837,10 @@ def _overflow_clean_up(self, prev_scale):
18441837
see_memory_usage('After overflow after clearing gradients', force=False)
18451838

18461839
if dist.get_rank() == 0:
1847-
logger.info(
1848-
"[deepspeed] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, "
1849-
"reducing to {}".format(dist.get_rank(),
1850-
prev_scale,
1851-
self.loss_scale))
1840+
overflow_msg = f"[deepspeed] OVERFLOW! Rank {dist.get_rank()} Skipping step."
1841+
if self.dtype == torch.half:
1842+
overflow_msg += f" Attempted loss scale: {prev_scale}, reducing to {self.loss_scale}"
1843+
logger.info(overflow_msg)
18521844

18531845
@instrument_w_nvtx
18541846
def _overflow_check_and_loss_scale_update(self):

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from collections import OrderedDict
1010

1111
from deepspeed.runtime import ZeROOptimizer
12-
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
12+
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
1313
from deepspeed.runtime.utils import (bwc_tensor_model_parallel_rank,
1414
get_global_norm,
1515
empty_cache,
@@ -506,21 +506,11 @@ def __init__(self,
506506
self.external_loss_scale = None
507507

508508
# we may have a way of fusing dynamic scale. Do not support for now
509-
if self.dtype == torch.float or self.dtype == torch.bfloat16 or not dynamic_loss_scale:
510-
loss_scale_value = 1.0 if (
511-
(self.dtype == torch.float) or
512-
(self.dtype == torch.bfloat16)) else static_loss_scale
513-
514-
self.dynamic_loss_scale = False
515-
self.loss_scaler = LossScaler(scale=loss_scale_value)
516-
cur_iter = 0
517-
else:
518-
if dynamic_loss_args is None:
519-
self.loss_scaler = DynamicLossScaler()
520-
else:
521-
self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
522-
523-
self.dynamic_loss_scale = True
509+
self.loss_scaler = CreateLossScaler(dtype=self.dtype,
510+
static_loss_scale=static_loss_scale,
511+
dynamic_scaling=dynamic_loss_scale,
512+
dynamic_loss_args=dynamic_loss_args)
513+
self.dynamic_loss_scale = self.loss_scaler.dynamic
524514

525515
see_memory_usage("Before initializing optimizer states", force=True)
526516
self.initialize_optimizer_states()
@@ -1788,11 +1778,10 @@ def step(self, closure=None):
17881778
self._update_scale(self.overflow)
17891779
if self.overflow:
17901780
if dist.get_rank() == 0:
1791-
logger.info(
1792-
"[deepspeed] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, "
1793-
"reducing to {}".format(dist.get_rank(),
1794-
prev_scale,
1795-
self.loss_scale))
1781+
overflow_msg = f"[deepspeed] OVERFLOW! Rank {dist.get_rank()} Skipping step."
1782+
if self.dtype == torch.half:
1783+
overflow_msg += f" Attempted loss scale: {prev_scale}, reducing to {self.loss_scale}"
1784+
logger.info(overflow_msg)
17961785

17971786
see_memory_usage('After overflow before clearing gradients')
17981787
self.zero_grad(set_to_none=True)

0 commit comments

Comments
 (0)