@@ -168,25 +168,22 @@ def _get_intervention_vector(self, sae: Any, feature_id: int) -> torch.Tensor:
168168 """
169169 Calculates the feature's decoder vector, subtracting the decoder bias.
170170 """
171-
172-
171+
173172 d_latent = sae .encoder .out_features
174173 sae_device = sae .encoder .weight .device
175174
176175 # Create a one-hot activation for our single feature.
177176 one_hot_activation = torch .zeros (1 , 1 , d_latent , device = sae_device )
178-
177+
179178 if feature_id >= d_latent :
180179 print (f"""DEBUG: ERROR - Feature ID { feature_id } is out of bounds
181180 for d_latent { d_latent } """ )
182181 return torch .zeros (1 )
183-
182+
184183 one_hot_activation [0 , 0 , feature_id ] = 1.0
185184
186185 # Create the corresponding indices needed for the decode method.
187- indices = torch .tensor (
188- [[[feature_id ]]], device = sae_device , dtype = torch .long
189- )
186+ indices = torch .tensor ([[[feature_id ]]], device = sae_device , dtype = torch .long )
190187
191188 with torch .no_grad ():
192189 try :
@@ -197,24 +194,25 @@ def _get_intervention_vector(self, sae: Any, feature_id: int) -> torch.Tensor:
197194 return torch .zeros (1 )
198195
199196 decoder_vector = vector_before_sub - decoded_zero
200-
197+
201198 final_norm = decoder_vector .norm ().item ()
202-
199+
203200 # --- MODIFIED DEBUG BLOCK ---
204201 # Only print if the feature is "decoder-live"
205202 if final_norm > 1e-6 :
206203 print (f"\n --- DEBUG: 'Decoder-Live' Feature Found: { feature_id } ---" )
207204 print (f"DEBUG: sae.encoder.out_features (d_latent): { d_latent } " )
208205 print (f"DEBUG: sae.encoder.weight.device (sae_device): { sae_device } " )
209206 print (f"DEBUG: Norm of decoded_zero: { decoded_zero .norm ().item ()} " )
210- print (f"DEBUG: Norm of vector_before_sub: { vector_before_sub .norm ().item ()} " )
207+ print (
208+ f"DEBUG: Norm of vector_before_sub: { vector_before_sub .norm ().item ()} "
209+ )
211210 print (f"DEBUG: Feature { feature_id } , FINAL Vector Norm: { final_norm } " )
212211 print ("--- END DEBUG ---\n " )
213212 # --- END MODIFIED BLOCK ---
214213
215214 return decoder_vector .squeeze ()
216215
217-
218216 async def __call__ (self , record : LatentRecord ) -> ScorerResult :
219217
220218 record_copy = copy .deepcopy (record )
@@ -241,7 +239,7 @@ async def __call__(self, record: LatentRecord) -> ScorerResult:
241239 sae = self ._get_sae_for_hookpoint (hookpoint_str , record_copy )
242240 if not sae :
243241 raise ValueError (f"Could not find SAE for hookpoint { hookpoint_str } " )
244-
242+
245243 intervention_vector = self ._get_intervention_vector (sae , record_copy .feature_id )
246244
247245 tuned_strength , initial_kl = await self ._tune_strength (
@@ -254,10 +252,18 @@ async def __call__(self, record: LatentRecord) -> ScorerResult:
254252
255253 for prompt in truncated_prompts :
256254 clean_text , clean_logp_dist = await self ._generate_with_intervention (
257- prompt , record_copy , strength = 0.0 , intervention_vector = intervention_vector , get_logp_dist = True
255+ prompt ,
256+ record_copy ,
257+ strength = 0.0 ,
258+ intervention_vector = intervention_vector ,
259+ get_logp_dist = True ,
258260 )
259261 int_text , int_logp_dist = await self ._generate_with_intervention (
260- prompt , record_copy , strength = tuned_strength , intervention_vector = intervention_vector , get_logp_dist = True
262+ prompt ,
263+ record_copy ,
264+ strength = tuned_strength ,
265+ intervention_vector = intervention_vector ,
266+ get_logp_dist = True ,
261267 )
262268
263269 logp_clean = await self ._score_explanation (
@@ -301,7 +307,6 @@ async def __call__(self, record: LatentRecord) -> ScorerResult:
301307 )
302308 return ScorerResult (record = record_copy , score = final_output_list )
303309
304-
305310 async def _get_latent_activations (
306311 self , prompt : str , record : LatentRecord
307312 ) -> torch .Tensor :
@@ -340,7 +345,6 @@ def capture_hook(module, inp, out):
340345
341346 return feature_acts [0 , :, record .feature_id ].cpu ()
342347
343-
344348 async def _truncate_prompt (self , prompt : str , record : LatentRecord ) -> str :
345349 """
346350 Truncates prompt to end just before the first token where latent activates.
@@ -357,17 +361,18 @@ async def _truncate_prompt(self, prompt: str, record: LatentRecord) -> str:
357361 first_activation_idx = all_activation_indices [all_activation_indices > 0 ]
358362
359363 if first_activation_idx .numel () > 0 :
360- truncation_point = first_activation_idx [0 ].item ()
364+ truncation_point = first_activation_idx [0 ].item ()
361365 input_ids = self .tokenizer (prompt , return_tensors = "pt" ).input_ids [0 ]
362- truncated_ids = input_ids [:truncation_point + 1 ]
366+ truncated_ids = input_ids [: truncation_point + 1 ]
363367 return self .tokenizer .decode (truncated_ids , skip_special_tokens = True )
364368
365369 return prompt
366370
367-
368371 async def _tune_strength (
369- self , prompts : List [str ], record : LatentRecord ,
370- intervention_vector : torch .Tensor
372+ self ,
373+ prompts : List [str ],
374+ record : LatentRecord ,
375+ intervention_vector : torch .Tensor ,
371376 ) -> Tuple [float , float ]:
372377 """
373378 Performs a binary search to find intervention strength that matches target_kl.
@@ -409,22 +414,26 @@ async def _tune_strength(
409414 best_strength = mid_strength
410415
411416 # Return the best found strength and the corresponding KL
412- final_kl = await self ._calculate_avg_kl (prompts , record , best_strength , intervention_vector )
417+ final_kl = await self ._calculate_avg_kl (
418+ prompts , record , best_strength , intervention_vector
419+ )
413420 return best_strength , final_kl
414421
415-
416422 async def _calculate_avg_kl (
417- self , prompts : List [str ], record : LatentRecord , strength : float ,
418- intervention_vector : torch .Tensor
423+ self ,
424+ prompts : List [str ],
425+ record : LatentRecord ,
426+ strength : float ,
427+ intervention_vector : torch .Tensor ,
419428 ) -> float :
420429 total_kl = 0.0
421430 n = 0
422431 for prompt in prompts :
423432 _ , clean_logp = await self ._generate_with_intervention (
424- prompt , record , 0.0 , intervention_vector ,True
433+ prompt , record , 0.0 , intervention_vector , True
425434 )
426435 _ , int_logp = await self ._generate_with_intervention (
427- prompt , record , strength , intervention_vector ,True
436+ prompt , record , strength , intervention_vector , True
428437 )
429438 p_clean = torch .exp (clean_logp )
430439 kl_div = F .kl_div (
@@ -434,7 +443,6 @@ async def _calculate_avg_kl(
434443 n += 1
435444 return total_kl / n if n > 0 else 0.0
436445
437-
438446 async def _generate_with_intervention (
439447 self ,
440448 prompt : str ,
@@ -474,8 +482,9 @@ def hook_fn(module, inp, out):
474482 intervention_start_index = prompt_length - 1
475483
476484 if current_seq_len >= prompt_length :
477- new_hiddens [:, intervention_start_index :, :] += delta .to (original_dtype )
478-
485+ new_hiddens [:, intervention_start_index :, :] += delta .to (
486+ original_dtype
487+ )
479488
480489 return (
481490 (new_hiddens ,) + out [1 :] if isinstance (out , tuple ) else new_hiddens
@@ -485,7 +494,7 @@ def hook_fn(module, inp, out):
485494
486495 try :
487496 with torch .no_grad ():
488- outputs = self .subject_model (input_ids , attention_mask = attention_mask )
497+ outputs = self .subject_model (input_ids , attention_mask = attention_mask )
489498 next_token_logits = outputs .logits [0 , - 1 , :]
490499 log_probs_next_token = (
491500 F .log_softmax (next_token_logits , dim = - 1 ) if get_logp_dist else None
@@ -507,10 +516,9 @@ def hook_fn(module, inp, out):
507516 log_probs_next_token .cpu () if get_logp_dist else torch .empty (0 )
508517 )
509518
510-
511519 async def _score_explanation (self , generated_text : str , explanation : str ) -> float :
512520 """
513- Computes log P(explanation | generated_text) using the paper's
521+ Computes log P(explanation | generated_text) using the paper's
514522 prompt format.
515523 """
516524 device = self ._get_device ()
@@ -519,9 +527,9 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo
519527 prompt_template = (
520528 "<PASSAGE>\n "
521529 f"{ generated_text } \n "
522- " The above passage contains an amplified amount of \" "
530+ ' The above passage contains an amplified amount of "'
523531 )
524- explanation_suffix = f" { explanation } \" "
532+ explanation_suffix = f' { explanation } "'
525533
526534 # Tokenize the parts
527535 context_enc = self .tokenizer (prompt_template , return_tensors = "pt" )
@@ -537,7 +545,7 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo
537545
538546 # We only need to score the explanation part
539547 context_len = context_enc .input_ids .shape [1 ]
540-
548+
541549 # Get logits for positions that predict the explanation tokens
542550 # Shape: [batch_size, explanation_len, vocab_size]
543551 explanation_logits = logits [:, context_len - 1 : - 1 , :]
@@ -549,14 +557,11 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo
549557 log_probs = F .log_softmax (explanation_logits , dim = - 1 )
550558
551559 # Gather the log-probabilities of the actual explanation tokens
552- token_log_probs = log_probs .gather (
553- 2 , target_ids .unsqueeze (- 1 )
554- ).squeeze (- 1 )
560+ token_log_probs = log_probs .gather (2 , target_ids .unsqueeze (- 1 )).squeeze (- 1 )
555561
556562 # Return the sum of log-probs for the explanation
557563 return token_log_probs .sum ().item ()
558564
559-
560565 def _get_sae_for_hookpoint (self , hookpoint_str : str , record : LatentRecord ) -> Any :
561566 """
562567 Retrieves the correct SAE model, handling the specific functools.partial
@@ -568,13 +573,13 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An
568573 candidate = record .sae
569574 elif self .explainer_model and isinstance (self .explainer_model , dict ):
570575 full_key = self ._get_full_hookpoint_path (hookpoint_str )
571- short_key = "." .join (hookpoint_str .split ("." )[- 2 :]) # e.g., "layers.6.mlp"
576+ short_key = "." .join (hookpoint_str .split ("." )[- 2 :]) # e.g., "layers.6.mlp"
572577
573578 for key in [hookpoint_str , full_key , short_key ]:
574579 if self .explainer_model .get (key ) is not None :
575580 candidate = self .explainer_model .get (key )
576581 break
577-
582+
578583 if candidate is None :
579584 # This will raise an error if the key isn't found
580585 raise ValueError (f"ERROR: Surprisal scorer could not find an SAE "
@@ -591,8 +596,9 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An
591596 find the 'sae' keyword.
592597 func: { candidate .func }
593598 args: { candidate .args }
594- keywords: { candidate .keywords } """ )
595-
599+ keywords: { candidate .keywords } """
600+ )
601+
596602 # This will raise an error if the candidate isn't a partial
597603 raise ValueError (f"""ERROR: Candidate for { hookpoint_str } was not a partial
598604 object, which was not expected. Type: { type (candidate )} """ )
0 commit comments