Skip to content

Commit

Permalink
[Feature] Support using gradient checkpointing in FSDP (#1382)
Browse files Browse the repository at this point in the history
  • Loading branch information
HAOCHENYE authored Oct 9, 2023
1 parent bf30c44 commit 8015d62
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 45 deletions.
45 changes: 44 additions & 1 deletion mmengine/_strategy/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os.path as osp
import time
from collections import OrderedDict
from functools import partial
from typing import Callable, Dict, List, Optional, Sequence, Union

import torch.nn as nn
Expand All @@ -25,7 +26,7 @@
from mmengine.optim import (AmpOptimWrapper, BaseOptimWrapper, OptimWrapper,
OptimWrapperDict, _ParamScheduler,
build_optim_wrapper)
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS,
from mmengine.registry import (FUNCTIONS, MODEL_WRAPPERS, OPTIM_WRAPPERS,
PARAM_SCHEDULERS, STRATEGIES, Registry)
from mmengine.utils import get_git_hash, mkdir_or_exist
from .distributed import DDPStrategy
Expand Down Expand Up @@ -91,6 +92,19 @@ class FSDPStrategy(DDPStrategy):
:meth:`setup_env`. Defaults to None.
- log_kwargs (dict, optional): Logger config passed in
:meth:`build_logger`. Defaults to None.
activation_checkpointing (dict, optional): Config dict for gradient
checkpoint.
Examples:
>>> activation_checkpointing = dict(check_fn='CustomCheckFn')
>>> activation_checkpointing = dict(check_fn=dict(type='CustomCheckFn', arg1=arg1))
``check_fn`` field should behave consistently with
``auto_wrap_policy`` defined in `model_wrapper`, and other
fields will be passed to ``apply_activation_checkpointing``
`New in version 0.9.0.`
.. _FSDP official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type
""" # noqa: E501
Expand All @@ -100,13 +114,15 @@ def __init__(self,
model_wrapper: Optional[dict] = None,
skip_init_weights=False,
state_dict_cfg: Union[str, dict] = 'local',
activation_checkpointing: Optional[dict] = None,
**kwargs):
super().__init__(model_wrapper=model_wrapper, **kwargs)
self._init_state_dict_cfg(state_dict_cfg)
if not isinstance(skip_init_weights, bool):
raise TypeError('skip_init_weights must be a boolean, but got '
f'{type(skip_init_weights)}')
self.skip_init_weights = skip_init_weights
self.activation_checkpointing = activation_checkpointing

def _wrap_model(self, model: nn.Module) -> None:
"""Wrap the model to :obj:``MMFullyShardedDataParallel`` or other
Expand All @@ -119,6 +135,12 @@ def _wrap_model(self, model: nn.Module) -> None:
FullyShardedDataParallel: ``MMFullyShardedDataParallel``
or subclass of ``FullyShardedDataParallel``.
"""
try:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \
apply_activation_checkpointing # noqa: E501
except ImportError:
apply_activation_checkpointing = None

for module in model.modules():
if isinstance(module, BaseDataPreprocessor):
module.to(get_device())
Expand All @@ -138,6 +160,27 @@ def _wrap_model(self, model: nn.Module) -> None:
model.set_state_dict_type(model, self.state_dict_type,
self.state_dict_config,
self.optim_state_dict_config)

if self.activation_checkpointing is not None:
if apply_activation_checkpointing is None:
raise RuntimeError(
'activation_checkpointing maybe deprecated by current '
'PyTorch version, maybe you could switch to PyTorch 2.0 '
'or 2.1 to use `activation_checkpointing`.')
cfg = copy.deepcopy(self.activation_checkpointing)
with FUNCTIONS.switch_scope_and_registry(None):
check_fn = cfg.pop('check_fn')
if isinstance(check_fn, str):
check_fn = FUNCTIONS.get(check_fn)
elif isinstance(check_fn, dict):
fn_type = check_fn.pop('type')
if isinstance(fn_type, str):
fn_type = FUNCTIONS.get(fn_type)
check_fn = partial(fn_type, **cfg)

if not callable(check_fn):
raise TypeError('`check_fn` must be a callable function')
apply_activation_checkpointing(model, check_fn=check_fn, **cfg)
return model

def _is_full_state_dict(self):
Expand Down
90 changes: 46 additions & 44 deletions mmengine/model/wrappers/fully_sharded_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,51 +146,53 @@ def __init__(
'`cpu_offload` should be `None`, `bool`'
f'or `CPUOffload`, but has type {type(cpu_offload)}')

if isinstance(auto_wrap_policy, str):
auto_wrap_policy = FUNCTIONS.get( # type: ignore
auto_wrap_policy)
if auto_wrap_policy is None:
raise ValueError('`auto_wrap_policy` is not registered!')
elif isinstance(auto_wrap_policy, dict):
policy = auto_wrap_policy.pop('type')
if isinstance(policy, str):
policy = FUNCTIONS.get(policy) # type: ignore
if policy is None:
raise ValueError('`auto_wrap_policy` is not registered!')
auto_wrap_policy = partial(policy, **auto_wrap_policy)

if not (auto_wrap_policy is None
or callable(auto_wrap_policy)): # type: ignore
raise TypeError('`auto_wrap_policy` should be a str, a '
'callable, a dict or None, but has type '
f'{type(auto_wrap_policy)}')

if isinstance(backward_prefetch, str):
backward_prefetch = BackwardPrefetch[backward_prefetch]
if not (isinstance(backward_prefetch, BackwardPrefetch)
or backward_prefetch is None):
raise TypeError(
'`backward_prefetch` should be `None`, string of '
'"BACKWARD_PRE" and "BACKWARD_POST", or '
f'`BackwardPrefetch`, but has type {type(backward_prefetch)}')

if isinstance(param_init_fn, str):
param_init_fn = FUNCTIONS.get( # type: ignore
param_init_fn)
if param_init_fn is None:
raise ValueError('`param_init_fn` is not registered!')
elif isinstance(param_init_fn, dict):
init_fn = param_init_fn.pop('type')
with FUNCTIONS.switch_scope_and_registry(None):
if isinstance(auto_wrap_policy, str):
auto_wrap_policy = FUNCTIONS.get( # type: ignore
auto_wrap_policy)
if auto_wrap_policy is None:
raise ValueError('`auto_wrap_policy` is not registered!')
elif isinstance(auto_wrap_policy, dict):
policy = auto_wrap_policy.pop('type')
if isinstance(policy, str):
policy = FUNCTIONS.get(policy) # type: ignore
if policy is None:
raise ValueError('`auto_wrap_policy` is not registered!')
auto_wrap_policy = partial(policy, **auto_wrap_policy)

if not (auto_wrap_policy is None
or callable(auto_wrap_policy)): # type: ignore
raise TypeError('`auto_wrap_policy` should be a str, a '
'callable, a dict or None, but has type '
f'{type(auto_wrap_policy)}')

if isinstance(backward_prefetch, str):
backward_prefetch = BackwardPrefetch[backward_prefetch]
if not (isinstance(backward_prefetch, BackwardPrefetch)
or backward_prefetch is None):
raise TypeError(
'`backward_prefetch` should be `None`, string of '
'"BACKWARD_PRE" and "BACKWARD_POST", or '
f'`BackwardPrefetch`, but has type {type(backward_prefetch)}' # noqa: E501
)

if isinstance(param_init_fn, str):
init_fn = FUNCTIONS.get(init_fn) # type: ignore
if init_fn is None:
raise ValueError('`param_init_fn` is not registered!')
param_init_fn = partial(init_fn, **param_init_fn)

if not (callable(param_init_fn) or param_init_fn is None):
raise TypeError('`param_init_fn` should be a str, a '
'callable, a dict or None, but has type '
f'{type(param_init_fn)}')
param_init_fn = FUNCTIONS.get( # type: ignore
param_init_fn)
if param_init_fn is None:
raise ValueError('`param_init_fn` is not registered!')
elif isinstance(param_init_fn, dict):
init_fn = param_init_fn.pop('type')
if isinstance(param_init_fn, str):
init_fn = FUNCTIONS.get(init_fn) # type: ignore
if init_fn is None:
raise ValueError('`param_init_fn` is not registered!')
param_init_fn = partial(init_fn, **param_init_fn)

if not (callable(param_init_fn) or param_init_fn is None):
raise TypeError('`param_init_fn` should be a str, a '
'callable, a dict or None, but has type '
f'{type(param_init_fn)}')

def parse_dtype(dtype):
if dtype is None:
Expand Down

0 comments on commit 8015d62

Please sign in to comment.