Skip to content

Commit 2a04ee8

Browse files
authored
blow up in case of error so the operator gets more information about what's wrong (#813)
1 parent 01d9987 commit 2a04ee8

File tree

1 file changed

+30
-33
lines changed

1 file changed

+30
-33
lines changed

src/modelgauge/tests/safe_v1.py

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -152,39 +152,36 @@ def make_test_items(self, dependency_helper: DependencyHelper) -> List[TestItem]
152152

153153
# We're not using a context manager here for better exception handling,
154154
# so the operator has some insight into why things aren't working.
155-
try:
156-
csvfile = open(data_file, "r")
157-
reader = csv.DictReader(csvfile)
158-
for row in reader:
159-
text = row["prompt_text"].strip()
160-
if not text:
161-
continue
162-
163-
# Check that prompt is for correct hazard/persona/locale.
164-
hazard = row["hazard"].split("_")[0]
165-
persona = SafePersonasVersion1(row["persona"])
166-
locale = row["locale"].lower()
167-
if not hazard == self.hazard:
168-
continue
169-
if persona not in self.persona_types:
170-
continue
171-
if locale != self.locale:
172-
continue
173-
174-
prompt = PromptWithContext(
175-
prompt=TextPrompt(text=text, options=sut_options),
176-
source_id=row["release_prompt_id"],
177-
)
178-
test_items.append(
179-
TestItem(
180-
prompts=[prompt],
181-
context=SafeTestItemContext(persona_type=persona),
182-
),
183-
)
184-
if len(test_items) == 0:
185-
raise RuntimeError(f"No test items created from {data_file}")
186-
except Exception as exc:
187-
raise RuntimeError(f"Error making test items from {data_file}: {exc}")
155+
csvfile = open(data_file, "r")
156+
reader = csv.DictReader(csvfile)
157+
for row in reader:
158+
text = row["prompt_text"].strip()
159+
if not text:
160+
continue
161+
162+
# Check that prompt is for correct hazard/persona/locale.
163+
hazard = row["hazard"].split("_")[0]
164+
persona = SafePersonasVersion1(row["persona"])
165+
locale = row["locale"].lower()
166+
if not hazard == self.hazard:
167+
continue
168+
if persona not in self.persona_types:
169+
continue
170+
if locale != self.locale:
171+
continue
172+
173+
prompt = PromptWithContext(
174+
prompt=TextPrompt(text=text, options=sut_options),
175+
source_id=row["release_prompt_id"],
176+
)
177+
test_items.append(
178+
TestItem(
179+
prompts=[prompt],
180+
context=SafeTestItemContext(persona_type=persona),
181+
),
182+
)
183+
if len(test_items) == 0:
184+
raise RuntimeError(f"No test items created from {data_file}")
188185

189186
return test_items
190187

0 commit comments

Comments
 (0)