diff --git a/classy_vision/tasks/classification_task.py b/classy_vision/tasks/classification_task.py index 443ce61dcb..526ae8d58d 100644 --- a/classy_vision/tasks/classification_task.py +++ b/classy_vision/tasks/classification_task.py @@ -719,10 +719,7 @@ def init_distributed_data_parallel_model(self): broadcast_buffers=broadcast_buffers, find_unused_parameters=self.find_unused_parameters, ) - if ( - isinstance(self.base_loss, ClassyLoss) - and self.base_loss.has_learned_parameters() - ): + if self._loss_has_learnable_params(): logging.info("Initializing distributed loss") self.distributed_loss = init_distributed_data_parallel_model( self.base_loss, @@ -1014,6 +1011,13 @@ def _broadcast_buffers(self): for buffer in buffers: broadcast(buffer, 0, group=self.distributed_model.process_group) + def _loss_has_learnable_params(self): + """Returns True if the loss has any learnable parameters""" + return ( + isinstance(self.base_loss, ClassyLoss) + and self.base_loss.has_learned_parameters() + ) + # TODO: Functions below should be better abstracted into the dataloader # abstraction def get_batchsize_per_replica(self):