-
Notifications
You must be signed in to change notification settings - Fork 994
/
Copy pathpredict.py
53 lines (41 loc) · 1.67 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import tensorflow as tf
# import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import sys
tf.app.flags.DEFINE_string("output_graph",
"./workspace/flowers_graph.pb",
"학습된 신경망이 저장된 위치")
tf.app.flags.DEFINE_string("output_labels",
"./workspace/flowers_labels.txt",
"학습할 레이블 데이터 파일")
tf.app.flags.DEFINE_boolean("show_image",
True,
"이미지 추론 후 이미지를 보여줍니다.")
FLAGS = tf.app.flags.FLAGS
def main(_):
labels = [line.rstrip() for line in tf.gfile.GFile(FLAGS.output_labels)]
with tf.gfile.FastGFile(FLAGS.output_graph, 'rb') as fp:
graph_def = tf.GraphDef()
graph_def.ParseFromString(fp.read())
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
logits = sess.graph.get_tensor_by_name('final_result:0')
image = tf.gfile.FastGFile(sys.argv[1], 'rb').read()
prediction = sess.run(logits, {'DecodeJpeg/contents:0': image})
# print('=== 예측 결과 ===')
# top_result = int(np.argmax(prediction[0]))
# name = labels[top_result]
# score = prediction[0][top_result]
# print('%s (%.2f%%)' % (name, score * 100))
print('=== 예측 결과 ===')
for i in range(len(labels)):
name = labels[i]
score = prediction[0][i]
print('%s (%.2f%%)' % (name, score * 100))
if FLAGS.show_image:
img = mpimg.imread(sys.argv[1])
plt.imshow(img)
plt.show()
if __name__ == "__main__":
tf.app.run()