diff --git a/diffusion/evaluation/generate_images.py b/diffusion/evaluation/generate_images.py index 3ac40c9c..ebb8af28 100644 --- a/diffusion/evaluation/generate_images.py +++ b/diffusion/evaluation/generate_images.py @@ -104,7 +104,7 @@ def __init__(self, get_file(path=self.load_path, destination=self.local_checkpoint_path, overwrite=True) with dist.local_rank_zero_download_and_wait(self.local_checkpoint_path): # Load the model - state_dict = torch.load(self.local_checkpoint_path) + state_dict = torch.load(self.local_checkpoint_path, map_location='cpu') for key in list(state_dict['state']['model'].keys()): if 'val_metrics.' in key: del state_dict['state']['model'][key] diff --git a/diffusion/generate.py b/diffusion/generate.py index b6b766ad..f03b3190 100644 --- a/diffusion/generate.py +++ b/diffusion/generate.py @@ -26,7 +26,7 @@ def generate(config: DictConfig) -> None: config (DictConfig): Configuration composed by Hydra """ reproducibility.seed_all(config.seed) - device = get_device() # type: ignore + device = get_device(None) # type: ignore dist.initialize_dist(device, config.dist_timeout) # The model to evaluate