From 14cf441c4a7aef97e1929990d5c260877c9c0281 Mon Sep 17 00:00:00 2001 From: Luca Medeiros Date: Sun, 31 Jul 2022 16:04:06 +0900 Subject: [PATCH 01/18] add type hint --- pl_bolts/callbacks/data_monitor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_bolts/callbacks/data_monitor.py b/pl_bolts/callbacks/data_monitor.py index 96537962cd..e082f0e6b0 100644 --- a/pl_bolts/callbacks/data_monitor.py +++ b/pl_bolts/callbacks/data_monitor.py @@ -28,7 +28,7 @@ class DataMonitorBase(Callback): WandbLogger, ) - def __init__(self, log_every_n_steps: int = None): + def __init__(self, log_every_n_steps: Optional[int] = None): """Base class for monitoring data histograms in a LightningModule. This requires a logger configured in the Trainer, otherwise no data is logged. The specific class that inherits from this base defines what data gets collected. @@ -97,7 +97,7 @@ def log_histogram(self, tensor: Tensor, name: str) -> None: logger.experiment.log(data={name: wandb.Histogram(tensor)}, commit=False) - def _is_logger_available(self, logger: LightningLoggerBase) -> bool: + def _is_logger_available(self, logger: Optional[LightningLoggerBase]) -> bool: available = True if not logger: rank_zero_warn("Cannot log histograms because Trainer has no logger.") From 7488302691b9988266dcf44ebf5f15bd9d559b6c Mon Sep 17 00:00:00 2001 From: Luca Medeiros Date: Mon, 1 Aug 2022 00:01:33 +0900 Subject: [PATCH 02/18] minor changes data_monitor --- pl_bolts/callbacks/data_monitor.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/pl_bolts/callbacks/data_monitor.py b/pl_bolts/callbacks/data_monitor.py index e082f0e6b0..73c32a215a 100644 --- a/pl_bolts/callbacks/data_monitor.py +++ b/pl_bolts/callbacks/data_monitor.py @@ -20,7 +20,6 @@ warn_missing_pkg("wandb") -@under_review() class DataMonitorBase(Callback): supported_loggers = ( @@ -105,13 +104,11 @@ def _is_logger_available(self, logger: Optional[LightningLoggerBase]) -> bool: if not isinstance(logger, self.supported_loggers): rank_zero_warn( f"{self.__class__.__name__} does not support logging with {logger.__class__.__name__}." - f" Supported loggers are: {', '.join(map(lambda x: str(x.__name__), self.supported_loggers))}" - ) + f" Supported loggers are: {', '.join(map(lambda x: str(x.__name__), self.supported_loggers))}") available = False return available -@under_review() class ModuleDataMonitor(DataMonitorBase): GROUP_NAME_INPUT = "input" @@ -120,9 +117,9 @@ class ModuleDataMonitor(DataMonitorBase): def __init__( self, submodules: Optional[Union[bool, List[str]]] = None, - log_every_n_steps: int = None, + log_every_n_steps: Optional[int] = None, ): - """ + """Logs the in- and output histogram of submodules. Args: submodules: If `True`, logs the in- and output histograms of every submodule in the LightningModule, including the root module itself. @@ -157,13 +154,10 @@ def __init__( def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: super().on_train_start(trainer, pl_module) submodule_dict = dict(pl_module.named_modules()) - self._hook_handles = [] - for name in self._get_submodule_names(pl_module): + for name in self._get_submodule_names(submodule_dict): if name not in submodule_dict: - rank_zero_warn( - f"{name} is not a valid identifier for a submodule in {pl_module.__class__.__name__}," - " skipping this key." - ) + rank_zero_warn(f"{name} is not a valid identifier for a submodule in {pl_module.__class__.__name__}," + " skipping this key.") continue handle = self._register_hook(name, submodule_dict[name]) self._hook_handles.append(handle) @@ -172,7 +166,7 @@ def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: for handle in self._hook_handles: handle.remove() - def _get_submodule_names(self, root_module: nn.Module) -> List[str]: + def _get_submodule_names(self, named_modules: dict) -> List[str]: # default is the root module only names = [""] @@ -180,7 +174,7 @@ def _get_submodule_names(self, root_module: nn.Module) -> List[str]: names = self._submodule_names if self._submodule_names is True: - names = [name for name, _ in root_module.named_modules()] + names = list(named_modules.keys()) return names @@ -197,12 +191,11 @@ def hook(_: Module, inp: Sequence, out: Sequence) -> None: return handle -@under_review() class TrainingDataMonitor(DataMonitorBase): GROUP_NAME = "training_step" - def __init__(self, log_every_n_steps: int = None): + def __init__(self, log_every_n_steps: Optional[int] = None): """Callback that logs the histogram of values in the batched data passed to `training_step`. Args: @@ -271,4 +264,4 @@ def shape2str(tensor: Tensor) -> str: >>> shape2str(torch.rand(4)) '[4]' """ - return "[" + ", ".join(map(str, tensor.shape)) + "]" + return str(list(tensor.shape)) From c00b13140197bc422689b5e736a2ddf1112f7c90 Mon Sep 17 00:00:00 2001 From: Luca Medeiros Date: Mon, 1 Aug 2022 00:01:55 +0900 Subject: [PATCH 03/18] review tests data_monitor consistency --- tests/callbacks/test_data_monitor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/callbacks/test_data_monitor.py b/tests/callbacks/test_data_monitor.py index e4abe1f914..811b352440 100644 --- a/tests/callbacks/test_data_monitor.py +++ b/tests/callbacks/test_data_monitor.py @@ -3,7 +3,7 @@ import pytest import torch -from pytorch_lightning import Trainer +from pytorch_lightning import Trainer, LightningModule from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from torch import nn @@ -57,7 +57,7 @@ def test_base_no_logger_warning(): monitor = TrainingDataMonitor() trainer = Trainer(logger=False, callbacks=[monitor]) with pytest.warns(UserWarning, match="Cannot log histograms because Trainer has no logger"): - monitor.on_train_start(trainer, pl_module=None) + monitor.on_train_start(trainer, pl_module=LightningModule()) def test_base_unsupported_logger_warning(tmpdir): @@ -65,7 +65,7 @@ def test_base_unsupported_logger_warning(tmpdir): monitor = TrainingDataMonitor() trainer = Trainer(logger=LoggerCollection([TensorBoardLogger(tmpdir)]), callbacks=[monitor]) with pytest.warns(UserWarning, match="does not support logging with LoggerCollection"): - monitor.on_train_start(trainer, pl_module=None) + monitor.on_train_start(trainer, pl_module=LightningModule()) @mock.patch("pl_bolts.callbacks.data_monitor.TrainingDataMonitor.log_histogram") @@ -121,7 +121,7 @@ def forward(self, *args, **kwargs): return self.sub_layer(*args, **kwargs) -class ModuleDataMonitorModel(nn.Module): +class ModuleDataMonitorModel(LightningModule): def __init__(self): super().__init__() self.layer1 = nn.Linear(12, 5) From 5c967fcaaba634e144ded469c597dc886dcd6c9f Mon Sep 17 00:00:00 2001 From: Luca Medeiros Date: Wed, 3 Aug 2022 01:59:23 +0900 Subject: [PATCH 04/18] pre-commit --- pl_bolts/callbacks/data_monitor.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pl_bolts/callbacks/data_monitor.py b/pl_bolts/callbacks/data_monitor.py index 73c32a215a..e7d5113d00 100644 --- a/pl_bolts/callbacks/data_monitor.py +++ b/pl_bolts/callbacks/data_monitor.py @@ -104,7 +104,8 @@ def _is_logger_available(self, logger: Optional[LightningLoggerBase]) -> bool: if not isinstance(logger, self.supported_loggers): rank_zero_warn( f"{self.__class__.__name__} does not support logging with {logger.__class__.__name__}." - f" Supported loggers are: {', '.join(map(lambda x: str(x.__name__), self.supported_loggers))}") + f" Supported loggers are: {', '.join(map(lambda x: str(x.__name__), self.supported_loggers))}" + ) available = False return available @@ -156,8 +157,10 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: submodule_dict = dict(pl_module.named_modules()) for name in self._get_submodule_names(submodule_dict): if name not in submodule_dict: - rank_zero_warn(f"{name} is not a valid identifier for a submodule in {pl_module.__class__.__name__}," - " skipping this key.") + rank_zero_warn( + f"{name} is not a valid identifier for a submodule in {pl_module.__class__.__name__}," + " skipping this key." + ) continue handle = self._register_hook(name, submodule_dict[name]) self._hook_handles.append(handle) From d3cf97884e287d8d58c1fc87b3942f7b834fe794 Mon Sep 17 00:00:00 2001 From: Luca Medeiros Date: Wed, 3 Aug 2022 02:28:21 +0900 Subject: [PATCH 05/18] type hint, under review --- pl_bolts/callbacks/data_monitor.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pl_bolts/callbacks/data_monitor.py b/pl_bolts/callbacks/data_monitor.py index e7d5113d00..515507fe2b 100644 --- a/pl_bolts/callbacks/data_monitor.py +++ b/pl_bolts/callbacks/data_monitor.py @@ -11,7 +11,6 @@ from torch.utils.hooks import RemovableHandle from pl_bolts.utils import _WANDB_AVAILABLE -from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg if _WANDB_AVAILABLE: @@ -185,7 +184,7 @@ def _register_hook(self, module_name: str, module: nn.Module) -> RemovableHandle input_group_name = f"{self.GROUP_NAME_INPUT}/{module_name}" if module_name else self.GROUP_NAME_INPUT output_group_name = f"{self.GROUP_NAME_OUTPUT}/{module_name}" if module_name else self.GROUP_NAME_OUTPUT - def hook(_: Module, inp: Sequence, out: Sequence) -> None: + def hook(_: Module, inp: Any, out: Any) -> None: inp = inp[0] if len(inp) == 1 else inp self.log_histograms(inp, group=input_group_name) self.log_histograms(out, group=output_group_name) @@ -226,7 +225,11 @@ def on_train_batch_start( self.log_histograms(batch, group=self.GROUP_NAME) -def collect_and_name_tensors(data: Any, output: Dict[str, Tensor], parent_name: str = "input") -> None: +def collect_and_name_tensors( + data: Union[Tensor, dict, Sequence], + output: Dict[str, Tensor], + parent_name: str = "input", +) -> None: """Recursively fetches all tensors in a (nested) collection of data (depth-first search) and names them. Data in dictionaries get named by their corresponding keys and otherwise they get indexed by an increasing integer. The shape of the tensor gets appended to the name as well. @@ -257,7 +260,6 @@ def collect_and_name_tensors(data: Any, output: Dict[str, Tensor], parent_name: collect_and_name_tensors(item, output, parent_name=f"{parent_name}/{i:d}") -@under_review() def shape2str(tensor: Tensor) -> str: """Returns the shape of a tensor in bracket notation as a string. From 05c93435b0b1044f5e4c9db6e0f70ec53256e23f Mon Sep 17 00:00:00 2001 From: Luca Medeiros Date: Wed, 3 Aug 2022 03:04:19 +0900 Subject: [PATCH 06/18] add catch_warnings --- tests/callbacks/test_data_monitor.py | 40 ++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/tests/callbacks/test_data_monitor.py b/tests/callbacks/test_data_monitor.py index 811b352440..bb18c4530e 100644 --- a/tests/callbacks/test_data_monitor.py +++ b/tests/callbacks/test_data_monitor.py @@ -1,10 +1,12 @@ +import warnings from unittest import mock from unittest.mock import call import pytest import torch -from pytorch_lightning import Trainer, LightningModule +from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger +from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch import nn from pl_bolts.callbacks import ModuleDataMonitor, TrainingDataMonitor @@ -13,8 +15,15 @@ @pytest.mark.parametrize(["log_every_n_steps", "max_steps", "expected_calls"], [pytest.param(3, 10, 3)]) @mock.patch("pl_bolts.callbacks.data_monitor.DataMonitorBase.log_histogram") -def test_base_log_interval_override(log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir): +def test_base_log_interval_override( + log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir, catch_warnings +): """Test logging interval set by log_every_n_steps argument.""" + warnings.filterwarnings( + "ignore", + message=".*does not have many workers which may be a bottleneck.*", + category=PossibleUserWarning, + ) monitor = TrainingDataMonitor(log_every_n_steps=log_every_n_steps) model = LitMNIST(data_dir=datadir, num_workers=0) trainer = Trainer( @@ -22,6 +31,7 @@ def test_base_log_interval_override(log_histogram, tmpdir, log_every_n_steps, ma log_every_n_steps=1, max_steps=max_steps, callbacks=[monitor], + accelerator="auto", ) trainer.fit(model) @@ -38,8 +48,15 @@ def test_base_log_interval_override(log_histogram, tmpdir, log_every_n_steps, ma ], ) @mock.patch("pl_bolts.callbacks.data_monitor.DataMonitorBase.log_histogram") -def test_base_log_interval_fallback(log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir): +def test_base_log_interval_fallback( + log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir, catch_warnings +): """Test that if log_every_n_steps not set in the callback, fallback to what is defined in the Trainer.""" + warnings.filterwarnings( + "ignore", + message=".*does not have many workers which may be a bottleneck.*", + category=PossibleUserWarning, + ) monitor = TrainingDataMonitor() model = LitMNIST(data_dir=datadir, num_workers=0) trainer = Trainer( @@ -47,6 +64,7 @@ def test_base_log_interval_fallback(log_histogram, tmpdir, log_every_n_steps, ma log_every_n_steps=log_every_n_steps, max_steps=max_steps, callbacks=[monitor], + accelerator="auto", ) trainer.fit(model) assert log_histogram.call_count == (expected_calls * 2) # 2 tensors per log call @@ -55,7 +73,11 @@ def test_base_log_interval_fallback(log_histogram, tmpdir, log_every_n_steps, ma def test_base_no_logger_warning(): """Test a warning is displayed when Trainer has no logger.""" monitor = TrainingDataMonitor() - trainer = Trainer(logger=False, callbacks=[monitor]) + trainer = Trainer( + logger=False, + callbacks=[monitor], + accelerator="auto", + ) with pytest.warns(UserWarning, match="Cannot log histograms because Trainer has no logger"): monitor.on_train_start(trainer, pl_module=LightningModule()) @@ -63,7 +85,11 @@ def test_base_no_logger_warning(): def test_base_unsupported_logger_warning(tmpdir): """Test a warning is displayed when an unsupported logger is used.""" monitor = TrainingDataMonitor() - trainer = Trainer(logger=LoggerCollection([TensorBoardLogger(tmpdir)]), callbacks=[monitor]) + trainer = Trainer( + logger=LoggerCollection([TensorBoardLogger(tmpdir)]), + callbacks=[monitor], + accelerator="auto", + ) with pytest.warns(UserWarning, match="does not support logging with LoggerCollection"): monitor.on_train_start(trainer, pl_module=LightningModule()) @@ -77,6 +103,7 @@ def test_training_data_monitor(log_histogram, tmpdir, datadir): default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor], + accelerator="auto", ) monitor.on_train_start(trainer, model) @@ -149,6 +176,7 @@ def test_module_data_monitor_forward(log_histogram, tmpdir): default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor], + accelerator="auto", ) monitor.on_train_start(trainer, model) monitor.on_train_batch_start(trainer, model, batch=None, batch_idx=0) @@ -170,6 +198,7 @@ def test_module_data_monitor_submodules_all(log_histogram, tmpdir): default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor], + accelerator="auto", ) monitor.on_train_start(trainer, model) monitor.on_train_batch_start(trainer, model, batch=None, batch_idx=0) @@ -197,6 +226,7 @@ def test_module_data_monitor_submodules_specific(log_histogram, tmpdir): default_root_dir=tmpdir, log_every_n_steps=1, callbacks=[monitor], + accelerator="auto", ) monitor.on_train_start(trainer, model) monitor.on_train_batch_start(trainer, model, batch=None, batch_idx=0) From 72edfa34928646aa632afd5332fd80588679ea07 Mon Sep 17 00:00:00 2001 From: Luca Medeiros Date: Wed, 10 Aug 2022 08:13:14 +0900 Subject: [PATCH 07/18] catch_warnings to all tests --- tests/callbacks/test_data_monitor.py | 45 ++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/tests/callbacks/test_data_monitor.py b/tests/callbacks/test_data_monitor.py index bb18c4530e..fb3b68c818 100644 --- a/tests/callbacks/test_data_monitor.py +++ b/tests/callbacks/test_data_monitor.py @@ -6,6 +6,7 @@ import torch from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger +from pytorch_lightning.utilities.rank_zero import LightningDeprecationWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch import nn @@ -16,7 +17,13 @@ @pytest.mark.parametrize(["log_every_n_steps", "max_steps", "expected_calls"], [pytest.param(3, 10, 3)]) @mock.patch("pl_bolts.callbacks.data_monitor.DataMonitorBase.log_histogram") def test_base_log_interval_override( - log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir, catch_warnings + log_histogram, + tmpdir, + log_every_n_steps, + max_steps, + expected_calls, + datadir, + catch_warnings, ): """Test logging interval set by log_every_n_steps argument.""" warnings.filterwarnings( @@ -49,7 +56,13 @@ def test_base_log_interval_override( ) @mock.patch("pl_bolts.callbacks.data_monitor.DataMonitorBase.log_histogram") def test_base_log_interval_fallback( - log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir, catch_warnings + log_histogram, + tmpdir, + log_every_n_steps, + max_steps, + expected_calls, + datadir, + catch_warnings, ): """Test that if log_every_n_steps not set in the callback, fallback to what is defined in the Trainer.""" warnings.filterwarnings( @@ -70,32 +83,34 @@ def test_base_log_interval_fallback( assert log_histogram.call_count == (expected_calls * 2) # 2 tensors per log call -def test_base_no_logger_warning(): +def test_base_no_logger_warning(catch_warnings): """Test a warning is displayed when Trainer has no logger.""" monitor = TrainingDataMonitor() - trainer = Trainer( - logger=False, - callbacks=[monitor], - accelerator="auto", - ) + trainer = Trainer(logger=False, callbacks=[monitor], accelerator="auto", max_epochs=-1) with pytest.warns(UserWarning, match="Cannot log histograms because Trainer has no logger"): monitor.on_train_start(trainer, pl_module=LightningModule()) -def test_base_unsupported_logger_warning(tmpdir): +def test_base_unsupported_logger_warning(tmpdir, catch_warnings): """Test a warning is displayed when an unsupported logger is used.""" + warnings.filterwarnings( + "ignore", + message=".*is deprecated in v1.6.*", + category=LightningDeprecationWarning, + ) monitor = TrainingDataMonitor() trainer = Trainer( logger=LoggerCollection([TensorBoardLogger(tmpdir)]), callbacks=[monitor], accelerator="auto", + max_epochs=1, ) with pytest.warns(UserWarning, match="does not support logging with LoggerCollection"): monitor.on_train_start(trainer, pl_module=LightningModule()) @mock.patch("pl_bolts.callbacks.data_monitor.TrainingDataMonitor.log_histogram") -def test_training_data_monitor(log_histogram, tmpdir, datadir): +def test_training_data_monitor(log_histogram, tmpdir, datadir, catch_warnings): """Test that the TrainingDataMonitor logs histograms of data points going into training_step.""" monitor = TrainingDataMonitor() model = LitMNIST(data_dir=datadir) @@ -104,6 +119,7 @@ def test_training_data_monitor(log_histogram, tmpdir, datadir): log_every_n_steps=1, callbacks=[monitor], accelerator="auto", + max_epochs=1, ) monitor.on_train_start(trainer, model) @@ -168,7 +184,7 @@ def forward(self, x): @mock.patch("pl_bolts.callbacks.data_monitor.ModuleDataMonitor.log_histogram") -def test_module_data_monitor_forward(log_histogram, tmpdir): +def test_module_data_monitor_forward(log_histogram, tmpdir, catch_warnings): """Test that the default ModuleDataMonitor logs inputs and outputs of model's forward.""" monitor = ModuleDataMonitor(submodules=None) model = ModuleDataMonitorModel() @@ -177,6 +193,7 @@ def test_module_data_monitor_forward(log_histogram, tmpdir): log_every_n_steps=1, callbacks=[monitor], accelerator="auto", + max_epochs=1, ) monitor.on_train_start(trainer, model) monitor.on_train_batch_start(trainer, model, batch=None, batch_idx=0) @@ -190,7 +207,7 @@ def test_module_data_monitor_forward(log_histogram, tmpdir): @mock.patch("pl_bolts.callbacks.data_monitor.ModuleDataMonitor.log_histogram") -def test_module_data_monitor_submodules_all(log_histogram, tmpdir): +def test_module_data_monitor_submodules_all(log_histogram, tmpdir, catch_warnings): """Test that the ModuleDataMonitor logs the inputs and outputs of each submodule.""" monitor = ModuleDataMonitor(submodules=True) model = ModuleDataMonitorModel() @@ -199,6 +216,7 @@ def test_module_data_monitor_submodules_all(log_histogram, tmpdir): log_every_n_steps=1, callbacks=[monitor], accelerator="auto", + max_epochs=1, ) monitor.on_train_start(trainer, model) monitor.on_train_batch_start(trainer, model, batch=None, batch_idx=0) @@ -218,7 +236,7 @@ def test_module_data_monitor_submodules_all(log_histogram, tmpdir): @mock.patch("pl_bolts.callbacks.data_monitor.ModuleDataMonitor.log_histogram") -def test_module_data_monitor_submodules_specific(log_histogram, tmpdir): +def test_module_data_monitor_submodules_specific(log_histogram, tmpdir, catch_warnings): """Test that the ModuleDataMonitor logs the inputs and outputs of selected submodules.""" monitor = ModuleDataMonitor(submodules=["layer1", "layer2.sub_layer"]) model = ModuleDataMonitorModel() @@ -227,6 +245,7 @@ def test_module_data_monitor_submodules_specific(log_histogram, tmpdir): log_every_n_steps=1, callbacks=[monitor], accelerator="auto", + max_epochs=1, ) monitor.on_train_start(trainer, model) monitor.on_train_batch_start(trainer, model, batch=None, batch_idx=0) From a62e904201600bef98b8939e4e8dbee94acf1cb5 Mon Sep 17 00:00:00 2001 From: Luca Medeiros Date: Mon, 19 Sep 2022 14:07:12 -0300 Subject: [PATCH 08/18] fix minor changes --- pl_bolts/callbacks/data_monitor.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pl_bolts/callbacks/data_monitor.py b/pl_bolts/callbacks/data_monitor.py index 515507fe2b..219b78db89 100644 --- a/pl_bolts/callbacks/data_monitor.py +++ b/pl_bolts/callbacks/data_monitor.py @@ -3,7 +3,7 @@ import numpy as np import torch from pytorch_lightning import Callback, LightningModule, Trainer -from pytorch_lightning.loggers import LightningLoggerBase, TensorBoardLogger, WandbLogger +from pytorch_lightning.loggers import Logger, TensorBoardLogger, WandbLogger from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from torch import Tensor, nn @@ -95,7 +95,7 @@ def log_histogram(self, tensor: Tensor, name: str) -> None: logger.experiment.log(data={name: wandb.Histogram(tensor)}, commit=False) - def _is_logger_available(self, logger: Optional[LightningLoggerBase]) -> bool: + def _is_logger_available(self, logger: Optional[Logger]) -> bool: available = True if not logger: rank_zero_warn("Cannot log histograms because Trainer has no logger.") @@ -154,6 +154,7 @@ def __init__( def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: super().on_train_start(trainer, pl_module) submodule_dict = dict(pl_module.named_modules()) + self._hook_handles = [] for name in self._get_submodule_names(submodule_dict): if name not in submodule_dict: rank_zero_warn( @@ -168,7 +169,7 @@ def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: for handle in self._hook_handles: handle.remove() - def _get_submodule_names(self, named_modules: dict) -> List[str]: + def _get_submodule_names(self, named_modules: Dict[str, nn.Module]) -> List[str]: # default is the root module only names = [""] @@ -176,7 +177,7 @@ def _get_submodule_names(self, named_modules: dict) -> List[str]: names = self._submodule_names if self._submodule_names is True: - names = list(named_modules.keys()) + names = list(named_modules) return names From c36a496f0170058f1f4505b366a969d94c7a0604 Mon Sep 17 00:00:00 2001 From: otaj Date: Wed, 2 Nov 2022 14:27:14 +0100 Subject: [PATCH 09/18] precommit --- tests/callbacks/test_data_monitor.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/callbacks/test_data_monitor.py b/tests/callbacks/test_data_monitor.py index 61116296d6..9f36a7779c 100644 --- a/tests/callbacks/test_data_monitor.py +++ b/tests/callbacks/test_data_monitor.py @@ -4,12 +4,10 @@ import pytest import torch - from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.utilities.rank_zero import LightningDeprecationWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning - from torch import nn from pl_bolts.callbacks import ModuleDataMonitor, TrainingDataMonitor @@ -96,7 +94,6 @@ def test_base_no_logger_warning(catch_warnings): monitor.on_train_start(trainer, pl_module=LightningModule()) - def test_base_unsupported_logger_warning(tmpdir, catch_warnings): """Test a warning is displayed when an unsupported logger is used.""" warnings.filterwarnings( @@ -115,7 +112,6 @@ def test_base_unsupported_logger_warning(tmpdir, catch_warnings): monitor.on_train_start(trainer, pl_module=LightningModule()) - @mock.patch("pl_bolts.callbacks.data_monitor.TrainingDataMonitor.log_histogram") def test_training_data_monitor(log_histogram, tmpdir, datadir, catch_warnings): """Test that the TrainingDataMonitor logs histograms of data points going into training_step.""" From bd23c27f3adc7712a268c22e19edf05b78f7b117 Mon Sep 17 00:00:00 2001 From: Jirka B Date: Fri, 19 May 2023 10:17:31 -0400 Subject: [PATCH 10/18] update mergify team --- .github/mergify.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/mergify.yml b/.github/mergify.yml index dc431ff4c5..314ae28dca 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -79,4 +79,4 @@ pull_request_rules: actions: request_reviews: teams: - - "@PyTorchLightning/core-bolts" + - "@Lightning-Universe/core-Bolts" From 90338ae9b5ef4f3130c63787516d1070837dba84 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 20 May 2023 20:49:14 +0000 Subject: [PATCH 11/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pl_bolts/callbacks/data_monitor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pl_bolts/callbacks/data_monitor.py b/src/pl_bolts/callbacks/data_monitor.py index 1f3d88732b..3d37c1728a 100644 --- a/src/pl_bolts/callbacks/data_monitor.py +++ b/src/pl_bolts/callbacks/data_monitor.py @@ -150,7 +150,6 @@ def __init__( # specific submodules trainer = Trainer(callbacks=[ModuleDataMonitor(submodules=["generator", "generator.conv1"])]) - """ super().__init__(log_every_n_steps=log_every_n_steps) self._submodule_names = submodules From 33824d419ed5c6775f188a9cec064e7e83da161c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 31 May 2023 18:23:42 +0000 Subject: [PATCH 12/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/callbacks/test_data_monitor.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/callbacks/test_data_monitor.py b/tests/callbacks/test_data_monitor.py index 51d903af4a..1ef6a9e02c 100644 --- a/tests/callbacks/test_data_monitor.py +++ b/tests/callbacks/test_data_monitor.py @@ -18,11 +18,7 @@ # @pytest.mark.parametrize(("log_every_n_steps", "max_steps", "expected_calls"), [pytest.param(3, 10, 3)]) @mock.patch("pl_bolts.callbacks.data_monitor.DataMonitorBase.log_histogram") def test_base_log_interval_override( - log_histogram, - tmpdir, - datadir, - catch_warnings, - log_every_n_steps=3, max_steps=10, expected_calls=3 + log_histogram, tmpdir, datadir, catch_warnings, log_every_n_steps=3, max_steps=10, expected_calls=3 ): """Test logging interval set by log_every_n_steps argument.""" warnings.filterwarnings( From 7b8097a7ddcbcbdb50e9c8d0b43a31be90b599e0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 31 May 2023 22:43:29 +0000 Subject: [PATCH 13/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/callbacks/test_data_monitor.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/callbacks/test_data_monitor.py b/tests/callbacks/test_data_monitor.py index 35b8f44c6b..a78165bdc7 100644 --- a/tests/callbacks/test_data_monitor.py +++ b/tests/callbacks/test_data_monitor.py @@ -4,18 +4,15 @@ import pytest import torch +from pl_bolts.callbacks import ModuleDataMonitor, TrainingDataMonitor +from pl_bolts.datamodules import MNISTDataModule +from pl_bolts.models import LitMNIST from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.utilities.rank_zero import LightningDeprecationWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch import nn -from pl_bolts.callbacks import ModuleDataMonitor, TrainingDataMonitor -from pl_bolts.datamodules import MNISTDataModule -from pl_bolts.models import LitMNIST -from pytorch_lightning import Trainer -from torch import nn - # @pytest.mark.parametrize(("log_every_n_steps", "max_steps", "expected_calls"), [pytest.param(3, 10, 3)]) @mock.patch("pl_bolts.callbacks.data_monitor.DataMonitorBase.log_histogram") From b4e44f6b1dea37b5eb21d25436e41ff09bba66e5 Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 30 Jun 2023 10:16:40 +0200 Subject: [PATCH 14/18] drop LoggerCollection --- tests/callbacks/test_data_monitor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/callbacks/test_data_monitor.py b/tests/callbacks/test_data_monitor.py index a78165bdc7..eefbb8474c 100644 --- a/tests/callbacks/test_data_monitor.py +++ b/tests/callbacks/test_data_monitor.py @@ -8,7 +8,7 @@ from pl_bolts.datamodules import MNISTDataModule from pl_bolts.models import LitMNIST from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger +from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.rank_zero import LightningDeprecationWarning from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch import nn @@ -96,7 +96,7 @@ def test_base_unsupported_logger_warning(tmpdir, catch_warnings): ) monitor = TrainingDataMonitor() trainer = Trainer( - logger=LoggerCollection([TensorBoardLogger(tmpdir)]), + logger=[TensorBoardLogger(tmpdir)], callbacks=[monitor], accelerator="auto", max_epochs=1, From b052d9f0c8bbab1d8ba6e1e0f747dd8b02a57c13 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Fri, 30 Jun 2023 11:07:32 +0200 Subject: [PATCH 15/18] logger --- tests/callbacks/test_data_monitor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_data_monitor.py b/tests/callbacks/test_data_monitor.py index eefbb8474c..097d22cd74 100644 --- a/tests/callbacks/test_data_monitor.py +++ b/tests/callbacks/test_data_monitor.py @@ -96,7 +96,7 @@ def test_base_unsupported_logger_warning(tmpdir, catch_warnings): ) monitor = TrainingDataMonitor() trainer = Trainer( - logger=[TensorBoardLogger(tmpdir)], + logger=TensorBoardLogger(tmpdir), callbacks=[monitor], accelerator="auto", max_epochs=1, From f3c97229adfdd58e16ecfe4e2e40ed8b07708b26 Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 30 Jun 2023 11:25:15 +0200 Subject: [PATCH 16/18] use lightning_utilities --- src/pl_bolts/callbacks/data_monitor.py | 4 ++-- src/pl_bolts/callbacks/verification/batch_gradient.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pl_bolts/callbacks/data_monitor.py b/src/pl_bolts/callbacks/data_monitor.py index 7b30c9c1d3..9d4f2ae892 100644 --- a/src/pl_bolts/callbacks/data_monitor.py +++ b/src/pl_bolts/callbacks/data_monitor.py @@ -2,10 +2,10 @@ import numpy as np import torch +from lightning_utilities import apply_to_collection +from lightning_utilities.core.rank_zero import rank_zero_warn from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger -from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.apply_func import apply_to_collection from torch import Tensor, nn from torch.nn import Module from torch.utils.hooks import RemovableHandle diff --git a/src/pl_bolts/callbacks/verification/batch_gradient.py b/src/pl_bolts/callbacks/verification/batch_gradient.py index a7fef82548..466e7b902c 100644 --- a/src/pl_bolts/callbacks/verification/batch_gradient.py +++ b/src/pl_bolts/callbacks/verification/batch_gradient.py @@ -4,8 +4,8 @@ import torch import torch.nn as nn +from lightning_utilities import apply_to_collection from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import Tensor From 85c839792eb5fb0ccb9631ed706e8b1ff8f13d94 Mon Sep 17 00:00:00 2001 From: Jirka Date: Fri, 30 Jun 2023 11:46:03 +0200 Subject: [PATCH 17/18] params --- tests/callbacks/test_data_monitor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/callbacks/test_data_monitor.py b/tests/callbacks/test_data_monitor.py index 097d22cd74..f6525b77e4 100644 --- a/tests/callbacks/test_data_monitor.py +++ b/tests/callbacks/test_data_monitor.py @@ -43,10 +43,10 @@ def test_base_log_interval_override( @pytest.mark.parametrize( ("log_every_n_steps", "max_steps", "expected_calls"), [ - pytest.param(1, 5, 5), - pytest.param(2, 5, 2), - pytest.param(5, 5, 1), - pytest.param(6, 5, 0), + (1, 5, 5), + (2, 5, 2), + (5, 5, 1), + (6, 5, 0) ], ) @mock.patch("pl_bolts.callbacks.data_monitor.DataMonitorBase.log_histogram") From c33dabc6b587aeb461e237426995fb7e6d0c3f66 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 30 Jun 2023 09:46:35 +0000 Subject: [PATCH 18/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/callbacks/test_data_monitor.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/callbacks/test_data_monitor.py b/tests/callbacks/test_data_monitor.py index f6525b77e4..3de4529bfd 100644 --- a/tests/callbacks/test_data_monitor.py +++ b/tests/callbacks/test_data_monitor.py @@ -42,12 +42,7 @@ def test_base_log_interval_override( @pytest.mark.parametrize( ("log_every_n_steps", "max_steps", "expected_calls"), - [ - (1, 5, 5), - (2, 5, 2), - (5, 5, 1), - (6, 5, 0) - ], + [(1, 5, 5), (2, 5, 2), (5, 5, 1), (6, 5, 0)], ) @mock.patch("pl_bolts.callbacks.data_monitor.DataMonitorBase.log_histogram") def test_base_log_interval_fallback(