Skip to content

feat: extending cache activation runner options for tokenization, batching, and shuffling#612

Open
chanind wants to merge 8 commits intomainfrom
harmize-cache-activations-options
Open

feat: extending cache activation runner options for tokenization, batching, and shuffling#612
chanind wants to merge 8 commits intomainfrom
harmize-cache-activations-options

Conversation

@chanind
Copy link
Copy Markdown
Collaborator

@chanind chanind commented Jan 3, 2026

This PR adds disable_concat_sequences and sequence_separator_token options to the cache activations runner, so it's possible to cache activations with the same options that can be used for the standard LLM training runner. In addition, this PR adds a shuffle_across_sequences option to allow the runner to fully shuffle activations regardless of which sequence they come from.

This PR also fixes a bug where the dataset saved to disk would not be shuffle even if shuffle=True is set; only the dataset that is uploaded to huggingface and returned by the runner would be shuffled.

@claude
Copy link
Copy Markdown

claude bot commented Jan 3, 2026

PR Review: Harmonize cache activation runner options with LLM runner

Summary

This PR adds three new configuration options to CacheActivationsRunnerConfig to align with the LLM training runner (disable_concat_sequences, sequence_separator_token, exclude_special_tokens) and fixes a bug where shuffled datasets weren't being saved to disk. Overall, this is a solid improvement that enhances consistency across the codebase.

@claude
Copy link
Copy Markdown

claude bot commented Jan 3, 2026

Code Quality & Best Practices ✅

Strengths:

  • Clean, well-structured code that follows existing patterns in the codebase
  • Good use of existing utilities (get_special_token_ids)
  • Proper handling of type conversions (bool → None/list, list → torch.Tensor)
  • Typo fix: "includeing" → "including" in cache_activations_runner.py:237

Minor Suggestions:

  1. Code duplication: The logic for handling exclude_special_tokens conversion (lines 32-40 in cache_activations_runner.py) is duplicated from activations_store.py:116-124. Consider extracting this into a shared utility function to reduce duplication and improve maintainability.

  2. Validation: The validation in config.py:573-576 for exclude_special_tokens list could be stricter. Consider adding validation to ensure token IDs are non-negative integers.

@claude
Copy link
Copy Markdown

claude bot commented Jan 3, 2026

Potential Bugs & Issues ⚠️

Shuffle Implementation Concern (lines 304-317 in cache_activations_runner.py):

The shuffle fix has a potential failure point. If the deletion loop fails partway through (e.g., permission error, disk full), the dataset could be left in an inconsistent state:

for item in final_cached_activation_path.iterdir():
    if item.name != ".shuffled":
        if item.is_dir():
            shutil.rmtree(item)  # Could fail here
        else:
            item.unlink()

Recommendation: Use a more atomic approach with backup/rollback, or at minimum wrap operations in try-except with logging to handle failures gracefully and avoid data corruption.

@claude
Copy link
Copy Markdown

claude bot commented Jan 3, 2026

Performance Considerations 📊

Shuffle Performance:

  • The shuffle operation now involves: shuffle → save to temp → delete old → move temp → reload
  • This doubles the I/O operations for shuffled datasets
  • For large datasets (multi-GB), this could be significant, but it's an acceptable tradeoff for correctness
  • Positive: The implementation correctly reloads from disk (line 317), ensuring consistency

Memory Efficiency:

  • The exclude_special_tokens tensor is correctly placed on CPU before being passed to ActivationsStore
  • No unnecessary copies or device transfers detected

@claude
Copy link
Copy Markdown

claude bot commented Jan 3, 2026

Security Concerns 🔒

Low Risk Overall:

  • Path operations are safe - final_cached_activation_path / item.name uses items from iterdir() which returns relative names, preventing path traversal
  • The exclude_special_tokens validation is good, though bounds checking for token IDs could be added to prevent edge cases with invalid token IDs

@claude
Copy link
Copy Markdown

claude bot commented Jan 3, 2026

Test Coverage 🧪

Excellent coverage for the shuffle fix:
✅ Comprehensive test in test_cache_activations_runner_shuffled_saved_to_disk
✅ Tests verify both returned dataset AND on-disk dataset match
✅ Tests verify data is actually shuffled (not just relabeled)

Missing coverage:
No integration tests for the new config parameters (disable_concat_sequences, sequence_separator_token, exclude_special_tokens) in the cache activations runner context. While these are tested elsewhere (e.g., test_tokenization_and_batching.py), it would be valuable to have at least one integration test showing they work correctly when passed through CacheActivationsRunnerConfig.

No error handling tests for the shuffle operation (e.g., what happens if disk is full during shuffle?)

Recommendation: Add a simple integration test to verify the new parameters work end-to-end.

@claude
Copy link
Copy Markdown

