diff --git a/README.md b/README.md index 83e556c..2a26dad 100644 --- a/README.md +++ b/README.md @@ -1,54 +1,220 @@ # EasyRoutine -This is just a simple collection of routines that I use frequently. I have found that I often need to do the same things over and over again, so I have created this repository to store them. I hope you find them useful. +A comprehensive Python toolkit for AI model interpretability, analysis, and efficient inference. EasyRoutine provides researchers and practitioners with powerful tools for understanding transformer models, extracting internal activations, and performing mechanistic interpretability studies. -## Installation +## 🚀 Key Features +- **🔍 Mechanistic Interpretability**: Deep analysis of transformer model internals + - Extract activations from any model component (residual streams, attention layers, MLPs) + - Support for attention pattern analysis and head-specific extractions + - Intervention capabilities for ablation studies and activation patching + - Works with both language models and vision-language models -## Interpretability -The interpretability module contains wrapper of huggingface LLM/VLM that help to perform interpretability tasks on the model. Currently, it supports: -- Extract activations of any component of the model -- Perform ablation study on the model during inference -- Perform activation patching on the model during inference +- **⚡ High-Performance Inference**: Optimized model inference with multiple backends + - VLLM integration for high-throughput inference + - Multi-GPU support and memory optimization + - Configurable generation parameters and chat templates + +- **📊 Smart Progress Tracking**: Adaptive progress bars for any environment + - Rich progress bars for interactive use + - Clean text logging for batch jobs (SLURM, PBS, etc.) + - Automatic environment detection + +- **🛠️ Essential Utilities**: Common tools for ML workflows + - Robust logging with multiple output formats + - File system navigation helpers + - Memory management utilities + +## 📦 Installation + +```bash +pip install easyroutine +``` + +For development installation: +```bash +git clone https://github.com/francescortu/easyroutine.git +cd easyroutine +pip install -e . +``` + +## 🔍 Interpretability - Quick Start + +### Basic Activation Extraction -### Simple Tutorial ```python -# First we need to import the HookedModel and the config classes -from easyroutine.interpretability import HookedModel, ExtractionConfig +from easyroutine.interpretability import HookedModel, ExtractionConfig +from easyroutine.console import progress -hooked_model = HookedModel.from_pretrained( - model_name="mistral-community/pixtral-12b", # the model name - device_map = "auto" +# Load any Hugging Face transformer model +model = HookedModel.from_pretrained( + model_name="gpt2", # or any HF model + device_map="auto" ) -# Now let's define a simple dataset -dataset = [ - "This is a test", - "This is another test" -] +# Prepare your data +texts = ["Hello, world!", "How are you today?"] +tokenizer = model.get_tokenizer() +dataset = [tokenizer(text, return_tensors="pt") for text in texts] + +# Configure what activations to extract +config = ExtractionConfig( + extract_resid_out=True, # Residual stream outputs + extract_attn_pattern=True, # Attention patterns + extract_mlp_out=True, # MLP layer outputs + save_input_ids=True # Keep track of tokens +) -tokenizer = hooked_model.get_tokenizer() +# Extract activations with progress tracking +cache = model.extract_cache( + progress(dataset, description="Extracting activations"), + target_token_positions=["last"], # Focus on final token + extraction_config=config +) -dataset = tokenizer(dataset, padding=True, truncation=True, return_tensors="pt") +# Access extracted data +residual_activations = cache["resid_out_0"] # Layer 0 residual outputs +attention_patterns = cache["attn_pattern_0_5"] # Layer 0, Head 5 attention +print(f"Extracted activations: {list(cache.keys())}") +``` -cache = hooked_model.extract_cache( - dataset, - target_token_positions = ["last"], - extraction_config = ExtractionConfig( - extract_resid_out = True +### Advanced: Intervention Studies + +```python +from easyroutine.interpretability import Intervention + +# Define interventions for causal analysis +interventions = [ + Intervention( + component="resid_out_5", # Target layer 5 residual stream + intervention_type="ablation", # Zero out activations + positions=["last"] # Only affect last token ) +] + +# Run model with interventions +cache_with_intervention = model.extract_cache( + dataset, + target_token_positions=["last"], + extraction_config=config, + interventions=interventions +) + +# Compare original vs. intervened activations +original_logits = cache["logits"] +intervened_logits = cache_with_intervention["logits"] +effect = original_logits - intervened_logits +``` + +### Vision-Language Models + +```python +# Works seamlessly with VLMs like LLaVA, Pixtral, etc. +vlm = HookedModel.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") + +# Process multimodal inputs +processor = vlm.get_processor() +inputs = processor(images=image, text="What do you see?", return_tensors="pt") + +# Extract activations from multimodal processing +cache = vlm.extract_cache( + [inputs], + target_token_positions=["last-image", "end-image"], + extraction_config=ExtractionConfig(extract_resid_out=True) +) +``` + +## ⚡ High-Performance Inference + +```python +from easyroutine.inference import VLLMInferenceModel + +# Initialize high-performance inference engine +model = VLLMInferenceModel.init_model( + model_name="microsoft/DialoGPT-large", + n_gpus=2, + dtype="bfloat16" +) + +# Generate responses +response = model.generate("Hello, how can I help you today?") + +# Multi-turn conversations with chat templates +chat_history = [] +chat_history = model.append_with_chat_template( + "What is machine learning?", + role="user", + chat_history=chat_history +) +response = model.generate_with_chat_template(chat_history) +``` + +## 🛠️ Utilities & Logging + +```python +from easyroutine.logger import setup_logging, logger +from easyroutine import path_to_parents +from easyroutine.console import progress + +# Flexible logging setup +setup_logging( + level="INFO", + file="experiment.log", + console=True, + fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) -```` +logger.info("Starting experiment...") + +# Convenient navigation helpers +path_to_parents(2) # Go up 2 directory levels + +# Smart progress tracking that adapts to your environment +for item in progress(large_dataset, description="Processing data"): + result = expensive_computation(item) + +# In interactive environments: rich progress bar with ETA +# In batch jobs (SLURM, etc.): clean timestamped logging +# [2024-01-15 10:30:15] Processing data: 1500/10000 (15.0%) - Elapsed: 2.3m, Remaining: 13.1m +``` + +## 📖 Documentation + +- **API Reference**: Comprehensive docstrings for all functions and classes +- **Examples**: Jupyter notebooks with common use cases +- **Tutorials**: Step-by-step guides for interpretability workflows + +## 🤝 Contributing + +We welcome contributions! Please see our contributing guidelines for details on: +- Code style and testing requirements +- Documentation standards +- How to submit issues and pull requests + +## 📄 License + +This project is licensed under the MIT License - see the LICENSE file for details. + +## 🔬 Citation + +If you use EasyRoutine in your research, please cite: +```bibtex +@software{easyroutine2024, + title={EasyRoutine: A Toolkit for AI Model Interpretability}, + author={Francesco Ortu}, + year={2024}, + url={https://github.com/francescortu/easyroutine} +} +``` +--- ### Development -For publish the package push a commit with the flag: - - `[patch]`: x.x.7 -> x.x.8 - - `[minor]`: x.7.x -> x.8.0 - - `[major]`: 2.x.x -> 3.0.0 -Example commit: `fix multiple bus [patch]` +For publishing new versions, use semantic version tags in commit messages: +- `[patch]`: Bug fixes (x.x.7 → x.x.8) +- `[minor]`: New features (x.7.x → x.8.0) +- `[major]`: Breaking changes (2.x.x → 3.0.0) -- \ No newline at end of file +Example: `git commit -m "Add support for Gemma models [minor]"` \ No newline at end of file diff --git a/easyroutine/__init__.py b/easyroutine/__init__.py index b7e0b3f..5823be9 100644 --- a/easyroutine/__init__.py +++ b/easyroutine/__init__.py @@ -1,2 +1,79 @@ +""" +EasyRoutine: A comprehensive toolkit for AI model interpretability and analysis. + +EasyRoutine is a powerful Python package that provides a collection of utilities +and tools for working with machine learning models, with a particular focus on +mechanistic interpretability of transformer models. It offers a unified interface +for model analysis, activation extraction, intervention studies, and more. + +Key Features: + - 🔍 Mechanistic Interpretability: Comprehensive tools for analyzing transformer + model internals, including activation extraction, attention pattern analysis, + and intervention studies + - 🚀 High-Performance Inference: Optimized inference backends with support for + VLLM and custom implementations + - 📊 Progress Tracking: Adaptive progress bars that work in both interactive + and batch environments + - 🛠️ Utility Functions: Common utilities for file system navigation and + robust logging + - 🔧 Extensible Architecture: Modular design allowing easy extension and + customization + +Main Modules: + interpretability: Core functionality for transformer model analysis + - HookedModel: Wrapper for extracting activations from any transformer + - ExtractionConfig: Fine-grained control over activation extraction + - ActivationCache: Efficient storage and manipulation of activations + - Intervention tools: Ablation studies and activation patching + + inference: Model inference interfaces for various backends + - BaseInferenceModel: Abstract interface for inference implementations + - VLLMInferenceModel: High-performance inference with VLLM backend + - Configuration classes for flexible inference setup + + console: Progress tracking and console utilities + - progress(): tqdm-style progress bars with environment detection + - LoggingProgress: Text-based progress for batch environments + + logger: Robust logging functionality + - Structured logging with multiple output options + - Level control and formatting customization + - Integration with rich for enhanced console output + + utils: Common utility functions + - File system navigation helpers + - Path manipulation utilities + +Quick Start: + >>> import easyroutine + >>> from easyroutine.interpretability import HookedModel, ExtractionConfig + >>> from easyroutine.console import progress + >>> + >>> # Load a model for interpretability analysis + >>> model = HookedModel.from_pretrained("gpt2") + >>> + >>> # Configure what to extract + >>> config = ExtractionConfig( + ... extract_resid_out=True, + ... extract_attn_pattern=True + ... ) + >>> + >>> # Extract activations with progress tracking + >>> data = [{"input_ids": tokenized_inputs}] + >>> cache = model.extract_cache( + ... progress(data, description="Extracting activations"), + ... target_token_positions=["last"], + ... extraction_config=config + ... ) + +Installation: + The package can be installed via pip: + ```bash + pip install easyroutine + ``` + +For more detailed documentation and examples, visit the project repository. +""" + from .utils import path_to_parents from .logger import logger \ No newline at end of file diff --git a/easyroutine/console/progress.py b/easyroutine/console/progress.py index 27b050c..b12c69d 100644 --- a/easyroutine/console/progress.py +++ b/easyroutine/console/progress.py @@ -1,3 +1,52 @@ +""" +Progress tracking utilities for interactive and batch environments. + +This module provides adaptive progress tracking that automatically adjusts its +behavior based on the execution environment. It offers rich, interactive progress +bars for terminal sessions and clean, log-friendly progress updates for batch +jobs and non-interactive environments. + +Key Features: + - Automatic environment detection (interactive vs batch) + - Rich progress bars for interactive terminals + - Clean text-based logging for batch jobs + - Support for both time-based and item-count-based update intervals + - Compatible with common batch systems (SLURM, PBS, SGE, etc.) + - tqdm-style interface for easy integration + +Main Components: + - LoggingProgress: Text-based progress tracker for batch environments + - _NoOpProgress: Null progress tracker for disabled progress + - progress(): Main function providing tqdm-style progress tracking + - get_progress_bar(): Factory function for creating appropriate progress trackers + +The module automatically detects execution context and chooses the most appropriate +progress tracking method: + - Interactive terminals: Rich progress bars with visual elements + - Batch jobs (sbatch, etc.): Text-based logging with timestamps + - Disabled mode: No-op tracker that doesn't display anything + +Environment Detection: + The module detects batch environments by checking for: + - Batch system environment variables (SLURM_JOB_ID, PBS_JOBID, etc.) + - Non-interactive terminals (!sys.stdout.isatty()) + - "dumb" terminal types + - Output redirection + +Example Usage: + >>> # Simple iteration with progress + >>> from easyroutine.console import progress + >>> for item in progress(my_list, description="Processing"): + ... process(item) + + >>> # Manual progress tracking + >>> with get_progress_bar() as pbar: + ... task = pbar.add_task("Training", total=epochs) + ... for epoch in range(epochs): + ... train_epoch() + ... pbar.update(task) +""" + from rich.progress import ( BarColumn, MofNCompleteColumn, @@ -26,17 +75,53 @@ class LoggingProgress: """ - A progress tracker designed for batch environments (sbatch, etc.) - that outputs clean, consistent progress updates to stdout/stderr. + A progress tracker designed for non-interactive batch environments. + + This class provides progress tracking functionality specifically optimized for + batch job environments (like sbatch, PBS, SGE) where fancy progress bars won't + display properly. Instead of using visual progress bars, it outputs clean, + timestamped progress updates to stdout/stderr that are suitable for job logs. + + The progress tracker supports both time-based and item-count-based update + intervals, allowing flexible control over update frequency to balance + informativeness with log verbosity. + + Attributes: + tasks (dict): Internal storage for tracking multiple tasks and their progress. + log_interval (int): Minimum time interval between progress updates in seconds. + update_frequency (int): Item count interval for progress updates (0 = disabled). + + Example: + >>> with LoggingProgress(log_interval=10, update_frequency=100) as progress: + ... task_id = progress.add_task("Processing files", total=1000) + ... for i in range(1000): + ... # do work + ... progress.update(task_id) + + >>> # Or use the track method for iterables + >>> with LoggingProgress() as progress: + ... for item in progress.track(data_list, description="Processing"): + ... # process item """ def __init__(self, log_interval: int = 5, update_frequency: int = 0): """ - Initialize the logging progress tracker. + Initialize the logging progress tracker with customizable update intervals. Args: - log_interval: How often to log progress updates (in seconds) - update_frequency: Alternative to log_interval - update every N items (0 = use log_interval only) + log_interval (int, optional): Minimum time interval between progress + updates in seconds. Progress will be logged at most once per this + interval, regardless of how frequently update() is called. + Defaults to 5 seconds. + update_frequency (int, optional): Alternative update trigger based on + item count. If set to a positive value, progress will be logged + every N items processed, in addition to time-based updates. + Set to 0 to disable item-count-based updates. Defaults to 0. + + Note: + Both update triggers work together - progress is logged when either + the time interval has elapsed OR the item count threshold is reached, + whichever comes first. """ self.tasks = {} self.log_interval = log_interval @@ -49,7 +134,33 @@ def __exit__(self, exc_type, exc_value, traceback): pass def add_task(self, description: str, total: int = None, **kwargs): - """Add a task to track.""" + """ + Add a new task to track with the progress tracker. + + Args: + description (str): Human-readable description of the task to be + displayed in progress updates. Should be concise but descriptive. + total (int, optional): Total number of items/steps expected for this + task. If provided, enables percentage calculation and ETA estimates. + If None, only item counts and elapsed time will be displayed. + Defaults to None. + **kwargs: Additional keyword arguments (currently unused but accepted + for compatibility with other progress bar interfaces). + + Returns: + int: Unique task identifier that should be used with update() calls + to track progress for this specific task. + + Side Effects: + - Creates a new task entry in the internal tasks dictionary + - Prints an initial status message indicating the task has started + - Records the task start time for elapsed time calculations + + Example: + >>> progress = LoggingProgress() + >>> task_id = progress.add_task("Training model", total=100) + # Output: [PROGRESS] Starting: Training model (Total: 100) + """ task_id = len(self.tasks) self.tasks[task_id] = { "description": description, @@ -64,7 +175,40 @@ def add_task(self, description: str, total: int = None, **kwargs): return task_id def update(self, task_id, advance=1, **kwargs): - """Update task progress.""" + """ + Update the progress of a tracked task. + + This method increments the completion counter for a task and conditionally + logs progress updates based on the configured time and item count intervals. + Progress is logged when either sufficient time has elapsed since the last + update OR when enough items have been processed since the last log. + + Args: + task_id (int): The task identifier returned by add_task(). + advance (int, optional): Number of items/steps to increment the + progress counter by. Defaults to 1. + **kwargs: Additional keyword arguments (currently unused but accepted + for compatibility with other progress bar interfaces). + + Returns: + None: This function has no return value. + + Side Effects: + - Increments the task's completion counter + - May print a progress update message if update conditions are met + - Updates internal timing information for the task + + Example: + >>> task_id = progress.add_task("Processing", total=100) + >>> for i in range(100): + ... # do work + ... progress.update(task_id) # Increment by 1 + >>> # Or increment by multiple items at once + >>> progress.update(task_id, advance=5) # Increment by 5 + + Note: + If the specified task_id doesn't exist, the call is silently ignored. + """ if task_id not in self.tasks: return @@ -110,7 +254,46 @@ def update(self, task_id, advance=1, **kwargs): def track( self, iterable: Iterable[T], total: Optional[int] = None, description: str = "" ) -> Iterable[T]: - """Track progress through an iterable.""" + """ + Track progress through an iterable with automatic progress updates. + + This method wraps an iterable and automatically updates progress as items + are yielded. It's a convenience method that combines add_task() and update() + calls for simple iteration tracking. + + Args: + iterable (Iterable[T]): The iterable to track progress through. + Can be any iterable object (list, generator, etc.). + total (int, optional): Total number of items in the iterable. + If None, attempts to determine the length using len(). + If the length cannot be determined, only item counts and + elapsed time will be displayed. Defaults to None. + description (str, optional): Description for the task to be displayed + in progress updates. Defaults to empty string. + + Yields: + T: Items from the original iterable, yielded one at a time. + + Returns: + Iterable[T]: A generator that yields items from the original iterable + while tracking progress. + + Side Effects: + - Creates a new task for tracking this iterable + - Prints progress updates as items are processed + - Prints a completion message when the iterable is exhausted + + Example: + >>> data = list(range(1000)) + >>> with LoggingProgress() as progress: + ... for item in progress.track(data, description="Processing data"): + ... # process item + ... result = expensive_operation(item) + + Note: + This method creates a new task internally and manages all progress + updates automatically. No manual update() calls are needed. + """ if total is None: try: total = len(iterable) @@ -134,7 +317,36 @@ def track( def format_time(seconds: float) -> str: - """Format seconds into a readable time string.""" + """ + Format a time duration in seconds into a human-readable string. + + This utility function converts a floating-point duration in seconds + into a more readable format using appropriate time units (seconds, + minutes, or hours) based on the magnitude of the duration. + + Args: + seconds (float): Duration in seconds to format. Can be fractional. + + Returns: + str: Formatted time string with one decimal place and appropriate unit: + - Durations < 60 seconds: "X.Xs" (e.g., "45.2s") + - Durations < 3600 seconds: "X.Xm" (e.g., "12.5m") + - Durations >= 3600 seconds: "X.Xh" (e.g., "2.3h") + + Example: + >>> format_time(45.7) + '45.7s' + >>> format_time(125.3) + '2.1m' + >>> format_time(7890.5) + '2.2h' + >>> format_time(0.123) + '0.1s' + + Note: + The function rounds to one decimal place for readability. For very + large durations, the hours format provides a compact representation. + """ if seconds < 60: return f"{seconds:.1f}s" elif seconds < 3600: @@ -266,23 +478,80 @@ def progress( update_frequency: int = 0, ): """ - A tqdm-style progress bar that can be wrapped around an iterable. + A tqdm-style progress tracker that automatically adapts to the environment. - This function automatically adapts to the environment: - - In interactive sessions (including interactive Slurm jobs), it shows a rich progress bar - - In non-interactive batch jobs (like sbatch), it uses simple text-based progress tracking + This function provides a drop-in replacement for tqdm that automatically + detects the execution environment and displays appropriate progress feedback: + - Rich progress bars in interactive terminals + - Clean text-based logging in batch jobs + - No output when disabled - e.g. `for i in progress(range(10)):` + The function wraps an iterable and yields items while tracking progress, + making it easy to add progress tracking to existing loops with minimal + code changes. Args: - iterable: The iterable to wrap with a progress bar. - description (str): Description to display. - total (int, optional): The total number of items. If None, it's inferred from len(iterable). - disable (bool): If True, the progress bar is disabled completely. - force_batch_mode (bool): If True, use text-based progress tracking even in interactive environments. - log_interval (int): In batch mode, how often (in seconds) to log progress updates. - update_frequency (int): In batch mode, update progress after processing this many items. - Set to 0 to use only time-based updates. + iterable: The iterable to wrap with progress tracking. Can be any + iterable object (list, tuple, generator, etc.). + description (str, optional): Description to display with the progress. + Shows what operation is being performed. Defaults to empty string. + desc (str, optional): Alternative parameter name for description, + provided for tqdm compatibility. If provided, takes precedence + over description parameter. Defaults to None. + total (int, optional): Expected total number of items in the iterable. + If None, attempts to determine automatically using len(iterable). + Required for accurate percentage and ETA calculations. + Defaults to None. + disable (bool, optional): If True, completely disables progress tracking + and returns the original iterable unchanged. Useful for conditional + progress display. Defaults to False. + force_batch_mode (bool, optional): If True, forces text-based progress + tracking even in interactive environments. Useful for testing or + when rich output is not desired. Defaults to False. + log_interval (int, optional): In batch mode, minimum seconds between + progress updates. Prevents excessive logging while ensuring regular + updates. Defaults to 1 second. + update_frequency (int, optional): In batch mode, number of items to + process between progress updates. Works in addition to time-based + updates. Set to 0 to disable item-count-based updates. + Defaults to 0. + + Yields: + Items from the original iterable, one at a time, with progress tracking. + + Returns: + Generator: A generator that yields items from the iterable while + tracking and displaying progress. + + Examples: + >>> # Basic usage - simple iteration with progress + >>> for item in progress(range(1000), description="Processing"): + ... process_item(item) + + >>> # With custom total for generators + >>> data_gen = generate_data() + >>> for item in progress(data_gen, total=expected_count, desc="Loading"): + ... handle_item(item) + + >>> # Conditional progress (disable in quiet mode) + >>> for item in progress(data, disable=quiet_mode): + ... process(item) + + >>> # Force batch mode for consistent output + >>> for item in progress(items, force_batch_mode=True, log_interval=5): + ... slow_operation(item) + + Environment Behavior: + - Interactive terminals: Displays rich progress bar with percentage, + visual bar, item counts, and time estimates + - Batch jobs: Outputs timestamped text updates at specified intervals + - Non-interactive: Falls back to text-based logging + - Disabled: Returns original iterable with no progress tracking + + Note: + The function automatically handles cases where len(iterable) is not + available (e.g., generators) by falling back to count-only progress + without percentage or ETA calculations. """ if total is None: try: diff --git a/easyroutine/inference/__init__.py b/easyroutine/inference/__init__.py index 5390c10..d8a19fd 100644 --- a/easyroutine/inference/__init__.py +++ b/easyroutine/inference/__init__.py @@ -1,2 +1,63 @@ +""" +Model inference interfaces for various backends and deployment scenarios. + +This package provides standardized interfaces for running inference across different +model backends and deployment strategies. It abstracts the complexities of different +inference engines while providing a consistent API for text generation and model +interaction. + +Key Features: + - Unified interface across different inference backends + - Support for both single-GPU and multi-GPU deployments + - Configurable generation parameters (temperature, top_p, etc.) + - Chat template management for conversational models + - Optimized inference with VLLM backend support + - Extensible architecture for custom inference implementations + +Supported Backends: + - VLLM: High-performance inference engine with PagedAttention + - Hugging Face Transformers: Direct model inference + - Custom implementations: Extensible base classes + +Main Components: + - BaseInferenceModel: Abstract base class for all inference implementations + - BaseInferenceModelConfig: Common configuration for inference parameters + - VLLMInferenceModel: VLLM-based high-performance inference + - VLLMInferenceModelConfig: VLLM-specific configuration options + +Example Usage: + >>> from easyroutine.inference import VLLMInferenceModel + >>> + >>> # Initialize with VLLM backend + >>> model = VLLMInferenceModel.init_model( + ... model_name="microsoft/DialoGPT-large", + ... n_gpus=2, + ... dtype="bfloat16" + ... ) + >>> + >>> # Generate responses + >>> response = model.generate("Hello, how are you?") + >>> print(response) + + >>> # Multi-turn conversation + >>> chat_history = [] + >>> chat_history = model.append_with_chat_template( + ... "What is machine learning?", + ... role="user", + ... chat_history=chat_history + ... ) + >>> response = model.generate_with_chat_template(chat_history) + +The package is designed to be backend-agnostic, allowing easy switching between +different inference engines based on performance requirements, memory constraints, +or deployment scenarios. +""" + from easyroutine.inference.base_model_interface import BaseInferenceModelConfig -from easyroutine.inference.vllm_model_interface import VLLMInferenceModel, VLLMInferenceModelConfig \ No newline at end of file + +# Optional VLLM imports - only available if vllm is installed +try: + from easyroutine.inference.vllm_model_interface import VLLMInferenceModel, VLLMInferenceModelConfig +except ImportError: + # VLLM is not available - this is optional + pass \ No newline at end of file diff --git a/easyroutine/inference/base_model_interface.py b/easyroutine/inference/base_model_interface.py index 234469a..539235d 100644 --- a/easyroutine/inference/base_model_interface.py +++ b/easyroutine/inference/base_model_interface.py @@ -6,7 +6,59 @@ @dataclass class BaseInferenceModelConfig: """ - Configuration for the model interface. + Base configuration class for inference model implementations. + + This configuration class provides common parameters needed for model + inference across different backends and implementations. It establishes + a standard interface for configuring model loading, generation parameters, + and hardware utilization. + + Attributes: + model_name (str): Identifier of the model to load. Can be: + - Hugging Face model repository name (e.g., "gpt-3.5-turbo") + - Local path to a model directory + - Model name supported by the specific inference backend + + n_gpus (int, optional): Number of GPUs to utilize for model inference. + Determines parallel processing capability and memory distribution. + Defaults to 1. + + dtype (str, optional): Data type precision for model parameters. + Common options include: + - "bfloat16": Good balance of speed and accuracy + - "float16": Faster inference, potential accuracy loss + - "float32": Full precision, slower inference + Defaults to "bfloat16". + + temperature (float, optional): Sampling temperature for text generation. + Controls randomness in generation: + - 0.0: Deterministic, always select most likely token + - 0.1-0.7: Low randomness, more focused responses + - 0.8-1.2: Higher randomness, more creative responses + Defaults to 0 (deterministic). + + top_p (float, optional): Nucleus sampling parameter. Only tokens with + cumulative probability <= top_p are considered for sampling. + Range: [0.0, 1.0]. Lower values make output more focused. + Defaults to 0.95. + + max_new_tokens (int, optional): Maximum number of new tokens to generate + in a single inference call. Controls output length and prevents + runaway generation. Defaults to 5000. + + Example: + >>> config = BaseInferenceModelConfig( + ... model_name="microsoft/DialoGPT-large", + ... n_gpus=2, + ... dtype="bfloat16", + ... temperature=0.7, + ... max_new_tokens=1000 + ... ) + >>> model = SomeInferenceModel(config) + + Note: + This is a base configuration class. Specific inference implementations + may extend this with additional parameters relevant to their backend. """ model_name: str n_gpus: int = 1 @@ -20,8 +72,51 @@ class BaseInferenceModelConfig: class BaseInferenceModel(ABC): """ - Base class for inference models. - This class should be extended by specific model implementations. + Abstract base class for inference model implementations. + + This class defines the standard interface that all inference model implementations + should follow, ensuring consistency across different backends (VLLM, Hugging Face, + custom implementations, etc.). It provides common functionality and enforces + implementation of essential methods through abstract methods. + + The base class handles configuration management and provides utility methods + for common inference tasks like chat template application and message formatting. + Subclasses should implement the specific inference logic for their backend. + + Key Design Principles: + - Uniform interface across different inference backends + - Configuration-driven initialization and behavior + - Support for both single-turn and multi-turn conversations + - Extensible for backend-specific optimizations + - Thread-safe inference operations + + Attributes: + config (BaseInferenceModelConfig): Configuration object containing + model parameters, generation settings, and hardware specifications. + + Abstract Methods: + Subclasses must implement backend-specific methods for: + - Model loading and initialization + - Text generation and inference + - Resource management and cleanup + + Example: + >>> class MyInferenceModel(BaseInferenceModel): + ... def __init__(self, config): + ... super().__init__(config) + ... # Initialize specific backend + ... + ... def generate(self, prompt): + ... # Implement generation logic + ... pass + >>> + >>> model = MyInferenceModel.init_model("gpt2", n_gpus=1) + >>> response = model.generate("Hello, world!") + + Note: + This class uses the ABC (Abstract Base Class) pattern to ensure + all subclasses implement required methods. Direct instantiation + of this class will raise a TypeError. """ def __init__(self, config: BaseInferenceModelConfig): @@ -30,15 +125,52 @@ def __init__(self, config: BaseInferenceModelConfig): @classmethod def init_model(cls, model_name: str, n_gpus: int = 1, dtype: str = 'bfloat16') -> 'BaseInferenceModel': """ - Initialize the model with the given configuration. + Class method for convenient model initialization with minimal configuration. + + This factory method provides a streamlined way to create model instances + with common default settings, automatically constructing the configuration + object and initializing the model. It's designed for quick setup scenarios + where detailed configuration isn't needed. + + Args: + model_name (str): Identifier of the model to initialize. Accepts: + - Hugging Face model repository names + - Local model paths + - Any model identifier supported by the implementation + + n_gpus (int, optional): Number of GPUs to allocate for the model. + Determines parallelization and memory distribution strategy. + Must be > 0. Defaults to 1. + + dtype (str, optional): Precision/data type for model parameters. + Affects inference speed and memory usage: + - "bfloat16": Recommended for most use cases + - "float16": Faster but may affect accuracy + - "float32": Full precision, slower + Defaults to "bfloat16". - Arguments: - model_name (str): Name of the model to initialize. - n_gpus (int): Number of GPUs to use. - dtype (str): Data type for the model. Returns: - - InferenceModel: An instance of the model. + BaseInferenceModel: An initialized instance of the implementing class + ready for inference operations. + + Example: + >>> # Quick initialization with defaults + >>> model = MyInferenceModel.init_model("gpt2") + >>> + >>> # Custom GPU and precision settings + >>> model = MyInferenceModel.init_model( + ... model_name="microsoft/DialoGPT-large", + ... n_gpus=2, + ... dtype="float16" + ... ) + >>> + >>> # Ready for inference + >>> response = model.generate("Hello!") + + Note: + This method creates a BaseInferenceModelConfig with default values + for unspecified parameters. For more detailed configuration, create + a custom config object and use the regular constructor. """ config = BaseInferenceModelConfig(model_name=model_name, n_gpus=n_gpus, dtype=dtype) return cls(config) diff --git a/easyroutine/interpretability/activation_cache.py b/easyroutine/interpretability/activation_cache.py index 764f59e..bd93f67 100644 --- a/easyroutine/interpretability/activation_cache.py +++ b/easyroutine/interpretability/activation_cache.py @@ -1,3 +1,45 @@ +""" +Activation cache management for storing and manipulating model activations. + +This module provides the ActivationCache class and associated utilities for storing, +organizing, and manipulating activations extracted from transformer models during +interpretability analysis. The cache system is designed to handle large-scale +activation data efficiently while providing flexible access patterns. + +Key Features: + - Hierarchical storage of activations by layer and component type + - Automatic aggregation strategies for combining activations across batches + - Memory-efficient handling of large activation datasets + - Support for different data types and tensor shapes + - Token position mapping for precise activation indexing + - Lazy loading and CPU/GPU memory management + +Classes: + - ActivationCache: Main cache class for storing and accessing activations + - Various aggregation functions for different combination strategies + +The module supports different aggregation strategies for combining activations: + - just_old: Always use new values (replacement strategy) + - just_me: Accumulate values in lists + - sublist: Flatten and extend lists of activations + - aggregate_last_layernorm: Special handling for layer normalization outputs + +Activation Storage Format: + The cache stores activations as tensors with standardized naming conventions: + - "resid_out_0", "resid_out_1", ...: Residual stream outputs by layer + - "attn_pattern_0_5": Attention patterns for layer 0, head 5 + - "mlp_out_2": MLP outputs for layer 2 + - "input_ids": Original input token IDs + - "mapping_index": Token position mappings + +Example Usage: + >>> cache = ActivationCache() + >>> cache["resid_out_0"] = torch.randn(1, 10, 768) + >>> cache["mapping_index"] = {"last": [9]} + >>> print(cache.keys()) # Shows available activations + >>> resid_activations = cache["resid_out_0"] # Access specific activations +""" + import re import torch import contextlib @@ -103,9 +145,72 @@ def __repr__(self): class ActivationCache: """ - A dictionary-like cache for storing and aggregating model activation values. - Supports custom aggregation strategies registered for keys (by prefix match) - and falls back to a default aggregation that can dynamically switch types if needed. + A specialized dictionary-like container for storing and aggregating model activations. + + This class provides a flexible caching system specifically designed for handling + activation data from transformer models during interpretability analysis. It supports + automatic aggregation of activations across multiple forward passes, custom + aggregation strategies for different data types, and validation of activation keys. + + The cache system is built around the concept of registered aggregation strategies + that define how new activations should be combined with existing ones when the + same key is encountered multiple times (e.g., across different batches). + + Key Features: + - Dictionary-like interface for easy access to activations + - Automatic validation of activation keys using regex patterns + - Customizable aggregation strategies for different activation types + - Support for deferred caching to manage memory usage + - Built-in strategies for common activation patterns + + Supported Activation Types: + - Residual stream activations: "resid_out_N", "resid_in_N", "resid_mid_N" + - Attention activations: "attn_in_N", "attn_out_N" + - Attention patterns: "avg_attn_pattern_L{layer}H{head}", "pattern_L{layer}H{head}" + - MLP activations: "mlp_out_N" + - Special data: "input_ids", "mapping_index", "last_layernorm", "token_dict" + - Value vectors: "values_N" + + Aggregation Strategies: + The cache supports different strategies for combining multiple values: + - Default: Automatic type switching based on tensor compatibility + - just_old: Always replace with new values + - just_me: Accumulate values in lists + - sublist: Flatten and extend activation lists + - aggregate_last_layernorm: Special concatenation for layer norm outputs + + Attributes: + cache (dict): Internal storage for activation data + valid_keys (tuple): Compiled regex patterns for validating activation keys + aggregation_strategies (dict): Registered aggregation functions by key prefix + deferred_cache (bool): Whether to defer memory-intensive operations + + Example: + >>> cache = ActivationCache() + >>> + >>> # Store activations + >>> cache["resid_out_0"] = torch.randn(1, 10, 768) + >>> cache["mapping_index"] = {"last": [9]} + >>> + >>> # Access activations + >>> residual_activations = cache["resid_out_0"] + >>> token_positions = cache["mapping_index"] + >>> + >>> # Register custom aggregation + >>> cache.register_aggregation("custom_key", lambda old, new: new) + >>> + >>> # Check available activations + >>> print(list(cache.keys())) + >>> print(f"Cache contains {len(cache)} items") + + >>> # Aggregate across batches + >>> cache["resid_out_0"] = torch.randn(1, 10, 768) # First batch + >>> cache["resid_out_0"] = torch.randn(1, 10, 768) # Second batch (aggregated) + + Note: + The cache automatically validates activation keys against predefined patterns + to ensure data consistency. Invalid keys will raise warnings or errors + depending on the validation mode. """ def __init__(self): diff --git a/easyroutine/interpretability/hooked_model.py b/easyroutine/interpretability/hooked_model.py index 2226630..8ce2112 100644 --- a/easyroutine/interpretability/hooked_model.py +++ b/easyroutine/interpretability/hooked_model.py @@ -64,14 +64,63 @@ def load_config() -> dict: @dataclass class HookedModelConfig: """ - Configuration of the HookedModel - - Arguments: - model_name (str): the name of the model to load - device_map (Literal["balanced", "cuda", "cpu", "auto"]): the device to use for the model - torch_dtype (torch.dtype): the dtype of the model - attn_implementation (Literal["eager", "flash_attention_2"]): the implementation of the attention - batch_size (int): the batch size of the model. FOR NOW, ONLY BATCH SIZE 1 IS SUPPORTED. USE AT YOUR OWN RISK + Configuration class for HookedModel initialization and behavior. + + This dataclass contains all the configuration parameters needed to initialize + a HookedModel instance. It provides sensible defaults while allowing + customization of model loading parameters, device configuration, and + processing settings. + + Attributes: + model_name (str): The identifier of the model to load. Can be: + - A Hugging Face model repository name (e.g., "gpt2", "mistral-7b") + - A local path to a model directory + - Any model supported by transformers library + + device_map (Literal["balanced", "cuda", "cpu", "auto"], optional): + Device placement strategy for the model. Options: + - "balanced": Distribute model across available GPUs evenly + - "cuda": Place entire model on the first available GPU + - "cpu": Place model on CPU (slower but uses less GPU memory) + - "auto": Let transformers decide optimal placement + Defaults to "balanced". + + torch_dtype (torch.dtype, optional): The data type for model parameters. + Common options include torch.float16, torch.bfloat16, torch.float32. + bfloat16 provides good balance of speed and stability. + Defaults to torch.bfloat16. + + attn_implementation (Literal["eager", "custom_eager"], optional): + The attention mechanism implementation to use: + - "eager": Standard PyTorch attention implementation + - "custom_eager": Enhanced implementation with better hook support + The custom implementation is recommended for interpretability work + as it provides more comprehensive hook coverage. + Defaults to "custom_eager". + + batch_size (int, optional): The batch size for model inference. + Currently, only batch size 1 is fully supported and tested. + Using larger batch sizes may lead to unexpected behavior. + Defaults to 1. + + Example: + >>> config = HookedModelConfig( + ... model_name="gpt2", + ... device_map="auto", + ... torch_dtype=torch.float16 + ... ) + >>> model = HookedModel(config) + + >>> # Or use the convenience method + >>> model = HookedModel.from_pretrained( + ... "gpt2", + ... device_map="cuda", + ... torch_dtype=torch.bfloat16 + ... ) + + Warning: + Batch sizes greater than 1 are experimental and may not work correctly + with all interpretability features. Use at your own risk. """ model_name: str @@ -86,30 +135,153 @@ class HookedModelConfig: @dataclass class ExtractionConfig: """ - Configuration of the extraction of the activations of the model. It store what activations you want to extract from the model. - - Arguments: - extract_resid_in (bool): if True, extract the input of the residual stream - extract_resid_mid (bool): if True, extract the output of the intermediate stream - extract_resid_out (bool): if True, extract the output of the residual stream - extract_resid_in_post_layernorm(bool): if True, extract the input of the residual stream after the layernorm - extract_attn_pattern (bool): if True, extract the attention pattern of the attn - extract_head_values_projected (bool): if True, extract the values vectors projected of the model - extract_head_keys_projected (bool): if True, extract the key vectors projected of the model - extract_head_queries_projected (bool): if True, extract the query vectors projected of the model - extract_head_keys (bool): if True, extract the keys of the attention - extract_head_values (bool): if True, extract the values of the attention - extract_head_queries (bool): if True, extract the queries of the attention - extract_head_out (bool): if True, extract the output of the heads [DEPRECATED] - extract_attn_out (bool): if True, extract the output of the attention of the attn_heads passed - extract_attn_in (bool): if True, extract the input of the attention of the attn_heads passed - extract_mlp_out (bool): if True, extract the output of the mlp of the attn - save_input_ids (bool): if True, save the input_ids in the cache - avg (bool): if True, extract the average of the activations over the target positions - avg_over_example (bool): if True, extract the average of the activations over the examples (it required a external cache to save the running avg) - attn_heads (Union[list[dict], Literal["all"]]): list of dictionaries with the layer and head to extract the attention pattern or 'all' to - attn_pattern_avg (Literal["mean", "sum", "baseline_ratio", "none"]): the type of average to perform over the attention pattern. See hook.py attention_pattern_head for more details - attn_pattern_row_positions (Optional[Union[List[int], List[Tuple], List[str], List[Union[int, Tuple, str]]]): the row positions of the attention pattern to extract. See hook.py attention_pattern_head for more details + Configuration class for specifying which model activations to extract. + + This comprehensive configuration class allows fine-grained control over + what internal activations and computations should be extracted from the + model during inference. It supports extraction from various model components + including residual streams, attention mechanisms, MLP layers, and more. + + The configuration follows a boolean flag pattern where each attribute + specifies whether to extract a particular type of activation. This allows + for flexible composition of extraction requirements based on research needs. + + Residual Stream Activations: + extract_resid_in (bool): Extract activations flowing into residual connections. + These are the inputs to each transformer layer before processing. + Defaults to False. + + extract_resid_mid (bool): Extract intermediate activations within layers. + These represent computational states between attention and MLP processing. + Defaults to False. + + extract_resid_out (bool): Extract activations flowing out of residual connections. + These are the final outputs of each transformer layer. + Defaults to False. + + extract_resid_in_post_layernorm (bool): Extract residual inputs after layer + normalization. Useful for studying the effect of normalization. + Defaults to False. + + Attention Mechanism Activations: + extract_attn_pattern (bool): Extract attention weight matrices showing + which tokens attend to which other tokens. Essential for attention + analysis and visualization. Defaults to False. + + extract_attn_out (bool): Extract attention layer outputs before residual + connection. Shows the contribution of attention to each position. + Defaults to False. + + extract_attn_in (bool): Extract attention layer inputs. Useful for + studying how different inputs affect attention computations. + Defaults to False. + + Attention Head Components: + extract_head_values_projected (bool): Extract value vectors after + projection in multi-head attention. Shows what information each + head is passing forward. Defaults to False. + + extract_head_keys_projected (bool): Extract key vectors after projection. + Combined with queries, determines attention patterns. + Defaults to False. + + extract_head_queries_projected (bool): Extract query vectors after + projection. Used with keys to compute attention weights. + Defaults to False. + + extract_head_keys (bool): Extract raw key vectors before projection. + Lower-level view of attention computation. Defaults to False. + + extract_head_values (bool): Extract raw value vectors before projection. + Shows pre-projection value representations. Defaults to False. + + extract_head_queries (bool): Extract raw query vectors before projection. + Shows pre-projection query representations. Defaults to False. + + extract_head_out (bool): [DEPRECATED] Extract head outputs. + Use extract_attn_out instead. Defaults to False. + + Other Layer Components: + extract_mlp_out (bool): Extract MLP (feed-forward) layer outputs. + Shows the contribution of position-wise processing. Defaults to False. + + extract_embed (bool): Extract embedding layer outputs. Shows initial + token representations before transformer processing. Defaults to False. + + extract_last_layernorm (bool): Extract final layer normalization outputs. + Shows normalized representations before final predictions. + Defaults to False. + + Metadata and Processing Options: + save_input_ids (bool): Include input token IDs in the activation cache. + Useful for mapping activations back to specific tokens. + Defaults to False. + + save_logits (bool): Include model output logits in the cache. + Essential for studying model predictions. Defaults to True. + + keep_gradient (bool): Preserve gradient information in extracted + activations. Required for gradient-based analysis methods. + Defaults to False. + + Aggregation Options: + avg (bool): Compute average activations over specified target positions. + Reduces memory usage when only summary statistics are needed. + Defaults to False. + + avg_over_example (bool): Compute running average over multiple examples. + Requires external cache management for accumulation. Defaults to False. + + Attention Analysis Options: + attn_heads (Union[list[dict], Literal["all"]]): Specifies which attention + heads to extract patterns from. Can be: + - "all": Extract from all heads in all layers + - List of dicts: Specific heads, e.g., [{"layer": 0, "head": 5}] + Defaults to "all". + + attn_pattern_avg (Literal["mean", "sum", "baseline_ratio", "none"]): + How to aggregate attention patterns. Options: + - "mean": Average across specified dimensions + - "sum": Sum across specified dimensions + - "baseline_ratio": Ratio relative to baseline pattern + - "none": No aggregation, return full patterns + Defaults to "none". + + attn_pattern_row_positions (Optional[Union[List[int], List[Tuple], List[str], List[Union[int, Tuple, str]]]]): + Specific row positions in attention patterns to extract. + Can specify token positions, ranges, or special position names. + If None, extracts full attention patterns. Defaults to None. + + Example: + >>> # Basic residual stream extraction + >>> config = ExtractionConfig( + ... extract_resid_out=True, + ... save_input_ids=True + ... ) + + >>> # Comprehensive attention analysis + >>> config = ExtractionConfig( + ... extract_attn_pattern=True, + ... extract_head_values_projected=True, + ... attn_heads=[{"layer": 0, "head": 0}, {"layer": 1, "head": 3}], + ... attn_pattern_avg="mean" + ... ) + + >>> # Full model analysis + >>> config = ExtractionConfig( + ... extract_resid_in=True, + ... extract_resid_out=True, + ... extract_attn_pattern=True, + ... extract_mlp_out=True, + ... save_logits=True, + ... save_input_ids=True + ... ) + + Note: + Extracting many activation types can significantly increase memory usage + and computation time. Enable only the activations needed for your analysis. + The is_not_empty() method can be used to verify that at least one + extraction option is enabled. """ extract_embed: bool = False @@ -142,7 +314,34 @@ class ExtractionConfig: def is_not_empty(self): """ - Return True if at least one of the attributes is True, False otherwise, i.e. if the model should extract something! + Check if any extraction options are enabled in this configuration. + + This method validates that at least one activation extraction flag is set + to True, ensuring that the configuration will actually extract some data + when used with a HookedModel. This is useful for validation before + running expensive extraction operations. + + Returns: + bool: True if at least one extraction option is enabled (any attribute + is True), False if all extraction options are disabled. + + Example: + >>> config = ExtractionConfig() # All defaults (mostly False) + >>> config.is_not_empty() + True # save_logits is True by default + + >>> config_empty = ExtractionConfig(save_logits=False) + >>> config_empty.is_not_empty() + False + + >>> config_active = ExtractionConfig(extract_resid_out=True) + >>> config_active.is_not_empty() + True + + Note: + This method checks all boolean extraction flags. It's recommended + to call this before expensive extraction operations to avoid + unnecessary computation when no activations would be extracted. """ return any( [ @@ -167,13 +366,77 @@ def is_not_empty(self): ) def to_dict(self): - return self.__dict__ + """ + Convert the configuration to a dictionary representation. + + Returns: + dict: Dictionary containing all configuration attributes as key-value pairs. + Useful for serialization, logging, or passing to other functions + that expect dictionary arguments. + + Example: + >>> config = ExtractionConfig(extract_resid_out=True, save_logits=True) + >>> config_dict = config.to_dict() + >>> print(config_dict) + {'extract_embed': False, 'extract_resid_out': True, ...} + """ class HookedModel: """ - This class is a wrapper around the huggingface model that allows to extract the activations of the model. It is support - advanced mechanistic intepretability methods like ablation, patching, etc. + A comprehensive wrapper around Hugging Face transformer models for mechanistic interpretability. + + This class provides advanced functionality for extracting internal activations from transformer + models and performing mechanistic interpretability methods such as ablation studies, activation + patching, and intervention analysis. It supports both language models and vision-language models + with automatic module detection and custom attention implementations. + + The HookedModel class serves as the primary interface for interpretability research, offering: + - Automatic model loading and configuration + - Activation extraction from any model component + - Support for intervention and ablation studies + - Custom attention implementations for better hook support + - Batch processing capabilities + - Multi-device support + + Key Features: + - Extract activations from residual streams, attention layers, MLP layers + - Support for attention pattern analysis and head-specific extractions + - Intervention capabilities for causal analysis + - Automatic tokenizer and processor handling + - Support for vision-language models with image processing + - Custom eager attention implementation for comprehensive hook support + + Attributes: + config (HookedModelConfig): Configuration object containing model settings + hf_model: The underlying Hugging Face model + hf_language_model: The language model component (for vision-language models) + model_config: Internal model configuration for hook management + hf_tokenizer: The model's tokenizer + processor: Optional processor for vision-language models + text_tokenizer: Text tokenizer component + input_handler: Handles input preprocessing based on model type + module_wrapper_manager: Manages custom module wrappers + intervention_manager: Handles intervention operations + + Example: + >>> # Basic model loading + >>> model = HookedModel.from_pretrained("gpt2") + >>> + >>> # Extract activations + >>> cache = model.extract_cache( + ... inputs, + ... target_token_positions=["last"], + ... extraction_config=ExtractionConfig(extract_resid_out=True) + ... ) + >>> + >>> # Perform interventions + >>> result = model.run_with_interventions(inputs, interventions) + + Note: + The model uses custom eager attention implementation by default to ensure + comprehensive hook support. This can be disabled by setting + attn_implementation="eager" in the configuration. """ def __init__(self, config: HookedModelConfig, log_file_path: Optional[str] = None): @@ -243,7 +506,19 @@ def __repr__(self): @classmethod def from_pretrained(cls, model_name: str, **kwargs): - return cls(HookedModelConfig(model_name=model_name, **kwargs)) + """ + Create a HookedModel instance from a pretrained model. + + Args: + model_name (str): Model name or path. + **kwargs: Additional arguments for HookedModelConfig. + + Returns: + HookedModel: Initialized model instance. + + Example: + >>> model = HookedModel.from_pretrained("gpt2") + """ def assert_module_exists(self, component: str): # Remove '.input' or '.output' from the component @@ -333,17 +608,59 @@ def use_language_model_only(self): logger.debug("HookedModel: Using only language model capabilities") def get_tokenizer(self): - return self.hf_tokenizer + """ + Get the primary tokenizer associated with this model. + + Returns the tokenizer that was loaded during model initialization. + For vision-language models, this may be a processor that includes + both text tokenization and image processing capabilities. + + Returns: + Union[transformers.PreTrainedTokenizer, transformers.ProcessorMixin]: + The tokenizer or processor associated with the model. + + Example: + >>> model = HookedModel.from_pretrained("gpt2") + >>> tokenizer = model.get_tokenizer() + >>> tokens = tokenizer("Hello world", return_tensors="pt") + + Note: + For text-only models, this returns a standard tokenizer. + For vision-language models, this may return a processor that + handles both text and image inputs. Use get_text_tokenizer() + if you specifically need the text tokenization component. + """ def get_text_tokenizer(self): - r""" - If the tokenizer is a processor, return just the tokenizer. If the tokenizer is a tokenizer, return the tokenizer - - Args: - None - + """ + Get the text tokenization component of the model's tokenizer. + + For vision-language models that use a processor (which combines text + tokenization and image processing), this method extracts and returns + just the text tokenizer component. For text-only models, this returns + the same tokenizer as get_tokenizer(). + Returns: - tokenizer: the tokenizer of the model + transformers.PreTrainedTokenizer: The text tokenizer component. + + Raises: + ValueError: If the model uses a processor that doesn't have a + tokenizer attribute. + + Example: + >>> # For a vision-language model + >>> model = HookedModel.from_pretrained("llava-v1.6-mistral-7b-hf") + >>> text_tokenizer = model.get_text_tokenizer() + >>> tokens = text_tokenizer("Hello world", return_tensors="pt") + + >>> # For a text-only model (same as get_tokenizer()) + >>> model = HookedModel.from_pretrained("gpt2") + >>> text_tokenizer = model.get_text_tokenizer() + + Note: + This method is particularly useful when you need to perform + text-specific operations on vision-language models where the + primary tokenizer is actually a multimodal processor. """ if self.processor is not None: if not hasattr(self.processor, "tokenizer"): @@ -1251,31 +1568,108 @@ def extract_cache( # save_other_batch_elements: bool = False, **kwargs, ): - r""" - Method to extract the activations of the model from a specific dataset. Compute a forward pass for each batch of the dataloader and save the activations in the cache. - - Arguments: - - dataloader (iterable): dataloader with the dataset. Each element of the dataloader must be a dictionary that contains the inputs that the model expects (input_ids, attention_mask, pixel_values ...) - - extracted_token_position (Union[Union[str, int, Tuple[int, int]], List[Union[str, int, Tuple[int, int]]]]): list of tokens to extract the activations from (["last", "end-image", "start-image", "first", -1, (2,10)]). See TokenIndex.get_token_index for more details - - batch_saver (Callable): function to save in the cache the additional element from each elemtn of the batch (For example, the labels of the dataset) - - move_to_cpu_after_forward (bool): if True, move the activations to the cpu right after the any forward pass of the model - - dict_token_index (Optional[torch.Tensor]): If provided, specifies the index in the vocabulary for which to compute gradients of logits with respect to input embeddings. Requires extraction_config.extract_input_embeddings_for_grad to be True. - - **kwargs: additional arguments to control hooks generation, basically accept any argument handled by the `.forward` method (i.e. ablation_queries, patching_queries, extract_resid_in) - + """ + Extract internal model activations from a dataset using forward passes. + + This is the primary method for extracting activations from the model for + interpretability analysis. It processes each batch in the dataloader, + performs forward passes while capturing specified activations, and + aggregates results into a comprehensive activation cache. + + The method supports flexible extraction configurations, allowing users to + specify exactly which model components to monitor and which token positions + to extract from. It can handle both text-only and multimodal inputs. + + Args: + dataloader (Iterable[Dict]): An iterable containing batches of model inputs. + Each element must be a dictionary containing the inputs that the model + expects (e.g., input_ids, attention_mask, pixel_values for VLMs). + Common format: {"input_ids": tensor, "attention_mask": tensor, ...} + + target_token_positions (Union[List[Union[str, int, Tuple[int, int]]], ...]): + Specification of which token positions to extract activations from. + Supports multiple formats: + - Strings: "last", "first", "end-image", "start-image", "all", etc. + - Integers: Specific token indices (e.g., -1 for last token) + - Tuples: Token ranges (e.g., (2, 10) for positions 2 through 10) + - Mixed lists: Combinations of the above types + See TokenIndex.get_token_index for complete specification. + + extraction_config (ExtractionConfig, optional): Configuration object + specifying which activations to extract from the model. + Controls extraction from residual streams, attention mechanisms, + MLP layers, etc. Defaults to ExtractionConfig() (basic config). + + interventions (List[Intervention], optional): List of intervention + objects to apply during forward passes. Enables ablation studies, + activation patching, and other causal analysis methods. + Defaults to None (no interventions). + + batch_saver (Callable, optional): Function to extract and save additional + information from each batch element (e.g., labels, metadata). + Should take a batch element and return a dictionary of items to save. + Defaults to lambda x: None (no additional saving). + + move_to_cpu_after_forward (bool, optional): Whether to move extracted + activations to CPU immediately after each forward pass. Helps + manage GPU memory usage for large datasets. Defaults to True. + + **kwargs: Additional keyword arguments passed to the forward method. + Can include ablation_queries, patching_queries, and other parameters + for controlling hook generation and intervention behavior. + Returns: - final_cache: dictionary with the activations of the model. The keys are the names of the activations and the values are the activations themselve - - Examples: - >>> dataloader = [{"input_ids": torch.tensor([[101, 1234, 1235, 102]]), "attention_mask": torch.tensor([[1, 1, 1, 1]]), "labels": torch.tensor([1])}, ...] - >>> model.extract_cache(dataloader, extracted_token_position=["last"], batch_saver=lambda x: {"labels": x["labels"]}) - {'resid_out_0': tensor([[[0.1, 0.2, 0.3, 0.4]]], grad_fn=), 'labels': tensor([1]), 'mapping_index': {'last': [0]}} + ActivationCache: A comprehensive cache object containing all extracted + activations organized by activation type and layer. The cache includes: + - Activation tensors keyed by component names (e.g., 'resid_out_0') + - Token position mappings indicating which positions were extracted + - Additional batch elements saved via batch_saver + - Metadata about the extraction process + + Example: + >>> # Basic usage: extract residual stream outputs + >>> dataloader = [ + ... {"input_ids": torch.tensor([[101, 1234, 1235, 102]]), + ... "attention_mask": torch.tensor([[1, 1, 1, 1]])}, + ... # ... more batches + ... ] + >>> + >>> config = ExtractionConfig(extract_resid_out=True, save_input_ids=True) + >>> cache = model.extract_cache( + ... dataloader, + ... target_token_positions=["last"], + ... extraction_config=config + ... ) + >>> print(cache.keys()) # ['resid_out_0', 'resid_out_1', ..., 'input_ids'] + + >>> # Advanced usage: extract attention patterns with interventions + >>> config = ExtractionConfig( + ... extract_attn_pattern=True, + ... extract_resid_out=True, + ... attn_heads=[{"layer": 0, "head": 5}] + ... ) + >>> interventions = [some_intervention_object] + >>> cache = model.extract_cache( + ... dataloader, + ... target_token_positions=["last", (5, 10)], + ... extraction_config=config, + ... interventions=interventions, + ... batch_saver=lambda x: {"labels": x.get("labels", None)} + ... ) + + Note: + - Large datasets may require careful memory management via move_to_cpu_after_forward + - The extraction_config must have is_not_empty() == True to extract anything + - Token position specifications are flexible and support various analysis needs + - Interventions enable causal analysis but may slow down extraction + + See Also: + - ExtractionConfig: For configuring which activations to extract + - TokenIndex.get_token_index: For token position specification details + - ActivationCache: For working with the returned activation data """ logger.info("HookedModel: Extracting cache") - - # get the function to save in the cache the additional element from the batch sime - - logger.info("HookedModel: Forward pass started") all_cache = ActivationCache() # a list of dictoionaries, each dictionary contains the activations of the model for a batch (so a dict of tensors) attn_pattern = ( ActivationCache() diff --git a/easyroutine/logger.py b/easyroutine/logger.py index a380239..1855082 100644 --- a/easyroutine/logger.py +++ b/easyroutine/logger.py @@ -16,7 +16,33 @@ def warning_once(message: str): """ Logs a warning message only once per runtime session. - Subsequent calls with the same message will be ignored. + + This function prevents duplicate warning messages from cluttering the logs by + tracking which messages have already been logged. Subsequent calls with the + same message will be silently ignored. + + Args: + message (str): The warning message to log. Only the first occurrence + of this exact message will be logged during the program's runtime. + + Returns: + None: This function has no return value. + + Side Effects: + - Logs the warning message to the easyroutine logger (first occurrence only) + - Adds the message to an internal set to track logged messages + + Example: + >>> warning_once("This is a warning") + # Logs: WARNING - This is a warning + >>> warning_once("This is a warning") + # No output - message already logged + >>> warning_once("Different warning") + # Logs: WARNING - Different warning + + Note: + The tracking of logged messages persists for the entire program runtime. + Messages are compared for exact string equality. """ if message not in _logged_once_messages: _logged_once_messages.add(message) @@ -27,8 +53,32 @@ def warning_once(message: str): def setup_default_logging(): """ - Set up default logging for easyroutine. This ensures that INFO-level logs are printed - to the console using RichHandler, while allowing the user to override this configuration. + Set up default logging configuration for the easyroutine package. + + This function initializes the easyroutine logger with sensible defaults, + ensuring that INFO-level logs are displayed to the console using RichHandler + for enhanced formatting. The configuration can be overridden by calling + setup_logging() with custom parameters. + + The default configuration includes: + - Log level: INFO + - Handler: RichHandler with rich tracebacks and markup support + - Formatter: Simple message format (RichHandler handles its own formatting) + - Propagation: Disabled to prevent duplicate messages in root logger + + Returns: + None: This function configures the logger in-place. + + Side Effects: + - Configures the global easyroutine logger + - Adds a console handler if none exists + - Sets logger level to INFO + - Disables log propagation to root logger + + Note: + This function only adds handlers if the logger doesn't already have any, + preventing duplicate handlers from being added if called multiple times. + This function is called automatically when the logger module is imported. """ if not logger.hasHandlers(): # Avoid adding multiple handlers logger.setLevel(logging.INFO) # Default level (user can change) @@ -50,16 +100,57 @@ def setup_default_logging(): def setup_logging(level="INFO", file=None, console=True, fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s"): """ - Configure logging for easyroutine. - + Configure comprehensive logging settings for the easyroutine package. + + This function provides full control over the logging configuration, allowing + users to customize logging levels, output destinations, and message formatting. + It replaces any existing handlers with the new configuration. + Args: - level (str): Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL). Default: INFO. - file (str, optional): Path to log file. If None, logs are not saved to a file. - console (bool): Whether to log to the console. Default: True. - fmt (str): Log message format. Default: standard logging format. + level (str, optional): The logging level to set. Must be one of: + 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'. + Defaults to 'INFO'. Case-insensitive. + file (str, optional): Path to a log file where messages should be saved. + If None, no file logging is performed. If provided, creates a + FileHandler to write logs to this file. Defaults to None. + console (bool, optional): Whether to enable console logging output. + If True, logs are displayed in the terminal using RichHandler. + Defaults to True. + fmt (str, optional): Log message format string for file output. + Uses standard Python logging format specifiers. Note that console + output uses RichHandler's built-in formatting. Defaults to a + standard format with timestamp, logger name, level, and message. + + Returns: + None: This function configures the logger in-place. + + Side Effects: + - Clears all existing handlers from the easyroutine logger + - Sets the logger level according to the 'level' parameter + - Adds a FileHandler if 'file' is specified + - Adds a RichHandler for console output if 'console' is True + - Logs a confirmation message about the new configuration + + Example: + >>> # Basic setup with file logging + >>> setup_logging(level="DEBUG", file="debug.log") + + >>> # Console-only logging with custom format + >>> setup_logging(level="WARNING", console=True, file=None) + + >>> # Both file and console with custom format + >>> setup_logging( + ... level="INFO", + ... file="app.log", + ... console=True, + ... fmt="[%(asctime)s] %(levelname)s: %(message)s" + ... ) - Example Usage: - setup_logging(level="DEBUG", file="easyroutine.log", console=True) + Note: + - Invalid level names will default to INFO level + - RichHandler provides enhanced console output with colors and formatting + - File handler uses the provided format, console uses RichHandler's format + - This function completely replaces the existing logging configuration """ # Clear any existing handlers (to prevent duplicates) @@ -90,7 +181,28 @@ def setup_logging(level="INFO", file=None, console=True, fmt="%(asctime)s - %(na def enable_debug_logging(): """ - Enable debug logging for easyroutine. Prints all DEBUG-level logs. + Enable debug-level logging for the easyroutine package. + + This convenience function sets the logger and all its handlers to DEBUG level, + which is the most verbose logging level. This will display all log messages + including DEBUG, INFO, WARNING, ERROR, and CRITICAL messages. + + Returns: + None: This function modifies the logger configuration in-place. + + Side Effects: + - Sets the easyroutine logger level to DEBUG + - Sets all existing handlers to DEBUG level + - Logs a debug message confirming the change + + Example: + >>> enable_debug_logging() + # DEBUG - Debug logging enabled for easyroutine. + + Note: + This function operates on the existing handlers. If no handlers are + configured, you may need to call setup_logging() or setup_default_logging() + first to ensure log messages are displayed. """ logger.setLevel(logging.DEBUG) for handler in logger.handlers: @@ -99,7 +211,28 @@ def enable_debug_logging(): def enable_info_logging(): """ - Enable info logging for easyroutine. Prints all INFO-level logs. + Enable info-level logging for the easyroutine package. + + This convenience function sets the logger and all its handlers to INFO level, + which will display INFO, WARNING, ERROR, and CRITICAL messages while + filtering out DEBUG messages. + + Returns: + None: This function modifies the logger configuration in-place. + + Side Effects: + - Sets the easyroutine logger level to INFO + - Sets all existing handlers to INFO level + - Logs an info message confirming the change + + Example: + >>> enable_info_logging() + # INFO - Info logging enabled for easyroutine. + + Note: + This is typically the default logging level for most applications. + This function operates on existing handlers - ensure handlers are + configured before calling this function. """ logger.setLevel(logging.INFO) for handler in logger.handlers: @@ -108,7 +241,28 @@ def enable_info_logging(): def enable_warning_logging(): """ - Enable warning logging for easyroutine. Prints all WARNING-level logs. + Enable warning-level logging for the easyroutine package. + + This convenience function sets the logger and all its handlers to WARNING level, + which will display only WARNING, ERROR, and CRITICAL messages while + filtering out DEBUG and INFO messages. This is useful for quieter operation + where only important issues should be reported. + + Returns: + None: This function modifies the logger configuration in-place. + + Side Effects: + - Sets the easyroutine logger level to WARNING + - Sets all existing handlers to WARNING level + - Logs a warning message confirming the change + + Example: + >>> enable_warning_logging() + # WARNING - Warning logging enabled for easyroutine. + + Note: + This level is useful for production environments where you want to + reduce log verbosity but still see potential issues. """ logger.setLevel(logging.WARNING) for handler in logger.handlers: @@ -118,7 +272,31 @@ def enable_warning_logging(): def disable_logging(): """ - Disable all logging for easyroutine. + Disable all logging output for the easyroutine package. + + This function effectively turns off all logging by setting the logger level + to a value higher than CRITICAL, ensuring that no log messages will be + displayed regardless of their severity. This is useful for silent operation + or when logging output is not desired. + + Returns: + None: This function modifies the logger configuration in-place. + + Side Effects: + - Sets the easyroutine logger level to CRITICAL + 1 (effectively off) + - Sets all existing handlers to CRITICAL + 1 level + - Attempts to log a final info message (which may not be displayed + due to the timing of when the level change takes effect) + + Example: + >>> disable_logging() + # All subsequent log messages will be suppressed + + Note: + To re-enable logging after calling this function, use one of the + enable_*_logging() functions or call setup_logging() with desired parameters. + The final "logging disabled" message may or may not appear depending on + when the level change takes effect relative to the logging call. """ logger.setLevel(logging.CRITICAL + 1) # Effectively turns off logging for handler in logger.handlers: diff --git a/easyroutine/utils.py b/easyroutine/utils.py index b077775..d6ff74f 100644 --- a/easyroutine/utils.py +++ b/easyroutine/utils.py @@ -3,10 +3,36 @@ def path_to_parents(levels=1): """ Change the current working directory to its parent directory. - This is equivalent to %cd ../ - level (int): Number of levels to go up in the directory tree. - for example, if level=2, the function will go up two levels. (i.e. %cd ../../) + This function navigates up the directory tree by the specified number of levels, + changing the current working directory to the parent (or ancestor) directory. + This is equivalent to executing 'cd ../' or 'cd ../../' commands depending on the levels. + + Args: + levels (int, optional): Number of levels to go up in the directory tree. + Defaults to 1. For example: + - levels=1: Go up one level (equivalent to 'cd ../') + - levels=2: Go up two levels (equivalent to 'cd ../../') + - levels=3: Go up three levels (equivalent to 'cd ../../../') + + Returns: + None: This function modifies the current working directory in-place. + + Side Effects: + - Changes the current working directory + - Prints the new working directory path to stdout + + Example: + >>> import os + >>> print(os.getcwd()) # /home/user/project/subdir + >>> path_to_parents(1) + Changed working directory to: /home/user/project + >>> path_to_parents(2) # From /home/user/project + Changed working directory to: /home/user + + Note: + The function assumes that the specified number of parent directories exist. + If not enough parent directories exist, the behavior is undefined. """ current_dir = os.getcwd() parent_dir = os.path.dirname(current_dir) @@ -22,7 +48,42 @@ def path_to_relative(relative_path): """ Change the current working directory to a relative path. - relative_path (str): The relative path to change the working directory to. + This function navigates to a new directory specified by a relative path from + the current working directory. The function constructs the absolute path by + joining the current directory with the provided relative path. + + Args: + relative_path (str): The relative path to change the working directory to. + This can be: + - A subdirectory name (e.g., 'subdir') + - A path with multiple directories (e.g., 'subdir/nested') + - A path using '..' to go up levels (e.g., '../sibling') + - Any valid relative path string + + Returns: + None: This function modifies the current working directory in-place. + + Side Effects: + - Changes the current working directory + - Prints the new working directory path to stdout + + Example: + >>> import os + >>> print(os.getcwd()) # /home/user/project + >>> path_to_relative('data') + Changed working directory to: /home/user/project/data + >>> path_to_relative('../scripts') + Changed working directory to: /home/user/project/scripts + >>> path_to_relative('models/trained') + Changed working directory to: /home/user/project/scripts/models/trained + + Raises: + OSError: If the specified relative path does not exist or is not accessible. + NotADirectoryError: If the path exists but is not a directory. + + Note: + The function does not validate if the target directory exists before attempting + to change to it. If the directory doesn't exist, os.chdir() will raise an exception. """ current_dir = os.getcwd() new_dir = os.path.join(current_dir, relative_path)