From 9c9f0779b6878546b519a2acb2be4238819275b5 Mon Sep 17 00:00:00 2001 From: rchan Date: Wed, 11 Sep 2024 17:42:30 +0100 Subject: [PATCH] update judge argument ordering --- src/prompto/judge.py | 14 ++++++------ src/prompto/scripts/create_judge_file.py | 2 +- tests/core/test_judge.py | 28 ++++++++++++------------ 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/src/prompto/judge.py b/src/prompto/judge.py index 4c1afb63..51759ad7 100644 --- a/src/prompto/judge.py +++ b/src/prompto/judge.py @@ -81,10 +81,6 @@ class Judge: A list of dictionaries containing the responses to judge. Each dictionary should contain the keys "prompt", and "response" - judge_settings : dict - A dictionary of judge settings with the keys "api", - "model_name", "parameters". Used to define the - judge LLMs to be used in the judging process template_prompt : dict[str, str] A dictionary containing the template prompt strings to be used for the judge LLMs. The keys should be the @@ -94,20 +90,24 @@ class Judge: for the input prompt (INPUT_PROMPT) and the output response (OUTPUT_RESPONSE) which will be formatted with the prompt and response from the completed prompt dict + judge_settings : dict + A dictionary of judge settings with the keys "api", + "model_name", "parameters". Used to define the + judge LLMs to be used in the judging process """ def __init__( self, completed_responses: list[dict], - judge_settings: dict, template_prompts: dict[str, str], + judge_settings: dict, ): - self.check_judge_settings(judge_settings) if not isinstance(template_prompts, dict): raise TypeError("template_prompts must be a dictionary") + self.check_judge_settings(judge_settings) self.completed_responses = completed_responses - self.judge_settings = judge_settings self.template_prompts = template_prompts + self.judge_settings = judge_settings @staticmethod def check_judge_settings(judge_settings: dict[str, dict]) -> bool: diff --git a/src/prompto/scripts/create_judge_file.py b/src/prompto/scripts/create_judge_file.py index 69c05d99..991e10bd 100644 --- a/src/prompto/scripts/create_judge_file.py +++ b/src/prompto/scripts/create_judge_file.py @@ -92,8 +92,8 @@ def main(): # create judge object from the parsed arguments j = Judge( completed_responses=responses, - judge_settings=judge_settings, template_prompts=template_prompts, + judge_settings=judge_settings, ) # create judge file diff --git a/tests/core/test_judge.py b/tests/core/test_judge.py index 6fd8ea3c..966a061b 100644 --- a/tests/core/test_judge.py +++ b/tests/core/test_judge.py @@ -274,33 +274,33 @@ def test_check_judge_init(): ): Judge() - # raise error if judge_settings is not a valid dictionary + # raise error if template_prompts is not a dictionary with pytest.raises( TypeError, - match="judge_settings must be a dictionary", + match="template_prompts must be a dictionary", ): Judge( - completed_responses="completed_responses", - judge_settings="not_a_dict", - template_prompts="template_prompt", + completed_responses="completed_responses (no check on list of dicts)", + template_prompts="not_a_dict", + judge_settings=JUDGE_SETTINGS, ) - # raise error if template_prompts is not a dictionary + # raise error if judge_settings is not a valid dictionary with pytest.raises( TypeError, - match="template_prompts must be a dictionary", + match="judge_settings must be a dictionary", ): Judge( - completed_responses="completed_responses", - judge_settings=JUDGE_SETTINGS, - template_prompts="not_a_dict", + completed_responses="completed_responses (no check on list of dicts)", + template_prompts={"template": "some template"}, + judge_settings="not_a_dict", ) tp = {"temp": "prompt: {INPUT_PROMPT} || response: {OUTPUT_RESPONSE}"} judge = Judge( completed_responses=COMPLETED_RESPONSES, - judge_settings=JUDGE_SETTINGS, template_prompts=tp, + judge_settings=JUDGE_SETTINGS, ) assert judge.completed_responses == COMPLETED_RESPONSES assert judge.judge_settings == JUDGE_SETTINGS @@ -311,8 +311,8 @@ def test_judge_create_judge_inputs_errors(): tp = {"temp": "prompt: {INPUT_PROMPT} || response: {OUTPUT_RESPONSE}"} judge = Judge( completed_responses=COMPLETED_RESPONSES, - judge_settings=JUDGE_SETTINGS, template_prompts=tp, + judge_settings=JUDGE_SETTINGS, ) # raise error if judge not provided @@ -348,8 +348,8 @@ def test_judge_create_judge_inputs(): tp = {"temp": "prompt: {INPUT_PROMPT} || response: {OUTPUT_RESPONSE}"} judge = Judge( completed_responses=COMPLETED_RESPONSES, - judge_settings=JUDGE_SETTINGS, template_prompts=tp, + judge_settings=JUDGE_SETTINGS, ) # "judge1" case @@ -467,8 +467,8 @@ def test_judge_create_judge_file(temporary_data_folder_judge): } judge = Judge( completed_responses=COMPLETED_RESPONSES, - judge_settings=JUDGE_SETTINGS, template_prompts=tp, + judge_settings=JUDGE_SETTINGS, ) # raise error if nothing is provided