-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
118 lines (93 loc) · 3.36 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from datasets import Dataset
from transformers import (
AutoTokenizer,
AutoModelForTokenClassification,
DataCollatorForTokenClassification,
Trainer,
TrainingArguments,
)
from pathlib import Path
import pandas as pd
import numpy as np
import json
INFERENCE_MAX_LENGTH = 256
model_path = "output/albert-base-v2_128"
def tokenize(example, tokenizer):
text = []
token_map = []
idx = 0
for t, ws in zip(example["tokens"], example["trailing_whitespace"]):
text.append(t)
token_map.extend([idx] * len(t))
if ws:
text.append(" ")
token_map.append(-1)
idx += 1
tokenized = tokenizer(
"".join(text),
return_offsets_mapping=True,
truncation=True,
max_length=INFERENCE_MAX_LENGTH,
)
return {
**tokenized,
"token_map": token_map,
}
data = json.load(open("data/test.json"))
ds = Dataset.from_dict(
{
"full_text": [x["full_text"] for x in data],
"document": [x["document"] for x in data],
"tokens": [x["tokens"] for x in data],
"trailing_whitespace": [x["trailing_whitespace"] for x in data],
}
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
ds = ds.map(tokenize, fn_kwargs={"tokenizer": tokenizer}, num_proc=2)
model = AutoModelForTokenClassification.from_pretrained(model_path)
collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=16)
args = TrainingArguments(".", per_device_eval_batch_size=1, report_to="none",)
trainer = Trainer(model=model, args=args, data_collator=collator, tokenizer=tokenizer,)
predictions = trainer.predict(ds).predictions
pred_softmax = np.exp(predictions) / np.sum(np.exp(predictions), axis=2).reshape(
predictions.shape[0], predictions.shape[1], 1
)
config = json.load(open(Path(model_path) / "config.json"))
id2label = config["id2label"]
preds = predictions.argmax(-1)
preds_without_O = pred_softmax[:, :, :12].argmax(-1)
O_preds = pred_softmax[:, :, 12]
threshold = 0.9
preds_final = np.where(O_preds < threshold, preds_without_O, preds)
triplets = []
document, token, label, token_str = [], [], [], []
for p, token_map, offsets, tokens, doc in zip(
preds_final, ds["token_map"], ds["offset_mapping"], ds["tokens"], ds["document"]
):
for token_pred, (start_idx, end_idx) in zip(p, offsets):
label_pred = id2label[str(token_pred)]
if start_idx + end_idx == 0:
continue
if token_map[start_idx] == -1:
start_idx += 1
# ignore "\n\n"
while start_idx < len(token_map) and tokens[token_map[start_idx]].isspace():
start_idx += 1
if start_idx >= len(token_map):
break
token_id = token_map[start_idx]
# ignore "O" predictions and whitespace preds
if label_pred != "O" and token_id != -1:
triplet = (label_pred, token_id, tokens[token_id])
if triplet not in triplets:
document.append(doc)
token.append(token_id)
label.append(label_pred)
token_str.append(tokens[token_id])
triplets.append(triplet)
df = pd.DataFrame(
{"document": document, "token": token, "label": label, "token_str": token_str}
)
df["row_id"] = list(range(len(df)))
# display(df.head(100))
df[["row_id", "document", "token", "label"]].to_csv("submission.csv", index=False)