diff --git a/classy_vision/tasks/classification_task.py b/classy_vision/tasks/classification_task.py index 443ce61dcb..a99fffcf38 100644 --- a/classy_vision/tasks/classification_task.py +++ b/classy_vision/tasks/classification_task.py @@ -6,6 +6,7 @@ import copy import enum +import itertools import json import logging import math @@ -157,6 +158,7 @@ def __init__(self): ) self.amp_args = None self.mixup_transform = None + self.grad_norm_clip = None self.perf_log = [] self.last_batch = None self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED @@ -412,6 +414,24 @@ def set_optimizer_schedulers(self, schedulers): self.optimizer_schedulers = schedulers return self + def set_grad_norm_clip( + self, + grad_norm_clip: Optional[float], + ) -> "ClassificationTask": + """Enable / disable clipping the gradient norm + + Args: + grad_norm_clip: The value to clip the gradient by, set to None to disable + """ + if grad_norm_clip is None: + logging.info(f"Disabled gradient norm clipping: {grad_norm_clip}") + else: + logging.info( + f"Enabled gradient norm clipping with threshold: {grad_norm_clip}" + ) + self.grad_norm_clip = grad_norm_clip + return self + @classmethod def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask": """Instantiates a ClassificationTask from a configuration. @@ -489,6 +509,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask": .set_distributed_options(**distributed_options) .set_hooks(hooks) .set_bn_weight_decay(config.get("bn_weight_decay", False)) + .set_grad_norm_clip(config.get("grad_norm_clip")) ) if not test_only: @@ -719,10 +740,7 @@ def init_distributed_data_parallel_model(self): broadcast_buffers=broadcast_buffers, find_unused_parameters=self.find_unused_parameters, ) - if ( - isinstance(self.base_loss, ClassyLoss) - and self.base_loss.has_learned_parameters() - ): + if self._loss_has_learnable_params(): logging.info("Initializing distributed loss") self.distributed_loss = init_distributed_data_parallel_model( self.base_loss, @@ -919,6 +937,9 @@ def train_step(self): else: self.optimizer.backward(local_loss) + if self.grad_norm_clip is not None: + self._clip_grad_norm() + self.check_inf_nan(loss) self.optimizer.step(where=self.where) @@ -992,6 +1013,22 @@ def create_data_iterators(self): del self.data_iterator self.data_iterator = iter(self.dataloader) + def _clip_grad_norm(self): + """Clip the gradient norms based on self.grad_norm_clip""" + model_params = ( + self.base_model.parameters() + if self.amp_args is None + else apex.amp.master_params(self.optimizer.optimizer) + ) + loss_params = ( + self.base_loss.parameters() + if self._loss_has_learnable_params() + else iter(()) + ) + nn.utils.clip_grad_norm_( + itertools.chain(model_params, loss_params), self.grad_norm_clip + ) + def _set_model_train_mode(self): """Set train mode for model""" phase = self.phases[self.phase_idx] @@ -1014,6 +1051,13 @@ def _broadcast_buffers(self): for buffer in buffers: broadcast(buffer, 0, group=self.distributed_model.process_group) + def _loss_has_learnable_params(self): + """Returns True if the loss has any learnable parameters""" + return ( + isinstance(self.base_loss, ClassyLoss) + and self.base_loss.has_learned_parameters() + ) + # TODO: Functions below should be better abstracted into the dataloader # abstraction def get_batchsize_per_replica(self): diff --git a/test/tasks_classification_task_test.py b/test/tasks_classification_task_test.py index b1fbbbdf92..b5547c539c 100644 --- a/test/tasks_classification_task_test.py +++ b/test/tasks_classification_task_test.py @@ -7,6 +7,7 @@ import copy import shutil import tempfile +import itertools import unittest from test.generic.config_utils import get_fast_test_task_config, get_test_task_config from test.generic.utils import ( @@ -284,3 +285,24 @@ def test_get_classy_state_on_loss(self): task = build_task(config) task.prepare() self.assertIn("alpha", task.get_classy_state()["loss"]) + + def test_grad_norm_clip(self): + config = get_fast_test_task_config() + config["loss"] = {"name": "test_stateful_loss", "in_plane": 256} + config["grad_norm_clip"] = grad_norm_clip = 1 + task = build_task(config) + task.prepare() + + # set fake gradients with norm > grad_norm_clip + for param in itertools.chain( + task.base_model.parameters(), task.base_loss.parameters() + ): + param.grad = 1.1 + torch.rand(param.shape) + self.assertGreater(param.grad.norm(), grad_norm_clip) + + task._clip_grad_norm() + + for param in itertools.chain( + task.base_model.parameters(), task.base_loss.parameters() + ): + self.assertLessEqual(param.grad.norm(), grad_norm_clip)