Skip to content

Commit

Permalink
update judge argument ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
rchan26 committed Sep 11, 2024
1 parent f70ecc3 commit 9c9f077
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 22 deletions.
14 changes: 7 additions & 7 deletions src/prompto/judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ class Judge:
A list of dictionaries containing the responses to judge.
Each dictionary should contain the keys "prompt",
and "response"
judge_settings : dict
A dictionary of judge settings with the keys "api",
"model_name", "parameters". Used to define the
judge LLMs to be used in the judging process
template_prompt : dict[str, str]
A dictionary containing the template prompt strings
to be used for the judge LLMs. The keys should be the
Expand All @@ -94,20 +90,24 @@ class Judge:
for the input prompt (INPUT_PROMPT) and the
output response (OUTPUT_RESPONSE) which will be formatted
with the prompt and response from the completed prompt dict
judge_settings : dict
A dictionary of judge settings with the keys "api",
"model_name", "parameters". Used to define the
judge LLMs to be used in the judging process
"""

def __init__(
self,
completed_responses: list[dict],
judge_settings: dict,
template_prompts: dict[str, str],
judge_settings: dict,
):
self.check_judge_settings(judge_settings)
if not isinstance(template_prompts, dict):
raise TypeError("template_prompts must be a dictionary")
self.check_judge_settings(judge_settings)
self.completed_responses = completed_responses
self.judge_settings = judge_settings
self.template_prompts = template_prompts
self.judge_settings = judge_settings

@staticmethod
def check_judge_settings(judge_settings: dict[str, dict]) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion src/prompto/scripts/create_judge_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def main():
# create judge object from the parsed arguments
j = Judge(
completed_responses=responses,
judge_settings=judge_settings,
template_prompts=template_prompts,
judge_settings=judge_settings,
)

# create judge file
Expand Down
28 changes: 14 additions & 14 deletions tests/core/test_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,33 +274,33 @@ def test_check_judge_init():
):
Judge()

# raise error if judge_settings is not a valid dictionary
# raise error if template_prompts is not a dictionary
with pytest.raises(
TypeError,
match="judge_settings must be a dictionary",
match="template_prompts must be a dictionary",
):
Judge(
completed_responses="completed_responses",
judge_settings="not_a_dict",
template_prompts="template_prompt",
completed_responses="completed_responses (no check on list of dicts)",
template_prompts="not_a_dict",
judge_settings=JUDGE_SETTINGS,
)

# raise error if template_prompts is not a dictionary
# raise error if judge_settings is not a valid dictionary
with pytest.raises(
TypeError,
match="template_prompts must be a dictionary",
match="judge_settings must be a dictionary",
):
Judge(
completed_responses="completed_responses",
judge_settings=JUDGE_SETTINGS,
template_prompts="not_a_dict",
completed_responses="completed_responses (no check on list of dicts)",
template_prompts={"template": "some template"},
judge_settings="not_a_dict",
)

tp = {"temp": "prompt: {INPUT_PROMPT} || response: {OUTPUT_RESPONSE}"}
judge = Judge(
completed_responses=COMPLETED_RESPONSES,
judge_settings=JUDGE_SETTINGS,
template_prompts=tp,
judge_settings=JUDGE_SETTINGS,
)
assert judge.completed_responses == COMPLETED_RESPONSES
assert judge.judge_settings == JUDGE_SETTINGS
Expand All @@ -311,8 +311,8 @@ def test_judge_create_judge_inputs_errors():
tp = {"temp": "prompt: {INPUT_PROMPT} || response: {OUTPUT_RESPONSE}"}
judge = Judge(
completed_responses=COMPLETED_RESPONSES,
judge_settings=JUDGE_SETTINGS,
template_prompts=tp,
judge_settings=JUDGE_SETTINGS,
)

# raise error if judge not provided
Expand Down Expand Up @@ -348,8 +348,8 @@ def test_judge_create_judge_inputs():
tp = {"temp": "prompt: {INPUT_PROMPT} || response: {OUTPUT_RESPONSE}"}
judge = Judge(
completed_responses=COMPLETED_RESPONSES,
judge_settings=JUDGE_SETTINGS,
template_prompts=tp,
judge_settings=JUDGE_SETTINGS,
)

# "judge1" case
Expand Down Expand Up @@ -467,8 +467,8 @@ def test_judge_create_judge_file(temporary_data_folder_judge):
}
judge = Judge(
completed_responses=COMPLETED_RESPONSES,
judge_settings=JUDGE_SETTINGS,
template_prompts=tp,
judge_settings=JUDGE_SETTINGS,
)

# raise error if nothing is provided
Expand Down

0 comments on commit 9c9f077

Please sign in to comment.