forked from pbcquoc/cifar_dcgan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcifar_loader.py
80 lines (59 loc) · 2.16 KB
/
cifar_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import numpy as np
import os
import cPickle as pickle
import glob
data_dir = "data"
data_dir_cifar10 = os.path.join(data_dir, "cifar-10-batches-py")
data_dir_cifar100 = os.path.join(data_dir, "cifar-100-python")
class_names_cifar10 = np.load(os.path.join(data_dir_cifar10, "batches.meta"))
#class_names_cifar100 = np.load(os.path.join(data_dir_cifar100, "meta"))
def one_hot(x, n):
"""
convert index representation to one-hot representation
"""
x = np.array(x)
assert x.ndim == 1
return np.eye(n)[x]
def _load_batch_cifar10(filename, dtype='float64'):
"""
load a batch in the CIFAR-10 format
"""
path = os.path.join(data_dir_cifar10, filename)
batch = np.load(path)
data = batch['data'] / 255.0 # scale between [0, 1]
labels = one_hot(batch['labels'], n=10) # convert labels to one-hot representation
return data.astype(dtype), labels.astype(dtype)
def _grayscale(a):
return a.reshape(a.shape[0], 3, 32, 32).mean(1).reshape(a.shape[0], -1)
def cifar10(dtype='float64', grayscale=True):
# train
x_train = []
t_train = []
for k in xrange(5):
x, t = _load_batch_cifar10("data_batch_%d" % (k + 1), dtype=dtype)
x_train.append(x)
t_train.append(t)
x_train = np.concatenate(x_train, axis=0)
t_train = np.concatenate(t_train, axis=0)
# test
x_test, t_test = _load_batch_cifar10("test_batch", dtype=dtype)
if grayscale:
x_train = _grayscale(x_train)
x_test = _grayscale(x_test)
return x_train, t_train, x_test, t_test
def _load_batch_cifar100(filename, dtype='float64'):
"""
load a batch in the CIFAR-100 format
"""
path = os.path.join(data_dir_cifar100, filename)
batch = np.load(path)
data = batch['data'] / 255.0
labels = one_hot(batch['fine_labels'], n=100)
return data.astype(dtype), labels.astype(dtype)
def cifar100(dtype='float64', grayscale=True):
x_train, t_train = _load_batch_cifar100("train", dtype=dtype)
x_test, t_test = _load_batch_cifar100("test", dtype=dtype)
if grayscale:
x_train = _grayscale(x_train)
x_test = _grayscale(x_test)
return x_train, t_train, x_test, t_test