claude bot commented Jan 3, 2026

Overall Assessment

Score: 8.5/10

This is a well-implemented PR that addresses real issues:

  1. ✅ Fixes a genuine bug with shuffle not persisting to disk
  2. ✅ Improves consistency across runners
  3. ✅ Good test coverage for the shuffle fix
  4. ✅ Follows existing code patterns
  5. ✅ Clear documentation in PR description

Recommendations before merge:

  1. High Priority: Add error handling around the shuffle disk operations (lines 304-317) to prevent data loss on failure
  2. Medium Priority: Add integration test for the new config parameters
  3. Low Priority: Extract the exclude_special_tokens conversion logic to a shared utility to reduce code duplication

The bug fix alone makes this PR valuable, and the new parameters are a nice addition for consistency. With the error handling improvement for the shuffle operations, this would be ready to merge!

@claude
Copy link
Copy Markdown

claude bot commented Jan 3, 2026

Code Review

Thank you for this PR! The changes look solid overall. Here's my detailed feedback:

✅ Positive Aspects

  1. Good alignment goal: Harmonizing options between the cache activations runner and LLM training runner improves consistency
  2. Important bug fix: The shuffle bug fix is valuable - ensuring the shuffled dataset is persisted to disk matches expected behavior
  3. Comprehensive test coverage: The new test test_cache_activations_runner_shuffled_saved_to_disk thoroughly validates the shuffle fix
  4. Clean parameter passing: The new parameters are properly threaded through to ActivationsStore
  5. Minor improvement: Fixed typo "includeing" → "including" in sae_lens/cache_activations_runner.py:227

🐛 Potential Issues

Critical: Missing exclude_special_tokens parameter

The PR description mentions adding exclude_special_tokens, but this parameter is not actually added to the code:

  • ✅ Added to config: disable_concat_sequences, sequence_separator_token
  • ❌ Missing from config: exclude_special_tokens
  • The parameter exists in LanguageModelSAERunnerConfig (sae_lens/config.py:281) and ActivationsStore.__init__ (sae_lens/training/activations_store.py:223)
  • It should be added to CacheActivationsRunnerConfig and passed to _mk_activations_store in sae_lens/cache_activations_runner.py:53

Recommendation: Add the missing parameter or update the PR description to remove mention of it.

Performance: Shuffle implementation could be optimized

The shuffle fix in sae_lens/cache_activations_runner.py:292-307 works but involves:

  1. Saving to a temp directory
  2. Deleting all original files
  3. Moving files from temp to original location
  4. Reloading the dataset

Potential concerns:

  • Could be slow for large datasets (double I/O)
  • Risk of data loss if interrupted mid-operation
  • The shutil.move across directory boundaries might copy instead of rename

Alternative approach: Consider using dataset.save_to_disk() with overwrite=True if supported, or investigate if there's a cleaner HuggingFace datasets API for this.

📝 Code Quality

  1. Device assignment (sae_lens/cache_activations_runner.py:31): Moving device = torch.device("cpu") to a separate line is fine but doesn't add value - consider reverting to inline for consistency
  2. Test quality: The tests are well-structured with clear assertions
  3. Documentation: The new config parameters have good docstrings

🧪 Test Coverage

Good coverage for shuffle fix, but consider adding:

  1. Test that verifies disable_concat_sequences is properly used
  2. Test that verifies sequence_separator_token is properly used
  3. Test for the shuffle operation with edge cases (empty dataset, single item, etc.)

🔒 Security

No security concerns identified.

⚡ Performance Considerations

Beyond the shuffle implementation noted above:

  • The reload of the dataset after shuffling (line 307) adds overhead
  • For very large datasets, this could be significant

Summary

Blocking issue: The exclude_special_tokens parameter mentioned in the PR description is not implemented.

Suggested action items:

  1. Add exclude_special_tokens to CacheActivationsRunnerConfig and pass it to ActivationsStore
  2. Consider optimizing the shuffle-to-disk implementation
  3. Add tests for the new disable_concat_sequences and sequence_separator_token parameters

Overall, this is a useful contribution once the missing parameter is addressed!

@codecov
Copy link
Copy Markdown

codecov bot commented Jan 3, 2026

Codecov Report

❌ Patch coverage is 86.36364% with 6 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
sae_lens/cache_activations_runner.py 84.61% 6 Missing ⚠️

📢 Thoughts on this report? Let us know!

@claude
Copy link
Copy Markdown

claude bot commented Jan 3, 2026

Code Review for PR #612

Summary

This PR adds two new configuration options (disable_concat_sequences and sequence_separator_token) to harmonize the cache activations runner with the LLM training runner, and fixes a bug where shuffled datasets were not being saved to disk. Overall, the implementation is solid with good test coverage.


