Skip to content

Commit 3844797

Browse files
committed
Add a script for testing F1 Score
1 parent 53e65ba commit 3844797

File tree

5 files changed

+725
-0
lines changed

5 files changed

+725
-0
lines changed

eval/README.md

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
## Accuracy testing of Sparse method
2+
3+
### Overview
4+
We use two Chinese subsets of [LongBench](https://huggingface.co/datasets/zai-org/LongBench) to test the accuracy of single-document QA (multifieldqa_zh) and multi-document QA (dureader). The F1 score is adopted to evaluate the accuracy of these sparse methods. For more information about LongBench, please refer to https://github.com/THUDM/LongBench.
5+
6+
### Quick Start
7+
8+
#### Environment Preparation
9+
```shell
10+
pip install jieba fuzzywuzzy rouge
11+
```
12+
#### Test Data Preparation
13+
Dowdload the Longbench dataset
14+
15+
```shell
16+
wget https://huggingface.co/datasets/THUDM/LongBench/resolve/main/data.zip && unzip data.zip
17+
18+
```
19+
20+
#### Configure Specific Sparse Method
21+
22+
Settings for different sparse methods are written in a JSON file, for example:
23+
```python
24+
{"ESA":
25+
{
26+
"init_window_sz": 1,
27+
"local_window_sz": 2,
28+
"min_blocks":4,
29+
"sparse_ratio": 0.2,
30+
"retrieval_stride": 10
31+
}
32+
}
33+
```
34+
35+
Run accuracy testing with:
36+
```shell
37+
cd eval
38+
39+
# Run with default settings: Qwen2.5-14B-Instruct batch=20
40+
bash eval_inference_F1.sh
41+
42+
# Run with custom parameters
43+
# --strip_think: extract the text after </think> from model predictions
44+
# --batch: number of requests processed per batch
45+
bash eval_inference_F1.sh \
46+
--model /home/models/QwQ-32B \
47+
--config ./eval/ucm_sparse_config_esa.json \
48+
--data ./eval/data \
49+
--strip_think 1 \
50+
--batch 1
51+
52+
```
53+
The result files will be saved in the eval/ucm_sparse_predictions folder.
54+
55+
### Results
56+
Test results of Full Attention (Qwen2.5-14B-Instruct):
57+
58+
| Dataset | F1-Score |
59+
|-------|-----------:|
60+
| multifieldqa_zh | 66.6 |
61+
| dureader | 29.33 |
62+

eval/eval.py

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
4+
import os
5+
import json
6+
import argparse
7+
import numpy as np
8+
import re
9+
import string
10+
11+
import jieba
12+
from fuzzywuzzy import fuzz
13+
from collections import Counter
14+
from rouge import Rouge
15+
16+
17+
def extract_pred_after_think(text):
18+
if text is None:
19+
return ""
20+
t = text.strip()
21+
idx = t.find("</think>")
22+
if idx != -1:
23+
return t[idx + len("</think>"):].strip()
24+
return t.strip()
25+
26+
def has_think_tag(text):
27+
if text is None:
28+
return False
29+
return ("</think>" in text)
30+
31+
32+
def normalize_answer(s):
33+
def remove_articles(text):
34+
return re.sub(r"\b(a|an|the)\b", " ", text)
35+
def white_space_fix(text):
36+
return " ".join(text.split())
37+
def remove_punc(text):
38+
exclude = set(string.punctuation)
39+
return "".join(ch for ch in text if ch not in exclude)
40+
def lower(text):
41+
return text.lower()
42+
return white_space_fix(remove_articles(remove_punc(lower(s))))
43+
44+
def normalize_zh_answer(s):
45+
def white_space_fix(text):
46+
return "".join(text.split())
47+
def remove_punc(text):
48+
cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
49+
all_punctuation = set(string.punctuation + cn_punctuation)
50+
return "".join(ch for ch in text if ch not in all_punctuation)
51+
def lower(text):
52+
return text.lower()
53+
return white_space_fix(remove_punc(lower(s)))
54+
55+
def count_score(prediction, ground_truth, **kwargs):
56+
numbers = re.findall(r"\d+", prediction)
57+
right_num = 0
58+
for number in numbers:
59+
if str(number) == str(ground_truth):
60+
right_num += 1
61+
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
62+
return float(final_score)
63+
64+
def retrieval_score(prediction, ground_truth, **kwargs):
65+
pattern = r'Paragraph (\d+)'
66+
matches = re.findall(pattern, ground_truth)
67+
ground_truth_id = matches[0]
68+
numbers = re.findall(r"\d+", prediction)
69+
right_num = 0
70+
for number in numbers:
71+
if str(number) == str(ground_truth_id):
72+
right_num += 1
73+
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
74+
return float(final_score)
75+
76+
def retrieval_zh_score(prediction, ground_truth, **kwargs):
77+
pattern = r'段落(\d+)'
78+
matches = re.findall(pattern, ground_truth)
79+
ground_truth_id = matches[0]
80+
numbers = re.findall(r"\d+", prediction)
81+
right_num = 0
82+
for number in numbers:
83+
if str(number) == str(ground_truth_id):
84+
right_num += 1
85+
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
86+
return float(final_score)
87+
88+
def code_sim_score(prediction, ground_truth, **kwargs):
89+
all_lines = prediction.lstrip('\n').split('\n')
90+
prediction = ""
91+
for line in all_lines:
92+
if ('`' not in line) and ('#' not in line) and ('//' not in line):
93+
prediction = line
94+
break
95+
return (fuzz.ratio(prediction, ground_truth) / 100)
96+
97+
def classification_score(prediction, ground_truth, **kwargs):
98+
em_match_list = []
99+
all_classes = kwargs["all_classes"]
100+
for class_name in all_classes:
101+
if class_name in prediction:
102+
em_match_list.append(class_name)
103+
for match_term in em_match_list:
104+
if match_term in ground_truth and match_term != ground_truth:
105+
em_match_list.remove(match_term)
106+
if ground_truth in em_match_list:
107+
score = (1.0 / len(em_match_list))
108+
else:
109+
score = 0.0
110+
return score
111+
112+
def rouge_score(prediction, ground_truth, **kwargs):
113+
rouge = Rouge()
114+
try:
115+
scores = rouge.get_scores([prediction], [ground_truth], avg=True)
116+
except:
117+
return 0.0
118+
return scores["rouge-l"]["f"]
119+
120+
def rouge_zh_score(prediction, ground_truth, **kwargs):
121+
prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
122+
ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
123+
score = rouge_score(prediction, ground_truth)
124+
return score
125+
126+
def f1_score(prediction, ground_truth, **kwargs):
127+
common = Counter(prediction) & Counter(ground_truth)
128+
num_same = sum(common.values())
129+
if num_same == 0:
130+
return 0
131+
precision = 1.0 * num_same / len(prediction)
132+
recall = 1.0 * num_same / len(ground_truth)
133+
f1 = (2 * precision * recall) / (precision + recall)
134+
return f1
135+
136+
def qa_f1_score(prediction, ground_truth, **kwargs):
137+
normalized_prediction = normalize_answer(prediction)
138+
normalized_ground_truth = normalize_answer(ground_truth)
139+
prediction_tokens = normalized_prediction.split()
140+
ground_truth_tokens = normalized_ground_truth.split()
141+
return f1_score(prediction_tokens, ground_truth_tokens)
142+
143+
def qa_f1_zh_score(prediction, ground_truth, **kwargs):
144+
prediction_tokens = list(jieba.cut(prediction, cut_all=False))
145+
ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
146+
prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
147+
ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
148+
prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
149+
ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
150+
return f1_score(prediction_tokens, ground_truth_tokens)
151+
152+
153+
dataset2metric = {
154+
"narrativeqa": qa_f1_score,
155+
"qasper": qa_f1_score,
156+
"multifieldqa_en": qa_f1_score,
157+
"multifieldqa_zh": qa_f1_zh_score,
158+
"clongeval": qa_f1_zh_score,
159+
"hotpotqa": qa_f1_score,
160+
"2wikimqa": qa_f1_score,
161+
"musique": qa_f1_score,
162+
"dureader": rouge_zh_score,
163+
"gov_report": rouge_score,
164+
"qmsum": rouge_score,
165+
"multi_news": rouge_score,
166+
"vcsum": rouge_zh_score,
167+
"trec": classification_score,
168+
"triviaqa": qa_f1_score,
169+
"samsum": rouge_score,
170+
"lsht": classification_score,
171+
"passage_retrieval_en": retrieval_score,
172+
"passage_count": count_score,
173+
"passage_retrieval_zh": retrieval_zh_score,
174+
"lcc": code_sim_score,
175+
"repobench-p": code_sim_score,
176+
}
177+
178+
179+
def parse_args(args=None):
180+
parser = argparse.ArgumentParser()
181+
parser.add_argument('--model', type=str, default=None)
182+
parser.add_argument('--answer', type=str, default=None)
183+
parser.add_argument('--dataset', type=str, default=None)
184+
parser.add_argument('--strip_think', action='store_true', help="Extract </think> after content")
185+
parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E")
186+
return parser.parse_args(args)
187+
188+
def scorer_e(dataset, predictions, answers, lengths, all_classes):
189+
scores = {"0-4k": [], "4-8k": [], "8k+": []}
190+
for (prediction, ground_truths, length) in zip(predictions, answers, lengths):
191+
score = 0.
192+
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
193+
prediction = prediction.lstrip('\n').split('\n')[0]
194+
for ground_truth in ground_truths:
195+
score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
196+
if length < 4000:
197+
scores["0-4k"].append(score)
198+
elif length < 8000:
199+
scores["4-8k"].append(score)
200+
else:
201+
scores["8k+"].append(score)
202+
for key in scores.keys():
203+
scores[key] = round(100 * np.mean(scores[key]), 2)
204+
return scores
205+
206+
def scorer(dataset, predictions, answers, all_classes):
207+
total_score = 0.
208+
# count = 0
209+
for (prediction, ground_truths) in zip(predictions, answers):
210+
score = 0.
211+
if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
212+
prediction = prediction.lstrip('\n').split('\n')[0]
213+
for ground_truth in ground_truths:
214+
score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
215+
216+
total_score += score
217+
return round(100 * total_score / len(predictions), 2)
218+
219+
def fix_json_format(line):
220+
line = re.sub(r'"answers": \[\[(.*?)\]\]', r'"answers": [\1]', line)
221+
222+
line = line.replace("'", '"')
223+
line = line.replace("None", "null")
224+
line = line.strip().replace("\n", "").replace("\r", "").replace("\t", "")
225+
pattern = re.compile(r'"pred":"(.*?)"(?=,)', re.DOTALL)
226+
def escape_quotes(match):
227+
escaped_value = match.group(1).replace('"', '\\"')
228+
return f'"pred":"{escaped_value}"'
229+
230+
line = pattern.sub(escape_quotes, line)
231+
232+
pattern = re.compile(r'"answers":\s*\[([^\]]+)\]', re.DOTALL)
233+
def escape_quotes_in_answers(match):
234+
internel_content = match.group(1)
235+
236+
items = internel_content.split('","')
237+
# import pdb; pdb.set_trace()
238+
escaped_items = [item.replace('"', '\\"') for item in items]
239+
240+
escaped_content = '","'.join(escaped_items)
241+
242+
return f'"answers": ["{escaped_content}"]'
243+
line = pattern.sub(escape_quotes_in_answers, line)
244+
245+
return line
246+
247+
248+
if __name__ == '__main__':
249+
args = parse_args()
250+
251+
predictions, answers, lengths = [], [], []
252+
all_classes = None
253+
with open(args.answer, "r", encoding="utf-8") as f:
254+
for line in f:
255+
data = json.loads(line)
256+
pred_raw = data["pred"]
257+
if args.strip_think:
258+
if not has_think_tag(pred_raw):
259+
continue
260+
pred_clean = extract_pred_after_think(pred_raw)
261+
print(pred_clean)
262+
else:
263+
pred_clean = pred_raw
264+
265+
predictions.append(pred_clean)
266+
answers.append(data["answers"])
267+
268+
if "length" in data:
269+
lengths.append(data["length"])
270+
271+
print("----"*10)
272+
print("有效条数:", len(predictions))
273+
print("----"*10)
274+
275+
if args.e:
276+
score = scorer_e(args.dataset, predictions, answers, lengths, all_classes)
277+
print("All score:", score)
278+
else:
279+
score50 = scorer(args.dataset, predictions[:50], answers[:50], all_classes)
280+
score_all = scorer(args.dataset, predictions, answers, all_classes)
281+
print("50 score:", score50)
282+
print("All score:", score_all)

0 commit comments

Comments
 (0)