-
Notifications
You must be signed in to change notification settings - Fork 3
/
train.py
72 lines (55 loc) · 1.99 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
from time import sleep
import infolog
from hparams import config2parser
from infolog import log
log = infolog.log
def save_seq(file, sequence, input_path):
'''Save training state to disk. (To skip for future runs)
'''
sequence = [str(int(s)) for s in sequence] + [input_path]
with open(file, 'w') as f:
f.write('|'.join(sequence))
def read_seq(file):
'''Load training state from disk. (To skip if not first run)
'''
if os.path.isfile(file):
with open(file, 'r') as f:
sequence = f.read().split('|')
return [bool(int(s)) for s in sequence[:-1]], sequence[-1]
else:
return [0], ''
def prepare_run(args):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(args.tf_log_level)
run_name = args.name or args.model
log_dir = os.path.join(args.base_dir, 'logs-{}'.format(run_name))
os.makedirs(log_dir, exist_ok=True)
infolog.init(os.path.join(log_dir, 'Terminal_train_log'), run_name, args.slack_url)
return log_dir
def train(args, log_dir):
state_file = os.path.join(log_dir, 'state_log')
# Get training states
state, input_path = read_seq(state_file)
log('\n#############################################################\n')
log('Speech Recognition Train\n')
log('#############################################################\n')
if args.model == 'LAS':
from modules.train import sr_train
checkpoint = sr_train(args, log_dir)
# Sleep 1/2 second to let previous graph close
sleep(0.5)
if checkpoint is None:
raise ValueError('Error occured while training, Exiting!')
state = 1
if state:
log('TRAINING IS ALREADY COMPLETE!!')
def main():
model = 'LAS'
accepted_models = ['LAS']
args = config2parser(model)
if args.model not in accepted_models:
raise ValueError('please enter a valid model to train: {}'.format(accepted_models))
log_dir = prepare_run(args)
train(args, log_dir)
if __name__ == '__main__':
main()