Positive Aspects

  1. Good Test Coverage: The new test test_cache_activations_runner_shuffled_saved_to_disk thoroughly validates the shuffle-to-disk fix, checking both tokens and activations.

  2. Configuration Consistency: Adding disable_concat_sequences and sequence_separator_token to CacheActivationsRunnerConfig properly aligns with the LLM runner options, improving API consistency.

  3. Documentation: The docstring additions in config.py:499-500 clearly explain the new parameters.

  4. Bug Fix: Addresses a real issue where shuffle=True only affected the returned dataset, not the saved one.


Issues and Concerns

1. File System Race Condition Risk (Medium Priority)

Location: cache_activations_runner.py:294-309

The shuffle-to-disk implementation has a potential issue with atomicity. If the process crashes between deleting old files and moving new files, the dataset could be left in a corrupted state.

Recommendation: Use atomic rename operations to ensure the original data remains intact until the new version is ready. Consider renaming the old directory to a backup, then renaming the new directory to the final location, then removing the backup.


2. Missing Test for New Config Parameters (Medium Priority)

While the shuffle fix is well-tested, there are no tests that verify disable_concat_sequences and sequence_separator_token actually work correctly when passed to the cache runner.

Recommendation: Add a test that verifies these parameters are properly passed to the ActivationsStore.


3. Performance Implications

The shuffle-to-disk operation now involves saving the entire dataset to a temp location, deleting original files, moving all files from temp to final location, and reloading the dataset.

For large datasets (multi-GB), this could be slow and disk-intensive.

Recommendation: Add a log message about the shuffle-to-disk operation so users understand why it might take time for large datasets.


Security Concerns

No significant security issues identified. The file operations use proper path handling via Path objects.


Best Practices and Style

Good:

  • Import organization is correct
  • Type hints are properly maintained
  • Code follows existing patterns
  • Test quality is comprehensive

Minor:

  • The device variable at line 31 could be inlined since it is only used once
  • The typo fix (includeing to including) is good but unrelated to the main PR purpose

Recommendations Summary

Must Address:

  1. Fix the file system race condition in the shuffle-to-disk implementation
  2. Add tests for the new disable_concat_sequences and sequence_separator_token parameters

Should Consider:

  1. Add a log message during the shuffle-to-disk operation for user feedback
  2. Document performance implications for large datasets

Overall Assessment

This is a well-implemented PR that addresses a real bug and improves API consistency. The main concern is the file operation safety in the shuffle-to-disk logic, which should be addressed before merging to prevent potential data corruption in edge cases.

Estimated Risk: Medium (due to file system operations)
Code Quality: Good
Test Coverage: Good (but needs one additional test for new parameters)

@claude
Copy link
Copy Markdown

claude bot commented Jan 3, 2026

Code Review - PR #612

Summary

This PR adds disable_concat_sequences and sequence_separator_token configuration options to harmonize the cache activations runner with the LLM training runner, and fixes a bug where shuffled datasets weren't being saved to disk. Overall, this is a well-implemented PR with excellent test coverage.

Positive Highlights

1. Excellent Test Coverage

The PR includes comprehensive tests covering:

  • The original shuffling bug fix (test_cache_activations_runner_shuffled_saved_to_disk)
  • The new shuffle_across_sequences functionality
  • Reproducibility of shuffling with seeds
  • Configuration validation

The tests are well-structured and use appropriate assertions to verify behavior.

2. Bug Fix - Shuffled Datasets Not Saved

The fix for the shuffling bug (lines 294-311 in cache_activations_runner.py) correctly addresses the issue where dataset.shuffle() only shuffled the in-memory dataset but didn't persist it to disk. The solution of saving to a temporary location and then moving files is sound.

3. Good Documentation

  • Docstrings are clear and explain the new parameters well
  • The shuffle_across_sequences parameter has detailed explanations in both the config docstring and inline comments

Issues and Concerns

1. Critical: Potential Data Loss in Shuffle Logic ⚠️

Location: sae_lens/cache_activations_runner.py:300-306

The logic that removes files before moving the shuffled dataset could potentially cause data loss if something goes wrong during the move operation:

for item in final_cached_activation_path.iterdir():
    is_arrow_file = (
        item.name.startswith("data-") and item.suffix == ".arrow"
    )
    is_dataset_metadata = item.name in ("dataset_info.json", "state.json")
    if is_arrow_file or is_dataset_metadata:
        item.unlink()  # Delete before move completes

Recommendation: Use atomic move operations or implement a two-phase commit:

  1. Move all files from .shuffled to a different temp directory
  2. Only after all moves succeed, remove the old files
  3. Or use shutil.move with overwrite if the library supports it

Alternative approach:

