Skip to content

Commit 1170618

Browse files
committed
fix
1 parent 82cbe27 commit 1170618

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/lightning/fabric/strategies/model_parallel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def _load_checkpoint(
484484
raise ImportError("Loading a non-distributed checkpoint into a distributed model requires PyTorch >= 2.4.")
485485

486486
checkpoint = torch.load(path, mmap=True, map_location="cpu")
487-
_load_raw_module_state(checkpoint.pop(module_key), module, world_size=1, strict=strict)
487+
_load_raw_module_state(checkpoint.pop(module_key), module, world_size=self.world_size, strict=strict)
488488

489489
requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys()
490490
_validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict)

0 commit comments

Comments
 (0)