Skip to content

Commit b588d62

Browse files
committed
Improve docs, update scripts
1 parent ec17149 commit b588d62

40 files changed

+459
-401
lines changed

model.py renamed to BigGAN.py

File renamed without changes.

calculate_inception_moments.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1-
# This script iterates over the dataset and calculates the moments of the
2-
# activations of the Inception net (needed for FID), and also returns
3-
# the Inception Score of the training data.
1+
''' Calculate Inception Moments
2+
This script iterates over the dataset and calculates the moments of the
3+
activations of the Inception net (needed for FID), and also returns
4+
the Inception Score of the training data.
5+
6+
Note that if you don't shuffle the data, the IS of true data will be under-
7+
estimated as it is label-ordered. By default, the data is not shuffled
8+
so as to reduce non-determinism. '''
49
import numpy as np
510
import torch
611
import torch.nn as nn
@@ -15,11 +20,11 @@ def prepare_parser():
1520
usage = 'Calculate and store inception metrics.'
1621
parser = ArgumentParser(description=usage)
1722
parser.add_argument(
18-
'--dataset', type=str, default='I128',
23+
'--dataset', type=str, default='I128_hdf5',
1924
help='Which Dataset to train on, out of I128, I256, C10, C100...'
2025
'Append _hdf5 to use the hdf5 version of the dataset. (default: %(default)s)')
2126
parser.add_argument(
22-
'--dataset_root', type=str, default='/home/s1580274/scratch/data/',
27+
'--dataset_root', type=str, default='data',
2328
help='Default location where data is stored (default: %(default)s)')
2429
parser.add_argument(
2530
'--batch_size', type=int, default=64,
@@ -67,7 +72,8 @@ def run(config):
6772
print('Calculating inception metrics...')
6873
IS_mean, IS_std = inception_utils.calculate_inception_score(logits)
6974
print('Training data from dataset %s has IS of %5.5f +/- %5.5f' % (config['dataset'], IS_mean, IS_std))
70-
# Prepare mu and sigma, save to disk
75+
# Prepare mu and sigma, save to disk. Remove "hdf5" by default
76+
# (the FID code also knows to strip "hdf5")
7177
print('Calculating means and covariances...')
7278
mu, sigma = np.mean(pool, axis=0), np.cov(pool, rowvar=False)
7379
print('Saving calculated means and covariances to disk...')

datasets.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
''' Datasets
2+
This file contains definitions for our CIFAR, ImageFolder, and HDF5 datasets
3+
'''
14
import os
25
import os.path
36
import sys
@@ -10,8 +13,7 @@
1013
from torchvision.datasets.utils import download_url, check_integrity
1114
import torch.utils.data as data
1215
from torch.utils.data import DataLoader
13-
14-
# Stuff for full imagenet
16+
1517
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']
1618

1719

@@ -106,9 +108,12 @@ def __init__(self, root, transform=None, target_transform=None,
106108
loader=default_loader, load_in_mem=False,
107109
index_filename='imagenet_imgs.npz', **kwargs):
108110
classes, class_to_idx = find_classes(root)
111+
# Load pre-computed image directory walk
109112
if os.path.exists(index_filename):
110113
print('Loading pre-saved Index file %s...' % index_filename)
111114
imgs = np.load(index_filename)['imgs']
115+
# If first time, walk the folder directory and save the
116+
# results to a pre-computed file.
112117
else:
113118
print('Generating Index file %s...' % index_filename)
114119
imgs = make_dataset(root, class_to_idx)
@@ -171,19 +176,9 @@ def __repr__(self):
171176
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
172177
return fmt_str
173178

174-
175-
# class ImageNetA(ImageFolder):
176-
# def __init__(self, root, transform=None, target_transform=None,
177-
# loader=default_loader, load_in_mem=False,
178-
# train=True,download=False, validate_seed=0,
179-
# val_split=0):
180-
# super(ImageNetA, self).__init__(root, transform, target_transform,
181-
# default_loader, load_in_mem, train, download, validate_seed,
182-
# val_split):
183-
184-
185-
# Imagenet at 256 with '/home/s1580274/scratch/ILSVRC256.hdf5'
186179

180+
''' ILSVRC_HDF5: A dataset to support I/O from an HDF5 to avoid
181+
having to load individual images all the time. '''
187182
import h5py as h5
188183
import torch
189184
class ILSVRC_HDF5(data.Dataset):
@@ -250,7 +245,7 @@ class CIFAR10(dset.CIFAR10):
250245
def __init__(self, root, train=True,
251246
transform=None, target_transform=None,
252247
download=True, validate_seed=0,
253-
val_split=0, load_in_mem=True):
248+
val_split=0, load_in_mem=True, **kwargs):
254249
self.root = os.path.expanduser(root)
255250
self.transform = transform
256251
self.target_transform = target_transform
@@ -264,9 +259,7 @@ def __init__(self, root, train=True,
264259
raise RuntimeError('Dataset not found or corrupted.' +
265260
' You can use download=True to download it')
266261

267-
# now load the picked numpy arrays
268-
269-
262+
# now load the picked numpy arrays
270263
self.data = []
271264
self.labels= []
272265
for fentry in self.train_list:
@@ -294,15 +287,10 @@ def __init__(self, root, train=True,
294287

295288
# randomly grab 500 elements of each class
296289
np.random.seed(validate_seed)
297-
298290
self.val_indices = []
299-
300-
301-
302291
for l_i in label_indices:
303292
self.val_indices += list(l_i[np.random.choice(len(l_i), int(len(self.data) * val_split) // (max(self.labels) + 1) ,replace=False)])
304-
305-
293+
306294
if self.train=='validate':
307295
self.data = self.data[self.val_indices]
308296
self.labels = list(np.asarray(self.labels)[self.val_indices])

inception_tf13.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
## Tensorflow inception score code
2-
# Derived from https://github.com/openai/improved-gan
3-
# Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py
4-
# THIS CODE REQUIRES TENSORFLOW 1.3 or EARLIER to run in BATCH MODE
1+
''' Tensorflow inception score code
2+
Derived from https://github.com/openai/improved-gan
3+
Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py
4+
THIS CODE REQUIRES TENSORFLOW 1.3 or EARLIER to run in PARALLEL BATCH MODE
5+
6+
To use this code, run sample.py on your model with --sample_npz, and then
7+
pass the experiment name in the --experiment_name
8+
'''
59
from __future__ import absolute_import
610
from __future__ import division
711
from __future__ import print_function
@@ -17,20 +21,18 @@
1721
from six.moves import urllib
1822
import tensorflow as tf
1923

20-
MODEL_DIR = '/home/s1580274/scratch'
24+
MODEL_DIR = ''
2125
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
2226
softmax = None
2327

24-
# Run on eddie with /home/s1580274/group/myconda/tensorflow/bin/python
25-
#fname='BigGAN_I128_hdf5_seed0_Gch64_Dch64_bs128_nDa4_nGa4_Glr1.0e-04_Dlr4.0e-04_Gnlrelu_Dnlrelu_Ginitxavier_Dinitxavier_Gattn64_Dattn64_Gshared_ema_SAGAN_bs128x4_ema'
2628
def prepare_parser():
2729
usage = 'Parser for TF1.3- Inception Score scripts.'
2830
parser = ArgumentParser(description=usage)
2931
parser.add_argument(
3032
'--experiment_name', type=str, default='',
3133
help='Which experiment''s samples.npz file to pull and evaluate')
3234
parser.add_argument(
33-
'--experiment_root', type=str, default='/home/s1580274/scratch/samples/',
35+
'--experiment_root', type=str, default='samples',
3436
help='Default location where samples are stored (default: %(default)s)')
3537
parser.add_argument(
3638
'--batch_size', type=int, default=500,
@@ -110,13 +112,12 @@ def _progress(count, block_size, total_size):
110112
logits = tf.matmul(tf.squeeze(pool3), w)
111113
softmax = tf.nn.softmax(logits)
112114

113-
# if softmax is None:
115+
# if softmax is None: # No need to functionalize like this.
114116
_init_inception()
115117

116118
fname = '%s/%s/samples.npz' % (config['experiment_root'], config['experiment_name'])
117119
print('loading %s ...'%fname)
118-
ims = np.load(fname)['x']# + '_samples.npz')['x']
119-
# ims =
120+
ims = np.load(fname)['x']
120121
import time
121122
t0 = time.time()
122123
inc = get_inception_score(list(ims.swapaxes(1,2).swapaxes(2,3)), splits=10)

inception_utils.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
1+
''' Inception utilities
2+
This file contains methods for calculating IS and FID, using either
3+
the original numpy code or an accelerated fully-pytorch version that
4+
uses a fast newton-schulz approximation for the matrix sqrt. There are also
5+
methods for acquiring a desired number of samples from the Generator,
6+
and parallelizing the inbuilt PyTorch inception network.
7+
8+
NOTE that Inception Scores and FIDs calculated using these methods will
9+
*not* be directly comparable to values calculated using the original TF
10+
IS/FID code. You *must* use the TF model if you wish to report and compare
11+
numbers. This code tends to produce IS values that are 5-10% lower than
12+
those obtained through TF.
13+
'''
114
import numpy as np
2-
from scipy import linalg # For FID
15+
from scipy import linalg # For numpy FID
316
import time
417

518
import torch
@@ -8,6 +21,7 @@
821
from torch.nn import Parameter as P
922
from torchvision.models.inception import inception_v3
1023

24+
1125
# Module that wraps the inception network to enable use with dataparallel and
1226
# returning pool features and logits.
1327
class WrapInception(nn.Module):
@@ -69,6 +83,7 @@ def forward(self, x):
6983
# 1000 (num_classes)
7084
return pool, logits
7185

86+
7287
# A pytorch implementation of cov, from Modar M. Alfadly
7388
# https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2
7489
def torch_cov(m, rowvar=False):
@@ -103,9 +118,9 @@ def torch_cov(m, rowvar=False):
103118
mt = m.t() # if complex: mt = m.t().conj()
104119
return fact * m.matmul(mt).squeeze()
105120

121+
106122
# Pytorch implementation of matrix sqrt, from Tsung-Yu Lin, and Subhransu Maji
107123
# https://github.com/msubhransu/matrix-sqrt
108-
109124
def sqrt_newton_schulz(A, numIters, dtype=None):
110125
with torch.no_grad():
111126
if dtype is None:
@@ -123,9 +138,9 @@ def sqrt_newton_schulz(A, numIters, dtype=None):
123138
sA = Y*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A)
124139
return sA
125140

141+
126142
# FID calculator from TTUR--consider replacing this with GPU-accelerated cov
127143
# calculations using torch?
128-
129144
def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
130145
"""Numpy implementation of the Frechet Distance.
131146
Taken from https://github.com/bioinf-jku/TTUR
@@ -180,7 +195,7 @@ def numpy_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
180195

181196
out = diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
182197
return out
183-
# # return () +
198+
184199

185200
def torch_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
186201
"""Pytorch implementation of the Frechet Distance.
@@ -214,6 +229,8 @@ def torch_calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
214229
out = (diff.dot(diff) + torch.trace(sigma1) + torch.trace(sigma2)
215230
- 2 * torch.trace(covmean))
216231
return out
232+
233+
217234
# Calculate Inception Score mean + std given softmax'd logits and number of splits
218235
def calculate_inception_score(pred, num_splits=10):
219236
scores = []

layers.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
''' Layers
2+
This file contains various layers for the BigGAN models.
3+
'''
14
import numpy as np
25
import torch
36
import torch.nn as nn
@@ -8,6 +11,7 @@
811

912
from sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d
1013

14+
1115
# Projection of x onto y
1216
def proj(x, y):
1317
return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
@@ -45,11 +49,13 @@ def power_iteration(W, u_, update=True, eps=1e-12):
4549
#svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
4650
return svs, us, vs
4751

52+
4853
# Convenience passthrough function
4954
class identity(nn.Module):
5055
def forward(self, input):
5156
return input
5257

58+
5359
# Spectral normalization base class
5460
class SN(object):
5561
def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
@@ -105,6 +111,7 @@ def forward(self, x):
105111
return F.conv2d(x, self.W_(), self.bias, self.stride,
106112
self.padding, self.dilation, self.groups)
107113

114+
108115
# Linear layer with spectral norm
109116
class SNLinear(nn.Linear, SN):
110117
def __init__(self, in_features, out_features, bias=True,
@@ -114,6 +121,7 @@ def __init__(self, in_features, out_features, bias=True,
114121
def forward(self, x):
115122
return F.linear(x, self.W_(), self.bias)
116123

124+
117125
# Embedding layer with spectral norm
118126
# We use num_embeddings as the dim instead of embedding_dim here
119127
# for convenience sake
@@ -319,7 +327,8 @@ def extra_repr(self):
319327
s = 'out: {output_size}, in: {input_size},'
320328
s +=' cross_replica={cross_replica}'
321329
return s.format(**self.__dict__)
322-
330+
331+
323332
# Normal, non-class-conditional BN
324333
class bn(nn.Module):
325334
def __init__(self, output_size, eps=1e-5, momentum=0.1,
@@ -355,15 +364,14 @@ def forward(self, x, y=None):
355364
else:
356365
return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain,
357366
self.bias, self.training, self.momentum, self.eps)
358-
367+
368+
359369
# Generator blocks
360370
# Note that this class assumes the kernel size and padding (and any other
361371
# settings) have been selected in the main generator module and passed in
362372
# through the which_conv arg. Similar rules apply with which_bn (the input
363373
# size [which is actually the number of channels of the conditional info] must
364374
# be preselected)
365-
""" Andy's note: I changed activation to NONE to enforce passing in an activation
366-
"""
367375
class GBlock(nn.Module):
368376
def __init__(self, in_channels, out_channels,
369377
which_conv=nn.Conv2d, which_bn=bn, activation=None,
@@ -401,8 +409,6 @@ def forward(self, x, y):
401409

402410

403411
# Residual block for the discriminator
404-
""" Andy's note: I changed activation to NONE to enforce passing in an activation
405-
"""
406412
class DBlock(nn.Module):
407413
def __init__(self, in_channels, out_channels, which_conv=SNConv2d, wide=True,
408414
preactivation=False, activation=None, downsample=None,):

0 commit comments

Comments
 (0)