1
+ # This script iterates over the dataset and calculates the moments of the
2
+ # activations of the Inception net (needed for FID), and also returns
3
+ # the Inception Score of the training data.
4
+ import numpy as np
5
+ import torch
6
+ import torch .nn as nn
7
+ import torch .nn .functional as F
8
+
9
+ import utils
10
+ import inception_utils
11
+ from tqdm import tqdm , trange
12
+ from argparse import ArgumentParser
13
+
14
+ def prepare_parser ():
15
+ usage = 'Calculate and store inception metrics.'
16
+ parser = ArgumentParser (description = usage )
17
+ parser .add_argument (
18
+ '--dataset' , type = str , default = 'I128' ,
19
+ help = 'Which Dataset to train on, out of I128, C10, C100, MN10, MN40, STL10 (default: %(default)s)' )
20
+ parser .add_argument (
21
+ '--batch_size' , type = int , default = 64 ,
22
+ help = 'Default overall batchsize (default: %(default)s)' )
23
+ parser .add_argument (
24
+ '--parallel' , action = 'store_true' , default = False ,
25
+ help = 'Train with multiple GPUs (default: %(default)s)' )
26
+ parser .add_argument (
27
+ '--augment' , action = 'store_true' , default = False ,
28
+ help = 'Augment with random crops and flips (default: %(default)s)' )
29
+ parser .add_argument (
30
+ '--hdf5' , action = 'store_true' , default = False ,
31
+ help = 'Use the HDF5 version of the dataset (default: %(default)s)' )
32
+ parser .add_argument (
33
+ '--num_workers' , type = int , default = 8 ,
34
+ help = 'Number of dataloader workers (default: %(default)s)' )
35
+ parser .add_argument (
36
+ '--shuffle' , action = 'store_true' , default = False ,
37
+ help = 'Shuffle the data? (default: %(default)s)' )
38
+ parser .add_argument (
39
+ '--seed' , type = int , default = 0 ,
40
+ help = 'Random seed to use.' )
41
+ return parser
42
+
43
+ def run (config ):
44
+ # Get loader
45
+ loaders = utils .get_data_loaders (** config )
46
+
47
+ # Load inception net
48
+ net = inception_utils .load_inception_net (parallel = config ['parallel' ])
49
+ pool , logits , labels = [], [], []
50
+ device = 'cuda'
51
+ for i , (x , y ) in enumerate (tqdm (loaders [0 ])):
52
+ x = x .to (device )
53
+ with torch .no_grad ():
54
+ pool_val , logits_val = net (x )
55
+ pool += [np .asarray (pool_val .cpu ())]
56
+ logits += [np .asarray (F .softmax (logits_val , 1 ).cpu ())]
57
+ labels += [np .asarray (y .cpu ())]
58
+
59
+ pool , logits , labels = [np .concatenate (item , 0 ) for item in [pool , logits , labels ]]
60
+ # uncomment to save pool, logits, and labels to disk
61
+ # print('Saving pool, logits, and labels to disk...')
62
+ # np.savez(config['dataset']+'_inception_activations.npz', {'pool': pool, 'logits': logits, 'labels': labels})
63
+ # Calculate inception metrics and report them
64
+ print ('Calculating inception metrics...' )
65
+ IS_mean , IS_std = inception_utils .calculate_inception_score (logits )
66
+ print ('Training data from dataset %s has IS of %5.5f +/- %5.5f' % (config ['dataset' ], IS_mean , IS_std ))
67
+ # Prepare mu and sigma, save to disk
68
+ print ('Calculating means and covariances...' )
69
+ mu , sigma = np .mean (pool , axis = 0 ), np .cov (pool , rowvar = False )
70
+ print ('Saving calculated means and covariances to disk...' )
71
+ np .savez (config ['dataset' ]+ '_inception_moments.npz' , ** {'mu' : mu , 'sigma' : sigma })
72
+
73
+ def main ():
74
+ # parse command line
75
+ parser = prepare_parser ()
76
+ config = vars (parser .parse_args ())
77
+ print (config )
78
+ run (config )
79
+ # run; replace this with a for loop to run multiple sequential jobs.
80
+ # train_test(**vars(args))
81
+
82
+
83
+ if __name__ == '__main__' :
84
+ main ()
0 commit comments