Skip to content

Commit

Permalink
Merge pull request #112 from alan-turing-institute/rephrasals-pipeline
Browse files Browse the repository at this point in the history
Add rephrasal parser option to run_experiment command
  • Loading branch information
rchan26 authored Nov 26, 2024
2 parents 52e4a39 + 72aca2a commit 61d0e7d
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 28 deletions.
17 changes: 9 additions & 8 deletions src/prompto/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import time
from datetime import datetime
from typing import Callable

import pandas as pd
from tqdm import tqdm
Expand Down Expand Up @@ -318,7 +319,7 @@ def grouped_experiment_prompts_summary(self) -> dict[str, str]:
return queries_and_rates_per_group

async def process(
self, evaluation_funcs: list[callable] | None = None
self, evaluation_funcs: list[Callable] | None = None
) -> tuple[dict, float]:
"""
Function to process the experiment.
Expand All @@ -336,7 +337,7 @@ async def process(
Parameters
----------
evaluation_funcs : list[callable], optional
evaluation_funcs : list[Callable], optional
List of evaluation functions to run on the completed responses.
Each function should take a prompt_dict as input and return a prompt dict
as output. The evaluation functions can use keys in the prompt_dict to
Expand Down Expand Up @@ -452,7 +453,7 @@ async def send_requests(
attempt: int,
rate_limit: int,
group: str | None = None,
evaluation_funcs: list[callable] | None = None,
evaluation_funcs: list[Callable] | None = None,
) -> tuple[list[dict], list[dict | Exception]]:
"""
Send requests to the API asynchronously.
Expand Down Expand Up @@ -538,7 +539,7 @@ async def send_requests_retry(
prompt_dicts: list[dict],
rate_limit: int,
group: str | None = None,
evaluation_funcs: list[callable] | None = None,
evaluation_funcs: list[Callable] | None = None,
) -> None:
"""
Send requests to the API asynchronously and retry failed queries
Expand Down Expand Up @@ -612,7 +613,7 @@ async def query_model_and_record_response(
prompt_dict: dict,
index: int | str | None,
attempt: int,
evaluation_funcs: list[callable] | None = None,
evaluation_funcs: list[Callable] | None = None,
) -> dict | Exception:
"""
Send request to generate response from a LLM and record the response in a jsonl file.
Expand Down Expand Up @@ -729,7 +730,7 @@ async def generate_text(
self,
prompt_dict: dict,
index: int | None,
evaluation_funcs: list[callable] | None = None,
evaluation_funcs: list[Callable] | None = None,
) -> dict:
"""
Generate text by querying an LLM.
Expand Down Expand Up @@ -783,7 +784,7 @@ async def generate_text(
return response

async def evaluate_responses(
self, prompt_dict, evaluation_funcs: list[callable]
self, prompt_dict, evaluation_funcs: list[Callable]
) -> dict:
"""
Runs evaluation functions on a prompt dictionary. Note that the list of functions
Expand All @@ -794,7 +795,7 @@ async def evaluate_responses(
prompt_dict : dict
Dictionary for the evaluation functions to run on. Note: in the process function,
this will be run on self.completed_responses.
evaluation_funcs : list[callable]
evaluation_funcs : list[Callable]
List of evaluation functions to run on the completed responses. Each function should
take a prompt_dict as input and return a prompt dict as output. The evaluation
functions can use keys in the prompt_dict to parameterise the functions.
Expand Down
74 changes: 58 additions & 16 deletions src/prompto/rephrasal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
from typing import Callable

from tqdm import tqdm

Expand Down Expand Up @@ -280,7 +281,9 @@ def create_rephrase_file(
return rephrase_prompts

@staticmethod
def _convert_rephrased_prompt_dict_to_input(rephrased_prompt: dict) -> dict:
def _convert_rephrased_prompt_dict_to_input(
rephrased_prompt: dict, parser: Callable | None = None
) -> dict | list[dict]:
"""
Method to convert a completed rephrased prompt dictionary to an input prompt dictionary.
This is done by:
Expand All @@ -298,6 +301,10 @@ def _convert_rephrased_prompt_dict_to_input(rephrased_prompt: dict) -> dict:
A dictionary containing the rephrased prompt. Should usually contain
the keys "id", "prompt", "input-prompt" and "input-id". Should also
contain "input-api", "input-model_name" and "input-parameters" keys
parser : Callable, optional
A parser function to apply to the rephrased prompt response. This
function should take a string and return a string or a list of strings.
If None, no parser will be applied to the rephrased prompt response
Returns
-------
Expand All @@ -308,25 +315,51 @@ def _convert_rephrased_prompt_dict_to_input(rephrased_prompt: dict) -> dict:
The "id" key will indicate the rephrased prompt id. The "api", "model_name",
and other keys from the original input will be restored
"""
input_prompt = {
"id": rephrased_prompt["id"],
"prompt": rephrased_prompt["response"],
"input-prompt": rephrased_prompt["input-prompt"],
"input-id": rephrased_prompt.get("input-id", "MA"),
}
# apply parser (if it is provided) to response in the rephrased_prompt
# this may return a single string or a list of strings
if parser is not None:
response = parser(rephrased_prompt["response"])
if isinstance(response, str):
response = [response]

if not isinstance(response, list):
raise TypeError(
"Applying parser on rephrased_prompt['response'] must return a string or list of strings"
)
else:
response = [rephrased_prompt["response"]]

input_prompts = []
for i, resp in enumerate(response):
id = rephrased_prompt["id"]
if len(response) > 1:
id = f"{id}-{i}"

new_prompt_dict = {
"id": id,
"prompt": resp,
"input-prompt": rephrased_prompt["input-prompt"],
"input-id": rephrased_prompt.get("input-id", "NA"),
}

# restore the original input keys (e.g. "api", "model_name", "parameters")
for k, v in rephrased_prompt.items():
if k.startswith("input-") and k not in ["input-prompt", "input-id"]:
new_prompt_dict[k[6:]] = v

input_prompts.append(new_prompt_dict)

# restore the original input keys (e.g. "api", "model_name", "parameters")
for k, v in rephrased_prompt.items():
if k.startswith("input-") and k not in ["input-prompt", "input-id"]:
input_prompt[k[6:]] = v
if len(response) == 1:
return input_prompts[0]

return input_prompt
return input_prompts

def create_new_input_file(
self,
keep_original: bool,
completed_rephrase_responses: list[dict],
out_filepath: str,
parser: Callable | None = None,
) -> list[dict]:
"""
Method to create a new input file given the original input prompts and
Expand All @@ -349,6 +382,10 @@ def create_new_input_file(
out_filepath : str
The path to the output file where the new input prompts will
be saved as a jsonl file
parser : Callable, optional
A parser function to apply to the rephrased prompt response. This
function should take a string and return a string or a list of strings.
If None, no parser will be applied to the rephrased prompt response
Returns
-------
Expand All @@ -363,10 +400,15 @@ def create_new_input_file(
raise ValueError("out_filepath must end with '.jsonl'")

# obtain the new rephrased prompts
new_input_prompts = [
self._convert_rephrased_prompt_dict_to_input(rephrased_prompt)
for rephrased_prompt in completed_rephrase_responses
]
new_input_prompts = []
for rephrased_prompt in completed_rephrase_responses:
input_prompts = self._convert_rephrased_prompt_dict_to_input(
rephrased_prompt, parser=parser
)
if isinstance(input_prompts, dict):
new_input_prompts.append(input_prompts)
else:
new_input_prompts += input_prompts

# add the original input prompts if keep_original is True
if keep_original:
Expand Down
80 changes: 80 additions & 0 deletions src/prompto/rephrasal_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import logging
from typing import Callable

import regex as re


def obtain_parser_functions(
parser: str | list[str], parser_functions_dict: dict[str, Callable]
) -> list[Callable]:
"""
Check if the parser(s) provided are in the parser_functions_dict.
Parameters
----------
parser : str | list[str]
A single parser or a list of parsers to check if they
are keys in the parser_functions_dict dictionary
parser_functions_dict : dict[str, Callable]
A dictionary of parser functions with the keys as the
parser names and the values as the parser functions
Returns
-------
list[Callable]
List of parser functions that correspond to the parsers
"""
if isinstance(parser, str):
parser = [parser]

functions = []
for p in parser:
if not isinstance(p, str):
raise TypeError("If parser is a list, each element must be a string")
if p not in parser_functions_dict.keys():
raise KeyError(
f"Parser '{p}' is not a key in parser_functions_dict. "
f"Available parsers are: {list(parser_functions_dict.keys())}"
)

functions.append(parser_functions_dict[p])

logging.info(f"parser functions to be used: {parser}")
return functions


def remove_brackets(text: str) -> str:
# regex to remove anything brackets and anything between them
return re.sub(r"\(.*?\)", "", text)


def remove_quotation_marks(text: str) -> str:
# remove quotation marks only if they are at the beginning and end of the string
if text.startswith('"') and text.endswith('"'):
return text[1:-1]
return text


def split_numbered_list(text: str) -> list[str]:
# regex pattern matches:
# - Starts with one or more digits (\d+)
# - followed by a period (\.)
# - followed by optional whitespace (\s*)
pattern = r"\d+\.\s*"

# split the text and clean each part
parts = re.split(pattern, text)

# remove empty strings from the beginning if they exist
if parts and not parts[0]:
parts = parts[1:]

# strip whitespace and newlines from each part
parts = [remove_quotation_marks(remove_brackets(part).strip()) for part in parts]

return parts


PARSER_FUNCTIONS = {
"split_numbered_list": split_numbered_list,
}
9 changes: 5 additions & 4 deletions src/prompto/scorer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
from typing import Callable


def obtain_scoring_functions(
scorer: str | list[str], scoring_functions_dict: dict[str, callable]
) -> list[callable]:
scorer: str | list[str], scoring_functions_dict: dict[str, Callable]
) -> list[Callable]:
"""
Check if the scorer(s) provided are in the scoring_functions_dict.
Expand All @@ -12,13 +13,13 @@ def obtain_scoring_functions(
scorer : str | list[str]
A single scorer or a list of scorers to check if they
are keys in the scoring_functions_dict dictionary
scoring_functions_dict : dict[str, callable]
scoring_functions_dict : dict[str, Callable]
A dictionary of scoring functions with the keys as the
scorer names and the values as the scoring functions
Returns
-------
list[callable]
list[Callable]
List of scoring functions that correspond to the scorers
"""
if isinstance(scorer, str):
Expand Down
19 changes: 19 additions & 0 deletions src/prompto/scripts/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from prompto.experiment import Experiment
from prompto.judge import Judge, load_judge_folder
from prompto.rephrasal import Rephraser, load_rephrase_folder
from prompto.rephrasal_parser import PARSER_FUNCTIONS, obtain_parser_functions
from prompto.scorer import SCORING_FUNCTIONS, obtain_scoring_functions
from prompto.settings import Settings
from prompto.utils import copy_file, create_folder, move_file, parse_list_arg
Expand Down Expand Up @@ -524,6 +525,16 @@ async def main():
type=str,
default=None,
)
parser.add_argument(
"--rephrase-parser",
"-rp",
help=(
"Parser to be used. "
"This must be a key in the parser functions dictionary"
),
type=str,
default=None,
),
parser.add_argument(
"--remove-original",
"-ro",
Expand Down Expand Up @@ -677,10 +688,18 @@ async def main():
rephrased_experiment_path = (
f"{settings.input_folder}/{rephrased_experiment_file_name}"
)
if args.rephrase_parser is not None:
parser_function = obtain_parser_functions(
parser=args.rephrase_parser, parser_functions_dict=PARSER_FUNCTIONS
)[0]
else:
parser_function = None

rephraser.create_new_input_file(
keep_original=not args.remove_original,
completed_rephrase_responses=rephrase_experiment.completed_responses,
out_filepath=rephrased_experiment_path,
parser=parser_function,
)

if args.only_rephrase:
Expand Down

0 comments on commit 61d0e7d

Please sign in to comment.