-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
561114d
commit b6b3187
Showing
5 changed files
with
165 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,73 +1,177 @@ | ||
# GANTree | ||
# GAN Tree | ||
This repository contains code for the paper `GAN-Tree: An Incrementally Learned Hierarchical Generative Framework for Multi-Modal Data Distributions`, published in ICCV 2019. | ||
|
||
A hierarchical tree based architecture over Generative Adversarial Networks (GANs) for generation of samples from multi-modal distributions through discontinuous embedding manifold, having a flexibility to tweak the degree of interpolatability across the modes in the latent space. | ||
The full paper can be found [here](). If you find our research work helpful, please consider citing: | ||
|
||
```cite | ||
``` | ||
|
||
GAN Tree Algorithm: | ||
## Contents | ||
1. [Overview of the Model](#1-overview-of-the-model) | ||
2. [Setup Instructions and Dependencies](#2-setup-instructions-and-dependencies) | ||
3. [Training GAN Tree from Scratch](#3-training-gan-tree-from-scratch) | ||
3.1. [Parser Arguments](#31-parser-arguments) | ||
4. [Repository Overview](#4-repository-overview) | ||
5. [Experiments](#5-experiments) | ||
5.1. [GAN Tree for Single Channel Dataset](#51-gan-tree-for-single-channel-dataset) | ||
5.2. [GAN Tree for Single Channel Mixed Dataset](#52-gan-tree-for-single-channel-mixed-dataset) | ||
5.3. [GAN Tree for Multiple Channel Mixed Dataset](#53-gan-tree-for-multiple-channel-mixed-dataset) | ||
5.4. [Incremental GAN Tree](#54-incremental-gan-tree) | ||
6. [Results Obtained](#6-results-obtained) | ||
6.1. [Generated GAN Tree for Single Channel Mixed Dataset](#61-generated-gan-tree-for-single-channel-mixed-dataset) | ||
6.2. [Generated GAN Tree for Multiple Channel Mixed Dataset](#62-generated-gan-tree-for-multiple-channel-mixed-dataset) | ||
6.3. [Generated i-GANTree for Adding Digit 5](#63-generated-iGANTree-for-adding-digit-5) | ||
7. [Guidelines for Contributors](#7-guidelines-for-contributors) | ||
7.1. [Reporting Bugs and Opening Issues](#71-reporting-bugs-and-opening-issues) | ||
7.2. [Pull Requests](#72-pull-requests) | ||
8. [License](#7-license) | ||
|
||
X = {0: all_x_data} | ||
|
||
CREATE node0: | ||
node0.E = Encoder() | ||
node0.D = Decoder() | ||
node0.Di = Disc() | ||
node0.mu = [0, 0] | ||
node0.cov = [[1, 0], [0, 1]] | ||
for node0: for each step: | ||
TRAIN node0.E and node0.D over cyclic_loss | ||
TRAIN node0.Di over disc_adv loss | ||
TRAIN node0.D over gen_adv loss | ||
node0.gmm = GaussianMixture(n_components=2) | ||
node0.gmm.fit(node0.encode(X[0])) | ||
split X into 2 labels: 1 and 2 | ||
CREATE node1 and node2 from node0: | ||
E = node0.E.copy() | ||
node1.E = E | ||
node2.E = E | ||
node1.D = node0.D.copy() | ||
node2.D = node0.D.copy() | ||
node1.Di = node0.Di.copy() | ||
node2.Di = node0.Di.copy() | ||
## 1. Overview of the Model | ||
|
||
 | ||
|
||
The overall GAN Tree architecture is given in the above figure. For further details about the architecture and training algorithm, please go through the paper. | ||
|
||
## 2. Setup Instructions and Dependencies | ||
You may setup the repository on your local machine by either downloading it or running the following line on `cmd prompt`: | ||
|
||
``` Batchfile | ||
git clone https://github.com/val-iisc/GANTree.git | ||
``` | ||
|
||
All dependencies required by this repo can be downloaded by creating a virtual or conda environment with Python 2.7 and running | ||
|
||
``` Batchfile | ||
pip install -r requirements.txt | ||
``` | ||
|
||
The `LSUN Bedroom Scene` and `CelebA` required for training can be found in the Google Drive link given inside the `data/datasets.txt` file. | ||
|
||
> 1. Make sure to have the proper CUDA version installed for PyTorch 0.4.1. | ||
> 2. The code will not run on Windows since Pytorch v0.4.1 with Python 2.7 is not supported on it. | ||
## 3. Training GAN Tree from Scratch | ||
To train your own GAN Tree from scratch, run | ||
|
||
```Batchfile | ||
python GANTree.py -hp path/to/hyperparams -en exp_name | ||
``` | ||
|
||
+ The hyper parameters for your experiment should be set in your `hyperparams.py` file (check `src/hyperparams` for examples). | ||
+ The training script will create a folder `experiments/exp_name` as specified in your `hyperparams` file or argument passed in the command line to the `-en` flag. | ||
+ This folder will contain all data related to the experiment such as generated images, logs, plots, and weights. It will also contain a dump of the hyperparameters. | ||
|
||
>1. Training will require a large amount of RAM. | ||
>2. Saving Gnodes requires ample amount of space (~500 MB per node). | ||
### 3.1. Parser Arguments | ||
|
||
The following argument flags are available for training: | ||
|
||
+ `-hp`, `--hyperparams`: path to the `hyperparam.py` file to be used for training. | ||
+ `-en`, `--exp_name`: experiment name. | ||
+ `-g`, `--gpu`: index of the gpu to be used. The default value is `0`. | ||
+ `-t`, `--tensorboard`: if `true`, start Tensorboard with the experiment. The default value is `false`. | ||
+ `-r`, `--resume`: if `true`, the training resumes from the latest step. The default value is `false`. | ||
+ `-d`, `--delete`: delete the entities from the experiment file. The default value is `[]`. The choices are `['logs', 'weights', 'results', 'all']`. | ||
+ `-w`, `--weights`: the weight type to load if resume flag is provided. The default value is `iter`. The choices are `['iter', 'best_gen', 'best_pred']`. | ||
|
||
## 4. Repository Overview | ||
This repository contains the following files and folders: | ||
|
||
1. **data**: This folder contains the various datasets. | ||
|
||
2. **experiments**: This folder contains data for different runs. | ||
|
||
3. **src**: Contains all the source code. | ||
|
||
i. **base**: Contains the code for all base classes. | ||
|
||
node1.mu = node0.gmm.means_[0] | ||
node2.mu = node0.gmm.means_[1] | ||
ii. **dataloaders**: Contains various dataloaders. | ||
|
||
node1.cov = node0.gmm.cov_[0] | ||
node2.cov = node0.gmm.cov_[1] | ||
REPEAT for k iters: | ||
{ | ||
REPEAT for a iters | ||
{ | ||
node = sample(node1, node2) with prior probabilities | ||
z ~ N(node.mu, node.cov) | ||
x ~ X[node.id] | ||
TRAIN node.E and node.D over cyclic_loss | ||
TRAIN node.Di over disc_adv loss | ||
TRAIN node.D over gen_adv loss | ||
TRAIN node.E over x_clf_loss | ||
} | ||
iii. **hyperparams**: Contains different `hyperparam.py` files for running various experiments. | ||
|
||
gmm0 = GaussianMixture(n_components=2) | ||
gmm0.fit(node0.post_gmm_encoder.encode(X[0])) | ||
split X into 2 labels: 1 and 2 | ||
iv. **models**: Contains code for constructing AAE models, GNode and GAN-Tree. | ||
|
||
node1.mu = gmm0.means_[0] | ||
node2.mu = gmm0.means_[1] | ||
v. **modules** and **utils**: Code for various functions used frequently. | ||
|
||
node1.cov = gmm0.cov_[0] | ||
node2.cov = gmm0.cov_[1] | ||
} | ||
vi. **trainers**: Contains code for the trainers of a particular GNode. | ||
|
||
## 5. Experiments | ||
|
||
### 5.1. GAN Tree for Single Channel Dataset | ||
|
||
We train GAN Tree on the MNIST dataset, which is a single channel dataset consisting of handwritten datasets. To run the experiment, the following command can be executed: | ||
|
||
```python | ||
python GANTree_MNIST.py | ||
``` | ||
|
||
### 5.2. GAN Tree for Single Channel Mixed Dataset | ||
|
||
We train GAN Tree on the MNIST and Fashion MNIST dataset mixed together to test its performance on datasets with a clear discontinuous manifold. To run the experiment, the following command can be executed: | ||
|
||
```python | ||
python GANTree_MNIST_Fashion_Mixed.py | ||
``` | ||
|
||
### 5.3. GAN Tree for Multiple Channel Mixed Dataset | ||
|
||
We train GAN Tree on the LSUN Bedroom Scene and CelebA dataset mixed together to test its robustness in the multiple channel scenario. To run the experiment, the following command can be executed: | ||
|
||
```python | ||
python GANTree_FaceBed.py | ||
``` | ||
|
||
### 5.4. Incremental GAN Tree | ||
|
||
The GAN Tree has the unique feature of being able to learn new related data without the need of previous data, i.e. learn incrementally. To run the experiment mentioned in the paper, the following commands can be executed: | ||
|
||
```python | ||
python GANTree_MNIST_0to4.py | ||
``` | ||
|
||
After creating a GAN Tree trained on the digits 0-4 from the MNIST dataset, we would like to add the digit 5. To incrementally learn a GAN Tree for the same, run the following command: | ||
|
||
```python | ||
python iGANTree_add5_dsigma4.py | ||
```` | ||
|
||
or | ||
|
||
```python | ||
python iGANTree_add5_dsigma9.py | ||
``` | ||
|
||
## 6. Results Obtained | ||
### 6.1. Generated GAN Tree for Single Channel Mixed Dataset | ||
|
||
 | ||
|
||
### 6.2. Generated GAN Tree for Multiple Channel Mixed Dataset | ||
 | ||
|
||
### 6.3. Generated i-GANTree for Adding Digit 5 | ||
 | ||
|
||
## 7. Guidelines for Contributors | ||
|
||
### 7.1. Reporting Bugs and Opening Issues | ||
|
||
If you'd like to report a bug or open an issue then please: | ||
|
||
**Check if there is an existing issue.** If there is then please add any more information that you have, or give it a 👍. | ||
|
||
When submitting an issue please describe the issue as clearly as possible, including how to reproduce the bug. If you can include a screenshot of the issues, that would be helpful. | ||
|
||
### 7.2. Pull Requests | ||
|
||
Please first discuss the change you wish to make via an issue. | ||
|
||
We don't have a set format for Pull Requests, but expect you to list changes, bugs generated and other relevant things in the PR message. | ||
|
||
## 8. License | ||
|
||
Generalize above in a recursive way for each split | ||
``` | ||
This repository is licensed under MIT license. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.