@@ -17,7 +17,7 @@ class SurprisalInterventionResult:
1717
1818 Attributes:
1919 score: The final computed score.
20- avg_kl: The average KL divergence between clean & intervened
20+ avg_kl: The average KL divergence between clean & intervened
2121 next-token distributions.
2222 explanation: The explanation string that was scored.
2323 """
@@ -47,7 +47,7 @@ class SurprisalInterventionScorer(Scorer):
4747 2. Compute the log-probability of the explanation conditioned on both the clean
4848 and intervened generated texts: log P(explanation | text)[cite: 209].
4949 3. Compute KL divergence between the clean & intervened next-token distributions.
50- 4. The final score is the mean change in explanation log-prob, divided by the
50+ 4. The final score is the mean change in explanation log-prob, divided by the
5151 mean KL divergence:
5252 score = mean(log_prob_intervened - log_prob_clean) / (mean_KL + ε).
5353 """
@@ -63,7 +63,7 @@ def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs):
6363 strength (float): The magnitude of the intervention. Default: 5.0.
6464 num_prompts (int): Number of activating examples to test. Default: 3.
6565 max_new_tokens (int): Max tokens to generate for continuations.
66- hookpoint (str): The module name (e.g., 'transformer.h.10.mlp')
66+ hookpoint (str): The module name (e.g., 'transformer.h.10.mlp')
6767 for the intervention.
6868 """
6969 self .subject_model = subject_model
@@ -121,9 +121,11 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any:
121121
122122 if not is_valid_format :
123123 if len (parts ) == 1 and hasattr (model , hookpoint_str ):
124- return getattr (model , hookpoint_str )
125- raise ValueError (f"""Hookpoint string '{ hookpoint_str } ' is not in a recognized format
126- like 'layers.6.mlp'.""" )
124+ return getattr (model , hookpoint_str )
125+ raise ValueError (
126+ f"""Hookpoint string '{ hookpoint_str } ' is not in a recognized format
127+ like 'layers.6.mlp'."""
128+ )
127129
128130 # Heuristically find the model prefix.
129131 prefix = None
@@ -251,7 +253,7 @@ async def _generate_with_and_without_intervention(
251253 Returns:
252254 A tuple containing:
253255 - The generated text (string).
254- - The log-probability distribution for the token immediately following
256+ - The log-probability distribution for the token immediately following
255257 the prompt (Tensor).
256258 """
257259 device = self ._get_device ()
0 commit comments