Skip to content

Commit 6e18bba

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 572da19 commit 6e18bba

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

delphi/scorers/intervention/surprisal_intervention_scorer.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def _find_layer(self, model: Any, name: str) -> torch.nn.Module:
112112

113113
def _get_full_hookpoint_path(self, hookpoint_str: str) -> str:
114114
"""
115-
Heuristically finds the model's prefix and constructs the full hookpoint
115+
Heuristically finds the model's prefix and constructs the full hookpoint
116116
path string.
117117
e.g., 'layers.6.mlp' -> 'model.layers.6.mlp'
118118
"""
@@ -124,7 +124,7 @@ def _get_full_hookpoint_path(self, hookpoint_str: str) -> str:
124124
if hasattr(candidate_body, "h") or hasattr(candidate_body, "layers"):
125125
prefix = p
126126
break
127-
127+
128128
return f"{prefix}.{hookpoint_str}" if prefix else hookpoint_str
129129

130130
def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any:
@@ -141,7 +141,6 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any:
141141
Original error: {e}"""
142142
)
143143

144-
145144
def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]:
146145
"""
147146
Function used for formatting results to run smoothly in the delphi pipeline
@@ -517,16 +516,18 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo
517516

518517
# If we found a partial but failed to unwrap it, we cannot proceed.
519518
print(
520-
f"""ERROR: Found a partial for {hookpoint_str} but could not
521-
unwrap the SAE instance.""")
519+
f"""ERROR: Found a partial for {hookpoint_str} but could not
520+
unwrap the SAE instance."""
521+
)
522522
return None
523523

524524
# If it's not a partial, it's the model itself.
525525
return candidate
526526

527527
print(
528-
f"""ERROR: Surprisal scorer could not find
529-
an SAE for hookpoint '{hookpoint_str}'""")
528+
f"""ERROR: Surprisal scorer could not find
529+
an SAE for hookpoint '{hookpoint_str}'"""
530+
)
530531
return None
531532

532533
def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> Any:
@@ -552,8 +553,10 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An
552553

553554
return candidate
554555

555-
print(f"""ERROR: Surprisal scorer could not find
556-
an SAE for hookpoint '{hookpoint_str}'""")
556+
print(
557+
f"""ERROR: Surprisal scorer could not find
558+
an SAE for hookpoint '{hookpoint_str}'"""
559+
)
557560
return None
558561

559562
def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor:

0 commit comments

Comments
 (0)