diff --git a/interface/utils.py b/interface/utils.py index a27ea4f..0850969 100644 --- a/interface/utils.py +++ b/interface/utils.py @@ -62,7 +62,7 @@ def face_interface(tensor_source_img, tensor_source_prompt, tensor_drving_img, FaceModel = FreeFaceModel(8, 8, model_size, is_train=False).to(device) # current_dir = os.path.dirname(os.path.abspath(__file__)) inference_model_path = os.path.join(current_dir, "../checkpoint/FreeFace/epoch_40.pth") - state_dict = torch.load(inference_model_path)['state_dict']['net_g'] + state_dict = torch.load(inference_model_path, map_location=device)['state_dict']['net_g'] FaceModel.load_state_dict(state_dict) FaceModel.eval() in0 = F.interpolate(tensor_source_img, size=(model_size, model_size), mode='nearest') @@ -238,4 +238,4 @@ def rotation_matrix_to_euler_angles(R): z = math.degrees(z) # print([x, y, z]) # exit() - return [x, y, z] \ No newline at end of file + return [x, y, z]