4747import yaml
4848
4949
50-
5150def load_config () -> dict :
5251 with importlib .resources .open_text (
5352 "easyroutine.interpretability.config" , "config.yaml"
@@ -100,6 +99,10 @@ class ExtractionConfig:
10099 extract_head_keys (bool): if True, extract the keys of the attention
101100 extract_head_values (bool): if True, extract the values of the attention
102101 extract_head_queries (bool): if True, extract the queries of the attention
102+ extract_values (bool): if True, extract the values. This do not reshape the values to the attention heads as extract_head_values does
103+ extract_keys (bool): if True, extract the keys. This do not reshape the keys to the attention heads as extract_head_keys does
104+ extract_queries (bool): if True, extract the queries. This do not reshape the queries to the attention heads as extract_head_queries does
105+ extract_last_layernorm (bool): if True, extract the last layernorm of the model
103106 extract_head_out (bool): if True, extract the output of the heads [DEPRECATED]
104107 extract_attn_out (bool): if True, extract the output of the attention of the attn_heads passed
105108 extract_attn_in (bool): if True, extract the input of the attention of the attn_heads passed
@@ -124,6 +127,9 @@ class ExtractionConfig:
124127 extract_head_keys : bool = False
125128 extract_head_values : bool = False
126129 extract_head_queries : bool = False
130+ extract_values : bool = False
131+ extract_keys : bool = False
132+ extract_queries : bool = False
127133 extract_head_out : bool = False
128134 extract_attn_out : bool = False
129135 extract_attn_in : bool = False
@@ -142,7 +148,7 @@ class ExtractionConfig:
142148
143149 def is_not_empty (self ):
144150 """
145- Return True if at least one of the attributes is True, False otherwise, i.e. if the model should extract something!
151+ Return True if at least one extraction option is enabled in the config, False otherwise.
146152 """
147153 return any (
148154 [
@@ -167,13 +173,15 @@ def is_not_empty(self):
167173 )
168174
169175 def to_dict (self ):
176+ """
177+ Return the configuration as a dictionary.
178+ """
170179 return self .__dict__
171180
172181
173182class HookedModel :
174183 """
175- This class is a wrapper around the huggingface model that allows to extract the activations of the model. It is support
176- advanced mechanistic intepretability methods like ablation, patching, etc.
184+ Wrapper around a HuggingFace model for extracting activations and supporting mechanistic interpretability methods.
177185 """
178186
179187 def __init__ (self , config : HookedModelConfig , log_file_path : Optional [str ] = None ):
@@ -286,27 +294,30 @@ def assert_all_modules_exist(self):
286294
287295 def set_custom_modules (self ):
288296 """
289- Apply the wrap of the custom modules. for now just the attention module
297+ Substitute custom modules (e.g., attention) into the model for advanced interpretability.
290298 """
291299 logger .info ("HookedModel: Setting custom modules." )
292300 self .module_wrapper_manager .substitute_attention_module (self .hf_model )
293301
294302 def restore_original_modules (self ):
295303 """
296- Restore the original modules of the model unloading the custom modules .
304+ Restore the original modules of the model, removing any custom substitutions .
297305 """
298306 logger .info ("HookedModel: Restoring original modules." )
299307 self .module_wrapper_manager .restore_original_attention_module (self .hf_model )
300308
301309 def is_multimodal (self ) -> bool :
302310 """
303- Get if the model is multimodal or not
311+ Return True if the model supports multimodal inputs (e.g., images), False otherwise.
304312 """
305313 if self .processor is not None :
306314 return True
307315 return False
308316
309317 def use_full_model (self ):
318+ """
319+ Switch to the full model (including multimodal components if available).
320+ """
310321 if self .processor is not None :
311322 logger .debug ("HookedModel: Using full model capabilities" )
312323 if self .base_model is not None :
@@ -319,6 +330,9 @@ def use_full_model(self):
319330 logger .debug ("HookedModel: Using full text only model capabilities" )
320331
321332 def use_language_model_only (self ):
333+ """
334+ Switch to using only the language model component (text-only mode).
335+ """
322336 if self .hf_language_model is None :
323337 logger .warning (
324338 "HookedModel: The model does not have a separate language model that can be used" ,
@@ -333,6 +347,9 @@ def use_language_model_only(self):
333347 logger .debug ("HookedModel: Using only language model capabilities" )
334348
335349 def get_tokenizer (self ):
350+ """
351+ Return the tokenizer associated with the model.
352+ """
336353 return self .hf_tokenizer
337354
338355 def get_text_tokenizer (self ):
@@ -366,41 +383,42 @@ def get_processor(self):
366383 return self .processor
367384
368385 def get_lm_head (self ):
386+ """
387+ Return the language modeling head (output projection layer) of the model.
388+ """
369389 return get_attribute_by_name (self .hf_model , self .model_config .unembed_matrix )
370390
371391 def get_last_layernorm (self ):
392+ """
393+ Return the last layer normalization module of the model.
394+ """
372395 return get_attribute_by_name (self .hf_model , self .model_config .last_layernorm )
373396
374397 def get_image_placeholder (self ) -> str :
398+ """
399+ Return the image placeholder string used by the tokenizer for multimodal models.
400+ """
375401 return self .image_placeholder
376402
377403 def eval (self ):
378- r """
379- Set the model in evaluation mode
404+ """
405+ Set the model to evaluation mode.
380406 """
381407 self .hf_model .eval ()
382408
383409 def device (self ):
384- r"""
385- Return the device of the model. If the model is in multiple devices, it will return the first device
386-
387- Args:
388- None
389-
390- Returns:
391- device: the device of the model
410+ """
411+ Return the device (e.g., 'cuda', 'cpu') where the model is located.
392412 """
393413 return self .first_device
394414
395415 def register_forward_hook (self , component : str , hook_function : Callable ):
396416 r"""
397- Add a new hook to the model. The hook will be called in the forward pass of the model .
417+ Register a forward hook on a model component .
398418
399419 Args:
400- component (str): the component of the model where the hook will be added.
401- hook_function (Callable): the function that will be called in the forward pass of the model. The function must have the following signature:
402- def hook_function(module, input, output):
403- pass
420+ component (str): Name of the model component.
421+ hook_function (Callable): Function to call during forward pass.
404422
405423 Returns:
406424 None
@@ -447,10 +465,16 @@ def to_string_tokens(
447465 return string_tokens
448466
449467 def register_interventions (self , interventions : List [Intervention ]):
468+ """
469+ Register a list of interventions to be applied during forward passes.
470+ """
450471 self .additional_interventions = interventions
451472 logger .debug (f"HookedModel: Registered { len (interventions )} interventions" )
452473
453474 def clean_interventions (self ):
475+ """
476+ Remove all registered interventions.
477+ """
454478 self .additional_interventions = []
455479 logger .debug (
456480 f"HookedModel: Removed { len (self .additional_interventions )} interventions"
@@ -503,6 +527,11 @@ def create_hooks(
503527 raise ValueError (
504528 "attn_heads must be 'all' or a list of dictionaries as [{'layer': 0, 'head': 0}]"
505529 )
530+ # register the intervention hooks as first thing to do
531+ if self .additional_interventions is not None :
532+ hooks += self .intervention_manager .create_intervention_hooks (
533+ interventions = self .additional_interventions , token_dict = token_dict
534+ )
506535
507536 if extraction_config .extract_resid_out :
508537 # assert that the component exists in the model
@@ -589,13 +618,14 @@ def create_hooks(
589618 "intervention" : partial (
590619 query_key_value_hook ,
591620 cache = cache ,
592- cache_key = "queries_ " ,
621+ cache_key = "head_queries_ " ,
593622 token_indexes = token_indexes ,
594623 head_dim = self .model_config .head_dim ,
595624 avg = extraction_config .avg ,
596625 layer = i ,
597626 head = head ,
598627 num_key_value_groups = self .model_config .num_key_value_groups ,
628+ num_attention_heads = self .model_config .num_attention_heads ,
599629 ),
600630 }
601631 for i , head in zip (layer_indexes , head_indexes )
@@ -608,13 +638,14 @@ def create_hooks(
608638 "intervention" : partial (
609639 query_key_value_hook ,
610640 cache = cache ,
611- cache_key = "values_ " ,
641+ cache_key = "head_values_ " ,
612642 token_indexes = token_indexes ,
613643 head_dim = self .model_config .head_dim ,
614644 avg = extraction_config .avg ,
615645 layer = i ,
616646 head = head ,
617647 num_key_value_groups = self .model_config .num_key_value_groups ,
648+ num_attention_heads = self .model_config .num_attention_heads ,
618649 ),
619650 }
620651 for i , head in zip (layer_indexes , head_indexes )
@@ -627,18 +658,64 @@ def create_hooks(
627658 "intervention" : partial (
628659 query_key_value_hook ,
629660 cache = cache ,
630- cache_key = "keys_ " ,
661+ cache_key = "head_keys_ " ,
631662 token_indexes = token_indexes ,
632663 head_dim = self .model_config .head_dim ,
633664 avg = extraction_config .avg ,
634665 layer = i ,
635666 head = head ,
636667 num_key_value_groups = self .model_config .num_key_value_groups ,
668+ num_attention_heads = self .model_config .num_attention_heads ,
637669 ),
638670 }
639671 for i , head in zip (layer_indexes , head_indexes )
640672 ]
641673
674+ if extraction_config .extract_values :
675+ hooks += [
676+ {
677+ "component" : self .model_config .head_value_hook_name .format (i ),
678+ "intervention" : partial (
679+ save_resid_hook ,
680+ cache = cache ,
681+ cache_key = f"values_L{ i } " ,
682+ token_indexes = token_indexes ,
683+ avg = extraction_config .avg ,
684+ ),
685+ }
686+ for i in range (0 , self .model_config .num_hidden_layers )
687+ ]
688+
689+ if extraction_config .extract_keys :
690+ hooks += [
691+ {
692+ "component" : self .model_config .head_key_hook_name .format (i ),
693+ "intervention" : partial (
694+ save_resid_hook ,
695+ cache = cache ,
696+ cache_key = f"keys_L{ i } " ,
697+ token_indexes = token_indexes ,
698+ avg = extraction_config .avg ,
699+ ),
700+ }
701+ for i in range (0 , self .model_config .num_hidden_layers )
702+ ]
703+
704+ if extraction_config .extract_queries :
705+ hooks += [
706+ {
707+ "component" : self .model_config .head_query_hook_name .format (i ),
708+ "intervention" : partial (
709+ save_resid_hook ,
710+ cache = cache ,
711+ cache_key = f"queries_L{ i } " ,
712+ token_indexes = token_indexes ,
713+ avg = extraction_config .avg ,
714+ ),
715+ }
716+ for i in range (0 , self .model_config .num_hidden_layers )
717+ ]
718+
642719 if extraction_config .extract_head_out :
643720 hooks += [
644721 {
@@ -772,10 +849,6 @@ def create_hooks(
772849 hooks += self .intervention_manager .create_intervention_hooks (
773850 interventions = interventions , token_dict = token_dict
774851 )
775- if self .additional_interventions is not None :
776- hooks += self .intervention_manager .create_intervention_hooks (
777- interventions = self .additional_interventions , token_dict = token_dict
778- )
779852 if extraction_config .extract_head_values_projected :
780853 hooks += [
781854 {
@@ -1167,7 +1240,7 @@ def set_hooks(self, hooks: List[Dict[str, Any]]):
11671240
11681241 def remove_hooks (self , hook_handlers ):
11691242 """
1170- Remove all the hooks from the model
1243+ Remove all hooks from the model using the provided handlers.
11711244 """
11721245 for hook_handler in hook_handlers :
11731246 hook_handler .remove ()
@@ -1201,7 +1274,10 @@ def generate(
12011274 # raise NotImplementedError("This method is not working. It needs to be fixed")
12021275 cache = ActivationCache ()
12031276 hook_handlers = None
1204- if target_token_positions is not None or self .additional_interventions is not None :
1277+ if (
1278+ target_token_positions is not None
1279+ or self .additional_interventions is not None
1280+ ):
12051281 string_tokens = self .to_string_tokens (
12061282 self .input_handler .get_input_ids (inputs ).squeeze ()
12071283 )
@@ -1293,8 +1369,9 @@ def extract_cache(
12931369 # example_dict = {}
12941370 n_batches = 0 # Initialize batch counter
12951371
1296- for batch in progress (dataloader , desc = "Extracting cache" , total = len (dataloader )):
1297-
1372+ for batch in progress (
1373+ dataloader , desc = "Extracting cache" , total = len (dataloader )
1374+ ):
12981375 # log_memory_usage("Extract cache - Before batch")
12991376 # tokens, others = batch
13001377 # inputs = {k: v.to(self.first_device) for k, v in tokens.items()}
0 commit comments