|
| 1 | +# coding: utf-8 |
| 2 | +try: |
| 3 | + import urllib.request |
| 4 | +except ImportError: |
| 5 | + raise ImportError('You should use Python 3.x') |
| 6 | +import os.path |
| 7 | +import gzip |
| 8 | +import pickle |
| 9 | +import os |
| 10 | +import numpy as np |
| 11 | + |
| 12 | + |
| 13 | +url_base = 'http://yann.lecun.com/exdb/mnist/' |
| 14 | +key_file = { |
| 15 | + 'train_img':'train-images-idx3-ubyte.gz', |
| 16 | + 'train_label':'train-labels-idx1-ubyte.gz', |
| 17 | + 'test_img':'t10k-images-idx3-ubyte.gz', |
| 18 | + 'test_label':'t10k-labels-idx1-ubyte.gz' |
| 19 | +} |
| 20 | + |
| 21 | +dataset_dir = os.path.dirname(os.path.abspath(__file__)) |
| 22 | +save_file = dataset_dir + "/mnist.pkl" |
| 23 | + |
| 24 | +train_num = 60000 |
| 25 | +test_num = 10000 |
| 26 | +img_dim = (1, 28, 28) |
| 27 | +img_size = 784 |
| 28 | + |
| 29 | + |
| 30 | +def _download(file_name): |
| 31 | + file_path = dataset_dir + "/" + file_name |
| 32 | + |
| 33 | + if os.path.exists(file_path): |
| 34 | + return |
| 35 | + |
| 36 | + print("Downloading " + file_name + " ... ") |
| 37 | + urllib.request.urlretrieve(url_base + file_name, file_path) |
| 38 | + print("Done") |
| 39 | + |
| 40 | +def download_mnist(): |
| 41 | + for v in key_file.values(): |
| 42 | + _download(v) |
| 43 | + |
| 44 | +def _load_label(file_name): |
| 45 | + file_path = dataset_dir + "/" + file_name |
| 46 | + |
| 47 | + print("Converting " + file_name + " to NumPy Array ...") |
| 48 | + with gzip.open(file_path, 'rb') as f: |
| 49 | + labels = np.frombuffer(f.read(), np.uint8, offset=8) |
| 50 | + print("Done") |
| 51 | + |
| 52 | + return labels |
| 53 | + |
| 54 | +def _load_img(file_name): |
| 55 | + file_path = dataset_dir + "/" + file_name |
| 56 | + |
| 57 | + print("Converting " + file_name + " to NumPy Array ...") |
| 58 | + with gzip.open(file_path, 'rb') as f: |
| 59 | + data = np.frombuffer(f.read(), np.uint8, offset=16) |
| 60 | + data = data.reshape(-1, img_size) |
| 61 | + print("Done") |
| 62 | + |
| 63 | + return data |
| 64 | + |
| 65 | +def _convert_numpy(): |
| 66 | + dataset = {} |
| 67 | + dataset['train_img'] = _load_img(key_file['train_img']) |
| 68 | + dataset['train_label'] = _load_label(key_file['train_label']) |
| 69 | + dataset['test_img'] = _load_img(key_file['test_img']) |
| 70 | + dataset['test_label'] = _load_label(key_file['test_label']) |
| 71 | + |
| 72 | + return dataset |
| 73 | + |
| 74 | +def init_mnist(): |
| 75 | + download_mnist() |
| 76 | + dataset = _convert_numpy() |
| 77 | + print("Creating pickle file ...") |
| 78 | + with open(save_file, 'wb') as f: |
| 79 | + pickle.dump(dataset, f, -1) |
| 80 | + print("Done!") |
| 81 | + |
| 82 | +def _change_one_hot_label(X): |
| 83 | + T = np.zeros((X.size, 10)) |
| 84 | + for idx, row in enumerate(T): |
| 85 | + row[X[idx]] = 1 |
| 86 | + |
| 87 | + return T |
| 88 | + |
| 89 | + |
| 90 | +def load_mnist(normalize=True, flatten=True, one_hot_label=False): |
| 91 | + """MNISTデータセットの読み込み |
| 92 | + |
| 93 | + Parameters |
| 94 | + ---------- |
| 95 | + normalize : 画像のピクセル値を0.0~1.0に正規化する |
| 96 | + one_hot_label : |
| 97 | + one_hot_labelがTrueの場合、ラベルはone-hot配列として返す |
| 98 | + one-hot配列とは、たとえば[0,0,1,0,0,0,0,0,0,0]のような配列 |
| 99 | + flatten : 画像を一次元配列に平にするかどうか |
| 100 | + |
| 101 | + Returns |
| 102 | + ------- |
| 103 | + (訓練画像, 訓練ラベル), (テスト画像, テストラベル) |
| 104 | + """ |
| 105 | + if not os.path.exists(save_file): |
| 106 | + init_mnist() |
| 107 | + |
| 108 | + with open(save_file, 'rb') as f: |
| 109 | + dataset = pickle.load(f) |
| 110 | + |
| 111 | + if normalize: |
| 112 | + for key in ('train_img', 'test_img'): |
| 113 | + dataset[key] = dataset[key].astype(np.float32) |
| 114 | + dataset[key] /= 255.0 |
| 115 | + |
| 116 | + if one_hot_label: |
| 117 | + dataset['train_label'] = _change_one_hot_label(dataset['train_label']) |
| 118 | + dataset['test_label'] = _change_one_hot_label(dataset['test_label']) |
| 119 | + |
| 120 | + if not flatten: |
| 121 | + for key in ('train_img', 'test_img'): |
| 122 | + dataset[key] = dataset[key].reshape(-1, 1, 28, 28) |
| 123 | + |
| 124 | + return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) |
| 125 | + |
| 126 | + |
| 127 | +if __name__ == '__main__': |
| 128 | + init_mnist() |
0 commit comments