# Move to backup location first
backup_path = final_cached_activation_path / ".backup"
backup_path.mkdir(exist_ok=True)
for item in final_cached_activation_path.iterdir():
    if is_arrow_file or is_dataset_metadata:
        shutil.move(str(item), str(backup_path / item.name))
# Move shuffled files
for item in shuffled_path.iterdir():
    shutil.move(str(item), str(final_cached_activation_path / item.name))
# Clean up backup
shutil.rmtree(backup_path)

2. Moderate: Inconsistent Shuffling Behavior

Location: sae_lens/cache_activations_runner.py:352-357 and lines 291-293

The PR introduces two types of shuffling:

  1. Sequence-level shuffling (line 293): dataset.shuffle(seed=self.cfg.seed) - uses HuggingFace's shuffle
  2. Position-level shuffling (line 354): Uses torch.randperm with a separate generator

Issue: These two shuffling operations use different random number generators:

  • Sequence shuffling uses HuggingFace's shuffle with the seed
  • Position shuffling uses a torch Generator seeded in __init__

Recommendation: Document this behavior clearly or consider whether both should use the same RNG for true reproducibility across different runs with distributed buffers.

3. Minor: Generator State Not Preserved Across Buffers

Location: sae_lens/cache_activations_runner.py:92

self._shuffle_generator = torch.Generator().manual_seed(self.cfg.seed)

The shuffle generator is initialized once, but since it's called in _create_shard (which is called multiple times in the loop at line 269), the generator state evolves. This means:

  • If you process 10 buffers, buffer 0 gets one random permutation, buffer 1 gets a different one (with evolved state)
  • This is likely the intended behavior but should be documented

Question: Is this the desired behavior, or should each buffer get an independent shuffle with the same seed? If the latter, you'd need to reset or create a new generator for each buffer.

4. Minor: Magic String in File Cleanup

Location: sae_lens/cache_activations_runner.py:304

is_dataset_metadata = item.name in ("dataset_info.json", "state.json")

These filenames are also referenced in _consolidate_shards (lines 197, 230). Consider extracting these as class constants:

DATASET_INFO_FILE = "dataset_info.json"
DATASET_STATE_FILE = "state.json"

5. Minor: Missing Import

Location: sae_lens/cache_activations_runner.py:3

The PR uses shutil.move (line 308) but I don't see it being imported. However, shutil is imported at line 3, so this is fine. Just noting that the specific functions being used aren't explicitly imported.

6. Documentation: Unclear Interaction Between Shuffle Options

The interaction between shuffle, shuffle_across_sequences, and the sequence-level dataset shuffle could be clearer:

  1. When shuffle=True and shuffle_across_sequences=False: Only sequences are shuffled
  2. When shuffle=True and shuffle_across_sequences=True: Both sequences AND positions within buffers are shuffled

This should be documented more explicitly in the docstring.

7. Performance: Double Save Operation

Location: sae_lens/cache_activations_runner.py:294-311

When shuffling is enabled, the dataset is saved twice:

  1. First in _consolidate_shards (line 287)
  2. Then again after shuffling (line 297)

Impact: This doubles the I/O time for large datasets.

Recommendation: Consider adding a parameter to _consolidate_shards to skip the initial save when shuffling will be applied, or shuffle before consolidation if possible.

Code Quality

Strengths:

  • Clear variable names
  • Good separation of concerns
  • Type hints are used consistently
  • Tests follow existing patterns in the codebase

Minor Style Issues:

  • Line 229: Typo fixed from "includeing" to "including" - good catch!

Security Concerns

No significant security issues identified. The file operations are appropriately scoped to the specified directory.

Performance Considerations

  1. Double I/O on shuffle (mentioned above)
  2. Memory usage during shuffle: The shuffle operation loads the entire dataset into memory temporarily. For very large datasets, this could be problematic. Consider documenting memory requirements.
  3. Random permutation: torch.randperm(n_activations) creates a full permutation array in memory, which is O(n) memory. This is fine for the use case but worth noting.

Recommendations Summary

High Priority:

  1. ⚠️ Fix the potential data loss issue in the shuffle file operations (use safer atomic operations)

Medium Priority:

  1. Document the interaction between different shuffle modes more clearly
  2. Consider the double-save performance issue for large datasets
  3. Clarify whether generator state evolution across buffers is intentional

Low Priority:

  1. Extract magic strings for dataset metadata filenames
  2. Add memory usage notes to documentation for large datasets

Overall Assessment

This is a solid PR that:

  • ✅ Adds useful functionality to harmonize configs
  • ✅ Fixes a real bug with shuffling
  • ✅ Includes excellent test coverage
  • ⚠️ Has one critical issue with file operations that should be addressed
  • ⚠️ Has some documentation and clarity issues that would benefit from improvement

