Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions examples/offline_inference/longbench/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import os, json, argparse
import numpy as np
from task_templates import TASK_TEMPLATES

from metrics import (
qa_f1_score,
rouge_score,
classification_score,
retrieval_score,
count_score,
code_sim_score,
)

dataset2metric = {
"narrativeqa": qa_f1_score,
"qasper": qa_f1_score,
"multifieldqa_en": qa_f1_score,
"hotpotqa": qa_f1_score,
"2wikimqa": qa_f1_score,
"musique": qa_f1_score,
"gov_report": rouge_score,
"qmsum": rouge_score,
"multi_news": rouge_score,
"trec": classification_score,
"triviaqa": qa_f1_score,
"samsum": rouge_score,
"passage_count": count_score,
"passage_retrieval_en": retrieval_score,
"lcc": code_sim_score,
"repobench-p": code_sim_score,
}

def evaluate_on_dataset(args):
score_path = os.path.join(f"scores_{args.method_name}" if args.method_name != "" else "scores")
os.makedirs(score_path, exist_ok=True)
pred_file = os.path.join(f"{args.pred_file_name}.jsonl")
score_file = os.path.join(score_path, f"{os.path.basename(pred_file).split('.jsonl')[0]}_score.jsonl")

with open(pred_file, "r", encoding="utf-8") as f:
predictions, answers = [], []
comments = []
for line in f:
data = json.loads(line)
predictions.append(data["pred"])
answers.append(data["answers"])
if "all_classes" in data:
all_classes = data["all_classes"]
else:
all_classes = None
if "comments" in data:
comments.append(data["comments"])
else:
comments.append(None)
print(f"Loaded {len(predictions)} predictions from {pred_file}")

print(f"Scoring {args.dataset} predictions from {args.model_id}...")
total_score = 0.
for (idx, prediction, ground_truths, comment) in zip(range(len(predictions)), predictions, answers, comments):
score = 0.
for ground_truth in ground_truths:
score = max(score, dataset2metric[args.dataset](prediction, ground_truth, all_classes=all_classes))
total_score += score

print(f"Sample {idx} - Score: {round(100 * score, 2)}")
json_data = {
"score": round(100 * score, 2),
**({"comments": comment} if comment else {}),
"pred": prediction,
"answers": ground_truths,
}
with open(score_file, "a", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False)
f.write('\n')

avg_score = round(100 * total_score / len(predictions), 2)
os.rename(score_file, f"{score_file.split('.jsonl')[0]}_{avg_score}.jsonl")

print(f"Average score of {len(predictions)} samples for {args.dataset} in prediction file {pred_file} on {args.model_id}: {avg_score}")
return avg_score, score_path


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-id", type=str, default="meta-llama/Llama-3.1-8B-Instruct", help="Model ID to evaluate")
parser.add_argument("--dataset", type=str, default="qasper", help="Dataset to evaluate on")
parser.add_argument("--datasets", type=str, nargs="+", default=None, help="List of datasets to evaluate on, if None, evaluate on the specified dataset")
parser.add_argument("--method-name", type=str, default="", help="Name of the method to evaluate")
parser.add_argument("--pred-file-name", type=str, default="", help="Prediction file name to evaluate")
args = parser.parse_args()

if args.datasets is not None:
all_tasks = False
if "all" in args.datasets:
args.datasets = list(TASK_TEMPLATES.keys())
all_tasks = True
avg_scores = []
for dataset in args.datasets:
args.dataset = dataset
avg_score, score_path = evaluate_on_dataset(args)
avg_scores.append(avg_score)
tasks_avg_score = np.mean(avg_scores)
print(f"Average score for {args.datasets} on {args.model_id} "
f"{'with method ' + args.method_name if args.method_name else ''}: {tasks_avg_score:.2f}")
if all_tasks:
new_score_path = f"{score_path}_{tasks_avg_score:.2f}"
if not os.path.exists(new_score_path):
os.rename(score_path, new_score_path)
else:
evaluate_on_dataset(args)
52 changes: 52 additions & 0 deletions examples/offline_inference/longbench/longbench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser

import json
from datasets import load_dataset
from task_templates import TASK_TEMPLATES


