diff --git a/pytorch3dunet/predict.py b/pytorch3dunet/predict.py index 6fc77567..b0fd88d5 100755 --- a/pytorch3dunet/predict.py +++ b/pytorch3dunet/predict.py @@ -38,7 +38,6 @@ def main(): if torch.cuda.device_count() > 1 and not config['device'] == 'cpu': model = nn.DataParallel(model) logger.info(f'Using {torch.cuda.device_count()} GPUs for prediction') - model = model.cuda() if torch.cuda.is_available() and not config['device'] == 'cpu': model = model.cuda() diff --git a/pytorch3dunet/unet3d/trainer.py b/pytorch3dunet/unet3d/trainer.py index f86d2529..407fb6c3 100644 --- a/pytorch3dunet/unet3d/trainer.py +++ b/pytorch3dunet/unet3d/trainer.py @@ -23,7 +23,6 @@ def create_trainer(config): if torch.cuda.device_count() > 1 and not config['device'] == 'cpu': model = nn.DataParallel(model) logger.info(f'Using {torch.cuda.device_count()} GPUs for prediction') - model = model.cuda() if torch.cuda.is_available() and not config['device'] == 'cpu': model = model.cuda()