Skip to content

Commit c6b901a

Browse files
committed
2 parents 8c65dbe + 68d6c63 commit c6b901a

File tree

1 file changed

+50
-44
lines changed

1 file changed

+50
-44
lines changed

delphi/scorers/intervention/surprisal_intervention_scorer.py

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -168,25 +168,22 @@ def _get_intervention_vector(self, sae: Any, feature_id: int) -> torch.Tensor:
168168
"""
169169
Calculates the feature's decoder vector, subtracting the decoder bias.
170170
"""
171-
172-
171+
173172
d_latent = sae.encoder.out_features
174173
sae_device = sae.encoder.weight.device
175174

176175
# Create a one-hot activation for our single feature.
177176
one_hot_activation = torch.zeros(1, 1, d_latent, device=sae_device)
178-
177+
179178
if feature_id >= d_latent:
180179
print(f"""DEBUG: ERROR - Feature ID {feature_id} is out of bounds
181180
for d_latent {d_latent}""")
182181
return torch.zeros(1)
183-
182+
184183
one_hot_activation[0, 0, feature_id] = 1.0
185184

186185
# Create the corresponding indices needed for the decode method.
187-
indices = torch.tensor(
188-
[[[feature_id]]], device=sae_device, dtype=torch.long
189-
)
186+
indices = torch.tensor([[[feature_id]]], device=sae_device, dtype=torch.long)
190187

191188
with torch.no_grad():
192189
try:
@@ -197,24 +194,25 @@ def _get_intervention_vector(self, sae: Any, feature_id: int) -> torch.Tensor:
197194
return torch.zeros(1)
198195

199196
decoder_vector = vector_before_sub - decoded_zero
200-
197+
201198
final_norm = decoder_vector.norm().item()
202-
199+
203200
# --- MODIFIED DEBUG BLOCK ---
204201
# Only print if the feature is "decoder-live"
205202
if final_norm > 1e-6:
206203
print(f"\n--- DEBUG: 'Decoder-Live' Feature Found: {feature_id} ---")
207204
print(f"DEBUG: sae.encoder.out_features (d_latent): {d_latent}")
208205
print(f"DEBUG: sae.encoder.weight.device (sae_device): {sae_device}")
209206
print(f"DEBUG: Norm of decoded_zero: {decoded_zero.norm().item()}")
210-
print(f"DEBUG: Norm of vector_before_sub: {vector_before_sub.norm().item()}")
207+
print(
208+
f"DEBUG: Norm of vector_before_sub: {vector_before_sub.norm().item()}"
209+
)
211210
print(f"DEBUG: Feature {feature_id}, FINAL Vector Norm: {final_norm}")
212211
print("--- END DEBUG ---\n")
213212
# --- END MODIFIED BLOCK ---
214213

215214
return decoder_vector.squeeze()
216215

217-
218216
async def __call__(self, record: LatentRecord) -> ScorerResult:
219217

220218
record_copy = copy.deepcopy(record)
@@ -241,7 +239,7 @@ async def __call__(self, record: LatentRecord) -> ScorerResult:
241239
sae = self._get_sae_for_hookpoint(hookpoint_str, record_copy)
242240
if not sae:
243241
raise ValueError(f"Could not find SAE for hookpoint {hookpoint_str}")
244-
242+
245243
intervention_vector = self._get_intervention_vector(sae, record_copy.feature_id)
246244

247245
tuned_strength, initial_kl = await self._tune_strength(
@@ -254,10 +252,18 @@ async def __call__(self, record: LatentRecord) -> ScorerResult:
254252

255253
for prompt in truncated_prompts:
256254
clean_text, clean_logp_dist = await self._generate_with_intervention(
257-
prompt, record_copy, strength=0.0, intervention_vector=intervention_vector, get_logp_dist=True
255+
prompt,
256+
record_copy,
257+
strength=0.0,
258+
intervention_vector=intervention_vector,
259+
get_logp_dist=True,
258260
)
259261
int_text, int_logp_dist = await self._generate_with_intervention(
260-
prompt, record_copy, strength=tuned_strength, intervention_vector=intervention_vector, get_logp_dist=True
262+
prompt,
263+
record_copy,
264+
strength=tuned_strength,
265+
intervention_vector=intervention_vector,
266+
get_logp_dist=True,
261267
)
262268

263269
logp_clean = await self._score_explanation(
@@ -301,7 +307,6 @@ async def __call__(self, record: LatentRecord) -> ScorerResult:
301307
)
302308
return ScorerResult(record=record_copy, score=final_output_list)
303309

304-
305310
async def _get_latent_activations(
306311
self, prompt: str, record: LatentRecord
307312
) -> torch.Tensor:
@@ -340,7 +345,6 @@ def capture_hook(module, inp, out):
340345

341346
return feature_acts[0, :, record.feature_id].cpu()
342347

343-
344348
async def _truncate_prompt(self, prompt: str, record: LatentRecord) -> str:
345349
"""
346350
Truncates prompt to end just before the first token where latent activates.
@@ -357,17 +361,18 @@ async def _truncate_prompt(self, prompt: str, record: LatentRecord) -> str:
357361
first_activation_idx = all_activation_indices[all_activation_indices > 0]
358362

359363
if first_activation_idx.numel() > 0:
360-
truncation_point = first_activation_idx[0].item()
364+
truncation_point = first_activation_idx[0].item()
361365
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids[0]
362-
truncated_ids = input_ids[:truncation_point + 1]
366+
truncated_ids = input_ids[: truncation_point + 1]
363367
return self.tokenizer.decode(truncated_ids, skip_special_tokens=True)
364368

365369
return prompt
366370

367-
368371
async def _tune_strength(
369-
self, prompts: List[str], record: LatentRecord,
370-
intervention_vector: torch.Tensor
372+
self,
373+
prompts: List[str],
374+
record: LatentRecord,
375+
intervention_vector: torch.Tensor,
371376
) -> Tuple[float, float]:
372377
"""
373378
Performs a binary search to find intervention strength that matches target_kl.
@@ -409,22 +414,26 @@ async def _tune_strength(
409414
best_strength = mid_strength
410415

411416
# Return the best found strength and the corresponding KL
412-
final_kl = await self._calculate_avg_kl(prompts, record, best_strength, intervention_vector)
417+
final_kl = await self._calculate_avg_kl(
418+
prompts, record, best_strength, intervention_vector
419+
)
413420
return best_strength, final_kl
414421

415-
416422
async def _calculate_avg_kl(
417-
self, prompts: List[str], record: LatentRecord, strength: float,
418-
intervention_vector: torch.Tensor
423+
self,
424+
prompts: List[str],
425+
record: LatentRecord,
426+
strength: float,
427+
intervention_vector: torch.Tensor,
419428
) -> float:
420429
total_kl = 0.0
421430
n = 0
422431
for prompt in prompts:
423432
_, clean_logp = await self._generate_with_intervention(
424-
prompt, record, 0.0, intervention_vector,True
433+
prompt, record, 0.0, intervention_vector, True
425434
)
426435
_, int_logp = await self._generate_with_intervention(
427-
prompt, record, strength, intervention_vector,True
436+
prompt, record, strength, intervention_vector, True
428437
)
429438
p_clean = torch.exp(clean_logp)
430439
kl_div = F.kl_div(
@@ -434,7 +443,6 @@ async def _calculate_avg_kl(
434443
n += 1
435444
return total_kl / n if n > 0 else 0.0
436445

437-
438446
async def _generate_with_intervention(
439447
self,
440448
prompt: str,
@@ -474,8 +482,9 @@ def hook_fn(module, inp, out):
474482
intervention_start_index = prompt_length - 1
475483

476484
if current_seq_len >= prompt_length:
477-
new_hiddens[:, intervention_start_index:, :] += delta.to(original_dtype)
478-
485+
new_hiddens[:, intervention_start_index:, :] += delta.to(
486+
original_dtype
487+
)
479488

480489
return (
481490
(new_hiddens,) + out[1:] if isinstance(out, tuple) else new_hiddens
@@ -485,7 +494,7 @@ def hook_fn(module, inp, out):
485494

486495
try:
487496
with torch.no_grad():
488-
outputs =self.subject_model(input_ids, attention_mask=attention_mask)
497+
outputs = self.subject_model(input_ids, attention_mask=attention_mask)
489498
next_token_logits = outputs.logits[0, -1, :]
490499
log_probs_next_token = (
491500
F.log_softmax(next_token_logits, dim=-1) if get_logp_dist else None
@@ -507,10 +516,9 @@ def hook_fn(module, inp, out):
507516
log_probs_next_token.cpu() if get_logp_dist else torch.empty(0)
508517
)
509518

510-
511519
async def _score_explanation(self, generated_text: str, explanation: str) -> float:
512520
"""
513-
Computes log P(explanation | generated_text) using the paper's
521+
Computes log P(explanation | generated_text) using the paper's
514522
prompt format.
515523
"""
516524
device = self._get_device()
@@ -519,9 +527,9 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo
519527
prompt_template = (
520528
"<PASSAGE>\n"
521529
f"{generated_text}\n"
522-
"The above passage contains an amplified amount of \""
530+
'The above passage contains an amplified amount of "'
523531
)
524-
explanation_suffix = f"{explanation}\""
532+
explanation_suffix = f'{explanation}"'
525533

526534
# Tokenize the parts
527535
context_enc = self.tokenizer(prompt_template, return_tensors="pt")
@@ -537,7 +545,7 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo
537545

538546
# We only need to score the explanation part
539547
context_len = context_enc.input_ids.shape[1]
540-
548+
541549
# Get logits for positions that predict the explanation tokens
542550
# Shape: [batch_size, explanation_len, vocab_size]
543551
explanation_logits = logits[:, context_len - 1 : -1, :]
@@ -549,14 +557,11 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo
549557
log_probs = F.log_softmax(explanation_logits, dim=-1)
550558

551559
# Gather the log-probabilities of the actual explanation tokens
552-
token_log_probs = log_probs.gather(
553-
2, target_ids.unsqueeze(-1)
554-
).squeeze(-1)
560+
token_log_probs = log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1)
555561

556562
# Return the sum of log-probs for the explanation
557563
return token_log_probs.sum().item()
558564

559-
560565
def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> Any:
561566
"""
562567
Retrieves the correct SAE model, handling the specific functools.partial
@@ -568,13 +573,13 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An
568573
candidate = record.sae
569574
elif self.explainer_model and isinstance(self.explainer_model, dict):
570575
full_key = self._get_full_hookpoint_path(hookpoint_str)
571-
short_key = ".".join(hookpoint_str.split(".")[-2:]) # e.g., "layers.6.mlp"
576+
short_key = ".".join(hookpoint_str.split(".")[-2:]) # e.g., "layers.6.mlp"
572577

573578
for key in [hookpoint_str, full_key, short_key]:
574579
if self.explainer_model.get(key) is not None:
575580
candidate = self.explainer_model.get(key)
576581
break
577-
582+
578583
if candidate is None:
579584
# This will raise an error if the key isn't found
580585
raise ValueError(f"ERROR: Surprisal scorer could not find an SAE "
@@ -591,8 +596,9 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An
591596
find the 'sae' keyword.
592597
func: {candidate.func}
593598
args: {candidate.args}
594-
keywords: {candidate.keywords}""")
595-
599+
keywords: {candidate.keywords}"""
600+
)
601+
596602
# This will raise an error if the candidate isn't a partial
597603
raise ValueError(f"""ERROR: Candidate for {hookpoint_str} was not a partial
598604
object, which was not expected. Type: {type(candidate)}""")

0 commit comments

Comments
 (0)