Skip to content

Commit bb1c0c4

Browse files
committed
Add pytorch version
1 parent 2f046e2 commit bb1c0c4

12 files changed

+221
-0
lines changed

pytorch/vae/model.py

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch.autograd import Variable
5+
6+
class VaeNet(nn.Module):
7+
"""
8+
This class will have the encoder and the decoder networks
9+
of the variational autoencoder. The encoder will take the
10+
input image and and transform the image to it's latent space.
11+
There will also be a separate method related to the ELBO loss
12+
which will be calculated as follows:
13+
14+
Elbo(x) = Marginal_likelihood(x) - KL_divergence(posterior || true_prior)
15+
16+
Applying the same model to both the MNIST dataset and the
17+
CIFAR10 dataset. So the things should be agnostic with each
18+
others those that deal with the dimensionality of things.
19+
20+
Since we are using the same model for both the datasets, we
21+
should keep the capasity as high as possible so that it is
22+
capable enough to deal with the complexity of the tougher
23+
dataset i.e. CIFAR10.
24+
"""
25+
26+
def __init__(self, latent_dim, batch_size):
27+
super(VaeNet, self).__init__()
28+
29+
self.latent_dim = latent_dim
30+
self.batch_size = batch_size
31+
self.encoder_network()
32+
self.decoder_network()
33+
34+
35+
def encoder_network(self):
36+
# The definition of the convolutional layer takes in the
37+
# number of channels of the input matrix, the number of
38+
# channels for the output matrix, the size of the
39+
# convolutional kernel.
40+
self.en_conv1 = nn.Conv2d(3, 8, 3, padding=1)
41+
self.en_conv2 = nn.Conv2d(8, 16, 3)
42+
self.en_bn2 = nn.BatchNorm2d(16)
43+
self.en_conv3 = nn.Conv2d(16, 32, 2, stride=2) # [BSx32x15x15]
44+
self.en_bn3 = nn.BatchNorm2d(32)
45+
self.en_conv4 = nn.Conv2d(32, 64, 3, stride=2)
46+
self.en_bn4 = nn.BatchNorm2d(64)
47+
self.en_conv5 = nn.Conv2d(64, 128, 3, stride=2) # [BSx128x3x3]
48+
self.en_bn5 = nn.BatchNorm2d(128)
49+
self.en_fc1 = nn.Linear(128 * 3 * 3, 256)
50+
self.en_mu = nn.Linear(256, self.latent_dim)
51+
self.en_sigma = nn.Linear(256, self.latent_dim)
52+
53+
def decoder_network(self):
54+
self.de_deconv1 = nn.ConvTranspose2d(self.latent_dim, self.batch_size * 4, 4, 1, 0)
55+
self.de_bn1 = nn.BatchNorm2d(self.batch_size * 4)
56+
self.de_deconv2 = nn.ConvTranspose2d(self.batch_size * 4, self.batch_size * 2, 4, 2, 1)
57+
self.de_bn2 = nn.BatchNorm2d(self.batch_size * 2)
58+
self.de_deconv3 = nn.ConvTranspose2d(self.batch_size * 2, self.batch_size, 4, 2, 1)
59+
self.de_bn3 = nn.BatchNorm2d(self.batch_size)
60+
self.de_deconv4 = nn.ConvTranspose2d(self.batch_size, 3, 4, 2, 1)
61+
62+
def encoder(self, x):
63+
x = F.elu(self.en_conv1(x))
64+
x = F.elu(self.en_bn2(self.en_conv2(x)))
65+
x = F.elu(self.en_bn3(self.en_conv3(x)))
66+
x = F.elu(self.en_bn4(self.en_conv4(x)))
67+
x = F.elu(self.en_bn5(self.en_conv5(x)))
68+
x = x.view(-1, 128 * 3 *3) # flatten
69+
x = F.elu(self.en_fc1(x))
70+
x_mu = self.en_mu(x)
71+
x_log_sigma_sq = self.en_sigma(x)
72+
return x_mu, x_log_sigma_sq
73+
74+
def reparameterize(self, mu, log_sigma_sq):
75+
if self.training:
76+
std = log_sigma_sq.mul(0.5).exp_() # Doing things in place
77+
eps = Variable(std.data.new(std.size()).normal_())
78+
return eps.mul(std).add_(mu) # multiply the std with epsilon and add it to the mean
79+
else:
80+
return mu
81+
82+
def decoder(self, x):
83+
x = x.view(-1, self.latent_dim, 1, 1)
84+
x = F.elu(self.de_bn1(self.de_deconv1(x)))
85+
x = F.elu(self.de_bn2(self.de_deconv2(x)))
86+
x = F.elu(self.de_bn3(self.de_deconv3(x)))
87+
x = F.tanh((self.de_deconv4(x)))
88+
return x
89+
90+
def forward(self,x):
91+
mu, log_sigma_sq = self.encoder(x)
92+
z = self.reparameterize(mu, log_sigma_sq)
93+
reconstructed_img = self.decoder(z)
94+
return reconstructed_img, mu, log_sigma_sq
95+
96+

