Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swav checks #900

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
258241f
default use torchvision transforms for swav
Atharva-Phatak Oct 3, 2022
c1b8023
added refactor loss into a new function in SWAVModule
Atharva-Phatak Oct 3, 2022
5224dcd
import fixes
Atharva-Phatak Oct 3, 2022
d77bf10
remove under review tag from model implementation
Atharva-Phatak Oct 3, 2022
34fe46d
add reviewd swav_finetuner.py
Atharva-Phatak Oct 3, 2022
73fd788
review swav cli-main
Atharva-Phatak Oct 3, 2022
f210d60
remove under_review blocks
Atharva-Phatak Oct 3, 2022
5065b40
Added tests and import fixes
Atharva-Phatak Oct 6, 2022
65b44d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 6, 2022
f5871a3
remove unwanted imports and antipatterns
Atharva-Phatak Oct 6, 2022
9c8b9d6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 6, 2022
63fd56d
pre-commit-fixes
Atharva-Phatak Oct 6, 2022
26f655f
Merge branch 'master' into swav-improvements
Atharva-Phatak Oct 6, 2022
3f17aac
deepsource fix
Atharva-Phatak Oct 6, 2022
7f7c4bd
fix trainer issues as pointed by azure
Atharva-Phatak Oct 7, 2022
eec7dae
maybe azure fix
Atharva-Phatak Oct 7, 2022
289df9e
enable gpus
Atharva-Phatak Oct 7, 2022
c3609df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 7, 2022
867586b
maybe trainer fix
Atharva-Phatak Oct 7, 2022
0edd78c
retry azure fix :(
Atharva-Phatak Oct 7, 2022
1c86ced
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 7, 2022
4181f8d
retry azure fix :( :(
Atharva-Phatak Oct 7, 2022
7dcab90
Hopeful azure fix
Atharva-Phatak Oct 7, 2022
561cffd
run on multiple gpus
Atharva-Phatak Oct 8, 2022
faae10d
fix for cpu tests
Atharva-Phatak Oct 8, 2022
d59cb29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2022
9808144
fix for cpu tests
Atharva-Phatak Oct 8, 2022
90fa0a2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2022
525bade
mypy fixes
Atharva-Phatak Oct 8, 2022
722b6de
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2022
38ed5b0
mypy-fix
Atharva-Phatak Oct 8, 2022
4cb17fd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2022
e31d354
type-fix
Atharva-Phatak Oct 8, 2022
f636df6
typo-fix
Atharva-Phatak Oct 9, 2022
2ae69d7
mypy checks
Atharva-Phatak Oct 9, 2022
cb47037
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2022
a42b87d
mypy-fix
Atharva-Phatak Oct 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pl_bolts/models/self_supervised/ssl_finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
from torchmetrics import Accuracy

from pl_bolts.models.self_supervised import SSLEvaluator
from pl_bolts.utils.stability import under_review


@under_review()
class SSLFineTuner(LightningModule):
"""Finetunes a self-supervised learning backbone using the standard evaluation protocol of a singler layer MLP
with 1024 units.
Expand Down
2 changes: 2 additions & 0 deletions pl_bolts/models/self_supervised/swav/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pl_bolts.models.self_supervised.swav.loss import SWAVLoss
from pl_bolts.models.self_supervised.swav.swav_module import SwAV
from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50
from pl_bolts.models.self_supervised.swav.transforms import (
Expand All @@ -13,4 +14,5 @@
"SwAVEvalDataTransform",
"SwAVFinetuneTransform",
"SwAVTrainDataTransform",
"SWAVLoss",
]
133 changes: 133 additions & 0 deletions pl_bolts/models/self_supervised/swav/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch import distributed as dist


class SWAVLoss(nn.Module):
def __init__(
self,
temperature: float,
crops_for_assign: tuple,
nmb_crops: tuple,
sinkhorn_iterations: int,
epsilon: float,
gpus: int,
num_nodes: int,
):
"""Implementation for SWAV loss function.

Args:
temperature: loss temperature
crops_for_assign: list of crop ids for computing assignment
nmb_crops: number of global and local crops, ex: [2, 6]
sinkhorn_iterations: iterations for sinkhorn normalization
epsilon: epsilon val for swav assignments
gpus: number of gpus per node used in training, passed to SwAV module
to manage the queue and select distributed sinkhorn
num_nodes: num_nodes: number of nodes to train on
"""
super().__init__()
self.temperature = temperature
self.crops_for_assign = crops_for_assign
self.softmax = nn.Softmax(dim=1)
self.sinkhorn_iterations = sinkhorn_iterations
self.epsilon = epsilon
self.nmb_crops = nmb_crops
self.gpus = gpus
self.num_nodes = num_nodes
if self.gpus * self.num_nodes > 1:
self.assignment_fn = self.distributed_sinkhorn
else:
self.assignment_fn = self.sinkhorn

def forward(
self,
output: torch.Tensor,
embedding: torch.Tensor,
prototype_weights: torch.Tensor,
batch_size: int,
queue: Optional[torch.Tensor] = None,
use_queue: bool = False,
) -> Tuple[int, Optional[torch.Tensor], bool]:
loss = 0
for i, crop_id in enumerate(self.crops_for_assign):
with torch.no_grad():
out = output[batch_size * crop_id : batch_size * (crop_id + 1)]

# Time to use the queue
if queue is not None:
if use_queue or not torch.all(queue[i, -1, :] == 0):
use_queue = True
out = torch.cat((torch.mm(queue[i], prototype_weights.t()), out))
# fill the queue
queue[i, batch_size:] = self.queue[i, :-batch_size].clone() # type: ignore
queue[i, :batch_size] = embedding[crop_id * batch_size : (crop_id + 1) * batch_size] # type: ignore

# get assignments
q = torch.exp(out / self.epsilon).t()
q = self.assignment_fn(q, self.sinkhorn_iterations)[-batch_size:]

# cluster assignment prediction
subloss = 0
for v in np.delete(np.arange(np.sum(self.nmb_crops)), crop_id):
p = self.softmax(output[batch_size * v : batch_size * (v + 1)] / self.temperature)
subloss -= torch.mean(torch.sum(q * torch.log(p), dim=1))
loss += subloss / (np.sum(self.nmb_crops) - 1) # type: ignore
loss /= len(self.crops_for_assign) # type: ignore

return loss, queue, use_queue

def sinkhorn(self, Q: torch.Tensor, nmb_iters: int) -> torch.Tensor:
"""Implementation of Sinkhorn clustering."""
with torch.no_grad():
sum_Q = torch.sum(Q)
Q /= sum_Q

K, B = Q.shape

if self.gpus > 0:
u = torch.zeros(K).cuda()
r = torch.ones(K).cuda() / K
c = torch.ones(B).cuda() / B
else:
u = torch.zeros(K)
r = torch.ones(K) / K
c = torch.ones(B) / B

for _ in range(nmb_iters):
u = torch.sum(Q, dim=1)

Q *= (r / u).unsqueeze(1)
Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)

return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()

def distributed_sinkhorn(self, Q: torch.Tensor, nmb_iters: int) -> torch.Tensor:
"""Implementation of Distributed Sinkhorn."""
with torch.no_grad():
sum_Q = torch.sum(Q)
dist.all_reduce(sum_Q)
Q /= sum_Q

if self.gpus > 0:
u = torch.zeros(Q.shape[0]).cuda(non_blocking=True)
r = torch.ones(Q.shape[0]).cuda(non_blocking=True) / Q.shape[0]
c = torch.ones(Q.shape[1]).cuda(non_blocking=True) / (self.gpus * Q.shape[1])
else:
u = torch.zeros(Q.shape[0])
r = torch.ones(Q.shape[0]) / Q.shape[0]
c = torch.ones(Q.shape[1]) / (self.gpus * Q.shape[1])

curr_sum = torch.sum(Q, dim=1)
dist.all_reduce(curr_sum)

for _ in range(nmb_iters):
u = curr_sum
Q *= (r / u).unsqueeze(1)
Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)
curr_sum = torch.sum(Q, dim=1)
dist.all_reduce(curr_sum)
return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()
2 changes: 0 additions & 2 deletions pl_bolts/models/self_supervised/swav/swav_finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from pl_bolts.models.self_supervised.swav.swav_module import SwAV
from pl_bolts.models.self_supervised.swav.transforms import SwAVFinetuneTransform
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization, stl10_normalization
from pl_bolts.utils.stability import under_review


@under_review()
def cli_main(): # pragma: no cover
from pl_bolts.datamodules import ImagenetDataModule, STL10DataModule

Expand Down
112 changes: 22 additions & 90 deletions pl_bolts/models/self_supervised/swav/swav_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
import os
from argparse import ArgumentParser

import numpy as np
import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torch import distributed as dist
from torch import nn

from pl_bolts.models.self_supervised.swav.loss import SWAVLoss
from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50
from pl_bolts.optimizers.lars import LARS
from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay
Expand All @@ -17,10 +16,8 @@
imagenet_normalization,
stl10_normalization,
)
from pl_bolts.utils.stability import under_review


@under_review()
class SwAV(LightningModule):
def __init__(
self,
Expand Down Expand Up @@ -129,19 +126,21 @@ def __init__(
self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs

if self.gpus * self.num_nodes > 1:
self.get_assignments = self.distributed_sinkhorn
else:
self.get_assignments = self.sinkhorn

self.model = self.init_model()

self.criterion = SWAVLoss(
gpus=self.gpus,
num_nodes=self.num_nodes,
temperature=self.temperature,
crops_for_assign=self.crops_for_assign,
nmb_crops=self.nmb_crops,
sinkhorn_iterations=self.sinkhorn_iterations,
epsilon=self.epsilon,
)
self.use_the_queue = None
# compute iters per epoch
global_batch_size = self.num_nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size
self.train_iters_per_epoch = self.num_samples // global_batch_size

self.queue = None
self.softmax = nn.Softmax(dim=1)

def setup(self, stage):
if self.queue_length > 0:
Expand Down Expand Up @@ -216,33 +215,17 @@ def shared_step(self, batch):
embedding = embedding.detach()
bs = inputs[0].size(0)

# 3. swav loss computation
loss = 0
for i, crop_id in enumerate(self.crops_for_assign):
with torch.no_grad():
out = output[bs * crop_id : bs * (crop_id + 1)]

# 4. time to use the queue
if self.queue is not None:
if self.use_the_queue or not torch.all(self.queue[i, -1, :] == 0):
self.use_the_queue = True
out = torch.cat((torch.mm(self.queue[i], self.model.prototypes.weight.t()), out))
# fill the queue
self.queue[i, bs:] = self.queue[i, :-bs].clone()
self.queue[i, :bs] = embedding[crop_id * bs : (crop_id + 1) * bs]

# 5. get assignments
q = torch.exp(out / self.epsilon).t()
q = self.get_assignments(q, self.sinkhorn_iterations)[-bs:]

# cluster assignment prediction
subloss = 0
for v in np.delete(np.arange(np.sum(self.nmb_crops)), crop_id):
p = self.softmax(output[bs * v : bs * (v + 1)] / self.temperature)
subloss -= torch.mean(torch.sum(q * torch.log(p), dim=1))
loss += subloss / (np.sum(self.nmb_crops) - 1)
loss /= len(self.crops_for_assign)

# SWAV loss computation
loss, queue, use_queue = self.criterion(
output=output,
embedding=embedding,
prototype_weights=self.model.prototypes.weight,
batch_size=bs,
queue=self.queue,
use_queue=self.use_the_queue,
)
self.queue = queue
self.use_the_queue = use_queue
return loss

def training_step(self, batch, batch_idx):
Expand Down Expand Up @@ -302,56 +285,6 @@ def configure_optimizers(self):

return [optimizer], [scheduler]

def sinkhorn(self, Q, nmb_iters):
with torch.no_grad():
sum_Q = torch.sum(Q)
Q /= sum_Q

K, B = Q.shape

if self.gpus > 0:
u = torch.zeros(K).cuda()
r = torch.ones(K).cuda() / K
c = torch.ones(B).cuda() / B
else:
u = torch.zeros(K)
r = torch.ones(K) / K
c = torch.ones(B) / B

for _ in range(nmb_iters):
u = torch.sum(Q, dim=1)

Q *= (r / u).unsqueeze(1)
Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)

return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()

def distributed_sinkhorn(self, Q, nmb_iters):
with torch.no_grad():
sum_Q = torch.sum(Q)
dist.all_reduce(sum_Q)
Q /= sum_Q

if self.gpus > 0:
u = torch.zeros(Q.shape[0]).cuda(non_blocking=True)
r = torch.ones(Q.shape[0]).cuda(non_blocking=True) / Q.shape[0]
c = torch.ones(Q.shape[1]).cuda(non_blocking=True) / (self.gpus * Q.shape[1])
else:
u = torch.zeros(Q.shape[0])
r = torch.ones(Q.shape[0]) / Q.shape[0]
c = torch.ones(Q.shape[1]) / (self.gpus * Q.shape[1])

curr_sum = torch.sum(Q, dim=1)
dist.all_reduce(curr_sum)

for it in range(nmb_iters):
u = curr_sum
Q *= (r / u).unsqueeze(1)
Q *= (c / torch.sum(Q, dim=0)).unsqueeze(0)
curr_sum = torch.sum(Q, dim=1)
dist.all_reduce(curr_sum)
return (Q / torch.sum(Q, dim=0, keepdim=True)).t().float()

@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
Expand Down Expand Up @@ -446,7 +379,6 @@ def add_model_specific_args(parent_parser):
return parser


@under_review()
def cli_main():
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
Expand Down
Loading