The code is generally well-written and follows the project's conventions. With the file operation safety issue addressed, this would be ready to merge.

@chanind chanind changed the title feat: harmonize cache activation runner options with LLM runner, and fix shuffling feat: harmonize cache activation runner options with LLM runner, and improve shuffling Jan 4, 2026
@chanind chanind changed the title feat: harmonize cache activation runner options with LLM runner, and improve shuffling feat: extending cache activation runner options for tokenization, batching, and shuffling Jan 4, 2026
@claude
Copy link
Copy Markdown

claude bot commented Jan 4, 2026

Code Review - PR #612

Thanks for this PR! Overall this is a solid implementation that adds useful features. Here's my detailed review:

✅ Strengths

  1. Excellent Test Coverage: The PR includes comprehensive tests covering:

    • Shuffle persistence to disk (test_cache_activations_runner_shuffled_saved_to_disk)
    • Shuffle across sequences functionality (test_cache_activations_runner_shuffle_across_sequences)
    • Reproducibility (test_cache_activations_runner_shuffle_across_sequences_reproducible)
    • Config validation (test_cache_activations_runner_shuffle_across_sequences_requires_shuffle)
  2. Good Documentation: Config parameters are well-documented in docstrings

  3. Config Validation: Appropriate validation in config.py:573-577 ensuring shuffle_across_sequences requires shuffle=True

🐛 Potential Issues

1. File Handling Race Condition Risk (sae_lens/cache_activations_runner.py:296-311)

The shuffling save-to-disk implementation has a potential issue:

shuffled_path = final_cached_activation_path / ".shuffled"
dataset.save_to_disk(str(shuffled_path))
# Remove old unshuffled data and replace with shuffled
for item in final_cached_activation_path.iterdir():
    # ... delete files
for item in shuffled_path.iterdir():
    shutil.move(str(item), str(final_cached_activation_path / item.name))
shuffled_path.rmdir()

Issues:

  • If the process crashes between deleting old files and moving new ones, data is lost
  • No atomic operation guarantee
  • .shuffled directory might conflict if multiple processes run

Recommendation: Use a more robust approach:

import tempfile
# Create temp directory with unique name
with tempfile.TemporaryDirectory(dir=final_cached_activation_path.parent) as tmp_dir:
    shuffled_path = Path(tmp_dir) / "shuffled"
    dataset.save_to_disk(str(shuffled_path))
    # Atomic rename operation
    backup_path = final_cached_activation_path.parent / f"{final_cached_activation_path.name}.backup"
    final_cached_activation_path.rename(backup_path)
    try:
        shuffled_path.rename(final_cached_activation_path)
        shutil.rmtree(backup_path)
    except Exception:
        final_cached_activation_path.rename(backup_path)  # rollback
        raise

2. Generator State Issue (sae_lens/cache_activations_runner.py:92, 354)

The _shuffle_generator is initialized once in __init__ and reused across all shards:

self._shuffle_generator = torch.Generator().manual_seed(self.cfg.seed)

Issue: Each call to torch.randperm() advances the generator state, which means:

  • Shuffle order depends on the number of shards processed
  • Not truly reproducible if buffer sizes change
  • Different runs with different n_buffers will produce different shuffles even with same seed

Recommendation: Either:

  • Reset the generator for each shard with a deterministic seed derived from shard index
  • Or document this behavior clearly that reproducibility requires identical buffer configuration

3. Memory Usage Consideration (sae_lens/cache_activations_runner.py:293-311)

The shuffle implementation loads the entire dataset, shuffles it, saves it, deletes old files, and reloads:

dataset = dataset.shuffle(seed=self.cfg.seed)  # loads to memory
dataset.save_to_disk(str(shuffled_path))
# ... file operations
dataset = Dataset.load_from_disk(str(final_cached_activation_path))  # loads again

Issue: This could be memory-intensive for large cached activation datasets

Recommendation: Consider using dataset.shuffle(seed=self.cfg.seed, keep_in_memory=False) and check if you can avoid the reload at the end

4. Missing Import (sae_lens/cache_activations_runner.py:3)

shutil is imported but only used in the new shuffle code. Verify this import exists (it does in the diff, so this is fine).

🔧 Code Quality Suggestions

1. Magic String (sae_lens/cache_activations_runner.py:296)

.shuffled is a magic string. Consider defining as a constant:

SHUFFLE_TEMP_DIR = ".shuffled"

2. Typo Fix Already Applied

Good catch fixing "includeing" → "including" in the comment at line 229!

3. Shuffle Across Sequences Semantics

The current implementation in _create_shard shuffles at the buffer level (within each shard), but then the entire dataset is shuffled again at the dataset level. This means:

  • Activations are shuffled within each buffer during shard creation
  • Then all sequences are shuffled again when consolidating shards

