-
Notifications
You must be signed in to change notification settings - Fork 938
/
train.py
86 lines (73 loc) · 3.71 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# -*- coding: utf-8 -*-
# file: main.py
# author: JinTian
# time: 11/03/2017 9:53 AM
# Copyright 2017 JinTian. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
import os
import tensorflow as tf
from poems.model import rnn_model
from poems.poems import process_poems, generate_batch
tf.app.flags.DEFINE_integer('batch_size', 64, 'batch size.')
tf.app.flags.DEFINE_float('learning_rate', 0.01, 'learning rate.')
tf.app.flags.DEFINE_string('model_dir', os.path.abspath('./model'), 'model save path.')
tf.app.flags.DEFINE_string('file_path', os.path.abspath('./data/poems.txt'), 'file name of poems.')
tf.app.flags.DEFINE_string('model_prefix', 'poems', 'model save prefix.')
tf.app.flags.DEFINE_integer('epochs', 50, 'train how many epochs.')
FLAGS = tf.app.flags.FLAGS
def run_training():
if not os.path.exists(FLAGS.model_dir):
os.makedirs(FLAGS.model_dir)
poems_vector, word_to_int, vocabularies = process_poems(FLAGS.file_path)
batches_inputs, batches_outputs = generate_batch(FLAGS.batch_size, poems_vector, word_to_int)
input_data = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
output_targets = tf.placeholder(tf.int32, [FLAGS.batch_size, None])
end_points = rnn_model(model='lstm', input_data=input_data, output_data=output_targets, vocab_size=len(
vocabularies), rnn_size=128, num_layers=2, batch_size=64, learning_rate=FLAGS.learning_rate)
saver = tf.train.Saver(tf.global_variables())
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
with tf.Session() as sess:
# sess = tf_debug.LocalCLIDebugWrapperSession(sess=sess)
# sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
sess.run(init_op)
start_epoch = 0
checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir)
if checkpoint:
saver.restore(sess, checkpoint)
print("## restore from the checkpoint {0}".format(checkpoint))
start_epoch += int(checkpoint.split('-')[-1])
print('## start training...')
try:
n_chunk = len(poems_vector) // FLAGS.batch_size
for epoch in range(start_epoch, FLAGS.epochs):
n = 0
for batch in range(n_chunk):
loss, _, _ = sess.run([
end_points['total_loss'],
end_points['last_state'],
end_points['train_op']
], feed_dict={input_data: batches_inputs[n], output_targets: batches_outputs[n]})
n += 1
print('Epoch: %d, batch: %d, training loss: %.6f' % (epoch, batch, loss))
if epoch % 6 == 0:
saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
except KeyboardInterrupt:
print('## Interrupt manually, try saving checkpoint for now...')
saver.save(sess, os.path.join(FLAGS.model_dir, FLAGS.model_prefix), global_step=epoch)
print('## Last epoch were saved, next time will start from epoch {}.'.format(epoch))
def main(_):
run_training()
if __name__ == '__main__':
tf.app.run()