diff --git a/examples/mnist.py b/examples/mnist.py index 5a3e5f2a..a56e8f17 100644 --- a/examples/mnist.py +++ b/examples/mnist.py @@ -70,7 +70,7 @@ def train(args, model, device, train_loader, optimizer, epoch): output = model(data) loss = F.nll_loss(output, target) scaler.scale(loss).backward() - scaler.step(optimizer) + with msamp.common.tensor.tensor.pretend_scaling_is_torch(): scaler.step(optimizer) scaler.update() if batch_idx % args.log_interval == 0: print( diff --git a/examples/mnist_ddp.py b/examples/mnist_ddp.py index 85a56fac..812558f3 100644 --- a/examples/mnist_ddp.py +++ b/examples/mnist_ddp.py @@ -74,7 +74,7 @@ def train(args, model, device, train_loader, optimizer, epoch): output = model(data) loss = F.nll_loss(output, target) scaler.scale(loss).backward() - scaler.step(optimizer) + with msamp.common.tensor.tensor.pretend_scaling_is_torch(): scaler.step(optimizer) scaler.update() if dist.get_rank() == 0: if batch_idx % args.log_interval == 0: diff --git a/msamp/common/tensor/tensor.py b/msamp/common/tensor/tensor.py index a173f025..9db20481 100644 --- a/msamp/common/tensor/tensor.py +++ b/msamp/common/tensor/tensor.py @@ -3,6 +3,7 @@ """MS-AMP tensor module.""" +from contextlib import contextmanager import torch import torch.nn.functional as F from msamp.common.tensor import ScalingMeta @@ -12,7 +13,16 @@ from msamp.common.utils import TransformerEngineWrapper +should_pretend_to_be_tt = False +@contextmanager +def pretend_scaling_is_torch(): + global should_pretend_to_be_tt + should_pretend_to_be_tt = True + yield + should_pretend_to_be_tt = False class ScalingTensor: + @property + def __class__(self): return torch.Tensor if should_pretend_to_be_tt else ScalingTensor """Customized tensor with scaling.""" class UniqueDtypeDecorator: """A decorator class to check whether dtype is supported and parameters are uniqie.""" diff --git a/msamp/nn/parameter.py b/msamp/nn/parameter.py index db9f2601..c41d270b 100644 --- a/msamp/nn/parameter.py +++ b/msamp/nn/parameter.py @@ -3,11 +3,15 @@ """MS-AMP parameter module.""" +import torch from msamp.common.tensor import ScalingTensor +import msamp.common.tensor.tensor as tensor_py class ScalingParameter(ScalingTensor): """Parameter class for ScalingTensor.""" + @property + def __class__(self): return torch.Tensor if tensor_py.should_pretend_to_be_tt else ScalingParameter def __init__(self, tensor, requires_grad=True): """Constructor.