Skip to content

Commit

Permalink
blow up in case of error so the operator gets more information about …
Browse files Browse the repository at this point in the history
…what's wrong (#813)
  • Loading branch information
rogthefrog authored Jan 22, 2025
1 parent 01d9987 commit 2a04ee8
Showing 1 changed file with 30 additions and 33 deletions.
63 changes: 30 additions & 33 deletions src/modelgauge/tests/safe_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,39 +152,36 @@ def make_test_items(self, dependency_helper: DependencyHelper) -> List[TestItem]

# We're not using a context manager here for better exception handling,
# so the operator has some insight into why things aren't working.
try:
csvfile = open(data_file, "r")
reader = csv.DictReader(csvfile)
for row in reader:
text = row["prompt_text"].strip()
if not text:
continue

# Check that prompt is for correct hazard/persona/locale.
hazard = row["hazard"].split("_")[0]
persona = SafePersonasVersion1(row["persona"])
locale = row["locale"].lower()
if not hazard == self.hazard:
continue
if persona not in self.persona_types:
continue
if locale != self.locale:
continue

prompt = PromptWithContext(
prompt=TextPrompt(text=text, options=sut_options),
source_id=row["release_prompt_id"],
)
test_items.append(
TestItem(
prompts=[prompt],
context=SafeTestItemContext(persona_type=persona),
),
)
if len(test_items) == 0:
raise RuntimeError(f"No test items created from {data_file}")
except Exception as exc:
raise RuntimeError(f"Error making test items from {data_file}: {exc}")
csvfile = open(data_file, "r")
reader = csv.DictReader(csvfile)
for row in reader:
text = row["prompt_text"].strip()
if not text:
continue

# Check that prompt is for correct hazard/persona/locale.
hazard = row["hazard"].split("_")[0]
persona = SafePersonasVersion1(row["persona"])
locale = row["locale"].lower()
if not hazard == self.hazard:
continue
if persona not in self.persona_types:
continue
if locale != self.locale:
continue

prompt = PromptWithContext(
prompt=TextPrompt(text=text, options=sut_options),
source_id=row["release_prompt_id"],
)
test_items.append(
TestItem(
prompts=[prompt],
context=SafeTestItemContext(persona_type=persona),
),
)
if len(test_items) == 0:
raise RuntimeError(f"No test items created from {data_file}")

return test_items

Expand Down

0 comments on commit 2a04ee8

Please sign in to comment.