Skip to content

Commit cd8acc2

Browse files
(3/n) Support 2D Parallelism - Efficient loading of full-state checkpoints (#19870)
* memory-optimized loading of full checkpoints into dist model * simplify * handle buffers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * handle strict loading, buffers, and add test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * chlog --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9455871 commit cd8acc2

File tree

3 files changed

+82
-13
lines changed

3 files changed

+82
-13
lines changed

src/lightning/fabric/CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515

1616
- Added support for PyTorch 2.3 ([#19708](https://github.com/Lightning-AI/pytorch-lightning/pull/19708))
1717

18-
- Added `ModelParallelStrategy` to support 2D parallelism ([#19846](https://github.com/Lightning-AI/pytorch-lightning/pull/19846), [#19852](https://github.com/Lightning-AI/pytorch-lightning/pull/19852))
18+
- Added `ModelParallelStrategy` to support 2D parallelism ([#19846](https://github.com/Lightning-AI/pytorch-lightning/pull/19846), [#19852](https://github.com/Lightning-AI/pytorch-lightning/pull/19852), [#19870](https://github.com/Lightning-AI/pytorch-lightning/pull/19870))
1919

2020

2121
### Changed

src/lightning/fabric/strategies/model_parallel.py

+36-11
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import itertools
1415
import shutil
1516
from contextlib import ExitStack
1617
from datetime import timedelta
1718
from pathlib import Path
18-
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Literal, Optional, TypeVar, Union
19+
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generator, Literal, Optional, TypeVar, Union
1920

2021
import torch
2122
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
@@ -429,7 +430,6 @@ def _load_checkpoint(
429430
StateDictOptions,
430431
get_model_state_dict,
431432
get_optimizer_state_dict,
432-
set_model_state_dict,
433433
set_optimizer_state_dict,
434434
)
435435

@@ -484,13 +484,8 @@ def _load_checkpoint(
484484
if not _TORCH_GREATER_EQUAL_2_4:
485485
raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.")
486486

487-
state_dict_options = StateDictOptions(
488-
broadcast_from_rank0=True, # type: ignore[call-arg]
489-
full_state_dict=True,
490-
strict=strict,
491-
)
492487
checkpoint = torch.load(path, mmap=True, map_location="cpu")
493-
set_model_state_dict(module, checkpoint.pop(module_key), options=state_dict_options)
488+
_load_raw_module_state(checkpoint.pop(module_key), module, strict=strict)
494489

495490
requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
496491
_validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict)
@@ -525,7 +520,9 @@ def _load_raw_module_state_from_path(path: Path, module: Module, world_size: int
525520
_load_raw_module_state(state_dict=state_dict, module=module, world_size=world_size, strict=strict)
526521

527522

528-
def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, world_size: int, strict: bool = True) -> None:
523+
def _load_raw_module_state(
524+
state_dict: Dict[str, Any], module: Module, world_size: int = 1, strict: bool = True
525+
) -> None:
529526
"""Loads the state dict into the module by gathering all weights first and then and writing back to each shard."""
530527
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
531528

@@ -535,11 +532,39 @@ def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, world_siz
535532

536533
from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
537534

538-
state_dict_options = StateDictOptions(broadcast_from_rank0=True, full_state_dict=True) # type: ignore[call-arg]
539-
set_model_state_dict(module, state_dict, options=state_dict_options)
535+
state_dict_options = StateDictOptions(
536+
broadcast_from_rank0=True, # type: ignore[call-arg]
537+
full_state_dict=True,
538+
strict=strict, # gets ignored at the moment
539+
)
540+
541+
for submodule_name, submodule in module.named_modules():
542+
for param_name, _ in _named_parameters_and_buffers_to_load(submodule):
543+
full_param_name = f"{submodule_name}{'.' if submodule_name else ''}{param_name}"
544+
if full_param_name not in state_dict:
545+
# Note: PyTorch does not currently respect the `strict` setting in state_dict_options!
546+
if not strict:
547+
continue
548+
raise KeyError(
549+
f"The model contains a key '{full_param_name}' that does not exist in the loaded checkpoint."
550+
" To disable strict loading, set `strict=False`."
551+
)
552+
local_state_dict = {param_name: state_dict[full_param_name]}
553+
set_model_state_dict(submodule, local_state_dict, options=state_dict_options)
540554

541555
elif isinstance(module, FSDP):
542556
with _get_full_state_dict_context(module, world_size=world_size, rank0_only=False):
543557
module.load_state_dict(state_dict, strict=strict)
544558
else:
545559
module.load_state_dict(state_dict, strict=strict)
560+
561+
562+
def _named_parameters_and_buffers_to_load(module: Module) -> Generator:
563+
"""Returns parameters and buffers, with non-persistent buffers excluded."""
564+
for param_name, param in itertools.chain(
565+
module.named_buffers(recurse=False),
566+
module.named_parameters(recurse=False),
567+
):
568+
if param_name in module._non_persistent_buffers_set:
569+
continue
570+
yield param_name, param

tests/tests_fabric/strategies/test_model_parallel_integration.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from copy import deepcopy
1516
from pathlib import Path
1617
from unittest import mock
1718

@@ -20,7 +21,7 @@
2021
import torch.nn as nn
2122
import torch.nn.functional as F
2223
from lightning.fabric import Fabric
23-
from lightning.fabric.strategies import ModelParallelStrategy
24+
from lightning.fabric.strategies.model_parallel import ModelParallelStrategy, _load_raw_module_state
2425
from lightning.fabric.utilities.load import _load_distributed_checkpoint
2526
from torch.utils.data import DataLoader, DistributedSampler
2627

@@ -675,3 +676,46 @@ def test_save_sharded_and_consolidate_and_load(tmp_path):
675676

676677
state = {"model": model, "steps": 1}
677678
fabric.load(checkpoint_path_full, state)
679+
680+
681+
@RunIf(min_torch="2.4", min_cuda_gpus=2, standalone=True)
682+
def test_load_raw_module_state():
683+
from torch.distributed.device_mesh import init_device_mesh
684+
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
685+
686+
class CustomModel(nn.Module):
687+
def __init__(self):
688+
super().__init__()
689+
self.parameter = nn.Parameter(torch.rand(2, 2))
690+
self.layer1 = nn.Linear(4, 4)
691+
self.layer2 = nn.Linear(4, 4)
692+
self.register_buffer("persistent_buffer", torch.rand(2), persistent=True)
693+
self.register_buffer("non_persistent_buffer", torch.rand(2), persistent=False)
694+
695+
fabric = Fabric(accelerator="cuda", devices=2)
696+
fabric.launch()
697+
fabric.seed_everything(0)
698+
699+
with fabric.init_module():
700+
model = CustomModel()
701+
702+
state_dict = deepcopy(model.state_dict())
703+
704+
with fabric.init_module():
705+
model = CustomModel()
706+
707+
device_mesh = init_device_mesh("cuda", mesh_shape=(2,), mesh_dim_names=("tp",))
708+
plan = {"layer1": ColwiseParallel()}
709+
parallelize_module(model, device_mesh, plan)
710+
_load_raw_module_state(state_dict, model, strict=True)
711+
712+
assert torch.equal(model.parameter, state_dict["parameter"])
713+
assert torch.equal(model.layer1.weight.full_tensor(), state_dict["layer1.weight"])
714+
assert torch.equal(model.layer2.weight, state_dict["layer2.weight"])
715+
assert torch.equal(model.persistent_buffer, state_dict["persistent_buffer"])
716+
717+
state_dict.pop("parameter")
718+
with pytest.raises(KeyError, match="The model contains a key 'parameter' that does not exist"):
719+
_load_raw_module_state(state_dict, model, strict=True)
720+
721+
_load_raw_module_state(state_dict, model, strict=False)

0 commit comments

Comments
 (0)