Question: Is this the intended behavior? Should shuffle_across_sequences only apply within buffers, or globally? The current implementation seems to shuffle both ways, which might be redundant or confusing.

Recommendation: Clarify in the config docstring exactly when/where shuffling occurs with shuffle_across_sequences=True

4. Missing Edge Case Test

Consider adding a test for what happens when shuffle_across_sequences=True but the buffer has an odd number of activations that don't perfectly divide by context_size.

🔒 Security

No security concerns identified. The code doesn't introduce any injection vulnerabilities or unsafe file operations beyond the atomicity concern mentioned above.

⚡ Performance

  1. Redundant Disk I/O: The shuffle-save-reload pattern adds overhead. Could potentially be optimized.
  2. shuffle_across_sequences: The torch.randperm() call is efficient, but creating the permutation for large buffers could be memory-intensive.

📝 Documentation Suggestions

  1. Add a note in CacheActivationsRunnerConfig docstring explaining:

    • The difference between shuffle (sequence-level) and shuffle_across_sequences (activation-level)
    • That shuffle_across_sequences affects reproducibility based on buffer configuration
    • Performance implications of both options
  2. Consider adding an example in the docstring showing when you'd want to use disable_concat_sequences and sequence_separator_token

Summary

Recommended Actions Before Merge:

  1. 🔴 Critical: Fix the file handling atomicity issue in the shuffle save-to-disk implementation
  2. 🟡 Important: Clarify/fix the generator state reproducibility issue
  3. 🟡 Important: Clarify the semantics of shuffling at buffer vs dataset level
  4. 🟢 Nice-to-have: Address the memory optimization suggestions

The feature itself is valuable and the tests are comprehensive. With the above fixes, this will be a solid addition to the codebase!

@claude
Copy link
Copy Markdown

claude bot commented Jan 4, 2026

Code Review for PR #612

Summary

This PR adds important tokenization and shuffling options to the cache activations runner, bringing it in line with the standard LLM training runner. The implementation is generally solid with good test coverage, but there are a few areas that could be improved.

✅ Strengths

Code Quality

  • Well-documented: Clear docstrings explain the new parameters (disable_concat_sequences, sequence_separator_token, shuffle_across_sequences)
  • Good error handling: The atomic swap logic for shuffling (lines 312-325 in cache_activations_runner.py:312-325) includes proper rollback on failure
  • Clear naming: Constants like _TMP_SHARDS_DIR, _SHUFFLED_DIR, _BACKUP_SUFFIX make the code more maintainable
  • Type hints: Good use of type annotations throughout

Test Coverage

  • Comprehensive testing: 223 lines of new tests covering multiple scenarios
  • Edge cases covered: Tests for reproducibility, validation, and disk persistence
  • Good test structure: Tests verify both in-memory and on-disk behavior

Bug Fix

  • Important fix: Addresses the bug where shuffle=True only shuffled the returned dataset, not the one saved to disk

⚠️ Issues & Recommendations

1. Potential Data Corruption Risk (High Priority)

Location: cache_activations_runner.py:312-325

The atomic swap implementation has a subtle but critical issue. The rollback logic attempts to rename directories back, but shutil.rmtree() is used for cleanup. If the process is interrupted between rename operations, you could end up with orphaned directories.

Recommendation:

  • Consider using tempfile.TemporaryDirectory for the shuffled path, which ensures cleanup even if the process crashes
  • Add logging at each step of the atomic swap to help debug issues
  • Consider adding a cleanup function that runs on startup to remove any orphaned .shuffled or .backup directories from previous failed runs

Example:

# Check for and clean up orphaned directories on startup
for suffix in [_SHUFFLED_DIR, _BACKUP_SUFFIX]:
    orphaned = final_cached_activation_path.parent / (
        final_cached_activation_path.name + suffix
    )
    if orphaned.exists():
        logger.warning(f"Cleaning up orphaned directory: {orphaned}")
        shutil.rmtree(orphaned)

2. Performance Concern: Double Disk I/O (Medium Priority)

Location: cache_activations_runner.py:298-328

When shuffle=True, the code:

  1. Loads dataset into memory and shuffles (line 298)
  2. Saves shuffled dataset to disk at a temporary location (line 310)
  3. Performs atomic rename (lines 314-315)
  4. Reloads the entire dataset from disk (line 328)

The reload at line 328 seems unnecessary if the in-memory dataset object is already shuffled and valid.

Recommendation:
Remove the reload step unless there's a specific reason the dataset object becomes invalid after the rename operation:

# Only reload if necessary (e.g., if dataset references the old path internally)
# Otherwise, continue using the already-shuffled in-memory dataset

3. Validation Logic Inconsistency (Low Priority)

Location: config.py:573-577

