-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSafetyClassifierForward.py
More file actions
34 lines (31 loc) · 1.25 KB
/
SafetyClassifierForward.py
File metadata and controls
34 lines (31 loc) · 1.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from torch import load, softmax
from SafetyClassifier import SafetyClassifier
from transformers import AutoTokenizer
# akhil.bdu.st/simple-safety, 8098
class SafetyClassifierForward:
def __init__(self, base_model_name):
self.model = SafetyClassifier(base_model_name)
# expects mounted directory called "data"
checkpoint_path = 'data/Qwen3-0.6B_safety/'
self.model.load_state_dict(load(
f"{checkpoint_path}/simple_safety_cpu.bin",
map_location='cpu'
))
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def eval_text(self, text_input):
text_input = self.prepend_prompt(text_input)
inputs = self.tokenizer(
text_input,
return_tensors = 'pt',
truncation=True,
padding=True
)
outputs = self.model(**inputs)
torch_probs = softmax(outputs['logits'], dim=-1)
prob_unsafe = float(torch_probs[0][0])
return prob_unsafe
def prepend_prompt(self, text_input):
prefix = 'Is the following text considered safe? The user said: '
return prefix + text_input