@@ -177,7 +177,8 @@ def _get_intervention_vector(self, sae: Any, feature_id: int) -> torch.Tensor:
177177 one_hot_activation = torch .zeros (1 , 1 , d_latent , device = sae_device )
178178
179179 if feature_id >= d_latent :
180- print (f"DEBUG: ERROR - Feature ID { feature_id } is out of bounds for d_latent { d_latent } " )
180+ print (f"""DEBUG: ERROR - Feature ID { feature_id } is out of bounds
181+ for d_latent { d_latent } """ )
181182 return torch .zeros (1 )
182183
183184 one_hot_activation [0 , 0 , feature_id ] = 1.0
@@ -576,22 +577,25 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An
576577
577578 if candidate is None :
578579 # This will raise an error if the key isn't found
579- raise ValueError (f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{ hookpoint_str } ' in self.explainer_model" )
580+ raise ValueError (f"ERROR: Surprisal scorer could not find an SAE "
581+ f"for hookpoint '{ hookpoint_str } ' in self.explainer_model" )
580582
581583 if isinstance (candidate , functools .partial ):
582584 # As shown in load_sparsify.py, the SAE is in the 'sae' keyword.
583585 if candidate .keywords and "sae" in candidate .keywords :
584586 return candidate .keywords ["sae" ] # Unwrapped successfully
585587 else :
586588 # This will raise an error if the partial is missing the keyword
587- raise ValueError (f"""ERROR: Found a partial for { hookpoint_str } but could not
589+ raise ValueError (f"""ERROR: Found a partial for
590+ { hookpoint_str } but could not
588591 find the 'sae' keyword.
589592 func: { candidate .func }
590593 args: { candidate .args }
591594 keywords: { candidate .keywords } """ )
592595
593596 # This will raise an error if the candidate isn't a partial
594- raise ValueError (f"ERROR: Candidate for { hookpoint_str } was not a partial object, which was not expected. Type: { type (candidate )} " )
597+ raise ValueError (f"""ERROR: Candidate for { hookpoint_str } was not a partial
598+ object, which was not expected. Type: { type (candidate )} """ )
595599
596600 def _get_intervention_direction (self , record : LatentRecord ) -> torch .Tensor :
597601 hookpoint_str = self .hookpoint_str or getattr (record , "hookpoint" , None )
0 commit comments