diff --git a/configs/selfsup/_base_/datasets/coco_orl.py b/configs/selfsup/_base_/datasets/coco_orl.py new file mode 100644 index 000000000..6eecee059 --- /dev/null +++ b/configs/selfsup/_base_/datasets/coco_orl.py @@ -0,0 +1,57 @@ +import copy + +# dataset settings +dataset_type = 'mmdet.CocoDataset' +# data_root = 'data/coco/' +data_root = '../data/coco/' +file_client_args = dict(backend='disk') +view_pipeline = [ + dict( + type='RandomResizedCrop', + size=224, + interpolation='bicubic', + backend='pillow'), + dict(type='RandomFlip', prob=0.5), + dict( + type='RandomApply', + transforms=[ + dict( + type='ColorJitter', + brightness=0.4, + contrast=0.4, + saturation=0.2, + hue=0.1) + ], + prob=0.8), + dict( + type='RandomGrayscale', + prob=0.2, + keep_channels=True, + channel_weights=(0.114, 0.587, 0.2989)), + dict(type='RandomGaussianBlur', sigma_min=0.1, sigma_max=2.0, prob=1), + dict(type='RandomSolarize', prob=0) +] +view_pipeline1 = copy.deepcopy(view_pipeline) +view_pipeline2 = copy.deepcopy(view_pipeline) +view_pipeline2[4]['prob'] = 0.1 # gaussian blur +view_pipeline2[5]['prob'] = 0.2 # solarization +train_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict( + type='MultiView', + num_views=[1, 1], + transforms=[view_pipeline1, view_pipeline2]), + dict(type='PackSelfSupInputs', meta_keys=['img_path']) +] +train_dataloader = dict( + batch_size=64, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + collate_fn=dict(type='default_collate'), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + pipeline=train_pipeline)) diff --git a/configs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco.py b/configs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco.py new file mode 100644 index 000000000..eb7f9e35a --- /dev/null +++ b/configs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco.py @@ -0,0 +1,63 @@ +_base_ = [ + '../../_base_/models/byol.py', + '../../_base_/datasets/coco_orl.py', + '../../_base_/schedules/sgd_coslr-200e_in1k.py', + '../../_base_/default_runtime.py', +] + +# model settings +model = dict( + neck=dict( + type='NonLinearNeck', + in_channels=2048, + hid_channels=4096, + out_channels=256, + num_layers=2, + with_bias=False, + with_last_bn=False, + with_avg_pool=True), + head=dict( + type='LatentPredictHead', + predictor=dict( + type='NonLinearNeck', + in_channels=256, + hid_channels=4096, + out_channels=256, + num_layers=2, + with_bias=False, + with_last_bn=False, + with_avg_pool=False))) + +update_interval = 1 # interval for accumulate gradient +# Amp optimizer +optimizer = dict(type='SGD', lr=0.4, weight_decay=0.0001, momentum=0.9) +optim_wrapper = dict( + type='AmpOptimWrapper', + optimizer=optimizer, + accumulative_counts=update_interval, +) +warmup_epochs = 4 +total_epochs = 800 +# learning policy +param_scheduler = [ + # warmup + dict( + type='LinearLR', + start_factor=0.0001, + by_epoch=True, + end=warmup_epochs, + # Update the learning rate after every iters. + convert_to_iter_based=True), + # ConsineAnnealingLR/StepLR/.. + dict( + type='CosineAnnealingLR', + eta_min=0., + T_max=total_epochs, + by_epoch=True, + begin=warmup_epochs, + end=total_epochs) +] + +# runtime settings +default_hooks = dict(checkpoint=dict(interval=100)) +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=total_epochs) diff --git a/configs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco_extractor.py b/configs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco_extractor.py new file mode 100644 index 000000000..7c4ea9ab6 --- /dev/null +++ b/configs/selfsup/orl/stage1/orl_resnet50_8xb64-coslr-800e_coco_extractor.py @@ -0,0 +1,85 @@ +_base_ = [ + '../../_base_/models/byol.py', + '../../_base_/datasets/coco_orl.py', + '../../_base_/schedules/sgd_coslr-200e_in1k.py', + '../../_base_/default_runtime.py', +] +# model settings +model = dict( + neck=dict( + type='NonLinearNeck', + in_channels=2048, + hid_channels=4096, + out_channels=256, + num_layers=2, + with_bias=False, + with_last_bn=False, + with_avg_pool=True), + head=dict( + type='LatentPredictHead', + predictor=dict( + type='NonLinearNeck', + in_channels=256, + hid_channels=4096, + out_channels=256, + num_layers=2, + with_bias=False, + with_last_bn=False, + with_avg_pool=False))) + +update_interval = 1 # interval for accumulate gradient +# Amp optimizer +optimizer = dict(type='SGD', lr=0.4, weight_decay=0.0001, momentum=0.9) +optim_wrapper = dict( + type='AmpOptimWrapper', + optimizer=optimizer, + accumulative_counts=update_interval, +) +warmup_epochs = 4 +total_epochs = 5 +# learning policy +param_scheduler = [ + # warmup + dict( + type='LinearLR', + start_factor=0.0001, + by_epoch=True, + end=warmup_epochs, + # Update the learning rate after every iters. + convert_to_iter_based=True), + # ConsineAnnealingLR/StepLR/.. + dict( + type='CosineAnnealingLR', + eta_min=0., + T_max=total_epochs, + by_epoch=True, + begin=warmup_epochs, + end=total_epochs) +] + +# "mmselfsup/configs/selfsup/orl/stage1/ +# orl_resnet50_8xb64-coslr-800e_coco_extractor.py" +# runtime settings +default_hooks = dict(checkpoint=dict(interval=100)) +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=total_epochs) +# load_from = './work_dirs/selfsup/orl/stage1/ +# orl_resnet50_8xb64-coslr-800e_coco/epoch_100.pth' +# resume=True +custom_hooks = [ + dict( + type='ExtractorHook', + keys=10, + extract_dataloader=dict( + batch_size=512, + num_workers=6, + persistent_workers=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=True), + collate_fn=dict(type='default_collate'), + dataset=dict( + type={{_base_.dataset_type}}, + data_root={{_base_.data_root}}, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + pipeline={{_base_.train_pipeline}})), + normalize=True), +] diff --git a/mmselfsup/engine/hooks/__init__.py b/mmselfsup/engine/hooks/__init__.py index 147d254f5..0f2e42df3 100644 --- a/mmselfsup/engine/hooks/__init__.py +++ b/mmselfsup/engine/hooks/__init__.py @@ -1,10 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. from .deepcluster_hook import DeepClusterHook from .densecl_hook import DenseCLHook +from .extractor_hook import ExtractorHook from .odc_hook import ODCHook from .simsiam_hook import SimSiamHook from .swav_hook import SwAVHook __all__ = [ - 'DeepClusterHook', 'DenseCLHook', 'ODCHook', 'SimSiamHook', 'SwAVHook' + 'DeepClusterHook', 'DenseCLHook', 'ODCHook', 'SimSiamHook', 'SwAVHook', + 'ExtractorHook' ] diff --git a/mmselfsup/engine/hooks/deepcluster_hook.py b/mmselfsup/engine/hooks/deepcluster_hook.py index 902127fde..d6c862c1b 100644 --- a/mmselfsup/engine/hooks/deepcluster_hook.py +++ b/mmselfsup/engine/hooks/deepcluster_hook.py @@ -75,6 +75,7 @@ def deepcluster(self, runner) -> None: # step 1: get features runner.model.eval() features = self.extractor(runner.model.module) + runner.model.train() # step 2: get labels diff --git a/mmselfsup/engine/hooks/extractor_hook.py b/mmselfsup/engine/hooks/extractor_hook.py new file mode 100644 index 000000000..299366c7b --- /dev/null +++ b/mmselfsup/engine/hooks/extractor_hook.py @@ -0,0 +1,172 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import json +import os +import time +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +import torchvision.transforms as T +from mmengine.dist import is_distributed +from mmengine.hooks import Hook +from mmengine.logging import print_log +from mmengine.model import BaseModel +from torch.nn import functional as F + +from mmselfsup.models.utils import Extractor +from mmselfsup.registry import HOOKS + + +# forward global image for knn retrieval +def global_forward(img: list, model: BaseModel): + assert torch.is_floating_point(img[0]), 'image type mismatch' + x = torch.stack(img).cuda() + with torch.no_grad(): + x = model.backbone(x) + feats = model.neck(x)[0] + feats_norm = F.normalize(feats, dim=1) + return feats_norm.detach() + + +class Trans(object): + + def __init__(self): + global_trans_list = [T.Resize(256), T.CenterCrop(224)] + self.global_transform = T.Compose(global_trans_list) + self.img_transform = T.Compose([ + T.ToTensor(), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + +@HOOKS.register_module() +class ExtractorHook(Hook): + """feature extractor hook. + + This hook includes the global clustering process in DC. + + Args: + extractor (dict): Config dict for feature extraction. + clustering (dict): Config dict that specifies the clustering algorithm. + unif_sampling (bool): Whether to apply uniform sampling. + reweight (bool): Whether to apply loss re-weighting. + reweight_pow (float): The power of re-weighting. + init_memory (bool): Whether to initialize memory banks used in ODC. + Defaults to False. + initial (bool): Whether to call the hook initially. Defaults to True. + interval (int): Frequency of epochs to call the hook. Defaults to 1. + seed (int, optional): Random seed. Defaults to None. + """ + + def __init__(self, + keys: int, + extract_dataloader: dict, + normalize=True, + seed: Optional[int] = None) -> None: + + self.dist_mode = is_distributed() + self.keys = keys + self.dataset = extract_dataloader['dataset'] + self.extractor = Extractor( + extract_dataloader=extract_dataloader, + seed=seed, + dist_mode=self.dist_mode, + pool_cfg=None) + self.normalize = normalize + + def retrieve_knn(self, features: torch.Tensor): + # load data + data_root = self.dataset['data_root'] + data_ann = self.dataset['ann_file'] + data_prefix = self.dataset['data_prefix']['img'] + train_json = data_root + data_ann + train_root = data_root + data_prefix + # train_json = '../data/coco/annotations/instances_train2017.json' + # train_root = '../data/coco/train2017/' + with open(train_json, 'r') as json_file: + data = json.load(json_file) + + train_fns = [train_root + item['file_name'] for item in data['images']] + imgids = [item['id'] for item in data['images']] + knn_imgids = [] + # batch processing + # trans = Trans() + batch = 512 + keys = self.keys + # feat_bank = features + + # feats_bank = torch.from_numpy(np.load(feat_bank_npy)).cuda() + feat_bank = features + for i in range(0, len(train_fns), batch): + print('[INFO] processing batch: {}'.format(i + 1)) + start = time.time() + if (i + batch) < len(train_fns): + query_feats = feat_bank[i:i + batch, :] + else: + query_feats = feat_bank[i:len(train_fns), :] + similarity = torch.mm(query_feats, feat_bank.T) + I_knn = torch.topk(similarity, keys + 1, dim=1)[1].cpu() + I_knn = I_knn[:, 1:] # exclude itself (i.e., 1st nn) + knn_list = I_knn.numpy().tolist() + [knn_imgids.append(knn) for knn in knn_list] + end = time.time() + print('[INFO] batch {} took {:.4f} seconds'.format( + i + 1, end - start)) + + # 118287 for coco, 241690 for coco+ + num_image = len(train_fns) + save_dir = data_root + '/meta/' + save_path = save_dir + 'train2017_{}nn_instance.json'.format(keys) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + assert len(imgids) == len(knn_imgids) == len(train_fns) == num_image, \ + f'Mismatch number of training images, got: {len(knn_imgids)}' + # dict + data_new = {} + info = {} + image_info = {} + pseudo_anno = {} + info['knn_image_num'] = keys + print(data.keys()) + image_info['file_name'] = [ + item['file_name'] for item in data['images'] + ] + image_info['id'] = [item['id'] for item in data['images']] + pseudo_anno['image_id'] = imgids + pseudo_anno['knn_image_id'] = knn_imgids + data_new['info'] = info + data_new['images'] = image_info + data_new['pseudo_annotations'] = pseudo_anno + with open(save_path, 'w') as f: + json.dump(data_new, f) + print('[INFO] image-level knn json file has been saved to {}'.format( + save_path)) + + def after_run(self, runner): + self._extract_func(runner) + + def _extract_func(self, runner): + # step 1: get features + runner.model.eval() + features = self.extractor(runner.model.module)['feat'] + if self.normalize: + features = nn.functional.normalize( + torch.from_numpy(features), dim=1) + + # step 2: save features + if not self.dist_mode or (self.dist_mode and runner.rank == 0): + np.save( + '{}/feature_epoch_{}.npy'.format(runner.work_dir, + runner.epoch), + features.numpy()) + print_log( + 'Feature extraction done!!! total features: {}\t\ + feature dimension: {}'.format( + features.size(0), features.size(1)), + logger='current') + # features = torch.from_numpy(np.load(feat_bank_npy)).cuda() + # step3: retrieval knn + if runner.rank == 0: + self.retrieve_knn(features) + # self.retrieve_knn(features, runner, runner.model.module) diff --git a/mmselfsup/utils/collect.py b/mmselfsup/utils/collect.py index 3022eeb1f..6aa35ad40 100644 --- a/mmselfsup/utils/collect.py +++ b/mmselfsup/utils/collect.py @@ -53,6 +53,7 @@ def dist_forward_collect(func: object, data_loader: DataLoader, rank, world_size = get_dist_info() results = [] if rank == 0: + prog_bar = mmengine.ProgressBar(len(data_loader)) for _, data in enumerate(data_loader): with torch.no_grad(): diff --git a/tools/slurm_train.sh b/tools/slurm_train.sh index ac36d5082..5b56251dc 100644 --- a/tools/slurm_train.sh +++ b/tools/slurm_train.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash set -x - +source activate idea PARTITION=$1 JOB_NAME=$2 CONFIG=$3