Skip to content

Commit 52e40ee

Browse files
committed
Merge branch 'dev' [patch]
2 parents 4f2723e + db545b2 commit 52e40ee

17 files changed

Lines changed: 1643 additions & 250 deletions

easyroutine/interpretability/activation_cache.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,12 @@ def __init__(self):
118118
re.compile(r"attn_out_\d+"),
119119
re.compile(r"avg_attn_pattern_L\dH\d+"),
120120
re.compile(r"pattern_L\dH+\d+"),
121-
re.compile(r"values_\d+"),
121+
re.compile(r"head_values_L\dH+\d+"),
122+
re.compile(r"head_keys_L\dH+\d+"),
123+
re.compile(r"head_queries_L\dH+\d+"),
124+
re.compile(r"values_L\d+"),
125+
re.compile(r"keys_L\d+"),
126+
re.compile(r"queries_L\d+"),
122127
re.compile(r"input_ids"),
123128
re.compile(r"mapping_index"),
124129
re.compile(r"mlp_out_\d+"),
@@ -503,7 +508,13 @@ def memory_tree(self, print_tree: bool = False, grouped_tree: bool = False) -> d
503508
r"pattern_L(\d+)H\d+": lambda m: f"pattern_L{m.group(1)}",
504509
r"head_out_\d+": "head_out",
505510
r"mlp_out_\d+": "mlp_out",
506-
r"values_\d+": "values",
511+
r"head_values_L(\d+)H\d+": lambda m: f"head_values_L{m.group(1)}",
512+
r"head_keys_L(\d+)H\d+": lambda m: f"head_keys_L{m.group(1)}",
513+
r"head_queries_L(\d+)H\d+": lambda m: f"head_queries_L{m.group(1)}",
514+
r"values_L(\d+)": lambda m: f"values_L{m.group(1)}",
515+
r"keys_L(\d+)": lambda m: f"keys_L{m.group(1)}",
516+
r"queries_L(\d+)": lambda m: f"queries_L{m.group(1)}"
517+
507518
}
508519

509520
for key, size in tree.items():

easyroutine/interpretability/config/config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ models:
1313
google/gemma-3-4b-it: ["<start_of_image>", null, "<end_of_image>"]
1414
google/gemma-3-12b-it: ["<start_of_image>", null, "<end_of_image>"]
1515
google/gemma-3-27b-it: ["<start_of_image>", null, "<end_of_image>"]
16+
llava-hf/llava-onevision-qwen2-7b-ov-hf: ["<image>", null, "<image>"]
1617

1718
tokenizer_placeholder:
1819
facebook/chameleon-7b: "<image>"
@@ -28,6 +29,7 @@ tokenizer_placeholder:
2829
google/gemma-3-4b-it: "<start_of_image>"
2930
google/gemma-3-12b-it: "<start_of_image>"
3031
google/gemma-3-27b-it: "<start_of_image>"
32+
llava-hf/llava-onevision-qwen2-7b-ov-hf: "<image>"
3133

3234
token_position:
3335
- last
@@ -40,6 +42,7 @@ token_position:
4042
- end-image
4143
- special
4244
- random-text
45+
- random-text-10
4346
- random-image
4447
- random-image-10
4548

easyroutine/interpretability/hooked_model.py

Lines changed: 110 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
import yaml
4848

4949

