-
Notifications
You must be signed in to change notification settings - Fork 19
Expand file tree
/
Copy pathmnist_example.py
More file actions
70 lines (46 loc) · 1.81 KB
/
mnist_example.py
File metadata and controls
70 lines (46 loc) · 1.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import tensorflow
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.layers import Conv2D, Input, Dense, MaxPool2D, BatchNormalization, GlobalAvgPool2D
from tensorflow.python.keras import activations
from deeplearning_models import functional_model, MyCustomModel
from my_utils import display_some_examples
# tensorflow.keras.Sequential
seq_model = tensorflow.keras.Sequential(
[
Input(shape=(28,28,1)),
Conv2D(32, (3,3), activation='relu'),
Conv2D(64, (3,3), activation='relu'),
MaxPool2D(),
BatchNormalization(),
Conv2D(128, (3,3), activation='relu'),
MaxPool2D(),
BatchNormalization(),
GlobalAvgPool2D(),
Dense(64, activation='relu'),
Dense(10, activation='softmax')
]
)
if __name__=='__main__':
(x_train, y_train), (x_test, y_test) = tensorflow.keras.datasets.mnist.load_data()
print("x_train.shape = ", x_train.shape)
print("y_train.shape = ", y_train.shape)
print("x_test.shape = ", x_test.shape)
print("y_test.shape = ", y_test.shape)
if False:
display_some_examples(x_train, y_train)
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255
x_train = np.expand_dims(x_train, axis=-1)
x_test = np.expand_dims(x_test, axis=-1)
y_train = tensorflow.keras.utils.to_categorical(y_train, 10)
y_test = tensorflow.keras.utils.to_categorical(y_test, 10)
#model = functional_model()
model = MyCustomModel()
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics='accuracy')
# label : 2
# one hot encoding : 2
# model training
model.fit(x_train, y_train, batch_size=64, epochs=3, validation_split=0.2)
# Evaluation on test set
model.evaluate(x_test, y_test, batch_size=64)