Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve overflow handling in ZeRO #6976

Open
wants to merge 66 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
a3a18f7
Improve overflow handling in ZeRO
tjruwase Jan 28, 2025
19431f8
Unit test and pydantic configuration
tjruwase Jan 28, 2025
406cf26
Formatting fixes
tjruwase Jan 28, 2025
35570f5
Merge branch 'master' into olruwase/ds_5241
tjruwase Jan 29, 2025
cb78444
Remove unused symbol
tjruwase Jan 29, 2025
ee1c1fd
Fix typo
tjruwase Jan 29, 2025
0b2cf73
Pydantic fp16 config
tjruwase Jan 29, 2025
c7a90f9
Fix more typos
tjruwase Jan 29, 2025
3694e07
Address #4986
tjruwase Jan 29, 2025
2bbcf00
Merge branch 'master' into olruwase/ds_5241
tjruwase Jan 29, 2025
c1b87ea
Merge branch 'master' into olruwase/ds_5241
tjruwase Jan 30, 2025
5da6cd0
Merge branch 'master' into olruwase/ds_5241
tjruwase Jan 30, 2025
a65d20c
Merge branch 'master' into olruwase/ds_5241
loadams Jan 30, 2025
ae039b2
Fix typo
tjruwase Jan 30, 2025
0446192
Merge branch 'olruwase/ds_5241' of github.com:microsoft/DeepSpeed int…
tjruwase Jan 30, 2025
5d48745
Merge branch 'master' into olruwase/ds_5241
tjruwase Jan 31, 2025
05c362d
Merge branch 'master' into olruwase/ds_5241
loadams Jan 31, 2025
5e17ed6
Merge branch 'master' into olruwase/ds_5241
loadams Feb 1, 2025
06bb3a6
Merge branch 'master' into olruwase/ds_5241
tjruwase Feb 5, 2025
0d0ab3d
Fix min loss scale
tjruwase Feb 5, 2025
cccd5b1
Merge branch 'master' into olruwase/ds_5241
tjruwase Feb 5, 2025
2c6f630
Fix UTs
tjruwase Feb 6, 2025
21bfca0
Merge branch 'olruwase/ds_5241' of github.com:microsoft/DeepSpeed int…
tjruwase Feb 6, 2025
5fe5810
Merge branch 'master' into olruwase/ds_5241
tjruwase Feb 6, 2025
732ceb7
Using explicit GPU upcast for ZeRO-Offload (#6962)
xylian86 Jan 21, 2025
db9aff9
Update version.txt after 0.16.3 release (#6965)
loadams Jan 21, 2025
4edeb03
Precisely track nvme optimizer offload (#6963)
tjruwase Jan 23, 2025
f00f4ea
Update build_win.bat script to exclue GDS op as it lacks Windows supp…
loadams Jan 24, 2025
c3846fa
Improve overflow handling in ZeRO
tjruwase Jan 28, 2025
7d56ffa
Unit test and pydantic configuration
tjruwase Jan 28, 2025
6ca11ef
Formatting fixes
tjruwase Jan 28, 2025
49f3df8
Add CUDA 12.8 support and comment on CUDA 12.7 (#6975)
loadams Jan 28, 2025
8364b12
Update torch versions to support 2.6 (#6977)
loadams Jan 29, 2025
ea9b473
Remove unused symbol
tjruwase Jan 29, 2025
d2425a2
Fix typo
tjruwase Jan 29, 2025
7d5be07
Pydantic fp16 config
tjruwase Jan 29, 2025
e8fc098
Fix more typos
tjruwase Jan 29, 2025
2bbb7b4
Address #4986
tjruwase Jan 29, 2025
3ab5e88
generalize deepspeed linear and implement it for non cuda systems (#6…
oelayan7 Jan 29, 2025
271db94
Fix typo
tjruwase Jan 30, 2025
b1900af
Update recommended Windows whl building versions (#6983)
loadams Jan 30, 2025
e3d10e5
Title: Fix setup_env_ranks to Properly Set Environment Variables Inst…
fabiosanger Jan 30, 2025
b8d8e39
Specify torchvision in nv-ds-chat workflow (prevents errors with torc…
loadams Jan 30, 2025
fde7df1
Remove assumption that padding only occurs on last rank (#6974)
xylian86 Jan 31, 2025
b0b0132
Use ds-specific module id to avoid conflicts (#6847)
tjruwase Jan 31, 2025
353ab08
Update A6000 workflows to use newer docker container - 24.09 vs 24.03…
loadams Jan 31, 2025
14189a7
Allow NVIDIA Blackwell (#6991)
fabiendupont Feb 4, 2025
75996f8
Update GH org references (#6998)
tjruwase Feb 5, 2025
b23c545
Fix min loss scale
tjruwase Feb 5, 2025
7cd3a9f
Fix UTs
tjruwase Feb 6, 2025
2c5629e
Update CNAME
loadams Feb 5, 2025
6b15688
Update CNAME
loadams Feb 5, 2025
3773d83
[XPU] max1100 workflow update for docker and softwares (#7003)
Liangliang-Ma Feb 5, 2025
64c4b04
autotp training(fix dco) (#7004)
inkcherry Feb 5, 2025
5fa2910
Merge branch 'olruwase/ds_5241' of github.com:microsoft/DeepSpeed int…
tjruwase Feb 6, 2025
1f5a672
Merge branch 'master' into olruwase/ds_5241
tjruwase Feb 7, 2025
9882116
Fix ds-chat CI regression
tjruwase Feb 7, 2025
97d7915
Merge branch 'olruwase/ds_7014' of github.com:microsoft/DeepSpeed int…
tjruwase Feb 7, 2025
4a1dd0f
Fix bug
tjruwase Feb 7, 2025
0ac4457
Avoid naming collision on partition()
tjruwase Feb 7, 2025
1597d48
Merge branch 'master' into olruwase/ds_5241
tjruwase Feb 8, 2025
2ae2062
Use new API
tjruwase Feb 8, 2025
9fb73a4
Merge branch 'master' into olruwase/ds_7014
tjruwase Feb 8, 2025
26fa8af
Merge branch 'olruwase/ds_7014' of github.com:microsoft/DeepSpeed int…
tjruwase Feb 8, 2025
b565d77
Merge branch 'olruwase/ds_5241' of github.com:microsoft/DeepSpeed int…
tjruwase Feb 8, 2025
d098c32
Merge branch 'master' into olruwase/ds_5241
loadams Feb 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class BF16_Optimizer(ZeROOptimizer):
def __init__(self,
init_optimizer,
param_names,
bfloat16_config,
mpu=None,
clip_grad=0.0,
norm_type=2,
Expand All @@ -44,7 +45,6 @@ def __init__(self,
timers=None,
grad_acc_dtype=None,
graph_harvesting=False,
immediate_grad_update=False,
has_moe_layers=False):
super().__init__()
see_memory_usage('begin bf16_optimizer', force=True)
Expand All @@ -53,10 +53,12 @@ def __init__(self,
self.param_names = param_names
self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim)

assert bfloat16_config.enabled, f"BF16Optimizer: requires bfloat16 to be enabled"
assert grad_acc_dtype in [torch.float32, torch.bfloat16
], f"BF16Optimizer: Unsupported gradient accumulation data type: {grad_acc_dtype}"
self.grad_acc_dtype = grad_acc_dtype
self.immediate_grad_update = immediate_grad_update

self.immediate_grad_update = bfloat16_config.immediate_grad_update

self.clip_grad = clip_grad
self.norm_type = norm_type
Expand Down
180 changes: 79 additions & 101 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
import base64

from .constants import *
from .fp16.loss_scaler import (
INITIAL_LOSS_SCALE,
SCALE_WINDOW,
DELAYED_SHIFT,
CONSECUTIVE_HYSTERESIS,
MIN_LOSS_SCALE,
)
# from .fp16.loss_scaler import (
# INITIAL_LOSS_SCALE,
# SCALE_WINDOW,
# DELAYED_SHIFT,
# CONSECUTIVE_HYSTERESIS,
# MIN_LOSS_SCALE,
# )
from .config_utils import (
get_scalar_param,
dict_raise_error_on_duplicate_keys,
Expand All @@ -31,6 +31,7 @@
from ..comm.config import DeepSpeedCommsConfig
from ..monitor.config import get_monitor_config
from ..inference.config import WeightQuantConfig
from .precision_config import get_bfloat16_config, get_float16_config

from deepspeed import comm as dist
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
Expand Down Expand Up @@ -157,88 +158,64 @@ def get_amp_params(param_dict):
return False


def get_fp16_enabled(param_dict):
if FP16 in param_dict.keys():
return get_scalar_param(param_dict[FP16], FP16_ENABLED, FP16_ENABLED_DEFAULT)
else:
return False


def get_bfloat16_enabled(param_dict):
for key in [BFLOAT16, BFLOAT16_OLD]:
if key in param_dict.keys():
return get_scalar_param(param_dict[key], BFLOAT16_ENABLED, BFLOAT16_ENABLED_DEFAULT)
return False


def get_bfloat16_immediate_grad_update(param_dict):
for key in [BFLOAT16, BFLOAT16_OLD]:
if key in param_dict.keys():
return get_scalar_param(param_dict[key], BFLOAT16_IMMEDIATE_GRAD_UPDATE,
BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT)
return False


def get_fp16_master_weights_and_grads_enabled(param_dict):
if get_fp16_enabled(param_dict):
return get_scalar_param(param_dict[FP16], FP16_MASTER_WEIGHTS_AND_GRADS, FP16_MASTER_WEIGHTS_AND_GRADS_DEFAULT)
else:
return False


def get_fp16_auto_cast(param_dict):
if get_fp16_enabled(param_dict):
return get_scalar_param(param_dict[FP16], FP16_AUTO_CAST, FP16_AUTO_CAST_DEFAULT)


def get_loss_scale(param_dict):
if get_fp16_enabled(param_dict):
return get_scalar_param(param_dict[FP16], FP16_LOSS_SCALE, FP16_LOSS_SCALE_DEFAULT)
elif get_bfloat16_enabled(param_dict):
return 1.0
else:
return FP16_LOSS_SCALE_DEFAULT


def get_initial_dynamic_scale(param_dict):
if get_fp16_enabled(param_dict):
initial_scale_power = get_scalar_param(param_dict[FP16], FP16_INITIAL_SCALE_POWER,
FP16_INITIAL_SCALE_POWER_DEFAULT)
elif get_bfloat16_enabled(param_dict):
initial_scale_power = 0
else:
initial_scale_power = FP16_INITIAL_SCALE_POWER_DEFAULT

return 2**initial_scale_power


def get_dynamic_loss_scale_args(param_dict):
loss_scale_args = None
if get_fp16_enabled(param_dict):
fp16_dict = param_dict[FP16]
dynamic_loss_args = [
FP16_INITIAL_SCALE_POWER,
FP16_LOSS_SCALE_WINDOW,
FP16_MIN_LOSS_SCALE,
FP16_HYSTERESIS,
FP16_CONSECUTIVE_HYSTERESIS,
]
if any(arg in list(fp16_dict.keys()) for arg in dynamic_loss_args):
init_scale = get_scalar_param(fp16_dict, FP16_INITIAL_SCALE_POWER, FP16_INITIAL_SCALE_POWER_DEFAULT)
scale_window = get_scalar_param(fp16_dict, FP16_LOSS_SCALE_WINDOW, FP16_LOSS_SCALE_WINDOW_DEFAULT)
delayed_shift = get_scalar_param(fp16_dict, FP16_HYSTERESIS, FP16_HYSTERESIS_DEFAULT)
consecutive_hysteresis = get_scalar_param(fp16_dict, FP16_CONSECUTIVE_HYSTERESIS,
FP16_CONSECUTIVE_HYSTERESIS_DEFAULT)
min_loss_scale = get_scalar_param(fp16_dict, FP16_MIN_LOSS_SCALE, FP16_MIN_LOSS_SCALE_DEFAULT)
loss_scale_args = {
INITIAL_LOSS_SCALE: 2**init_scale,
SCALE_WINDOW: scale_window,
DELAYED_SHIFT: delayed_shift,
CONSECUTIVE_HYSTERESIS: consecutive_hysteresis,
MIN_LOSS_SCALE: min_loss_scale,
}

return loss_scale_args
# def get_fp16_enabled(param_dict):
# if FP16 in param_dict.keys():
# return get_scalar_param(param_dict[FP16], FP16_ENABLED, FP16_ENABLED_DEFAULT)
# else:
# return False

# def get_fp16_master_weights_and_grads_enabled(param_dict):
# if get_fp16_enabled(param_dict):
# return get_scalar_param(param_dict[FP16], FP16_MASTER_WEIGHTS_AND_GRADS, FP16_MASTER_WEIGHTS_AND_GRADS_DEFAULT)
# else:
# return False

# def get_fp16_auto_cast(param_dict):
# if get_fp16_enabled(param_dict):
# return get_scalar_param(param_dict[FP16], FP16_AUTO_CAST, FP16_AUTO_CAST_DEFAULT)

# def get_loss_scale(param_dict):
# if get_fp16_enabled(param_dict):
# return get_scalar_param(param_dict[FP16], FP16_LOSS_SCALE, FP16_LOSS_SCALE_DEFAULT)
# else:
# return FP16_LOSS_SCALE_DEFAULT

# def get_initial_dynamic_scale(param_dict):
# if get_fp16_enabled(param_dict):
# initial_scale_power = get_scalar_param(param_dict[FP16], FP16_INITIAL_SCALE_POWER,
# FP16_INITIAL_SCALE_POWER_DEFAULT)
# else:
# initial_scale_power = FP16_INITIAL_SCALE_POWER_DEFAULT

# return 2**initial_scale_power

# def get_dynamic_loss_scale_args(param_dict):
# loss_scale_args = None
# if get_fp16_enabled(param_dict):
# fp16_dict = param_dict[FP16]
# dynamic_loss_args = [
# FP16_INITIAL_SCALE_POWER,
# FP16_LOSS_SCALE_WINDOW,
# FP16_MIN_LOSS_SCALE,
# FP16_HYSTERESIS,
# FP16_CONSECUTIVE_HYSTERESIS,
# ]
# if any(arg in list(fp16_dict.keys()) for arg in dynamic_loss_args):
# init_scale = get_scalar_param(fp16_dict, FP16_INITIAL_SCALE_POWER, FP16_INITIAL_SCALE_POWER_DEFAULT)
# scale_window = get_scalar_param(fp16_dict, FP16_LOSS_SCALE_WINDOW, FP16_LOSS_SCALE_WINDOW_DEFAULT)
# delayed_shift = get_scalar_param(fp16_dict, FP16_HYSTERESIS, FP16_HYSTERESIS_DEFAULT)
# consecutive_hysteresis = get_scalar_param(fp16_dict, FP16_CONSECUTIVE_HYSTERESIS,
# FP16_CONSECUTIVE_HYSTERESIS_DEFAULT)
# min_loss_scale = get_scalar_param(fp16_dict, FP16_MIN_LOSS_SCALE, FP16_MIN_LOSS_SCALE_DEFAULT)
# loss_scale_args = {
# INITIAL_LOSS_SCALE: 2**init_scale,
# SCALE_WINDOW: scale_window,
# DELAYED_SHIFT: delayed_shift,
# CONSECUTIVE_HYSTERESIS: consecutive_hysteresis,
# MIN_LOSS_SCALE: min_loss_scale,
# }

# return loss_scale_args


def get_gradient_accumulation_steps(param_dict):
Expand Down Expand Up @@ -827,18 +804,19 @@ def _initialize_params(self, param_dict):
self.monitor_config = get_monitor_config(param_dict)

self.gradient_clipping = get_gradient_clipping(param_dict)
self.fp16_enabled = get_fp16_enabled(param_dict)
self.fp16_auto_cast = get_fp16_auto_cast(param_dict)
self.bfloat16_enabled = get_bfloat16_enabled(param_dict)
self.bfloat16_immediate_grad_update = get_bfloat16_immediate_grad_update(param_dict)
assert not (self.fp16_enabled
and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled'
self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled(param_dict)
# self.fp16_enabled = get_fp16_enabled(param_dict)
# self.fp16_auto_cast = get_fp16_auto_cast(param_dict)
self.float16_config = get_float16_config(param_dict)
self.bfloat16_config = get_bfloat16_config(param_dict)
assert not (self.float16_config.enabled
and self.bfloat16_config.enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled'
# self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled(param_dict)
# self.loss_scale = get_loss_scale(param_dict)
# self.initial_dynamic_scale = get_initial_dynamic_scale(param_dict)
# self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict)

self.amp_enabled = get_amp_enabled(param_dict)
self.amp_params = get_amp_params(param_dict)
self.loss_scale = get_loss_scale(param_dict)
self.initial_dynamic_scale = get_initial_dynamic_scale(param_dict)
self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict)

self.compression_config = get_compression_config(param_dict)
self.graph_harvesting = get_graph_harvesting(param_dict)
Expand Down Expand Up @@ -1018,11 +996,11 @@ def _do_error_check(self):
<= ZeroStageEnum.max_stage), "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(
ZeroStageEnum.max_stage)

if self.fp16_master_weights_and_gradients:
if self.float16_config.fp16_master_weights_and_grads:
assert self.zero_enabled and self.zero_optimization_stage == ZeroStageEnum.gradients, "Fp16_master_weights_and_grads is only supported with ZeRO Stage 2 for now."

def _do_warning_check(self):
fp16_enabled = self.fp16_enabled
fp16_enabled = self.float16_config.enabled

vocabulary_size = self._param_dict.get(VOCABULARY_SIZE, VOCABULARY_SIZE_DEFAULT)
if vocabulary_size and vocabulary_size % TENSOR_CORE_ALIGN_SIZE != 0:
Expand Down
7 changes: 6 additions & 1 deletion deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@
BFLOAT16_FORMAT = '''
BFLOAT16 parameters should be of the format:
"bf16": {
"enabled": true
"enabled": true,
"immediate_grad_update": false,
"check_overflow": false
}
'''
BFLOAT16 = "bf16"
Expand All @@ -126,6 +128,9 @@
BFLOAT16_ENABLED = "enabled"
BFLOAT16_ENABLED_DEFAULT = False

CHECK_OVERFLOW = "check_overflow"
BFLOAT16_CHECK_OVERFLOW_DEFAULT = False

# BFLOAT16 optimizer immediate gradient update
BFLOAT16_IMMEDIATE_GRAD_UPDATE = "immediate_grad_update"
BFLOAT16_IMMEDIATE_GRAD_UPDATE_DEFAULT = False
Expand Down
28 changes: 18 additions & 10 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,13 +906,13 @@ def graph_harvesting(self):
return self._config.graph_harvesting

def fp16_enabled(self):
return self._config.fp16_enabled
return self._config.float16_config.enabled

def bfloat16_enabled(self):
return self._config.bfloat16_enabled
return self._config.bfloat16_config.enabled

def fp16_master_weights_and_gradients(self):
return self._config.fp16_master_weights_and_gradients
return self._config.float16_config.fp16_master_weights_and_grads

def amp_enabled(self):
return self._config.amp_enabled
Expand All @@ -921,10 +921,10 @@ def amp_params(self):
return self._config.amp_params

def fp16_auto_cast(self):
return self._config.fp16_auto_cast
return self._config.float16_config.auto_cast

def loss_scale(self):
return self._config.loss_scale
return self._config.float16_config.loss_scale

def gradient_accumulation_steps(self):
return self._config.gradient_accumulation_steps
Expand Down Expand Up @@ -990,13 +990,13 @@ def gradient_clipping(self):
return self._config.gradient_clipping

def dynamic_loss_scale(self):
return self._config.loss_scale == 0
return self._config.float16_config.loss_scale == 0

def initial_dynamic_scale(self):
return self._config.initial_dynamic_scale
return self._config.float16_config.initial_dynamic_scale()

def dynamic_loss_scale_args(self):
return self._config.dynamic_loss_scale_args
return self._config.float16_config.dynamic_loss_scale_args()

def swap_tensor_config(self):
return self._config.swap_tensor_config
Expand Down Expand Up @@ -1597,14 +1597,14 @@ def _configure_bf16_optimizer(self, optimizer):
timers = self.timers if self.wall_clock_breakdown() else NoopTimer()
optimizer = BF16_Optimizer(optimizer,
self.param_names,
bfloat16_config=self._config.bfloat16_config,
mpu=self.mpu,
clip_grad=clip_grad,
allgather_bucket_size=self.zero_allgather_bucket_size(),
dp_process_group=self.seq_data_parallel_group,
timers=timers,
grad_acc_dtype=self.get_data_types()[1],
graph_harvesting=self.graph_harvesting(),
immediate_grad_update=self._config.bfloat16_immediate_grad_update,
has_moe_layers=self.has_moe_layers)

return optimizer
Expand All @@ -1615,6 +1615,13 @@ def _configure_zero_optimizer(self, optimizer):
mics_shard_size = self.mics_shard_size()
model_dtype, gradient_accumulation_dtype = self.get_data_types()

if self.bfloat16_enabled():
check_grad_overflow = self._config.bfloat16_config.check_grad_overflow
elif self.fp16_enabled():
check_grad_overflow = True
else:
check_grad_overflow = False

timers = self.timers if self.wall_clock_breakdown() else NoopTimer()

if optimizer is None:
Expand Down Expand Up @@ -1666,7 +1673,8 @@ def _configure_zero_optimizer(self, optimizer):
fp16_master_weights_and_gradients=self.fp16_master_weights_and_gradients(),
gradient_accumulation_dtype=gradient_accumulation_dtype,
communication_data_type=self.communication_data_type,
elastic_checkpoint=self.zero_elastic_checkpoint())
elastic_checkpoint=self.zero_elastic_checkpoint(),
check_grad_overflow=check_grad_overflow)

elif zero_stage == ZeroStageEnum.weights:
assert not self.has_moe_layers, "MoE not supported with Stage 3"
Expand Down
Loading
Loading