diff --git a/README.md b/README.md index 784e8d6..550e565 100644 --- a/README.md +++ b/README.md @@ -48,13 +48,16 @@ There are two ways to use the python package of basicsr, which are provided in t - :arrow_right: [installation mode](https://github.com/xinntao/BasicSR-examples/tree/installation): you need to install the project by running `python setup.py develop`. After installation, it is more convenient to import and use. -As a simple introduction and explanation, we use the example of *simple mode*, but we recommend the *installation mode* in practical use. +This `installation` branch uses the *installation mode* for illustration. We recommend using this mode for practical use. ```bash -git clone https://github.com/xinntao/BasicSR-examples.git +git clone -b installation https://github.com/xinntao/BasicSR-examples.git cd BasicSR-examples +python setup.py develop # need to install ``` +**Note that**: the installation model requires a package name for installation. We use `basicsrexamples` as the package name. + ### Preliminary Most deep-learning projects can be divided into the following parts: @@ -106,7 +109,7 @@ Let's explain it separately in the following parts. We need to implement a new dataset to fulfill our purpose. The dataset is used to feed the data into the model. -An example of this dataset is in [data/example_dataset.py](data/example_dataset.py). It has the following steps. +An example of this dataset is in [basicsrexamples/data/example_dataset.py](basicsrexamples/data/example_dataset.py). It has the following steps. 1. Read Ground-Truth (GT) images. BasicSR provides [FileClient](https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/file_client.py) for easily reading files in a folder, LMDB file and meta_info txt. In this example, we use the folder mode. For more reading modes, please refer to [basicsr/data](https://github.com/xinntao/BasicSR/tree/master/basicsr/data) 1. Synthesize low resolution images. We can directly implement the data procedures in the `__getitem__(self, index)` function, such as downsampling and adding JPEG compression. Many basic operations can be found in [[basicsr/data/degradations]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/degradations.py), [[basicsr/data/tranforms]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/transforms.py) ,and [[basicsr/data/data_util]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/data_util.py) @@ -151,7 +154,7 @@ datasets: #### :two: arch -An example of architecture is in [archs/example_arch.py](archs/example_arch.py). It mainly builds the network structure. +An example of architecture is in [basicsrexamples/archs/example_arch.py](basicsrexamples/archs/example_arch.py). It mainly builds the network structure. **Note**: @@ -174,7 +177,7 @@ network_g: #### :three: model -An example of model is in [models/example_model.py](models/example_model.py). It mainly builds the training process of a model. +An example of model is in [basicsrexamples/models/example_model.py](basicsrexamples/models/example_model.py). It mainly builds the training process of a model. In this file: 1. We inherit `SRModel` from basicsr. Many models have similar operations, so you can inherit and modify from [basicsr/models](https://github.com/xinntao/BasicSR/tree/master/basicsr/models). In this way, you can easily implement your ideas, such as GAN model, video model, *etc*. @@ -222,7 +225,7 @@ train: The whole training pipeline can reuse the [basicsr/train.py](https://github.com/xinntao/BasicSR/blob/master/basicsr/train.py) in BasicSR. -Based on this, our [train.py](train.py) can be very concise: +Based on this, our [basicsrexamples/train.py](basicsrexamples/train.py) can be very concise: ```python import os.path as osp @@ -243,7 +246,7 @@ if __name__ == '__main__': So far, we have completed the development of our project. We can quickly check whether there is a bug through the `debug` mode: ```bash -python train.py -opt options/example_option.yml --debug +python basicsrexamples/train.py -opt options/example_option.yml --debug ``` With `--debug`, the program will enter the debug mode. In the debug mode, the program will output at each iteration, and perform validation every 8 iterations, so that you can easily know whether the program has a bug~ @@ -253,7 +256,7 @@ With `--debug`, the program will enter the debug mode. In the debug mode, the pr After debugging, we can have the normal training. ```bash -python train.py -opt options/example_option.yml +python basicsrexamples/train.py -opt options/example_option.yml ``` If the training process is interrupted unexpectedly and the resume is required. Please use `--auto_resume` in the command: @@ -268,13 +271,36 @@ So far, you have finished developing your own projects using `BasicSR`. Isn't it You can use BasicSR-Examples as a template for your project. Here are some modifications you may need. +As GitHub does not support a specific branch as a template, we need extra steps to use the `installation` branch as the template. + +1. Click `Use this template` and remember to check the `[ ] Include all branches` checkbox +2. Change the installation branch to the master branch + ```bash + git clone -b installation YOUR_REPO # clone the installation branch + cd REPO_NAME + git branch -m installation master # rename the installation branch to master + git push -f origin master # force push the local master branch to remote + git push origin --delete installation # delete the remote installation branch + ``` + +You may need to modify the following files: + 1. Set up the *pre-commit* hook 1. In the root path, run: > pre-commit install 1. Modify the `LICENSE`
This repository uses the *MIT* license, you may change it to other licenses -The simple mode do not require many modifications. Those using the installation mode may need more modifications. See [here](https://github.com/xinntao/BasicSR-examples/blob/installation/README.md#As-a-Template) +As the installation mode requires the package name, you also need to modify all the `basicsrexamples` names to YOUR_PACKAGE_NAME. +Here are the detailed locations that contain the `basicsrexamples` name: +1. The `basicsrexamples` folder +1. [setup.py](setup.py#L9);   [setup.py](setup.py#L48);  [setup.py](setup.py#L91) +1. [basicsrexamples/train.py](basicsrexamples/train.py#L4-L6) +1. [basicsrexamples/archs/\_\_init\_\_.py](basicsrexamples/archs/__init__.py#L11) +1. [basicsrexamples/data/\_\_init\_\_.py](basicsrexamples/data/__init__.py#L11) +1. [basicsrexamples/models/\_\_init\_\_.py](basicsrexamples/models/__init__.py#L11) + +You also need to modify the corresponding information in the [setup.py](setup.py#L88-L113). ## :e-mail: Contact diff --git a/README_CN.md b/README_CN.md index 8fe94a8..6fbe548 100644 --- a/README_CN.md +++ b/README_CN.md @@ -46,13 +46,16 @@ - :arrow_right: [简单模式](https://github.com/xinntao/BasicSR-examples/tree/master): 项目的仓库不需要安装,就可以运行使用。但它有局限:不方便 import 复杂的层级关系;在其他位置也不容易访问本项目中的函数 - :arrow_right: [安装模式](https://github.com/xinntao/BasicSR-examples/tree/installation): 项目的仓库需要安装 `python setup.py develop`,安装之后 import 和使用都更加方便 -作为简单的入门和讲解, 我们使用*简单模式*的样例,但在实际使用中我们推荐*安装模式*。 +这个`installation`分支使用 *安装模式* 进行说明。在实际使用中我们也推荐 *安装模式*。 ```bash -git clone https://github.com/xinntao/BasicSR-examples.git +git clone -b installation https://github.com/xinntao/BasicSR-examples.git cd BasicSR-examples +python setup.py develop # need to install ``` +**注意**: 安装模式需要一个 包的名字 来安装。在这里,我们使用 `basicsrexamples` 作为包名。 + ### 预备 大部分的深度学习项目,都可以分为以下几个部分: @@ -102,7 +105,7 @@ python scripts/prepare_example_data.py 这个部分是用来确定喂给模型的数据的。 -这个 dataset 的例子在[data/example_dataset.py](data/example_dataset.py) 中,它完成了: +这个 dataset 的例子在[basicsrexamples/data/example_dataset.py](basicsrexamples/data/example_dataset.py) 中,它完成了: 1. 我们读取 Ground-Truth (GT) 的图像。读取的操作,BasicSR 提供了[FileClient](https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/file_client.py), 可以方便地读取 folder, lmdb 和 meta_info txt 指定的文件。在这个例子中,我们通过读取 folder 来说明,更多的读取模式可以参考 [basicsr/data](https://github.com/xinntao/BasicSR/tree/master/basicsr/data) 1. 合成低分辨率的图像。我们直接可以在 `__getitem__(self, index)` 的函数中实现我们想要的操作,比如降采样和添加 JPEG 压缩。很多基本操作都可以在 [[basicsr/data/degradations]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/degradations.py), [[basicsr/data/tranforms]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/transforms.py) 和 [[basicsr/data/data_util]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/data_util.py) 中找到 1. 转换成 Torch Tensor,返回合适的信息 @@ -145,7 +148,7 @@ datasets: #### :two: arch -Architecture 的例子在 [archs/example_arch.py](archs/example_arch.py)中。它主要搭建了网络结构。 +Architecture 的例子在 [basicsrexamples/archs/example_arch.py](basicsrexamples/archs/example_arch.py)中。它主要搭建了网络结构。 **注意**: 1. 需要在 `ExampleArch` 前添加 `@ARCH_REGISTRY.register()`,以便注册好新写的 arch。这个操作主要用来防止出现同名的 arch,从而带来潜在的 bug @@ -167,7 +170,7 @@ network_g: #### :three: model -Model 的例子在 [models/example_model.py](models/example_model.py)中。它主要搭建了模型的训练过程。 +Model 的例子在 [basicsrexamples/models/example_model.py](basicsrexamples/models/example_model.py)中。它主要搭建了模型的训练过程。 在这个文件中: 1. 我们从 basicsr 中继承了 `SRModel`。很多模型都有相似的操作,因此可以通过继承 [basicsr/models](https://github.com/xinntao/BasicSR/tree/master/basicsr/models) 中的模型来更方便地实现自己的想法,比如GAN模型,Video模型等 1. 使用了两个 Loss: L1 和 L2 (MSE) loss @@ -213,7 +216,7 @@ train: 整个 training pipeline 可以复用 basicsr 里面的 [basicsr/train.py](https://github.com/xinntao/BasicSR/blob/master/basicsr/train.py)。 -基于此,我们的 [train.py](train.py)可以非常简洁。 +基于此,我们的 [basicsrexamples/train.py](basicsrexamples/train.py)可以非常简洁。 ```python import os.path as osp @@ -234,7 +237,7 @@ if __name__ == '__main__': 至此,我们已经完成了我们这个项目的开发,下面可以通过 `debug` 模式来快捷地看看是否有问题: ```bash -python train.py -opt options/example_option.yml --debug +python basicsrexamples/train.py -opt options/example_option.yml --debug ``` 只要带上 `--debug` 就进入 debug 模式。在 debug 模式中,程序每个iter都会输出,8个iter后就会进行validation,这样可以很方便地知道程序有没有bug啦~ @@ -244,7 +247,7 @@ python train.py -opt options/example_option.yml --debug 经过debug没有问题后,我们就可以正式训练了。 ```bash -python train.py -opt options/example_option.yml +python basicsrexamples/train.py -opt options/example_option.yml ``` 如果训练过程意外中断需要 resume, 则使用 `--auto_resume` 可以方便地自动resume: @@ -258,13 +261,37 @@ python train.py -opt options/example_option.yml --auto_resume 你可以使用 BasicSR-Examples 作为你项目的模板。下面主要展示一下你可能需要的修改。 +由于 GitHub 不支持将特定分支作为模板,因此我们需要额外的步骤来使用`installation`分支作为模板。 + +1. 点击 `Use this template`;记得勾选 `[ ] Include all branches` +2. 将 安装分支 修改为 主分支 + ```bash + git clone -b installation YOUR_REPO # clone the installation branch + cd REPO_NAME + git branch -m installation master # rename the installation branch to master + git push -f origin master # force push the local master branch to remote + git push origin --delete installation # delete the remote installation branch + ``` + +你需要根据需要修改以下文件: + 1. 设置 *pre-commit* hook 1. 在文件夹根目录, 运行 > pre-commit install 1. 修改 `LICENSE` 文件
本仓库使用 *MIT* 许可, 根据需要可以修改成其他许可 -使用 简单模式 的基本不需要修改,使用 安装模式 的可能需要较多修改,参见[这里](https://github.com/xinntao/BasicSR-examples/blob/installation/README_CN.md#As-a-Template) +由于安装模式需要包名,因此还需要将所有`basicsrexamples`名称修改为 YOUR_PACKAGE_NAME。 +以下是包含`basicsrexamples`名称的详细位置: + +1. The `basicsrexamples` folder +1. [setup.py](setup.py#L9);   [setup.py](setup.py#L48);  [setup.py](setup.py#L91) +1. [basicsrexamples/train.py](basicsrexamples/train.py#L4-L6) +1. [basicsrexamples/archs/\_\_init\_\_.py](basicsrexamples/archs/__init__.py#L11) +1. [basicsrexamples/data/\_\_init\_\_.py](basicsrexamples/data/__init__.py#L11) +1. [basicsrexamples/models/\_\_init\_\_.py](basicsrexamples/models/__init__.py#L11) + +你也需要修改文件 [setup.py](setup.py#L88-L113) 中的相应信息。 ## :e-mail: 联系 diff --git a/basicsrexamples/__init__.py b/basicsrexamples/__init__.py new file mode 100644 index 0000000..36731a0 --- /dev/null +++ b/basicsrexamples/__init__.py @@ -0,0 +1,5 @@ +# flake8: noqa +from .archs import * +from .data import * +from .models import * +from .version import __gitsha__, __version__ diff --git a/archs/__init__.py b/basicsrexamples/archs/__init__.py similarity index 77% rename from archs/__init__.py rename to basicsrexamples/archs/__init__.py index 42ec069..372c490 100644 --- a/archs/__init__.py +++ b/basicsrexamples/archs/__init__.py @@ -8,4 +8,4 @@ arch_folder = osp.dirname(osp.abspath(__file__)) arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] # import all the arch modules -_arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames] +_arch_modules = [importlib.import_module(f'basicsrexamples.archs.{file_name}') for file_name in arch_filenames] diff --git a/archs/example_arch.py b/basicsrexamples/archs/example_arch.py similarity index 100% rename from archs/example_arch.py rename to basicsrexamples/archs/example_arch.py diff --git a/data/__init__.py b/basicsrexamples/data/__init__.py similarity index 77% rename from data/__init__.py rename to basicsrexamples/data/__init__.py index b22fe2b..be92f0e 100644 --- a/data/__init__.py +++ b/basicsrexamples/data/__init__.py @@ -8,4 +8,4 @@ data_folder = osp.dirname(osp.abspath(__file__)) dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] # import all the dataset modules -_dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames] +_dataset_modules = [importlib.import_module(f'basicsrexamples.data.{file_name}') for file_name in dataset_filenames] diff --git a/data/example_dataset.py b/basicsrexamples/data/example_dataset.py similarity index 100% rename from data/example_dataset.py rename to basicsrexamples/data/example_dataset.py diff --git a/basicsrexamples/losses/__init__.py b/basicsrexamples/losses/__init__.py new file mode 100644 index 0000000..046e7f0 --- /dev/null +++ b/basicsrexamples/losses/__init__.py @@ -0,0 +1,11 @@ +import importlib +from os import path as osp + +from basicsr.utils import scandir + +# automatically scan and import loss modules for registry +# scan all the files that end with '_loss.py' under the loss folder +loss_folder = osp.dirname(osp.abspath(__file__)) +loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')] +# import all the loss modules +_model_modules = [importlib.import_module(f'basicsrexamples.losses.{file_name}') for file_name in loss_filenames] diff --git a/basicsrexamples/losses/example_loss.py b/basicsrexamples/losses/example_loss.py new file mode 100644 index 0000000..7939217 --- /dev/null +++ b/basicsrexamples/losses/example_loss.py @@ -0,0 +1,26 @@ +from torch import nn as nn +from torch.nn import functional as F + +from basicsr.utils.registry import LOSS_REGISTRY + + +@LOSS_REGISTRY.register() +class ExampleLoss(nn.Module): + """Example Loss. + + Args: + loss_weight (float): Loss weight for Example loss. Default: 1.0. + """ + + def __init__(self, loss_weight=1.0): + super(ExampleLoss, self).__init__() + self.loss_weight = loss_weight + + def forward(self, pred, target, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. + """ + return self.loss_weight * F.l1_loss(pred, target, reduction='mean') diff --git a/models/__init__.py b/basicsrexamples/models/__init__.py similarity index 77% rename from models/__init__.py rename to basicsrexamples/models/__init__.py index a92d6ac..0183977 100644 --- a/models/__init__.py +++ b/basicsrexamples/models/__init__.py @@ -8,4 +8,4 @@ model_folder = osp.dirname(osp.abspath(__file__)) model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] # import all the model modules -_model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames] +_model_modules = [importlib.import_module(f'basicsrexamples.models.{file_name}') for file_name in model_filenames] diff --git a/models/example_model.py b/basicsrexamples/models/example_model.py similarity index 100% rename from models/example_model.py rename to basicsrexamples/models/example_model.py diff --git a/basicsrexamples/train.py b/basicsrexamples/train.py new file mode 100644 index 0000000..530726b --- /dev/null +++ b/basicsrexamples/train.py @@ -0,0 +1,12 @@ +# flake8: noqa +import os.path as osp + +import basicsrexamples.archs +import basicsrexamples.data +import basicsrexamples.losses +import basicsrexamples.models +from basicsr.train import train_pipeline + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + train_pipeline(root_path) diff --git a/options/example_option.yml b/options/example_option.yml index 468af1d..1649ff3 100644 --- a/options/example_option.yml +++ b/options/example_option.yml @@ -65,9 +65,8 @@ train: # losses l1_opt: - type: L1Loss + type: ExampleLoss loss_weight: 1.0 - reduction: mean l2_opt: type: MSELoss diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..1c70c3b --- /dev/null +++ b/setup.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python + +from setuptools import find_packages, setup + +import os +import subprocess +import time + +version_file = 'basicsrexamples/version.py' + + +def readme(): + with open('README.md', encoding='utf-8') as f: + content = f.read() + return content + + +def get_git_hash(): + + def _minimal_ext_cmd(cmd): + # construct minimal environment + env = {} + for k in ['SYSTEMROOT', 'PATH', 'HOME']: + v = os.environ.get(k) + if v is not None: + env[k] = v + # LANGUAGE is used on win32 + env['LANGUAGE'] = 'C' + env['LANG'] = 'C' + env['LC_ALL'] = 'C' + out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + return out + + try: + out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) + sha = out.strip().decode('ascii') + except OSError: + sha = 'unknown' + + return sha + + +def get_hash(): + if os.path.exists('.git'): + sha = get_git_hash()[:7] + else: + sha = 'unknown' + + return sha + + +def write_version_py(): + content = """# GENERATED VERSION FILE +# TIME: {} +__version__ = '{}' +__gitsha__ = '{}' +version_info = ({}) +""" + sha = get_hash() + with open('VERSION', 'r') as f: + SHORT_VERSION = f.read().strip() + VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) + + version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) + with open(version_file, 'w') as f: + f.write(version_file_str) + + +def get_version(): + with open(version_file, 'r') as f: + exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__'] + + +def get_requirements(filename='requirements.txt'): + here = os.path.dirname(os.path.realpath(__file__)) + with open(os.path.join(here, filename), 'r') as f: + requires = [line.replace('\n', '') for line in f.readlines()] + return requires + + +if __name__ == '__main__': + write_version_py() + setup( + name='basicsrexamples', + version=get_version(), + description='BasicSR Examples', + long_description=readme(), + long_description_content_type='text/markdown', + author='Xintao Wang', + author_email='xintao.wang@outlook.com', + keywords='computer vision, pytorch, basicsr, image restoration, super-resolution', + url='https://github.com/xinntao/BasicSR-examples', + include_package_data=True, + packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), + classifiers=[ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + ], + license='BSD-3-Clause License', + setup_requires=['cython', 'numpy'], + install_requires=get_requirements(), + zip_safe=False) diff --git a/train.py b/train.py deleted file mode 100644 index b03e0f3..0000000 --- a/train.py +++ /dev/null @@ -1,10 +0,0 @@ -import os.path as osp - -import archs # noqa: F401 -import data # noqa: F401 -import models # noqa: F401 -from basicsr.train import train_pipeline - -if __name__ == '__main__': - root_path = osp.abspath(osp.join(__file__, osp.pardir)) - train_pipeline(root_path)