diff --git a/tensorflow_model/mnist_convnet_keras.py b/tensorflow_model/mnist_convnet_keras.py index 8d1fdfb..57887b0 100644 --- a/tensorflow_model/mnist_convnet_keras.py +++ b/tensorflow_model/mnist_convnet_keras.py @@ -74,11 +74,11 @@ def train(model, x_train, y_train, x_test, y_test): validation_data=(x_test, y_test)) -def export_model(saver, model, input_node_names, output_node_name): +def export_model(input_node_names, output_node_name): tf.train.write_graph(K.get_session().graph_def, 'out', \ MODEL_NAME + '_graph.pbtxt') - saver.save(K.get_session(), 'out/' + MODEL_NAME + '.chkp') + tf.train.Saver().save(K.get_session(), 'out/' + MODEL_NAME + '.chkp') freeze_graph.freeze_graph('out/' + MODEL_NAME + '_graph.pbtxt', None, \ False, 'out/' + MODEL_NAME + '.chkp', output_node_name, \ @@ -109,7 +109,7 @@ def main(): train(model, x_train, y_train, x_test, y_test) - export_model(tf.train.Saver(), model, ["conv2d_1_input"], "dense_2/Softmax") + export_model(["conv2d_1_input"], "dense_2/Softmax") if __name__ == '__main__':