diff --git a/src/prompto/experiment.py b/src/prompto/experiment.py index 6c76eff..bf111da 100644 --- a/src/prompto/experiment.py +++ b/src/prompto/experiment.py @@ -4,6 +4,7 @@ import os import time from datetime import datetime +from typing import Callable import pandas as pd from tqdm import tqdm @@ -318,7 +319,7 @@ def grouped_experiment_prompts_summary(self) -> dict[str, str]: return queries_and_rates_per_group async def process( - self, evaluation_funcs: list[callable] | None = None + self, evaluation_funcs: list[Callable] | None = None ) -> tuple[dict, float]: """ Function to process the experiment. @@ -336,7 +337,7 @@ async def process( Parameters ---------- - evaluation_funcs : list[callable], optional + evaluation_funcs : list[Callable], optional List of evaluation functions to run on the completed responses. Each function should take a prompt_dict as input and return a prompt dict as output. The evaluation functions can use keys in the prompt_dict to @@ -452,7 +453,7 @@ async def send_requests( attempt: int, rate_limit: int, group: str | None = None, - evaluation_funcs: list[callable] | None = None, + evaluation_funcs: list[Callable] | None = None, ) -> tuple[list[dict], list[dict | Exception]]: """ Send requests to the API asynchronously. @@ -538,7 +539,7 @@ async def send_requests_retry( prompt_dicts: list[dict], rate_limit: int, group: str | None = None, - evaluation_funcs: list[callable] | None = None, + evaluation_funcs: list[Callable] | None = None, ) -> None: """ Send requests to the API asynchronously and retry failed queries @@ -612,7 +613,7 @@ async def query_model_and_record_response( prompt_dict: dict, index: int | str | None, attempt: int, - evaluation_funcs: list[callable] | None = None, + evaluation_funcs: list[Callable] | None = None, ) -> dict | Exception: """ Send request to generate response from a LLM and record the response in a jsonl file. @@ -729,7 +730,7 @@ async def generate_text( self, prompt_dict: dict, index: int | None, - evaluation_funcs: list[callable] | None = None, + evaluation_funcs: list[Callable] | None = None, ) -> dict: """ Generate text by querying an LLM. @@ -783,7 +784,7 @@ async def generate_text( return response async def evaluate_responses( - self, prompt_dict, evaluation_funcs: list[callable] + self, prompt_dict, evaluation_funcs: list[Callable] ) -> dict: """ Runs evaluation functions on a prompt dictionary. Note that the list of functions @@ -794,7 +795,7 @@ async def evaluate_responses( prompt_dict : dict Dictionary for the evaluation functions to run on. Note: in the process function, this will be run on self.completed_responses. - evaluation_funcs : list[callable] + evaluation_funcs : list[Callable] List of evaluation functions to run on the completed responses. Each function should take a prompt_dict as input and return a prompt dict as output. The evaluation functions can use keys in the prompt_dict to parameterise the functions. diff --git a/src/prompto/rephrasal.py b/src/prompto/rephrasal.py index 5ad67a5..d6740c8 100644 --- a/src/prompto/rephrasal.py +++ b/src/prompto/rephrasal.py @@ -1,6 +1,7 @@ import json import logging import os +from typing import Callable from tqdm import tqdm @@ -280,7 +281,9 @@ def create_rephrase_file( return rephrase_prompts @staticmethod - def _convert_rephrased_prompt_dict_to_input(rephrased_prompt: dict) -> dict: + def _convert_rephrased_prompt_dict_to_input( + rephrased_prompt: dict, parser: Callable | None = None + ) -> dict | list[dict]: """ Method to convert a completed rephrased prompt dictionary to an input prompt dictionary. This is done by: @@ -298,6 +301,10 @@ def _convert_rephrased_prompt_dict_to_input(rephrased_prompt: dict) -> dict: A dictionary containing the rephrased prompt. Should usually contain the keys "id", "prompt", "input-prompt" and "input-id". Should also contain "input-api", "input-model_name" and "input-parameters" keys + parser : Callable, optional + A parser function to apply to the rephrased prompt response. This + function should take a string and return a string or a list of strings. + If None, no parser will be applied to the rephrased prompt response Returns ------- @@ -308,25 +315,51 @@ def _convert_rephrased_prompt_dict_to_input(rephrased_prompt: dict) -> dict: The "id" key will indicate the rephrased prompt id. The "api", "model_name", and other keys from the original input will be restored """ - input_prompt = { - "id": rephrased_prompt["id"], - "prompt": rephrased_prompt["response"], - "input-prompt": rephrased_prompt["input-prompt"], - "input-id": rephrased_prompt.get("input-id", "MA"), - } + # apply parser (if it is provided) to response in the rephrased_prompt + # this may return a single string or a list of strings + if parser is not None: + response = parser(rephrased_prompt["response"]) + if isinstance(response, str): + response = [response] + + if not isinstance(response, list): + raise TypeError( + "Applying parser on rephrased_prompt['response'] must return a string or list of strings" + ) + else: + response = [rephrased_prompt["response"]] + + input_prompts = [] + for i, resp in enumerate(response): + id = rephrased_prompt["id"] + if len(response) > 1: + id = f"{id}-{i}" + + new_prompt_dict = { + "id": id, + "prompt": resp, + "input-prompt": rephrased_prompt["input-prompt"], + "input-id": rephrased_prompt.get("input-id", "MA"), + } + + # restore the original input keys (e.g. "api", "model_name", "parameters") + for k, v in rephrased_prompt.items(): + if k.startswith("input-") and k not in ["input-prompt", "input-id"]: + new_prompt_dict[k[6:]] = v + + input_prompts.append(new_prompt_dict) - # restore the original input keys (e.g. "api", "model_name", "parameters") - for k, v in rephrased_prompt.items(): - if k.startswith("input-") and k not in ["input-prompt", "input-id"]: - input_prompt[k[6:]] = v + if len(response) == 1: + return input_prompts[0] - return input_prompt + return input_prompts def create_new_input_file( self, keep_original: bool, completed_rephrase_responses: list[dict], out_filepath: str, + parser: Callable | None = None, ) -> list[dict]: """ Method to create a new input file given the original input prompts and @@ -349,6 +382,10 @@ def create_new_input_file( out_filepath : str The path to the output file where the new input prompts will be saved as a jsonl file + parser : Callable, optional + A parser function to apply to the rephrased prompt response. This + function should take a string and return a string or a list of strings. + If None, no parser will be applied to the rephrased prompt response Returns ------- @@ -363,10 +400,15 @@ def create_new_input_file( raise ValueError("out_filepath must end with '.jsonl'") # obtain the new rephrased prompts - new_input_prompts = [ - self._convert_rephrased_prompt_dict_to_input(rephrased_prompt) - for rephrased_prompt in completed_rephrase_responses - ] + new_input_prompts = [] + for rephrased_prompt in completed_rephrase_responses: + input_prompts = self._convert_rephrased_prompt_dict_to_input( + rephrased_prompt, parser=parser + ) + if isinstance(input_prompts, dict): + new_input_prompts.append(input_prompts) + else: + new_input_prompts += input_prompts # add the original input prompts if keep_original is True if keep_original: diff --git a/src/prompto/rephrasal_parser.py b/src/prompto/rephrasal_parser.py new file mode 100644 index 0000000..173b1e0 --- /dev/null +++ b/src/prompto/rephrasal_parser.py @@ -0,0 +1,80 @@ +import logging +from typing import Callable + +import regex as re + + +def obtain_parser_functions( + parser: str | list[str], parser_functions_dict: dict[str, Callable] +) -> list[Callable]: + """ + Check if the parser(s) provided are in the parser_functions_dict. + + Parameters + ---------- + parser : str | list[str] + A single parser or a list of parsers to check if they + are keys in the parser_functions_dict dictionary + parser_functions_dict : dict[str, Callable] + A dictionary of parser functions with the keys as the + parser names and the values as the parser functions + + Returns + ------- + list[Callable] + List of parser functions that correspond to the parsers + """ + if isinstance(parser, str): + parser = [parser] + + functions = [] + for p in parser: + if not isinstance(p, str): + raise TypeError("If parser is a list, each element must be a string") + if p not in parser_functions_dict.keys(): + raise KeyError( + f"Parser '{p}' is not a key in parser_functions_dict. " + f"Available parsers are: {list(parser_functions_dict.keys())}" + ) + + functions.append(parser_functions_dict[p]) + + logging.info(f"parser functions to be used: {parser}") + return functions + + +def remove_brackets(text: str) -> str: + # regex to remove anything brackets and anything between them + return re.sub(r"\(.*?\)", "", text) + + +def remove_quotation_marks(text: str) -> str: + # remove quotation marks only if they are at the beginning and end of the string + if text.startswith('"') and text.endswith('"'): + return text[1:-1] + return text + + +def split_numbered_list(text: str) -> list[str]: + # regex pattern matches: + # - Starts with one or more digits (\d+) + # - followed by a period (\.) + # - followed by optional whitespace (\s*) + pattern = r"\d+\.\s*" + + # split the text and clean each part + parts = re.split(pattern, text) + + # remove empty strings from the beginning if they exist + if parts and not parts[0]: + parts = parts[1:] + + # strip whitespace and newlines from each part + parts = [remove_quotation_marks(remove_brackets(part).strip()) for part in parts] + + return parts + + +PARSER_FUNCTIONS = { + "split_numbered_list": split_numbered_list, +} diff --git a/src/prompto/scorer.py b/src/prompto/scorer.py index 47558d0..680c514 100644 --- a/src/prompto/scorer.py +++ b/src/prompto/scorer.py @@ -1,9 +1,10 @@ import logging +from typing import Callable def obtain_scoring_functions( - scorer: str | list[str], scoring_functions_dict: dict[str, callable] -) -> list[callable]: + scorer: str | list[str], scoring_functions_dict: dict[str, Callable] +) -> list[Callable]: """ Check if the scorer(s) provided are in the scoring_functions_dict. @@ -12,13 +13,13 @@ def obtain_scoring_functions( scorer : str | list[str] A single scorer or a list of scorers to check if they are keys in the scoring_functions_dict dictionary - scoring_functions_dict : dict[str, callable] + scoring_functions_dict : dict[str, Callable] A dictionary of scoring functions with the keys as the scorer names and the values as the scoring functions Returns ------- - list[callable] + list[Callable] List of scoring functions that correspond to the scorers """ if isinstance(scorer, str): diff --git a/src/prompto/scripts/run_experiment.py b/src/prompto/scripts/run_experiment.py index 0ce11e2..5531c49 100644 --- a/src/prompto/scripts/run_experiment.py +++ b/src/prompto/scripts/run_experiment.py @@ -9,6 +9,7 @@ from prompto.experiment import Experiment from prompto.judge import Judge, load_judge_folder from prompto.rephrasal import Rephraser, load_rephrase_folder +from prompto.rephrasal_parser import PARSER_FUNCTIONS, obtain_parser_functions from prompto.scorer import SCORING_FUNCTIONS, obtain_scoring_functions from prompto.settings import Settings from prompto.utils import copy_file, create_folder, move_file, parse_list_arg @@ -524,6 +525,16 @@ async def main(): type=str, default=None, ) + parser.add_argument( + "--rephrase-parser", + "-rp", + help=( + "Parser to be used. " + "This must be a key in the parser functions dictionary" + ), + type=str, + default=None, + ), parser.add_argument( "--remove-original", "-ro", @@ -681,6 +692,9 @@ async def main(): keep_original=not args.remove_original, completed_rephrase_responses=rephrase_experiment.completed_responses, out_filepath=rephrased_experiment_path, + parser=obtain_parser_functions( + parser=args.rephrase_parser, parser_functions_dict=PARSER_FUNCTIONS + )[0], ) if args.only_rephrase: