-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
58 lines (43 loc) · 2.18 KB
/
test.py
File metadata and controls
58 lines (43 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from torch.nn.utils.rnn import pad_sequence
def test_model(model, tokenizer, PATH):
start_visualize = []
end_visualize = []
with torch.no_grad(), open(f'submissions/{PATH}.csv', 'w') as fd:
writer = csv.writer(fd)
writer.writerow(['Id', 'Predicted'])
rows = []
# for sample in tqdm(test_dataset, "Testing"):
for sample in tqdm(indexed_test_dataset, "Testing"):
input_ids, token_type_ids = [torch.tensor(sample[key], dtype=torch.long, device="cuda") for key in ("input_ids", "token_type_ids")]
# print(sample)
model.eval()
with torch.no_grad():
output = load_model(input_ids=input_ids[None, :], token_type_ids=token_type_ids[None, :])
start_logits = output.start_logits
end_logits = output.end_logits
start_logits.squeeze_(0), end_logits.squeeze_(0)
start_prob = start_logits[token_type_ids.bool()][1:-1].softmax(-1)
end_prob = end_logits[token_type_ids.bool()][1:-1].softmax(-1)
probability = torch.triu(start_prob[:, None] @ end_prob[None, :])
# 토큰 길이 8까지만
for row in range(len(start_prob) - 8):
probability[row] = torch.cat((probability[row][:8+row].cpu(), torch.Tensor([0] * (len(start_prob)-(8+row))).cpu()), 0)
index = torch.argmax(probability).item()
start = index // len(end_prob)
end = index % len(end_prob)
# 확률 너무 낮으면 자르기
if start_prob[start] >= 0 or end_prob[end] >= 0:
start_str = sample['position'][start][0]
end_str = sample['position'][end][1]
else:
start_str = 0
end_str = 0
start_visualize.append((list(start_prob.cpu()), (start, end), (start_str, end_str)))
end_visualize.append((list(end_prob.cpu()), (start, end), (start_str, end_str)))
rows.append([sample["guid"], sample['context'][start_str:end_str]])
writer.writerows(rows)
return writer