From fd34c1004da6d2d46e889e182edb9ec90e1752ef Mon Sep 17 00:00:00 2001 From: shruthan Date: Tue, 23 Sep 2025 23:29:49 -0700 Subject: [PATCH] add gpqa diamond --- metrics/README.md | 9 + metrics/multiple_choice_metrics.py | 167 ++++++++++++++++++ tasks/README.md | 1 + tasks/spoken_language_reasoning/README.md | 3 +- .../gpqa_diamond/base.yaml | 17 ++ .../gpqa_diamond/gpqa_diamond_audio.yaml | 6 + .../gpqa_diamond/gpqa_diamond_text.yaml | 6 + .../gsm8k/gsm8k_text.yaml | 2 - utils/constants.py | 3 + 9 files changed, 211 insertions(+), 3 deletions(-) create mode 100644 metrics/multiple_choice_metrics.py create mode 100644 tasks/spoken_language_reasoning/gpqa_diamond/base.yaml create mode 100644 tasks/spoken_language_reasoning/gpqa_diamond/gpqa_diamond_audio.yaml create mode 100644 tasks/spoken_language_reasoning/gpqa_diamond/gpqa_diamond_text.yaml diff --git a/metrics/README.md b/metrics/README.md index c4d3dd3..b703e4c 100644 --- a/metrics/README.md +++ b/metrics/README.md @@ -25,6 +25,7 @@ For more detailed documentation regarding which metrics can be used for which ta | `bfcl_match_score` (↑) | Structured logic form comparison | bfcl_match_score | | `sql_score` (↑) | SQL correctness and execution match | text2sql_score | | `instruction_following` (↑) | LLM-judged instruction following capability | final | +| `multiple_choice_accuracy` (↑) | Accuracy of prediction the correct option letter in multiple choice tasks | multiple_choice_accuracy | --- @@ -148,3 +149,11 @@ For more detailed documentation regarding which metrics can be used for which ta - **Description**: Measure the instruction following capabilities of LALMs by averaging accuracy across (1) strict-prompt, (2) strict-instruction, (3)loose-prompt and (4) loose-instruction. - **Scoring (record-level)** Score between `0` and `1`, higher is better. - **Used In**: Audio Instruction Following (`ifeval`) + +--- + +### `multiple_choice_accuracy` +- **Type**: Multiple choice accuracy metric +- **Description**: Measure the accuracy of prediction the correct option letter in multiple choice tasks. The correct option is expected in the format `Answer: A` +- **Scoring (record-level)** Score between `0` and `100`, higher is better. +- **Used In**: Audio Instruction Following (`ifeval`) \ No newline at end of file diff --git a/metrics/multiple_choice_metrics.py b/metrics/multiple_choice_metrics.py new file mode 100644 index 0000000..1a41a9f --- /dev/null +++ b/metrics/multiple_choice_metrics.py @@ -0,0 +1,167 @@ +"""Multiple Choice Question metrics implementation for GPQA Diamond. + +Evaluates model performance on multiple choice questions by extracting the predicted +answer choice (A-J) and comparing it to the reference answer. +""" +import re +from typing import List, Dict, Optional, Tuple, Any, Union + +from metrics.metrics import Metrics +from utils import util +from utils.custom_logging import write_record_log, append_final_score + + +class MultipleChoiceMetrics(Metrics): + """Multiple Choice Question evaluation metric. + + Computes accuracy for multiple choice questions by extracting the predicted + answer choice (A-J) and comparing it to the reference answer. + """ + + def __init__(self): + super().__init__() + self.name = "multiple_choice_accuracy" + self.instructions = None + self.model_responses = [] + self.record_level_scores = None + + def __call__( + self, + candidates: List[str], + references: List[str], + instructions: Optional[str] = None, + *, + task_name: Optional[str] = None, + model_name: Optional[str] = None, + model_responses: Optional[List[Any]] = None + ) -> Dict[str, float]: + """Evaluate multiple choice accuracy and optionally log results. + + Args: + candidates: List of model-generated text responses + references: List of reference answers (single letters A-J) + instructions: Optional instructions text + task_name: Task identifier for logging + model_name: Model identifier for logging + model_responses: Optional model responses for logging + + Returns: + Dictionary with accuracy percentage under 'multiple_choice_accuracy' key + """ + self.instructions = instructions + self.model_responses = model_responses if model_responses else [] + + scores, normalized_candidates, normalized_references = self.compute_record_level_scores(candidates, references) + overall = self.get_score(candidates, references) + + if task_name and model_name: + score_list = scores.get(self.name, []) + write_record_log( + self, + normalized_references, + normalized_candidates, + score_list, + task_name, + model_name, + instructions=self.instructions, + model_responses=self.model_responses + ) + append_final_score(self, overall, task_name, model_name, self.model_responses) + + return overall + + def _extract_mc_answer(self, prediction: str) -> Optional[str]: + """ + Extracts the multiple-choice answer letter (A-J) from a prediction string. + Uses a staged approach: try the primary pattern first, then fallbacks in order. + Returns the last match from the first successful pattern, or None if nothing found. + Patterns based on: https://artificialanalysis.ai/methodology/intelligence-benchmarking + + Args: + prediction: The model's prediction string + + Returns: + Uppercase letter (A-J) if found, None otherwise + """ + if not isinstance(prediction, str): + return None + + patterns = [ + # Primary pattern: Answer: X + r"(?i)[\*\_]{0,2}Answer[\*\_]{0,2}\s*:[\s\*\_]{0,2}\s*([A-J])(?![a-zA-Z0-9])", + # LaTeX boxed notation + r"\\\\boxed\{[^}]*([A-J])[^}]*\}", + # Natural language + r"answer is ([a-jA-J])", + # With parenthesis + r"answer is\s*\(\s*([a-jA-J])\s*\)", + # Choice format: "D) ..." + r"([A-J])\)\s*[^A-J]*", + # Explicit statement: "E is the correct answer" + r"([A-J])\s+is\s+the\s+correct\s+answer", + # Standalone letter at end + r"([A-J])\s*$", + # Letter followed by period + r"([A-J])\s*\\.", + # Letter followed by non-word character + r"([A-J])\s*[^\w]", + ] + + for pat in patterns: + matches = re.findall(pat, prediction, re.IGNORECASE) + if matches: + return matches[-1].upper() + + return prediction + + + def compute_record_level_scores( + self, + candidates: List[str], + references: List[str] + ) -> Tuple[Dict[str, List[float]], List[str], List[str]]: + """Compute per-record scores for multiple choice answers. + + Args: + candidates: List of model-generated text responses + references: List of reference answers (single letters A-J) + + Returns: + Tuple of (scores dict, normalized candidates, normalized references) + """ + if len(candidates) != len(references): + raise ValueError(f"Mismatched lengths: {len(candidates)} candidates vs {len(references)} references") + + scores = [] + normalized_candidates = [] + normalized_references = [] + + for candidate, reference in zip(candidates, references): + pred = self._extract_mc_answer(candidate) + + normalized_candidates.append(pred) + normalized_references.append(reference) + + score = 1.0 if (pred is not None and pred == reference) else 0.0 + scores.append(score) + + return {self.name: scores}, candidates, references + + def get_score(self, candidates: List[str], references: List[str]) -> Dict[str, float]: + """Compute overall accuracy percentage. + + Args: + candidates: Generated text from the model + references: Reference text from the dataset + + Returns: + Dictionary with accuracy percentage under metric name + """ + + if not self.record_level_scores: + self.record_level_scores, _, _ = self.compute_record_level_scores(candidates, references) + + scores = self.record_level_scores.get(self.name, []) + accuracy = (sum(scores) / len(scores) * 100.0 if scores else 0.0) + + return {self.name: util.smart_round(accuracy, 2)} \ No newline at end of file diff --git a/tasks/README.md b/tasks/README.md index 3e49500..99a4030 100644 --- a/tasks/README.md +++ b/tasks/README.md @@ -48,6 +48,7 @@ For more detailed documentation regarding individual metrics, refer to [Metrics | `spoken_language_reasoning` | `ifeval` | `instruction_following` | | `spoken_language_reasoning` | `mtbench` | `mt_bench_llm_judge` | | `spoken_language_reasoning` | `gsm8k` | `gsm8k_exact_match` | +| `spoken_language_reasoning` | `gpqa_diamond` | `multiple_choice_accuracy` | | `safety_and_security` | `safety` | `detailed_judge_prompt` | | `safety_and_security` | `spoofing` | `detailed_judge_prompt`, `llm_judge_binary` | diff --git a/tasks/spoken_language_reasoning/README.md b/tasks/spoken_language_reasoning/README.md index 2e2f80f..fa96cc1 100644 --- a/tasks/spoken_language_reasoning/README.md +++ b/tasks/spoken_language_reasoning/README.md @@ -31,4 +31,5 @@ bash data/scripts/downnload_spider.sh | **IFEVAL** | Speech Instruction Following | [spoken_language_reasoning/ifeval](./ifeval/base.yaml)| Speech-based complex instruction following dataset | Apache-2.0 | | **BFCL** | Speech Function Calling | [spoken_language_reasoning/bfcl](./bfcl/base.yaml)| Speech-based complex function calling dataset with audio input | Apache-2.0 | | **SPEECH_TO_SQL** | Speech-to-Coding | [spoken_language_reasoning/speech_to_sql](./speech_to_sql/base.yaml)| Speech-based dataset involving following instructions to produce executable code | Apache-2.0 | -| **GSM8k** | Grade School Math | [spoken_language_reasoning/gsm8k](./gsm8k/base.yaml)| Speech-based math dataset with grade school math word problems | MIT (text dataset) | \ No newline at end of file +| **GSM8k** | Grade School Math | [spoken_language_reasoning/gsm8k](./gsm8k/base.yaml)| Speech-based math dataset with grade school math word problems | MIT (text dataset) | +| **GPQA Diamond** | Grade School Math | [spoken_language_reasoning/gpqa_diamond](./gpqa_diamond/base.yaml)| Speech based questions considered difficult, written and validated by experts in biology, physics, and chemistry. | cc-by-4.0 | \ No newline at end of file diff --git a/tasks/spoken_language_reasoning/gpqa_diamond/base.yaml b/tasks/spoken_language_reasoning/gpqa_diamond/base.yaml new file mode 100644 index 0000000..61da0aa --- /dev/null +++ b/tasks/spoken_language_reasoning/gpqa_diamond/base.yaml @@ -0,0 +1,17 @@ +# Base configuration for VoiceBench IFEval tasks +language: en +split: test +preprocessor: GeneralPreprocessor +postprocessor: GeneralPostprocessor +target_column: answer +long_audio_processing_logic: truncate +# Prompt from https://artificialanalysis.ai/methodology/intelligence-benchmarking +user_prompt: > + Answer the following multiple choice question. The last line of your response should be in the following format: 'Answer: A/B/C/D' (e.g. 'Answer: A'). + +generation_kwargs: + temperature: 0.001 + max_completion_tokens: 2048 + +metrics: + - metric: multiple_choice_accuracy \ No newline at end of file diff --git a/tasks/spoken_language_reasoning/gpqa_diamond/gpqa_diamond_audio.yaml b/tasks/spoken_language_reasoning/gpqa_diamond/gpqa_diamond_audio.yaml new file mode 100644 index 0000000..54f510c --- /dev/null +++ b/tasks/spoken_language_reasoning/gpqa_diamond/gpqa_diamond_audio.yaml @@ -0,0 +1,6 @@ +task_name: gpqa_diamond_audio +dataset_path: ServiceNow-AI/gpqa_audio +split: test +extends: ["./base.yaml#"] +modality: audio +audio_column: audio diff --git a/tasks/spoken_language_reasoning/gpqa_diamond/gpqa_diamond_text.yaml b/tasks/spoken_language_reasoning/gpqa_diamond/gpqa_diamond_text.yaml new file mode 100644 index 0000000..e83b4c0 --- /dev/null +++ b/tasks/spoken_language_reasoning/gpqa_diamond/gpqa_diamond_text.yaml @@ -0,0 +1,6 @@ +task_name: gpqa_diamond_text +dataset_path: ServiceNow-AI/gpqa_audio +split: test +extends: ["./base.yaml#"] +modality: text +textual_input_column: text_prompt diff --git a/tasks/spoken_language_reasoning/gsm8k/gsm8k_text.yaml b/tasks/spoken_language_reasoning/gsm8k/gsm8k_text.yaml index 7454984..d444a71 100644 --- a/tasks/spoken_language_reasoning/gsm8k/gsm8k_text.yaml +++ b/tasks/spoken_language_reasoning/gsm8k/gsm8k_text.yaml @@ -5,5 +5,3 @@ split: test extends: ["./base.yaml#"] modality: text textual_input_column: question -user_prompt: > - Solve the following math problem step by step. Put your final answer inside \boxed{}. diff --git a/utils/constants.py b/utils/constants.py index a1560cc..d72e7e3 100644 --- a/utils/constants.py +++ b/utils/constants.py @@ -69,6 +69,7 @@ 'sql_score': ("metrics.sql_score", "SqlScore"), "word_error_rate": ("metrics.word_error_rate_metrics", "WERMetrics"), "comet": ("metrics.comet_score", "CometScore"), + "multiple_choice_accuracy": ("metrics.multiple_choice_metrics", "MultipleChoiceMetrics"), "mt_bench_llm_judge": ("metrics.llm_judge", "MtbenchLLMJudgeMetric"), } @@ -135,6 +136,8 @@ 'ifeval': ['instruction_following'], 'speech_to_sql': ['sql_score'], 'gsm8k': ['gsm8k_exact_match'], + 'gpqa_diamond': ['multiple_choice_accuracy'], + 'mmlu': ['multiple_choice_accuracy'], # Safety and Security 'safety': ['llm_judge_redteaming'],