From 7a773ff4d6a8994a022cef55278fd018c5b199ed Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 13 May 2024 20:14:03 +0200 Subject: [PATCH 1/6] add ema on model parameters --- torchmdnet/module.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 108a1915e..ca5f5e8c0 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -14,6 +14,7 @@ from torchmdnet.models.model import create_model, load_model from torchmdnet.models.utils import dtype_mapping import torch_geometric.transforms as T +from torch_ema import ExponentialMovingAverage class FloatCastDatasetWrapper(T.BaseTransform): @@ -73,6 +74,8 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): else: self.model = create_model(self.hparams, prior_model, mean, std) + self.ema_weights = ExponentialMovingAverage(self.model.parameters(), decay=0.995) + # initialize exponential smoothing self.ema = None self._reset_ema_dict() @@ -251,6 +254,10 @@ def optimizer_step(self, *args, **kwargs): pg["lr"] = lr_scale * self.hparams.lr super().optimizer_step(*args, **kwargs) optimizer.zero_grad() + + def on_before_zero_grad(self, *args, **kwargs): + self.ema_weights.to(self.device) + self.ema_weights.update(self.model.parameters()) def _get_mean_loss_dict_for_type(self, type): # Returns a list with the mean loss for each loss_fn for each stage (train, val, test) From 96934fe47f5388eabcbfb91732e2f61c44dc329b Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 13 May 2024 20:19:51 +0200 Subject: [PATCH 2/6] add ema_prmtrs_decay to control ema on model prmts --- torchmdnet/module.py | 10 +++++++--- torchmdnet/scripts/train.py | 1 + 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index ca5f5e8c0..9a24d01f0 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -74,7 +74,10 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): else: self.model = create_model(self.hparams, prior_model, mean, std) - self.ema_weights = ExponentialMovingAverage(self.model.parameters(), decay=0.995) + self.ema_prmtrs = None + if self.hparams.ema_prmtrs_decay is not None: + # initialize EMA for the model paremeters + self.ema_prmtrs = ExponentialMovingAverage(self.model.parameters(), decay=self.hparams.ema_prmtrs_decay) # initialize exponential smoothing self.ema = None @@ -256,8 +259,9 @@ def optimizer_step(self, *args, **kwargs): optimizer.zero_grad() def on_before_zero_grad(self, *args, **kwargs): - self.ema_weights.to(self.device) - self.ema_weights.update(self.model.parameters()) + if self.ema_prmtrs is not None: + self.ema_prmtrs.to(self.device) + self.ema_prmtrs.update(self.model.parameters()) def _get_mean_loss_dict_for_type(self, type): # Returns a list with the mean loss for each loss_fn for each stage (train, val, test) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 2e69212b4..4d43d0b70 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -59,6 +59,7 @@ def get_argparse(): parser.add_argument('--redirect', type=bool, default=False, help='Redirect stdout and stderr to log_dir/log') parser.add_argument('--gradient-clipping', type=float, default=0.0, help='Gradient clipping norm') parser.add_argument('--remove-ref-energy', action='store_true', help='If true, remove the reference energy from the dataset for delta-learning. Total energy can still be predicted by the model during inference by turning this flag off when loading. The dataset must be compatible with Atomref for this to be used.') + parser.add_argument('--ema-prmtrs-decay', type=float, default=None, help='Exponential moving average decay for model parameters (None to disable)') # dataset specific parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset') parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")') From 7c11034b627b1eec61aade26713ebaeda8c46e40 Mon Sep 17 00:00:00 2001 From: Antonio Mirarchi Date: Mon, 13 May 2024 20:23:15 +0200 Subject: [PATCH 3/6] add torch_ema to enviroment.yml --- environment.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/environment.yml b/environment.yml index 697fe8084..fc1eb3225 100644 --- a/environment.yml +++ b/environment.yml @@ -17,3 +17,5 @@ dependencies: - pytest - psutil - gxx<12 + - pip: + - torch-ema From dfba2b0662ea2f06d4fcaaa4ecb8afe53a955993 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 15 Jul 2024 12:08:16 +0200 Subject: [PATCH 4/6] Change name to ema_parameters_decay, fix initialization --- torchmdnet/module.py | 21 +++++++++++++-------- torchmdnet/scripts/train.py | 2 +- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 9eae8d348..390e59084 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -74,11 +74,16 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): else: self.model = create_model(self.hparams, prior_model, mean, std) - self.ema_prmtrs = None - if self.hparams.ema_prmtrs_decay is not None: + self.ema_parameters = None + if ( + "ema_parameters_decay" in self.hparams + and self.hparams.ema_parameters_decay is not None + ): # initialize EMA for the model paremeters - self.ema_prmtrs = ExponentialMovingAverage(self.model.parameters(), decay=self.hparams.ema_prmtrs_decay) - + self.ema_parameters = ExponentialMovingAverage( + self.model.parameters(), decay=self.hparams.ema_parameters_decay + ) + # initialize exponential smoothing self.ema = None self._reset_ema_dict() @@ -257,11 +262,11 @@ def optimizer_step(self, *args, **kwargs): pg["lr"] = lr_scale * self.hparams.lr super().optimizer_step(*args, **kwargs) optimizer.zero_grad() - + def on_before_zero_grad(self, *args, **kwargs): - if self.ema_prmtrs is not None: - self.ema_prmtrs.to(self.device) - self.ema_prmtrs.update(self.model.parameters()) + if self.ema_parameters is not None: + self.ema_parameters.to(self.device) + self.ema_parameters.update(self.model.parameters()) def _get_mean_loss_dict_for_type(self, type): # Returns a list with the mean loss for each loss_fn for each stage (train, val, test) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index acd624e56..44a2e7974 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -59,7 +59,7 @@ def get_argparse(): parser.add_argument('--redirect', type=bool, default=False, help='Redirect stdout and stderr to log_dir/log') parser.add_argument('--gradient-clipping', type=float, default=0.0, help='Gradient clipping norm') parser.add_argument('--remove-ref-energy', action='store_true', help='If true, remove the reference energy from the dataset for delta-learning. Total energy can still be predicted by the model during inference by turning this flag off when loading. The dataset must be compatible with Atomref for this to be used.') - parser.add_argument('--ema-prmtrs-decay', type=float, default=None, help='Exponential moving average decay for model parameters (None to disable)') + parser.add_argument('--ema-parameters-decay', type=float, default=None, help='Exponential moving average decay for model parameters (defaults to None, meaning disable). The decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed.') # dataset specific parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset') parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")') From de8ffa5f944306d7c4547031debb525f2496ce2c Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 15 Jul 2024 12:55:14 +0200 Subject: [PATCH 5/6] Switch to torch.optim. Add ema_parameters_start to train.py. Rename some variables --- environment.yml | 1 - torchmdnet/module.py | 47 +++++++++++++++++++++---------------- torchmdnet/scripts/train.py | 1 + 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/environment.yml b/environment.yml index 9392d147d..5116787f8 100644 --- a/environment.yml +++ b/environment.yml @@ -12,7 +12,6 @@ dependencies: - pydantic - torchmetrics - tqdm - - torch-ema # Dev tools - flake8 - pytest diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 390e59084..ea26051d1 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -6,6 +6,7 @@ import torch from torch.optim import AdamW from torch.optim.lr_scheduler import ReduceLROnPlateau +import torch.optim.swa_utils from torch.nn.functional import local_response_norm, mse_loss, l1_loss from torch import Tensor from typing import Optional, Dict, Tuple @@ -14,7 +15,6 @@ from torchmdnet.models.model import create_model, load_model from torchmdnet.models.utils import dtype_mapping import torch_geometric.transforms as T -from torch_ema import ExponentialMovingAverage class FloatCastDatasetWrapper(T.BaseTransform): @@ -74,19 +74,26 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): else: self.model = create_model(self.hparams, prior_model, mean, std) - self.ema_parameters = None + self.ema_model = None if ( "ema_parameters_decay" in self.hparams and self.hparams.ema_parameters_decay is not None ): - # initialize EMA for the model paremeters - self.ema_parameters = ExponentialMovingAverage( - self.model.parameters(), decay=self.hparams.ema_parameters_decay + self.ema_model = torch.optim.swa_utils.AveragedModel( + self.model, + multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn( + self.hparams.ema_parameters_decay + ), ) + self.ema_parameters_start = ( + self.hparams.ema_parameters_start + if "ema_parameters_start" in self.hparams + else 0 + ) - # initialize exponential smoothing - self.ema = None - self._reset_ema_dict() + # initialize exponential smoothing for the losses + self.ema_loss = None + self._reset_ema_loss_dict() # initialize loss collection self.losses = None @@ -188,12 +195,12 @@ def _update_loss_with_ema(self, stage, type, loss_name, loss): alpha = getattr(self.hparams, f"ema_alpha_{type}") if stage in ["train", "val"] and alpha < 1 and alpha > 0: ema = ( - self.ema[stage][type][loss_name] - if loss_name in self.ema[stage][type] + self.ema_loss[stage][type][loss_name] + if loss_name in self.ema_loss[stage][type] else loss.detach() ) loss = alpha * loss + (1 - alpha) * ema - self.ema[stage][type][loss_name] = loss.detach() + self.ema_loss[stage][type][loss_name] = loss.detach() return loss def step(self, batch, loss_fn_list, stage): @@ -261,13 +268,13 @@ def optimizer_step(self, *args, **kwargs): for pg in optimizer.param_groups: pg["lr"] = lr_scale * self.hparams.lr super().optimizer_step(*args, **kwargs) + if ( + self.trainer.global_step >= self.ema_parameters_start + and self.ema_model is not None + ): + self.ema_model.update_parameters(self.model) optimizer.zero_grad() - def on_before_zero_grad(self, *args, **kwargs): - if self.ema_parameters is not None: - self.ema_parameters.to(self.device) - self.ema_parameters.update(self.model.parameters()) - def _get_mean_loss_dict_for_type(self, type): # Returns a list with the mean loss for each loss_fn for each stage (train, val, test) # Parameters: @@ -320,9 +327,9 @@ def _reset_losses_dict(self): for loss_type in ["total", "y", "neg_dy"]: self.losses[stage][loss_type] = defaultdict(list) - def _reset_ema_dict(self): - self.ema = {} + def _reset_ema_loss_dict(self): + self.ema_loss = {} for stage in ["train", "val"]: - self.ema[stage] = {} + self.ema_loss[stage] = {} for loss_type in ["y", "neg_dy"]: - self.ema[stage][loss_type] = {} + self.ema_loss[stage][loss_type] = {} diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 44a2e7974..aef1b0ece 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -60,6 +60,7 @@ def get_argparse(): parser.add_argument('--gradient-clipping', type=float, default=0.0, help='Gradient clipping norm') parser.add_argument('--remove-ref-energy', action='store_true', help='If true, remove the reference energy from the dataset for delta-learning. Total energy can still be predicted by the model during inference by turning this flag off when loading. The dataset must be compatible with Atomref for this to be used.') parser.add_argument('--ema-parameters-decay', type=float, default=None, help='Exponential moving average decay for model parameters (defaults to None, meaning disable). The decay is a parameter between 0 and 1 that controls how fast the averaged parameters are decayed.') + parser.add_argument('--ema-parameters-start', type=int, default=0, help='Epoch to start averaging the parameters.') # dataset specific parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset') parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")') From 55afbbd328b2ccbc8ffded11d80bfcc56cebf47b Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Mon, 15 Jul 2024 13:03:13 +0200 Subject: [PATCH 6/6] Use current_epoch to start ema weights --- torchmdnet/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index ea26051d1..51a2ade41 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -269,7 +269,7 @@ def optimizer_step(self, *args, **kwargs): pg["lr"] = lr_scale * self.hparams.lr super().optimizer_step(*args, **kwargs) if ( - self.trainer.global_step >= self.ema_parameters_start + self.trainer.current_epoch >= self.ema_parameters_start and self.ema_model is not None ): self.ema_model.update_parameters(self.model)