Skip to content

Commit aab4a17

Browse files
authored
Fix the deprecated torch.cuda.amp module (#21)
* Fix the deprecated torch.cuda.amp module * Fix test according to the latest GradScaler definition.
1 parent e529128 commit aab4a17

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

mmengine/optim/optimizer/amp_optimizer_wrapper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from contextlib import contextmanager
3+
from functools import partial
34
from typing import Union
45

56
import torch
@@ -17,7 +18,8 @@
1718
elif is_mlu_available():
1819
from torch.mlu.amp import GradScaler
1920
else:
20-
from torch.cuda.amp import GradScaler
21+
from torch.amp import GradScaler
22+
GradScaler = partial(GradScaler, device='cuda')
2123

2224

2325
@OPTIM_WRAPPERS.register_module()

tests/test_optim/test_optimizer/test_optimizer_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch.distributed as torch_dist
99
import torch.nn as nn
1010
from parameterized import parameterized
11-
from torch.cuda.amp import GradScaler
11+
from torch.amp import GradScaler
1212
from torch.nn.parallel.distributed import DistributedDataParallel
1313
from torch.optim import SGD, Adam, Optimizer
1414

0 commit comments

Comments
 (0)