-
Notifications
You must be signed in to change notification settings - Fork 223
【Hackathon 8th No.23】Improved Training of Wasserstein GANs 论文复现 -part #1146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 40 commits
Commits
Show all changes
66 commits
Select commit
Hold shift + click to select a range
0b02a82
Create wgangp.yaml
XvLingWYY 83ecd01
Add files via upload
XvLingWYY 0f56f37
Add files via upload
XvLingWYY 51befb3
Add files via upload
XvLingWYY 532f852
Delete examples/wgangp/conf/wgangp.yaml
XvLingWYY f0da2e9
Update wgangp_cifar10.yaml
XvLingWYY 5f18b0b
Update functions.py
XvLingWYY c3d753d
Update wgangp_cifar10.py
XvLingWYY 84230d1
Update model.py
XvLingWYY 8331599
Update wgangp_cifar10.py
XvLingWYY 326da7a
Update wgangp_mnist.py
XvLingWYY b06cee6
Update wgangp_toy.py
XvLingWYY 93118bf
Update functions.py
XvLingWYY ba7481b
Update model.py
XvLingWYY afbd7d8
Update wgangp_cifar10.py
XvLingWYY e3fddbd
Update wgangp_mnist.py
XvLingWYY 3068ff6
Update wgangp_toy.py
XvLingWYY 0c2b21c
Update wgangp_cifar10.yaml
XvLingWYY bdf54d2
Update model.py
XvLingWYY c0e34b3
Update wgan_gp.md
XvLingWYY 127650b
Update wgangp_cifar10.yaml
XvLingWYY 617c9f3
Update model.py
XvLingWYY acf302f
Update wgan_gp.md
XvLingWYY 9549f1a
Update wgan_gp.md
XvLingWYY 6871216
Update wgan_gp.md
XvLingWYY 2c38a16
Add files via upload
XvLingWYY 314321e
Update wgan_gp.md
XvLingWYY 26adbd1
Add files via upload
XvLingWYY 0914ba6
Add files via upload
XvLingWYY aef0367
Add files via upload
XvLingWYY 1e14469
Add files via upload
XvLingWYY d0627c1
Add files via upload
XvLingWYY 70d7852
Add files via upload
XvLingWYY 4194dcc
Add files via upload
XvLingWYY 5933fbf
Update wgangp_cifar10.yaml
XvLingWYY 6cdbc38
Update wgangp_cifar10.yaml
XvLingWYY 0442eee
Add files via upload
XvLingWYY 15f855c
Add files via upload
XvLingWYY 865b731
Add files via upload
XvLingWYY cbe55e3
Add files via upload
XvLingWYY d9dd40b
Merge branch 'PaddlePaddle:develop' into develop
XvLingWYY 6a4dd20
Delete examples/wgangp directory
XvLingWYY 59181b3
Create wgangp_cifar10.yaml
XvLingWYY ea46640
Add files via upload
XvLingWYY e5ce98f
Add files via upload
XvLingWYY 30bd672
Add files via upload
XvLingWYY 16eb6df
Add files via upload
XvLingWYY 8bae84a
Add files via upload
XvLingWYY 5b7b437
Add files via upload
XvLingWYY af72624
Add files via upload
XvLingWYY 3cbafe5
Add files via upload
XvLingWYY d0d8249
Add files via upload
XvLingWYY 5b43407
Add files via upload
XvLingWYY cf80107
Add files via upload
XvLingWYY 56e5d68
Delete examples/wgangp_toy_model.py
XvLingWYY 6c872c8
Add files via upload
XvLingWYY 0722c3c
Add files via upload
XvLingWYY 71a103f
Add files via upload
XvLingWYY 071fbd7
Add files via upload
XvLingWYY a3ec077
Merge branch 'PaddlePaddle:develop' into develop
XvLingWYY 0cc7ef7
Delete docs/zh/examples/wgan_gp.md
XvLingWYY 5f6489c
Delete docs/index.md
XvLingWYY 7811c40
Delete mkdocs.yml
XvLingWYY 353d6e7
Add files via upload
XvLingWYY 3885f41
Add files via upload
XvLingWYY d8d9ca1
Add files via upload
XvLingWYY File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,349 @@ | ||
# WGANGP | ||
|
||
!!! note | ||
|
||
1. 运行之前将[Cifar10](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz)下载,并更新wgangp_cifar10.yaml中的data_path | ||
2. 运行之前将[MINST](http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz)下载,并更新wgangp_mnist.yaml中的data_path | ||
|
||
=== "模型训练命令" | ||
```sh | ||
python wgangp_cifar10.py | ||
``` | ||
```sh | ||
python wgangp_mnist.py | ||
``` | ||
```sh | ||
python wgangp_toy.py | ||
``` | ||
|
||
=== "模型评估命令" | ||
```sh | ||
python wgangp_cifar10.py mode=eval | ||
``` | ||
```sh | ||
python wgangp_mnist.py mode=eval | ||
``` | ||
```sh | ||
python wgangp_toy.py mode=eval | ||
``` | ||
|
||
|
||
| 预训练模型 | | ||
|:-----------------------------------| | ||
| wgangp_cifar10_pretrained.pdparams | | ||
| wgangp_mnist_pretrained.pdparams | | ||
| wgangp_toy_pretrained.pdparams | | ||
## 1. 背景简介 | ||
在数字图像处理和机器学习领域,生成对抗网络(GANs)因其卓越的图像生成能力而受到广泛关注。然而,传统的GAN架构在训练过程中可能会遇到不稳定的问题,尤其是在生成高分辨率或复杂场景的图像时。为了解决这些问题,研究人员提出了带有梯度惩罚的Wasserstein生成对抗网络(WGAN-GP),它不仅增强了训练过程的稳定性,还显著提升了生成图像的质量。 | ||
|
||
WGAN-GP通过改进损失函数来最小化真实数据分布与生成数据分布之间的差异,并引入梯度惩罚机制以确保训练过程中的平滑性和稳定性。这种优化方法克服了传统GAN中常见的模式崩溃问题,同时促进了更高效的训练和更逼真的图像生成。 | ||
|
||
## 2. 模型原理 | ||
WGAN-GP提出一种替代权重剪裁的方法:对评论者输入梯度的范数施加惩罚。在几乎无需超参数调整的情况下稳定训练多种GAN架构. | ||
|
||
### 2.1 模型结构 | ||
|
||
WGAN-GP是一个条件对抗网络,包含了一个noise-to-image的生成器和一个CNN的判别器。下面显示了模型的整体结构。 | ||
|
||
``` | ||
noise===>generator===>fake_image== | ||
==>discriminator===>Wasserstein Loss+Gradient Penalty | ||
image== | ||
``` | ||
|
||
- `Generator`是一种卷积神经网络。 | ||
|
||
- `Discriminator`是由卷积块组成的模型。输入图像,输出图像的真实性分数。 | ||
|
||
### 2.2 损失函数 | ||
|
||
判别器的损失函数采用了Wasserstein损失和梯度惩罚。其表达式为: | ||
|
||
$$ | ||
L_d = \underset{\tilde{x} \sim \mathbb{P}_g}{\mathbb{E}} D(\tilde{x}) - \underset{x \sim \mathbb{P}_r}{\mathbb{E}}D(x) + \lambda \underset{\hat{x} \sim \mathbb{P}_{\hat{x}}}{\mathbb{E}} \left[ \left( \| \nabla_{\hat{x}} D(\hat{x}) \|_2 - 1 \right)^2 \right] | ||
$$ | ||
|
||
其中$\mathbb{P}_g$是生成器的分布,$\mathbb{P}_r$是真实数据的分布,$\mathbb{P}_{\hat{x}}$是来自$\mathbb{P}_g$和$\mathbb{P}_r$的混合插值样本。 | ||
|
||
生成器的损失函数是对抗性损失[$- \underset{\tilde{x} \sim \mathbb{P}_g}{\mathbb{E}}D(\tilde{x})$]和内容损失(MAE、MSE)的组合。其表达式为: | ||
|
||
$$ | ||
L_g = - \underset{\tilde{x} \sim \mathbb{P}_g}{\mathbb{E}}D(\tilde{x}) | ||
$$ | ||
其中$\mathbb{P}_g$是生成器的分布 | ||
|
||
## 3. 模型构建 | ||
|
||
接下来开始讲解如何使用PaddleScience框架实现WGAN-GP。以下内容仅对关键步骤进行阐述,其余细节请参考 [API文档](https://paddlescience-docs.readthedocs.io/zh-cn/latest/zh/api/arch/)。 | ||
|
||
### 3.1 数据集介绍 | ||
|
||
数据集采用了[Cifar10](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz)数据集、[MNIST](http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz)和玩具数据集(swissroll/8gaussians/25gaussians)。 | ||
|
||
Cifar10数据集包含60000张32x32彩色图像,共分为10个类别,每个类别6000张图像。 | ||
|
||
Cifar10数据集有3个版本 | ||
|
||
| Version | Size | md5sum | | ||
|:-----------------|:------------|:-------------------------------------| | ||
| CIFAR-100 python | 161 MB | eb9058c3a382ffc7106e4002c42a8d85 | | ||
| CIFAR-100 Matlab | 175 MB | 6a4bfa1dcd5c9453dda6bb54194911f4 | | ||
| CIFAR-100 binary | 161 MB | 03b5dce01913d631647c71ecec9e9cb8 | | ||
|
||
本实现使用的为CIFAR-100 python版本 | ||
|
||
MNIST数据集包含60000张28x28灰度图像,共分为10个类别,每个类别6000张图像。 | ||
|
||
玩具数据集 | ||
|
||
Swissroll:三维非线性流形数据集,呈现连续卷曲的螺旋结构, | ||
|
||
8gaussians:二维合成数据集,包含八个对称分布的高斯簇,各簇中心均匀分布于圆周, | ||
|
||
25gaussians:高密度高斯混合数据集,由25个规则排列的二维高斯分布构成,簇间距紧凑。 | ||
|
||
### 3.2 构建dataset API | ||
|
||
由于Cifar10数据集由5个数据文件组成,由于数据集组织方式,我们无法直接使用PaddleScience内置的dataset API,所以先把所有数据读取出来,再使用```ppsci.data.dataset.array_dataset.NamedArrayDataset```。 | ||
|
||
下面给出Cifar10数据集读取的代码: | ||
``` py linenums="152" | ||
--8<-- | ||
examples/wgangp/functions.py:152:160 | ||
--8<-- | ||
``` | ||
其中`data_path`传入的是CIFAR-10的路径。 | ||
|
||
下面给出dataloader的配置代码: | ||
``` py linenums="108" | ||
--8<-- | ||
examples/wgangp/wgangp_cifar10.py:108:123 | ||
--8<-- | ||
``` | ||
|
||
由于MNIST数据集无法直接使用PaddleScience内置的dataset API,所以先把所有数据读取出来,再使用```ppsci.data.dataset.array_dataset.NamedArrayDataset```。 | ||
|
||
下面给出MNIST数据集读取的代码: | ||
``` py linenums="352" | ||
--8<-- | ||
examples/wgangp/functions.py:352:361 | ||
--8<-- | ||
``` | ||
|
||
下面给出dataloader的配置代码: | ||
``` py linenums="97" | ||
--8<-- | ||
examples/wgangp/wgangp_mnist.py:97:111 | ||
--8<-- | ||
``` | ||
|
||
由于玩具数据集无法直接使用PaddleScience内置的dataset API,所以先把所有数据生成出来,再使用```ppsci.data.dataset.array_dataset.NamedArrayDataset```。 | ||
|
||
下面给出玩具数据集的生成代码 | ||
``` py linenums="177" | ||
--8<-- | ||
examples/wgangp/functions.py:177:219 | ||
--8<-- | ||
``` | ||
|
||
下面给出dataloader的配置代码: | ||
``` py linenums="89" | ||
--8<-- | ||
examples/wgangp/wgangp_toy.py:89:103 | ||
--8<-- | ||
``` | ||
|
||
### 3.3 模型构建 | ||
|
||
本案例的WGAN-GP没有被内置在PaddleScience中,需要额外实现,因此我们自定义了`WganGpCifar10Generator`和`WganGpCifar10Discriminator`、`WganGpMnistGenerator`和`WganGpMnistDiscriminator`、`WganGpToyGenerator`和`WganGpToyDiscriminator`。 | ||
|
||
模型的构建代码如下: | ||
|
||
``` py | ||
--8<-- | ||
examples/wgangp/wgangp_cifar10.py:96:98 | ||
examples/wgangp/wgangp_mnist.py:87:89 | ||
examples/wgangp/wgangp_toy.py:79:81 | ||
--8<-- | ||
``` | ||
|
||
参数配置如下: | ||
|
||
``` yaml | ||
--8<-- | ||
examples/wgangp/conf/wgangp_cifar10.yaml:29:43 | ||
examples/wgangp/conf/wgangp_mnist.yaml:29:38 | ||
examples/wgangp/conf/wgangp_toy.yaml:29:37 | ||
--8<-- | ||
``` | ||
|
||
### 3.4 自定义loss | ||
|
||
WGAN-GP的损失函数较复杂,需要我们自定义实现。PaddleScience提供了用于自定loss函数的API——`ppsci.loss.FunctionalLoss`。方法为先定义loss函数,再将函数名作为参数传给 `FunctionalLoss`。需要注意,自定义loss函数的输入输出需要是字典的格式。 | ||
|
||
#### 3.4.1 Generator的loss | ||
|
||
Cifar10_Generator的loss包含了对抗性损失和分类损失。这两项loss都有对应的权重,如果某一项 loss 的权重为 0,则表示训练中不添加该 loss 项。 | ||
|
||
``` py linenums="16" | ||
--8<-- | ||
examples/wgangp/functions.py:16:44 | ||
--8<-- | ||
``` | ||
|
||
MNIST_Generator的loss只包含了对抗性损失。 | ||
``` py linenums="297" | ||
--8<-- | ||
examples/wgangp/functions.py:297:312 | ||
--8<-- | ||
``` | ||
Toy_Generator的loss只包含了对抗性损失。 | ||
``` py linenums="221" | ||
--8<-- | ||
examples/wgangp/functions.py:221:237 | ||
--8<-- | ||
``` | ||
|
||
#### 3.4.2 Discriminator的loss | ||
|
||
Cifar10_Discriminator的loss包含了Wasserstein损失和梯度惩罚以及分类损失。其中,只有分类损失项有权重参数。 | ||
``` py linenums="46" | ||
--8<-- | ||
examples/wgangp/functions.py:46:95 | ||
--8<-- | ||
``` | ||
|
||
MNIST_Discriminator的loss包含了Wasserstein损失和梯度惩罚。 | ||
``` py linenums="314" | ||
--8<-- | ||
examples/wgangp/functions.py:314:350 | ||
--8<-- | ||
``` | ||
|
||
Toy_Discriminator的loss包含了Wasserstein损失和梯度惩罚。 | ||
``` py linenums="239" | ||
--8<-- | ||
examples/wgangp/functions.py:239:275 | ||
--8<-- | ||
``` | ||
|
||
### 3.5 约束构建 | ||
|
||
所有案例均使用`ppsci.constraint.SupervisedConstraint`构建约束。 | ||
|
||
构建代码如下: | ||
|
||
``` py | ||
--8<-- | ||
examples/wgangp/wgangp_cifar10.py:125:142 | ||
examples/wgangp/wgangp_mnist.py:113:129 | ||
examples/wgangp/wgangp_toy.py:105:121 | ||
--8<-- | ||
``` | ||
|
||
### 3.6 优化器构建 | ||
|
||
WGANGP使用Adam优化器,可直接调用`ppsci.optimizer.Adam`构建,代码如下: | ||
|
||
``` py | ||
--8<-- | ||
examples/wgangp/wgangp_cifar10.py:114:159 | ||
examples/wgangp/wgangp_mnist.py:131:134 | ||
examples/wgangp/wgangp_toy.py:123:127 | ||
--8<-- | ||
``` | ||
|
||
### 3.7 Solver构建 | ||
|
||
将构建好的模型、约束、优化器和其它参数传递给 `ppsci.solver.Solver`。 | ||
|
||
``` py | ||
--8<-- | ||
examples/wgangp/wgangp_cifar10.py:161:179 | ||
examples/wgangp/wgangp_mnist.py:136:152 | ||
examples/wgangp/wgangp_toy.py:129:145 | ||
--8<-- | ||
``` | ||
|
||
### 3.8 模型训练 | ||
|
||
``` py | ||
--8<-- | ||
examples/wgangp/wgangp_cifar10.py:182:187 | ||
examples/wgangp/wgangp_mnist.py:154:160 | ||
examples/wgangp/wgangp_toy.py:147:153 | ||
--8<-- | ||
``` | ||
|
||
### 3.9 自定义metric | ||
|
||
案例中只有针对Cifar10的案例有评估指标为Inception Score,MNIST和Toy案例没有评估指标。由于metric为空会报错所以自定义了一个无效metric | ||
所以我们额外实现了两个metric | ||
|
||
PaddleScience提供了用于自定metric函数的API——`ppsci.metric.FunctionalMetric`。方法为先定义metric函数,再将函数名作为参数传给 `FunctionalMetric`。需要注意,自定义metric函数的输入输出需要是字典的格式。 | ||
|
||
Inception Score的实现代码如下: | ||
``` py linenums="97" | ||
--8<-- | ||
examples/wgangp/functions.py:97:139 | ||
--8<-- | ||
``` | ||
|
||
invalid_metric的代码如下 | ||
``` py linenums="373" | ||
--8<-- | ||
examples/wgangp/functions.py:373:374 | ||
--8<-- | ||
``` | ||
|
||
### 3.10 Validator构建 | ||
|
||
本案例使用`ppsci.validate.SupervisedValidator`构建评估器。 | ||
|
||
``` py | ||
--8<-- | ||
examples/wgangp/wgangp_cifar10.py:53:63 | ||
examples/wgangp/wgangp_mnist.py:45:54 | ||
examples/wgangp/wgangp_toy.py:45:52 | ||
--8<-- | ||
``` | ||
|
||
### 3.11 模型评估 | ||
|
||
将模型、评估器和权重路径传递给`ppsci.solver.Solver`后,通过`solver.eval()`启动评估。 | ||
|
||
``` py | ||
--8<-- | ||
examples/wgangp/wgangp_cifar10.py:65:74 | ||
examples/wgangp/wgangp_mnist.py:56:65 | ||
examples/wgangp/wgangp_toy.py:54:63 | ||
--8<-- | ||
``` | ||
|
||
### 3.12 可视化 | ||
|
||
评估完成后,我们以图片的形式对结果进行可视化,代码如下: | ||
|
||
``` py | ||
--8<-- | ||
examples/wgangp/wgangp_cifar10.py:76:92 | ||
examples/wgangp/wgangp_mnist.py:67:83 | ||
examples/wgangp/wgangp_toy.py:65:75 | ||
--8<-- | ||
``` | ||
|
||
## 4. 完整代码 | ||
|
||
``` py | ||
--8<-- | ||
examples/wgangp/wgangp_cifar10.py | ||
examples/wgangp/wgangp_mnist.py | ||
examples/wgangp/wgangp_toy.py | ||
--8<-- | ||
``` | ||
|
||
## 6. 参考文献 | ||
|
||
- [Improved Training of Wasserstein GANs 论文](https://arxiv.org/abs/1704.00028) | ||
|
||
- [参考代码](https://github.com/igul222/improved_wgan_training) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请一并修改mkdocs.yml和docs/index.md
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的谢谢老师,已完成此项修改。