Skip to content

Commit 8015d62

Browse files
authored
[Feature] Support using gradient checkpointing in FSDP (#1382)
1 parent bf30c44 commit 8015d62

File tree

2 files changed

+90
-45
lines changed

2 files changed

+90
-45
lines changed

mmengine/_strategy/fsdp.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os.path as osp
66
import time
77
from collections import OrderedDict
8+
from functools import partial
89
from typing import Callable, Dict, List, Optional, Sequence, Union
910

1011
import torch.nn as nn
@@ -25,7 +26,7 @@
2526
from mmengine.optim import (AmpOptimWrapper, BaseOptimWrapper, OptimWrapper,
2627
OptimWrapperDict, _ParamScheduler,
2728
build_optim_wrapper)
28-
from mmengine.registry import (MODEL_WRAPPERS, OPTIM_WRAPPERS,
29+
from mmengine.registry import (FUNCTIONS, MODEL_WRAPPERS, OPTIM_WRAPPERS,
2930
PARAM_SCHEDULERS, STRATEGIES, Registry)
3031
from mmengine.utils import get_git_hash, mkdir_or_exist
3132
from .distributed import DDPStrategy
@@ -91,6 +92,19 @@ class FSDPStrategy(DDPStrategy):
9192
:meth:`setup_env`. Defaults to None.
9293
- log_kwargs (dict, optional): Logger config passed in
9394
:meth:`build_logger`. Defaults to None.
95+
activation_checkpointing (dict, optional): Config dict for gradient
96+
checkpoint.
97+
98+
Examples:
99+
>>> activation_checkpointing = dict(check_fn='CustomCheckFn')
100+
>>> activation_checkpointing = dict(check_fn=dict(type='CustomCheckFn', arg1=arg1))
101+
102+
103+
``check_fn`` field should behave consistently with
104+
``auto_wrap_policy`` defined in `model_wrapper`, and other
105+
fields will be passed to ``apply_activation_checkpointing``
106+
107+
`New in version 0.9.0.`
94108
95109
.. _FSDP official api documents: https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type
96110
""" # noqa: E501
@@ -100,13 +114,15 @@ def __init__(self,
100114
model_wrapper: Optional[dict] = None,
101115
skip_init_weights=False,
102116
state_dict_cfg: Union[str, dict] = 'local',
117+
activation_checkpointing: Optional[dict] = None,
103118
**kwargs):
104119
super().__init__(model_wrapper=model_wrapper, **kwargs)
105120
self._init_state_dict_cfg(state_dict_cfg)
106121
if not isinstance(skip_init_weights, bool):
107122
raise TypeError('skip_init_weights must be a boolean, but got '
108123
f'{type(skip_init_weights)}')
109124
self.skip_init_weights = skip_init_weights
125+
self.activation_checkpointing = activation_checkpointing
110126

111127
def _wrap_model(self, model: nn.Module) -> None:
112128
"""Wrap the model to :obj:``MMFullyShardedDataParallel`` or other
@@ -119,6 +135,12 @@ def _wrap_model(self, model: nn.Module) -> None:
119135
FullyShardedDataParallel: ``MMFullyShardedDataParallel``
120136
or subclass of ``FullyShardedDataParallel``.
121137
"""
138+
try:
139+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \
140+
apply_activation_checkpointing # noqa: E501
141+
except ImportError:
142+
apply_activation_checkpointing = None
143+
122144
for module in model.modules():
123145
if isinstance(module, BaseDataPreprocessor):
124146
module.to(get_device())
@@ -138,6 +160,27 @@ def _wrap_model(self, model: nn.Module) -> None:
138160
model.set_state_dict_type(model, self.state_dict_type,
139161
self.state_dict_config,
140162
self.optim_state_dict_config)
163+
164+
if self.activation_checkpointing is not None:
165+
if apply_activation_checkpointing is None:
166+
raise RuntimeError(
167+
'activation_checkpointing maybe deprecated by current '
168+
'PyTorch version, maybe you could switch to PyTorch 2.0 '
169+
'or 2.1 to use `activation_checkpointing`.')
170+
cfg = copy.deepcopy(self.activation_checkpointing)
171+
with FUNCTIONS.switch_scope_and_registry(None):
172+
check_fn = cfg.pop('check_fn')
173+
if isinstance(check_fn, str):
174+
check_fn = FUNCTIONS.get(check_fn)
175+
elif isinstance(check_fn, dict):
176+
fn_type = check_fn.pop('type')
177+
if isinstance(fn_type, str):
178+
fn_type = FUNCTIONS.get(fn_type)
179+
check_fn = partial(fn_type, **cfg)
180+
181+
if not callable(check_fn):
182+
raise TypeError('`check_fn` must be a callable function')
183+
apply_activation_checkpointing(model, check_fn=check_fn, **cfg)
141184
return model
142185

143186
def _is_full_state_dict(self):

mmengine/model/wrappers/fully_sharded_distributed.py

Lines changed: 46 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -146,51 +146,53 @@ def __init__(
146146
'`cpu_offload` should be `None`, `bool`'
147147
f'or `CPUOffload`, but has type {type(cpu_offload)}')
148148

149-
if isinstance(auto_wrap_policy, str):
150-
auto_wrap_policy = FUNCTIONS.get( # type: ignore
151-
auto_wrap_policy)
152-
if auto_wrap_policy is None:
153-
raise ValueError('`auto_wrap_policy` is not registered!')
154-
elif isinstance(auto_wrap_policy, dict):
155-
policy = auto_wrap_policy.pop('type')
156-
if isinstance(policy, str):
157-
policy = FUNCTIONS.get(policy) # type: ignore
158-
if policy is None:
159-
raise ValueError('`auto_wrap_policy` is not registered!')
160-
auto_wrap_policy = partial(policy, **auto_wrap_policy)
161-
162-
if not (auto_wrap_policy is None
163-
or callable(auto_wrap_policy)): # type: ignore
164-
raise TypeError('`auto_wrap_policy` should be a str, a '
165-
'callable, a dict or None, but has type '
166-
f'{type(auto_wrap_policy)}')
167-
168-
if isinstance(backward_prefetch, str):
169-
backward_prefetch = BackwardPrefetch[backward_prefetch]
170-
if not (isinstance(backward_prefetch, BackwardPrefetch)
171-
or backward_prefetch is None):
172-
raise TypeError(
173-
'`backward_prefetch` should be `None`, string of '
174-
'"BACKWARD_PRE" and "BACKWARD_POST", or '
175-
f'`BackwardPrefetch`, but has type {type(backward_prefetch)}')
176-
177-
if isinstance(param_init_fn, str):
178-
param_init_fn = FUNCTIONS.get( # type: ignore
179-
param_init_fn)
180-
if param_init_fn is None:
181-
raise ValueError('`param_init_fn` is not registered!')
182-
elif isinstance(param_init_fn, dict):
183-
init_fn = param_init_fn.pop('type')
149+
with FUNCTIONS.switch_scope_and_registry(None):
150+
if isinstance(auto_wrap_policy, str):
151+
auto_wrap_policy = FUNCTIONS.get( # type: ignore
152+
auto_wrap_policy)
153+
if auto_wrap_policy is None:
154+
raise ValueError('`auto_wrap_policy` is not registered!')
155+
elif isinstance(auto_wrap_policy, dict):
156+
policy = auto_wrap_policy.pop('type')
157+
if isinstance(policy, str):
158+
policy = FUNCTIONS.get(policy) # type: ignore
159+
if policy is None:
160+
raise ValueError('`auto_wrap_policy` is not registered!')
161+
auto_wrap_policy = partial(policy, **auto_wrap_policy)
162+
163+
if not (auto_wrap_policy is None
164+
or callable(auto_wrap_policy)): # type: ignore
165+
raise TypeError('`auto_wrap_policy` should be a str, a '
166+
'callable, a dict or None, but has type '
167+
f'{type(auto_wrap_policy)}')
168+
169+
if isinstance(backward_prefetch, str):
170+
backward_prefetch = BackwardPrefetch[backward_prefetch]
171+
if not (isinstance(backward_prefetch, BackwardPrefetch)
172+
or backward_prefetch is None):
173+
raise TypeError(
174+
'`backward_prefetch` should be `None`, string of '
175+
'"BACKWARD_PRE" and "BACKWARD_POST", or '
176+
f'`BackwardPrefetch`, but has type {type(backward_prefetch)}' # noqa: E501
177+
)
178+
184179
if isinstance(param_init_fn, str):
185-
init_fn = FUNCTIONS.get(init_fn) # type: ignore
186-
if init_fn is None:
187-
raise ValueError('`param_init_fn` is not registered!')
188-
param_init_fn = partial(init_fn, **param_init_fn)
189-
190-
if not (callable(param_init_fn) or param_init_fn is None):
191-
raise TypeError('`param_init_fn` should be a str, a '
192-
'callable, a dict or None, but has type '
193-
f'{type(param_init_fn)}')
180+
param_init_fn = FUNCTIONS.get( # type: ignore
181+
param_init_fn)
182+
if param_init_fn is None:
183+
raise ValueError('`param_init_fn` is not registered!')
184+
elif isinstance(param_init_fn, dict):
185+
init_fn = param_init_fn.pop('type')
186+
if isinstance(param_init_fn, str):
187+
init_fn = FUNCTIONS.get(init_fn) # type: ignore
188+
if init_fn is None:
189+
raise ValueError('`param_init_fn` is not registered!')
190+
param_init_fn = partial(init_fn, **param_init_fn)
191+
192+
if not (callable(param_init_fn) or param_init_fn is None):
193+
raise TypeError('`param_init_fn` should be a str, a '
194+
'callable, a dict or None, but has type '
195+
f'{type(param_init_fn)}')
194196

195197
def parse_dtype(dtype):
196198
if dtype is None:

0 commit comments

Comments
 (0)