def predict_on_data(args, task="qasper", sample_size=None):
llm = LLM(**args)
dataset = load_dataset("zai-org/LongBench", task, split="test", trust_remote_code=True)
task_temp = TASK_TEMPLATES.get(task)

conversations = []
for idx, item in enumerate(dataset):
conversations.append([
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": task_temp["template"].format(context=item["context"], input=item["input"])},
])
if idx + 1 == sample_size:
break

sampling_params = llm.get_default_sampling_params()
sampling_params.temperature = 0 # set temperature to 0 for greedy decoding
sampling_params.max_tokens = 128

outputs = llm.chat(conversations, sampling_params, use_tqdm=True)
for idx in range(len(outputs)):
pred_data = {
"sample_idx": idx,
"pred": outputs[idx].outputs[0].text,
"answers": dataset[idx]["answers"],
}
with open(f"pred/longbench_{task}_preds.jsonl", "a") as f:
json.dump(pred_data, f, ensure_ascii=False)
f.write("\n")


if __name__ == "__main__":
parser = FlexibleArgumentParser()
# Add engine args
EngineArgs.add_cli_args(parser)
# enforce eager execution (no dynamo, no cudagraphs)
parser.set_defaults(enforce_eager=True)

parser.set_defaults(model="meta-llama/Llama-3.1-8B-Instruct")
parser.set_defaults(max_model_len=32768)
parser.set_defaults(max_num_seqs=1)

args: dict = vars(parser.parse_args())

predict_on_data(args, sample_size=10)
100 changes: 100 additions & 0 deletions examples/offline_inference/longbench/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import re
import string

import jieba
from fuzzywuzzy import fuzz
import difflib

from typing import List
from collections import Counter
from rouge import Rouge

def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""

def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)

def white_space_fix(text):
return " ".join(text.split())

def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)

def lower(text):
return text.lower()

return white_space_fix(remove_articles(remove_punc(lower(s))))

def count_score(prediction, ground_truth, **kwargs):
numbers = re.findall(r"\d+", prediction)
right_num = 0
for number in numbers:
if str(number) == str(ground_truth):
right_num += 1
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
return float(final_score)

def retrieval_score(prediction, ground_truth, **kwargs):
pattern = r'Paragraph (\d+)'
matches = re.findall(pattern, ground_truth)
ground_truth_id = matches[0]
numbers = re.findall(r"\d+", prediction)
right_num = 0
for number in numbers:
if str(number) == str(ground_truth_id):
right_num += 1
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
return float(final_score)

def code_sim_score(prediction, ground_truth, **kwargs):
all_lines = prediction.lstrip('\n').split('\n')
prediction = ""
for line in all_lines:
if ('`' not in line) and ('#' not in line) and ('//' not in line):
prediction = line
break
return (fuzz.ratio(prediction, ground_truth) / 100)

def classification_score(prediction, ground_truth, **kwargs):
em_match_list = []
all_classes = kwargs["all_classes"]
for class_name in all_classes:
if class_name in prediction:
em_match_list.append(class_name)
for match_term in em_match_list:
if match_term in ground_truth and match_term != ground_truth:
em_match_list.remove(match_term)
if ground_truth in em_match_list:
score = (1.0 / len(em_match_list))
else:
score = 0.0
return score

def rouge_score(prediction, ground_truth, **kwargs):
rouge = Rouge()
try:
scores = rouge.get_scores([prediction], [ground_truth], avg=True)
except:
return 0.0
return scores["rouge-l"]["f"]


def f1_score(prediction, ground_truth, **kwargs):
common = Counter(prediction) & Counter(ground_truth)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction)
recall = 1.0 * num_same / len(ground_truth)
f1 = (2 * precision * recall) / (precision + recall)
return f1

def qa_f1_score(prediction, ground_truth, **kwargs):
normalized_prediction = normalize_answer(prediction)
normalized_ground_truth = normalize_answer(ground_truth)

