Skip to content

Fix torch.load() for PyTorch 2.6+ compatibility#414

Open
SeanL009 wants to merge 1 commit into
TMElyralab:mainfrom
SeanL009:fix/pytorch-26-weights-only
Open

Fix torch.load() for PyTorch 2.6+ compatibility#414
SeanL009 wants to merge 1 commit into
TMElyralab:mainfrom
SeanL009:fix/pytorch-26-weights-only

Conversation

@SeanL009
Copy link
Copy Markdown

Summary

PyTorch 2.6+ changed the default of torch.load() to weights_only=True, and also introduced stricter zip archive handling. The previous code omitted map_location when CUDA was available, which can cause deserialization errors with zip-serialized checkpoints.

This PR always passes map_location=self.device to torch.load() for consistent behavior across PyTorch versions.

Changes

  • musetalk/models/unet.py: Always use map_location=self.device in torch.load() instead of conditionally omitting it when CUDA is available

Always use map_location=self.device in torch.load() to ensure
compatibility with PyTorch 2.6+. The previous code omitted map_location
when CUDA was available, which can cause deserialization errors with
zip-serialized checkpoints in newer PyTorch versions.

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant