Skip to content

Commit d25ce41

Browse files
authored
FEAT: add support for multimodal data from HarmBench (#1110)
1 parent c0fb5cd commit d25ce41

File tree

5 files changed

+495
-0
lines changed

5 files changed

+495
-0
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ API Reference
126126
fetch_examples
127127
fetch_forbidden_questions_dataset
128128
fetch_harmbench_dataset
129+
fetch_harmbench_multimodal_dataset_async
129130
fetch_librAI_do_not_answer_dataset
130131
fetch_llm_latent_adversarial_training_harmful_dataset
131132
fetch_jbb_behaviors_by_harm_category

pyrit/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pyrit.datasets.dataset_helper import fetch_examples
1111
from pyrit.datasets.forbidden_questions_dataset import fetch_forbidden_questions_dataset
1212
from pyrit.datasets.harmbench_dataset import fetch_harmbench_dataset
13+
from pyrit.datasets.harmbench_multimodal_dataset import fetch_harmbench_multimodal_dataset_async
1314
from pyrit.datasets.librAI_do_not_answer_dataset import fetch_librAI_do_not_answer_dataset
1415
from pyrit.datasets.llm_latent_adversarial_training_harmful_dataset import (
1516
fetch_llm_latent_adversarial_training_harmful_dataset,
@@ -47,6 +48,7 @@
4748
"fetch_examples",
4849
"fetch_forbidden_questions_dataset",
4950
"fetch_harmbench_dataset",
51+
"fetch_harmbench_multimodal_dataset_async",
5052
"fetch_librAI_do_not_answer_dataset",
5153
"fetch_llm_latent_adversarial_training_harmful_dataset",
5254
"fetch_many_shot_jailbreaking_dataset",
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
import logging
5+
import uuid
6+
from enum import Enum
7+
from pathlib import Path
8+
from typing import List, Literal, Optional
9+
10+
from pyrit.common.net_utility import make_request_and_raise_if_error_async
11+
from pyrit.datasets.dataset_helper import fetch_examples
12+
from pyrit.models import SeedPromptDataset, data_serializer_factory
13+
from pyrit.models.seed_prompt import SeedPrompt
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
class SemanticCategory(Enum):
19+
CYBERCRIME_INTRUSION = "cybercrime_intrusion" # n=54
20+
ILLEGAL = "illegal" # 36
21+
HARMFUL = "harmful" # 9
22+
CHEMICAL_BIOLOGICAL = "chemical_biological" # 4
23+
HARASSMENT_BULLYING = "harassment_bullying" # 4
24+
MISINFORMATION_DISINFORMATION = "misinformation_disinformation" # 3
25+
26+
27+
async def fetch_harmbench_multimodal_dataset_async(
28+
*,
29+
source: str = (
30+
"https://raw.githubusercontent.com/centerforaisafety/HarmBench/c0423b9/data/behavior_datasets/"
31+
"harmbench_behaviors_multimodal_all.csv"
32+
),
33+
source_type: Literal["public_url", "file"] = "public_url",
34+
cache: bool = True,
35+
data_home: Optional[Path] = None,
36+
categories: Optional[List[SemanticCategory]] = None,
37+
) -> SeedPromptDataset:
38+
"""
39+
Fetch HarmBench multimodal examples and create a SeedPromptDataset.
40+
41+
The HarmBench multimodal dataset contains 110 harmful behaviors.
42+
Each example consists of an image ("image_path") and a behavior string referencing the image ("text").
43+
The text and image prompts that belong to the same example are linked using the same ``prompt_group_id``.
44+
You can extract the grouped prompts using the ``group_seed_prompts_by_prompt_group_id`` method.
45+
46+
Note: The first call may be slow as images need to be downloaded from the remote repository.
47+
Subsequent calls will be faster since images are cached locally and won't need to be re-downloaded.
48+
49+
Args:
50+
source (str): The source from which to fetch examples. Defaults to the HarmBench repository.
51+
source_type (Literal["public_url", "file"]): The type of source. Defaults to 'public_url'.
52+
cache (bool): Whether to cache the fetched examples. Defaults to True.
53+
data_home (Optional[Path]): Directory to store cached data. Defaults to None.
54+
categories (Optional[List[SemanticCategory]]): List of semantic categories
55+
to filter examples. If None, all categories are included (default).
56+
57+
Returns:
58+
SeedPromptDataset: A SeedPromptDataset containing the multimodal examples.
59+
60+
Raises:
61+
ValueError: If any of the specified categories are invalid.
62+
63+
Note:
64+
For more information related to the HarmBench project and the original dataset, visit:
65+
https://www.harmbench.org/ \n
66+
Paper: https://arxiv.org/abs/2402.04249 \n
67+
Authors:
68+
Mantas Mazeika & Long Phan & Xuwang Yin & Andy Zou & Zifan Wang & Norman Mu & Elham Sakhaee
69+
& Nathaniel Li & Steven Basart & Bo Li & David Forsyth & Dan Hendrycks
70+
"""
71+
if categories is not None:
72+
valid_categories = {category.value for category in SemanticCategory}
73+
invalid_categories = (
74+
set(cat.value if isinstance(cat, SemanticCategory) else cat for cat in categories) - valid_categories
75+
)
76+
if invalid_categories:
77+
raise ValueError(f"Invalid semantic categories: {', '.join(invalid_categories)}")
78+
79+
required_keys = {"Behavior", "BehaviorID", "FunctionalCategory", "SemanticCategory", "ImageFileName"}
80+
examples = fetch_examples(source, source_type, cache, data_home)
81+
prompts = []
82+
failed_image_count = 0
83+
84+
for example in examples:
85+
missing_keys = required_keys - example.keys()
86+
if missing_keys:
87+
raise ValueError(f"Missing keys in example: {', '.join(missing_keys)}")
88+
89+
if example["FunctionalCategory"] != "multimodal":
90+
continue
91+
92+
semantic_category = example["SemanticCategory"]
93+
94+
if categories is not None:
95+
category_values = {cat.value for cat in categories}
96+
if semantic_category not in category_values:
97+
continue
98+
99+
behavior_text = example["Behavior"]
100+
behavior_id = example["BehaviorID"]
101+
image_filename = example["ImageFileName"]
102+
image_description = example.get("ImageDescription", "")
103+
redacted_description = example.get("RedactedImageDescription", "")
104+
105+
# A unique group ID to link the text and image prompts since they are part of the same example
106+
group_id = uuid.uuid4()
107+
108+
# Note: All images in the HarmBench dataset are stored as .png files, even if the ImageFileName
109+
# field specifies a different extension (.jpg or .jpeg). Hence we always use .png extension here.
110+
image_url = (
111+
"https://raw.githubusercontent.com/centerforaisafety/HarmBench/c0423b9/data/multimodal_behavior_images/"
112+
f"{image_filename.rsplit('.', 1)[0]}.png"
113+
)
114+
115+
try:
116+
# Only include examples where image fetch is successful
117+
local_image_path = await _fetch_and_save_image_async(image_url, behavior_id)
118+
119+
image_prompt = SeedPrompt(
120+
value=local_image_path,
121+
data_type="image_path",
122+
name=f"HarmBench Multimodal Image - {behavior_id}",
123+
dataset_name="HarmBench Multimodal Examples",
124+
harm_categories=[semantic_category],
125+
description=f"An image prompt from the HarmBench multimodal dataset, BehaviorID: {behavior_id}",
126+
source=source,
127+
prompt_group_id=group_id,
128+
sequence=0,
129+
metadata={
130+
"behavior_id": behavior_id,
131+
"image_description": image_description,
132+
"redacted_image_description": redacted_description,
133+
"original_image_url": image_url,
134+
},
135+
)
136+
prompts.append(image_prompt)
137+
except Exception as e:
138+
failed_image_count += 1
139+
logger.warning(f"Failed to fetch image for behavior {behavior_id}: {e}. Skipping this example.")
140+
else:
141+
text_prompt = SeedPrompt(
142+
value=behavior_text,
143+
data_type="text",
144+
name=f"HarmBench Multimodal Text - {behavior_id}",
145+
dataset_name="HarmBench Multimodal Examples",
146+
harm_categories=[semantic_category],
147+
description=(f"A text prompt from the HarmBench multimodal dataset, BehaviorID: {behavior_id}"),
148+
source=source,
149+
prompt_group_id=group_id,
150+
sequence=0,
151+
metadata={
152+
"behavior_id": behavior_id,
153+
},
154+
authors=[
155+
"Mantas Mazeika",
156+
"Long Phan",
157+
"Xuwang Yin",
158+
"Andy Zou",
159+
"Zifan Wang",
160+
"Norman Mu",
161+
"Elham Sakhaee",
162+
"Nathaniel Li",
163+
"Steven Basart",
164+
"Bo Li",
165+
"David Forsyth",
166+
"Dan Hendrycks",
167+
],
168+
groups=[
169+
"University of Illinois Urbana-Champaign",
170+
"Center for AI Safety",
171+
"Carnegie Mellon University",
172+
"UC Berkeley",
173+
"Microsoft",
174+
],
175+
)
176+
prompts.append(text_prompt)
177+
178+
if failed_image_count > 0:
179+
logger.warning(f"Total skipped examples: {failed_image_count} (image fetch failures)")
180+
181+
seed_prompt_dataset = SeedPromptDataset(prompts=prompts)
182+
return seed_prompt_dataset
183+
184+
185+
async def _fetch_and_save_image_async(image_url: str, behavior_id: str) -> str:
186+
filename = f"harmbench_{behavior_id}.png"
187+
serializer = data_serializer_factory(category="seed-prompt-entries", data_type="image_path", extension="png")
188+
189+
# Return existing path if image already exists for this BehaviorID
190+
serializer.value = str(serializer._memory.results_path + serializer.data_sub_directory + f"/{filename}")
191+
try:
192+
if await serializer._memory.results_storage_io.path_exists(serializer.value):
193+
return serializer.value
194+
except Exception as e:
195+
logger.warning(f"Failed to check whether image for {behavior_id} already exists: {e}")
196+
197+
response = await make_request_and_raise_if_error_async(endpoint_uri=image_url, method="GET")
198+
await serializer.save_data(data=response.content, output_filename=filename.replace(".png", ""))
199+
200+
return str(serializer.value)

tests/integration/datasets/test_fetch_datasets.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
fetch_equitymedqa_dataset_unique_values,
1414
fetch_forbidden_questions_dataset,
1515
fetch_harmbench_dataset,
16+
fetch_harmbench_multimodal_dataset_async,
1617
fetch_jbb_behaviors_by_harm_category,
1718
fetch_jbb_behaviors_by_jbb_category,
1819
fetch_jbb_behaviors_dataset,
@@ -72,6 +73,21 @@ def test_fetch_datasets(fetch_function, is_seed_prompt_dataset):
7273
assert len(data.prompts) > 0
7374

7475

76+
@pytest.mark.asyncio
77+
@pytest.mark.parametrize(
78+
"fetch_function, number_of_prompts",
79+
[
80+
(fetch_harmbench_multimodal_dataset_async, 110 * 2),
81+
],
82+
)
83+
async def test_fetch_multimodal_datasets(fetch_function, number_of_prompts):
84+
data = await fetch_function()
85+
86+
assert data is not None
87+
assert isinstance(data, SeedPromptDataset)
88+
assert len(data.prompts) == number_of_prompts
89+
90+
7591
@pytest.mark.integration
7692
def test_fetch_jbb_behaviors_by_harm_category():
7793
"""Integration test for filtering by harm category with real data."""

0 commit comments

Comments
 (0)