Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert_zero_checkpoint_to_fp32_state_dict fails with torch 2.6.0 due to weights_only default change #20643

Open
championsnet opened this issue Mar 14, 2025 · 0 comments
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.5.x

Comments

@championsnet
Copy link

championsnet commented Mar 14, 2025

Bug description

With the release of torch 2.6.0, the default behavior of torch.load has changed and the weights_only argument now defaults to True. This change affects the convert_zero_checkpoint_to_fp32_state_dict function for converting a DeepSpeed checkpoint, leading to the following error when loading checkpoints that include non-weight data:

_pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, do those steps only if you trust the source of the checkpoint...

The error occurs because the function relies on loading optimizer and model state files that contain more than just weights. In previous torch versions, weights_only defaulted to False, allowing the entire state dict to be loaded without issue.

A simple fix would be to explicitly set weights_only=False when calling torch.load within the function:

optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE, weights_only=False)
client_state = torch.load(model_file, map_location=CPU_DEVICE, weights_only=False)

I do not know about any security concerns that might pop up for this fix.

What version are you seeing the problem on?

v2.5

How to reproduce the bug

1.	Use a checkpoint generated with DeepSpeed and Lightning (e.g., one that includes both model weights and additional state such as optimizer state, learning rate scheduler, etc.).
2.	Run the convert_zero_checkpoint_to_fp32_state_dict function.
3.	Observe that the call to torch.load fails due to the default weights_only=True setting in torch 2.6.0.

Error messages and logs

Traceback (most recent call last):
  ...
  File "~/venv/lib/python3.11/site-packages/lightning/pytorch/utilities/deepspeed.py", line 96, in convert_zero_checkpoint_to_fp32_state_dict
    optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/venv/lib/python3.11/site-packages/torch/serialization.py", line 1470, in load
    raise pickle.UnpicklingError(_get_wo_message(str(e))) from None
_pickle.UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, �[1mdo those steps only if you trust the source of the checkpoint�[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL deepspeed.runtime.zero.config.ZeroStageEnum was not an allowed global by default. Please use `torch.serialization.add_safe_globals([ZeroStageEnum])` or the `torch.serialization.safe_globals([ZeroStageEnum])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

Environment

Current environment
#- PyTorch Lightning Version (2.5.0):
#- PyTorch Version (2.6.0):
#- Python version (e.g., 3.11):
#- OS (e.g., Linux):
#- How you installed Lightning(`pip`):

More info

No response

@championsnet championsnet added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Mar 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.5.x
Projects
None yet
Development

No branches or pull requests

1 participant