Skip to content

Commit 3967ff9

Browse files
authored
Initial upload of main files
1 parent ffbc193 commit 3967ff9

16 files changed

+3238
-0
lines changed

animal_hash.py

Lines changed: 439 additions & 0 deletions
Large diffs are not rendered by default.

calculate_inception_moments.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# This script iterates over the dataset and calculates the moments of the
2+
# activations of the Inception net (needed for FID), and also returns
3+
# the Inception Score of the training data.
4+
import numpy as np
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
9+
import utils
10+
import inception_utils
11+
from tqdm import tqdm, trange
12+
from argparse import ArgumentParser
13+
14+
def prepare_parser():
15+
usage = 'Calculate and store inception metrics.'
16+
parser = ArgumentParser(description=usage)
17+
parser.add_argument(
18+
'--dataset', type=str, default='I128',
19+
help='Which Dataset to train on, out of I128, C10, C100, MN10, MN40, STL10 (default: %(default)s)')
20+
parser.add_argument(
21+
'--batch_size', type=int, default=64,
22+
help='Default overall batchsize (default: %(default)s)')
23+
parser.add_argument(
24+
'--parallel', action='store_true', default=False,
25+
help='Train with multiple GPUs (default: %(default)s)')
26+
parser.add_argument(
27+
'--augment', action='store_true', default=False,
28+
help='Augment with random crops and flips (default: %(default)s)')
29+
parser.add_argument(
30+
'--hdf5', action='store_true', default=False,
31+
help='Use the HDF5 version of the dataset (default: %(default)s)')
32+
parser.add_argument(
33+
'--num_workers', type=int, default=8,
34+
help='Number of dataloader workers (default: %(default)s)')
35+
parser.add_argument(
36+
'--shuffle', action='store_true', default=False,
37+
help='Shuffle the data? (default: %(default)s)')
38+
parser.add_argument(
39+
'--seed', type=int, default=0,
40+
help='Random seed to use.')
41+
return parser
42+
43+
def run(config):
44+
# Get loader
45+
loaders = utils.get_data_loaders(**config)
46+
47+
# Load inception net
48+
net = inception_utils.load_inception_net(parallel=config['parallel'])
49+
pool, logits, labels = [], [], []
50+
device = 'cuda'
51+
for i, (x, y) in enumerate(tqdm(loaders[0])):
52+
x = x.to(device)
53+
with torch.no_grad():
54+
pool_val, logits_val = net(x)
55+
pool += [np.asarray(pool_val.cpu())]
56+
logits += [np.asarray(F.softmax(logits_val, 1).cpu())]
57+
labels += [np.asarray(y.cpu())]
58+
59+
pool, logits, labels = [np.concatenate(item, 0) for item in [pool, logits, labels]]
60+
# uncomment to save pool, logits, and labels to disk
61+
# print('Saving pool, logits, and labels to disk...')
62+
# np.savez(config['dataset']+'_inception_activations.npz', {'pool': pool, 'logits': logits, 'labels': labels})
63+
# Calculate inception metrics and report them
64+
print('Calculating inception metrics...')
65+
IS_mean, IS_std = inception_utils.calculate_inception_score(logits)
66+
print('Training data from dataset %s has IS of %5.5f +/- %5.5f' % (config['dataset'], IS_mean, IS_std))
67+
# Prepare mu and sigma, save to disk
68+
print('Calculating means and covariances...')
69+
mu, sigma = np.mean(pool, axis=0), np.cov(pool, rowvar=False)
70+
print('Saving calculated means and covariances to disk...')
71+
np.savez(config['dataset']+'_inception_moments.npz', **{'mu' : mu, 'sigma' : sigma})
72+
73+
def main():
74+
# parse command line
75+
parser = prepare_parser()
76+
config = vars(parser.parse_args())
77+
print(config)
78+
run(config)
79+
# run; replace this with a for loop to run multiple sequential jobs.
80+
# train_test(**vars(args))
81+
82+
83+
if __name__ == '__main__':
84+
main()

0 commit comments

Comments
 (0)