Description
Bug description
With the release of torch 2.6.0, the default behavior of torch.load
has changed and the weights_only
argument now defaults to True
. This change affects the convert_zero_checkpoint_to_fp32_state_dict
function for converting a DeepSpeed checkpoint, leading to the following error when loading checkpoints that include non-weight data:
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint...
The error occurs because the function relies on loading optimizer and model state files that contain more than just weights. In previous torch versions, weights_only
defaulted to False
, allowing the entire state dict to be loaded without issue.
A simple fix would be to explicitly set weights_only=False
when calling torch.load
within the function:
optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE, weights_only=False)
client_state = torch.load(model_file, map_location=CPU_DEVICE, weights_only=False)
I do not know about any security concerns that might pop up for this fix.
What version are you seeing the problem on?
v2.5
How to reproduce the bug
1. Use a checkpoint generated with DeepSpeed and Lightning (e.g., one that includes both model weights and additional state such as optimizer state, learning rate scheduler, etc.).
2. Run the convert_zero_checkpoint_to_fp32_state_dict function.
3. Observe that the call to torch.load fails due to the default weights_only=True setting in torch 2.6.0.
Error messages and logs
Traceback (most recent call last):
...
File "~/venv/lib/python3.11/site-packages/lightning/pytorch/utilities/deepspeed.py", line 96, in convert_zero_checkpoint_to_fp32_state_dict
optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/venv/lib/python3.11/site-packages/torch/serialization.py", line 1470, in load
raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, �[1mdo those steps only if you trust the source of the checkpoint�[0m.
(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
WeightsUnpickler error: Unsupported global: GLOBAL deepspeed.runtime.zero.config.ZeroStageEnum was not an allowed global by default. Please use `torch.serialization.add_safe_globals([ZeroStageEnum])` or the `torch.serialization.safe_globals([ZeroStageEnum])` context manager to allowlist this global if you trust this class/function.
Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.
Environment
Current environment
#- PyTorch Lightning Version (2.5.0):
#- PyTorch Version (2.6.0):
#- Python version (e.g., 3.11):
#- OS (e.g., Linux):
#- How you installed Lightning(`pip`):
More info
No response