prediction_tokens = normalized_prediction.split()
ground_truth_tokens = normalized_ground_truth.split()
return f1_score(prediction_tokens, ground_truth_tokens)
123 changes: 123 additions & 0 deletions examples/offline_inference/longbench/task_templates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
TASK_TEMPLATES = {

"narrativeqa": {
"template": ("You are given a story, which can be either a novel or a movie script, and a question. "
"Answer the question as concisely as you can, using a single phrase if possible. "
"Do not provide any explanation.\n\nStory: {context}\n\n"
"Now, answer the question based on the story as concisely as you can, "
"using a single phrase if possible. Do not provide any explanation.\n\n"
"Question: {input}\n\n"),
"answer_prefix": "Answer:"
},

"qasper": {
"template": ("You are given a scientific article and a question. Answer the question as concisely as you can, "
"using a single phrase or sentence if possible. If the question cannot be answered based on "
"the information in the article, write \"unanswerable\". If the question is a yes/no question, "
"answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\n"
"Article: {context}\n\nAnswer the question based on the above article as concisely as you can, "
"using a single phrase or sentence if possible. If the question cannot be answered based on "
"the information in the article, write \"unanswerable\". If the question is a yes/no question, "
"answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\n"
"Question: {input}\n\n"),
"answer_prefix": "Answer:"
},

"multifieldqa_en": {
"template": ("Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question "
"based on the above text, only give me the answer and do not output any other words.\n\n"
"Question: {input}\n"),
"answer_prefix": "Answer:"
},

"hotpotqa": {
"template": ("Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\n"
"The following are given passages.\n{context}\n\nAnswer the question based on the given passages. "
"Only give me the answer and do not output any other words.\n\n"
"Question: {input}\n"),
"answer_prefix": "Answer:"
},

"2wikimqa": {
"template": ("Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\n"
"The following are given passages.\n{context}\n\nAnswer the question based on the given passages. "
"Only give me the answer and do not output any other words.\n\n"
"Question: {input}\n"),
"answer_prefix": "Answer:"
},

"musique": {
"template": ("Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\n"
"The following are given passages.\n{context}\n\nAnswer the question based on the given passages. "
"Only give me the answer and do not output any other words.\n\n"
"Question: {input}\n"),
"answer_prefix": "Answer:"
},

"gov_report": {
"template": ("You are given a report by a government agency. Write a one-page summary of the report.\n\n"
"Report:\n{context}\n\nNow, write a one-page summary of the report.\n\n"),
"answer_prefix": "Summary:"
},

"qmsum": {
"template": ("You are given a meeting transcript and a query containing a question or instruction. "
"Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\n"
"Now, answer the query based on the above meeting transcript in one or more sentences.\n\n"
"Query: {input}\n"),
"answer_prefix": "Answer:"
},

"multi_news": {
"template": ("You are given several news passages. Write a one-page summary of all news. \n\n"
"News:\n{context}\n\nNow, write a one-page summary of all the news.\n\n"),
"answer_prefix": "Summary:"
},

"trec": {
"template": ("Please determine the type of the question below. "
"Here are some examples of questions.\n\n{context}\n\n{input}"),
"answer_prefix": ""
},

"triviaqa": {
"template": ("Answer the question based on the given passage. Only give me the answer and do not output any other words. "
"The following are some examples.\n\n{context}\n\n{input}"),
"answer_prefix": ""
},

"samsum": {
"template": ("Summarize the dialogue into a few short sentences. "
"The following are some examples.\n\n{context}\n\n{input}"),
"answer_prefix": ""
},

"passage_count": {
"template": ("There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. "
"Please carefully read these paragraphs and determine how many unique paragraphs there are "
"after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n"
"{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. "
"The output format should only contain the number, such as 1, 2, 3, and so on.\n\n"),
"answer_prefix": "The final answer is: "
},

"passage_retrieval_en": {
"template": ("Here are 30 paragraphs from Wikipedia, along with an abstract. "
"Please determine which paragraph the abstract is from.\n\n"
"{context}\n\nThe following is an abstract.\n\n{input}\n\n"
"Please enter the number of the paragraph that the abstract is from. "
"The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\n"),
"answer_prefix": "The answer is: "
},

"lcc": {
"template": ("Please complete the code given below. \n{context}"),
"answer_prefix": "Next line of code:\n"
},

"repobench-p": {
"template": ("Please complete the code given below. \n{context}{input}"),
"answer_prefix": "Next line of code:\n"
},

}
Loading