-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtest_batch_segmentation.py
More file actions
70 lines (51 loc) · 2.21 KB
/
test_batch_segmentation.py
File metadata and controls
70 lines (51 loc) · 2.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#!/usr/bin/env python
import numpy as np
import glob
import tensorflow as tf
import scipy
import scipy.misc as misc
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from matplotlib import colors as mpl_colors
tf.app.flags.DEFINE_string(
'image_dir_path', None, 'Directory for images to segment')
tf.app.flags.DEFINE_string(
'dataset_type', 'rgb', 'Type of images to segment')
tf.app.flags.DEFINE_string(
'output_path', None, 'Output directory where images need to be stored')
tf.app.flags.DEFINE_string(
'model_path', None, 'Directory for frozen model')
tf.app.flags.DEFINE_string(
'image_extension', None, 'Extension of image files in the directory')
FLAGS = tf.app.flags.FLAGS
palette = [(0.0, 0.0, 0.0), (0.5, 0.0, 0.0), (0.0, 0.5, 0.0), (0.5, 0.5, 0.0),
(0.0, 0.0, 0.5), (0.5, 0.0, 0.5), (0.0, 0.5, 0.5), (0.5, 0.5, 0.5),
(0.25, 0.0, 0.0), (0.75, 0.0, 0.0), (0.25, 0.5, 0.0), (0.75, 0.5, 0.0),
(0.25, 0.0, 0.5), (0.75, 0.0, 0.5), (0.25, 0.5, 0.5), (0.75, 0.5, 0.5),
(0.0, 0.25, 0.0), (0.5, 0.25, 0.0), (0.0, 0.75, 0.0), (0.5, 0.75, 0.0),
(0.0, 0.25, 0.5)]
my_cmap = mpl_colors.LinearSegmentedColormap.from_list('Custom cmap', palette, 21)
def main(_):
g = tf.Graph()
sess = tf.Session(graph=g)
with sess.graph.as_default():
graph_def = tf.GraphDef()
with open(FLAGS.model_path, 'rb') as file:
graph_def.ParseFromString(file.read())
tf.import_graph_def(graph_def, name="")
input_x = sess.graph.get_operation_by_name('ph_input_x').outputs[0]
pred = sess.graph.get_operation_by_name('predictions').outputs[0]
output_path=FLAGS.output_path
if not tf.gfile.Exists(output_path):
tf.gfile.MakeDirs(output_path)
for filename in glob.glob(FLAGS.image_dir_path+'/*.'+FLAGS.image_extension):
print("\nProcessing : "+str(filename))
input_image_ori = scipy.misc.imread(filename)
H, W = input_image_ori.shape[0], input_image_ori.shape[1]
input_image = scipy.misc.imresize(input_image_ori, (512, 473))
p = sess.run(pred, feed_dict={input_x: input_image})[0]
fname = filename[filename.rfind('/'):]
scipy.misc.imsave(output_path+fname, p)
if __name__ == '__main__':
tf.app.run()