diff --git a/README.md b/README.md
index 784e8d6..550e565 100644
--- a/README.md
+++ b/README.md
@@ -48,13 +48,16 @@ There are two ways to use the python package of basicsr, which are provided in t
- :arrow_right: [installation mode](https://github.com/xinntao/BasicSR-examples/tree/installation): you need to install the project by running `python setup.py develop`. After installation, it is more convenient to import and use.
-As a simple introduction and explanation, we use the example of *simple mode*, but we recommend the *installation mode* in practical use.
+This `installation` branch uses the *installation mode* for illustration. We recommend using this mode for practical use.
```bash
-git clone https://github.com/xinntao/BasicSR-examples.git
+git clone -b installation https://github.com/xinntao/BasicSR-examples.git
cd BasicSR-examples
+python setup.py develop # need to install
```
+**Note that**: the installation model requires a package name for installation. We use `basicsrexamples` as the package name.
+
### Preliminary
Most deep-learning projects can be divided into the following parts:
@@ -106,7 +109,7 @@ Let's explain it separately in the following parts.
We need to implement a new dataset to fulfill our purpose. The dataset is used to feed the data into the model.
-An example of this dataset is in [data/example_dataset.py](data/example_dataset.py). It has the following steps.
+An example of this dataset is in [basicsrexamples/data/example_dataset.py](basicsrexamples/data/example_dataset.py). It has the following steps.
1. Read Ground-Truth (GT) images. BasicSR provides [FileClient](https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/file_client.py) for easily reading files in a folder, LMDB file and meta_info txt. In this example, we use the folder mode. For more reading modes, please refer to [basicsr/data](https://github.com/xinntao/BasicSR/tree/master/basicsr/data)
1. Synthesize low resolution images. We can directly implement the data procedures in the `__getitem__(self, index)` function, such as downsampling and adding JPEG compression. Many basic operations can be found in [[basicsr/data/degradations]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/degradations.py), [[basicsr/data/tranforms]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/transforms.py) ,and [[basicsr/data/data_util]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/data_util.py)
@@ -151,7 +154,7 @@ datasets:
#### :two: arch
-An example of architecture is in [archs/example_arch.py](archs/example_arch.py). It mainly builds the network structure.
+An example of architecture is in [basicsrexamples/archs/example_arch.py](basicsrexamples/archs/example_arch.py). It mainly builds the network structure.
**Note**:
@@ -174,7 +177,7 @@ network_g:
#### :three: model
-An example of model is in [models/example_model.py](models/example_model.py). It mainly builds the training process of a model.
+An example of model is in [basicsrexamples/models/example_model.py](basicsrexamples/models/example_model.py). It mainly builds the training process of a model.
In this file:
1. We inherit `SRModel` from basicsr. Many models have similar operations, so you can inherit and modify from [basicsr/models](https://github.com/xinntao/BasicSR/tree/master/basicsr/models). In this way, you can easily implement your ideas, such as GAN model, video model, *etc*.
@@ -222,7 +225,7 @@ train:
The whole training pipeline can reuse the [basicsr/train.py](https://github.com/xinntao/BasicSR/blob/master/basicsr/train.py) in BasicSR.
-Based on this, our [train.py](train.py) can be very concise:
+Based on this, our [basicsrexamples/train.py](basicsrexamples/train.py) can be very concise:
```python
import os.path as osp
@@ -243,7 +246,7 @@ if __name__ == '__main__':
So far, we have completed the development of our project. We can quickly check whether there is a bug through the `debug` mode:
```bash
-python train.py -opt options/example_option.yml --debug
+python basicsrexamples/train.py -opt options/example_option.yml --debug
```
With `--debug`, the program will enter the debug mode. In the debug mode, the program will output at each iteration, and perform validation every 8 iterations, so that you can easily know whether the program has a bug~
@@ -253,7 +256,7 @@ With `--debug`, the program will enter the debug mode. In the debug mode, the pr
After debugging, we can have the normal training.
```bash
-python train.py -opt options/example_option.yml
+python basicsrexamples/train.py -opt options/example_option.yml
```
If the training process is interrupted unexpectedly and the resume is required. Please use `--auto_resume` in the command:
@@ -268,13 +271,36 @@ So far, you have finished developing your own projects using `BasicSR`. Isn't it
You can use BasicSR-Examples as a template for your project. Here are some modifications you may need.
+As GitHub does not support a specific branch as a template, we need extra steps to use the `installation` branch as the template.
+
+1. Click `Use this template` and remember to check the `[ ] Include all branches` checkbox
+2. Change the installation branch to the master branch
+ ```bash
+ git clone -b installation YOUR_REPO # clone the installation branch
+ cd REPO_NAME
+ git branch -m installation master # rename the installation branch to master
+ git push -f origin master # force push the local master branch to remote
+ git push origin --delete installation # delete the remote installation branch
+ ```
+
+You may need to modify the following files:
+
1. Set up the *pre-commit* hook
1. In the root path, run:
> pre-commit install
1. Modify the `LICENSE`
This repository uses the *MIT* license, you may change it to other licenses
-The simple mode do not require many modifications. Those using the installation mode may need more modifications. See [here](https://github.com/xinntao/BasicSR-examples/blob/installation/README.md#As-a-Template)
+As the installation mode requires the package name, you also need to modify all the `basicsrexamples` names to YOUR_PACKAGE_NAME.
+Here are the detailed locations that contain the `basicsrexamples` name:
+1. The `basicsrexamples` folder
+1. [setup.py](setup.py#L9); [setup.py](setup.py#L48); [setup.py](setup.py#L91)
+1. [basicsrexamples/train.py](basicsrexamples/train.py#L4-L6)
+1. [basicsrexamples/archs/\_\_init\_\_.py](basicsrexamples/archs/__init__.py#L11)
+1. [basicsrexamples/data/\_\_init\_\_.py](basicsrexamples/data/__init__.py#L11)
+1. [basicsrexamples/models/\_\_init\_\_.py](basicsrexamples/models/__init__.py#L11)
+
+You also need to modify the corresponding information in the [setup.py](setup.py#L88-L113).
## :e-mail: Contact
diff --git a/README_CN.md b/README_CN.md
index 8fe94a8..6fbe548 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -46,13 +46,16 @@
- :arrow_right: [简单模式](https://github.com/xinntao/BasicSR-examples/tree/master): 项目的仓库不需要安装,就可以运行使用。但它有局限:不方便 import 复杂的层级关系;在其他位置也不容易访问本项目中的函数
- :arrow_right: [安装模式](https://github.com/xinntao/BasicSR-examples/tree/installation): 项目的仓库需要安装 `python setup.py develop`,安装之后 import 和使用都更加方便
-作为简单的入门和讲解, 我们使用*简单模式*的样例,但在实际使用中我们推荐*安装模式*。
+这个`installation`分支使用 *安装模式* 进行说明。在实际使用中我们也推荐 *安装模式*。
```bash
-git clone https://github.com/xinntao/BasicSR-examples.git
+git clone -b installation https://github.com/xinntao/BasicSR-examples.git
cd BasicSR-examples
+python setup.py develop # need to install
```
+**注意**: 安装模式需要一个 包的名字 来安装。在这里,我们使用 `basicsrexamples` 作为包名。
+
### 预备
大部分的深度学习项目,都可以分为以下几个部分:
@@ -102,7 +105,7 @@ python scripts/prepare_example_data.py
这个部分是用来确定喂给模型的数据的。
-这个 dataset 的例子在[data/example_dataset.py](data/example_dataset.py) 中,它完成了:
+这个 dataset 的例子在[basicsrexamples/data/example_dataset.py](basicsrexamples/data/example_dataset.py) 中,它完成了:
1. 我们读取 Ground-Truth (GT) 的图像。读取的操作,BasicSR 提供了[FileClient](https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/file_client.py), 可以方便地读取 folder, lmdb 和 meta_info txt 指定的文件。在这个例子中,我们通过读取 folder 来说明,更多的读取模式可以参考 [basicsr/data](https://github.com/xinntao/BasicSR/tree/master/basicsr/data)
1. 合成低分辨率的图像。我们直接可以在 `__getitem__(self, index)` 的函数中实现我们想要的操作,比如降采样和添加 JPEG 压缩。很多基本操作都可以在 [[basicsr/data/degradations]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/degradations.py), [[basicsr/data/tranforms]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/transforms.py) 和 [[basicsr/data/data_util]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/data_util.py) 中找到
1. 转换成 Torch Tensor,返回合适的信息
@@ -145,7 +148,7 @@ datasets:
#### :two: arch
-Architecture 的例子在 [archs/example_arch.py](archs/example_arch.py)中。它主要搭建了网络结构。
+Architecture 的例子在 [basicsrexamples/archs/example_arch.py](basicsrexamples/archs/example_arch.py)中。它主要搭建了网络结构。
**注意**:
1. 需要在 `ExampleArch` 前添加 `@ARCH_REGISTRY.register()`,以便注册好新写的 arch。这个操作主要用来防止出现同名的 arch,从而带来潜在的 bug
@@ -167,7 +170,7 @@ network_g:
#### :three: model
-Model 的例子在 [models/example_model.py](models/example_model.py)中。它主要搭建了模型的训练过程。
+Model 的例子在 [basicsrexamples/models/example_model.py](basicsrexamples/models/example_model.py)中。它主要搭建了模型的训练过程。
在这个文件中:
1. 我们从 basicsr 中继承了 `SRModel`。很多模型都有相似的操作,因此可以通过继承 [basicsr/models](https://github.com/xinntao/BasicSR/tree/master/basicsr/models) 中的模型来更方便地实现自己的想法,比如GAN模型,Video模型等
1. 使用了两个 Loss: L1 和 L2 (MSE) loss
@@ -213,7 +216,7 @@ train:
整个 training pipeline 可以复用 basicsr 里面的 [basicsr/train.py](https://github.com/xinntao/BasicSR/blob/master/basicsr/train.py)。
-基于此,我们的 [train.py](train.py)可以非常简洁。
+基于此,我们的 [basicsrexamples/train.py](basicsrexamples/train.py)可以非常简洁。
```python
import os.path as osp
@@ -234,7 +237,7 @@ if __name__ == '__main__':
至此,我们已经完成了我们这个项目的开发,下面可以通过 `debug` 模式来快捷地看看是否有问题:
```bash
-python train.py -opt options/example_option.yml --debug
+python basicsrexamples/train.py -opt options/example_option.yml --debug
```
只要带上 `--debug` 就进入 debug 模式。在 debug 模式中,程序每个iter都会输出,8个iter后就会进行validation,这样可以很方便地知道程序有没有bug啦~
@@ -244,7 +247,7 @@ python train.py -opt options/example_option.yml --debug
经过debug没有问题后,我们就可以正式训练了。
```bash
-python train.py -opt options/example_option.yml
+python basicsrexamples/train.py -opt options/example_option.yml
```
如果训练过程意外中断需要 resume, 则使用 `--auto_resume` 可以方便地自动resume:
@@ -258,13 +261,37 @@ python train.py -opt options/example_option.yml --auto_resume
你可以使用 BasicSR-Examples 作为你项目的模板。下面主要展示一下你可能需要的修改。
+由于 GitHub 不支持将特定分支作为模板,因此我们需要额外的步骤来使用`installation`分支作为模板。
+
+1. 点击 `Use this template`;记得勾选 `[ ] Include all branches`
+2. 将 安装分支 修改为 主分支
+ ```bash
+ git clone -b installation YOUR_REPO # clone the installation branch
+ cd REPO_NAME
+ git branch -m installation master # rename the installation branch to master
+ git push -f origin master # force push the local master branch to remote
+ git push origin --delete installation # delete the remote installation branch
+ ```
+
+你需要根据需要修改以下文件:
+
1. 设置 *pre-commit* hook
1. 在文件夹根目录, 运行
> pre-commit install
1. 修改 `LICENSE` 文件
本仓库使用 *MIT* 许可, 根据需要可以修改成其他许可
-使用 简单模式 的基本不需要修改,使用 安装模式 的可能需要较多修改,参见[这里](https://github.com/xinntao/BasicSR-examples/blob/installation/README_CN.md#As-a-Template)
+由于安装模式需要包名,因此还需要将所有`basicsrexamples`名称修改为 YOUR_PACKAGE_NAME。
+以下是包含`basicsrexamples`名称的详细位置:
+
+1. The `basicsrexamples` folder
+1. [setup.py](setup.py#L9); [setup.py](setup.py#L48); [setup.py](setup.py#L91)
+1. [basicsrexamples/train.py](basicsrexamples/train.py#L4-L6)
+1. [basicsrexamples/archs/\_\_init\_\_.py](basicsrexamples/archs/__init__.py#L11)
+1. [basicsrexamples/data/\_\_init\_\_.py](basicsrexamples/data/__init__.py#L11)
+1. [basicsrexamples/models/\_\_init\_\_.py](basicsrexamples/models/__init__.py#L11)
+
+你也需要修改文件 [setup.py](setup.py#L88-L113) 中的相应信息。
## :e-mail: 联系
diff --git a/basicsrexamples/__init__.py b/basicsrexamples/__init__.py
new file mode 100644
index 0000000..36731a0
--- /dev/null
+++ b/basicsrexamples/__init__.py
@@ -0,0 +1,5 @@
+# flake8: noqa
+from .archs import *
+from .data import *
+from .models import *
+from .version import __gitsha__, __version__
diff --git a/archs/__init__.py b/basicsrexamples/archs/__init__.py
similarity index 77%
rename from archs/__init__.py
rename to basicsrexamples/archs/__init__.py
index 42ec069..372c490 100644
--- a/archs/__init__.py
+++ b/basicsrexamples/archs/__init__.py
@@ -8,4 +8,4 @@
arch_folder = osp.dirname(osp.abspath(__file__))
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
# import all the arch modules
-_arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames]
+_arch_modules = [importlib.import_module(f'basicsrexamples.archs.{file_name}') for file_name in arch_filenames]
diff --git a/archs/example_arch.py b/basicsrexamples/archs/example_arch.py
similarity index 100%
rename from archs/example_arch.py
rename to basicsrexamples/archs/example_arch.py
diff --git a/data/__init__.py b/basicsrexamples/data/__init__.py
similarity index 77%
rename from data/__init__.py
rename to basicsrexamples/data/__init__.py
index b22fe2b..be92f0e 100644
--- a/data/__init__.py
+++ b/basicsrexamples/data/__init__.py
@@ -8,4 +8,4 @@
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
# import all the dataset modules
-_dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames]
+_dataset_modules = [importlib.import_module(f'basicsrexamples.data.{file_name}') for file_name in dataset_filenames]
diff --git a/data/example_dataset.py b/basicsrexamples/data/example_dataset.py
similarity index 100%
rename from data/example_dataset.py
rename to basicsrexamples/data/example_dataset.py
diff --git a/basicsrexamples/losses/__init__.py b/basicsrexamples/losses/__init__.py
new file mode 100644
index 0000000..046e7f0
--- /dev/null
+++ b/basicsrexamples/losses/__init__.py
@@ -0,0 +1,11 @@
+import importlib
+from os import path as osp
+
+from basicsr.utils import scandir
+
+# automatically scan and import loss modules for registry
+# scan all the files that end with '_loss.py' under the loss folder
+loss_folder = osp.dirname(osp.abspath(__file__))
+loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')]
+# import all the loss modules
+_model_modules = [importlib.import_module(f'basicsrexamples.losses.{file_name}') for file_name in loss_filenames]
diff --git a/basicsrexamples/losses/example_loss.py b/basicsrexamples/losses/example_loss.py
new file mode 100644
index 0000000..7939217
--- /dev/null
+++ b/basicsrexamples/losses/example_loss.py
@@ -0,0 +1,26 @@
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import LOSS_REGISTRY
+
+
+@LOSS_REGISTRY.register()
+class ExampleLoss(nn.Module):
+ """Example Loss.
+
+ Args:
+ loss_weight (float): Loss weight for Example loss. Default: 1.0.
+ """
+
+ def __init__(self, loss_weight=1.0):
+ super(ExampleLoss, self).__init__()
+ self.loss_weight = loss_weight
+
+ def forward(self, pred, target, **kwargs):
+ """
+ Args:
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
+ """
+ return self.loss_weight * F.l1_loss(pred, target, reduction='mean')
diff --git a/models/__init__.py b/basicsrexamples/models/__init__.py
similarity index 77%
rename from models/__init__.py
rename to basicsrexamples/models/__init__.py
index a92d6ac..0183977 100644
--- a/models/__init__.py
+++ b/basicsrexamples/models/__init__.py
@@ -8,4 +8,4 @@
model_folder = osp.dirname(osp.abspath(__file__))
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
# import all the model modules
-_model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames]
+_model_modules = [importlib.import_module(f'basicsrexamples.models.{file_name}') for file_name in model_filenames]
diff --git a/models/example_model.py b/basicsrexamples/models/example_model.py
similarity index 100%
rename from models/example_model.py
rename to basicsrexamples/models/example_model.py
diff --git a/basicsrexamples/train.py b/basicsrexamples/train.py
new file mode 100644
index 0000000..530726b
--- /dev/null
+++ b/basicsrexamples/train.py
@@ -0,0 +1,12 @@
+# flake8: noqa
+import os.path as osp
+
+import basicsrexamples.archs
+import basicsrexamples.data
+import basicsrexamples.losses
+import basicsrexamples.models
+from basicsr.train import train_pipeline
+
+if __name__ == '__main__':
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
+ train_pipeline(root_path)
diff --git a/options/example_option.yml b/options/example_option.yml
index 468af1d..1649ff3 100644
--- a/options/example_option.yml
+++ b/options/example_option.yml
@@ -65,9 +65,8 @@ train:
# losses
l1_opt:
- type: L1Loss
+ type: ExampleLoss
loss_weight: 1.0
- reduction: mean
l2_opt:
type: MSELoss
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..1c70c3b
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,107 @@
+#!/usr/bin/env python
+
+from setuptools import find_packages, setup
+
+import os
+import subprocess
+import time
+
+version_file = 'basicsrexamples/version.py'
+
+
+def readme():
+ with open('README.md', encoding='utf-8') as f:
+ content = f.read()
+ return content
+
+
+def get_git_hash():
+
+ def _minimal_ext_cmd(cmd):
+ # construct minimal environment
+ env = {}
+ for k in ['SYSTEMROOT', 'PATH', 'HOME']:
+ v = os.environ.get(k)
+ if v is not None:
+ env[k] = v
+ # LANGUAGE is used on win32
+ env['LANGUAGE'] = 'C'
+ env['LANG'] = 'C'
+ env['LC_ALL'] = 'C'
+ out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
+ return out
+
+ try:
+ out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
+ sha = out.strip().decode('ascii')
+ except OSError:
+ sha = 'unknown'
+
+ return sha
+
+
+def get_hash():
+ if os.path.exists('.git'):
+ sha = get_git_hash()[:7]
+ else:
+ sha = 'unknown'
+
+ return sha
+
+
+def write_version_py():
+ content = """# GENERATED VERSION FILE
+# TIME: {}
+__version__ = '{}'
+__gitsha__ = '{}'
+version_info = ({})
+"""
+ sha = get_hash()
+ with open('VERSION', 'r') as f:
+ SHORT_VERSION = f.read().strip()
+ VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
+
+ version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
+ with open(version_file, 'w') as f:
+ f.write(version_file_str)
+
+
+def get_version():
+ with open(version_file, 'r') as f:
+ exec(compile(f.read(), version_file, 'exec'))
+ return locals()['__version__']
+
+
+def get_requirements(filename='requirements.txt'):
+ here = os.path.dirname(os.path.realpath(__file__))
+ with open(os.path.join(here, filename), 'r') as f:
+ requires = [line.replace('\n', '') for line in f.readlines()]
+ return requires
+
+
+if __name__ == '__main__':
+ write_version_py()
+ setup(
+ name='basicsrexamples',
+ version=get_version(),
+ description='BasicSR Examples',
+ long_description=readme(),
+ long_description_content_type='text/markdown',
+ author='Xintao Wang',
+ author_email='xintao.wang@outlook.com',
+ keywords='computer vision, pytorch, basicsr, image restoration, super-resolution',
+ url='https://github.com/xinntao/BasicSR-examples',
+ include_package_data=True,
+ packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
+ classifiers=[
+ 'Development Status :: 4 - Beta',
+ 'License :: OSI Approved :: Apache Software License',
+ 'Operating System :: OS Independent',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.7',
+ 'Programming Language :: Python :: 3.8',
+ ],
+ license='BSD-3-Clause License',
+ setup_requires=['cython', 'numpy'],
+ install_requires=get_requirements(),
+ zip_safe=False)
diff --git a/train.py b/train.py
deleted file mode 100644
index b03e0f3..0000000
--- a/train.py
+++ /dev/null
@@ -1,10 +0,0 @@
-import os.path as osp
-
-import archs # noqa: F401
-import data # noqa: F401
-import models # noqa: F401
-from basicsr.train import train_pipeline
-
-if __name__ == '__main__':
- root_path = osp.abspath(osp.join(__file__, osp.pardir))
- train_pipeline(root_path)