Skip to content

Commit 0f3e62b

Browse files
committed
2 parents aa357f6 + 9c052a6 commit 0f3e62b

File tree

1 file changed

+93
-43
lines changed

1 file changed

+93
-43
lines changed

README.md

Lines changed: 93 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -21,46 +21,96 @@ install the "gpu" version of PyTorch.<br>
2121
import pytorch_pro_gan.PRO_GAN as pg
2222

2323
Use the modules `pg.Generator`, `pg.Discriminator` and
24-
`pg.ProGAN`.
25-
26-
Help on class ProGAN in module pro_gan_pytorch.PRO_GAN:
27-
28-
class ProGAN(builtins.object)
29-
| Wrapper around the Generator and the Discriminator
30-
|
31-
| Methods defined here:
32-
|
33-
| __init__(self, depth=7, latent_size=64, learning_rate=0.001, beta_1=0, beta_2=0.99, eps=1e-08, drift=0.001, n_critic=1, device=device(type='cpu'))
34-
| constructor for the class
35-
| :param depth: depth of the GAN (will be used for each generator and discriminator)
36-
| :param latent_size: latent size of the manifold used by the GAN
37-
| :param learning_rate: learning rate for Adam
38-
| :param beta_1: beta_1 for Adam
39-
| :param beta_2: beta_2 for Adam
40-
| :param eps: epsilon for Adam
41-
| :param n_critic: number of times to update discriminator
42-
| :param device: device to run the GAN on (GPU / CPU)
43-
|
44-
| optimize_discriminator(self, noise, real_batch, depth, alpha)
45-
| performs one step of weight update on discriminator using the batch of data
46-
| :param noise: input noise of sample generation
47-
| :param real_batch: real samples batch
48-
| :param depth: current depth of optimization
49-
| :param alpha: current alpha for fade-in
50-
| :return: current loss (Wasserstein loss)
51-
|
52-
| optimize_generator(self, noise, depth, alpha)
53-
| performs one step of weight update on generator for the given batch_size
54-
| :param noise: input random noise required for generating samples
55-
| :param depth: depth of the network at which optimization is done
56-
| :param alpha: value of alpha for fade-in effect
57-
| :return: current loss (Wasserstein estimate)
58-
|
59-
| ----------------------------------------------------------------------
60-
| Data descriptors defined here:
61-
|
62-
| __dict__
63-
| dictionary for instance variables (if defined)
64-
|
65-
| __weakref__
66-
| list of weak references to the object (if defined)
24+
`pg.ProGAN`. Mostly, you'll only need the ProGAN module.
25+
26+
4.) Example Code for CIFAR-10 dataset:
27+
28+
import torch as th
29+
import torchvision as tv
30+
import pro_gan_pytorch.PRO_GAN as pg
31+
32+
# select the device to be used for training
33+
device = th.device("cuda" if th.cuda.is_available() else "cpu")
34+
data_path = "cifar-10/"
35+
36+
def setup_data(batch_size, num_workers, download=False):
37+
"""
38+
setup the CIFAR-10 dataset for training the CNN
39+
:param batch_size: batch_size for sgd
40+
:param num_workers: num_readers for data reading
41+
:param download: Boolean for whether to download the data
42+
:return: classes, trainloader, testloader => training and testing data loaders
43+
"""
44+
# data setup:
45+
classes = ('plane', 'car', 'bird', 'cat', 'deer',
46+
'dog', 'frog', 'horse', 'ship', 'truck')
47+
48+
transforms = tv.transforms.ToTensor()
49+
50+
trainset = tv.datasets.CIFAR10(root=data_path,
51+
transform=transforms,
52+
download=download)
53+
trainloader = th.utils.data.DataLoader(trainset, batch_size=batch_size,
54+
shuffle=True,
55+
num_workers=num_workers)
56+
57+
testset = tv.datasets.CIFAR10(root=data_path,
58+
transform=transforms, train=False,
59+
download=False)
60+
testloader = th.utils.data.DataLoader(testset, batch_size=batch_size,
61+
shuffle=True,
62+
num_workers=num_workers)
63+
64+
return classes, trainloader, testloader
65+
66+
67+
if __name__ == '__main__':
68+
69+
# some parameters:
70+
depth = 4
71+
num_epochs = 100 # number of epochs per depth (resolution)
72+
latent_size = 128
73+
74+
# get the data. Ignore the test data and their classes
75+
_, train_data_loader, _ = setup_data(batch_size=32, num_workers=3, download=True)
76+
77+
# ======================================================================
78+
# This line creates the PRO-GAN
79+
# ======================================================================
80+
pro_gan = pg.ProGAN(depth=depth, latent_size=latent_size, device=device)
81+
# ======================================================================
82+
83+
# train the pro_gan using the cifar-10 data
84+
for current_depth in range(depth):
85+
print("working on depth:", current_depth)
86+
87+
# note that the rest of the api indexes depth from 0
88+
for epoch in range(1, num_epochs + 1):
89+
print("\ncurrent_epoch: ", epoch)
90+
91+
# calculate the value of aplha for fade-in effect
92+
alpha = int(epoch / num_epochs)
93+
94+
# iterate over the dataset in batches:
95+
for i, batch in enumerate(train_data_loader, 1):
96+
images, _ = batch
97+
# generate some random noise:
98+
noise = th.randn(images.shape[0], latent_size)
99+
100+
# optimize discriminator:
101+
dis_loss = pro_gan.optimize_discriminator(noise, images, current_depth, alpha)
102+
103+
# optimize generator:
104+
gen_loss = pro_gan.optimize_generator(noise, current_depth, alpha)
105+
106+
print("Batch: %d dis_loss: %.3f gen_loss: %.3f"
107+
% (i, dis_loss, gen_loss))
108+
109+
print("epoch finished ...")
110+
111+
print("training complete ...")
112+
113+
# #TODO
114+
1.) Add the conditional PRO_GAN module <br>
115+
2.) Setup the travis - checker. (I have to figure out some good unit tests too :D lulz!) <br>
116+
3.) Write an informative README.rst (although it is rarely read) <br>

0 commit comments

Comments
 (0)