diff --git a/lm_eval/tasks/ai2d/ai2d.yaml b/lm_eval/tasks/ai2d/ai2d.yaml new file mode 100644 index 0000000000..29c431fef3 --- /dev/null +++ b/lm_eval/tasks/ai2d/ai2d.yaml @@ -0,0 +1,33 @@ +dataset_path: lmms-lab/ai2d +task: ai2d +test_split: test +output_type: generate_until +doc_to_image: + - image +doc_to_text: "Look at the scientific diagram carefully and answer the following question: {{question | capitalize}}\n{% for option in options -%}{{['A', 'B', 'C', 'D', 'E', 'F'][loop.index0]}}. {{option | capitalize}}\n{% endfor -%} + +Think step by step and finally respond to the question with only the correct option number as \"FINAL ANSWER\"." +gen_prefix: "Let's think step by step." +doc_to_target: "{{ ['A', 'B', 'C', 'D', 'E', 'F'][answer|int] }}" +generation_kwargs: + until: [] + temperature: 0.0 + do_sample: false + max_gen_toks: 512 +filter_list: + - name: "strict-match" + filter: + - function: "regex" + regex_pattern: "(?:[`\\*_]*(?i:FINAL ANSWER|Final Answer|Answer|answer is)[`\\*_]*)[:\\s]*[`\\*_]*([A-D])[`\\*_]*" + - function: "take_first" + - name: "flexible-extract" + filter: + - function: "custom" + filter_fn: !function utils.flexible_extract + - function: "take_first" +metric_list: + - metric: exact_match + aggregation: mean + higher_is_better: true +metadata: + version: 0.0 diff --git a/lm_eval/tasks/ai2d/utils.py b/lm_eval/tasks/ai2d/utils.py new file mode 100644 index 0000000000..0c06652cba --- /dev/null +++ b/lm_eval/tasks/ai2d/utils.py @@ -0,0 +1,35 @@ +import re +import string + + +REGEX = re.compile( + "[`*_]*(?i:FINAL ANSWER|Final Answer|Answer|answer is)[`*_]*[:\s]*[`*_]*([A-D])[`*_]*" +) + + +def flexible_extract(resps, docs): + def filter_set(inst): + filtered = [] + for resp in inst: + # first, we try to match the regex pattern + if match := REGEX.findall(resp): + match = match[-1] + if match: + return match + # if we can't match the regex pattern, we try to match the last character + while resp[-1] in string.punctuation: + resp = resp[:-1] + if resp[-1] in ["A", "B", "C", "D"]: + resp = resp[-1] + else: + # match on A-D after a colon (last match), for example option: A. + pattern = r":\s*([A-D])" + matches = re.findall(pattern, resp) + if matches: + resp = matches[-1] + filtered.append(resp) + return filtered + + filtered_resps = list(map(lambda x: filter_set(x), resps)) + + return filtered_resps