The validation correctly prevents shuffle_across_sequences=True when shuffle=False, but the error message could be more informative about why this constraint exists.

Recommendation:
Enhance the error message:

raise ValueError(
    "shuffle_across_sequences=True requires shuffle=True. "
    "Set shuffle=True to enable shuffling across sequences. "
    "Note: shuffle_across_sequences performs within-buffer shuffling, "
    "while shuffle performs sequence-level shuffling of the final dataset."
)

4. Missing Edge Case in Tests (Low Priority)

The tests don't verify what happens if the shuffle operation fails partway through (e.g., disk full, permissions error during atomic swap).

Recommendation:
Add a test that simulates failure during the atomic swap to ensure the rollback logic works correctly.

5. Documentation Clarification (Low Priority)

Location: config.py:481-482

The docstring for shuffle_across_sequences is good, but could be clearer about when to use it vs. regular shuffle.

Recommendation:

shuffle_across_sequences (bool): Whether to shuffle individual activations across all sequence 
    positions within each buffer, breaking sequence boundaries. This is useful for training 
    scenarios where sequence order is not important. When True, treats each buffer as a flat 
    2D array and shuffles activation positions while keeping token_ids paired with activations.
    Requires shuffle=True.

🔒 Security Considerations

No significant security concerns. The code:

  • ✅ Doesn't execute arbitrary code
  • ✅ Uses safe file operations (no shell injection risks)
  • ✅ Properly handles paths (using Path objects)
  • ✅ No hardcoded credentials or secrets

⚡ Performance Considerations

  1. Memory usage: The shuffle operation loads the entire dataset into memory (line 298). For very large datasets, this could be problematic. Consider documenting memory requirements.

  2. Disk space: During shuffling, you temporarily need ~2x the dataset size (original + shuffled copy). This should be documented.

  3. Shuffle across sequences timing: The torch.randperm at cache_activations_runner.py:371 is called per buffer, which is good for memory but means the seed-based reproducibility only works when buffer configuration is identical.

📝 Minor Nits

  1. Typo fix already included: Good catch on "not includeing" → "not including" at line 234
  2. Consistent formatting: Code follows project style well
  3. Import organization: Clean and well-organized

🎯 Overall Assessment

This is a solid PR that adds valuable functionality with good test coverage. The main concerns are:

  1. The potential for orphaned directories if the atomic swap fails
  2. Unnecessary dataset reload after shuffling
  3. Documentation could be slightly clearer about memory/disk requirements

Recommendation: Approve with minor changes - The issues identified are relatively minor and the core functionality is sound. I'd suggest addressing the data corruption risk (#1) before merging, but the others can be addressed in follow-up PRs if needed.

Nice work on the comprehensive test coverage! 🎉

@claude
Copy link
Copy Markdown

claude bot commented Jan 4, 2026

Code Review - PR #612

Thank you for this comprehensive PR adding tokenization options and shuffling improvements to the cache activations runner! Overall, the implementation is solid with excellent test coverage. Here are my findings:


Strengths

  1. Excellent Test Coverage: The PR includes comprehensive tests covering:

    • Basic shuffle-to-disk behavior (test_cache_activations_runner_shuffled_saved_to_disk)
    • Cross-sequence shuffling with validation (test_cache_activations_runner_shuffle_across_sequences)
    • Reproducibility with seeded shuffling (test_cache_activations_runner_shuffle_across_sequences_reproducible)
    • Input validation (test_cache_activations_runner_shuffle_across_sequences_requires_shuffle)
  2. Good Documentation: Clear docstrings and inline comments explaining the shuffling behavior and atomic swap mechanism.

  3. Atomic Operations: The backup/restore mechanism for disk operations is well-implemented with proper error handling.

  4. Validation: Good use of __post_init__ validation to ensure shuffle_across_sequences requires shuffle=True.


🐛 Potential Bugs & Issues

Critical: Memory Usage in shuffle_across_sequences

Location: sae_lens/cache_activations_runner.py:306-345

The shuffle_across_sequences implementation loads the entire dataset into memory:

all_data = dataset[:]
acts = all_data[hook_name]  # (n_seq, context_size, d_in)

Issue: For large datasets, this could cause OOM errors. The PR description mentions this is for "cache activations runner," but there's no documentation about memory requirements or limits.

Recommendation:

  • Add a comment/docstring warning about memory requirements
  • Consider adding a size check or warning when the dataset exceeds a certain threshold
  • Alternatively, implement chunked shuffling for very large datasets

Example:

# Estimate memory usage
estimated_memory_gb = (acts.shape[0] * self.context_size * self.cfg.d_in * 4) / 1e9
if estimated_memory_gb > 10:  # configurable threshold
    logger.warning(
        f"shuffle_across_sequences will load ~{estimated_memory_gb:.1f}GB into memory. "
        f"Consider using sequence-level shuffle for very large datasets."
    )