pytorch/vae/trainer.py

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
from __future__ import print_function
2+
import torch
3+
import torch.utils.data
4+
from torch import nn, optim
5+
from torch.autograd import Variable
6+
import argparse
7+
import numpy as np
8+
from torch.nn import functional as F
9+
from torchvision import datasets, transforms
10+
from torchvision.utils import save_image
11+
from model import VaeNet
12+
13+
# Define arguments required for training using parser.
14+
parser = argparse.ArgumentParser(description='VAE CIFAR example')
15+
parser.add_argument('--batch_size', type=int, default=128, metavar='N',
16+
help='input batch size for training (default: 128)')
17+
parser.add_argument('--epochs', type=int, default=10, metavar='N',
18+
help='number of epochs to train (default: 10)')
19+
parser.add_argument('--no_cuda', action='store_true', default=False,
20+
help='enables CUDA training')
21+
parser.add_argument('--seed', type=int, default=1, metavar='S',
22+
help='random seed (default: 1)')
23+
parser.add_argument('--latent_dim', type=int, default=100, metavar='L',
24+
help='size of the latent dimension (default: 100)')
25+
parser.add_argument('--log_interval', type=int, default=100, metavar='N',
26+
help='how many batches to wait for before logging training status')
27+
28+
# Parse the arguments and see if cuda is available
29+
args = parser.parse_args()
30+
args.cuda = not args.no_cuda and torch.cuda.is_available()
31+
32+
# Use the defined seed to initialize state
33+
torch.manual_seed(args.seed)
34+
if args.cuda:
35+
torch.cuda.manual_seed(args.seed)
36+
37+
# Define the transformation process of the data.
38+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
39+
40+
trainset = datasets.CIFAR10(root='./data/', train=True, download=True,
41+
transform=transform)
42+
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True,
43+
num_workers=3)
44+
testset = datasets.CIFAR10(root='./data/', train=False, download=True,
45+
transform=transform)
46+
testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False,
47+
num_workers=3)
48+
# Define the model and port it to the gpu
49+
model = VaeNet(batch_size=args.batch_size, latent_dim=args.latent_dim)
50+
if args.cuda:
51+
model = model.cuda()
52+
53+
optimizer = optim.Adam(model.parameters(), lr=1e-4)
54+
55+
# Define the loss function.
56+
def vae_loss(x_recons, x_original, mu, log_sigma_sq):
57+
reconstruct_loss = F.mse_loss(x_recons, x_original)
58+
# KL divergence loss can be defined as follows
59+
# 0.5 * sum(1 + log(sigma^2) - mu^2 -sigma^2)
60+
kl_div = -0.5 * torch.sum(1 + log_sigma_sq - mu.pow(2) - log_sigma_sq.exp())
61+
kl_div /= args.batch_size * 32 * 32 * 3
62+
return kl_div, reconstruct_loss
63+
64+
# Define the train step
65+
def train(epoch):
66+
model.train()
67+
train_loss = 0
68+
likelihood = 0
69+
divergence = 0
70+
for batch_idx, data in enumerate(trainloader):
71+
images, labels = data
72+
if args.cuda:
73+
images = Variable(images.cuda())
74+
optimizer.zero_grad()
75+
reconstructed_img, mu, log_sigma_sq = model(images)
76+
kl_div, recon_loss = vae_loss(reconstructed_img, images, mu, log_sigma_sq)
77+
loss = kl_div + recon_loss
78+
loss.backward()
79+
train_loss += loss.data[0]
80+
likelihood += recon_loss
81+
divergence += kl_div
82+
optimizer.step()
83+
if batch_idx % args.log_interval == 0:
84+
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
85+
epoch, batch_idx * len(data), len(trainloader.dataset),
86+
100. * batch_idx / len(trainloader),
87+
loss.data[0] / len(data)))
88+
print('Epoch: %f, Average loss: %f, Average reconstruction loss: %f, Average kl divergence loss: %f' \
89+
% (epoch, train_loss / len(trainloader.dataset), \
90+
likelihood / len(trainloader.dataset), \
91+
divergence / len(trainloader.dataset)))
92+
93+
# Define the test step
94+
def test(epoch):
95+
model.eval()
96+
test_loss = 0
97+
for batch_idx, data in enumerate(testloader):
98+
images, labels = data
99+
if args.cuda:
100+
images = Variable(images.cuda())
101+
reconstructed_img, mu, log_sigma_sq = model(images)
102+
kl_div, recon_loss = vae_loss(reconstructed_img, images, mu, log_sigma_sq)
103+
test_loss += (kl_div + recon_loss).data[0]
104+
if batch_idx == 0:
105+
n = min(images.size(0), 8)
106+
comparison = torch.cat([images[:n],
107+
reconstructed_img.view(args.batch_size, 3, 32, 32)[:n]])
108+
save_image(comparison.data,
109+
'results/reconstruction_' + str(epoch) + '.png', nrow=n)
110+
111+
test_loss /= len(testloader.dataset)
112+
print('Test set loss: %f' % (test_loss))
113+
114+
# Set up the training loop
115+
for epoch in range(1, args.epochs + 1):
116+
train(epoch)
117+
test(epoch)
118+
# Sample a random value from the gaussian distribution
119+
sample = Variable(torch.randn(args.batch_size, 100))
120+
if args.cuda:
121+
sample = sample.cuda()
122+
sample = model.decoder(sample)
123+
save_image(sample.data.view(args.batch_size, 3, 32, 32),
124+
'results/sample_' + str(epoch) + '.png')
125+
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)