Skip to content

convert_zero_checkpoint_to_fp32_state_dict fails with torch 2.6.0 due to weights_only default change #20643

Open
@championsnet

Description

@championsnet

Bug description

With the release of torch 2.6.0, the default behavior of torch.load has changed and the weights_only argument now defaults to True. This change affects the convert_zero_checkpoint_to_fp32_state_dict function for converting a DeepSpeed checkpoint, leading to the following error when loading checkpoints that include non-weight data:

_pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint...

The error occurs because the function relies on loading optimizer and model state files that contain more than just weights. In previous torch versions, weights_only defaulted to False, allowing the entire state dict to be loaded without issue.

A simple fix would be to explicitly set weights_only=False when calling torch.load within the function:

optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE, weights_only=False)
client_state = torch.load(model_file, map_location=CPU_DEVICE, weights_only=False)

I do not know about any security concerns that might pop up for this fix.

What version are you seeing the problem on?

v2.5

How to reproduce the bug

1.	Use a checkpoint generated with DeepSpeed and Lightning (e.g., one that includes both model weights and additional state such as optimizer state, learning rate scheduler, etc.).
2.	Run the convert_zero_checkpoint_to_fp32_state_dict function.
3.	Observe that the call to torch.load fails due to the default weights_only=True setting in torch 2.6.0.

Error messages and logs

Traceback (most recent call last):
  ...
  File "~/venv/lib/python3.11/site-packages/lightning/pytorch/utilities/deepspeed.py", line 96, in convert_zero_checkpoint_to_fp32_state_dict
    optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/venv/lib/python3.11/site-packages/torch/serialization.py", line 1470, in load
    raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, �[1mdo those steps only if you trust the source of the checkpoint�[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL deepspeed.runtime.zero.config.ZeroStageEnum was not an allowed global by default. Please use `torch.serialization.add_safe_globals([ZeroStageEnum])` or the `torch.serialization.safe_globals([ZeroStageEnum])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

Environment

Current environment
#- PyTorch Lightning Version (2.5.0):
#- PyTorch Version (2.6.0):
#- Python version (e.g., 3.11):
#- OS (e.g., Linux):
#- How you installed Lightning(`pip`):

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions