Skip to content

Commit 1a6fa0c

Browse files
committed
Pre-commit clears
1 parent 4ffc891 commit 1a6fa0c

File tree

1 file changed

+27
-20
lines changed

1 file changed

+27
-20
lines changed

delphi/scorers/intervention/surprisal_intervention_scorer.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -114,20 +114,21 @@ def _find_layer(self, model: Any, name: str) -> torch.nn.Module:
114114

115115

116116
def _get_full_hookpoint_path(self, hookpoint_str: str) -> str:
117-
"""
118-
Heuristically finds the model's prefix and constructs the full hookpoint path string.
119-
e.g., 'layers.6.mlp' -> 'model.layers.6.mlp'
120-
"""
121-
# Heuristically find the model prefix.
122-
prefix = None
123-
for p in ["gpt_neox", "transformer", "model"]:
124-
if hasattr(self.subject_model, p):
125-
candidate_body = getattr(self.subject_model, p)
126-
if hasattr(candidate_body, "h") or hasattr(candidate_body, "layers"):
127-
prefix = p
128-
break
129-
130-
return f"{prefix}.{hookpoint_str}" if prefix else hookpoint_str
117+
"""
118+
Heuristically finds the model's prefix and constructs the full hookpoint
119+
path string.
120+
e.g., 'layers.6.mlp' -> 'model.layers.6.mlp'
121+
"""
122+
# Heuristically find the model prefix.
123+
prefix = None
124+
for p in ["gpt_neox", "transformer", "model"]:
125+
if hasattr(self.subject_model, p):
126+
candidate_body = getattr(self.subject_model, p)
127+
if hasattr(candidate_body, "h") or hasattr(candidate_body, "layers"):
128+
prefix = p
129+
break
130+
131+
return f"{prefix}.{hookpoint_str}" if prefix else hookpoint_str
131132

132133

133134
def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any:
@@ -144,6 +145,7 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any:
144145
Original error: {e}"""
145146
)
146147

148+
147149
def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]:
148150
"""
149151
Function used for formatting results to run smoothly in the delphi pipeline
@@ -276,7 +278,7 @@ def capture_hook(module, inp, out):
276278

277279
async def _truncate_prompt(self, prompt: str, record: LatentRecord) -> str:
278280
"""
279-
Truncates a prompt to end just before the first token where the latent activates.
281+
Truncates prompt to end just before the first token where latent activates.
280282
"""
281283
activations = await self._get_latent_activations(prompt, record)
282284
if activations.numel() == 0:
@@ -299,7 +301,7 @@ async def _tune_strength(
299301
self, prompts: List[str], record: LatentRecord
300302
) -> Tuple[float, float]:
301303
"""
302-
Performs a binary search to find the intervention strength that matches `target_kl`.
304+
Performs a binary search to find intervention strength that matches target_kl.
303305
"""
304306
low_strength, high_strength = 0.0, 40.0 # Heuristic search range
305307
best_strength = self.target_kl # Default to target_kl if search fails
@@ -384,7 +386,7 @@ def hook_fn(module, inp, out):
384386
# 2. Create the corresponding indices needed for the decode method.
385387
indices = torch.tensor([[[record.feature_id]]], device=sae_device, dtype=torch.long)
386388

387-
# 3. Decode this one-hot vector to get the feature's direction in the hidden space.
389+
# 3. Decode one-hot vector to get feature's direction in hidden space.
388390
# We subtract the decoded zero vector to remove any decoder bias.
389391
decoded_zero = sae.decode(torch.zeros_like(one_hot_activation), indices)
390392
decoder_vector = sae.decode(one_hot_activation, indices) - decoded_zero
@@ -485,13 +487,17 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo
485487
return instance # Unwrapped successfully.
486488

487489
# If we found a partial but failed to unwrap it, we cannot proceed.
488-
print(f"ERROR: Found a partial for {hookpoint_str} but could not unwrap the SAE instance.")
490+
print(
491+
f"""ERROR: Found a partial for {hookpoint_str} but could not
492+
unwrap the SAE instance.""")
489493
return None
490494

491495
# If it's not a partial, it's the model itself.
492496
return candidate
493497

494-
print(f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{hookpoint_str}'")
498+
print(
499+
f"""ERROR: Surprisal scorer could not find
500+
an SAE for hookpoint '{hookpoint_str}'""")
495501
return None
496502

497503

@@ -518,7 +524,8 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An
518524

519525
return candidate
520526

521-
print(f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{hookpoint_str}'")
527+
print(f"""ERROR: Surprisal scorer could not find
528+
an SAE for hookpoint '{hookpoint_str}'""")
522529
return None
523530

524531

0 commit comments

Comments
 (0)