Skip to content

Commit

Permalink
swav-improvements (Lightning-Universe#903)
Browse files Browse the repository at this point in the history
  • Loading branch information
Atharva-Phatak authored and matsumotosan committed Oct 27, 2022
1 parent 486d858 commit 2f5538f
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions tests/models/self_supervised/unit/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
import torch
from PIL import Image

from pl_bolts.models.self_supervised.swav.transforms import (
SwAVEvalDataTransform,
SwAVFinetuneTransform,
SwAVTrainDataTransform,
)
from pl_bolts.transforms.self_supervised.simclr_transforms import (
SimCLREvalDataTransform,
SimCLRFinetuneTransform,
Expand Down Expand Up @@ -62,6 +67,53 @@ def test_swav_finetune_transform(catch_warnings):
assert view.size(1) == view.size(2) == input_height


@pytest.mark.parametrize(
"transform_cls",
[pytest.param(SwAVTrainDataTransform, id="train-data"), pytest.param(SwAVEvalDataTransform, id="eval-data")],
)
def test_swav_train_data_transform(catch_warnings, transform_cls):
# dummy image
img = np.random.randint(low=0, high=255, size=(32, 32, 3), dtype=np.uint8)
img = Image.fromarray(img)
crop_sizes = (96, 36)

# size of the generated views
transform = transform_cls(size_crops=crop_sizes)
views = transform(img)

# the transform must output a list or a tuple of images
assert isinstance(views, (list, tuple))

# the transform must output three images
# (2 Global Crops, 4 Local Crops, online evaluation view)
assert len(views) == 7

# all views are tensors
assert all(torch.is_tensor(v) for v in views)

# Global Views have equal size
assert all(v.size(1) == v.size(2) == crop_sizes[0] for v in views[:2])
# Check local views have same size
assert all(v.size(1) == v.size(2) == crop_sizes[1] for v in views[2 : len(views) - 1]) # Ignore online transform


def test_swav_finetune_transform(catch_warnings):
# dummy image
img = np.random.randint(low=0, high=255, size=(32, 32, 3), dtype=np.uint8)
img = Image.fromarray(img)

# size of the generated views
input_height = 96
transform = SwAVFinetuneTransform(input_height=input_height)
view = transform(img)

# the view generator is a tensor
assert torch.is_tensor(view)

# view has expected size
assert view.size(1) == view.size(2) == input_height


@pytest.mark.parametrize(
"transform_cls",
[pytest.param(SimCLRTrainDataTransform, id="train-data"), pytest.param(SimCLREvalDataTransform, id="eval-data")],
Expand Down

0 comments on commit 2f5538f

Please sign in to comment.