Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reviewing GAN basics, VisionDataModule, MNISTDataModule, CIFAR10DataModule #843

Merged
merged 16 commits into from
Jul 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pl_bolts/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
CIFAR10 = None


@under_review()
class CIFAR10DataModule(VisionDataModule):
"""
.. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/01/
Expand Down
2 changes: 0 additions & 2 deletions pl_bolts/datamodules/mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets import MNIST
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -12,7 +11,6 @@
warn_missing_pkg("torchvision")


@under_review()
class MNISTDataModule(VisionDataModule):
"""
.. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png
Expand Down
36 changes: 33 additions & 3 deletions pl_bolts/datamodules/vision_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset, random_split

from pl_bolts.utils.stability import under_review


@under_review()
class VisionDataModule(LightningDataModule):

EXTRA_ARGS: dict = {}
Expand All @@ -30,6 +27,9 @@ def __init__(
shuffle: bool = True,
pin_memory: bool = True,
drop_last: bool = False,
train_transforms: Optional[Callable] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we include these transforms into the class docstring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we should add it. What will be an ideal way of doing this? Do I need to create another pull request for this?

Copy link
Contributor

@otaj otaj Aug 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this as well, good catch @luca-medeiros!

@shivammehta007, yes, opening a new PR is the only option (not even I have push rights to master 😂). Since this PR is already merged, there really is no other option. Just write there it's a followup of #843

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@otaj

Seems you are the only member doing reviews for #839.
What about assembling a small (3~5 people) team to help you out with it? I would be willing to help review as well!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luca-medeiros Oh, absolutely! That's the whole point of our Slack channel. I think you're not in there yet, can you ping me on PL Slack (@Ota) and I will add you there?

val_transforms: Optional[Callable] = None,
test_transforms: Optional[Callable] = None,
*args: Any,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -58,6 +58,36 @@ def __init__(
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
self._train_transforms = train_transforms
self._val_transforms = val_transforms
self._test_transforms = test_transforms

@property
def train_transforms(self) -> Optional[Callable[..., Any]]:
"""Optional transforms (or collection of transforms) you can apply to train dataset."""
return self._train_transforms

@train_transforms.setter
def train_transforms(self, t: Callable) -> None:
self._train_transforms = t

@property
def val_transforms(self) -> Optional[Callable[..., Any]]:
"""Optional transforms (or collection of transforms) you can apply to validation dataset."""
return self._val_transforms

@val_transforms.setter
def val_transforms(self, t: Callable) -> None:
self._val_transforms = t

@property
def test_transforms(self) -> Optional[Callable[..., Any]]:
"""Optional transforms (or collection of transforms) you can apply to test dataset."""
return self._test_transforms

@test_transforms.setter
def test_transforms(self, t: Callable) -> None:
self._test_transforms = t

def prepare_data(self, *args: Any, **kwargs: Any) -> None:
"""Saves files to data_dir."""
Expand Down
14 changes: 8 additions & 6 deletions pl_bolts/models/gans/basic/basic_gan_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

import torch
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from torch.nn import functional as F

from pl_bolts.models.gans.basic.components import Discriminator, Generator
from pl_bolts.utils.stability import under_review


@under_review()
class GAN(LightningModule):
"""Vanilla GAN implementation.

Expand All @@ -22,7 +21,7 @@ class GAN(LightningModule):
Example CLI::

# mnist
python basic_gan_module.py --gpus 1
python basic_gan_module.py --gpus 1

# imagenet
python basic_gan_module.py --gpus 1 --dataset 'imagenet2012'
Expand Down Expand Up @@ -166,7 +165,6 @@ def add_model_specific_args(parent_parser):
return parser


@under_review()
def cli_main(args=None):
from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, MNISTDataModule, STL10DataModule
Expand All @@ -193,8 +191,12 @@ def cli_main(args=None):

dm = dm_cls.from_argparse_args(args)
model = GAN(*dm.size(), **vars(args))
callbacks = [TensorboardGenerativeModelImageSampler(), LatentDimInterpolator(interpolate_epoch_interval=5)]
trainer = Trainer.from_argparse_args(args, callbacks=callbacks, progress_bar_refresh_rate=20)
callbacks = [
TensorboardGenerativeModelImageSampler(),
LatentDimInterpolator(interpolate_epoch_interval=5),
TQDMProgressBar(refresh_rate=20),
]
trainer = Trainer.from_argparse_args(args, callbacks=callbacks)
trainer.fit(model, datamodule=dm)
return dm, model, trainer

Expand Down
4 changes: 0 additions & 4 deletions pl_bolts/models/gans/basic/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
from torch import nn
from torch.nn import functional as F

from pl_bolts.utils.stability import under_review


@under_review()
class Generator(nn.Module):
def __init__(self, latent_dim, img_shape, hidden_dim=256):
super().__init__()
Expand All @@ -27,7 +24,6 @@ def forward(self, z):
return img


@under_review()
class Discriminator(nn.Module):
def __init__(self, img_shape, hidden_dim=1024):
super().__init__()
Expand Down
6 changes: 3 additions & 3 deletions tests/datamodules/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,17 @@ def test_cityscapes_datamodule(datadir):


@pytest.mark.parametrize("val_split, train_len", [(0.2, 48_000), (5_000, 55_000)])
def test_vision_data_module(datadir, val_split, train_len):
def test_vision_data_module(datadir, val_split, catch_warnings, train_len):
dm = _create_dm(MNISTDataModule, datadir, val_split=val_split)
assert len(dm.dataset_train) == train_len


@pytest.mark.parametrize("dm_cls", [BinaryMNISTDataModule, CIFAR10DataModule, FashionMNISTDataModule, MNISTDataModule])
def test_data_modules(datadir, dm_cls):
def test_data_modules(datadir, catch_warnings, dm_cls):
dm = _create_dm(dm_cls, datadir)
loader = dm.train_dataloader()
img, _ = next(iter(loader))
assert img.size() == torch.Size([2, *dm.size()])
assert img.size() == torch.Size([2, *dm.dims])


def _create_dm(dm_cls, datadir, **kwargs):
Expand Down
Empty file added tests/models/gans/__init__.py
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import warnings

import pytest
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms as transform_lib

Expand All @@ -10,17 +13,26 @@

@pytest.mark.parametrize(
"dm_cls",
[
pytest.param(MNISTDataModule, id="mnist"),
pytest.param(CIFAR10DataModule, id="cifar10"),
],
[pytest.param(MNISTDataModule, id="mnist"), pytest.param(CIFAR10DataModule, id="cifar10")],
)
def test_gan(tmpdir, datadir, dm_cls):
seed_everything()

def test_gan(tmpdir, datadir, catch_warnings, dm_cls):
# Validation loop for GANs is not well defined!
warnings.filterwarnings(
"ignore",
message="You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.",
category=UserWarning,
)
warnings.filterwarnings(
"ignore",
message="The dataloader, train_dataloader, does not have many workers which may be a bottleneck",
category=PossibleUserWarning,
)
seed_everything(1234)
dm = dm_cls(data_dir=datadir, num_workers=0)
model = GAN(*dm.size())
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
model = GAN(*dm.dims)
trainer = Trainer(
fast_dev_run=True, default_root_dir=tmpdir, max_epochs=-1, accelerator="auto", log_every_n_steps=1
)
trainer.fit(model, datamodule=dm)


Expand Down
Empty file.
38 changes: 38 additions & 0 deletions tests/models/gans/unit/test_basic_components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
import torch
from pytorch_lightning import seed_everything

from pl_bolts.models.gans.basic.components import Discriminator, Generator


@pytest.mark.parametrize(
"latent_dim, img_shape",
[
pytest.param(100, (3, 28, 28), id="100-multichannel"),
pytest.param(100, (1, 28, 28), id="100-singlechannel"),
],
)
def test_generator(catch_warnings, latent_dim, img_shape):
batch_dim = 10
seed_everything()
generator = Generator(latent_dim=latent_dim, img_shape=img_shape)
noise = torch.randn(batch_dim, latent_dim)
samples = generator(noise)
assert samples.shape == (batch_dim, *img_shape)


@pytest.mark.parametrize(
"img_shape",
[
pytest.param((3, 28, 28), id="discriminator-multichannel"),
pytest.param((1, 28, 28), id="discriminator-singlechannel"),
],
)
def test_discriminator(catch_warnings, img_shape):
batch_dim = 10
seed_everything()
discriminator = Discriminator(img_shape=img_shape)
samples = torch.randn(batch_dim, *img_shape)
real_or_fake = discriminator(samples)
assert real_or_fake.shape == (batch_dim, 1)
assert (torch.clamp(real_or_fake.clone(), 0, 1) == real_or_fake).all()