@@ -112,7 +112,7 @@ def _find_layer(self, model: Any, name: str) -> torch.nn.Module:
112112
113113 def _get_full_hookpoint_path (self , hookpoint_str : str ) -> str :
114114 """
115- Heuristically finds the model's prefix and constructs the full hookpoint
115+ Heuristically finds the model's prefix and constructs the full hookpoint
116116 path string.
117117 e.g., 'layers.6.mlp' -> 'model.layers.6.mlp'
118118 """
@@ -124,7 +124,7 @@ def _get_full_hookpoint_path(self, hookpoint_str: str) -> str:
124124 if hasattr (candidate_body , "h" ) or hasattr (candidate_body , "layers" ):
125125 prefix = p
126126 break
127-
127+
128128 return f"{ prefix } .{ hookpoint_str } " if prefix else hookpoint_str
129129
130130 def _resolve_hookpoint (self , model : Any , hookpoint_str : str ) -> Any :
@@ -141,7 +141,6 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any:
141141 Original error: { e } """
142142 )
143143
144-
145144 def _sanitize_examples (self , examples : List [Any ]) -> List [Dict [str , Any ]]:
146145 """
147146 Function used for formatting results to run smoothly in the delphi pipeline
@@ -517,16 +516,18 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo
517516
518517 # If we found a partial but failed to unwrap it, we cannot proceed.
519518 print (
520- f"""ERROR: Found a partial for { hookpoint_str } but could not
521- unwrap the SAE instance.""" )
519+ f"""ERROR: Found a partial for { hookpoint_str } but could not
520+ unwrap the SAE instance."""
521+ )
522522 return None
523523
524524 # If it's not a partial, it's the model itself.
525525 return candidate
526526
527527 print (
528- f"""ERROR: Surprisal scorer could not find
529- an SAE for hookpoint '{ hookpoint_str } '""" )
528+ f"""ERROR: Surprisal scorer could not find
529+ an SAE for hookpoint '{ hookpoint_str } '"""
530+ )
530531 return None
531532
532533 def _get_sae_for_hookpoint (self , hookpoint_str : str , record : LatentRecord ) -> Any :
@@ -552,8 +553,10 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An
552553
553554 return candidate
554555
555- print (f"""ERROR: Surprisal scorer could not find
556- an SAE for hookpoint '{ hookpoint_str } '""" )
556+ print (
557+ f"""ERROR: Surprisal scorer could not find
558+ an SAE for hookpoint '{ hookpoint_str } '"""
559+ )
557560 return None
558561
559562 def _get_intervention_direction (self , record : LatentRecord ) -> torch .Tensor :
0 commit comments