-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_reader.py
More file actions
356 lines (303 loc) · 14.2 KB
/
train_reader.py
File metadata and controls
356 lines (303 loc) · 14.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import time
import sys
import torch
import transformers
import numpy as np
from pathlib import Path
from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler
from src.options import Options
import src.slurm
import src.util
import src.evaluation
import src.data
import src.models_tp
import wandb
from tqdm import tqdm
from transformers import GenerationConfig, AutoConfig
from collections import Counter
import string
import regex
import time
import torch.nn.functional as F
def get_gumbel_temperature(current_step, base_steps, warmup_steps, initial_temp, final_temp=0.1):
if current_step < base_steps:
return initial_temp
elif base_steps < current_step < base_steps + warmup_steps:
ratio = (current_step - base_steps) / float(warmup_steps)
return initial_temp + (final_temp - initial_temp) * ratio
else:
return final_temp
def train(tokenizer, model, optimizer, scheduler, step, train_dataset, eval_dataset, opt, collator, best_dev_em, checkpoint_path):
if opt.is_main:
try:
tb_logger = torch.utils.tensorboard.SummaryWriter(Path(opt.checkpoint_dir)/opt.name)
except:
tb_logger = None
logger.warning('Tensorboard is not available.')
torch.manual_seed(opt.global_rank + opt.seed) #different seed for different sampling depending on global_rank
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(
train_dataset,
sampler=train_sampler,
batch_size=opt.per_gpu_batch_size,
drop_last=True,
num_workers=10,
collate_fn=collator
)
loss, curr_loss = 0.0, 0.0
epoch = 0
model.train()
actual_model = model.module if hasattr(model, "module") else model
while step < opt.total_steps:
epoch += 1
with tqdm(train_dataloader, unit='batch') as train_batch:
for i, batch in enumerate(train_batch):
train_batch.set_description(f"Training Epoch {epoch}")
step += 1
(idx, labels, _, context_ids, context_mask, question_length) = batch
current_temp = get_gumbel_temperature(
current_step=step,
base_steps=opt.pruning_warmup_steps,
warmup_steps=opt.temp_warmup_steps,
initial_temp=opt.gumbel_temperature,
final_temp=0.1
)
for layer_block in actual_model.encoder.pruning_layers:
layer_block.gumbel_temperature = current_temp
actual_model.encoder.last_pruning_layer.gumbel_temperature = current_temp
outputs = model(
input_ids=context_ids.cuda(),
attention_mask=context_mask.cuda(),
labels=labels.cuda(),
question_length=question_length,
output_attentions=True,
return_dict=True,
)
abs_loss, pruning_loss, kl_loss = outputs.loss
pruning_masks = outputs.pruning_masks
origin_token_num = len(context_ids.view(-1))
passed_token_num = pruning_masks[-1].sum()
pruned_ratio = passed_token_num / origin_token_num
train_loss = pruning_loss*opt.pruning_scale + abs_loss + kl_loss*opt.rerank_loss_scale
train_loss.backward()
scheduler.step()
if step % opt.accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), opt.clip)
optimizer.step()
model.zero_grad()
train_loss = src.util.average_main(train_loss, opt)
curr_loss += train_loss.item()
train_batch.set_postfix(train_step=step, train_loss=train_loss.item(),
kl_loss=kl_loss.item(),
pruning_loss=pruning_loss.item(),
pruned_ratio=pruned_ratio.item(),
abs_loss=abs_loss.item(),
temp=current_temp,
lr=optimizer.param_groups[0]['lr'])
if step % opt.eval_freq == 0:
dev_em, f1, pr, act_pr, tpq, strem, rouge = evaluate(model, eval_dataset, tokenizer, collator, opt)
model.train()
if opt.is_main:
if dev_em > best_dev_em:
best_dev_em = dev_em
src.util.save(model, optimizer, scheduler, step, best_dev_em,
opt, checkpoint_path, 'best_dev')
log = f"{step} / {opt.total_steps} |"
log += f"train: {curr_loss/opt.eval_freq:.3f} |"
log += f"evaluation: {100*dev_em:.2f}EM |"
log += f"evaluation: {100*f1:.2f}F1 |"
log += f"evaluation: {100*pr:.2f}PR |"
log += f"evaluation: {100*act_pr:.2f}ACT_PR |"
log += f"evaluation: {tpq}TPQ |"
log += f"evaluation: {100*strem:.2f}STREM |"
log += f"evaluation: {100*rouge:.2f}RougeL |"
log += f"lr: {scheduler.get_last_lr()[0]:.5f}"
logger.info(log)
if tb_logger is not None:
tb_logger.add_scalar("Evaluation", dev_em, step)
tb_logger.add_scalar("Training", curr_loss / (opt.eval_freq), step)
curr_loss = 0.
if opt.is_main and step % opt.save_freq == 0:
src.util.save(model, optimizer, scheduler, step, best_dev_em,
opt, checkpoint_path, f"step-{step}")
if step > opt.total_steps:
break
def evaluate(model, dataset, tokenizer, collator, opt):
sampler = SequentialSampler(dataset)
dataloader = DataLoader(dataset,
sampler=sampler,
batch_size=opt.per_gpu_batch_size,
drop_last=False,
num_workers=10,
collate_fn=collator
)
model.eval()
total = 0
exactmatch = []
f1_scores = []
strem = []
rougeL = []
pruned_ratios = []
actual_pruned_ratios = []
tpqs = []
generation_config = GenerationConfig(
return_dict_in_generate=True,
output_attentions=True,
)
model = model.module if hasattr(model, "module") else model
with torch.no_grad():
with tqdm(dataloader, unit='batch') as eval_batch:
for i, batch in enumerate(eval_batch):
eval_batch.set_description(f"Evaluation")
(idx, _, _, context_ids, context_mask, question_length) = batch
batch_size, n_passages, passage_length = context_ids.size()
start = time.time()
outputs = model.generate(
input_ids=context_ids.cuda(),
attention_mask=context_mask.cuda(),
max_length=128,
question_length=question_length,
generation_config=generation_config,
)
end = time.time()
generated_ids = outputs.sequences
pruning_masks = outputs.pruning_masks
origin_token_num = len(context_ids.view(-1))
passed_token_num = pruning_masks[-1].cpu().sum().item()
actual_token_num = torch.nonzero(context_mask.view(-1), as_tuple=False).size(0)
actual_passed_token_num = outputs.encoder_attention_mask.cpu().sum().item()
for k, o in enumerate(generated_ids):
ans = tokenizer.decode(o, skip_special_tokens=True)
gold = dataset.get_example(idx[k]).get('answers', None)
long_gold = dataset.get_example(idx[k])['long_answers']
score = src.evaluation.ems(ans, gold) if gold is not None else 0.0
total += 1
exactmatch.append(score)
strem_score = src.evaluation.strem(ans, gold) if gold is not None else 0.0
strem.append(strem_score)
rouge_scores = src.evaluation.rouge(ans, long_gold, use_stemmer=False, mode="best")
rougeL.append(rouge_scores['rougeL'])
f1_score = src.evaluation.f1_score(ans, long_gold)
f1_scores.append(f1_score)
pruned_ratio = passed_token_num / origin_token_num
pruned_ratios.append(pruned_ratio)
actual_pruned_ratio = actual_passed_token_num / actual_token_num
actual_pruned_ratios.append(actual_pruned_ratio)
tpq = end - start
tpqs.append(tpq)
eval_batch.set_postfix(abs_score=score,
f1_score=f1_score,
abs_acc=sum(exactmatch)/total,
f1_acc=sum(f1_scores)/total,
pruned_ratio=pruned_ratio,
actual_pruned_ratio=actual_pruned_ratio,
tpq=tpq,
tpqs=sum(tpqs)/total,
strem=strem_score,
strems=sum(strem)/total,
rougel=rouge_scores['rougeL'],
rougeL=sum(rougeL)/total)
exactmatch, em_total = src.util.weighted_average(np.mean(exactmatch), total, opt)
f1, f1_total = src.util.weighted_average(np.mean(f1_scores), total, opt)
pruned_ratio, prun_total = src.util.weighted_average(np.mean(pruned_ratios), total, opt)
actual_pruned_ratio, actual_prun_total = src.util.weighted_average(np.mean(actual_pruned_ratios), total, opt)
tpq, tpq_total = src.util.weighted_average(np.mean(tpqs), total, opt)
strem, strem_total = src.util.weighted_average(np.mean(strem), total, opt)
rouge, rouge_total = src.util.weighted_average(np.mean(rougeL), total, opt)
return exactmatch, f1, pruned_ratio, actual_pruned_ratio, tpq, strem, rouge
if __name__ == "__main__":
options = Options()
options.add_reader_options()
options.add_optim_options()
opt = options.parse()
#opt = options.get_options(use_reader=True, use_optim=True)
torch.manual_seed(opt.seed)
src.slurm.init_distributed_mode(opt)
src.slurm.init_signal_handler()
checkpoint_path = Path(opt.checkpoint_dir)/opt.name
checkpoint_exists = checkpoint_path.exists()
if opt.is_distributed:
torch.distributed.barrier()
checkpoint_path.mkdir(parents=True, exist_ok=True)
#if not checkpoint_exists and opt.is_main:
# options.print_options(opt)
#checkpoint_path, checkpoint_exists = util.get_checkpoint_path(opt)
logger = src.util.init_logger(
opt.is_main,
opt.is_distributed,
checkpoint_path / 'run.log'
)
model_name = 't5-' + opt.model_size
model_class = src.models_tp.FiDT5_TP
#load data
tokenizer = transformers.T5Tokenizer.from_pretrained(model_name)
collator = src.data.Collator(opt.text_maxlength, tokenizer, answer_maxlength=opt.answer_maxlength)
# use golbal rank and world size to split the eval set on multiple gpus
train_examples = src.data.load_data(
opt.train_data,
global_rank=opt.global_rank,
world_size=opt.world_size,
)
train_dataset = src.data.Dataset(train_examples, opt.n_context)
# use golbal rank and world size to split the eval set on multiple gpus
eval_examples = src.data.load_data(
opt.eval_data,
global_rank=opt.global_rank,
world_size=opt.world_size,
)
eval_dataset = src.data.Dataset(eval_examples, opt.n_context)
if not checkpoint_exists and opt.model_path == "none":
t5 = transformers.T5ForConditionalGeneration.from_pretrained(model_name)
t5.config.gumbel_temperature = opt.gumbel_temperature
t5.config.theta1 = opt.theta1
t5.config.theta2 = opt.theta2
t5.config.last_theta = opt.last_theta
t5.config.n_context = opt.n_context
model = src.models_tp.FiDT5_TP(t5.config)
model.load_t5(t5.config, t5.state_dict())
model = model.to(opt.local_rank)
optimizer, scheduler = src.util.set_optim(opt, model)
step, best_dev_em = 0, 0.0
elif opt.model_path == "none":
load_path = checkpoint_path / 'checkpoint' / 'latest'
model, optimizer, scheduler, opt_checkpoint, step, best_dev_em = \
src.util.load(model_class, load_path, opt, reset_params=False)
logger.info(f"Model loaded from {load_path}")
else:
t5 = transformers.T5ForConditionalGeneration.from_pretrained(model_name)
t5.config.gumbel_temperature = opt.gumbel_temperature
t5.config.theta1 = opt.theta1
t5.config.theta2 = opt.theta2
t5.config.last_theta = opt.last_theta
t5.config.n_context = opt.n_context
model, optimizer, scheduler, opt_checkpoint, step, best_dev_em = \
src.util._load(model_class, opt.model_path, opt, t5.config, reset_params=True)
model = model.to(opt.device)
logger.info(f"Model loaded from {opt.model_path}")
model.set_checkpoint(opt.use_checkpoint)
if opt.is_distributed:
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[opt.local_rank],
output_device=opt.local_rank,
find_unused_parameters=True,
)
logger.info("Start training")
train(
tokenizer,
model,
optimizer,
scheduler,
step,
train_dataset,
eval_dataset,
opt,
collator,
best_dev_em,
checkpoint_path
)