Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revision of the MoCo SSL model #928

Merged
merged 25 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
443514e
Moco_v2 renamed to MoCo
senarvi Oct 28, 2022
ddef423
Refactored
senarvi Oct 28, 2022
e5011bb
Merge branch 'master' into moco-revision
senarvi Oct 31, 2022
7ebf094
Merge branch 'master' into moco-revision
senarvi Nov 3, 2022
74de0aa
Terminology changed: projector renamed to head
senarvi Nov 24, 2022
e5e1266
Merge branch 'master' into moco-revision
senarvi Jun 29, 2023
221c36d
Fixed transform names
senarvi Jun 29, 2023
9e33775
Fixed code formatting
senarvi Jun 29, 2023
7fef0bf
Fixed MoCo unit test
senarvi Jun 29, 2023
b6c3c6d
Fixed lr_scheduler type annotation
senarvi Jun 29, 2023
ef5ea1d
Merge branch 'master' into moco-revision
senarvi Jun 29, 2023
cc4c077
Supports a 4-dimensional tensor of transformed images, in addition to…
senarvi Jul 2, 2023
b4c238d
Merge branch 'moco-revision' of github.com:groke-technologies/pytorch…
senarvi Jul 2, 2023
00d86bb
Merge branch 'master' into moco-revision
senarvi Jul 2, 2023
0bf096c
Fixed error messages and code formatting
senarvi Jul 2, 2023
337c9f9
moco2_module is now named moco_module
senarvi Jul 2, 2023
f1c7be3
moco_modules uses LightningCLI
senarvi Jul 2, 2023
7c0b0b6
MoCo CLI uses CIFAR10 so that the tests will pass (ImageNet cannot be…
senarvi Jul 2, 2023
b17cf8f
Fixed mypy errors in the backward compatibility code of LightningCLI …
senarvi Jul 2, 2023
2a6c359
Merge branch 'master' into moco-revision
senarvi Jul 4, 2023
ed39b61
Merge branch 'master' into moco-revision
senarvi Jul 4, 2023
558fe0f
Import LightningCLI from pytorch_lightning.utilities.cli
senarvi Jul 8, 2023
adc59bb
Merge branch 'moco-revision' of github.com:groke-technologies/pytorch…
senarvi Jul 8, 2023
ea8c6be
Fixed code formatting
senarvi Jul 8, 2023
b057957
Import LightningCLI from pytorch_lightning.cli
senarvi Jul 8, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/introduction_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ All models are tested (daily), benchmarked, documented and work on CPUs, TPUs, G

from pl_bolts.models import VAE
from pl_bolts.models.vision import GPT2, ImageGPT, PixelCNN
from pl_bolts.models.self_supervised import AMDIM, CPC_v2, SimCLR, Moco_v2
from pl_bolts.models.self_supervised import AMDIM, CPC_v2, SimCLR, MoCo
from pl_bolts.models import LinearRegression, LogisticRegression
from pl_bolts.models.gans import GAN
from pl_bolts.callbacks import PrintTableMetricsCallback
Expand Down
2 changes: 1 addition & 1 deletion docs/source/models/self_supervised.rst
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ CPC (v2) API
Moco (v2) API
^^^^^^^^^^^^^

.. autoclass:: pl_bolts.models.self_supervised.Moco_v2
.. autoclass:: pl_bolts.models.self_supervised.MoCo
:noindex:

SimCLR
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ module = [
"pl_bolts.models.self_supervised.cpc.transforms",
"pl_bolts.models.self_supervised.evaluator",
"pl_bolts.models.self_supervised.moco.callbacks",
"pl_bolts.models.self_supervised.moco.moco2_module",
"pl_bolts.models.self_supervised.moco.transforms",
"pl_bolts.models.self_supervised.resnets",
"pl_bolts.models.self_supervised.simclr.simclr_finetuner",
Expand Down
12 changes: 8 additions & 4 deletions src/pl_bolts/models/detection/retinanet/retinanet_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,17 @@ def configure_optimizers(self):
@under_review()
def cli_main():
try: # Backward compatibility for Lightning CLI
from pytorch_lightning.cli import LightningCLI # PL v1.9+
except ImportError:
from pytorch_lightning.utilities.cli import LightningCLI # PL v1.8
import pytorch_lightning.cli

cli_class: Any = getattr(pytorch_lightning.cli, "LightningCLI") # PL v1.9+
except Exception:
Borda marked this conversation as resolved.
Show resolved Hide resolved
import pytorch_lightning.utilities.cli

cli_class = getattr(pytorch_lightning.utilities.cli, "LightningCLI") # PL v1.8
Borda marked this conversation as resolved.
Show resolved Hide resolved

from pl_bolts.datamodules import VOCDetectionDataModule

LightningCLI(RetinaNet, VOCDetectionDataModule, seed_everything_default=42)
cli_class(RetinaNet, VOCDetectionDataModule, seed_everything_default=42)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cli_class(RetinaNet, VOCDetectionDataModule, seed_everything_default=42)
LightningCLI(RetinaNet, VOCDetectionDataModule, seed_everything_default=42)



if __name__ == "__main__":
Expand Down
12 changes: 10 additions & 2 deletions src/pl_bolts/models/detection/yolo/yolo_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torch
import torch.nn as nn
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities.cli import LightningCLI
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch import Tensor, optim

Expand Down Expand Up @@ -614,4 +613,13 @@ def _resize(self, image: Tensor, target: TARGET) -> Tuple[Tensor, TARGET]:


if __name__ == "__main__":
LightningCLI(CLIYOLO, ResizedVOCDetectionDataModule, seed_everything_default=42)
try: # Backward compatibility for Lightning CLI
Borda marked this conversation as resolved.
Show resolved Hide resolved
import pytorch_lightning.cli

cli_class: Any = getattr(pytorch_lightning.cli, "LightningCLI") # PL v1.9+
except Exception:
import pytorch_lightning.utilities.cli

cli_class = getattr(pytorch_lightning.utilities.cli, "LightningCLI") # PL v1.8

cli_class(CLIYOLO, ResizedVOCDetectionDataModule, seed_everything_default=42)
4 changes: 2 additions & 2 deletions src/pl_bolts/models/self_supervised/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pl_bolts.models.self_supervised.byol.byol_module import BYOL
from pl_bolts.models.self_supervised.cpc.cpc_module import CPC_v2
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator
from pl_bolts.models.self_supervised.moco.moco2_module import Moco_v2
from pl_bolts.models.self_supervised.moco.moco_module import MoCo
from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR
from pl_bolts.models.self_supervised.simsiam.simsiam_module import SimSiam
from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner
Expand All @@ -31,7 +31,7 @@
"BYOL",
"CPC_v2",
"SSLEvaluator",
"Moco_v2",
"MoCo",
"SimCLR",
"SimSiam",
"SSLFineTuner",
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/models/self_supervised/moco/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


@under_review()
class MocoLRScheduler(Callback):
class MoCoLRScheduler(Callback):
def __init__(self, initial_lr=0.03, use_cosine_scheduler=False, schedule=(120, 160), max_epochs=200) -> None:
super().__init__()
self.lr = initial_lr
Expand Down
Loading