50-
5150
def load_config() -> dict:
5251
with importlib.resources.open_text(
5352
"easyroutine.interpretability.config", "config.yaml"
@@ -100,6 +99,10 @@ class ExtractionConfig:
10099
extract_head_keys (bool): if True, extract the keys of the attention
101100
extract_head_values (bool): if True, extract the values of the attention
102101
extract_head_queries (bool): if True, extract the queries of the attention
102+
extract_values (bool): if True, extract the values. This do not reshape the values to the attention heads as extract_head_values does
103+
extract_keys (bool): if True, extract the keys. This do not reshape the keys to the attention heads as extract_head_keys does
104+
extract_queries (bool): if True, extract the queries. This do not reshape the queries to the attention heads as extract_head_queries does
105+
extract_last_layernorm (bool): if True, extract the last layernorm of the model
103106
extract_head_out (bool): if True, extract the output of the heads [DEPRECATED]
104107
extract_attn_out (bool): if True, extract the output of the attention of the attn_heads passed
105108
extract_attn_in (bool): if True, extract the input of the attention of the attn_heads passed
@@ -124,6 +127,9 @@ class ExtractionConfig:
124127
extract_head_keys: bool = False
125128
extract_head_values: bool = False
126129
extract_head_queries: bool = False
130+
extract_values: bool = False
131+
extract_keys: bool = False
132+
extract_queries: bool = False
127133
extract_head_out: bool = False
128134
extract_attn_out: bool = False
129135
extract_attn_in: bool = False
@@ -142,7 +148,7 @@ class ExtractionConfig:
142148

143149
def is_not_empty(self):
144150
"""
145-
Return True if at least one of the attributes is True, False otherwise, i.e. if the model should extract something!
151+
Return True if at least one extraction option is enabled in the config, False otherwise.
146152
"""
147153
return any(
148154
[
@@ -167,13 +173,15 @@ def is_not_empty(self):
167173
)
168174

169175
def to_dict(self):
176+
"""
177+
Return the configuration as a dictionary.
178+
"""
170179
return self.__dict__
171180

172181

173182
class HookedModel:
174183
"""
175-
This class is a wrapper around the huggingface model that allows to extract the activations of the model. It is support
176-
advanced mechanistic intepretability methods like ablation, patching, etc.
184+
Wrapper around a HuggingFace model for extracting activations and supporting mechanistic interpretability methods.
177185
"""
178186

179187
def __init__(self, config: HookedModelConfig, log_file_path: Optional[str] = None):
@@ -286,27 +294,30 @@ def assert_all_modules_exist(self):
286294

287295
def set_custom_modules(self):
288296
"""
289-
Apply the wrap of the custom modules. for now just the attention module
297+
Substitute custom modules (e.g., attention) into the model for advanced interpretability.
290298
"""
291299
logger.info("HookedModel: Setting custom modules.")
292300
self.module_wrapper_manager.substitute_attention_module(self.hf_model)
293301

294302
def restore_original_modules(self):
295303
"""
296-
Restore the original modules of the model unloading the custom modules.
304+
Restore the original modules of the model, removing any custom substitutions.
297305
"""
298306
logger.info("HookedModel: Restoring original modules.")
299307
self.module_wrapper_manager.restore_original_attention_module(self.hf_model)
300308

301309
def is_multimodal(self) -> bool:
302310
"""
303-
Get if the model is multimodal or not
311+
Return True if the model supports multimodal inputs (e.g., images), False otherwise.
304312
"""
305313
if self.processor is not None:
306314
return True
307315
return False
308316

309317
def use_full_model(self):
318+
"""
319+
Switch to the full model (including multimodal components if available).
320+
"""
310321
if self.processor is not None:
311322
logger.debug("HookedModel: Using full model capabilities")
312323
if self.base_model is not None:
@@ -319,6 +330,9 @@ def use_full_model(self):
319330
logger.debug("HookedModel: Using full text only model capabilities")
320331

321332
def use_language_model_only(self):
333+
"""
334+
Switch to using only the language model component (text-only mode).
335+
"""
322336
if self.hf_language_model is None:
323337
logger.warning(
324338
"HookedModel: The model does not have a separate language model that can be used",
@@ -333,6 +347,9 @@ def use_language_model_only(self):
333347
logger.debug("HookedModel: Using only language model capabilities")
334348

335349
def get_tokenizer(self):
350+
"""
351+
Return the tokenizer associated with the model.
352+
"""
336353
return self.hf_tokenizer
337354

338355
def get_text_tokenizer(self):
@@ -366,41 +383,42 @@ def get_processor(self):
366383
return self.processor
367384

368385
def get_lm_head(self):
386+
"""
387+
Return the language modeling head (output projection layer) of the model.
388+
"""
369389
return get_attribute_by_name(self.hf_model, self.model_config.unembed_matrix)
370390

371391
def get_last_layernorm(self):
392+
"""
393+
Return the last layer normalization module of the model.
394+
"""
372395
return get_attribute_by_name(self.hf_model, self.model_config.last_layernorm)
373396

374397
def get_image_placeholder(self) -> str:
398+
"""
399+
Return the image placeholder string used by the tokenizer for multimodal models.
400+
"""
375401
return self.image_placeholder
376402

377403
def eval(self):
378-
r"""
379-
Set the model in evaluation mode
404+
"""
405+
Set the model to evaluation mode.
380406
"""
381407
self.hf_model.eval()
382408

383409
def device(self):
384-
r"""
385-
Return the device of the model. If the model is in multiple devices, it will return the first device
386-
387-
Args:
388-
None
389-
390-
Returns:
391-
device: the device of the model
410+
"""
411+
Return the device (e.g., 'cuda', 'cpu') where the model is located.
392412
"""
393413
return self.first_device
394414

395415
def register_forward_hook(self, component: str, hook_function: Callable):
396416
r"""
397-
Add a new hook to the model. The hook will be called in the forward pass of the model.
417+
Register a forward hook on a model component.
398418
399419
Args:
400-
component (str): the component of the model where the hook will be added.
401-
hook_function (Callable): the function that will be called in the forward pass of the model. The function must have the following signature:
402-
def hook_function(module, input, output):
403-
pass
420+
component (str): Name of the model component.
421+
hook_function (Callable): Function to call during forward pass.
404422
405423
Returns:
406424
None
@@ -447,10 +465,16 @@ def to_string_tokens(
447465
return string_tokens
448466

449467
def register_interventions(self, interventions: List[Intervention]):
468+
"""
469+
Register a list of interventions to be applied during forward passes.
470+
"""
450471
self.additional_interventions = interventions
451472
logger.debug(f"HookedModel: Registered {len(interventions)} interventions")
452473

453474
def clean_interventions(self):
475+
"""
476+
Remove all registered interventions.
477+
"""
454478
self.additional_interventions = []
455479
logger.debug(
456480
f"HookedModel: Removed {len(self.additional_interventions)} interventions"
@@ -503,6 +527,11 @@ def create_hooks(
503527
raise ValueError(
504528
"attn_heads must be 'all' or a list of dictionaries as [{'layer': 0, 'head': 0}]"
505529
)
530+
# register the intervention hooks as first thing to do
531+
if self.additional_interventions is not None:
532+
hooks += self.intervention_manager.create_intervention_hooks(
533+
interventions=self.additional_interventions, token_dict=token_dict
534+
)
506535

507536
if extraction_config.extract_resid_out:
508537
# assert that the component exists in the model
@@ -589,13 +618,14 @@ def create_hooks(
589618
"intervention": partial(
590619
query_key_value_hook,
591620
cache=cache,
592-
cache_key="queries_",
621+
cache_key="head_queries_",
593622
token_indexes=token_indexes,
594623
head_dim=self.model_config.head_dim,
595624
avg=extraction_config.avg,
596625
layer=i,
597626
head=head,
598627
num_key_value_groups=self.model_config.num_key_value_groups,
628+
num_attention_heads=self.model_config.num_attention_heads,
599629
),
600630
}
601631
for i, head in zip(layer_indexes, head_indexes)
@@ -608,13 +638,14 @@ def create_hooks(
608638
"intervention": partial(
609639
query_key_value_hook,
610640
cache=cache,
611-
cache_key="values_",
641+
cache_key="head_values_",
612642
token_indexes=token_indexes,
613643
head_dim=self.model_config.head_dim,
614644
avg=extraction_config.avg,
615645
layer=i,
616646
head=head,
617647
num_key_value_groups=self.model_config.num_key_value_groups,
648+
num_attention_heads=self.model_config.num_attention_heads,
618649
),
619650
}
620651
for i, head in zip(layer_indexes, head_indexes)
@@ -627,18 +658,64 @@ def create_hooks(
627658
"intervention": partial(
628659
query_key_value_hook,
629660
cache=cache,
630-
cache_key="keys_",
661+
cache_key="head_keys_",
631662
token_indexes=token_indexes,
632663
head_dim=self.model_config.head_dim,
633664
avg=extraction_config.avg,
634665
layer=i,
635666
head=head,
636667
num_key_value_groups=self.model_config.num_key_value_groups,
668+
num_attention_heads=self.model_config.num_attention_heads,
637669
),
638670
}
639671
for i, head in zip(layer_indexes, head_indexes)
640672
]
641673

674+
if extraction_config.extract_values:
675+
hooks += [
676+
{
677+
"component": self.model_config.head_value_hook_name.format(i),
678+
"intervention": partial(
679+
save_resid_hook,
680+
cache=cache,
681+
cache_key=f"values_L{i}",
682+
token_indexes=token_indexes,
683+
avg=extraction_config.avg,
684+
),
685+
}
686+
for i in range(0, self.model_config.num_hidden_layers)
687+
]
688+
689+
if extraction_config.extract_keys:
690+
hooks += [
691+
{
692+
"component": self.model_config.head_key_hook_name.format(i),
693+
"intervention": partial(
694+
save_resid_hook,
695+
cache=cache,
696+
cache_key=f"keys_L{i}",
697+
token_indexes=token_indexes,
698+
avg=extraction_config.avg,
699+
),
700+
}
701+
for i in range(0, self.model_config.num_hidden_layers)
702+
]
703+
704+
if extraction_config.extract_queries:
705+
hooks += [
706+
{
707+
"component": self.model_config.head_query_hook_name.format(i),
708+
"intervention": partial(
709+
save_resid_hook,
710+
cache=cache,
711+
cache_key=f"queries_L{i}",
712+
token_indexes=token_indexes,
713+
avg=extraction_config.avg,
714+
),
715+
}
716+
for i in range(0, self.model_config.num_hidden_layers)
717+
]
718+
642719
if extraction_config.extract_head_out:
643720
hooks += [
644721
{
@@ -772,10 +849,6 @@ def create_hooks(
772849
hooks += self.intervention_manager.create_intervention_hooks(
773850
interventions=interventions, token_dict=token_dict
774851
)
775-
if self.additional_interventions is not None:
776-
hooks += self.intervention_manager.create_intervention_hooks(
777-
interventions=self.additional_interventions, token_dict=token_dict
778-
)
779852
if extraction_config.extract_head_values_projected:
780853
hooks += [
781854
{
@@ -1167,7 +1240,7 @@ def set_hooks(self, hooks: List[Dict[str, Any]]):
11671240

11681241
def remove_hooks(self, hook_handlers):
11691242
"""
1170-
Remove all the hooks from the model
1243+
Remove all hooks from the model using the provided handlers.
11711244
"""
11721245
for hook_handler in hook_handlers:
11731246
hook_handler.remove()
@@ -1201,7 +1274,10 @@ def generate(
12011274
# raise NotImplementedError("This method is not working. It needs to be fixed")
12021275
cache = ActivationCache()
12031276
hook_handlers = None
1204-
if target_token_positions is not None or self.additional_interventions is not None:
1277+
if (
1278+
target_token_positions is not None
1279+
or self.additional_interventions is not None
1280+
):
12051281
string_tokens = self.to_string_tokens(
12061282
self.input_handler.get_input_ids(inputs).squeeze()
12071283
)
@@ -1293,8 +1369,9 @@ def extract_cache(
12931369
# example_dict = {}
12941370
n_batches = 0 # Initialize batch counter
12951371

1296-
for batch in progress(dataloader, desc="Extracting cache", total=len(dataloader)):
1297-
1372+
for batch in progress(
1373+
dataloader, desc="Extracting cache", total=len(dataloader)
1374+
):
12981375
# log_memory_usage("Extract cache - Before batch")
12991376
# tokens, others = batch
13001377
# inputs = {k: v.to(self.first_device) for k, v in tokens.items()}

0 commit comments

Comments
 (0)