Skip to content

Commit

Permalink
Reviewing GAN basics, VisionDataModule, MNISTDataModule, CIFAR10DataM…
Browse files Browse the repository at this point in the history
…odule (#843)
  • Loading branch information
shivammehta25 authored Jul 29, 2022
1 parent 675b176 commit acc0c98
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 28 deletions.
1 change: 0 additions & 1 deletion pl_bolts/datamodules/cifar10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
CIFAR10 = None


@under_review()
class CIFAR10DataModule(VisionDataModule):
"""
.. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/01/
Expand Down
2 changes: 0 additions & 2 deletions pl_bolts/datamodules/mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pl_bolts.datamodules.vision_datamodule import VisionDataModule
from pl_bolts.datasets import MNIST
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
Expand All @@ -12,7 +11,6 @@
warn_missing_pkg("torchvision")


@under_review()
class MNISTDataModule(VisionDataModule):
"""
.. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png
Expand Down
36 changes: 33 additions & 3 deletions pl_bolts/datamodules/vision_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset, random_split

from pl_bolts.utils.stability import under_review


@under_review()
class VisionDataModule(LightningDataModule):

EXTRA_ARGS: dict = {}
Expand All @@ -30,6 +27,9 @@ def __init__(
shuffle: bool = True,
pin_memory: bool = True,
drop_last: bool = False,
train_transforms: Optional[Callable] = None,
val_transforms: Optional[Callable] = None,
test_transforms: Optional[Callable] = None,
*args: Any,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -58,6 +58,36 @@ def __init__(
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
self._train_transforms = train_transforms
self._val_transforms = val_transforms
self._test_transforms = test_transforms

@property
def train_transforms(self) -> Optional[Callable[..., Any]]:
"""Optional transforms (or collection of transforms) you can apply to train dataset."""
return self._train_transforms

@train_transforms.setter
def train_transforms(self, t: Callable) -> None:
self._train_transforms = t

@property
def val_transforms(self) -> Optional[Callable[..., Any]]:
"""Optional transforms (or collection of transforms) you can apply to validation dataset."""
return self._val_transforms

@val_transforms.setter
def val_transforms(self, t: Callable) -> None:
self._val_transforms = t

@property
def test_transforms(self) -> Optional[Callable[..., Any]]:
"""Optional transforms (or collection of transforms) you can apply to test dataset."""
return self._test_transforms

@test_transforms.setter
def test_transforms(self, t: Callable) -> None:
self._test_transforms = t

def prepare_data(self, *args: Any, **kwargs: Any) -> None:
"""Saves files to data_dir."""
Expand Down
14 changes: 8 additions & 6 deletions pl_bolts/models/gans/basic/basic_gan_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

import torch
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from torch.nn import functional as F

from pl_bolts.models.gans.basic.components import Discriminator, Generator
from pl_bolts.utils.stability import under_review


@under_review()
class GAN(LightningModule):
"""Vanilla GAN implementation.
Expand All @@ -22,7 +21,7 @@ class GAN(LightningModule):
Example CLI::
# mnist
python basic_gan_module.py --gpus 1
python basic_gan_module.py --gpus 1
# imagenet
python basic_gan_module.py --gpus 1 --dataset 'imagenet2012'
Expand Down Expand Up @@ -166,7 +165,6 @@ def add_model_specific_args(parent_parser):
return parser


@under_review()
def cli_main(args=None):
from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, MNISTDataModule, STL10DataModule
Expand All @@ -193,8 +191,12 @@ def cli_main(args=None):

dm = dm_cls.from_argparse_args(args)
model = GAN(*dm.size(), **vars(args))
callbacks = [TensorboardGenerativeModelImageSampler(), LatentDimInterpolator(interpolate_epoch_interval=5)]
trainer = Trainer.from_argparse_args(args, callbacks=callbacks, progress_bar_refresh_rate=20)
callbacks = [
TensorboardGenerativeModelImageSampler(),
LatentDimInterpolator(interpolate_epoch_interval=5),
TQDMProgressBar(refresh_rate=20),
]
trainer = Trainer.from_argparse_args(args, callbacks=callbacks)
trainer.fit(model, datamodule=dm)
return dm, model, trainer

Expand Down
4 changes: 0 additions & 4 deletions pl_bolts/models/gans/basic/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
from torch import nn
from torch.nn import functional as F

from pl_bolts.utils.stability import under_review


@under_review()
class Generator(nn.Module):
def __init__(self, latent_dim, img_shape, hidden_dim=256):
super().__init__()
Expand All @@ -27,7 +24,6 @@ def forward(self, z):
return img


@under_review()
class Discriminator(nn.Module):
def __init__(self, img_shape, hidden_dim=1024):
super().__init__()
Expand Down
6 changes: 3 additions & 3 deletions tests/datamodules/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,17 @@ def test_cityscapes_datamodule(datadir):


@pytest.mark.parametrize("val_split, train_len", [(0.2, 48_000), (5_000, 55_000)])
def test_vision_data_module(datadir, val_split, train_len):
def test_vision_data_module(datadir, val_split, catch_warnings, train_len):
dm = _create_dm(MNISTDataModule, datadir, val_split=val_split)
assert len(dm.dataset_train) == train_len


@pytest.mark.parametrize("dm_cls", [BinaryMNISTDataModule, CIFAR10DataModule, FashionMNISTDataModule, MNISTDataModule])
def test_data_modules(datadir, dm_cls):
def test_data_modules(datadir, catch_warnings, dm_cls):
dm = _create_dm(dm_cls, datadir)
loader = dm.train_dataloader()
img, _ = next(iter(loader))
assert img.size() == torch.Size([2, *dm.size()])
assert img.size() == torch.Size([2, *dm.dims])


def _create_dm(dm_cls, datadir, **kwargs):
Expand Down
Empty file added tests/models/gans/__init__.py
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import warnings

import pytest
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms as transform_lib

Expand All @@ -10,17 +13,26 @@

@pytest.mark.parametrize(
"dm_cls",
[
pytest.param(MNISTDataModule, id="mnist"),
pytest.param(CIFAR10DataModule, id="cifar10"),
],
[pytest.param(MNISTDataModule, id="mnist"), pytest.param(CIFAR10DataModule, id="cifar10")],
)
def test_gan(tmpdir, datadir, dm_cls):
seed_everything()

def test_gan(tmpdir, datadir, catch_warnings, dm_cls):
# Validation loop for GANs is not well defined!
warnings.filterwarnings(
"ignore",
message="You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.",
category=UserWarning,
)
warnings.filterwarnings(
"ignore",
message="The dataloader, train_dataloader, does not have many workers which may be a bottleneck",
category=PossibleUserWarning,
)
seed_everything(1234)
dm = dm_cls(data_dir=datadir, num_workers=0)
model = GAN(*dm.size())
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
model = GAN(*dm.dims)
trainer = Trainer(
fast_dev_run=True, default_root_dir=tmpdir, max_epochs=-1, accelerator="auto", log_every_n_steps=1
)
trainer.fit(model, datamodule=dm)


Expand Down
Empty file.
38 changes: 38 additions & 0 deletions tests/models/gans/unit/test_basic_components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pytest
import torch
from pytorch_lightning import seed_everything

from pl_bolts.models.gans.basic.components import Discriminator, Generator


@pytest.mark.parametrize(
"latent_dim, img_shape",
[
pytest.param(100, (3, 28, 28), id="100-multichannel"),
pytest.param(100, (1, 28, 28), id="100-singlechannel"),
],
)
def test_generator(catch_warnings, latent_dim, img_shape):
batch_dim = 10
seed_everything()
generator = Generator(latent_dim=latent_dim, img_shape=img_shape)
noise = torch.randn(batch_dim, latent_dim)
samples = generator(noise)
assert samples.shape == (batch_dim, *img_shape)


@pytest.mark.parametrize(
"img_shape",
[
pytest.param((3, 28, 28), id="discriminator-multichannel"),
pytest.param((1, 28, 28), id="discriminator-singlechannel"),
],
)
def test_discriminator(catch_warnings, img_shape):
batch_dim = 10
seed_everything()
discriminator = Discriminator(img_shape=img_shape)
samples = torch.randn(batch_dim, *img_shape)
real_or_fake = discriminator(samples)
assert real_or_fake.shape == (batch_dim, 1)
assert (torch.clamp(real_or_fake.clone(), 0, 1) == real_or_fake).all()

0 comments on commit acc0c98

Please sign in to comment.