diff --git a/examples/offline_inference/longbench/eval.py b/examples/offline_inference/longbench/eval.py new file mode 100644 index 000000000000..a6fc6ed0cd1c --- /dev/null +++ b/examples/offline_inference/longbench/eval.py @@ -0,0 +1,109 @@ +import os, json, argparse +import numpy as np +from task_templates import TASK_TEMPLATES + +from metrics import ( + qa_f1_score, + rouge_score, + classification_score, + retrieval_score, + count_score, + code_sim_score, +) + +dataset2metric = { + "narrativeqa": qa_f1_score, + "qasper": qa_f1_score, + "multifieldqa_en": qa_f1_score, + "hotpotqa": qa_f1_score, + "2wikimqa": qa_f1_score, + "musique": qa_f1_score, + "gov_report": rouge_score, + "qmsum": rouge_score, + "multi_news": rouge_score, + "trec": classification_score, + "triviaqa": qa_f1_score, + "samsum": rouge_score, + "passage_count": count_score, + "passage_retrieval_en": retrieval_score, + "lcc": code_sim_score, + "repobench-p": code_sim_score, +} + +def evaluate_on_dataset(args): + score_path = os.path.join(f"scores_{args.method_name}" if args.method_name != "" else "scores") + os.makedirs(score_path, exist_ok=True) + pred_file = os.path.join(f"{args.pred_file_name}.jsonl") + score_file = os.path.join(score_path, f"{os.path.basename(pred_file).split('.jsonl')[0]}_score.jsonl") + + with open(pred_file, "r", encoding="utf-8") as f: + predictions, answers = [], [] + comments = [] + for line in f: + data = json.loads(line) + predictions.append(data["pred"]) + answers.append(data["answers"]) + if "all_classes" in data: + all_classes = data["all_classes"] + else: + all_classes = None + if "comments" in data: + comments.append(data["comments"]) + else: + comments.append(None) + print(f"Loaded {len(predictions)} predictions from {pred_file}") + + print(f"Scoring {args.dataset} predictions from {args.model_id}...") + total_score = 0. + for (idx, prediction, ground_truths, comment) in zip(range(len(predictions)), predictions, answers, comments): + score = 0. + for ground_truth in ground_truths: + score = max(score, dataset2metric[args.dataset](prediction, ground_truth, all_classes=all_classes)) + total_score += score + + print(f"Sample {idx} - Score: {round(100 * score, 2)}") + json_data = { + "score": round(100 * score, 2), + **({"comments": comment} if comment else {}), + "pred": prediction, + "answers": ground_truths, + } + with open(score_file, "a", encoding="utf-8") as f: + json.dump(json_data, f, ensure_ascii=False) + f.write('\n') + + avg_score = round(100 * total_score / len(predictions), 2) + os.rename(score_file, f"{score_file.split('.jsonl')[0]}_{avg_score}.jsonl") + + print(f"Average score of {len(predictions)} samples for {args.dataset} in prediction file {pred_file} on {args.model_id}: {avg_score}") + return avg_score, score_path + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model-id", type=str, default="meta-llama/Llama-3.1-8B-Instruct", help="Model ID to evaluate") + parser.add_argument("--dataset", type=str, default="qasper", help="Dataset to evaluate on") + parser.add_argument("--datasets", type=str, nargs="+", default=None, help="List of datasets to evaluate on, if None, evaluate on the specified dataset") + parser.add_argument("--method-name", type=str, default="", help="Name of the method to evaluate") + parser.add_argument("--pred-file-name", type=str, default="", help="Prediction file name to evaluate") + args = parser.parse_args() + + if args.datasets is not None: + all_tasks = False + if "all" in args.datasets: + args.datasets = list(TASK_TEMPLATES.keys()) + all_tasks = True + avg_scores = [] + for dataset in args.datasets: + args.dataset = dataset + avg_score, score_path = evaluate_on_dataset(args) + avg_scores.append(avg_score) + tasks_avg_score = np.mean(avg_scores) + print(f"Average score for {args.datasets} on {args.model_id} " + f"{'with method ' + args.method_name if args.method_name else ''}: {tasks_avg_score:.2f}") + if all_tasks: + new_score_path = f"{score_path}_{tasks_avg_score:.2f}" + if not os.path.exists(new_score_path): + os.rename(score_path, new_score_path) + else: + evaluate_on_dataset(args) \ No newline at end of file diff --git a/examples/offline_inference/longbench/longbench.py b/examples/offline_inference/longbench/longbench.py new file mode 100644 index 000000000000..bd842904efca --- /dev/null +++ b/examples/offline_inference/longbench/longbench.py @@ -0,0 +1,52 @@ +from vllm import LLM, EngineArgs +from vllm.utils import FlexibleArgumentParser + +import json +from datasets import load_dataset +from task_templates import TASK_TEMPLATES + + +def predict_on_data(args, task="qasper", sample_size=None): + llm = LLM(**args) + dataset = load_dataset("zai-org/LongBench", task, split="test", trust_remote_code=True) + task_temp = TASK_TEMPLATES.get(task) + + conversations = [] + for idx, item in enumerate(dataset): + conversations.append([ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": task_temp["template"].format(context=item["context"], input=item["input"])}, + ]) + if idx + 1 == sample_size: + break + + sampling_params = llm.get_default_sampling_params() + sampling_params.temperature = 0 # set temperature to 0 for greedy decoding + sampling_params.max_tokens = 128 + + outputs = llm.chat(conversations, sampling_params, use_tqdm=True) + for idx in range(len(outputs)): + pred_data = { + "sample_idx": idx, + "pred": outputs[idx].outputs[0].text, + "answers": dataset[idx]["answers"], + } + with open(f"pred/longbench_{task}_preds.jsonl", "a") as f: + json.dump(pred_data, f, ensure_ascii=False) + f.write("\n") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + # Add engine args + EngineArgs.add_cli_args(parser) + # enforce eager execution (no dynamo, no cudagraphs) + parser.set_defaults(enforce_eager=True) + + parser.set_defaults(model="meta-llama/Llama-3.1-8B-Instruct") + parser.set_defaults(max_model_len=32768) + parser.set_defaults(max_num_seqs=1) + + args: dict = vars(parser.parse_args()) + + predict_on_data(args, sample_size=10) diff --git a/examples/offline_inference/longbench/metrics.py b/examples/offline_inference/longbench/metrics.py new file mode 100644 index 000000000000..12e841ad9fc0 --- /dev/null +++ b/examples/offline_inference/longbench/metrics.py @@ -0,0 +1,100 @@ +import re +import string + +import jieba +from fuzzywuzzy import fuzz +import difflib + +from typing import List +from collections import Counter +from rouge import Rouge + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + +def count_score(prediction, ground_truth, **kwargs): + numbers = re.findall(r"\d+", prediction) + right_num = 0 + for number in numbers: + if str(number) == str(ground_truth): + right_num += 1 + final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) + return float(final_score) + +def retrieval_score(prediction, ground_truth, **kwargs): + pattern = r'Paragraph (\d+)' + matches = re.findall(pattern, ground_truth) + ground_truth_id = matches[0] + numbers = re.findall(r"\d+", prediction) + right_num = 0 + for number in numbers: + if str(number) == str(ground_truth_id): + right_num += 1 + final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) + return float(final_score) + +def code_sim_score(prediction, ground_truth, **kwargs): + all_lines = prediction.lstrip('\n').split('\n') + prediction = "" + for line in all_lines: + if ('`' not in line) and ('#' not in line) and ('//' not in line): + prediction = line + break + return (fuzz.ratio(prediction, ground_truth) / 100) + +def classification_score(prediction, ground_truth, **kwargs): + em_match_list = [] + all_classes = kwargs["all_classes"] + for class_name in all_classes: + if class_name in prediction: + em_match_list.append(class_name) + for match_term in em_match_list: + if match_term in ground_truth and match_term != ground_truth: + em_match_list.remove(match_term) + if ground_truth in em_match_list: + score = (1.0 / len(em_match_list)) + else: + score = 0.0 + return score + +def rouge_score(prediction, ground_truth, **kwargs): + rouge = Rouge() + try: + scores = rouge.get_scores([prediction], [ground_truth], avg=True) + except: + return 0.0 + return scores["rouge-l"]["f"] + + +def f1_score(prediction, ground_truth, **kwargs): + common = Counter(prediction) & Counter(ground_truth) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction) + recall = 1.0 * num_same / len(ground_truth) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + +def qa_f1_score(prediction, ground_truth, **kwargs): + normalized_prediction = normalize_answer(prediction) + normalized_ground_truth = normalize_answer(ground_truth) + + prediction_tokens = normalized_prediction.split() + ground_truth_tokens = normalized_ground_truth.split() + return f1_score(prediction_tokens, ground_truth_tokens) diff --git a/examples/offline_inference/longbench/task_templates.py b/examples/offline_inference/longbench/task_templates.py new file mode 100644 index 000000000000..a289b946b07c --- /dev/null +++ b/examples/offline_inference/longbench/task_templates.py @@ -0,0 +1,123 @@ +TASK_TEMPLATES = { + + "narrativeqa": { + "template": ("You are given a story, which can be either a novel or a movie script, and a question. " + "Answer the question as concisely as you can, using a single phrase if possible. " + "Do not provide any explanation.\n\nStory: {context}\n\n" + "Now, answer the question based on the story as concisely as you can, " + "using a single phrase if possible. Do not provide any explanation.\n\n" + "Question: {input}\n\n"), + "answer_prefix": "Answer:" + }, + + "qasper": { + "template": ("You are given a scientific article and a question. Answer the question as concisely as you can, " + "using a single phrase or sentence if possible. If the question cannot be answered based on " + "the information in the article, write \"unanswerable\". If the question is a yes/no question, " + "answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\n" + "Article: {context}\n\nAnswer the question based on the above article as concisely as you can, " + "using a single phrase or sentence if possible. If the question cannot be answered based on " + "the information in the article, write \"unanswerable\". If the question is a yes/no question, " + "answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\n" + "Question: {input}\n\n"), + "answer_prefix": "Answer:" + }, + + "multifieldqa_en": { + "template": ("Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question " + "based on the above text, only give me the answer and do not output any other words.\n\n" + "Question: {input}\n"), + "answer_prefix": "Answer:" + }, + + "hotpotqa": { + "template": ("Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\n" + "The following are given passages.\n{context}\n\nAnswer the question based on the given passages. " + "Only give me the answer and do not output any other words.\n\n" + "Question: {input}\n"), + "answer_prefix": "Answer:" + }, + + "2wikimqa": { + "template": ("Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\n" + "The following are given passages.\n{context}\n\nAnswer the question based on the given passages. " + "Only give me the answer and do not output any other words.\n\n" + "Question: {input}\n"), + "answer_prefix": "Answer:" + }, + + "musique": { + "template": ("Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\n" + "The following are given passages.\n{context}\n\nAnswer the question based on the given passages. " + "Only give me the answer and do not output any other words.\n\n" + "Question: {input}\n"), + "answer_prefix": "Answer:" + }, + + "gov_report": { + "template": ("You are given a report by a government agency. Write a one-page summary of the report.\n\n" + "Report:\n{context}\n\nNow, write a one-page summary of the report.\n\n"), + "answer_prefix": "Summary:" + }, + + "qmsum": { + "template": ("You are given a meeting transcript and a query containing a question or instruction. " + "Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\n" + "Now, answer the query based on the above meeting transcript in one or more sentences.\n\n" + "Query: {input}\n"), + "answer_prefix": "Answer:" + }, + + "multi_news": { + "template": ("You are given several news passages. Write a one-page summary of all news. \n\n" + "News:\n{context}\n\nNow, write a one-page summary of all the news.\n\n"), + "answer_prefix": "Summary:" + }, + + "trec": { + "template": ("Please determine the type of the question below. " + "Here are some examples of questions.\n\n{context}\n\n{input}"), + "answer_prefix": "" + }, + + "triviaqa": { + "template": ("Answer the question based on the given passage. Only give me the answer and do not output any other words. " + "The following are some examples.\n\n{context}\n\n{input}"), + "answer_prefix": "" + }, + + "samsum": { + "template": ("Summarize the dialogue into a few short sentences. " + "The following are some examples.\n\n{context}\n\n{input}"), + "answer_prefix": "" + }, + + "passage_count": { + "template": ("There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. " + "Please carefully read these paragraphs and determine how many unique paragraphs there are " + "after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n" + "{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. " + "The output format should only contain the number, such as 1, 2, 3, and so on.\n\n"), + "answer_prefix": "The final answer is: " + }, + + "passage_retrieval_en": { + "template": ("Here are 30 paragraphs from Wikipedia, along with an abstract. " + "Please determine which paragraph the abstract is from.\n\n" + "{context}\n\nThe following is an abstract.\n\n{input}\n\n" + "Please enter the number of the paragraph that the abstract is from. " + "The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\n"), + "answer_prefix": "The answer is: " + }, + + "lcc": { + "template": ("Please complete the code given below. \n{context}"), + "answer_prefix": "Next line of code:\n" + }, + + "repobench-p": { + "template": ("Please complete the code given below. \n{context}{input}"), + "answer_prefix": "Next line of code:\n" + }, + +} \ No newline at end of file diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 79879b6805bc..6d1065365b56 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -28,7 +28,7 @@ GroupShape) from vllm.model_executor.models.vision import get_vit_attn_backend from vllm.platforms import _Backend, current_platform -from vllm.utils import GiB_bytes, direct_register_custom_op +from vllm.utils import GiB_bytes, direct_register_custom_op, load_and_pack_head_ids logger = init_logger(__name__) USE_XFORMERS_OPS = None @@ -620,6 +620,21 @@ def unified_attention_with_output( output=output, output_scale=output_scale, output_block_scale=output_block_scale) + + if attn_metadata: + context_tokens_start = 128 + obs_window_size = 256 + if query.shape[0] > context_tokens_start + obs_window_size: + context_tokens_end = query.shape[0] - obs_window_size + head_block_ids = load_and_pack_head_ids() + layer_idx = int(layer_name.split(".")[2]) + if layer_idx in head_block_ids[0]: + context_token_blocks = attn_metadata.block_table[0, context_tokens_start // 16 : context_tokens_end // 16] + head_idx = head_block_ids[1][layer_idx] + i, j = torch.meshgrid(context_token_blocks, \ + torch.tensor(head_idx, dtype=context_token_blocks.dtype, device=context_token_blocks.device), \ + indexing='ij') + kv_cache[1, i, :, j, :] = 0.0 maybe_save_kv_layer_to_connector(layer_name, kv_cache) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 11d6686009b2..ff12e300ddf7 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -83,6 +83,29 @@ logger = init_logger(__name__) + +def load_and_pack_head_ids(): + head_block_ids_fname = os.path.join( + os.path.dirname(__file__), + "llama-3.1-8b-inst_head_block_ids_sp-0.5.npy", + ) + with open(head_block_ids_fname, "rb") as f: + head_block_ids = np.load(f) + + pairs = head_block_ids.tolist() + + by_layer = defaultdict(list) + for L, H in pairs: + by_layer[L].append(H) + + layers_present = sorted(by_layer.keys()) + per_layer_heads = [list(by_layer.get(i, ())) for i in range(32)] + + assert sum(len(heads) for heads in per_layer_heads) == len(pairs), "Mismatch between packed pairs and original pairs" + + return [layers_present, per_layer_heads] + + # This value is chosen to have a balance between ITL and TTFT. Note it is # not optimized for throughput. DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048 diff --git a/vllm/utils/llama-3.1-8b-inst_head_block_ids_sp-0.5.npy b/vllm/utils/llama-3.1-8b-inst_head_block_ids_sp-0.5.npy new file mode 100644 index 000000000000..88b4b7ed42c6 Binary files /dev/null and b/vllm/utils/llama-3.1-8b-inst_head_block_ids_sp-0.5.npy differ