diff --git a/src/pl_bolts/callbacks/data_monitor.py b/src/pl_bolts/callbacks/data_monitor.py index 62ec3b7f7..9d4f2ae89 100644 --- a/src/pl_bolts/callbacks/data_monitor.py +++ b/src/pl_bolts/callbacks/data_monitor.py @@ -2,16 +2,15 @@ import numpy as np import torch +from lightning_utilities import apply_to_collection +from lightning_utilities.core.rank_zero import rank_zero_warn from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger -from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.apply_func import apply_to_collection from torch import Tensor, nn from torch.nn import Module from torch.utils.hooks import RemovableHandle from pl_bolts.utils import _WANDB_AVAILABLE -from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg # Backward compatibility for Lightning Logger @@ -26,14 +25,13 @@ warn_missing_pkg("wandb") -@under_review() class DataMonitorBase(Callback): supported_loggers = ( TensorBoardLogger, WandbLogger, ) - def __init__(self, log_every_n_steps: int = None) -> None: + def __init__(self, log_every_n_steps: Optional[int] = None): """Base class for monitoring data histograms in a LightningModule. This requires a logger configured in the Trainer, otherwise no data is logged. The specific class that inherits from this base defines what data gets collected. @@ -102,7 +100,7 @@ def log_histogram(self, tensor: Tensor, name: str) -> None: logger.experiment.log(data={name: wandb.Histogram(tensor)}, commit=False) - def _is_logger_available(self, logger: Logger) -> bool: + def _is_logger_available(self, logger: Optional[Logger]) -> bool: available = True if not logger: rank_zero_warn("Cannot log histograms because Trainer has no logger.") @@ -116,7 +114,6 @@ def _is_logger_available(self, logger: Logger) -> bool: return available -@under_review() class ModuleDataMonitor(DataMonitorBase): GROUP_NAME_INPUT = "input" GROUP_NAME_OUTPUT = "output" @@ -124,9 +121,10 @@ class ModuleDataMonitor(DataMonitorBase): def __init__( self, submodules: Optional[Union[bool, List[str]]] = None, - log_every_n_steps: int = None, - ) -> None: - """ + log_every_n_steps: Optional[int] = None, + ): + """Logs the in- and output histogram of submodules. + Args: submodules: If `True`, logs the in- and output histograms of every submodule in the LightningModule, including the root module itself. @@ -152,7 +150,6 @@ def __init__( # specific submodules trainer = Trainer(callbacks=[ModuleDataMonitor(submodules=["generator", "generator.conv1"])]) - """ super().__init__(log_every_n_steps=log_every_n_steps) self._submodule_names = submodules @@ -162,7 +159,7 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: super().on_train_start(trainer, pl_module) submodule_dict = dict(pl_module.named_modules()) self._hook_handles = [] - for name in self._get_submodule_names(pl_module): + for name in self._get_submodule_names(submodule_dict): if name not in submodule_dict: rank_zero_warn( f"{name} is not a valid identifier for a submodule in {pl_module.__class__.__name__}," @@ -176,7 +173,7 @@ def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: for handle in self._hook_handles: handle.remove() - def _get_submodule_names(self, root_module: nn.Module) -> List[str]: + def _get_submodule_names(self, named_modules: Dict[str, nn.Module]) -> List[str]: # default is the root module only names = [""] @@ -184,7 +181,7 @@ def _get_submodule_names(self, root_module: nn.Module) -> List[str]: names = self._submodule_names if self._submodule_names is True: - names = [name for name, _ in root_module.named_modules()] + names = list(named_modules) return names @@ -192,7 +189,7 @@ def _register_hook(self, module_name: str, module: nn.Module) -> RemovableHandle input_group_name = f"{self.GROUP_NAME_INPUT}/{module_name}" if module_name else self.GROUP_NAME_INPUT output_group_name = f"{self.GROUP_NAME_OUTPUT}/{module_name}" if module_name else self.GROUP_NAME_OUTPUT - def hook(_: Module, inp: Sequence, out: Sequence) -> None: + def hook(_: Module, inp: Any, out: Any) -> None: inp = inp[0] if len(inp) == 1 else inp self.log_histograms(inp, group=input_group_name) self.log_histograms(out, group=output_group_name) @@ -201,11 +198,10 @@ def hook(_: Module, inp: Sequence, out: Sequence) -> None: return module.register_forward_hook(hook) -@under_review() class TrainingDataMonitor(DataMonitorBase): GROUP_NAME = "training_step" - def __init__(self, log_every_n_steps: int = None) -> None: + def __init__(self, log_every_n_steps: Optional[int] = None): """Callback that logs the histogram of values in the batched data passed to `training_step`. Args: @@ -233,7 +229,11 @@ def on_train_batch_start( self.log_histograms(batch, group=self.GROUP_NAME) -def collect_and_name_tensors(data: Any, output: Dict[str, Tensor], parent_name: str = "input") -> None: +def collect_and_name_tensors( + data: Union[Tensor, dict, Sequence], + output: Dict[str, Tensor], + parent_name: str = "input", +) -> None: """Recursively fetches all tensors in a (nested) collection of data (depth-first search) and names them. Data in dictionaries get named by their corresponding keys and otherwise they get indexed by an increasing integer. The shape of the tensor gets appended to the name as well. @@ -264,7 +264,6 @@ def collect_and_name_tensors(data: Any, output: Dict[str, Tensor], parent_name: collect_and_name_tensors(item, output, parent_name=f"{parent_name}/{i:d}") -@under_review() def shape2str(tensor: Tensor) -> str: """Returns the shape of a tensor in bracket notation as a string. @@ -274,4 +273,4 @@ def shape2str(tensor: Tensor) -> str: >>> shape2str(torch.rand(4)) '[4]' """ - return "[" + ", ".join(map(str, tensor.shape)) + "]" + return str(list(tensor.shape)) diff --git a/src/pl_bolts/callbacks/verification/batch_gradient.py b/src/pl_bolts/callbacks/verification/batch_gradient.py index a7fef8254..466e7b902 100644 --- a/src/pl_bolts/callbacks/verification/batch_gradient.py +++ b/src/pl_bolts/callbacks/verification/batch_gradient.py @@ -4,8 +4,8 @@ import torch import torch.nn as nn +from lightning_utilities import apply_to_collection from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor diff --git a/tests/callbacks/test_data_monitor.py b/tests/callbacks/test_data_monitor.py index 3d3ddc0c4..163338926 100644 --- a/tests/callbacks/test_data_monitor.py +++ b/tests/callbacks/test_data_monitor.py @@ -1,3 +1,4 @@ +import warnings from unittest import mock from unittest.mock import call @@ -6,16 +7,24 @@ from pl_bolts.callbacks import ModuleDataMonitor, TrainingDataMonitor from pl_bolts.datamodules import MNISTDataModule from pl_bolts.models import LitMNIST -from pytorch_lightning import Trainer +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.utilities.rank_zero import LightningDeprecationWarning +from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch import nn # @pytest.mark.parametrize(("log_every_n_steps", "max_steps", "expected_calls"), [pytest.param(3, 10, 3)]) @mock.patch("pl_bolts.callbacks.data_monitor.DataMonitorBase.log_histogram") def test_base_log_interval_override( - log_histogram, tmpdir, datadir, log_every_n_steps=3, max_steps=10, expected_calls=3 + log_histogram, tmpdir, datadir, catch_warnings, log_every_n_steps=3, max_steps=10, expected_calls=3 ): """Test logging interval set by log_every_n_steps argument.""" + warnings.filterwarnings( + "ignore", + message=".*does not have many workers which may be a bottleneck.*", + category=PossibleUserWarning, + ) monitor = TrainingDataMonitor(log_every_n_steps=log_every_n_steps) model = LitMNIST(num_workers=0) datamodule = MNISTDataModule(data_dir=datadir) @@ -24,6 +33,7 @@ def test_base_log_interval_override( log_every_n_steps=1, max_steps=max_steps, callbacks=[monitor], + accelerator="auto", ) trainer.fit(model, datamodule=datamodule) @@ -32,16 +42,24 @@ def test_base_log_interval_override( @pytest.mark.parametrize( ("log_every_n_steps", "max_steps", "expected_calls"), - [ - pytest.param(1, 5, 5), - pytest.param(2, 5, 2), - pytest.param(5, 5, 1), - pytest.param(6, 5, 0), - ], + [(1, 5, 5), (2, 5, 2), (5, 5, 1), (6, 5, 0)], ) @mock.patch("pl_bolts.callbacks.data_monitor.DataMonitorBase.log_histogram") -def test_base_log_interval_fallback(log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir): +def test_base_log_interval_fallback( + log_histogram, + tmpdir, + log_every_n_steps, + max_steps, + expected_calls, + datadir, + catch_warnings, +): """Test that if log_every_n_steps not set in the callback, fallback to what is defined in the Trainer.""" + warnings.filterwarnings( + "ignore", + message=".*does not have many workers which may be a bottleneck.*", + category=PossibleUserWarning, + ) monitor = TrainingDataMonitor() model = LitMNIST(num_workers=0) datamodule = MNISTDataModule(data_dir=datadir) @@ -50,21 +68,40 @@ def test_base_log_interval_fallback(log_histogram, tmpdir, log_every_n_steps, ma log_every_n_steps=log_every_n_steps, max_steps=max_steps, callbacks=[monitor], + accelerator="auto", ) trainer.fit(model, datamodule=datamodule) assert log_histogram.call_count == (expected_calls * 2) # 2 tensors per log call -def test_base_no_logger_warning(): +def test_base_no_logger_warning(catch_warnings): """Test a warning is displayed when Trainer has no logger.""" monitor = TrainingDataMonitor() - trainer = Trainer(logger=False, callbacks=[monitor]) + trainer = Trainer(logger=False, callbacks=[monitor], accelerator="auto", max_epochs=-1) with pytest.warns(UserWarning, match="Cannot log histograms because Trainer has no logger"): - monitor.on_train_start(trainer, pl_module=None) + monitor.on_train_start(trainer, pl_module=LightningModule()) + + +def test_base_unsupported_logger_warning(tmpdir, catch_warnings): + """Test a warning is displayed when an unsupported logger is used.""" + warnings.filterwarnings( + "ignore", + message=".*is deprecated in v1.6.*", + category=LightningDeprecationWarning, + ) + monitor = TrainingDataMonitor() + trainer = Trainer( + logger=TensorBoardLogger(tmpdir), + callbacks=[monitor], + accelerator="auto", + max_epochs=1, + ) + with pytest.warns(UserWarning, match="does not support logging with LoggerCollection"): + monitor.on_train_start(trainer, pl_module=LightningModule()) @mock.patch("pl_bolts.callbacks.data_monitor.TrainingDataMonitor.log_histogram") -def test_training_data_monitor(log_histogram, tmpdir, datadir): +def test_training_data_monitor(log_histogram, tmpdir, datadir, catch_warnings): """Test that the TrainingDataMonitor logs histograms of data points going into training_step.""" monitor = TrainingDataMonitor() model = LitMNIST(data_dir=datadir) @@ -72,6 +109,8 @@ def test_training_data_monitor(log_histogram, tmpdir, datadir): default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor], + accelerator="auto", + max_epochs=1, ) monitor.on_train_start(trainer, model) @@ -116,8 +155,8 @@ def forward(self, *args, **kwargs): return self.sub_layer(*args, **kwargs) -class ModuleDataMonitorModel(nn.Module): - def __init__(self) -> None: +class ModuleDataMonitorModel(LightningModule): + def __init__(self): super().__init__() self.layer1 = nn.Linear(12, 5) self.layer2 = SubModule(5, 2) @@ -135,7 +174,7 @@ def forward(self, x): @mock.patch("pl_bolts.callbacks.data_monitor.ModuleDataMonitor.log_histogram") -def test_module_data_monitor_forward(log_histogram, tmpdir): +def test_module_data_monitor_forward(log_histogram, tmpdir, catch_warnings): """Test that the default ModuleDataMonitor logs inputs and outputs of model's forward.""" monitor = ModuleDataMonitor(submodules=None) model = ModuleDataMonitorModel() @@ -143,6 +182,8 @@ def test_module_data_monitor_forward(log_histogram, tmpdir): default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor], + accelerator="auto", + max_epochs=1, ) monitor.on_train_start(trainer, model) monitor.on_train_batch_start(trainer, model, batch=None, batch_idx=0) @@ -156,7 +197,7 @@ def test_module_data_monitor_forward(log_histogram, tmpdir): @mock.patch("pl_bolts.callbacks.data_monitor.ModuleDataMonitor.log_histogram") -def test_module_data_monitor_submodules_all(log_histogram, tmpdir): +def test_module_data_monitor_submodules_all(log_histogram, tmpdir, catch_warnings): """Test that the ModuleDataMonitor logs the inputs and outputs of each submodule.""" monitor = ModuleDataMonitor(submodules=True) model = ModuleDataMonitorModel() @@ -164,6 +205,8 @@ def test_module_data_monitor_submodules_all(log_histogram, tmpdir): default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor], + accelerator="auto", + max_epochs=1, ) monitor.on_train_start(trainer, model) monitor.on_train_batch_start(trainer, model, batch=None, batch_idx=0) @@ -183,7 +226,7 @@ def test_module_data_monitor_submodules_all(log_histogram, tmpdir): @mock.patch("pl_bolts.callbacks.data_monitor.ModuleDataMonitor.log_histogram") -def test_module_data_monitor_submodules_specific(log_histogram, tmpdir): +def test_module_data_monitor_submodules_specific(log_histogram, tmpdir, catch_warnings): """Test that the ModuleDataMonitor logs the inputs and outputs of selected submodules.""" monitor = ModuleDataMonitor(submodules=["layer1", "layer2.sub_layer"]) model = ModuleDataMonitorModel() @@ -191,6 +234,8 @@ def test_module_data_monitor_submodules_specific(log_histogram, tmpdir): default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor], + accelerator="auto", + max_epochs=1, ) monitor.on_train_start(trainer, model) monitor.on_train_batch_start(trainer, model, batch=None, batch_idx=0)