From d1f1aabf81749e28d5eba6dbf264c8b8ea2c2de4 Mon Sep 17 00:00:00 2001 From: fanqiNO1 <75657629+fanqiNO1@users.noreply.github.com> Date: Fri, 17 May 2024 15:27:53 +0800 Subject: [PATCH] [Feature] Support calculating loss during validation (#1503) --- mmengine/runner/loops.py | 82 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 1 deletion(-) diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 329fd48914..5a678db7b9 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -8,8 +8,10 @@ from torch.utils.data import DataLoader from mmengine.evaluator import Evaluator -from mmengine.logging import print_log +from mmengine.logging import HistoryBuffer, print_log from mmengine.registry import LOOPS +from mmengine.structures import BaseDataElement +from mmengine.utils import is_list_of from .amp import autocast from .base_loop import BaseLoop from .utils import calc_dynamic_intervals @@ -363,17 +365,26 @@ def __init__(self, logger='current', level=logging.WARNING) self.fp16 = fp16 + self.val_loss: Dict[str, HistoryBuffer] = dict() def run(self) -> dict: """Launch validation.""" self.runner.call_hook('before_val') self.runner.call_hook('before_val_epoch') self.runner.model.eval() + + # clear val loss + self.val_loss.clear() for idx, data_batch in enumerate(self.dataloader): self.run_iter(idx, data_batch) # compute metrics metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) + + if self.val_loss: + loss_dict = _parse_losses(self.val_loss, 'val') + metrics.update(loss_dict) + self.runner.call_hook('after_val_epoch', metrics=metrics) self.runner.call_hook('after_val') return metrics @@ -391,6 +402,9 @@ def run_iter(self, idx, data_batch: Sequence[dict]): # outputs should be sequence of BaseDataElement with autocast(enabled=self.fp16): outputs = self.runner.model.val_step(data_batch) + + outputs, self.val_loss = _update_losses(outputs, self.val_loss) + self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.runner.call_hook( 'after_val_iter', @@ -435,17 +449,26 @@ def __init__(self, logger='current', level=logging.WARNING) self.fp16 = fp16 + self.test_loss: Dict[str, HistoryBuffer] = dict() def run(self) -> dict: """Launch test.""" self.runner.call_hook('before_test') self.runner.call_hook('before_test_epoch') self.runner.model.eval() + + # clear test loss + self.test_loss.clear() for idx, data_batch in enumerate(self.dataloader): self.run_iter(idx, data_batch) # compute metrics metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) + + if self.test_loss: + loss_dict = _parse_losses(self.test_loss, 'test') + metrics.update(loss_dict) + self.runner.call_hook('after_test_epoch', metrics=metrics) self.runner.call_hook('after_test') return metrics @@ -462,9 +485,66 @@ def run_iter(self, idx, data_batch: Sequence[dict]) -> None: # predictions should be sequence of BaseDataElement with autocast(enabled=self.fp16): outputs = self.runner.model.test_step(data_batch) + + outputs, self.test_loss = _update_losses(outputs, self.test_loss) + self.evaluator.process(data_samples=outputs, data_batch=data_batch) self.runner.call_hook( 'after_test_iter', batch_idx=idx, data_batch=data_batch, outputs=outputs) + + +def _parse_losses(losses: Dict[str, HistoryBuffer], + stage: str) -> Dict[str, float]: + """Parses the raw losses of the network. + + Args: + losses (dict): raw losses of the network. + stage (str): The stage of loss, e.g., 'val' or 'test'. + + Returns: + dict[str, float]: The key is the loss name, and the value is the + average loss. + """ + all_loss = 0 + loss_dict: Dict[str, float] = dict() + + for loss_name, loss_value in losses.items(): + avg_loss = loss_value.mean() + loss_dict[loss_name] = avg_loss + if 'loss' in loss_name: + all_loss += avg_loss + + loss_dict[f'{stage}_loss'] = all_loss + return loss_dict + + +def _update_losses(outputs: list, losses: dict) -> Tuple[list, dict]: + """Update and record the losses of the network. + + Args: + outputs (list): The outputs of the network. + losses (dict): The losses of the network. + + Returns: + list: The updated outputs of the network. + dict: The updated losses of the network. + """ + if isinstance(outputs[-1], + BaseDataElement) and outputs[-1].keys() == ['loss']: + loss = outputs[-1].loss # type: ignore + outputs = outputs[:-1] + else: + loss = dict() + + for loss_name, loss_value in loss.items(): + if loss_name not in losses: + losses[loss_name] = HistoryBuffer() + if isinstance(loss_value, torch.Tensor): + losses[loss_name].update(loss_value.item()) + elif is_list_of(loss_value, torch.Tensor): + for loss_value_i in loss_value: + losses[loss_name].update(loss_value_i.item()) + return outputs, losses