-
Notifications
You must be signed in to change notification settings - Fork 0
/
finetune.py
35 lines (30 loc) · 966 Bytes
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import gpt_2_simple as gpt2
#Select a gpt2 model size, options are "117M", "345M" and "774M"
model_name = "117M"
dataset = 'sampledata.txt'
run_name = 'sampledata'
# model is saved into current directory under /models/model_name/
gpt2.download_gpt2(model_name=model_name)
#train new model
sess = gpt2.start_tf_sess()
gpt2.finetune(sess,
dataset,
model_name=model_name,
run_name=run_name,
steps=1000, # steps is max number of training steps
save_every=50,
sample_every=50)
gpt2.generate(sess)
#train existing model
'''
sess = gpt2.start_tf_sess()
gpt2.finetune(sess,
dataset,
model_name=model_name,
run_name=run_name,
steps=1000, # steps is max number of training steps
save_every=50,
sample_every=50,
overwrite=True)
gpt2.generate(sess)
'''