@@ -21,46 +21,96 @@ install the "gpu" version of PyTorch.<br>
21
21
import pytorch_pro_gan.PRO_GAN as pg
22
22
23
23
Use the modules ` pg.Generator ` , ` pg.Discriminator ` and
24
- ` pg.ProGAN ` .
25
-
26
- Help on class ProGAN in module pro_gan_pytorch.PRO_GAN:
27
-
28
- class ProGAN(builtins.object)
29
- | Wrapper around the Generator and the Discriminator
30
- |
31
- | Methods defined here:
32
- |
33
- | __init__(self, depth=7, latent_size=64, learning_rate=0.001, beta_1=0, beta_2=0.99, eps=1e-08, drift=0.001, n_critic=1, device=device(type='cpu'))
34
- | constructor for the class
35
- | :param depth: depth of the GAN (will be used for each generator and discriminator)
36
- | :param latent_size: latent size of the manifold used by the GAN
37
- | :param learning_rate: learning rate for Adam
38
- | :param beta_1: beta_1 for Adam
39
- | :param beta_2: beta_2 for Adam
40
- | :param eps: epsilon for Adam
41
- | :param n_critic: number of times to update discriminator
42
- | :param device: device to run the GAN on (GPU / CPU)
43
- |
44
- | optimize_discriminator(self, noise, real_batch, depth, alpha)
45
- | performs one step of weight update on discriminator using the batch of data
46
- | :param noise: input noise of sample generation
47
- | :param real_batch: real samples batch
48
- | :param depth: current depth of optimization
49
- | :param alpha: current alpha for fade-in
50
- | :return: current loss (Wasserstein loss)
51
- |
52
- | optimize_generator(self, noise, depth, alpha)
53
- | performs one step of weight update on generator for the given batch_size
54
- | :param noise: input random noise required for generating samples
55
- | :param depth: depth of the network at which optimization is done
56
- | :param alpha: value of alpha for fade-in effect
57
- | :return: current loss (Wasserstein estimate)
58
- |
59
- | ----------------------------------------------------------------------
60
- | Data descriptors defined here:
61
- |
62
- | __dict__
63
- | dictionary for instance variables (if defined)
64
- |
65
- | __weakref__
66
- | list of weak references to the object (if defined)
24
+ ` pg.ProGAN ` . Mostly, you'll only need the ProGAN module.
25
+
26
+ 4.) Example Code for CIFAR-10 dataset:
27
+
28
+ import torch as th
29
+ import torchvision as tv
30
+ import pro_gan_pytorch.PRO_GAN as pg
31
+
32
+ # select the device to be used for training
33
+ device = th.device("cuda" if th.cuda.is_available() else "cpu")
34
+ data_path = "cifar-10/"
35
+
36
+ def setup_data(batch_size, num_workers, download=False):
37
+ """
38
+ setup the CIFAR-10 dataset for training the CNN
39
+ :param batch_size: batch_size for sgd
40
+ :param num_workers: num_readers for data reading
41
+ :param download: Boolean for whether to download the data
42
+ :return: classes, trainloader, testloader => training and testing data loaders
43
+ """
44
+ # data setup:
45
+ classes = ('plane', 'car', 'bird', 'cat', 'deer',
46
+ 'dog', 'frog', 'horse', 'ship', 'truck')
47
+
48
+ transforms = tv.transforms.ToTensor()
49
+
50
+ trainset = tv.datasets.CIFAR10(root=data_path,
51
+ transform=transforms,
52
+ download=download)
53
+ trainloader = th.utils.data.DataLoader(trainset, batch_size=batch_size,
54
+ shuffle=True,
55
+ num_workers=num_workers)
56
+
57
+ testset = tv.datasets.CIFAR10(root=data_path,
58
+ transform=transforms, train=False,
59
+ download=False)
60
+ testloader = th.utils.data.DataLoader(testset, batch_size=batch_size,
61
+ shuffle=True,
62
+ num_workers=num_workers)
63
+
64
+ return classes, trainloader, testloader
65
+
66
+
67
+ if __name__ == '__main__':
68
+
69
+ # some parameters:
70
+ depth = 4
71
+ num_epochs = 100 # number of epochs per depth (resolution)
72
+ latent_size = 128
73
+
74
+ # get the data. Ignore the test data and their classes
75
+ _, train_data_loader, _ = setup_data(batch_size=32, num_workers=3, download=True)
76
+
77
+ # ======================================================================
78
+ # This line creates the PRO-GAN
79
+ # ======================================================================
80
+ pro_gan = pg.ProGAN(depth=depth, latent_size=latent_size, device=device)
81
+ # ======================================================================
82
+
83
+ # train the pro_gan using the cifar-10 data
84
+ for current_depth in range(depth):
85
+ print("working on depth:", current_depth)
86
+
87
+ # note that the rest of the api indexes depth from 0
88
+ for epoch in range(1, num_epochs + 1):
89
+ print("\ncurrent_epoch: ", epoch)
90
+
91
+ # calculate the value of aplha for fade-in effect
92
+ alpha = int(epoch / num_epochs)
93
+
94
+ # iterate over the dataset in batches:
95
+ for i, batch in enumerate(train_data_loader, 1):
96
+ images, _ = batch
97
+ # generate some random noise:
98
+ noise = th.randn(images.shape[0], latent_size)
99
+
100
+ # optimize discriminator:
101
+ dis_loss = pro_gan.optimize_discriminator(noise, images, current_depth, alpha)
102
+
103
+ # optimize generator:
104
+ gen_loss = pro_gan.optimize_generator(noise, current_depth, alpha)
105
+
106
+ print("Batch: %d dis_loss: %.3f gen_loss: %.3f"
107
+ % (i, dis_loss, gen_loss))
108
+
109
+ print("epoch finished ...")
110
+
111
+ print("training complete ...")
112
+
113
+ # #TODO
114
+ 1.) Add the conditional PRO_GAN module <br >
115
+ 2.) Setup the travis - checker. (I have to figure out some good unit tests too : D lulz!) <br >
116
+ 3.) Write an informative README.rst (although it is rarely read) <br >
0 commit comments