Minor: Inconsistent Path Handling

Location: sae_lens/cache_activations_runner.py:363, 381

You're mixing Path objects and string conversions:

dataset.save_to_disk(str(shuffled_path))  # Line 363
dataset = Dataset.load_from_disk(str(final_cached_activation_path))  # Line 381

Recommendation: Check if save_to_disk/load_from_disk accept Path objects directly (they likely do). This would be more consistent with the rest of the codebase.


Edge Case: Empty Dataset Handling

Location: sae_lens/cache_activations_runner.py:306-310

What happens if dataset is empty or has 0 sequences?

all_data = dataset[:]
acts = all_data[hook_name]  # Could be shape (0, context_size, d_in)
n_seq = acts.shape[0]  # n_seq = 0

The reshaping operations would fail with n_seq=0.

Recommendation: Add a guard:

if len(dataset) == 0:
    logger.warning("Dataset is empty, skipping shuffle_across_sequences")
    return dataset

Performance Considerations

  1. Triple Disk I/O for Shuffled Datasets

    When shuffle=True, the workflow is:

    • Write consolidated dataset to disk
    • Load entire dataset into memory
    • Save shuffled dataset to disk
    • Rename/swap directories
    • Reload dataset from disk

    Impact: For large datasets, this is 3x disk I/O. Consider whether this could be optimized by shuffling during consolidation rather than after.

  2. Einops Overhead

    The use of einops.rearrange is elegant and readable, but for very large tensors, it creates intermediate copies. This is likely acceptable given the context, but worth noting.

  3. Token ID Conversion

    Location: sae_lens/cache_activations_runner.py:342

    "token_ids": token_ids_shuffled.to(torch.int32),

    This conversion might create an unnecessary copy if token_ids_shuffled is already int32. Consider:

    "token_ids": token_ids_shuffled.to(torch.int32) if token_ids_shuffled.dtype != torch.int32 else token_ids_shuffled,

🔒 Security Concerns

No significant security issues identified. The code:

  • Doesn't execute arbitrary user input
  • Handles file paths safely with Path objects
  • Has proper error handling in the atomic swap mechanism
  • Doesn't expose sensitive data

The backup/restore mechanism properly cleans up temporary directories on both success and failure paths.


📝 Code Quality & Best Practices

Good Practices

  • ✅ Constants defined at module level (_TMP_SHARDS_DIR, etc.)
  • ✅ Proper use of type hints
  • ✅ Descriptive variable names
  • ✅ Good separation of concerns (shuffle vs shuffle_across_sequences)

Minor Improvements

  1. Magic Numbers

    Location: sae_lens/cache_activations_runner.py:319

    Consider extracting the seed handling into a helper or adding a comment explaining why we create a new generator rather than using torch.manual_seed().

  2. Error Messages

    Location: sae_lens/cache_activations_runner.py:372-378

    The bare except Exception: catches everything, including KeyboardInterrupt (though Python 3.x handles this better). Consider being more specific or at least re-raising with context:

    except Exception as e:
        logger.error(f"Failed to swap shuffled dataset: {e}")
        # ... rollback logic ...
        raise
  3. Docstring for shuffle_across_sequences

    The docstring in config.py:482 is good, but could benefit from mentioning the memory implications and when to use this vs regular shuffle.


🧪 Test Coverage

Excellent coverage overall! The tests thoroughly validate:

  • ✅ Shuffle behavior and persistence to disk
  • ✅ Token-activation pairing preservation
  • ✅ Reproducibility with seeds
  • ✅ Configuration validation

Suggestions for Additional Tests:

  1. Large dataset behavior: Test with a dataset size that would reveal the memory loading behavior (even if mocked)
  2. Edge case: Test with n_seq=1 (single sequence)
  3. Error recovery: Test the backup/restore mechanism by simulating a failure during the atomic swap
  4. Interaction test: Verify behavior when disable_concat_sequences=True and shuffle_across_sequences=True are both set

📋 Summary

This is a well-implemented PR with excellent test coverage and clear documentation. The main concern is the memory usage for shuffle_across_sequences with large datasets, which should at minimum be documented and ideally have a warning mechanism.

Recommendations Priority:

  1. 🔴 High: Add memory usage warning/documentation for shuffle_across_sequences
  2. 🟡 Medium: Add empty dataset guard in shuffle_across_sequences
  3. 🟡 Medium: Improve error message context in the atomic swap exception handler
  4. 🟢 Low: Path object consistency (str conversion)
  5. 🟢 Low: Conditional dtype conversion optimization

Great work on this feature! The atomic swap mechanism and comprehensive testing are particularly well done. 🎉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant