diff --git a/musetalk/models/unet.py b/musetalk/models/unet.py index 575e79af..c81473c7 100755 --- a/musetalk/models/unet.py +++ b/musetalk/models/unet.py @@ -41,7 +41,7 @@ def __init__(self, self.device = device else: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device) + weights = torch.load(model_path, map_location=self.device, weights_only=False) self.model.load_state_dict(weights) if use_float16: self.model = self.model.half()