|
| 1 | +import tensorflow as tf |
| 2 | +import numpy as np |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +from tensorflow.keras.utils import to_categorical |
| 5 | +from tensorflow.keras.datasets import mnist |
| 6 | +from time import time |
| 7 | +import os |
| 8 | + |
| 9 | +def save(sess, saver, checkpoint_dir, model_name, step): |
| 10 | + |
| 11 | + if not os.path.exists(checkpoint_dir): |
| 12 | + os.makedirs(checkpoint_dir) |
| 13 | + |
| 14 | + saver.save(sess, os.path.join(checkpoint_dir, model_name + '.model'), global_step=step) |
| 15 | + |
| 16 | + |
| 17 | +def load(sess, saver, checkpoint_dir): |
| 18 | + print(" [*] Reading checkpoints...") |
| 19 | + |
| 20 | + ckpt = tf.train.get_checkpoint_state(checkpoint_dir) |
| 21 | + if ckpt : |
| 22 | + ckpt_name = os.path.basename(ckpt.model_checkpoint_path) |
| 23 | + saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name)) |
| 24 | + counter = int(ckpt_name.split('-')[-1]) |
| 25 | + print(" [*] Success to read {}".format(ckpt_name)) |
| 26 | + return True, counter |
| 27 | + else: |
| 28 | + print(" [*] Failed to find a checkpoint") |
| 29 | + return False, 0 |
| 30 | + |
| 31 | +def normalize(X_train, X_test): |
| 32 | + X_train = X_train / 255.0 |
| 33 | + X_test = X_test / 255.0 |
| 34 | + |
| 35 | + return X_train, X_test |
| 36 | + |
| 37 | +def load_mnist() : |
| 38 | + (train_data, train_labels), (test_data, test_labels) = mnist.load_data() |
| 39 | + train_data = np.expand_dims(train_data, axis=-1) # [N, 28, 28] -> [N, 28, 28, 1] |
| 40 | + test_data = np.expand_dims(test_data, axis=-1) # [N, 28, 28] -> [N, 28, 28, 1] |
| 41 | + |
| 42 | + train_data, test_data = normalize(train_data, test_data) |
| 43 | + |
| 44 | + train_labels = to_categorical(train_labels, 10) # [N,] -> [N, 10] |
| 45 | + test_labels = to_categorical(test_labels, 10) # [N,] -> [N, 10] |
| 46 | + |
| 47 | + return train_data, train_labels, test_data, test_labels |
| 48 | + |
| 49 | +def classification_loss(logit, label) : |
| 50 | + loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=label, logits=logit)) |
| 51 | + prediction = tf.equal(tf.argmax(logit, -1), tf.argmax(label, -1)) |
| 52 | + accuracy = tf.reduce_mean(tf.cast(prediction, tf.float32)) |
| 53 | + |
| 54 | + return loss, accuracy |
| 55 | + |
| 56 | +def network(x, reuse=False) : |
| 57 | + with tf.variable_scope('network', reuse=reuse) : |
| 58 | + x = tf.layers.flatten(x) # [N, 28, 28, 1] -> [N, 784] |
| 59 | + |
| 60 | + weight_init = tf.random_normal_initializer() |
| 61 | + |
| 62 | + # [N, 784] -> [N, 10] |
| 63 | + hypothesis = tf.layers.dense(inputs=x, units=10, use_bias=True, kernel_initializer=weight_init, name='fully_connected_logit') |
| 64 | + |
| 65 | + return hypothesis # hypothesis = logit |
| 66 | + |
| 67 | + |
| 68 | +""" dataset """ |
| 69 | +train_x, train_y, test_x, test_y = load_mnist() |
| 70 | + |
| 71 | +""" parameters """ |
| 72 | +learning_rate = 0.001 |
| 73 | +batch_size = 128 |
| 74 | + |
| 75 | +training_epochs = 1 |
| 76 | +training_iterations = len(train_x) // batch_size |
| 77 | + |
| 78 | +img_size = 28 |
| 79 | +c_dim = 1 |
| 80 | +label_dim = 10 |
| 81 | + |
| 82 | +train_flag = True |
| 83 | + |
| 84 | +""" Graph Input using Dataset API """ |
| 85 | +train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y)).\ |
| 86 | + shuffle(buffer_size=100000).\ |
| 87 | + prefetch(buffer_size=batch_size).\ |
| 88 | + batch(batch_size).\ |
| 89 | + repeat() |
| 90 | + |
| 91 | +test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y)).\ |
| 92 | + shuffle(buffer_size=100000).\ |
| 93 | + prefetch(buffer_size=len(test_x)).\ |
| 94 | + batch(len(test_x)).\ |
| 95 | + repeat() |
| 96 | + |
| 97 | +""" Model """ |
| 98 | +train_iterator = train_dataset.make_one_shot_iterator() |
| 99 | +test_iterator = test_dataset.make_one_shot_iterator() |
| 100 | + |
| 101 | +train_inputs, train_labels = train_iterator.get_next() |
| 102 | +test_inputs, test_labels = test_iterator.get_next() |
| 103 | + |
| 104 | +train_logits = network(train_inputs) |
| 105 | +test_logits = network(test_inputs, reuse=True) |
| 106 | + |
| 107 | +train_loss, train_accuracy = classification_loss(logit=train_logits, label=train_labels) |
| 108 | +_, test_accuracy = classification_loss(logit=test_logits, label=test_labels) |
| 109 | + |
| 110 | +""" Training """ |
| 111 | +optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(train_loss) |
| 112 | + |
| 113 | +"""" Summary """ |
| 114 | +summary_train_loss = tf.summary.scalar("train_loss", train_loss) |
| 115 | +summary_train_accuracy = tf.summary.scalar("train_accuracy", train_accuracy) |
| 116 | + |
| 117 | +summary_test_accuracy = tf.summary.scalar("test_accuracy", test_accuracy) |
| 118 | + |
| 119 | +train_summary = tf.summary.merge([summary_train_loss, summary_train_accuracy]) |
| 120 | +test_summary = tf.summary.merge([summary_test_accuracy]) |
| 121 | + |
| 122 | + |
| 123 | +with tf.Session() as sess : |
| 124 | + tf.global_variables_initializer().run() |
| 125 | + start_time = time() |
| 126 | + |
| 127 | + saver = tf.train.Saver() |
| 128 | + checkpoint_dir = 'checkpoints' |
| 129 | + logs_dir = 'logs' |
| 130 | + |
| 131 | + model_dir = 'nn_softmax' |
| 132 | + model_name = 'dense' |
| 133 | + |
| 134 | + checkpoint_dir = os.path.join(checkpoint_dir, model_dir) |
| 135 | + logs_dir = os.path.join(logs_dir, model_dir) |
| 136 | + |
| 137 | + |
| 138 | + if train_flag : |
| 139 | + writer = tf.summary.FileWriter(logs_dir, sess.graph) |
| 140 | + else : |
| 141 | + writer = None |
| 142 | + |
| 143 | + |
| 144 | + # restore check-point if it exits |
| 145 | + could_load, checkpoint_counter = load(sess, saver, checkpoint_dir) |
| 146 | + |
| 147 | + if could_load: |
| 148 | + start_epoch = (int)(checkpoint_counter / training_iterations) |
| 149 | + start_batch_index = checkpoint_counter - start_epoch * training_iterations |
| 150 | + counter = checkpoint_counter |
| 151 | + print(" [*] Load SUCCESS") |
| 152 | + else: |
| 153 | + start_epoch = 0 |
| 154 | + start_batch_index = 0 |
| 155 | + counter = 1 |
| 156 | + print(" [!] Load failed...") |
| 157 | + |
| 158 | + if train_flag : |
| 159 | + """ Training phase """ |
| 160 | + for epoch in range(start_epoch, training_epochs) : |
| 161 | + for idx in range(start_batch_index, training_iterations) : |
| 162 | + |
| 163 | + # train |
| 164 | + _, summary_str, train_loss_val, train_accuracy_val = sess.run([optimizer, train_summary, train_loss, train_accuracy]) |
| 165 | + writer.add_summary(summary_str, counter) |
| 166 | + |
| 167 | + # test |
| 168 | + summary_str, test_accuracy_val = sess.run([test_summary, test_accuracy]) |
| 169 | + writer.add_summary(summary_str, counter) |
| 170 | + |
| 171 | + counter += 1 |
| 172 | + print("Epoch: [%2d] [%5d/%5d] time: %4.4f, train_loss: %.8f, train_accuracy: %.2f, test_Accuracy: %.2f" \ |
| 173 | + % (epoch, idx, training_iterations, time() - start_time, train_loss_val, train_accuracy_val, test_accuracy_val)) |
| 174 | + |
| 175 | + start_batch_index = 0 |
| 176 | + save(sess, saver, checkpoint_dir, model_name, counter) |
| 177 | + |
| 178 | + save(sess, saver, checkpoint_dir, model_name, counter) |
| 179 | + print('Learning Finished!') |
| 180 | + |
| 181 | + test_accuracy_val = sess.run(test_accuracy) |
| 182 | + print("Test accuracy: %.8f" % (test_accuracy_val)) |
| 183 | + |
| 184 | + else : |
| 185 | + """ Test phase """ |
| 186 | + test_accuracy_val = sess.run(test_accuracy) |
| 187 | + print("Test accuracy: %.8f" % (test_accuracy_val)) |
| 188 | + |
| 189 | + """ Get test image """ |
| 190 | + r = np.random.randint(low=0, high=len(test_x) - 1) |
| 191 | + print("Label: ", np.argmax(test_y[r: r+1], axis=-1)) |
| 192 | + print("Prediction: ", sess.run(tf.argmax(test_logits, axis=-1), feed_dict={test_inptus: test_x[r: r+1]})) |
| 193 | + |
| 194 | + plt.imshow(test_x[r:r + 1].reshape(28, 28), cmap='Greys', interpolation='nearest') |
| 195 | + plt.show() |
0 commit comments