@@ -114,20 +114,21 @@ def _find_layer(self, model: Any, name: str) -> torch.nn.Module:
114114
115115
116116 def _get_full_hookpoint_path (self , hookpoint_str : str ) -> str :
117- """
118- Heuristically finds the model's prefix and constructs the full hookpoint path string.
119- e.g., 'layers.6.mlp' -> 'model.layers.6.mlp'
120- """
121- # Heuristically find the model prefix.
122- prefix = None
123- for p in ["gpt_neox" , "transformer" , "model" ]:
124- if hasattr (self .subject_model , p ):
125- candidate_body = getattr (self .subject_model , p )
126- if hasattr (candidate_body , "h" ) or hasattr (candidate_body , "layers" ):
127- prefix = p
128- break
129-
130- return f"{ prefix } .{ hookpoint_str } " if prefix else hookpoint_str
117+ """
118+ Heuristically finds the model's prefix and constructs the full hookpoint
119+ path string.
120+ e.g., 'layers.6.mlp' -> 'model.layers.6.mlp'
121+ """
122+ # Heuristically find the model prefix.
123+ prefix = None
124+ for p in ["gpt_neox" , "transformer" , "model" ]:
125+ if hasattr (self .subject_model , p ):
126+ candidate_body = getattr (self .subject_model , p )
127+ if hasattr (candidate_body , "h" ) or hasattr (candidate_body , "layers" ):
128+ prefix = p
129+ break
130+
131+ return f"{ prefix } .{ hookpoint_str } " if prefix else hookpoint_str
131132
132133
133134 def _resolve_hookpoint (self , model : Any , hookpoint_str : str ) -> Any :
@@ -144,6 +145,7 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any:
144145 Original error: { e } """
145146 )
146147
148+
147149 def _sanitize_examples (self , examples : List [Any ]) -> List [Dict [str , Any ]]:
148150 """
149151 Function used for formatting results to run smoothly in the delphi pipeline
@@ -276,7 +278,7 @@ def capture_hook(module, inp, out):
276278
277279 async def _truncate_prompt (self , prompt : str , record : LatentRecord ) -> str :
278280 """
279- Truncates a prompt to end just before the first token where the latent activates.
281+ Truncates prompt to end just before the first token where latent activates.
280282 """
281283 activations = await self ._get_latent_activations (prompt , record )
282284 if activations .numel () == 0 :
@@ -299,7 +301,7 @@ async def _tune_strength(
299301 self , prompts : List [str ], record : LatentRecord
300302 ) -> Tuple [float , float ]:
301303 """
302- Performs a binary search to find the intervention strength that matches ` target_kl` .
304+ Performs a binary search to find intervention strength that matches target_kl.
303305 """
304306 low_strength , high_strength = 0.0 , 40.0 # Heuristic search range
305307 best_strength = self .target_kl # Default to target_kl if search fails
@@ -384,7 +386,7 @@ def hook_fn(module, inp, out):
384386 # 2. Create the corresponding indices needed for the decode method.
385387 indices = torch .tensor ([[[record .feature_id ]]], device = sae_device , dtype = torch .long )
386388
387- # 3. Decode this one-hot vector to get the feature's direction in the hidden space.
389+ # 3. Decode one-hot vector to get feature's direction in hidden space.
388390 # We subtract the decoded zero vector to remove any decoder bias.
389391 decoded_zero = sae .decode (torch .zeros_like (one_hot_activation ), indices )
390392 decoder_vector = sae .decode (one_hot_activation , indices ) - decoded_zero
@@ -485,13 +487,17 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo
485487 return instance # Unwrapped successfully.
486488
487489 # If we found a partial but failed to unwrap it, we cannot proceed.
488- print (f"ERROR: Found a partial for { hookpoint_str } but could not unwrap the SAE instance." )
490+ print (
491+ f"""ERROR: Found a partial for { hookpoint_str } but could not
492+ unwrap the SAE instance.""" )
489493 return None
490494
491495 # If it's not a partial, it's the model itself.
492496 return candidate
493497
494- print (f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{ hookpoint_str } '" )
498+ print (
499+ f"""ERROR: Surprisal scorer could not find
500+ an SAE for hookpoint '{ hookpoint_str } '""" )
495501 return None
496502
497503
@@ -518,7 +524,8 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An
518524
519525 return candidate
520526
521- print (f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{ hookpoint_str } '" )
527+ print (f"""ERROR: Surprisal scorer could not find
528+ an SAE for hookpoint '{ hookpoint_str } '""" )
522529 return None
523530
524531
0 commit comments