Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revision of CPC module and tests #902

Draft
wants to merge 27 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
73ea0f2
Add typing. Add docs.
matsumotosan Oct 10, 2022
e5fb432
Add logging to shared step
matsumotosan Oct 11, 2022
0fd056d
Merge branch 'master' into cpc_module
matsumotosan Oct 19, 2022
a194e2b
Merge branch 'master' into cpc_module
Borda Oct 27, 2022
9294eef
Merge branch 'master' into cpc_module
matsumotosan Oct 28, 2022
0166425
Merge branch 'master' into cpc_module
matsumotosan Oct 31, 2022
1402654
Merge branch 'master' into cpc_module
matsumotosan Nov 1, 2022
8ec0edb
Merge branch 'master' into cpc_module
Nov 2, 2022
06ccdc5
Merge branch 'master' into cpc_module
matsumotosan Dec 31, 2022
32b01e5
Merge branch 'master' into cpc_module
matsumotosan Jan 9, 2023
2f1e1d6
Merge branch 'master' into cpc_module
matsumotosan Feb 25, 2023
3fddca0
Merge branch 'master' into cpc_module
Borda May 18, 2023
4d7e946
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 18, 2023
bd23c27
update mergify team
Borda May 19, 2023
1e4d3f6
Merge branch 'master' into cpc_module
Borda May 19, 2023
cafc9ac
Merge branch 'master' into cpc_module
Borda May 19, 2023
0df73ec
Merge branch 'cpc_module' of https://github.com/matsumotosan/lightnin…
matsumotosan Jul 3, 2023
c17b402
build(deps): update jsonargparse[signatures] requirement from <=4.22.…
dependabot[bot] Jul 10, 2023
011f209
cleaning broken strings
Borda Jul 12, 2023
c6f6d3b
format docs with 120 (#1057)
Borda Jul 12, 2023
b8966d1
adding required wrapper (#1056)
Borda Jul 12, 2023
4f910f6
build(deps): bump pytest-rerunfailures from 11.1.2 to 12.0 in /requir…
dependabot[bot] Jul 12, 2023
a912d83
Add typing. Add docs.
matsumotosan Oct 10, 2022
0c8ae31
Add logging to shared step
matsumotosan Oct 11, 2022
ab52e99
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 18, 2023
0184091
Merge branch 'cpc_module' of https://github.com/matsumotosan/lightnin…
matsumotosan Jul 13, 2023
25c70f7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 13, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ repos:
rev: v1.7.3
hooks:
- id: docformatter
args: [--in-place, --wrap-summaries=115, --wrap-descriptions=120]
additional_dependencies: [tomli]
args: ["--in-place"]

- repo: https://github.com/executablebooks/mdformat
rev: 0.7.16
Expand All @@ -44,11 +45,6 @@ repos:
- mdformat_frontmatter
exclude: CHANGELOG.md

#- repo: https://github.com/PyCQA/isort
# rev: 5.12.0
# hooks:
# - id: isort

- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Added `requires` wrapper ([#1056](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/1056))


### Changed
Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ relative_files = true
line-length = 120
exclude = "(.eggs|.git|.hg|.mypy_cache|.venv|_build|buck-out|build|dist)"

[tool.isort]
known_first_party = ["pl_bolts", "tests", "notebooks"]
skip_glob = []
profile = "black"
line_length = 120
[tool.docformatter]
recursive = true
wrap-summaries = 120
wrap-descriptions = 120
blank = true


[tool.ruff]
Expand Down
4 changes: 2 additions & 2 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ coverage[toml] >7.0.0, <8.0.0
pytest ==7.4.0
pytest-cov ==4.1.0
pytest-timeout ==2.1.0
pytest-rerunfailures ==11.1.2
pytest-rerunfailures ==12.0

scikit-learn >=1.0.2, <=1.3.0
sparseml >1.0.0, <1.6.0
ale-py >=0.7, <=0.8.1
jsonargparse[signatures] >4.0.0, <=4.22.0 # for LightningCLI
jsonargparse[signatures] >4.0.0, <=4.22.1 # for LightningCLI
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _augment_requirement(ln: str, comment_char: str = "#", unfreeze: bool = True
'arrow>=1.2.0, <=1.2.2 # strict'
>>> _augment_requirement("arrow", unfreeze=True)
'arrow'

"""
# filer all comments
if comment_char in ln:
Expand Down Expand Up @@ -61,6 +62,7 @@ def _load_requirements(path_dir: str, file_name: str, unfreeze: bool = not _FREE
>>> path_req = os.path.join(_PATH_ROOT, "requirements")
>>> _load_requirements(path_req, "docs.txt") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
['sphinx>=4.0', ...]

"""
with open(os.path.join(path_dir, file_name)) as file:
lines = [ln.strip() for ln in file.readlines()]
Expand All @@ -77,6 +79,7 @@ def _load_readme_description(path_dir: str, homepage: str, ver: str) -> str:

>>> _load_readme_description(_PATH_ROOT, "", "") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
'<div align="center">...'

"""
path_readme = os.path.join(path_dir, "README.md")
with open(path_readme, encoding="utf-8") as fo:
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/callbacks/byol_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class BYOLMAWeightUpdate(Callback):
model.target_network = ...

trainer = Trainer(callbacks=[BYOLMAWeightUpdate()])

"""

def __init__(self, initial_tau: float = 0.996) -> None:
Expand Down
18 changes: 11 additions & 7 deletions src/pl_bolts/callbacks/data_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ class DataMonitorBase(Callback):

def __init__(self, log_every_n_steps: int = None) -> None:
"""Base class for monitoring data histograms in a LightningModule. This requires a logger configured in the
Trainer, otherwise no data is logged. The specific class that inherits from this base defines what data
gets collected.
Trainer, otherwise no data is logged. The specific class that inherits from this base defines what data gets
collected.

Args:
log_every_n_steps: The interval at which histograms should be logged. This defaults to the
interval defined in the Trainer. Use this to override the Trainer default.

"""
super().__init__()
self._log_every_n_steps: Optional[int] = log_every_n_steps
Expand Down Expand Up @@ -84,12 +85,13 @@ def log_histograms(self, batch: Any, group: str = "") -> None:
self.log_histogram(tensor, name)

def log_histogram(self, tensor: Tensor, name: str) -> None:
"""Override this method to customize the logging of histograms. Detaches the tensor from the graph and
moves it to the CPU for logging.
"""Override this method to customize the logging of histograms. Detaches the tensor from the graph and moves it
to the CPU for logging.

Args:
tensor: The tensor for which to log a histogram
name: The name of the tensor as determined by the callback. Example: ``ìnput/0/[64, 1, 28, 28]``

"""
logger = self._trainer.logger
tensor = tensor.detach().cpu()
Expand Down Expand Up @@ -234,9 +236,9 @@ def on_train_batch_start(


def collect_and_name_tensors(data: Any, output: Dict[str, Tensor], parent_name: str = "input") -> None:
"""Recursively fetches all tensors in a (nested) collection of data (depth-first search) and names them. Data
in dictionaries get named by their corresponding keys and otherwise they get indexed by an increasing integer.
The shape of the tensor gets appended to the name as well.
"""Recursively fetches all tensors in a (nested) collection of data (depth-first search) and names them. Data in
dictionaries get named by their corresponding keys and otherwise they get indexed by an increasing integer. The
shape of the tensor gets appended to the name as well.

Args:
data: A collection of data (potentially nested).
Expand All @@ -249,6 +251,7 @@ def collect_and_name_tensors(data: Any, output: Dict[str, Tensor], parent_name:
>>> collect_and_name_tensors(data, output)
>>> output # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
{'input/x/[2, 3]': ..., 'input/y/z/[5]': ...}

"""
assert isinstance(output, dict)
if isinstance(data, Tensor):
Expand All @@ -273,5 +276,6 @@ def shape2str(tensor: Tensor) -> str:
'[1, 2, 3]'
>>> shape2str(torch.rand(4))
'[4]'

"""
return "[" + ", ".join(map(str, tensor.shape)) + "]"
1 change: 1 addition & 0 deletions src/pl_bolts/callbacks/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class PrintTableMetricsCallback(Callback):
# loss│train_loss│val_loss│epoch
# ──────────────────────────────
# 2.2541470527648926│2.2541470527648926│2.2158432006835938│0

"""

def __init__(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/callbacks/sparseml.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class SparseMLCallback(Callback):
Args:
recipe_path: Path to a SparseML compatible yaml recipe.
More information at https://docs.neuralmagic.com/sparseml/source/recipes.html

"""

def __init__(self, recipe_path: str) -> None:
Expand Down
2 changes: 2 additions & 0 deletions src/pl_bolts/callbacks/ssl_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class SSLOnlineEvaluator(Callback): # pragma: no cover
online_eval = SSLOnlineEvaluator(
z_dim=model.z_dim
)

"""

def __init__(
Expand Down Expand Up @@ -182,6 +183,7 @@ def set_training(module: nn.Module, mode: bool):
Args:
module: module to set training mode
mode: whether to set training mode (True) or evaluation mode (False).

"""
original_mode = module.training

Expand Down
5 changes: 3 additions & 2 deletions src/pl_bolts/callbacks/variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

@under_review()
class LatentDimInterpolator(Callback):
"""Interpolates the latent space for a model by setting all dims to zero and stepping through the first two
dims increasing one unit at a time.
"""Interpolates the latent space for a model by setting all dims to zero and stepping through the first two dims
increasing one unit at a time.

Default interpolates between [-5, 5] (-5, -4, -3, ..., 3, 4, 5)

Expand All @@ -28,6 +28,7 @@ class LatentDimInterpolator(Callback):
from pl_bolts.callbacks import LatentDimInterpolator

Trainer(callbacks=[LatentDimInterpolator()])

"""

def __init__(
Expand Down
8 changes: 6 additions & 2 deletions src/pl_bolts/callbacks/verification/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class VerificationBase:

All verifications should run with any
:class: `torch.nn.Module` unless otherwise stated.

"""

def __init__(self, model: nn.Module) -> None:
Expand All @@ -39,14 +40,16 @@ def check(self, *args: Any, **kwargs: Any) -> bool:
`True` if the test passes, and `False` otherwise. Some verifications can only be performed
with a heuristic accuracy, thus the return value may not always reflect the true state of
the system in these cases.

"""

def _get_input_array_copy(self, input_array: Optional[Any] = None) -> Any:
"""Returns a deep copy of the example input array in cases where it is expected that the input changes
during the verification process.
"""Returns a deep copy of the example input array in cases where it is expected that the input changes during
the verification process.

Arguments:
input_array: The input to clone.

"""
if input_array is None and isinstance(self.model, LightningModule):
input_array = self.model.example_input_array
Expand Down Expand Up @@ -89,6 +92,7 @@ class VerificationCallbackBase(Callback):
This type of verification is expected to only work with
:class:`~pytorch_lightning.core.lightning.LightningModule` and will take the input array
from :attr:`~pytorch_lightning.core.lightning.LightningModule.example_input_array` if needed.

"""

def __init__(self, warn: bool = True, error: bool = False) -> None:
Expand Down
4 changes: 4 additions & 0 deletions src/pl_bolts/callbacks/verification/batch_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class BatchGradientVerification(VerificationBase):

This can happen if reshape- and/or permutation operations are carried out in the wrong order or on the wrong tensor
dimensions.

"""

NORM_LAYER_CLASSES = (
Expand Down Expand Up @@ -57,6 +58,7 @@ def check(

Returns:
``True`` if the data in the batch does not mix during the forward pass, and ``False`` otherwise.

"""
input_mapping = input_mapping or default_input_mapping
output_mapping = output_mapping or default_output_mapping
Expand Down Expand Up @@ -151,6 +153,7 @@ def default_input_mapping(data: Any) -> List[Tensor]:
torch.Size([3, 1])
>>> result[1].shape
torch.Size([3, 2])

"""
tensors = collect_tensors(data)
batches: List[Tensor] = []
Expand Down Expand Up @@ -181,6 +184,7 @@ def default_output_mapping(data: Any) -> Tensor:
>>> result = default_output_mapping(data)
>>> result.shape
torch.Size([3, 7])

"""
if isinstance(data, Tensor):
return data
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/callbacks/vision/confused_logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class ConfusedLogitCallback(Callback): # pragma: no cover
Authored by:

- Alfredo Canziani

"""

def __init__(
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/datamodules/async_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class AsynchronousLoader:
if set and DataLoader has a __len__. Otherwise it can be left as None
**kwargs: Any additional arguments to pass to the dataloader if we're
constructing one here

"""

def __init__(
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ class TinyCIFAR10DataModule(CIFAR10DataModule):

dm = CIFAR10DataModule(PATH)
model = LitModel(datamodule=dm)

"""

dataset_cls = TrialCIFAR10
Expand Down
1 change: 1 addition & 0 deletions src/pl_bolts/datamodules/emnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def num_classes(self) -> int:
"""Returns the number of classes.

See the table above.

"""
return len(self.dataset_cls.classes_split_dict[self.split])

Expand Down
18 changes: 13 additions & 5 deletions src/pl_bolts/datamodules/experience_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ExperienceSourceDataset(IterableDataset):

Takes a generate_batch function that returns an iterator. The logic for the experience source and how the batch is
generated is defined the Lightning model itself

"""

def __init__(self, generate_batch: Callable) -> None:
Expand Down Expand Up @@ -95,6 +96,7 @@ def runner(self, device: torch.device) -> Tuple[Experience]:

Returns:
Tuple of Experiences

"""
while True:
# get actions for all envs
Expand All @@ -116,14 +118,15 @@ def runner(self, device: torch.device) -> Tuple[Experience]:
self.iter_idx += 1

def update_history_queue(self, env_idx, exp, history) -> None:
"""Updates the experience history queue with the lastest experiences. In the event of an experience step is
in the done state, the history will be incrementally appended to the queue, removing the tail of the
history each time.
"""Updates the experience history queue with the lastest experiences. In the event of an experience step is in
the done state, the history will be incrementally appended to the queue, removing the tail of the history each
time.

Args:
env_idx: index of the environment
exp: the current experience
history: history of experience steps for this environment

"""
# If there is a full history of step, append history to queue
if len(history) == self.n_steps:
Expand Down Expand Up @@ -184,6 +187,7 @@ def env_step(self, env_idx: int, env: Env, action: List[int]) -> Experience:

Returns:
Experience tuple

"""
next_state, r, is_done, _ = env.step(action[0])

Expand All @@ -198,6 +202,7 @@ def update_env_stats(self, env_idx: int) -> None:

Args:
env_idx: index of the environment used to update stats

"""
self._total_rewards.append(self.cur_rewards[env_idx])
self.total_steps.append(self.cur_steps[env_idx])
Expand Down Expand Up @@ -248,6 +253,7 @@ def runner(self, device: torch.device) -> Experience:

Yields:
Discounted Experience

"""
for experiences in super().runner(device):
last_exp_state, tail_experiences = self.split_head_tail_exp(experiences)
Expand All @@ -263,14 +269,15 @@ def runner(self, device: torch.device) -> Experience:
)

def split_head_tail_exp(self, experiences: Tuple[Experience]) -> Tuple[List, Tuple[Experience]]:
"""Takes in a tuple of experiences and returns the last state and tail experiences based on if the last
state is the end of an episode.
"""Takes in a tuple of experiences and returns the last state and tail experiences based on if the last state is
the end of an episode.

Args:
experiences: Tuple of N Experience

Returns:
last state (Array or None) and remaining Experience

"""
if experiences[-1].done and len(experiences) <= self.steps:
last_exp_state = experiences[-1].new_state
Expand All @@ -288,6 +295,7 @@ def discount_rewards(self, experiences: Tuple[Experience]) -> float:

Returns:
total discounted reward

"""
total_reward = 0.0
for exp in reversed(experiences):
Expand Down
Loading