-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
116 lines (95 loc) · 4.27 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from collections import Counter
import random
import torch
from toxic_completions import get_toxic_completions_messages
from models.base import BaseProbe
from models.logistic import LogisticProbe
def parse_dtype(dtype):
if dtype == "float16":
return torch.float16
elif dtype == "float32":
return torch.float32
elif dtype == "float64":
return torch.float64
else:
raise ValueError(f"Invalid dtype: {dtype}")
def get_data(
dataset: str,
num_samples: int,
max_num_generate: int = 0,
do_few_shot: bool = False,
label_key: str | None = None,
cache_dir: str | None = None,
):
messages = get_toxic_completions_messages(max_messages=num_samples + 2*max_num_generate, do_few_shot=do_few_shot, cache_dir=cache_dir)
prompt_counts = Counter([message["prompt"] for message in messages])
all_prompts = list(prompt_counts.keys())
random.Random(0).shuffle(all_prompts)
num_train = int(0.8*num_samples)
num_val = num_samples - num_train
train_count = 0
train_prompts, val_prompts = [], []
for prompt in all_prompts:
if train_count < num_train:
train_prompts.append(prompt)
train_count += prompt_counts[prompt]
else:
val_prompts.append(prompt)
train_messages = [message for message in messages if message["prompt"] in train_prompts]
val_messages = [message for message in messages if message["prompt"] in val_prompts]
completion_messages = []
seen_prompts = set()
for message in val_messages:
if message["prompt"] not in seen_prompts:
completion_messages.append(message)
seen_prompts.add(message["prompt"])
if len(completion_messages) >= max_num_generate:
break
val_messages = val_messages[:num_val]
return train_messages, val_messages, completion_messages, num_train
# def get_cb_data(circuit_breaker_dataset, num_samples, max_num_generate):
# harmless_messages = circuit_breaker_dataset.harmless_set
# harmful_messages = circuit_breaker_dataset.circuit_breaker.orig
# val_messages = circuit_breaker_dataset.val_orig
# train_messages = harmless_messages + harmful_messages
# # Function to split the user prompt and assistant response
# def split_prompt_response(message):
# user_tag = "user\n\n"
# assistant_tag = "assistant\n\n"
# separator = "<SEPARATOR>"
# # Case 1: <SEPARATOR> present, but no user prompt
# if separator in message:
# prompt_part, response_part = message.split(separator)
# prompt = prompt_part.replace(user_tag, "").replace(assistant_tag, "").strip()
# response = response_part.strip()
# # Case 2: with user and assistant tags but no <SEPARATOR>
# elif user_tag in message and assistant_tag in message:
# parts = message.split(assistant_tag)
# prompt = parts[0].replace(user_tag, "").strip()
# response = parts[1].strip()
# else:
# return None # If the message doesn't have the expected format, skip it
# # Format the train and validation messages
# train_messages = [split_prompt_response(message) for message in train_messages]
# val_messages = [split_prompt_response(message) for message in val_messages]
# # Filter out any None values that might result from improperly formatted messages
# train_messages = [msg for msg in train_messages if msg is not None]
# val_messages = [msg for msg in val_messages if msg is not None]
# # Select completion messages for validation
# completion_messages = []
# seen_prompts = set()
# for message in val_messages:
# if message["prompt"] not in seen_prompts:
# completion_messages.append(message)
# seen_prompts.add(message["prompt"])
# if len(completion_messages) >= max_num_generate:
# break
# # Return the processed data
# return train_messages, val_messages, completion_messages
def train_probe(probe, train_xs, train_labels, device="cpu") -> BaseProbe:
if probe == "logistic":
model = LogisticProbe(normalize=True, l2_penalty=1.0)
model.fit(train_xs, train_labels)
else:
raise NotImplementedError(f"Invalid probe: {probe}")
return model