diff --git a/passl/core/__init__.py b/passl/core/__init__.py index d1329172..cae98799 100644 --- a/passl/core/__init__.py +++ b/passl/core/__init__.py @@ -16,3 +16,6 @@ from .grad_scaler import GradScaler from .sync_utils import grad_sync, param_sync from .param_fuse import get_fused_params +from .config import Config +from .builder import Builder, PasslBuilder +from .trainer import PasslTrainer diff --git a/passl/core/builder.py b/passl/core/builder.py new file mode 100644 index 00000000..ab2c3ca5 --- /dev/null +++ b/passl/core/builder.py @@ -0,0 +1,245 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Any, Optional + +import yaml +import paddle + +from passl.core import manager, Config +from paddleseg.utils import f, logger +from paddleseg.utils.utils import CachedProperty as cached_property + + +class Builder(object): + """ + The base class for building components. + + Args: + config (Config): A Config class. + comp_list (list, optional): A list of component classes. Default: None + """ + + def __init__(self, config: Config, comp_list: Optional[list]=None): + super().__init__() + self.config = config + self.comp_list = comp_list + + def build_component(self, cfg): + """ + Create Python object, such as model, loss, dataset, etc. + """ + cfg = copy.deepcopy(cfg) + if 'type' not in cfg: + raise RuntimeError( + "It is not possible to create a component object from {}, as 'type' is not specified.". + format(cfg)) + + class_type = cfg.pop('type') + com_class = self.load_component_class(class_type) + + params = {} + for key, val in cfg.items(): + if self.is_meta_type(val): + params[key] = self.build_component(val) + elif isinstance(val, list): + params[key] = [ + self.build_component(item) + if self.is_meta_type(item) else item for item in val + ] + else: + params[key] = val + + try: + obj = self.build_component_impl(com_class, **params) + except Exception as e: + if hasattr(com_class, '__name__'): + com_name = com_class.__name__ + else: + com_name = '' + raise RuntimeError( + f"Tried to create a {com_name} object, but the operation has failed. " + "Please double check the arguments used to create the object.\n" + f"The error message is: \n{str(e)}") + + return obj + + def build_component_impl(self, component_class, *args, **kwargs): + return component_class(*args, **kwargs) + + def load_component_class(self, class_type): + for com in self.comp_list: + if class_type in com.components_dict: + return com[class_type] + raise RuntimeError("The specified component ({}) was not found.".format( + class_type)) + + @classmethod + def is_meta_type(cls, obj): + # TODO: should we define a protocol (see https://peps.python.org/pep-0544/#defining-a-protocol) + # to make it more pythonic? + return isinstance(obj, dict) and 'type' in obj + + @classmethod + def show_msg(cls, name, cfg): + msg = 'Use the following config to build {}\n'.format(name) + msg += str(yaml.dump({name: cfg}, Dumper=utils.NoAliasDumper)) + logger.info(msg[0:-1]) + + +class PasslBuilder(Builder): + """ + This class is responsible for building components for semantic segmentation. + """ + + def __init__(self, config, comp_list=None): + if comp_list is None: + comp_list = [ + manager.MODELS, manager.BACKBONES, manager.DATASETS, + manager.TRANSFORMS, manager.LOSSES, manager.OPTIMIZERS + ] + super().__init__(config, comp_list) + + @cached_property + def model(self) -> paddle.nn.Layer: + model_cfg = self.config.model_cfg + assert model_cfg != {}, \ + 'No model specified in the configuration file.' + + self.show_msg('model', model_cfg) + return self.build_component(model_cfg) + + @cached_property + def optimizer(self) -> paddle.optimizer.Optimizer: + opt_cfg = self.config.optimizer_cfg + assert opt_cfg != {}, \ + 'No optimizer specified in the configuration file.' + # For compatibility + if opt_cfg['type'] == 'adam': + opt_cfg['type'] = 'Adam' + if opt_cfg['type'] == 'sgd': + opt_cfg['type'] = 'SGD' + if opt_cfg['type'] == 'SGD' and 'momentum' in opt_cfg: + opt_cfg['type'] = 'Momentum' + logger.info('If the type is SGD and momentum in optimizer config, ' + 'the type is changed to Momentum.') + self.show_msg('optimizer', opt_cfg) + opt = self.build_component(opt_cfg) + opt = opt(self.model, self.lr_scheduler) + return opt + + @cached_property + def lr_scheduler(self) -> paddle.optimizer.lr.LRScheduler: + lr_cfg = self.config.lr_scheduler_cfg + assert lr_cfg != {}, \ + 'No lr_scheduler specified in the configuration file.' + + use_warmup = False + if 'warmup_iters' in lr_cfg: + use_warmup = True + warmup_iters = lr_cfg.pop('warmup_iters') + assert 'warmup_start_lr' in lr_cfg, \ + "When use warmup, please set warmup_start_lr and warmup_iters in lr_scheduler" + warmup_start_lr = lr_cfg.pop('warmup_start_lr') + end_lr = lr_cfg['learning_rate'] + + lr_type = lr_cfg.pop('type') + if lr_type == 'PolynomialDecay': + iters = self.config.iters - warmup_iters if use_warmup else self.config.iters + iters = max(iters, 1) + lr_cfg.setdefault('decay_steps', iters) + + try: + lr_sche = getattr(paddle.optimizer.lr, lr_type)(**lr_cfg) + except Exception as e: + raise RuntimeError( + "Create {} has failed. Please check lr_scheduler in config. " + "The error message: {}".format(lr_type, e)) + + if use_warmup: + lr_sche = paddle.optimizer.lr.LinearWarmup( + learning_rate=lr_sche, + warmup_steps=warmup_iters, + start_lr=warmup_start_lr, + end_lr=end_lr) + + return lr_sche + + @cached_property + def loss(self) -> dict: + loss_cfg = self.config.loss_cfg + assert loss_cfg != {}, \ + 'No loss specified in the configuration file.' + + # check and synchronize the ignore_index in model config and dataset class + if self.config.train_dataset_cfg['type'] != 'Dataset': + assert hasattr(self.train_dataset_class, 'IGNORE_INDEX'), \ + 'If train_dataset class is not `Dataset`, it must have `IGNORE_INDEX` attr.' + + self.show_msg("loss", loss_cfg) + loss_dict = {'coef': loss_cfg['coef'], "types": []} + for item in loss_cfg['types']: + loss_dict['types'].append(self.build_component(item)) + return loss_dict + + @cached_property + def train_dataset(self) -> paddle.io.Dataset: + dataset_cfg = self.config.train_dataset_cfg + assert dataset_cfg != {}, \ + 'No train_dataset specified in the configuration file.' + self.show_msg('train_dataset', dataset_cfg) + dataset = self.build_component(dataset_cfg) + assert len(dataset) != 0, \ + 'The number of samples in train_dataset is 0. Please check whether the dataset is valid.' + return dataset + + @cached_property + def val_dataset(self) -> paddle.io.Dataset: + dataset_cfg = self.config.val_dataset_cfg + assert dataset_cfg != {}, \ + 'No val_dataset specified in the configuration file.' + self.show_msg('val_dataset', dataset_cfg) + dataset = self.build_component(dataset_cfg) + if len(dataset) == 0: + logger.warning( + 'The number of samples in val_dataset is 0. Please ensure this is the desired behavior.' + ) + return dataset + + @cached_property + def train_dataset_class(self) -> Any: + dataset_cfg = self.config.train_dataset_cfg + assert dataset_cfg != {}, \ + 'No train_dataset specified in the configuration file.' + dataset_type = dataset_cfg.get('type') + return self.load_component_class(dataset_type) + + @cached_property + def val_dataset_class(self) -> Any: + dataset_cfg = self.config.val_dataset_cfg + assert dataset_cfg != {}, \ + 'No val_dataset specified in the configuration file.' + dataset_type = dataset_cfg.get('type') + return self.load_component_class(dataset_type) + + @cached_property + def val_transforms(self) -> list: + dataset_cfg = self.config.val_dataset_cfg + assert dataset_cfg != {}, \ + 'No val_dataset specified in the configuration file.' + transforms = [] + for item in dataset_cfg.get('transforms', []): + transforms.append(self.build_component(item)) + return transforms diff --git a/passl/core/config.py b/passl/core/config.py new file mode 100644 index 00000000..521a1579 --- /dev/null +++ b/passl/core/config.py @@ -0,0 +1,209 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six +import codecs +import os +from ast import literal_eval +from typing import Dict, Optional + +import yaml +import passl.utils as utils + +_INHERIT_KEY = '_inherited_' +_BASE_KEY = '_base_' + + +class Config(object): + """ + Configuration parsing. + + The following hyper-parameters are available in the config file: + batch_size: The number of samples per gpu. + epochs: The total training epochs. + train_dataset: A training data config including type/data_root/transforms/mode. + For data type, please refer to paddleseg.datasets. + For specific transforms, please refer to paddleseg.transforms.transforms. + val_dataset: A validation data config including type/data_root/transforms/mode. + optimizer: A optimizer config. Please refer to paddleseg.optimizers. + loss: A loss config. Multi-loss config is available. The loss type order is + consistent with the seg model outputs, where the coef term indicates the + weight of corresponding loss. Note that the number of coef must be the + same as the number of model outputs, and there could be only one loss type + if using the same loss type among the outputs, otherwise the number of + loss type must be consistent with coef. + model: A model config including type/backbone and model-dependent arguments. + For model type, please refer to paddleseg.models. + For backbone, please refer to paddleseg.models.backbones. + + Args: + path (str) : The path of config file, supports yaml format only. + opts (list, optional): Use opts to update the key-value pairs of all options. + + """ + + def __init__( + self, + path: str, + learning_rate: Optional[float]=None, + batch_size: Optional[int]=None, + epochs: Optional[int]=None, + opts: Optional[list]=None, ): + assert os.path.exists(path), \ + 'Config path ({}) does not exist'.format(path) + assert path.endswith('yml') or path.endswith('yaml'), \ + 'Config file ({}) should be yaml format'.format(path) + + self.dic = self._parse_from_yaml(path) + self.dic = self.update_config_dict( + self.dic, + learning_rate=learning_rate, + batch_size=batch_size, + epochs=epochs, + opts=opts) + + @property + def batch_size(self) -> int: + return self.dic.get('batch_size') + + @property + def epochs(self) -> int: + return self.dic.get('epochs') + + @property + def to_static_training(self) -> bool: + return self.dic.get('to_static_training', False) + + @property + def model_cfg(self) -> Dict: + return self.dic.get('model', {}).copy() + + @property + def loss_cfg(self) -> Dict: + return self.dic.get('loss', {}).copy() + + @property + def distill_loss_cfg(self) -> Dict: + return self.dic.get('distill_loss', {}).copy() + + @property + def lr_scheduler_cfg(self) -> Dict: + return self.dic.get('lr_scheduler', {}).copy() + + @property + def optimizer_cfg(self) -> Dict: + return self.dic.get('optimizer', {}).copy() + + @property + def train_dataset_cfg(self) -> Dict: + return self.dic.get('train_dataset', {}).copy() + + @property + def val_dataset_cfg(self) -> Dict: + return self.dic.get('val_dataset', {}).copy() + + # TODO merge test_config into val_dataset + @property + def test_config(self) -> Dict: + return self.dic.get('test_config', {}).copy() + + @classmethod + def update_config_dict(cls, dic: dict, *args, **kwargs) -> dict: + return update_config_dict(dic, *args, **kwargs) + + @classmethod + def _parse_from_yaml(cls, path: str, *args, **kwargs) -> dict: + return parse_from_yaml(path, *args, **kwargs) + + def __str__(self) -> str: + # Use NoAliasDumper to avoid yml anchor + return yaml.dump(self.dic, Dumper=utils.NoAliasDumper) + + +def parse_from_yaml(path: str): + """Parse a yaml file and build config""" + with codecs.open(path, 'r', 'utf-8') as file: + dic = yaml.load(file, Loader=yaml.FullLoader) + + def merge_config_dicts(dic, base_dic): + """Merge dic to base_dic and return base_dic.""" + base_dic = base_dic.copy() + dic = dic.copy() + + if not dic.get(_INHERIT_KEY, True): + dic.pop(_INHERIT_KEY) + return dic + + for key, val in dic.items(): + if isinstance(val, dict) and key in base_dic: + base_dic[key] = merge_config_dicts(val, base_dic[key]) + else: + base_dic[key] = val + + return base_dic + + if _BASE_KEY in dic: + base_files = dic.pop(_BASE_KEY) + if isinstance(base_files, str): + base_files = [base_files] + for bf in base_files: + base_path = os.path.join(os.path.dirname(path), bf) + base_dic = parse_from_yaml(base_path) + dic = merge_config_dicts(dic, base_dic) + + return dic + + + +def update_config_dict(dic: dict, + learning_rate: Optional[float]=None, + batch_size: Optional[int]=None, + epochs: Optional[int]=None, + opts: Optional[list]=None): + """Update config""" + # TODO: If the items to update are marked as anchors in the yaml file, + # we should synchronize the references. + dic = dic.copy() + + if learning_rate: + dic['lr_scheduler']['learning_rate'] = learning_rate + if batch_size: + dic['batch_size'] = batch_size + if epochs: + dic['epochs'] = epochs + + if opts is not None: + for item in opts: + assert ('=' in item) and (len(item.split('=')) == 2), "--opts params should be key=value," \ + " such as `--opts batch_size=1 test_config.scales=0.75,1.0,1.25`, " \ + "but got ({})".format(opts) + + key, value = item.split('=') + if isinstance(value, six.string_types): + try: + value = literal_eval(value) + except ValueError: + pass + except SyntaxError: + pass + key_list = key.split('.') + + tmp_dic = dic + for subkey in key_list[:-1]: + assert subkey in tmp_dic, "Can not update {}, because it is not in config.".format( + key) + tmp_dic = tmp_dic[subkey] + tmp_dic[key_list[-1]] = value + + return dic diff --git a/passl/core/manager.py b/passl/core/manager.py new file mode 100644 index 00000000..33478209 --- /dev/null +++ b/passl/core/manager.py @@ -0,0 +1,148 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from collections.abc import Sequence + +import warnings + + +class ComponentManager: + """ + Implement a manager class to add the new component properly. + The component can be added as either class or function type. + + Args: + name (str): The name of component. + + Returns: + A callable object of ComponentManager. + + Examples 1: + + from paddleseg.cvlibs.manager import ComponentManager + + model_manager = ComponentManager() + + class AlexNet: ... + class ResNet: ... + + model_manager.add_component(AlexNet) + model_manager.add_component(ResNet) + + # Or pass a sequence alliteratively: + model_manager.add_component([AlexNet, ResNet]) + print(model_manager.components_dict) + # {'AlexNet': , 'ResNet': } + + Examples 2: + + # Or an easier way, using it as a Python decorator, while just add it above the class declaration. + from paddleseg.cvlibs.manager import ComponentManager + + model_manager = ComponentManager() + + @model_manager.add_component + class AlexNet: ... + + @model_manager.add_component + class ResNet: ... + + print(model_manager.components_dict) + # {'AlexNet': , 'ResNet': } + """ + + def __init__(self, name=None): + self._components_dict = dict() + self._name = name + + def __len__(self): + return len(self._components_dict) + + def __repr__(self): + name_str = self._name if self._name else self.__class__.__name__ + return "{}:{}".format(name_str, list(self._components_dict.keys())) + + def __getitem__(self, item): + if item not in self._components_dict.keys(): + raise KeyError("{} does not exist in availabel {}".format(item, + self)) + return self._components_dict[item] + + @property + def components_dict(self): + return self._components_dict + + @property + def name(self): + return self._name + + def _add_single_component(self, component): + """ + Add a single component into the corresponding manager. + + Args: + component (function|class): A new component. + + Raises: + TypeError: When `component` is neither class nor function. + KeyError: When `component` was added already. + """ + + # Currently only support class or function type + if not (inspect.isclass(component) or inspect.isfunction(component)): + raise TypeError("Expect class/function type, but received {}". + format(type(component))) + + # Obtain the internal name of the component + component_name = component.__name__ + + # Check whether the component was added already + if component_name in self._components_dict.keys(): + warnings.warn("{} exists already! It is now updated to {} !!!". + format(component_name, component)) + self._components_dict[component_name] = component + + else: + # Take the internal name of the component as its key + self._components_dict[component_name] = component + + def add_component(self, components): + """ + Add component(s) into the corresponding manager. + + Args: + components (function|class|list|tuple): Support four types of components. + + Returns: + components (function|class|list|tuple): Same with input components. + """ + + # Check whether the type is a sequence + if isinstance(components, Sequence): + for component in components: + self._add_single_component(component) + else: + component = components + self._add_single_component(component) + + return components + + +MODELS = ComponentManager("models") +BACKBONES = ComponentManager("backbones") +DATASETS = ComponentManager("datasets") +TRANSFORMS = ComponentManager("transforms") +LOSSES = ComponentManager("losses") +OPTIMIZERS = ComponentManager("optimizers") diff --git a/passl/core/trainer.py b/passl/core/trainer.py new file mode 100644 index 00000000..97670cb4 --- /dev/null +++ b/passl/core/trainer.py @@ -0,0 +1,112 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from copy import deepcopy +import os +import random + +import paddle +import numpy as np + +from passl.core import Config, PasslBuilder +from passl.utils import utils + + +class PassleTrainer(): + def __init__(self, args) -> None: + # load configs + self.args = args + self.cfg = Config(self.args.config, + learning_rate=self.args.learning_rate, + epochs=self.args.epochs, + batch_size=self.args.batch_size, + opts=self.args.opts) + + # build registered componets + self.builder = PasslBuilder(self.cfg) + + self.init_env() + + self.model_conversion() + + self.component_init() + + def init_env(self) -> None: + utils.show_env_info() + utils.show_cfg_info(self.cfg) + utils.set_seed(self.args.seed) + utils.set_device(self.args.device) + utils.set_cv2_num_threads(self.args.num_workers) + + self.init_save_dir() + if self.args.use_vdl: + self.init_visualdl() + + def init_save_dir(self): + if not os.path.isdir(self.args.save_dir): + if os.path.exists(self.args.save_dir): + os.remove(self.args.save_dir) + os.makedirs(self.args.save_dir, exist_ok=True) + + def init_visualdl(self) -> None: + from visualdl import LogWriter + log_writer = LogWriter(self.args.save_dir) + + def model_conversion(self): + self.model_syncbn_convert() + if self.args.ema: + self.ema_model = self.build_ema() + self.resume_model() + + def model_syncbn_convert(self) -> None: + self.model = utils.convert_sync_batchnorm(self.builder.model, self.args.device) + + def build_ema(self): + ema_model = deepcopy(self.model) + ema_model.eval() + for param in ema_model.parameters(): + param.stop_gradient = True + return ema_model + + def resume_model(self): + self.start_epoch = 0 + if self.args.resume_model is not None: + self.start_epoch = utils.resume(self.builder.model, self.builder.optimizer, self.args.resume_model) + + def distributed_conversion(self): + if paddle.distributed.ParallelEnv().nranks > 1: + paddle.distributed.fleet.init(is_collective=True) + self.builder.optimizer.optimizer = paddle.distributed.fleet.distributed_optimizer( + self.builder.optimizer) # The return is Fleet object + self.builder.model = paddle.distributed.fleet.distributed_model(self.builder.model) + + def component_init(self): + self.batch_sampler = paddle.io.DistributedBatchSampler( + self.builder.train_dataset, batch_size=self.args.batch_size, shuffle=True, drop_last=True) + self.loader = paddle.io.DataLoader( + self.builder.train_dataset, + batch_sampler=self.batch_sampler, + num_workers=self.args.num_workers, + return_list=True, + worker_init_fn=utils.worker_init_fn, ) + + def train(self) -> None: + print('You have build a PassleTrainer, please inherit this class and implement the train method for each model') + + + +if __name__ == '__main__': + trainer = PassleTrainer() + trainer.train() diff --git a/passl/data/dataset/imagefolder_dataset.py b/passl/data/dataset/imagefolder_dataset.py index 618f5d77..7e661070 100644 --- a/passl/data/dataset/imagefolder_dataset.py +++ b/passl/data/dataset/imagefolder_dataset.py @@ -12,18 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union -import numpy as np import os +import numpy as np +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union import paddle +from passl.core import manager from passl.data.dataset import default_loader IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") - +@manager.DATASETS.add_component class ImageFolder(paddle.io.Dataset): """ Code ref from https://github.com/pytorch/vision/blob/main/torchvision/datasets/folder.py @@ -38,9 +39,9 @@ class ImageFolder(paddle.io.Dataset): the same methods can be overridden to customize the dataset. Args: root (string): Root directory path. - transform (callable, optional): A function/transform that takes in an numpy image + transforms (callable, optional): A function/transforms that takes in an numpy image and returns a transformed version. E.g, ``transforms.RandomCrop`` - target_transform (callable, optional): A function/transform that takes in the + target_transforms (callable, optional): A function/transforms that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. is_valid_file (callable, optional): A function that takes path of an Image file @@ -53,8 +54,8 @@ class ImageFolder(paddle.io.Dataset): def __init__(self, root, - transform=None, - target_transform=None, + transforms=None, + target_transforms=None, loader=default_loader, extensions=IMG_EXTENSIONS): @@ -70,8 +71,8 @@ def __init__(self, self.imgs = samples self.targets = [s[1] for s in samples] - self.transform = transform - self.target_transform = target_transform + self.transforms = transforms + self.target_transforms = target_transforms self.loader = loader @@ -185,10 +186,10 @@ def find_classes(self, directory): def __getitem__(self, idx): path, target = self.imgs[idx] sample = self.loader(path) - if self.transform is not None: - sample = self.transform(sample) - if self.target_transform is not None: - target = self.target_transform(target) + if self.transforms is not None: + sample = self.transforms(sample) + if self.target_transforms is not None: + target = self.target_transforms(target) return (sample, np.int32(target)) def __len__(self) -> int: diff --git a/passl/data/preprocess/basic_transforms.py b/passl/data/preprocess/basic_transforms.py index 475c7985..bed1c015 100644 --- a/passl/data/preprocess/basic_transforms.py +++ b/passl/data/preprocess/basic_transforms.py @@ -31,6 +31,7 @@ from paddle.vision.transforms import ToTensor, Normalize from passl.utils import logger +from passl.core import manager __all__ = [ "Compose", @@ -65,7 +66,7 @@ class OperatorParamError(ValueError): """ pass - +@manager.TRANSFORMS.add_component class Compose(object): def __init__(self, transforms): self.transforms = transforms @@ -84,6 +85,7 @@ def __repr__(self) -> str: return format_string +@manager.TRANSFORMS.add_component class TwoViewsTransform(object): """Take two random crops of one image""" @@ -97,6 +99,7 @@ def __call__(self, x): return [im1, im2] +@manager.TRANSFORMS.add_component class DecodeImage(object): """ decode image """ @@ -182,6 +185,7 @@ def resize(img, size, interpolation=None, backend="cv2"): _RANDOM_INTERPOLATION = ('bilinear', 'bicubic') +@manager.TRANSFORMS.add_component class UnifiedResize(object): def __init__(self, interpolation=None, backend="cv2"): if interpolation == 'random': @@ -196,6 +200,7 @@ def __call__(self, img, size): return resize(img, size, interpolation, self.backend) +@manager.TRANSFORMS.add_component class ResizeImage(object): """ resize image """ @@ -231,6 +236,7 @@ def __call__(self, img): return self._resize_func(img, (w, h)) +@manager.TRANSFORMS.add_component class Resize(object): def __init__(self, size, @@ -303,7 +309,7 @@ def __call__(self, img): self.max_size) return self._resize_func(img, (w, h)) - +@manager.TRANSFORMS.add_component class CenterCropImage(object): """ crop image """ @@ -322,6 +328,7 @@ def __call__(self, img): return crop(img, h_start, w_start, h, w) +@manager.TRANSFORMS.add_component class CenterCrop(object): """ center crop image, align torchvision """ @@ -369,6 +376,7 @@ def __call__(self, img): return crop(img, crop_top, crop_left, crop_height, crop_width) +@manager.TRANSFORMS.add_component class RandCropImage(object): """ random crop image """ @@ -469,6 +477,7 @@ def resized_crop( return img +@manager.TRANSFORMS.add_component class RandomResizedCrop(object): def __init__( self, @@ -559,6 +568,7 @@ def __call__(self, img): antialias=self.antialias) +@manager.TRANSFORMS.add_component class RandomResizedCropWithTwoImages(RandomResizedCrop): def __init__( self, @@ -626,11 +636,13 @@ def __call__(self, img): antialias=self.antialias) +@manager.TRANSFORMS.add_component class RandomResizedCropAndInterpolation(RandCropImage): """ only rename """ pass +@manager.TRANSFORMS.add_component class MAERandCropImage(RandCropImage): """ RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. @@ -661,6 +673,7 @@ def __call__(self, img): return self._resize_func(img, size) +@manager.TRANSFORMS.add_component class RandFlipImage(object): """ random flip image flip_code: @@ -692,6 +705,7 @@ def __call__(self, img): return img +@manager.TRANSFORMS.add_component class RandomHorizontalFlip(object): def __init__(self, p=0.5): self.p = p @@ -703,6 +717,7 @@ def __call__(self, img): return img +@manager.TRANSFORMS.add_component class NormalizeImage(object): """ normalize image such as substract mean, divide std """ @@ -752,6 +767,7 @@ def __call__(self, img): return img.astype(self.output_dtype) +@manager.TRANSFORMS.add_component class ToCHWImage(object): """ convert hwc image to chw image """ @@ -766,6 +782,7 @@ def __call__(self, img): return img.transpose((2, 0, 1)) +@manager.TRANSFORMS.add_component class ColorJitter(PPColorJitter): """ColorJitter. """ @@ -786,6 +803,7 @@ def __call__(self, img): return img +@manager.TRANSFORMS.add_component class Pixels(object): def __init__(self, mode="const", mean=[0., 0., 0.]): self._mode = mode @@ -804,6 +822,7 @@ def __call__(self, h=224, w=224, c=3): ) +@manager.TRANSFORMS.add_component class RandomErasing(object): """RandomErasing. This code is adapted from https://github.com/zhunzhong07/Random-Erasing, and refer to Timm. @@ -855,6 +874,7 @@ def __call__(self, img): return img +@manager.TRANSFORMS.add_component class RandomApply(object): def __init__(self, transforms, p=0.5): self.transforms = transforms @@ -868,6 +888,7 @@ def __call__(self, img): return img +@manager.TRANSFORMS.add_component class RandomGrayscale(object): def __init__(self, p=0.1): self.p = p @@ -905,6 +926,7 @@ def __call__(self, img): return img +@manager.TRANSFORMS.add_component class SimCLRGaussianBlur(object): """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" @@ -925,6 +947,7 @@ def __call__(self, img): return img +@manager.TRANSFORMS.add_component class BYOLSolarize(object): """Solarize augmentation from BYOL: https://arxiv.org/abs/2006.07733""" diff --git a/passl/models/backbones/vit_moco.py b/passl/models/backbones/vit_moco.py new file mode 100644 index 00000000..5f3030e0 --- /dev/null +++ b/passl/models/backbones/vit_moco.py @@ -0,0 +1,200 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import paddle +import paddle.nn as nn +from functools import partial, reduce +from operator import mul + +from passl.core import manager +from passl.models.vision_transformer import VisionTransformer, PatchEmbed, to_2tuple + +from passl.nn import init + + +class VisionTransformerMoCo(VisionTransformer): + def __init__(self, stop_grad_conv1=False, **kwargs): + super().__init__(**kwargs) + # Use fixed 2D sin-cos position embedding + self.build_2d_sincos_position_embedding() + + # weight initialization + for name, m in self.named_sublayers(): + if isinstance(m, nn.Linear): + if 'qkv' in name: + # treat the weights of Q, K, V separately + val = math.sqrt( + 6. / float(m.weight.shape[1] // 3 + m.weight.shape[0])) + init.uniform_(m.weight, -val, val) + else: + init.xavier_uniform_(m.weight) + init.zeros_(m.bias) + init.normal_(self.cls_token, std=1e-6) + + if isinstance(self.patch_embed, PatchEmbed): + # xavier_uniform initialization + val = math.sqrt(6. / float(3 * reduce( + mul, self.patch_embed.patch_size, 1) + self.embed_dim)) + init.uniform_(self.patch_embed.proj.weight, -val, val) + init.zeros_(self.patch_embed.proj.bias) + + if stop_grad_conv1: + self.patch_embed.proj.weight.stop_gradient = True + self.patch_embed.proj.bias.stop_gradient = True + + def build_2d_sincos_position_embedding(self, temperature=10000.): + h = self.patch_embed.img_size[0] // self.patch_embed.patch_size[0] + w = self.patch_embed.img_size[1] // self.patch_embed.patch_size[1] + grid_w = paddle.arange(w, dtype=paddle.float32) + grid_h = paddle.arange(h, dtype=paddle.float32) + grid_w, grid_h = paddle.meshgrid(grid_w, grid_h) + assert self.embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' + pos_dim = self.embed_dim // 4 + omega = paddle.arange(pos_dim, dtype=paddle.float32) / pos_dim + omega = 1. / (temperature**omega) + + out_w = grid_w.flatten()[..., None] @omega[None] + out_h = grid_h.flatten()[..., None] @omega[None] + pos_emb = paddle.concat( + [ + paddle.sin(out_w), paddle.cos(out_w), paddle.sin(out_h), + paddle.cos(out_h) + ], + axis=1)[None, :, :] + pe_token = paddle.zeros([1, 1, self.embed_dim], dtype=paddle.float32) + + pos_embed = paddle.concat([pe_token, pos_emb], axis=1) + self.pos_embed = self.create_parameter(shape=pos_embed.shape) + self.pos_embed.set_value(pos_embed) + self.pos_embed.stop_gradient = True + + +class ConvStem(nn.Layer): + """ + ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881 + """ + + def __init__(self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True): + super().__init__() + + assert patch_size == 16, 'ConvStem only supports patch size of 16' + assert embed_dim % 8 == 0, 'Embed dimension must be divisible by 8 for ConvStem' + + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], + img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + # build stem, similar to the design in https://arxiv.org/abs/2106.14881 + stem = [] + input_dim, output_dim = 3, embed_dim // 8 + for l in range(4): + stem.append( + nn.Conv2D( + input_dim, + output_dim, + kernel_size=3, + stride=2, + padding=1, + bias_attr=False)) + stem.append(nn.BatchNorm2D(output_dim)) + stem.append(nn.ReLU()) + input_dim = output_dim + output_dim *= 2 + stem.append(nn.Conv2D(input_dim, embed_dim, kernel_size=1)) + self.proj = nn.Sequential(*stem) + + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose((0, 2, 1)) # BCHW -> BNC + x = self.norm(x) + return x + +@manager.BACKBONES.add_component +def moco_vit_small(**kwargs): + model = VisionTransformerMoCo( + patch_size=16, + embed_dim=384, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + **kwargs) + return model + +@manager.BACKBONES.add_component +def moco_vit_base(**kwargs): + model = VisionTransformerMoCo( + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + **kwargs) + return model + + +@manager.BACKBONES.add_component +def moco_vit_conv_small(**kwargs): + # minus one ViT block + model = VisionTransformerMoCo( + patch_size=16, + embed_dim=384, + depth=11, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + embed_layer=ConvStem, + **kwargs) + return model + +@manager.BACKBONES.add_component +def moco_vit_conv_base(**kwargs): + # minus one ViT block + model = VisionTransformerMoCo( + patch_size=16, + embed_dim=768, + depth=11, + num_heads=12, + mlp_ratio=4, + qkv_bias=True, + norm_layer=partial( + nn.LayerNorm, epsilon=1e-6), + embed_layer=ConvStem, + **kwargs) + return model diff --git a/passl/models/mocov3.py b/passl/models/mocov3.py new file mode 100644 index 00000000..5f29c2e2 --- /dev/null +++ b/passl/models/mocov3.py @@ -0,0 +1,161 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + +from passl.core import manager + + +class MoCo(nn.Layer): + """ + Build a MoCo model with a base encoder, a momentum encoder, and two MLPs + https://arxiv.org/abs/1911.05722 + """ + + def __init__(self, base_encoder, dim=256, mlp_dim=4096, T=1.0): + """ + dim: feature dimension (default: 256) + mlp_dim: hidden dimension in MLPs (default: 4096) + T: softmax temperature (default: 1.0) + """ + super(MoCo, self).__init__() + + self.T = T + + # build encoders + self.base_encoder = base_encoder(num_classes=mlp_dim) + self.momentum_encoder = base_encoder(num_classes=mlp_dim) + + self._build_projector_and_predictor_mlps(dim, mlp_dim) + + for param_b, param_m in zip(self.base_encoder.parameters(), + self.momentum_encoder.parameters()): + param_m.copy_(param_b, False) # initialize + param_m.stop_gradient = True # not update by gradient + + def _build_mlp(self, + num_layers, + input_dim, + mlp_dim, + output_dim, + last_bn=True): + mlp = [] + for l in range(num_layers): + dim1 = input_dim if l == 0 else mlp_dim + dim2 = output_dim if l == num_layers - 1 else mlp_dim + + mlp.append(nn.Linear(dim1, dim2, bias_attr=False)) + + if l < num_layers - 1: + mlp.append(nn.BatchNorm1D(dim2)) + mlp.append(nn.ReLU()) + elif last_bn: + # follow SimCLR's design: https://github.com/google-research/simclr/blob/master/model_util.py#L157 + # for simplicity, we further removed gamma in BN + mlp.append( + nn.BatchNorm1D( + dim2, weight_attr=False, bias_attr=False)) + + return nn.Sequential(*mlp) + + def _build_projector_and_predictor_mlps(self, dim, mlp_dim): + pass + + @paddle.no_grad() + def _update_momentum_encoder(self, m): + """Momentum update of the momentum encoder""" + with paddle.amp.auto_cast(False): + for param_b, param_m in zip(self.base_encoder.parameters(), + self.momentum_encoder.parameters()): + paddle.assign((param_m * m + param_b * (1. - m)), param_m) + + def contrastive_loss(self, q, k): + # normalize + q = nn.functional.normalize(q, axis=1) + k = nn.functional.normalize(k, axis=1) + # gather all targets + k = concat_all_gather(k) + # Einstein sum is more intuitive + logits = paddle.einsum('nc,mc->nm', q, k) / self.T + N = logits.shape[0] # batch size per GPU + labels = (paddle.arange( + N, dtype=paddle.int64) + N * paddle.distributed.get_rank()) + return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T) + + def forward(self, x1, x2, m): + """ + Input: + x1: first views of images + x2: second views of images + m: moco momentum + Output: + loss + """ + + # compute features + q1 = self.predictor(self.base_encoder(x1)) + q2 = self.predictor(self.base_encoder(x2)) + + with paddle.no_grad(): # no gradient + self._update_momentum_encoder(m) # update the momentum encoder + + # compute momentum features as targets + k1 = self.momentum_encoder(x1) + k2 = self.momentum_encoder(x2) + + return self.contrastive_loss(q1, k2) + self.contrastive_loss(q2, k1) + +@manager.MODELS.add_component +class MoCo_ResNet(MoCo): + def _build_projector_and_predictor_mlps(self, dim, mlp_dim): + hidden_dim = self.base_encoder.fc.weight.shape[0] + del self.base_encoder.fc, self.momentum_encoder.fc # remove original fc layer + + # projectors + self.base_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim) + self.momentum_encoder.fc = self._build_mlp(2, hidden_dim, mlp_dim, dim) + + # predictor + self.predictor = self._build_mlp(2, dim, mlp_dim, dim, False) + +@manager.MODELS.add_component +class MoCo_ViT(MoCo): + def _build_projector_and_predictor_mlps(self, dim, mlp_dim): + hidden_dim = self.base_encoder.head.weight.shape[0] + del self.base_encoder.head, self.momentum_encoder.head # remove original fc layer + + # projectors + self.base_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, dim) + self.momentum_encoder.head = self._build_mlp(3, hidden_dim, mlp_dim, + dim) + + # predictor + self.predictor = self._build_mlp(2, dim, mlp_dim, dim) + + +# utils +@paddle.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + """ + if paddle.distributed.get_world_size() < 2: + return tensor + + tensors_gather = [] + paddle.distributed.all_gather(tensors_gather, tensor) + + output = paddle.concat(tensors_gather, axis=0) + return output diff --git a/passl/utils/__init__.py b/passl/utils/__init__.py index 97043fd7..f00126f3 100644 --- a/passl/utils/__init__.py +++ b/passl/utils/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from . import logger +from .utils import * \ No newline at end of file diff --git a/passl/utils/download.py b/passl/utils/download.py new file mode 100644 index 00000000..5e6cd878 --- /dev/null +++ b/passl/utils/download.py @@ -0,0 +1,180 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import os +import shutil +import sys +import tarfile +import time +import zipfile + +import requests + +lasttime = time.time() +FLUSH_INTERVAL = 0.1 + + +def progress(str, end=False): + global lasttime + if end: + str += "\n" + lasttime = 0 + if time.time() - lasttime >= FLUSH_INTERVAL: + sys.stdout.write("\r%s" % str) + lasttime = time.time() + sys.stdout.flush() + + +def _download_file(url, savepath, print_progress): + if print_progress: + print("Connecting to {}".format(url)) + r = requests.get(url, stream=True, timeout=15) + total_length = r.headers.get('content-length') + + if total_length is None: + with open(savepath, 'wb') as f: + shutil.copyfileobj(r.raw, f) + else: + with open(savepath, 'wb') as f: + dl = 0 + total_length = int(total_length) + starttime = time.time() + if print_progress: + print("Downloading %s" % os.path.basename(savepath)) + for data in r.iter_content(chunk_size=4096): + dl += len(data) + f.write(data) + if print_progress: + done = int(50 * dl / total_length) + progress("[%-50s] %.2f%%" % + ('=' * done, float(100 * dl) / total_length)) + if print_progress: + progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) + + +def _uncompress_file_zip(filepath, extrapath): + files = zipfile.ZipFile(filepath, 'r') + filelist = files.namelist() + rootpath = filelist[0] + total_num = len(filelist) + for index, file in enumerate(filelist): + files.extract(file, extrapath) + yield total_num, index, rootpath + files.close() + yield total_num, index, rootpath + + +def _uncompress_file_tar(filepath, extrapath, mode="r:gz"): + files = tarfile.open(filepath, mode) + filelist = files.getnames() + total_num = len(filelist) + rootpath = filelist[0] + for index, file in enumerate(filelist): + files.extract(file, extrapath) + yield total_num, index, rootpath + files.close() + yield total_num, index, rootpath + + +def _uncompress_file(filepath, extrapath, delete_file, print_progress): + if print_progress: + print("Uncompress %s" % os.path.basename(filepath)) + + if filepath.endswith("zip"): + handler = _uncompress_file_zip + elif filepath.endswith("tgz"): + handler = functools.partial(_uncompress_file_tar, mode="r:*") + else: + handler = functools.partial(_uncompress_file_tar, mode="r") + + for total_num, index, rootpath in handler(filepath, extrapath): + if print_progress: + done = int(50 * float(index) / total_num) + progress("[%-50s] %.2f%%" % + ('=' * done, float(100 * index) / total_num)) + if print_progress: + progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) + + if delete_file: + os.remove(filepath) + + return rootpath + + +def download_file_and_uncompress(url, + savepath=None, + extrapath=None, + extraname=None, + print_progress=True, + cover=False, + delete_file=True, + filename=None): + if savepath is None: + savepath = "." + + if extrapath is None: + extrapath = "." + + savename = url.split("/")[-1] + if not os.path.exists(savepath): + os.makedirs(savepath) + + savepath = os.path.join(savepath, savename) + savename = ".".join(savename.split(".")[:-1]) + savename = os.path.join(extrapath, savename) + extraname = savename if extraname is None else os.path.join(extrapath, + extraname) + + if cover: + if os.path.exists(savepath): + shutil.rmtree(savepath) + if os.path.exists(savename): + shutil.rmtree(savename) + if os.path.exists(extraname): + shutil.rmtree(extraname) + full_path = os.path.join(extraname, + filename) if filename is not None else extraname + + rank_id_curr_node = int(os.environ.get("PADDLE_RANK_IN_NODE", 0)) + + if not os.path.exists( + full_path): # If pretrained model exists, skip download process. + lock_path = extraname + '.download.lock' + with open(lock_path, 'w'): # touch + os.utime(lock_path, None) + if rank_id_curr_node == 0: + if not os.path.exists(savename): + if not os.path.exists(savepath): + _download_file(url, savepath, print_progress) + + if (not tarfile.is_tarfile(savepath)) and ( + not zipfile.is_zipfile(savepath)): + if not os.path.exists(extraname): + os.makedirs(extraname) + shutil.move(savepath, extraname) + + else: + savename = _uncompress_file(savepath, extrapath, + delete_file, print_progress) + savename = os.path.join(extrapath, savename) + shutil.move(savename, extraname) + + os.remove(lock_path) + + else: + while os.path.exists(lock_path): + time.sleep(0.5) + + return extraname diff --git a/passl/utils/logger.py b/passl/utils/logger.py index 8d44a7ac..7c75b940 100644 --- a/passl/utils/logger.py +++ b/passl/utils/logger.py @@ -1,4 +1,4 @@ -# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,111 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import sys +import time -import logging -import datetime -import paddle.distributed as dist -from passl.utils.misc import AverageMeter +import paddle -_logger = None +levels = {0: 'ERROR', 1: 'WARNING', 2: 'INFO', 3: 'DEBUG'} +log_level = 2 -def init_logger(name='passl', log_file=None, log_level=logging.INFO): - """Initialize and get a logger by name. - If the logger has not been initialized, this method will initialize the - logger by adding one or two handlers, otherwise the initialized logger will - be directly returned. During initialization, a StreamHandler will always be - added. If `log_file` is specified a FileHandler will also be added. - Args: - name (str): Logger name. - log_file (str | None): The log filename. If specified, a FileHandler - will be added to the logger. - log_level (int): The logger level. Note that only the process of - rank 0 is affected, and other processes will set the level to - "Error" thus be silent most of the time. - Returns: - logging.Logger: The expected logger. - """ - global _logger - assert _logger is None, "logger should not be initialized twice or more." - _logger = logging.getLogger(name) +def log(level=2, message=""): + if paddle.distributed.ParallelEnv().local_rank == 0: + current_time = time.time() + time_array = time.localtime(current_time) + current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array) + if log_level >= level: + print("{} [{}]\t{}".format(current_time, levels[level], message) + .encode("utf-8").decode("latin1")) + sys.stdout.flush() - formatter = logging.Formatter( - '[%(asctime)s] %(name)s %(levelname)s: %(message)s', - datefmt="%Y/%m/%d %H:%M:%S") - stream_handler = logging.StreamHandler(stream=sys.stdout) - stream_handler.setFormatter(formatter) - _logger.addHandler(stream_handler) - if log_file is not None and dist.get_rank() == 0: - log_file_folder = os.path.split(log_file)[0] - os.makedirs(log_file_folder, exist_ok=True) - file_handler = logging.FileHandler(log_file, 'a') - file_handler.setFormatter(formatter) - _logger.addHandler(file_handler) - if dist.get_rank() == 0: - _logger.setLevel(log_level) - else: - _logger.setLevel(logging.ERROR) +def debug(message=""): + log(level=3, message=message) -def log_at_trainer0(log): - """ - logs will print multi-times when calling Fleet API. - Only display single log and ignore the others. - """ +def info(message=""): + log(level=2, message=message) - def wrapper(fmt, *args): - if dist.get_rank() == 0: - log(fmt, *args) - return wrapper +def warning(message=""): + log(level=1, message=message) -@log_at_trainer0 -def info(fmt, *args): - _logger.info(fmt, *args) - - -@log_at_trainer0 -def debug(fmt, *args): - _logger.debug(fmt, *args) - - -@log_at_trainer0 -def warning(fmt, *args): - _logger.warning(fmt, *args) - - -@log_at_trainer0 -def error(fmt, *args): - _logger.error(fmt, *args) - - -def scaler(name, value, step, writer): - """ - This function will draw a scalar curve generated by the visualdl. - Usage: Install visualdl: pip3 install visualdl==2.0.0b4 - and then: - visualdl --logdir ./scalar --host 0.0.0.0 --port 8830 - to preview loss corve in real time. - """ - if writer is None: - return - writer.add_scalar(tag=name, step=step, value=value) - - -def dict_format(d, float_placeholders="{:.5f}"): - str_list = [] - for key in d: - if isinstance(d[key], float): - value = ("{}: " + float_placeholders).format(key, d[key]) - elif isinstance(d[key], AverageMeter): - value = ("{}: " + float_placeholders).format(key, d[key].avg) - else: - value = "{}: {}".format(key, d[key]) - - str_list.append(value) - return ", ".join(str_list) +def error(message=""): + log(level=0, message=message) diff --git a/passl/utils/utils.py b/passl/utils/utils.py new file mode 100644 index 00000000..4230538f --- /dev/null +++ b/passl/utils/utils.py @@ -0,0 +1,491 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import glob +import yaml +import random +import tempfile +import platform +import subprocess +import contextlib +from urllib.parse import urlparse, unquote + +import numpy as np + +import cv2 +import paddle +import passl +from passl.utils import logger +from passl.utils.download import download_file_and_uncompress + + + +class NoAliasDumper(yaml.SafeDumper): + """ + Avoid yaml anchor + """ + def ignore_aliases(self): + return True + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + +def _get_user_home(): + return os.path.expanduser('~') + + +def _get_passl_home(): + if 'PASSL_HOME' in os.environ: + home_path = os.environ['PASSL_HOME'] + if os.path.exists(home_path): + if os.path.isdir(home_path): + return home_path + else: + logger.warning('PASSL_HOME {} is a file!'.format(home_path)) + else: + return home_path + return os.path.join(_get_user_home(), '.paddleseg') + + +def _get_sub_home(directory): + home = os.path.join(_get_passl_home(), directory) + if not os.path.exists(home): + os.makedirs(home, exist_ok=True) + return home + + +USER_HOME = _get_user_home() +PASSL_HOME = _get_passl_home() +DATA_HOME = _get_sub_home('dataset') +TMP_HOME = _get_sub_home('tmp') +PRETRAINED_MODEL_HOME = _get_sub_home('pretrained_model') + + + +def set_seed(seed=None): + if seed is not None: + paddle.seed(seed) + np.random.seed(seed) + random.seed(seed) + + +def show_env_info(): + env_info = get_sys_env() + info = ['{}: {}'.format(k, v) for k, v in env_info.items()] + info = '\n'.join(['', format('Environment Information', '-^48s')] + info + + ['-' * 48]) + logger.info(info) + + +def show_cfg_info(config): + msg = '\n---------------Config Information---------------\n' + ordered_module = ('batch_size', 'epochs', 'train_dataset', 'val_dataset', + 'optimizer', 'lr_scheduler', 'loss', 'model') + all_module = set(config.dic.keys()) + for module in ordered_module: + if module in config.dic: + module_dic = {module: config.dic[module]} + msg += str(yaml.dump(module_dic, Dumper=NoAliasDumper)) + all_module.remove(module) + for module in all_module: + module_dic = {module: config.dic[module]} + msg += str(yaml.dump(module_dic, Dumper=NoAliasDumper)) + msg += '------------------------------------------------\n' + logger.info(msg) + + +def set_device(device): + env_info = get_sys_env() + if device == 'gpu' and env_info['Paddle compiled with cuda'] \ + and env_info['GPUs used']: + place = 'gpu' + elif device == 'xpu' and paddle.is_compiled_with_xpu(): + place = 'xpu' + elif device == 'npu' and "npu" in paddle.device.get_all_custom_device_type( + ): + place = 'npu' + elif device == 'mlu' and paddle.is_compiled_with_mlu(): + place = 'mlu' + else: + place = 'cpu' + paddle.set_device(place) + logger.info("Set device: {}".format(place)) + + +def convert_sync_batchnorm(model, device): + # Convert bn to sync_bn when use multi GPUs + env_info = get_sys_env() + if device == 'gpu' and env_info['Paddle compiled with cuda'] \ + and env_info['GPUs used'] and paddle.distributed.ParallelEnv().nranks > 1: + model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) + logger.info("Convert bn to sync_bn") + return model + + +def set_cv2_num_threads(num_workers): + # Limit cv2 threads if too many subprocesses are spawned. + # This should reduce resource allocation and thus boost performance. + nranks = paddle.distributed.ParallelEnv().nranks + if nranks >= 8 and num_workers >= 8: + logger.warning("The number of threads used by OpenCV is " \ + "set to 1 to improve performance.") + cv2.setNumThreads(1) + + +def worker_init_fn(worker_id): + np.random.seed(random.randint(0, 100000)) + +@contextlib.contextmanager +def generate_tempdir(directory: str=None, **kwargs): + '''Generate a temporary directory''' + directory = TMP_HOME if not directory else directory + with tempfile.TemporaryDirectory(dir=directory, **kwargs) as _dir: + yield _dir + + +def load_entire_model(model, pretrained): + if pretrained is not None: + load_pretrained_model(model, pretrained) + else: + logger.warning('Weights are not loaded for {} model since the ' + 'path of weights is None'.format( + model.__class__.__name__)) + + +def download_pretrained_model(pretrained_model): + """ + Download pretrained model from url. + Args: + pretrained_model (str): the url of pretrained weight + Returns: + str: the path of pretrained weight + """ + assert urlparse(pretrained_model).netloc, "The url is not valid." + + pretrained_model = unquote(pretrained_model) + savename = pretrained_model.split('/')[-1] + if not savename.endswith(('tgz', 'tar.gz', 'tar', 'zip')): + savename = pretrained_model.split('/')[-2] + filename = pretrained_model.split('/')[-1] + else: + savename = savename.split('.')[0] + filename = 'model.pdparams' + + with generate_tempdir() as _dir: + pretrained_model = download_file_and_uncompress( + pretrained_model, + savepath=_dir, + cover=False, + extrapath=PRETRAINED_MODEL_HOME, + extraname=savename, + filename=filename) + pretrained_model = os.path.join(pretrained_model, filename) + return pretrained_model + + +def load_pretrained_model(model, pretrained_model): + if pretrained_model is not None: + logger.info('Loading pretrained model from {}'.format(pretrained_model)) + + if urlparse(pretrained_model).netloc: + pretrained_model = download_pretrained_model(pretrained_model) + + if os.path.exists(pretrained_model): + para_state_dict = paddle.load(pretrained_model) + + model_state_dict = model.state_dict() + keys = model_state_dict.keys() + num_params_loaded = 0 + for k in keys: + if k not in para_state_dict: + logger.warning("{} is not in pretrained model".format(k)) + elif list(para_state_dict[k].shape) != list(model_state_dict[k] + .shape): + logger.warning( + "[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})" + .format(k, para_state_dict[k].shape, model_state_dict[k] + .shape)) + else: + model_state_dict[k] = para_state_dict[k] + num_params_loaded += 1 + model.set_dict(model_state_dict) + logger.info("There are {}/{} variables loaded into {}.".format( + num_params_loaded, + len(model_state_dict), model.__class__.__name__)) + + else: + raise ValueError('The pretrained model directory is not Found: {}'. + format(pretrained_model)) + else: + logger.info( + 'No pretrained model to load, {} will be trained from scratch.'. + format(model.__class__.__name__)) + + +def resume(model, optimizer, resume_model): + if resume_model is not None: + logger.info('Resume model from {}'.format(resume_model)) + if os.path.exists(resume_model): + resume_model = os.path.normpath(resume_model) + ckpt_path = os.path.join(resume_model, 'model.pdparams') + para_state_dict = paddle.load(ckpt_path) + ckpt_path = os.path.join(resume_model, 'model.pdopt') + opti_state_dict = paddle.load(ckpt_path) + model.set_state_dict(para_state_dict) + optimizer.set_state_dict(opti_state_dict) + + epoch = resume_model.split('_')[-1] + epoch = int(epoch) + return epoch + else: + raise ValueError( + 'Directory of the model needed to resume is not Found: {}'. + format(resume_model)) + else: + logger.info('No model needed to resume.') + + +def worker_init_fn(worker_id): + np.random.seed(random.randint(0, 100000)) + + +def get_image_list(image_path): + """Get image list""" + valid_suffix = [ + '.JPEG', '.jpeg', '.JPG', '.jpg', '.BMP', '.bmp', '.PNG', '.png' + ] + image_list = [] + image_dir = None + if os.path.isfile(image_path): + if os.path.splitext(image_path)[-1] in valid_suffix: + image_list.append(image_path) + else: + image_dir = os.path.dirname(image_path) + with open(image_path, 'r') as f: + for line in f: + line = line.strip() + if len(line.split()) > 1: + line = line.split()[0] + image_list.append(os.path.join(image_dir, line)) + elif os.path.isdir(image_path): + image_dir = image_path + for root, dirs, files in os.walk(image_path): + for f in files: + if '.ipynb_checkpoints' in root: + continue + if f.startswith('.'): + continue + if os.path.splitext(f)[-1] in valid_suffix: + image_list.append(os.path.join(root, f)) + else: + raise FileNotFoundError( + '`--image_path` is not found. it should be a path of image, or a file list containing image paths, or a directory including images.' + ) + + if len(image_list) == 0: + raise RuntimeError( + 'There are not image file in `--image_path`={}'.format(image_path)) + + return image_list, image_dir + + +class NoAliasDumper(yaml.SafeDumper): + def ignore_aliases(self, data): + return True + + +class CachedProperty(object): + """ + A property that is only computed once per instance and then replaces itself with an ordinary attribute. + + The implementation refers to https://github.com/pydanny/cached-property/blob/master/cached_property.py . + Note that this implementation does NOT work in multi-thread or coroutine senarios. + """ + + def __init__(self, func): + super().__init__() + self.func = func + self.__doc__ = getattr(func, '__doc__', '') + + def __get__(self, obj, cls): + if obj is None: + return self + val = self.func(obj) + # Hack __dict__ of obj to inject the value + # Note that this is only executed once + obj.__dict__[self.func.__name__] = val + return val + + +def get_in_channels(model_cfg): + if 'backbone' in model_cfg: + return model_cfg['backbone'].get('in_channels', None) + else: + return model_cfg.get('in_channels', None) + + +def set_in_channels(model_cfg, in_channels): + model_cfg = model_cfg.copy() + if 'backbone' in model_cfg: + model_cfg['backbone']['in_channels'] = in_channels + else: + model_cfg['in_channels'] = in_channels + return model_cfg + + + +def _find_cuda_home(): + '''Finds the CUDA install path. It refers to the implementation of + pytorch . + ''' + # Guess #1 + cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') + if cuda_home is None: + # Guess #2 + try: + which = 'where' if sys.platform == 'win32' else 'which' + nvcc = subprocess.check_output([which, + 'nvcc']).decode().rstrip('\r\n') + cuda_home = os.path.dirname(os.path.dirname(nvcc)) + except Exception: + # Guess #3 + if sys.platform == 'win32': + cuda_homes = glob.glob( + 'C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*') + if len(cuda_homes) == 0: + cuda_home = '' + else: + cuda_home = cuda_homes[0] + else: + cuda_home = '/usr/local/cuda' + if not os.path.exists(cuda_home): + cuda_home = None + return cuda_home + + +def _get_nvcc_info(cuda_home): + if cuda_home is not None and os.path.isdir(cuda_home): + try: + nvcc = os.path.join(cuda_home, 'bin/nvcc') + if not sys.platform == 'win32': + nvcc = subprocess.check_output( + "{} -V".format(nvcc), shell=True).decode() + else: + nvcc = subprocess.check_output( + "\"{}\" -V".format(nvcc), shell=True).decode() + nvcc = nvcc.strip().split('\n')[-1] + except subprocess.SubprocessError: + nvcc = "Not Available" + else: + nvcc = "Not Available" + return nvcc + + +def _get_gpu_info(): + try: + gpu_info = subprocess.check_output(['nvidia-smi', + '-L']).decode().strip() + gpu_info = gpu_info.split('\n') + for i in range(len(gpu_info)): + gpu_info[i] = ' '.join(gpu_info[i].split(' ')[:4]) + except: + gpu_info = ' Can not get GPU information. Please make sure CUDA have been installed successfully.' + return gpu_info + + +def get_sys_env(): + """collect environment information""" + env_info = {} + env_info['platform'] = platform.platform() + + env_info['Python'] = sys.version.replace('\n', '') + + # TODO is_compiled_with_cuda() has not been moved + compiled_with_cuda = paddle.is_compiled_with_cuda() + env_info['Paddle compiled with cuda'] = compiled_with_cuda + + if compiled_with_cuda: + cuda_home = _find_cuda_home() + env_info['NVCC'] = _get_nvcc_info(cuda_home) + # refer to https://github.com/PaddlePaddle/Paddle/blob/release/2.0-rc/paddle/fluid/platform/device_context.cc#L327 + v = paddle.get_cudnn_version() + v = str(v // 1000) + '.' + str(v % 1000 // 100) + env_info['cudnn'] = v + if 'gpu' in paddle.get_device(): + gpu_nums = paddle.distributed.ParallelEnv().nranks + else: + gpu_nums = 0 + env_info['GPUs used'] = gpu_nums + + env_info['CUDA_VISIBLE_DEVICES'] = os.environ.get( + 'CUDA_VISIBLE_DEVICES') + if gpu_nums == 0: + os.environ['CUDA_VISIBLE_DEVICES'] = '' + env_info['GPU'] = _get_gpu_info() + + try: + gcc = subprocess.check_output(['gcc', '--version']).decode() + gcc = gcc.strip().split('\n')[0] + env_info['GCC'] = gcc + except: + pass + + env_info['Passl'] = passl.__version__ + env_info['PaddlePaddle'] = paddle.__version__ + env_info['OpenCV'] = cv2.__version__ + + return env_info + + diff --git a/tasks/ssl/mocov3_new/configs/pretrain.yml b/tasks/ssl/mocov3_new/configs/pretrain.yml new file mode 100644 index 00000000..1be71b10 --- /dev/null +++ b/tasks/ssl/mocov3_new/configs/pretrain.yml @@ -0,0 +1,82 @@ +batch_size: 4096 # this is the total batch size +epochs: 300 + +train_dataset: + type: ImageFolder + root: ./dataset/ILSVRC2012/train + transforms: + - type: TwoViewsTransform + type: Compose + - type: RandomResizedCrop + size: 224 + scale: [0.08, 1.0] + - type: RandomApply + - type: ColorJitter + p: 0.4 + brightness: 0.4 + contrast: 0.2 + saturation: 0.1 + p: 0.8 + - type: RandomGrayscale + p: 0.2 + - type: RandomApply + - type: SimCLRGaussianBlur + sigma: [.1, 2.] + p: 1.0 + - type: RandomHorizontalFlip + - type: ToTensor + - type: Normalize + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + type: Compose + - type: RandomResizedCrop + size: 224 + scale: [0.08, 1.0] + - type: RandomApply + - type: ColorJitter + p: 0.4 + brightness: 0.4 + contrast: 0.2 + saturation: 0.1 + p: 0.8 + - type: RandomGrayscale + p: 0.2 + - type: RandomApply + - type: SimCLRGaussianBlur + sigma: [.1, 2.] + p: 1.0 + - type: RandomApply + - type: BYOLSolarize + p: 0.2 + - type: RandomHorizontalFlip + - type: ToTensor + - type: Normalize + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + + +val_dataset: + type: ImageFolder + root: ./dataset/ILSVRC2012/train + transforms: + + target_transforms: + + +model: + type: MoCo_ViT + backbone: + type: moco_vit_base + stop_grad_conv1: True + dim: 256 + mlp_dim: 4096 + T: 0.2 + +optimizer: + type: AdamW + weight_decay: 0.1 + learning_rate: 1.5e-4 + +special_config: + moco_m_cos: True + moco_m: 0.99 \ No newline at end of file diff --git a/tasks/ssl/mocov3_new/main_moco_new.py b/tasks/ssl/mocov3_new/main_moco_new.py new file mode 100644 index 00000000..d03f6abe --- /dev/null +++ b/tasks/ssl/mocov3_new/main_moco_new.py @@ -0,0 +1,349 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import builtins +import math +import os +import random +import shutil +import time +import warnings + +import paddle +import paddle.nn as nn +import paddle.distributed as dist + +import passl +from passl.utils import utils + +from visualdl import LogWriter as SummaryWriter + + + +def parse_args(): + parser = argparse.ArgumentParser(description='MoCo ImageNet Pre-Training') + + # Common params + parser.add_argument( + '--train_mode', + help='The three training model choose from ["pretrain", "linear_probe", "finetune"].', + type=str, + default='pretrain') + parser.add_argument("--config", help="The path of config file.", type=str) + parser.add_argument( + '--device', + help='Set the device place for training model.', + default='gpu', + choices=['cpu', 'gpu', 'xpu', 'npu', 'mlu'], + type=str) + parser.add_argument( + '--save_dir', + help='The directory for saving the model snapshot.', + type=str, + default='./output') + parser.add_argument( + '-j', + '--num_workers', + help='Number of workers for data loader. Bigger num_workers can speed up data processing.', + type=int, + default=8) + parser.add_argument( + '--use_vdl', + help='Whether to record the data to VisualDL in training.', + action='store_true') + parser.add_argument( + '--use_ema', + help='Whether to ema the model in training.', + action='store_true') + + # Runntime params + parser.add_argument( + '--resume_model', + help='The path of the model to resume training.', + type=str) + parser.add_argument( + '--epochs', + default=100, + help='Epochs in training.', + type=int) + parser.add_argument( + '-b', + '--batch_size', + default=4096, + help='Mini batch size of one gpu or cpu. ', + type=int) + parser.add_argument( + '--lr', + '--learning_rate', + default=0.6, + help='Learning rate.', + type=float) + parser.add_argument( + '--warmup_epochs', + default=10, + help="num of epochs for linear warmup.", + type=int) + + parser.add_argument( + '--save_interval', + help='How many epochs to save a model snapshot once during training.', + type=int, + default=1000) + parser.add_argument( + '--log_iters', + help='Display logging information at every `log_iters`.', + default=10, + type=int) + + # Other params + parser.add_argument( + '--seed', + help='Set the random seed in training.', + default=None, + type=int) + parser.add_argument( + "--precision", + default="fp32", + type=str, + choices=["fp32", "fp16"], + help="Use AMP (Auto mixed precision) if precision='fp16'. If precision='fp32', the training is normal." + ) + parser.add_argument( + "--amp_level", + default="O1", + type=str, + choices=["O1", "O2"], + help="Auto mixed precision level. Accepted values are “O1” and “O2”: O1 represent mixed precision, the input \ + data type of each operator will be casted by white_list and black_list; O2 represent Pure fp16, all operators \ + parameters and input data will be casted to fp16, except operators in black_list, don’t support fp16 kernel \ + and batchnorm. Default is O1(amp).") + + parser.add_argument( + '--opts', help='Update the key-value pairs of all options.', nargs='+') + + return parser.parse_args() + +################################################ +## Code need to be modified to compact Trainer # +## >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>># +################################################ +def main(): + if args.seed is not None: + random.seed(args.seed) + paddle.seed(args.seed) + np.random.seed(args.seed) + RELATED_FLAGS_SETTING = {} + RELATED_FLAGS_SETTING['FLAGS_cudnn_deterministic'] = 1 + paddle.fluid.set_flags(RELATED_FLAGS_SETTING) + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + + device = paddle.set_device("gpu") + dist.init_parallel_env() + args.world_size = dist.get_world_size() + args.rank = dist.get_rank() + args.distributed = args.world_size > 1 + + # suppress printing if not first GPU on each node + if args.rank != 0: + + def print_pass(*args): + pass + + builtins.print = print_pass + + + # infer learning rate before changing batch size + args.lr = args.lr * args.batch_size / 256 + + if args.distributed: + # apply SyncBN + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + + args.batch_size = int(args.batch_size / args.world_size) + model = paddle.DataParallel(model) + + print(model) # print model after SyncBatchNorm + + scaler = paddle.amp.GradScaler( + init_loss_scaling=2.**16, incr_every_n_steps=2000) + + summary_writer = SummaryWriter() if args.rank == 0 else None + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + checkpoint = paddle.load(args.resume) + args.start_epoch = checkpoint['epoch'] + model.set_state_dict(checkpoint['state_dict']) + optimizer.set_state_dict(checkpoint['optimizer']) + scaler.load_state_dict(checkpoint['scaler']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + + train_sampler = paddle.io.DistributedBatchSampler( + train_dataset, + shuffle=True, + batch_size=args.batch_size, + drop_last=True) + + train_loader = paddle.io.DataLoader( + train_dataset, + batch_sampler=train_sampler, + num_workers=args.workers, + use_shared_memory=True, ) + + for epoch in range(args.start_epoch, args.epochs): + if args.distributed: + train_loader.batch_sampler.set_epoch(epoch) + + # train for one epoch + train(train_loader, model, optimizer, scaler, summary_writer, epoch, + args) + + if args.rank == 0 and epoch % 10 == 0 or epoch == args.epochs - 1: # only the first GPU saves checkpoint + save_checkpoint( + { + 'epoch': epoch + 1, + 'arch': args.arch, + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'scaler': scaler.state_dict(), + }, + is_best=False, + filename='checkpoint_%04d.pd' % epoch) + + if args.rank == 0: + summary_writer.close() + + +def train(train_loader, model, optimizer, scaler, summary_writer, epoch, args): + batch_time = utils.AverageMeter('Time', ':6.3f') + data_time = utils.AverageMeter('Data', ':6.3f') + learning_rates = utils.AverageMeter('LR', ':.4e') + losses = utils.AverageMeter('Loss', ':.4e') + progress = utils.ProgressMeter( + len(train_loader), [batch_time, data_time, learning_rates, losses], + prefix="Epoch: [{}]".format(epoch)) + + # switch to train mode + model.train() + + end = time.time() + iters_per_epoch = len(train_loader) + moco_m = args.moco_m + for i, (images, _) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + # adjust learning rate and momentum coefficient per iteration + lr = adjust_learning_rate(optimizer, epoch + i / iters_per_epoch, args) + learning_rates.update(lr) + if args.moco_m_cos: + moco_m = adjust_moco_momentum(epoch + i / iters_per_epoch, args) + + images[0] = images[0].cuda() + images[1] = images[1].cuda() + + # compute output + with paddle.amp.auto_cast(): + loss = model(images[0], images[1], moco_m) + + losses.update(loss.item(), images[0].shape[0]) + if args.rank == 0: + summary_writer.add_scalar("loss", + loss.item(), epoch * iters_per_epoch + i) + + # compute gradient and do SGD step + optimizer.clear_grad() + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + +def save_checkpoint(state, is_best, filename='checkpoint.pd'): + paddle.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pd') + + + +def adjust_learning_rate(optimizer, epoch, args): + """Decays the learning rate with half-cycle cosine after warmup""" + if epoch < args.warmup_epochs: + lr = args.lr * epoch / args.warmup_epochs + else: + lr = args.lr * 0.5 * ( + 1. + math.cos(math.pi * (epoch - args.warmup_epochs) / + (args.epochs - args.warmup_epochs))) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + return lr + + +def adjust_moco_momentum(epoch, args): + """Adjust moco momentum based on current epoch""" + m = 1. - 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) * ( + 1. - args.moco_m) + return m +################################################ +## >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>># +################################################ + + + +# code starts here +class MOCOV3PretrainTrainer([passl.core.PasslTrainer]): + def __init__(self) -> None: + """ + The init part has three important components: + self.builder: contains all the component build from config: loss, models,backbones, dataset, optimier, transforms. + self.args: The cli args, mainly includes the training configs. + self.cfg: The config get from the yaml, update serveral configs from args as well, for print to the log. + """ + super().__init__(args) + + + def train(self): + """ + Train the model + """ + + +if __name__ == '__main__': + args = parse_args() + if args.train_mode == "pretrain": + trainer = MOCOV3PretrainTrainer(args) + elif args.train_mode == "finetune": + pass + # trainer = MOCOV3FinetuneTrainer(args) + elif args.train_mode == "linear_probe": + pass + # trainer = MOCOV3LinearTrainer(args) + trainer.train() diff --git a/tasks/ssl/mocov3_new/pretrain.sh b/tasks/ssl/mocov3_new/pretrain.sh new file mode 100644 index 00000000..bbfe1f3c --- /dev/null +++ b/tasks/ssl/mocov3_new/pretrain.sh @@ -0,0 +1,26 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#unset PADDLE_TRAINER_ENDPOINTS +#export PADDLE_NNODES=4 +#export PADDLE_MASTER="10.67.228.16:12538" +#export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export FLAGS_stop_check_timeout=3600 + +python -m paddle.distributed.launch \ + --nnodes=$PADDLE_NNODES \ + --master=$PADDLE_MASTER \ + --devices=$CUDA_VISIBLE_DEVICES \ + main_moco.py --warmup_epochs 40 \ + --train_mode pretrain