From fcae35dd18fba8e90b647a1b3fc72ca2cdc99a0e Mon Sep 17 00:00:00 2001 From: 152334H <54623771+152334H@users.noreply.github.com> Date: Wed, 6 Mar 2024 04:49:38 +0000 Subject: [PATCH 1/2] add context manager to fake ScalingTensor -> torch.Tensor --- msamp/common/tensor/tensor.py | 10 ++++++++++ msamp/nn/parameter.py | 4 ++++ 2 files changed, 14 insertions(+) diff --git a/msamp/common/tensor/tensor.py b/msamp/common/tensor/tensor.py index a173f025..e3228dfb 100644 --- a/msamp/common/tensor/tensor.py +++ b/msamp/common/tensor/tensor.py @@ -3,6 +3,7 @@ """MS-AMP tensor module.""" +from contextlib import contextmanager import torch import torch.nn.functional as F from msamp.common.tensor import ScalingMeta @@ -12,7 +13,16 @@ from msamp.common.utils import TransformerEngineWrapper +should_pretend_to_be_tt = False +@contextmanager +def pretend_scaling_is_torch(): + global lol + should_pretend_to_be_tt = True + yield + should_pretend_to_be_tt = False class ScalingTensor: + @property + def __class__(self): return torch.Tensor if should_pretend_to_be_tt else ScalingTensor """Customized tensor with scaling.""" class UniqueDtypeDecorator: """A decorator class to check whether dtype is supported and parameters are uniqie.""" diff --git a/msamp/nn/parameter.py b/msamp/nn/parameter.py index db9f2601..c41d270b 100644 --- a/msamp/nn/parameter.py +++ b/msamp/nn/parameter.py @@ -3,11 +3,15 @@ """MS-AMP parameter module.""" +import torch from msamp.common.tensor import ScalingTensor +import msamp.common.tensor.tensor as tensor_py class ScalingParameter(ScalingTensor): """Parameter class for ScalingTensor.""" + @property + def __class__(self): return torch.Tensor if tensor_py.should_pretend_to_be_tt else ScalingParameter def __init__(self, tensor, requires_grad=True): """Constructor. From e1a1a06929178848abdc4745638919e8b897650b Mon Sep 17 00:00:00 2001 From: 152334H <54623771+152334H@users.noreply.github.com> Date: Wed, 6 Mar 2024 04:58:20 +0000 Subject: [PATCH 2/2] mnist ddp/single gpu examples fixed --- examples/mnist.py | 2 +- examples/mnist_ddp.py | 2 +- msamp/common/tensor/tensor.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/mnist.py b/examples/mnist.py index 5a3e5f2a..a56e8f17 100644 --- a/examples/mnist.py +++ b/examples/mnist.py @@ -70,7 +70,7 @@ def train(args, model, device, train_loader, optimizer, epoch): output = model(data) loss = F.nll_loss(output, target) scaler.scale(loss).backward() - scaler.step(optimizer) + with msamp.common.tensor.tensor.pretend_scaling_is_torch(): scaler.step(optimizer) scaler.update() if batch_idx % args.log_interval == 0: print( diff --git a/examples/mnist_ddp.py b/examples/mnist_ddp.py index 85a56fac..812558f3 100644 --- a/examples/mnist_ddp.py +++ b/examples/mnist_ddp.py @@ -74,7 +74,7 @@ def train(args, model, device, train_loader, optimizer, epoch): output = model(data) loss = F.nll_loss(output, target) scaler.scale(loss).backward() - scaler.step(optimizer) + with msamp.common.tensor.tensor.pretend_scaling_is_torch(): scaler.step(optimizer) scaler.update() if dist.get_rank() == 0: if batch_idx % args.log_interval == 0: diff --git a/msamp/common/tensor/tensor.py b/msamp/common/tensor/tensor.py index e3228dfb..9db20481 100644 --- a/msamp/common/tensor/tensor.py +++ b/msamp/common/tensor/tensor.py @@ -16,7 +16,7 @@ should_pretend_to_be_tt = False @contextmanager def pretend_scaling_is_torch(): - global lol + global should_pretend_to_be_tt should_pretend_to_be_tt = True yield should_pretend_to_be_tt = False