Skip to content

Conversation

@kaix-nv
Copy link
Contributor

@kaix-nv kaix-nv commented Oct 7, 2025

What does this PR do?

This PR introduces a new sparse attention module for LLMs that reduces computational complexity by selectively computing only the most important attention scores. This feature enables speedup and memory reduction during inference, especially for long sequences, while maintaining model quality.

Type of change: ?
New feature

Overview: ?
The goal of this PR is to add a new ModelOpt module that supports sparse attention, improving inference efficiency and lowering cost by skipping unnecessary attention computations. The sparse attention module provides a flexible foundation for future attention sparsity techniques.

Core Design

  • SparseAttentionModule: Central module managing sparse attention behavior

  • Extensible Method Registry: Plugin architecture for sparse attention algorithms

    • Skip-softmax (implemented): Threshold-based masking for efficient attention computation
  • DSA (planned): Future support for additional sparse attention methods

  • Statistics Manager: Collects and reports sparsity metrics during inference

  • Calibration Support: Automatic threshold tuning to achieve target sparsity ratios

  • HuggingFace Unified Checkpoint Export: Export models with sparse_attention_config metadata

  • Multi-Backend Kernel Support:

    • PyTorch: Reference implementation with softmax patching
    • Triton: Optimized custom kernels
    • FlashInfer: Integration for high-performance attention (planned)

Usage

# Add a code snippet demonstrating how to use this
python hf_spar_attn.py --pyt_ckpt_path Qwen/Qwen3-4B --sparse_attn skip_softmax_calib --backend pytorch --export_dir exported_model 

Testing

Unit/GPU test passed.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Introduces Sparse Attention for LLMs with a Flash Attention–aware softmax skip method, configurable thresholds, stats, and phase-aware behavior.
    • Adds calibration workflow (RULER-based) and simple forward-loop support.
    • Integrates with HuggingFace models on-the-fly; supports export of sparse settings into model configs.
    • Provides a command-line tool and Python APIs to sparsify, calibrate, enable/disable, and save/restore.
    • Adds weight sparsity tooling (magnitude and SparseGPT) with Megatron compatibility.
  • Documentation

    • New example guide with installation, quick start, configuration, calibration, export, and performance notes.
  • Tests

    • Extensive unit and GPU tests for conversion, calibration, export, and end-to-end generation.
  • Chores

    • Example requirements updated.

@kaix-nv kaix-nv requested review from a team as code owners October 7, 2025 22:46
@kaix-nv kaix-nv requested review from cjluo-nv and realAsma October 7, 2025 22:46
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 7, 2025

Walkthrough

Adds a new sparse-attention subsystem and weight-sparsity features: configs, registries, runtime SparseAttentionModule, flash_softmax_skip method, RULER calibration (dataset + calibrator), HuggingFace plugins, HF export augmentation, CLI example, and extensive unit/GPU tests and utilities.

Changes

Cohort / File(s) Summary
Examples: LLM Sparse Attention
examples/llm_sparse_attention/README.md, examples/llm_sparse_attention/hf_spar_attn.py, examples/llm_sparse_attention/requirements.txt
New example README, CLI script and requirements demonstrating applying, calibrating, verifying, and exporting sparse attention for HuggingFace models.
HF Export Integration
modelopt/torch/export/unified_export_hf.py
Adds _get_sparse_attention_config(model) and injects sparse_attention_config into exported HF config when sparse modules are present.
Sparse Attention: Package & Mode
modelopt/torch/sparsity/attention_sparsity/__init__.py, .../mode.py, .../model_sparsify.py
New package init, mode descriptor, and top-level sparsify/calibrate entrypoints wired into modelopt mode system.
Sparse Attention: Configuration
modelopt/torch/sparsity/attention_sparsity/config.py
Pydantic-based config types, presets, calibration config and validation for sparse attention settings.
Sparse Attention: Conversion & Management
modelopt/torch/sparsity/attention_sparsity/conversion.py
Detection, on-the-fly replacement of attention modules, attribute/pattern application, enable/disable, metadata/state update and restore, and reporting utilities.
Sparse Attention: Methods & Registry
modelopt/torch/sparsity/attention_sparsity/methods/__init__.py, .../methods/registry.py, .../methods/flash_softmax_skip.py
New methods registry, abstract method API, and flash_softmax_skip implementation with block-wise masking, phase handling, calibration hooks and registration.
Sparse Attention: NN Modules & Stats
modelopt/torch/sparsity/attention_sparsity/nn/__init__.py, .../nn/sparse_attention.py, .../nn/stats_manager.py
SparseAttentionModule that patches softmax at runtime, SparseAttentionRegistry, and SparseAttentionStatsManager for aggregated/per-sample stats.
Sparse Attention: Calibration
modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py, .../calibration/calibrate.py, .../calibration/calibrator.py, .../calibration/dataset.py
RULER dataset builder, forward-loop creator, DynamicThresholdCalibrator (length-based regression), and calibrate_sparse_attention driver applying per-module calibration.
Sparse Attention: Plugins (HuggingFace)
modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py, .../plugins/huggingface.py
Re-export plugin entry and register_sparse_attention_on_the_fly to wrap HF attention modules dynamically.
Weight Sparsity: Package & Configs
modelopt/torch/sparsity/weight_sparsity/__init__.py, .../config.py
Package initializer and default/export sparsity configs (SparseMagnitudeConfig, SparseGPTConfig, ExportSparseConfig).
Weight Sparsity: Core Module & Mode
modelopt/torch/sparsity/weight_sparsity/module.py, .../mode.py, .../sparsification.py
SparseModule, SpDMRegistry, mode descriptors/registry for sparsity, conversion/restore/update/export entrypoints, and high-level sparsify/export APIs.
Weight Sparsity: Algorithms & Searchers
modelopt/torch/sparsity/weight_sparsity/searcher.py, .../magnitude.py, .../sparsegpt.py
Base searcher, 2:4 ASP magnitude mask utilities and MagnitudeSearcher, SparseGPT searcher with Hessian collection and block-wise pruning.
Weight Sparsity: Plugins (Megatron)
modelopt/torch/sparsity/weight_sparsity/plugins/__init__.py, .../plugins/megatron.py
Optional Megatron integration for sparse mask sharding and plugin rules.
Test Utilities
tests/_test_utils/torch_sparsity/sparse_attention_common.py
Test models, sample configs, helpers for sparsify/forward and save/restore verification.
GPU Tests: Sparse Attention
tests/gpu/torch/sparsity/attention_sparsity/*.py
Multiple CUDA-only tests covering forward/backward, mixed dtypes, calibration (RULER), export verification, HF integration and end-to-end flows.
Unit Tests: Sparse Attention
tests/unit/torch/sparsity/attention_sparsity/*.py
Unit tests for config validation, calibration logic, conversion, export config extraction, and mode registry presence.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor U as User/CLI
  participant M as HF/PyTorch Model
  participant S as SparseAttentionMode
  participant C as Conversion
  participant SA as SparseAttentionModule
  participant SM as Sparse Method
  participant ST as StatsManager

  U->>S: sparsify(model, config)
  S->>C: convert_to_sparse_attention_model(...)
  C->>M: replace attention with `SparseAttentionModule`
  Note over M,SA: Model now contains sparse-aware attention
  U->>M: forward(input)
  M->>SA: attention forward
  SA->>SM: apply_sparsity(q,k,v, scores)
  SM-->>SA: masked_scores + stats
  SA->>ST: collect(stats)
  SA->>M: continue forward with masked_scores
Loading
sequenceDiagram
  autonumber
  actor U as User
  participant Cal as Calibration Driver
  participant DS as RULER Dataset
  participant Tok as Tokenizer
  participant Dyn as DynamicThresholdCalibrator
  participant SA as SparseAttentionModule

  U->>Cal: calibrate_sparse_attention(model, config)
  Cal->>DS: build_calibration_dataset() (if forward_loop not provided)
  Cal->>Tok: AutoTokenizer.from_pretrained(...)
  Cal->>Dyn: run calibrator with forward_loop
  loop thresholds x samples
    Dyn->>SA: set threshold, enable calibration mode
    SA->>Model: forward sample -> stats
    SA-->>Dyn: collect per-sample sparsity stats
  end
  Dyn->>Dyn: fit scale_factor = a*(1/length) via regression
  Dyn->>SA: apply calibrated scale_factor to modules
  Dyn-->>U: return calibration results
Loading
sequenceDiagram
  autonumber
  actor U as User
  participant Exp as export_hf_checkpoint
  participant M as Model
  participant G as _get_sparse_attention_config

  U->>Exp: export_hf_checkpoint(model, dst)
  Exp->>G: scan model for sparse modules
  alt sparse config found
    G-->>Exp: sparse_attention_config dict
    Exp->>Exp: inject into config.json and write
  else
    G-->>Exp: {}
  end
  Exp-->>U: files written
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120+ minutes

Poem

A rabbit taps keys with a whisker’s intent,
Blocks of attention now sparsely are spent.
RULER feeds riddles, thresholds take flight,
Flash skips the softmax, masking feels light.
Export and tests hum — hop, hop — model delight. 🐇✨

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly specifies the primary change—adding support for sparse attention—and aligns with the PR’s objectives of introducing SparseAttentionModule, calibration, and export features, making it clear and descriptive.
Docstring Coverage ✅ Passed Docstring coverage is 90.36% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch kaix/skip_softmax_calib

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 12

🧹 Nitpick comments (2)
modelopt/torch/sparsity/attention_sparsity/config.py (1)

125-131: Enforce integer block sizes.
The validator still accepts floats (e.g., 128.5), contradicting the “block size” contract. Reject non-integral values up front.

     @field_validator("br", "bc")
     @classmethod
     def validate_block_size(cls, v):
         """Validate block sizes are positive integers."""
-        if v <= 0:
-            raise ValueError(f"Block size must be positive, got {v}")
+        if not isinstance(v, int):
+            raise ValueError(f"Block size must be an integer, got type {type(v).__name__}")
+        if v <= 0:
+            raise ValueError(f"Block size must be positive, got {v}")
         return v
tests/_test_utils/torch_sparsity/sparse_attention_common.py (1)

203-205: Capture the restored module returned by mto.restore_from_modelopt_state

restore_from_modelopt_state may hand back a different module object (especially when starting from a modellike tuple). Ignoring its return risks operating on the unmodified instance. Reassign the helper’s model_restored to the returned module before loading weights.

Apply this diff:

-    mto.restore_from_modelopt_state(model_restored, state_dict)
+    model_restored = mto.restore_from_modelopt_state(model_restored, state_dict)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1537885 and 2ab2a02.

📒 Files selected for processing (40)
  • examples/llm_sparse_attention/README.md (1 hunks)
  • examples/llm_sparse_attention/hf_spar_attn.py (1 hunks)
  • modelopt/torch/export/unified_export_hf.py (2 hunks)
  • modelopt/torch/sparsity/attention_sparsity/__init__.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/config.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/conversion.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/methods/__init__.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/methods/registry.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/mode.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/model_sparsify.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/nn/__init__.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py (1 hunks)
  • modelopt/torch/sparsity/weight_sparsity/__init__.py (1 hunks)
  • modelopt/torch/sparsity/weight_sparsity/config.py (1 hunks)
  • modelopt/torch/sparsity/weight_sparsity/magnitude.py (1 hunks)
  • modelopt/torch/sparsity/weight_sparsity/mode.py (1 hunks)
  • modelopt/torch/sparsity/weight_sparsity/module.py (1 hunks)
  • modelopt/torch/sparsity/weight_sparsity/plugins/__init__.py (1 hunks)
  • modelopt/torch/sparsity/weight_sparsity/plugins/megatron.py (1 hunks)
  • modelopt/torch/sparsity/weight_sparsity/searcher.py (1 hunks)
  • modelopt/torch/sparsity/weight_sparsity/sparsegpt.py (1 hunks)
  • modelopt/torch/sparsity/weight_sparsity/sparsification.py (1 hunks)
  • tests/_test_utils/torch_sparsity/sparse_attention_common.py (1 hunks)
  • tests/gpu/torch/sparsity/attention_sparsity/test_basic_gpu.py (1 hunks)
  • tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py (1 hunks)
  • tests/gpu/torch/sparsity/attention_sparsity/test_export_gpu.py (1 hunks)
  • tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py (1 hunks)
  • tests/unit/torch/sparsity/attention_sparsity/test_export_config.py (1 hunks)
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py (1 hunks)
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py (1 hunks)
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py (1 hunks)
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (36)
modelopt/torch/sparsity/attention_sparsity/methods/__init__.py (1)
modelopt/torch/sparsity/attention_sparsity/methods/registry.py (3)
  • SparseAttentionMethod (23-49)
  • get_sparse_method (88-120)
  • register_sparse_method (56-85)
tests/unit/torch/sparsity/attention_sparsity/test_export_config.py (3)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (5)
  • SimpleTransformerEncoderLayer (49-69)
  • sparsify_model_and_forward (150-182)
  • get_input (44-46)
  • get_input (67-69)
  • get_input (88-90)
modelopt/torch/export/unified_export_hf.py (1)
  • _get_sparse_attention_config (340-460)
modelopt/torch/sparsity/attention_sparsity/conversion.py (1)
  • disable_sparse_attention (278-305)
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py (3)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (2)
  • SimpleAttentionModel (27-46)
  • SimpleTransformerEncoderLayer (49-69)
modelopt/torch/sparsity/attention_sparsity/conversion.py (2)
  • disable_sparse_attention (278-305)
  • enable_sparse_attention (308-335)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
  • SparseAttentionModule (29-201)
modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py (1)
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py (1)
  • register_sparse_attention_on_the_fly (55-100)
modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py (2)
modelopt/torch/nas/utils.py (1)
  • enabled (226-227)
modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py (1)
  • set_calibration_mode (58-60)
modelopt/torch/sparsity/weight_sparsity/config.py (1)
modelopt/torch/opt/config.py (2)
  • ModeloptBaseConfig (59-147)
  • get_kwargs_for_create_model_with_rules (322-383)
modelopt/torch/sparsity/attention_sparsity/methods/registry.py (1)
modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py (2)
  • apply_sparsity (241-284)
  • name (287-289)
modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py (4)
modelopt/torch/sparsity/attention_sparsity/model_sparsify.py (1)
  • calibrate (151-197)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (1)
  • calibrate_sparse_attention (112-177)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)
  • DynamicThresholdCalibrator (31-308)
modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py (1)
  • RulerDatasetBuilder (181-602)
tests/gpu/torch/sparsity/attention_sparsity/test_export_gpu.py (4)
tests/_test_utils/torch_model/transformers_models.py (1)
  • create_tiny_llama_dir (117-131)
modelopt/torch/export/unified_export_hf.py (1)
  • export_hf_checkpoint (621-677)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
  • SparseAttentionConfig (292-319)
modelopt/torch/sparsity/attention_sparsity/conversion.py (1)
  • disable_sparse_attention (278-305)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (6)
modelopt/torch/opt/dynamic.py (3)
  • DynamicModule (338-914)
  • _DMRegistryCls (917-1124)
  • config (1265-1278)
modelopt/torch/quantization/utils.py (1)
  • replace_function (300-307)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
  • SparseAttentionAttributeConfig (34-156)
modelopt/torch/sparsity/attention_sparsity/methods/registry.py (2)
  • get_sparse_method (88-120)
  • apply_sparsity (27-44)
modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py (3)
  • SparseAttentionStatsManager (7-125)
  • get_summary (74-92)
  • collect (43-72)
modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py (1)
  • apply_sparsity (241-284)
tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py (3)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
  • SparseAttentionConfig (292-319)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
  • SparseAttentionModule (29-201)
modelopt/torch/opt/conversion.py (1)
  • modelopt_state (444-486)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
modelopt/torch/opt/config.py (3)
  • ModeloptBaseConfig (59-147)
  • ModeloptField (50-53)
  • keys (132-134)
modelopt/torch/sparsity/weight_sparsity/__init__.py (1)
modelopt/torch/sparsity/weight_sparsity/module.py (1)
  • SparseModule (31-87)
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py (1)
modelopt/torch/sparsity/attention_sparsity/config.py (3)
  • FlashSoftmaxSkipConfig (322-346)
  • SparseAttentionAttributeConfig (34-156)
  • SparseAttentionConfig (292-319)
examples/llm_sparse_attention/hf_spar_attn.py (5)
modelopt/torch/export/unified_export_hf.py (1)
  • export_hf_checkpoint (621-677)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
  • SparseAttentionConfig (292-319)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (4)
  • SparseAttentionModule (29-201)
  • get_stats (130-139)
  • disable (121-123)
  • enable (117-119)
modelopt/torch/utils/memory_monitor.py (1)
  • launch_memory_monitor (134-151)
modelopt/torch/sparsity/attention_sparsity/model_sparsify.py (1)
  • sparsify (35-148)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (5)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
  • CalibrationConfig (159-247)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
  • SparseAttentionModule (29-201)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (2)
  • DynamicThresholdCalibrator (31-308)
  • calibrate (77-219)
modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py (2)
  • RulerDatasetBuilder (181-602)
  • build_calibration_dataset (236-262)
modelopt/torch/sparsity/attention_sparsity/model_sparsify.py (1)
  • calibrate (151-197)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (6)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
  • SparseAttentionModule (29-201)
modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py (3)
  • SparseAttentionStatsManager (7-125)
  • set_calibration_mode (94-106)
  • get_calibration_stats (118-125)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (1)
  • forward_loop (97-107)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (1)
  • forward_loop (162-164)
tests/gpu/torch/sparsity/attention_sparsity/test_export_gpu.py (1)
  • forward_loop (189-196)
modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py (1)
  • set_calibration_mode (58-60)
modelopt/torch/sparsity/attention_sparsity/mode.py (4)
modelopt/torch/opt/config.py (1)
  • ModeloptBaseConfig (59-147)
modelopt/torch/opt/mode.py (2)
  • ModeDescriptor (56-259)
  • _ModeRegistryCls (267-344)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
  • SparseAttentionConfig (292-319)
modelopt/torch/sparsity/attention_sparsity/conversion.py (3)
  • convert_to_sparse_attention_model (47-76)
  • restore_sparse_attention_model (175-195)
  • update_sparse_attention_metadata (234-275)
modelopt/torch/sparsity/weight_sparsity/magnitude.py (2)
modelopt/torch/sparsity/weight_sparsity/module.py (1)
  • SparseModule (31-87)
modelopt/torch/sparsity/weight_sparsity/searcher.py (3)
  • BaseSparseSearcher (31-84)
  • _check_weight_size (59-61)
  • _compute_mask (64-66)
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py (1)
modelopt/torch/opt/mode.py (2)
  • _ModeRegistryCls (267-344)
  • get_from_any (326-336)
modelopt/torch/export/unified_export_hf.py (3)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
  • SparseAttentionModule (29-201)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/opt/dynamic.py (3)
  • original_cls (867-873)
  • get_original_cls_by_level (905-914)
  • level (158-160)
modelopt/torch/sparsity/attention_sparsity/conversion.py (5)
modelopt/torch/opt/conversion.py (4)
  • ModelLikeModule (318-330)
  • ModeloptStateManager (63-311)
  • init_modellike (326-330)
  • state_version (135-137)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
  • SparseAttentionConfig (292-319)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (6)
  • SparseAttentionModule (29-201)
  • set_from_attribute_config (61-107)
  • _setup (141-153)
  • disable (121-123)
  • enable (117-119)
  • is_enabled (126-128)
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py (2)
  • register_sparse_attention_on_the_fly (55-100)
  • _setup (33-38)
modelopt/torch/sparsity/attention_sparsity/model_sparsify.py (4)
modelopt/torch/opt/conversion.py (3)
  • ModeloptStateManager (63-311)
  • apply_mode (342-429)
  • is_converted (102-127)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (2)
  • calibrate_sparse_attention (112-177)
  • forward_loop (97-107)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
  • SparseAttentionConfig (292-319)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)
  • calibrate (77-219)
modelopt/torch/sparsity/attention_sparsity/nn/__init__.py (2)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
  • SparseAttentionModule (29-201)
modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py (1)
  • SparseAttentionStatsManager (7-125)
modelopt/torch/sparsity/weight_sparsity/module.py (2)
modelopt/torch/opt/dynamic.py (4)
  • DynamicModule (338-914)
  • _register_hparam (383-423)
  • _register_temp_attribute (476-513)
  • _register_dynamic_attribute (425-474)
modelopt/torch/opt/hparam.py (1)
  • Hparam (48-275)
modelopt/torch/sparsity/weight_sparsity/sparsegpt.py (4)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/sparsity/weight_sparsity/magnitude.py (3)
  • get_nmprune_info (29-35)
  • _check_weight_size (134-144)
  • _compute_mask (146-148)
modelopt/torch/sparsity/weight_sparsity/module.py (1)
  • SparseModule (31-87)
modelopt/torch/sparsity/weight_sparsity/searcher.py (5)
  • BaseSparseSearcher (31-84)
  • default_search_config (37-39)
  • _check_weight_size (59-61)
  • _compute_mask (64-66)
  • _named_sparsifiable_modules (68-76)
modelopt/torch/sparsity/weight_sparsity/plugins/megatron.py (3)
modelopt/torch/opt/plugins/megatron.py (1)
  • _MegatronMLP (120-142)
modelopt/torch/sparsity/weight_sparsity/module.py (1)
  • SparseModule (31-87)
modelopt/torch/opt/config.py (1)
  • register_default (227-237)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (2)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
  • SparseAttentionModule (29-201)
modelopt/torch/opt/conversion.py (2)
  • modelopt_state (444-486)
  • restore_from_modelopt_state (510-567)
tests/gpu/torch/sparsity/attention_sparsity/test_basic_gpu.py (1)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (9)
  • SimpleAttentionModel (27-46)
  • SimpleTransformerEncoder (72-90)
  • SimpleTransformerEncoderLayer (49-69)
  • get_test_configs (142-147)
  • save_restore_test (185-218)
  • sparsify_model_and_forward (150-182)
  • get_input (44-46)
  • get_input (67-69)
  • get_input (88-90)
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py (2)
modelopt/torch/opt/dynamic.py (3)
  • DynamicModule (338-914)
  • get_original_cls_by_level (905-914)
  • level (158-160)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (2)
  • SparseAttentionModule (29-201)
  • _setup (141-153)
modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py (2)
modelopt/torch/sparsity/attention_sparsity/methods/registry.py (4)
  • SparseAttentionMethod (23-49)
  • register_sparse_method (56-85)
  • apply_sparsity (27-44)
  • name (48-49)
modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py (1)
  • set_calibration_mode (94-106)
modelopt/torch/sparsity/weight_sparsity/sparsification.py (3)
modelopt/torch/opt/conversion.py (2)
  • apply_mode (342-429)
  • get_mode (432-441)
modelopt/torch/opt/searcher.py (1)
  • BaseSearcher (52-271)
modelopt/torch/utils/network.py (1)
  • unwrap_model (430-454)
modelopt/torch/sparsity/weight_sparsity/mode.py (4)
modelopt/torch/opt/dynamic.py (2)
  • config (1265-1278)
  • convert_to_dynamic (1138-1187)
modelopt/torch/opt/conversion.py (1)
  • ApplyModeError (314-315)
modelopt/torch/opt/mode.py (2)
  • ModeDescriptor (56-259)
  • _ModeRegistryCls (267-344)
modelopt/torch/utils/network.py (2)
  • compare_dict (423-427)
  • unwrap_model (430-454)
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py (3)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (5)
  • SimpleAttentionModel (27-46)
  • forward_loop (162-164)
  • forward (39-41)
  • forward (63-64)
  • forward (84-85)
modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py (3)
  • RulerDatasetBuilder (181-602)
  • _generate_target_lengths (28-56)
  • build_calibration_dataset (236-262)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
  • SparseAttentionModule (29-201)
tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py (4)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (4)
  • SimpleTransformerEncoderLayer (49-69)
  • get_input (44-46)
  • get_input (67-69)
  • get_input (88-90)
modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py (2)
  • RulerDatasetBuilder (181-602)
  • build_calibration_dataset (236-262)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
  • SparseAttentionModule (29-201)
modelopt/torch/opt/conversion.py (2)
  • modelopt_state (444-486)
  • restore_from_modelopt_state (510-567)
modelopt/torch/sparsity/weight_sparsity/searcher.py (4)
modelopt/torch/opt/searcher.py (1)
  • BaseSearcher (52-271)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/sparsity/weight_sparsity/module.py (2)
  • SparseModule (31-87)
  • set_mask (68-87)
modelopt/torch/sparsity/weight_sparsity/magnitude.py (1)
  • get_nmprune_info (29-35)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs

Comment on lines +273 to +290
args.pyt_ckpt_path,
attn_implementation="eager",
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)

# Set pad token if not set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

# Move model to GPU if available
if torch.cuda.is_available():
model = model.cuda()
print("Model moved to CUDA")

# Apply sparse attention to the model (with calibration if configured)
model = sparsify_model(model, args)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Call model.eval() before running inference

The model stays in training mode after loading (dropout, layernorm stats, etc.), so the sparse-attention calibration and the generation comparison run with stochastic behavior. That breaks verification and produces inconsistent results. Move the model to eval mode before sparsifying/exporting.

     model = AutoModelForCausalLM.from_pretrained(
         args.pyt_ckpt_path,
         attn_implementation="eager",
         torch_dtype=torch.bfloat16,
     )
     tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
 
     # Set pad token if not set
     if tokenizer.pad_token is None:
         tokenizer.pad_token = tokenizer.eos_token
 
     # Move model to GPU if available
     if torch.cuda.is_available():
         model = model.cuda()
         print("Model moved to CUDA")
 
+    model.eval()
+
     # Apply sparse attention to the model (with calibration if configured)
     model = sparsify_model(model, args)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
args.pyt_ckpt_path,
attn_implementation="eager",
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
# Set pad token if not set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Move model to GPU if available
if torch.cuda.is_available():
model = model.cuda()
print("Model moved to CUDA")
# Apply sparse attention to the model (with calibration if configured)
model = sparsify_model(model, args)
model = AutoModelForCausalLM.from_pretrained(
args.pyt_ckpt_path,
attn_implementation="eager",
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
# Set pad token if not set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Move model to GPU if available
if torch.cuda.is_available():
model = model.cuda()
print("Model moved to CUDA")
model.eval()
# Apply sparse attention to the model (with calibration if configured)
model = sparsify_model(model, args)
🤖 Prompt for AI Agents
In examples/llm_sparse_attention/hf_spar_attn.py around lines 273 to 290, the
model remains in training mode which allows dropout and other stochastic
behaviors during sparse-attention calibration and generation; call model.eval()
after loading (and after moving to CUDA if applicable) and before calling
sparsify_model or running any calibration/generation to ensure deterministic
behavior during verification.

Comment on lines +364 to +439
# Collect all enabled sparse attention modules
sparse_modules = []
for name, module in model.named_modules():
if isinstance(module, SparseAttentionModule) and module.is_enabled:
sparse_modules.append((name, module))

if not sparse_modules:
return {}

sparse_config = {
"config_groups": {},
"producer": {
"name": "modelopt",
"version": __version__,
},
}

# Check first module for global calibration parameters
# (all modules share the same calibration parameters)
first_module = sparse_modules[0][1]
method_instance = first_module._sparse_method_instance
threshold_scale_factor = getattr(method_instance, "threshold_scale_factor", None)

if threshold_scale_factor is not None:
# Model was calibrated: add global calibration parameters
sparse_config["threshold_scale_factor"] = float(threshold_scale_factor)

target_sparsity = getattr(method_instance, "target_sparsity", None)
if target_sparsity is not None:
sparse_config["target_sparsity"] = float(target_sparsity)

# Group modules by configuration
# Key: (sparse_algo, threshold_repr), Value: list of module class names
config_to_targets = {}

for name, module in sparse_modules:
method_instance = module._sparse_method_instance

# Extract sparse algorithm name from method name
# e.g., "flash_softmax_skip" -> "softmax_skip"
method_name = method_instance.name
if method_name.startswith("flash_"):
sparse_algo = method_name[6:] # Remove "flash_" prefix
else:
sparse_algo = method_name

# Get module's original class name for targets
# Get the class name before SparseAttentionModule wrapping
original_cls = module.get_original_cls_by_level(level=0)
target_class_name = original_cls.__name__

# Build config key for grouping
if threshold_scale_factor is None:
# Not calibrated: include threshold in grouping
threshold_config = getattr(method_instance, "threshold_config", None)
if isinstance(threshold_config, dict):
# Convert dict to tuple for hashable key
threshold_repr = tuple(sorted(threshold_config.items()))
else:
threshold_repr = threshold_config
else:
# Calibrated: no threshold in per-layer config
threshold_repr = None

config_key = (sparse_algo, threshold_repr)

if config_key not in config_to_targets:
config_to_targets[config_key] = {
"sparse_algo": sparse_algo,
"threshold_config": threshold_config if threshold_scale_factor is None else None,
"targets": set(),
}

config_to_targets[config_key]["targets"].add(target_class_name)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Ensure sparse modules are fully initialized before export.

module.is_enabled can return True even before _setup() runs, so _sparse_method_instance might not exist yet. Accessing it here will raise AttributeError for modules that were converted but never forwarded. Call module._setup() (or otherwise guarantee initialization) before dereferencing _sparse_method_instance.

🤖 Prompt for AI Agents
In modelopt/torch/export/unified_export_hf.py around lines 364 to 438, some
SparseAttentionModule instances may report is_enabled=True before they have been
initialized, so accessing module._sparse_method_instance can raise
AttributeError; call module._setup() (or otherwise ensure initialization) for
each module in the sparse_modules collection before dereferencing
_sparse_method_instance, and guard access with a check that the attribute exists
after setup (skip or log modules that still lack the attribute) so export
proceeds only with fully-initialized modules.

Comment on lines +233 to +235
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
random.seed(seed)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid resetting Python’s global RNG

Calling random.seed(seed) here mutates the process-wide RNG state every time the builder is constructed. Any code relying on randomness after calibration (e.g., sampling batches, data augmentation) becomes silently deterministic, which is a major cross-cutting bug. Instantiate a dedicated random.Random (and thread it through the helpers) instead of reseeding the global RNG.

🤖 Prompt for AI Agents
In modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py around
lines 233 to 235, the constructor currently calls random.seed(seed) which
mutates the global RNG; instead create a dedicated RNG instance (e.g.,
self.random = random.Random(seed)) and remove the global seed call, then replace
any uses of the global random module in this class and pass this self.random
into any helper functions or sub-objects that need randomness so they call
methods on the instance (self.random.random(), self.random.choice(), etc.)
rather than the module-level functions.

Comment on lines +198 to +232
def restore_sparse_attention_state(model: nn.Module, state_dict: dict[str, Any]):
"""Restore sparse attention state from state dict.
Args:
model: Model with sparse attention modules
state_dict: Saved state dictionary
"""
for name, module in model.named_modules():
if isinstance(module, SparseAttentionModule):
module_name = get_unwrapped_name(name, model)
if module_name in state_dict:
module_state = state_dict[module_name]

# Restore method and config
if "method" in module_state:
module._method = module_state["method"]
if "method_config" in module_state:
# Restore config attributes
# Separate method instance attributes from module attributes
method_instance_attrs = {"threshold_scale_factor", "target_sparsity"}

for key, val in module_state["method_config"].items():
if key not in method_instance_attrs:
# Set on module
setattr(module, f"_{key}", val)

# Re-setup with restored config
module._setup()

# Restore method instance attributes after _setup
if "method_config" in module_state:
for key, val in module_state["method_config"].items():
if key in {"threshold_scale_factor", "target_sparsity"}:
setattr(module._sparse_method_instance, key, val)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reinitialize sparse method when restoring state

After loading metadata["sparse_attention_state"], we overwrite _method and _method_config, yet _sparse_method_instance is left pointing to the pre-existing object that still carries the old configuration. Any calibrated thresholds or even method swaps captured in metadata therefore never reach the runtime path—only threshold_scale_factor/target_sparsity get patched manually. Call _init_sparse_method() (before _setup()) once the module-level attributes are restored so the method instance is rebuilt from the restored config. Otherwise, restored models diverge from the saved sparse behavior.

Apply this diff:

-        if "method_config" in module_state:
-            for key, val in module_state["method_config"].items():
-                if key not in method_instance_attrs:
-                    # Set on module
-                    setattr(module, f"_{key}", val)
-
-        # Re-setup with restored config
-        module._setup()
+        if "method_config" in module_state:
+            for key, val in module_state["method_config"].items():
+                if key not in method_instance_attrs:
+                    # Set on module
+                    setattr(module, f"_{key}", val)
+
+        # Rebuild sparse method with restored config before final setup
+        module._init_sparse_method()
+        # Re-setup with restored config
+        module._setup()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def restore_sparse_attention_state(model: nn.Module, state_dict: dict[str, Any]):
"""Restore sparse attention state from state dict.
Args:
model: Model with sparse attention modules
state_dict: Saved state dictionary
"""
for name, module in model.named_modules():
if isinstance(module, SparseAttentionModule):
module_name = get_unwrapped_name(name, model)
if module_name in state_dict:
module_state = state_dict[module_name]
# Restore method and config
if "method" in module_state:
module._method = module_state["method"]
if "method_config" in module_state:
# Restore config attributes
# Separate method instance attributes from module attributes
method_instance_attrs = {"threshold_scale_factor", "target_sparsity"}
for key, val in module_state["method_config"].items():
if key not in method_instance_attrs:
# Set on module
setattr(module, f"_{key}", val)
# Re-setup with restored config
module._setup()
# Restore method instance attributes after _setup
if "method_config" in module_state:
for key, val in module_state["method_config"].items():
if key in {"threshold_scale_factor", "target_sparsity"}:
setattr(module._sparse_method_instance, key, val)
def restore_sparse_attention_state(model: nn.Module, state_dict: dict[str, Any]):
"""Restore sparse attention state from state dict.
Args:
model: Model with sparse attention modules
state_dict: Saved state dictionary
"""
for name, module in model.named_modules():
if isinstance(module, SparseAttentionModule):
module_name = get_unwrapped_name(name, model)
if module_name in state_dict:
module_state = state_dict[module_name]
# Restore method and config
if "method" in module_state:
module._method = module_state["method"]
if "method_config" in module_state:
# Restore config attributes
# Separate method instance attributes from module attributes
method_instance_attrs = {"threshold_scale_factor", "target_sparsity"}
for key, val in module_state["method_config"].items():
if key not in method_instance_attrs:
# Set on module
setattr(module, f"_{key}", val)
# Rebuild sparse method with restored config before final setup
module._init_sparse_method()
# Re-setup with restored config
module._setup()
# Restore method instance attributes after _setup
if "method_config" in module_state:
for key, val in module_state["method_config"].items():
if key in {"threshold_scale_factor", "target_sparsity"}:
setattr(module._sparse_method_instance, key, val)
🤖 Prompt for AI Agents
In modelopt/torch/sparsity/attention_sparsity/conversion.py around lines 198 to
232, when restoring a SparseAttentionModule you set module._method and
module-level config but do not recreate the method instance, leaving
module._sparse_method_instance with stale state; after restoring module-level
attributes from module_state (i.e. after setting module._method and any
module._* config fields) call module._init_sparse_method() to reinitialize the
method instance, then proceed to module._setup() and finally patch any
instance-specific attributes (threshold_scale_factor, target_sparsity) if
present in module_state so the runtime method reflects the restored config.

Comment on lines +32 to +48
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets)

sparse_state_dict = {
k: v
for k, v in self.state_dict(prefix="", keep_vars=True).items()
if k == "_weight_mask"
}

sharded_axis_dict = self._get_shard_axis_dict(sparse_state_dict)

if sparse_state_dict:
sharded_state_dict.update(
**make_sharded_tensors_for_checkpoint(
sparse_state_dict, prefix, sharded_axis_dict, sharded_offsets
)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Propagate metadata when calling the base sharded_state_dict.
The override discards the caller-supplied metadata, breaking checkpoint writers that depend on it.

-    def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
-        sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets)
+    def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
+        sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets)
sparse_state_dict = {
k: v
for k, v in self.state_dict(prefix="", keep_vars=True).items()
if k == "_weight_mask"
}
sharded_axis_dict = self._get_shard_axis_dict(sparse_state_dict)
if sparse_state_dict:
sharded_state_dict.update(
**make_sharded_tensors_for_checkpoint(
sparse_state_dict, prefix, sharded_axis_dict, sharded_offsets
)
)
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
sparse_state_dict = {
k: v
for k, v in self.state_dict(prefix="", keep_vars=True).items()
if k == "_weight_mask"
}
sharded_axis_dict = self._get_shard_axis_dict(sparse_state_dict)
if sparse_state_dict:
sharded_state_dict.update(
**make_sharded_tensors_for_checkpoint(
sparse_state_dict, prefix, sharded_axis_dict, sharded_offsets
)
)
🤖 Prompt for AI Agents
In modelopt/torch/sparsity/weight_sparsity/plugins/megatron.py around lines 32
to 48, the override of sharded_state_dict discards the caller-supplied metadata
when calling the base implementation; update the call to
super().sharded_state_dict to pass the metadata through (e.g.,
super().sharded_state_dict(prefix, sharded_offsets, metadata=metadata) or
equivalent named parameter), leaving the rest of the logic unchanged so
checkpoint writers that rely on metadata continue to receive it.

Comment on lines +51 to +54
is_nm_prune, n, m = asp.get_nmprune_info(config_sanitized["pattern"])
assert is_nm_prune and n == 2 and m == 4, (
f"Unsupported pattern {self.config['pattern']} for sparsity"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix misuse of self.config in assertion.

When the pattern is invalid, the assertion message accesses self.config, but self.config is not populated until after sanitize_search_config returns. That triggers an AttributeError instead of the intended assertion, hiding the real cause from the user. Please reference config_sanitized (or the incoming config) in the message instead.

Apply this diff:

-        assert is_nm_prune and n == 2 and m == 4, (
-            f"Unsupported pattern {self.config['pattern']} for sparsity"
-        )
+        assert is_nm_prune and n == 2 and m == 4, (
+            f"Unsupported pattern {config_sanitized['pattern']} for sparsity"
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
is_nm_prune, n, m = asp.get_nmprune_info(config_sanitized["pattern"])
assert is_nm_prune and n == 2 and m == 4, (
f"Unsupported pattern {self.config['pattern']} for sparsity"
)
is_nm_prune, n, m = asp.get_nmprune_info(config_sanitized["pattern"])
assert is_nm_prune and n == 2 and m == 4, (
f"Unsupported pattern {config_sanitized['pattern']} for sparsity"
)
🤖 Prompt for AI Agents
In modelopt/torch/sparsity/weight_sparsity/searcher.py around lines 51 to 54,
the assertion error message wrongly references self.config which isn't set yet
and can raise AttributeError; change the assertion message to reference
config_sanitized (or the incoming config) instead so the correct pattern value
is reported when the assertion fails, e.g. use config_sanitized["pattern"] in
the f-string and keep the existing is_nm_prune and n/m checks unchanged.

Comment on lines +32 to +133
def invert(hessian: torch.Tensor) -> torch.Tensor:
"""Invert a Hessian matrix."""
try:
hessian_inv = torch.linalg.cholesky(hessian)
hessian_inv = torch.cholesky_inverse(hessian_inv)
hessian_inv = torch.linalg.cholesky(hessian_inv, upper=True)
except RuntimeError:
cols = hessian.size(0)
eps = 1e-6 * torch.eye(cols).to(hessian.device)
hessian_inv = torch.cholesky_inverse(torch.linalg.cholesky(hessian + eps))

return hessian_inv


def prepare(
tensor: torch.Tensor, hessian: torch.Tensor, hessian_damp: float
) -> tuple[torch.Tensor, torch.Tensor]:
"""Prepare the inverse Hessian matrix."""
weight = tensor.detach().clone()
# move the hessian matrix from CPU to GPU for acceleration
hessian = hessian.to(weight.device)
if len(weight.size()) == 4:
weight = weight.flatten(1)

zero = torch.diag(hessian) == 0
hessian[zero, zero] = 1
weight[:, zero] = 0

damp = hessian_damp * torch.mean(torch.diag(hessian))
cols = weight.size(1)
diag = torch.arange(cols)
hessian[diag, diag] += damp

hessian_inv = invert(hessian)

# remove the Hessian matrix to save GPU memory
del hessian
torch.cuda.empty_cache()

return weight, hessian_inv


@torch.no_grad()
def create_sgpt_mask(
tensor: torch.Tensor, hessian: torch.Tensor, config: SearchConfig
) -> torch.Tensor:
"""Create a sparse mask for the given tensor."""
shape = tensor.size()
weight, hessian_inv = prepare(tensor, hessian, config["hessian_damp"])
rows, cols = weight.size()
hessian_inv_diag = torch.diagonal(hessian_inv, dim1=0, dim2=1)

is_nm_prune, n, m = get_nmprune_info(config["pattern"])
col_bs = config["col_block_size"]
row_bs = config["row_block_size"]
# if row_bs is not specified, prune the whole weight block
if row_bs == -1:
row_bs = rows

for r1 in range(0, rows, row_bs):
r2 = min(r1 + row_bs, rows)
# the mask of the weights not to be pruned
w_rows = weight[r1:r2].float()

# pruning the weight block W[row:row+row_bs, i1:i1+col_bs]
for i1 in range(0, cols, col_bs):
i2 = min(i1 + col_bs, cols)
w_blk = w_rows[:, i1:i2].clone()
q_blk = torch.zeros_like(w_blk)
# the error of the weights to be pruned
delta_blk = torch.zeros_like(w_blk)
hinv_blk = hessian_inv[i1:i2, i1:i2]
hinv_diag_blk = hessian_inv_diag[i1:i2]

errors_blk = (w_blk**2) / (hinv_diag_blk**2 + 1e-9)
if torch.isnan(errors_blk).any():
print("nan in errors_blk.")

mask_blk = torch.zeros_like(w_blk, dtype=torch.bool)

for j in range(i2 - i1):
# compute the error of the weights to be pruned
w = w_blk[:, j]
d = hinv_diag_blk[j]
if is_nm_prune and j % m == 0:
errors_blk = (w_blk[:, j : j + m] ** 2) / (hinv_diag_blk[j : j + m] ** 2 + 1e-9)
mask_blk.scatter_(
1, j + torch.topk(errors_blk, n, dim=1, largest=False)[1], True
)

q = w.clone()
q[mask_blk[:, j]] = 0
q_blk[:, j] = q

# update the remaining weights in the col_bs block to compensate the error caused by pruning W[:, j]
err = (w - q) / d
w_blk[:, j:] -= err.unsqueeze(1).matmul(hinv_blk[j, j:].unsqueeze(0))
delta_blk[:, j] = err

# compensate the error caused by pruning W[:, i: i + col_bs] with the weights update in W[:, i + col_bs:]
w_rows[:, i1:i2] = q_blk
w_rows[:, i2:] -= delta_blk.matmul(hessian_inv[i1:i2, i2:])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix incorrect Hessian “inverse” usage

invert() currently returns the Cholesky factor of the inverse (upper‑triangular) instead of the actual inverse matrix. Downstream code (e.g., Lines 105‑134) then treats it as a full inverse—leading to inconsistent scaling (hinv_diag_blk is the square root of the true diagonal) and using triangular blocks where the dense inverse is required. The resulting masks/weight updates are numerically wrong and can severely degrade pruning quality. Please return the genuine inverse and consume it consistently (no extra square root assumptions):

@@
-    try:
-        hessian_inv = torch.linalg.cholesky(hessian)
-        hessian_inv = torch.cholesky_inverse(hessian_inv)
-        hessian_inv = torch.linalg.cholesky(hessian_inv, upper=True)
+    try:
+        chol = torch.linalg.cholesky(hessian)
+        hessian_inv = torch.cholesky_inverse(chol)
@@
-    hessian_inv_diag = torch.diagonal(hessian_inv, dim1=0, dim2=1)
+    hessian_inv_diag = torch.diagonal(hessian_inv, dim1=0, dim2=1)
@@
-            errors_blk = (w_blk**2) / (hinv_diag_blk**2 + 1e-9)
+            errors_blk = (w_blk**2) / (hinv_diag_blk + 1e-9)
@@
-                if is_nm_prune and j % m == 0:
-                    errors_blk = (w_blk[:, j : j + m] ** 2) / (hinv_diag_blk[j : j + m] ** 2 + 1e-9)
+                if is_nm_prune and j % m == 0:
+                    errors_blk = (w_blk[:, j : j + m] ** 2) / (hinv_diag_blk[j : j + m] + 1e-9)
@@
-                err = (w - q) / d
+                err = (w - q) / (d + 1e-9)

Adjust any related calculations to use the full inverse (no extra Cholesky).

Comment on lines +268 to +275
if mod.hessian.device.type == "cuda":
if cls._is_memory_sufficient(mod.hessian.device.index, 0.8):
mod.hessian += inp.matmul(inp.t()).to(mod.hessian.device)
else:
target_device = "cpu"
mod.hessian = mod.hessian.to("cpu")
mod.hessian += inp.matmul(inp.t()).to(target_device)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Avoid double-counting Hessian updates on CUDA

When enough GPU memory is available, the hook adds inp.matmul(inp.t()) twice: once inside the CUDA branch and once again right after, inflating Hessian statistics and corrupting the mask search. Add the update exactly once and only move to CPU when needed:

-            if mod.hessian.device.type == "cuda":
-                if cls._is_memory_sufficient(mod.hessian.device.index, 0.8):
-                    mod.hessian += inp.matmul(inp.t()).to(mod.hessian.device)
-                else:
-                    target_device = "cpu"
-                    mod.hessian = mod.hessian.to("cpu")
-            mod.hessian += inp.matmul(inp.t()).to(target_device)
+            update = inp.matmul(inp.t())
+            if mod.hessian.device.type == "cuda" and not cls._is_memory_sufficient(
+                mod.hessian.device.index, 0.8
+            ):
+                mod.hessian = mod.hessian.to("cpu")
+            mod.hessian += update.to(mod.hessian.device)

This keeps the accumulation correct while still supporting the CPU fallback.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if mod.hessian.device.type == "cuda":
if cls._is_memory_sufficient(mod.hessian.device.index, 0.8):
mod.hessian += inp.matmul(inp.t()).to(mod.hessian.device)
else:
target_device = "cpu"
mod.hessian = mod.hessian.to("cpu")
mod.hessian += inp.matmul(inp.t()).to(target_device)
# compute the rank-1 update once
update = inp.matmul(inp.t())
# if on GPU but memory is insufficient, move Hessian to CPU
if mod.hessian.device.type == "cuda" and not cls._is_memory_sufficient(
mod.hessian.device.index, 0.8
):
mod.hessian = mod.hessian.to("cpu")
# apply the update on whichever device mod.hessian currently resides
mod.hessian += update.to(mod.hessian.device)
🤖 Prompt for AI Agents
In modelopt/torch/sparsity/weight_sparsity/sparsegpt.py around lines 268 to 275,
the Hessian update is being applied twice on CUDA (once inside the CUDA branch
and once after), which double-counts; change the logic so you compute the new
outer-product once, decide the target_device beforehand (move mod.hessian to CPU
only if memory is insufficient), and then perform a single mod.hessian +=
outer_product.to(target_device); ensure you do not add inside the CUDA branch
and only set target_device and move mod.hessian when necessary.

Comment on lines +32 to +97
def sparsify(
model: nn.Module, mode: ModeLike, config: SearchConfig | None = None
) -> tuple[nn.Module, dict[str, Any]]:
"""Sparsify a given model and search for they optimal sparsified weights.
Args:
model: A standard model that contains standard building blocks to be sparsified in-place.
mode: A (list of) string(s) or Mode(s) or a list of tuples containing the mode and its
config indicating the desired mode(s) (and configurations) for the convert
process. Modes set up the model for different algorithms for model optimization. The
following modes are available:
* :class:`"sparse_magnitude"<modelopt.torch.sparsity.mode.SparseMagnitudeModeDescriptor>`:
The ``model`` will be sparsified according to the magnitude of weights in each
layer. The mode's config is described in
:class:`SparseMagnitudeConfig<modelopt.torch.sparsity.config.SparseMagnitudeConfig>`.
* :class:`"sparsegpt"<modelopt.torch.sparsity.mode.SparseGPTModeDescriptor>`:
The ``model`` will be sparsified and weights are updated optimally using an Hessian
approximation of the loss function (see SparseGPT paper for details). The mode's
config is described in
:class:`SparseGPTConfig<modelopt.torch.sparsity.config.SparseGPTConfig>`.
If the mode argument is specified as a dictionary, the keys should indicate the mode and
the values specify the per-mode configuration. If not provided, then default
configuration would be used.
config: Additional optional arguments to configure the search. Currently, we support:
* ``verbose``: Whether to print detailed search stats during search.
* ``forward_loop``: A ``Callable`` that takes a model as input and runs a forward loop
on it. It is recommended to choose the data loader used inside the forward loop
carefully to reduce the runtime. Cannot be provided at the same time as
``data_loader`` and ``collect_func``.
* ``data_loader``: An iterator yielding batches of data for calibrating the
normalization layers in the model or compute gradient scores. It is recommended to use
the same data loader as for training but with significantly fewer iterations. Cannot
be provided at the same time as ``forward_loop``.
* ``collect_func``: A ``Callable`` that takes a batch of data from the data loader as
input and returns the input to ``model.forward()`` as described in
:meth:`run_forward_loop <modelopt.torch.utils.network.run_forward_loop>`. Cannot
be provided at the same time as ``forward_loop``.
.. note::
Additional configuration options may be added by individual algorithms. Please
refer to the documentation of the individual algorithms for more information.
Returns: A sparsified model
.. note::
The given model is sparsified in-place. The returned model is thus a reference to the same
model instance as the input model.
"""
# apply sparsity to the model
model = apply_mode(model, mode, registry=SparsityModeRegistry)

# retrieve searcher class
searcher_cls: type[BaseSearcher] = getattr(get_mode(model), "search_algorithm")

# run search+sparsification algorithm
searcher = searcher_cls()
searcher.search(model, {}, (), config)

# return the sparsified model
return model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Return the search results as promised.

The signature advertises tuple[nn.Module, dict[str, Any]], yet the implementation returns only the model. Callers that unpack (model, state) will crash with “not enough values to unpack,” making this a functional bug. Please propagate the search results (e.g., the SearchStateDict returned by searcher.search) alongside the model so the runtime matches the declared API.

A direct fix would be:

-    searcher.search(model, {}, (), config)
-
-    # return the sparsified model
-    return model
+    search_state = searcher.search(model, {}, (), config)
+
+    # return the sparsified model and search details
+    return model, search_state
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def sparsify(
model: nn.Module, mode: ModeLike, config: SearchConfig | None = None
) -> tuple[nn.Module, dict[str, Any]]:
"""Sparsify a given model and search for they optimal sparsified weights.
Args:
model: A standard model that contains standard building blocks to be sparsified in-place.
mode: A (list of) string(s) or Mode(s) or a list of tuples containing the mode and its
config indicating the desired mode(s) (and configurations) for the convert
process. Modes set up the model for different algorithms for model optimization. The
following modes are available:
* :class:`"sparse_magnitude"<modelopt.torch.sparsity.mode.SparseMagnitudeModeDescriptor>`:
The ``model`` will be sparsified according to the magnitude of weights in each
layer. The mode's config is described in
:class:`SparseMagnitudeConfig<modelopt.torch.sparsity.config.SparseMagnitudeConfig>`.
* :class:`"sparsegpt"<modelopt.torch.sparsity.mode.SparseGPTModeDescriptor>`:
The ``model`` will be sparsified and weights are updated optimally using an Hessian
approximation of the loss function (see SparseGPT paper for details). The mode's
config is described in
:class:`SparseGPTConfig<modelopt.torch.sparsity.config.SparseGPTConfig>`.
If the mode argument is specified as a dictionary, the keys should indicate the mode and
the values specify the per-mode configuration. If not provided, then default
configuration would be used.
config: Additional optional arguments to configure the search. Currently, we support:
* ``verbose``: Whether to print detailed search stats during search.
* ``forward_loop``: A ``Callable`` that takes a model as input and runs a forward loop
on it. It is recommended to choose the data loader used inside the forward loop
carefully to reduce the runtime. Cannot be provided at the same time as
``data_loader`` and ``collect_func``.
* ``data_loader``: An iterator yielding batches of data for calibrating the
normalization layers in the model or compute gradient scores. It is recommended to use
the same data loader as for training but with significantly fewer iterations. Cannot
be provided at the same time as ``forward_loop``.
* ``collect_func``: A ``Callable`` that takes a batch of data from the data loader as
input and returns the input to ``model.forward()`` as described in
:meth:`run_forward_loop <modelopt.torch.utils.network.run_forward_loop>`. Cannot
be provided at the same time as ``forward_loop``.
.. note::
Additional configuration options may be added by individual algorithms. Please
refer to the documentation of the individual algorithms for more information.
Returns: A sparsified model
.. note::
The given model is sparsified in-place. The returned model is thus a reference to the same
model instance as the input model.
"""
# apply sparsity to the model
model = apply_mode(model, mode, registry=SparsityModeRegistry)
# retrieve searcher class
searcher_cls: type[BaseSearcher] = getattr(get_mode(model), "search_algorithm")
# run search+sparsification algorithm
searcher = searcher_cls()
searcher.search(model, {}, (), config)
# return the sparsified model
return model
def sparsify(
model: nn.Module, mode: ModeLike, config: SearchConfig | None = None
) -> tuple[nn.Module, dict[str, Any]]:
"""Sparsify a given model and search for the optimal sparsified weights.
Args:
model: A standard model that contains standard building blocks to be sparsified in-place.
mode: A (list of) string(s) or Mode(s) or a list of tuples containing the mode and its
config indicating the desired mode(s) (and configurations) for the convert
process. ...
config: Additional optional arguments to configure the search. ...
Returns: A sparsified model and the search state dict
.. note::
The given model is sparsified in-place. The returned model is thus a reference to the same
model instance as the input model.
"""
# apply sparsity to the model
model = apply_mode(model, mode, registry=SparsityModeRegistry)
# retrieve searcher class
searcher_cls: type[BaseSearcher] = getattr(get_mode(model), "search_algorithm")
# run search+sparsification algorithm
searcher = searcher_cls()
search_state = searcher.search(model, {}, (), config)
# return the sparsified model and search details
return model, search_state
🤖 Prompt for AI Agents
In modelopt/torch/sparsity/weight_sparsity/sparsification.py around lines 32-97,
the function signature promises to return tuple[nn.Module, dict[str, Any]] but
currently only returns the model; capture the result of searcher.search (e.g.,
state = searcher.search(...)), ensure it is a dict (coerce to {} if None), and
return (model, state) so callers that unpack (model, state) won’t fail and the
runtime matches the declared API.

@kaix-nv kaix-nv changed the title Kaix/skip softmax calib Sparse Attention Oct 7, 2025
@kaix-nv kaix-nv changed the title Sparse Attention Add support for sparse attention Oct 7, 2025
@kaix-nv kaix-nv changed the title Add support for sparse attention [New feature] Add support for sparse attention Oct 7, 2025
@kaix-nv kaix-nv changed the title [New feature] Add support for sparse attention [New feature] Add Support For Sparse Attention Oct 7, 2025
@kaix-nv kaix-nv requested a review from RalphMao October 8, 2025 05:46
@kaix-nv kaix-nv force-pushed the kaix/skip_softmax_calib branch from 2ab2a02 to 023159a Compare October 10, 2025 21:59
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

♻️ Duplicate comments (9)
modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py (1)

59-71: Count unexpected phases under “unknown”.

Unrecognized phases are still discarded, so the “unknown” bucket never actually captures them. Please fold any missing key back to "unknown" before incrementing, e.g.:

-        phase = stats.get("phase", "unknown")
-        if phase in self.aggregated_stats["phase_counts"]:
-            self.aggregated_stats["phase_counts"][phase] += 1
+        phase = stats.get("phase", "unknown")
+        if phase not in self.aggregated_stats["phase_counts"]:
+            phase = "unknown"
+        self.aggregated_stats["phase_counts"][phase] += 1
examples/llm_sparse_attention/hf_spar_attn.py (1)

284-289: Switch the model to eval() before sparsifying.

Dropout and other train-mode layers stay active, making calibration and verification nondeterministic. Call model.eval() right after the CUDA move (and before sparsify_model) so inference runs in deterministic eval mode:

     if torch.cuda.is_available():
         model = model.cuda()
         print("Model moved to CUDA")
 
+    model.eval()
+
     # Apply sparse attention to the model (with calibration if configured)
modelopt/torch/sparsity/weight_sparsity/plugins/megatron.py (1)

32-47: Pass metadata through the sharded_state_dict override.

The override still drops the caller-supplied metadata, so checkpoint writers depending on it lose information. Forward the argument to the base implementation.

-        sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets)
+        sharded_state_dict = super().sharded_state_dict(
+            prefix, sharded_offsets, metadata
+        )
modelopt/torch/sparsity/weight_sparsity/searcher.py (1)

46-54: Fix assertion to avoid touching unset self.config.

At this stage self.config isn’t initialized, so the assertion raises AttributeError instead of surfacing the bad pattern. Use config_sanitized (or the incoming config) in the message instead.

-        assert is_nm_prune and n == 2 and m == 4, (
-            f"Unsupported pattern {self.config['pattern']} for sparsity"
-        )
+        assert is_nm_prune and n == 2 and m == 4, (
+            f"Unsupported pattern {config_sanitized['pattern']} for sparsity"
+        )
modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py (1)

233-600: Use a dedicated RNG instead of reseeding Python’s global state.

random.seed(seed) and the module-level random.* calls mutate the process-wide RNG every time the builder runs, causing unrelated code to become deterministically seeded. Initialize a per-instance generator and route all randomness through it.

-        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
-        random.seed(seed)
+        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
+        self._rng = random.Random(seed)
@@
-        random.shuffle(all_samples)
+        self._rng.shuffle(all_samples)

Apply the same replacement for every subsequent random.* usage (e.g., randint, sample, choice, choices, shuffle, etc.) so they all call self._rng.

modelopt/torch/sparsity/attention_sparsity/methods/registry.py (1)

18-114: Select the latest method version numerically, not lexicographically.

sorted(...)[-1] will pick v9 over v10, so callers defaulting to “latest” silently regress once version numbers reach two digits. Please keep the fallback deterministic but parse the numeric suffix first.

@@
-from abc import ABC, abstractmethod
-
-import torch
+from abc import ABC, abstractmethod
+import re
+
+import torch
@@
-    if not version:
-        version = sorted(method_versions.keys())[-1]
+    if not version:
+        def _version_key(tag: str) -> tuple[int, str]:
+            match = re.match(r"^v(\d+)$", tag)
+            return (int(match.group(1)) if match else -1, tag)
+
+        version = max(method_versions.keys(), key=_version_key)
modelopt/torch/sparsity/weight_sparsity/sparsegpt.py (2)

32-134: Return the true Hessian inverse and fix downstream scaling.

invert() hands back the Cholesky factor of the inverse, not the inverse itself, so every consumer (e.g., hinv_diag_blk) reads square roots and the pruning math collapses. That cascades into dividing by hinv_diag_blk**2 and by d without stabilizers, severely distorting error estimates.

Please return the actual inverse and consume it consistently:

@@
-def invert(hessian: torch.Tensor) -> torch.Tensor:
-    """Invert a Hessian matrix."""
-    try:
-        hessian_inv = torch.linalg.cholesky(hessian)
-        hessian_inv = torch.cholesky_inverse(hessian_inv)
-        hessian_inv = torch.linalg.cholesky(hessian_inv, upper=True)
-    except RuntimeError:
-        cols = hessian.size(0)
-        eps = 1e-6 * torch.eye(cols).to(hessian.device)
-        hessian_inv = torch.cholesky_inverse(torch.linalg.cholesky(hessian + eps))
-
-    return hessian_inv
+def invert(hessian: torch.Tensor) -> torch.Tensor:
+    """Invert a Hessian matrix."""
+    try:
+        chol = torch.linalg.cholesky(hessian)
+        return torch.cholesky_inverse(chol)
+    except RuntimeError:
+        cols = hessian.size(0)
+        eps = 1e-6 * torch.eye(cols, device=hessian.device, dtype=hessian.dtype)
+        chol = torch.linalg.cholesky(hessian + eps)
+        return torch.cholesky_inverse(chol)
@@
-    hessian_inv_diag = torch.diagonal(hessian_inv, dim1=0, dim2=1)
+    hessian_inv_diag = torch.diagonal(hessian_inv, dim1=0, dim2=1)
@@
-            errors_blk = (w_blk**2) / (hinv_diag_blk**2 + 1e-9)
+            errors_blk = (w_blk**2) / (hinv_diag_blk + 1e-9)
@@
-                if is_nm_prune and j % m == 0:
-                    errors_blk = (w_blk[:, j : j + m] ** 2) / (hinv_diag_blk[j : j + m] ** 2 + 1e-9)
+                if is_nm_prune and j % m == 0:
+                    errors_blk = (w_blk[:, j : j + m] ** 2) / (
+                        hinv_diag_blk[j : j + m] + 1e-9
+                    )
@@
-                err = (w - q) / d
+                err = (w - q) / (d + 1e-9)

266-275: Avoid double-counting the Hessian update on CUDA.

The CUDA branch adds inp.matmul(inp.t()) once inside the branch and then again unconditionally, inflating statistics whenever GPU memory is sufficient.

@@
-            if mod.hessian.device.type == "cuda":
-                if cls._is_memory_sufficient(mod.hessian.device.index, 0.8):
-                    mod.hessian += inp.matmul(inp.t()).to(mod.hessian.device)
-                else:
-                    target_device = "cpu"
-                    mod.hessian = mod.hessian.to("cpu")
-            mod.hessian += inp.matmul(inp.t()).to(target_device)
+            update = inp.matmul(inp.t())
+            if mod.hessian.device.type == "cuda" and not cls._is_memory_sufficient(
+                mod.hessian.device.index, 0.8
+            ):
+                target_device = "cpu"
+                mod.hessian = mod.hessian.to("cpu")
+            mod.hessian += update.to(target_device)
modelopt/torch/sparsity/attention_sparsity/conversion.py (1)

198-232: Rebuild sparse method instance after restoring configuration.

As flagged in a previous review comment, after restoring module._method and module._method_config from the saved state (lines 212-222), the existing module._sparse_method_instance retains its pre-restoration configuration. Calling module._setup() (line 225) without first rebuilding the method instance means:

  • The method instance is never reconstructed with the restored configuration
  • Only threshold_scale_factor and target_sparsity are manually patched afterward (lines 228-231)
  • Other restored method configuration fields remain stale in the instance

This causes the restored model's runtime behavior to diverge from the saved sparse attention configuration.

Apply this diff to rebuild the method instance with the restored configuration:

         if "method_config" in module_state:
             # Restore config attributes
             # Separate method instance attributes from module attributes
             method_instance_attrs = {"threshold_scale_factor", "target_sparsity"}

             for key, val in module_state["method_config"].items():
                 if key not in method_instance_attrs:
                     # Set on module
                     setattr(module, f"_{key}", val)

+        # Rebuild sparse method with restored config before final setup
+        module._init_sparse_method()
+
         # Re-setup with restored config
         module._setup()

         # Restore method instance attributes after _setup
         if "method_config" in module_state:
             for key, val in module_state["method_config"].items():
                 if key in {"threshold_scale_factor", "target_sparsity"}:
                     setattr(module._sparse_method_instance, key, val)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2ab2a02 and 023159a.

📒 Files selected for processing (41)
  • examples/llm_sparse_attention/README.md (1 hunks)
  • examples/llm_sparse_attention/hf_spar_attn.py (1 hunks)
  • examples/llm_sparse_attention/requirements.txt (1 hunks)
  • modelopt/torch/export/unified_export_hf.py (2 hunks)
  • modelopt/torch/sparsity/attention_sparsity/__init__.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/config.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/conversion.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/methods/__init__.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/methods/registry.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/mode.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/model_sparsify.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/nn/__init__.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py (1 hunks)
  • modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py (1 hunks)
  • modelopt/torch/sparsity/weight_sparsity/__init__.py (1 hunks)
  • modelopt/torch/sparsity/weight_sparsity/config.py (1 hunks)
  • modelopt/torch/sparsity/weight_sparsity/magnitude.py (1 hunks)
  • modelopt/torch/sparsity/weight_sparsity/mode.py (1 hunks)
  • modelopt/torch/sparsity/weight_sparsity/module.py (1 hunks)
  • modelopt/torch/sparsity/weight_sparsity/plugins/__init__.py (1 hunks)
  • modelopt/torch/sparsity/weight_sparsity/plugins/megatron.py (1 hunks)
  • modelopt/torch/sparsity/weight_sparsity/searcher.py (1 hunks)
  • modelopt/torch/sparsity/weight_sparsity/sparsegpt.py (1 hunks)
  • modelopt/torch/sparsity/weight_sparsity/sparsification.py (1 hunks)
  • tests/_test_utils/torch_sparsity/sparse_attention_common.py (1 hunks)
  • tests/gpu/torch/sparsity/attention_sparsity/test_basic_gpu.py (1 hunks)
  • tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py (1 hunks)
  • tests/gpu/torch/sparsity/attention_sparsity/test_export_gpu.py (1 hunks)
  • tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py (1 hunks)
  • tests/unit/torch/sparsity/attention_sparsity/test_export_config.py (1 hunks)
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py (1 hunks)
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py (1 hunks)
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py (1 hunks)
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • examples/llm_sparse_attention/requirements.txt
🚧 Files skipped from review as they are similar to previous changes (12)
  • modelopt/torch/sparsity/attention_sparsity/nn/init.py
  • modelopt/torch/sparsity/attention_sparsity/calibration/init.py
  • modelopt/torch/sparsity/weight_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/mode.py
  • modelopt/torch/sparsity/attention_sparsity/methods/init.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py
  • modelopt/torch/sparsity/weight_sparsity/plugins/init.py
  • tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py
  • modelopt/torch/sparsity/weight_sparsity/sparsification.py
  • modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py
  • modelopt/torch/export/unified_export_hf.py
  • modelopt/torch/sparsity/weight_sparsity/module.py
🧰 Additional context used
🧬 Code graph analysis (26)
modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py (1)
modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py (1)
  • set_calibration_mode (58-60)
tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py (3)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
  • SparseAttentionConfig (292-319)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
  • SparseAttentionModule (29-201)
modelopt/torch/opt/conversion.py (1)
  • modelopt_state (444-486)
modelopt/torch/sparsity/weight_sparsity/__init__.py (1)
modelopt/torch/sparsity/weight_sparsity/module.py (1)
  • SparseModule (31-87)
modelopt/torch/sparsity/attention_sparsity/model_sparsify.py (3)
modelopt/torch/opt/conversion.py (3)
  • ModeloptStateManager (63-311)
  • apply_mode (342-429)
  • is_converted (102-127)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (2)
  • calibrate_sparse_attention (112-177)
  • forward_loop (97-107)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
  • SparseAttentionConfig (292-319)
examples/llm_sparse_attention/hf_spar_attn.py (5)
modelopt/torch/export/unified_export_hf.py (1)
  • export_hf_checkpoint (622-682)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
  • SparseAttentionConfig (292-319)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (4)
  • SparseAttentionModule (29-201)
  • get_stats (130-139)
  • disable (121-123)
  • enable (117-119)
modelopt/torch/utils/memory_monitor.py (1)
  • launch_memory_monitor (134-151)
modelopt/torch/sparsity/attention_sparsity/model_sparsify.py (1)
  • sparsify (35-148)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
modelopt/torch/opt/config.py (3)
  • ModeloptBaseConfig (59-147)
  • ModeloptField (50-53)
  • keys (132-134)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (5)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
  • CalibrationConfig (159-247)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
  • SparseAttentionModule (29-201)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (2)
  • DynamicThresholdCalibrator (31-308)
  • calibrate (77-219)
modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py (2)
  • RulerDatasetBuilder (181-602)
  • build_calibration_dataset (236-262)
modelopt/torch/sparsity/attention_sparsity/model_sparsify.py (1)
  • calibrate (151-197)
modelopt/torch/sparsity/weight_sparsity/plugins/megatron.py (3)
modelopt/torch/opt/plugins/megatron.py (1)
  • _MegatronMLP (120-142)
modelopt/torch/sparsity/weight_sparsity/module.py (1)
  • SparseModule (31-87)
modelopt/torch/opt/config.py (1)
  • register_default (227-237)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (6)
modelopt/torch/opt/dynamic.py (3)
  • DynamicModule (338-914)
  • _DMRegistryCls (917-1124)
  • config (1265-1278)
modelopt/torch/quantization/utils.py (1)
  • replace_function (300-307)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
  • SparseAttentionAttributeConfig (34-156)
modelopt/torch/sparsity/attention_sparsity/methods/registry.py (2)
  • get_sparse_method (88-120)
  • apply_sparsity (27-44)
modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py (3)
  • SparseAttentionStatsManager (7-125)
  • get_summary (74-92)
  • collect (43-72)
modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py (1)
  • apply_sparsity (241-284)
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py (3)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (2)
  • SimpleAttentionModel (27-46)
  • SimpleTransformerEncoderLayer (49-69)
modelopt/torch/sparsity/attention_sparsity/conversion.py (2)
  • disable_sparse_attention (278-305)
  • enable_sparse_attention (308-335)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
  • SparseAttentionModule (29-201)
modelopt/torch/sparsity/weight_sparsity/magnitude.py (2)
modelopt/torch/sparsity/weight_sparsity/module.py (1)
  • SparseModule (31-87)
modelopt/torch/sparsity/weight_sparsity/searcher.py (3)
  • BaseSparseSearcher (31-84)
  • _check_weight_size (59-61)
  • _compute_mask (64-66)
modelopt/torch/sparsity/weight_sparsity/sparsegpt.py (4)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/sparsity/weight_sparsity/magnitude.py (3)
  • get_nmprune_info (29-35)
  • _check_weight_size (134-144)
  • _compute_mask (146-148)
modelopt/torch/sparsity/weight_sparsity/module.py (1)
  • SparseModule (31-87)
modelopt/torch/sparsity/weight_sparsity/searcher.py (5)
  • BaseSparseSearcher (31-84)
  • default_search_config (37-39)
  • _check_weight_size (59-61)
  • _compute_mask (64-66)
  • _named_sparsifiable_modules (68-76)
modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py (1)
modelopt/torch/utils/random.py (3)
  • random (59-61)
  • shuffle (148-150)
  • choice (64-89)
tests/gpu/torch/sparsity/attention_sparsity/test_basic_gpu.py (1)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (9)
  • SimpleAttentionModel (27-46)
  • SimpleTransformerEncoder (72-90)
  • SimpleTransformerEncoderLayer (49-69)
  • get_test_configs (142-147)
  • save_restore_test (185-218)
  • sparsify_model_and_forward (150-182)
  • get_input (44-46)
  • get_input (67-69)
  • get_input (88-90)
tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py (5)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (5)
  • SimpleTransformerEncoderLayer (49-69)
  • forward_loop (162-164)
  • get_input (44-46)
  • get_input (67-69)
  • get_input (88-90)
modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py (2)
  • RulerDatasetBuilder (181-602)
  • build_calibration_dataset (236-262)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
  • SparseAttentionModule (29-201)
tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py (1)
  • forward_loop (210-214)
modelopt/torch/opt/conversion.py (2)
  • modelopt_state (444-486)
  • restore_from_modelopt_state (510-567)
modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py (1)
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py (1)
  • register_sparse_attention_on_the_fly (55-100)
modelopt/torch/sparsity/attention_sparsity/conversion.py (5)
modelopt/torch/opt/conversion.py (4)
  • ModelLikeModule (318-330)
  • ModeloptStateManager (63-311)
  • init_modellike (326-330)
  • state_version (135-137)
modelopt/torch/utils/network.py (1)
  • get_unwrapped_name (599-612)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
  • SparseAttentionConfig (292-319)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (6)
  • SparseAttentionModule (29-201)
  • set_from_attribute_config (61-107)
  • _setup (141-153)
  • disable (121-123)
  • enable (117-119)
  • is_enabled (126-128)
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py (2)
  • register_sparse_attention_on_the_fly (55-100)
  • _setup (33-38)
modelopt/torch/sparsity/weight_sparsity/mode.py (6)
modelopt/torch/opt/dynamic.py (3)
  • config (1265-1278)
  • DynamicSpace (1127-1368)
  • convert_to_dynamic (1138-1187)
modelopt/torch/opt/config.py (1)
  • ModeloptBaseConfig (59-147)
modelopt/torch/opt/conversion.py (1)
  • ApplyModeError (314-315)
modelopt/torch/opt/mode.py (2)
  • ModeDescriptor (56-259)
  • _ModeRegistryCls (267-344)
modelopt/torch/utils/network.py (2)
  • compare_dict (423-427)
  • unwrap_model (430-454)
modelopt/torch/sparsity/weight_sparsity/config.py (1)
  • ExportSparseConfig (50-51)
tests/unit/torch/sparsity/attention_sparsity/test_export_config.py (3)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (5)
  • SimpleTransformerEncoderLayer (49-69)
  • sparsify_model_and_forward (150-182)
  • get_input (44-46)
  • get_input (67-69)
  • get_input (88-90)
modelopt/torch/export/unified_export_hf.py (1)
  • _get_sparse_attention_config (341-461)
modelopt/torch/sparsity/attention_sparsity/conversion.py (1)
  • disable_sparse_attention (278-305)
tests/gpu/torch/sparsity/attention_sparsity/test_export_gpu.py (4)
tests/_test_utils/torch_model/transformers_models.py (1)
  • create_tiny_llama_dir (117-131)
modelopt/torch/export/unified_export_hf.py (1)
  • export_hf_checkpoint (622-682)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
  • SparseAttentionConfig (292-319)
modelopt/torch/sparsity/attention_sparsity/conversion.py (1)
  • disable_sparse_attention (278-305)
modelopt/torch/sparsity/weight_sparsity/searcher.py (5)
modelopt/torch/opt/searcher.py (1)
  • BaseSearcher (52-271)
modelopt/torch/utils/logging.py (1)
  • print_rank_0 (92-95)
modelopt/torch/sparsity/weight_sparsity/module.py (2)
  • SparseModule (31-87)
  • set_mask (68-87)
modelopt/torch/sparsity/weight_sparsity/magnitude.py (1)
  • get_nmprune_info (29-35)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (3)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
  • SparseAttentionModule (29-201)
modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py (3)
  • SparseAttentionStatsManager (7-125)
  • set_calibration_mode (94-106)
  • get_calibration_stats (118-125)
modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py (1)
  • set_calibration_mode (58-60)
modelopt/torch/sparsity/attention_sparsity/methods/registry.py (1)
modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py (2)
  • apply_sparsity (241-284)
  • name (287-289)
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py (4)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (5)
  • SimpleAttentionModel (27-46)
  • forward_loop (162-164)
  • forward (39-41)
  • forward (63-64)
  • forward (84-85)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)
  • DynamicThresholdCalibrator (31-308)
modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py (3)
  • RulerDatasetBuilder (181-602)
  • _generate_target_lengths (28-56)
  • build_calibration_dataset (236-262)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
  • SparseAttentionModule (29-201)
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py (3)
modelopt/torch/opt/dynamic.py (3)
  • DynamicModule (338-914)
  • get_original_cls_by_level (905-914)
  • level (158-160)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (2)
  • SparseAttentionModule (29-201)
  • _setup (141-153)
modelopt/torch/export/layer_utils.py (1)
  • is_attention (318-320)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (2)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
  • SparseAttentionModule (29-201)
modelopt/torch/opt/conversion.py (2)
  • modelopt_state (444-486)
  • restore_from_modelopt_state (510-567)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (5)
modelopt/torch/sparsity/weight_sparsity/__init__.py (1)

18-23: Public API wiring looks solid.

The package initializer cleanly surfaces the expected registries and modules, matching the surrounding structure. Nice work keeping the namespace consistent.

modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py (1)

18-22: Plugin export hub looks good.

Centralizing the HuggingFace registration hook here keeps the plugin surface tidy and discoverable.

tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py (1)

36-96: Great coverage on conversion flows.

These cases hit replacement, toggling, and pattern selection, giving solid confidence in the conversion logic.

tests/unit/torch/sparsity/attention_sparsity/test_export_config.py (1)

32-203: Excellent validation of export config serialization.

This suite thoroughly covers default, phase-aware, disabled, calibrated, and metadata paths, so regressions in _get_sparse_attention_config should surface quickly.

tests/_test_utils/torch_sparsity/sparse_attention_common.py (1)

1-218: Well-structured test utilities.

This module provides clean, reusable test infrastructure for sparse attention testing:

  • Three test model classes with consistent APIs (forward and get_input class methods)
  • Comprehensive configuration constants covering different sparse attention scenarios
  • Helper functions that properly implement forward passes during calibration (lines 162-164) and validation (lines 175-180)
  • Robust save/restore testing with output verification (lines 207-217)

These utilities establish good patterns that should be followed by the test code in test_calibration_gpu.py.

Comment on lines +326 to +330
"--backend",
type=str,
default="pytorch",
choices=["pytorch", "triton"],
help="Backend to use for sparse attention computation (default: pytorch)",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Remove the unsupported triton backend option.

SparseAttentionAttributeConfig.validate_backend hard-rejects any backend other than "pytorch", so choosing "triton" causes an immediate crash. Until Triton support lands, constrain the CLI to the actually supported backend (and mirror the README update):

-    parser.add_argument(
-        "--backend",
-        type=str,
-        default="pytorch",
-        choices=["pytorch", "triton"],
-        help="Backend to use for sparse attention computation (default: pytorch)",
-    )
+    parser.add_argument(
+        "--backend",
+        type=str,
+        default="pytorch",
+        choices=["pytorch"],
+        help="Backend to use for sparse attention computation (currently only 'pytorch').",
+    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"--backend",
type=str,
default="pytorch",
choices=["pytorch", "triton"],
help="Backend to use for sparse attention computation (default: pytorch)",
parser.add_argument(
"--backend",
type=str,
default="pytorch",
choices=["pytorch"],
help="Backend to use for sparse attention computation (currently only 'pytorch').",
)
🤖 Prompt for AI Agents
In examples/llm_sparse_attention/hf_spar_attn.py around lines 326 to 330, the
CLI parser currently includes "triton" as a choice for the --backend option even
though SparseAttentionAttributeConfig.validate_backend only accepts "pytorch";
remove "triton" from the choices and update the default/choices to only include
"pytorch" so the CLI cannot select an unsupported backend and crash at
validation time.

Comment on lines +119 to +120
- `--backend`: Backend for computation - pytorch or triton (default: pytorch)
- `--seq_len`: Maximum sequence length (default: 2048)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Align backend docs with actual support.

The options table advertises a triton backend, but SparseAttentionAttributeConfig.validate_backend rejects anything other than pytorch. Following the README literally causes the CLI to fail. Please either document it as “pytorch only” (and mention Triton as future work) or update the implementation to accept the advertised backend.

🤖 Prompt for AI Agents
examples/llm_sparse_attention/README.md lines 119-120: the README claims a
`triton` backend but the code (SparseAttentionAttributeConfig.validate_backend)
only accepts `pytorch`, causing the CLI to fail; either (A) update the README to
list backend as "pytorch (only)" and add a short note that Triton is future
work, or (B) modify validate_backend to accept `"triton"` as a valid value and
wire through any backend-specific switches/placeholders so the CLI won't reject
`--backend triton` (ensure tests/usage handle the triton path or clearly stub
it).

Comment on lines +78 to +83
mask = torch.IntTensor(matrix.shape).fill_(1).view(-1, m)
mat, _ = reshape_1d(matrix, m)
pmax = torch.argmax(torch.matmul(mat.abs(), patterns.t()), dim=1)
mask[:] = patterns[pmax[:]]
mask = mask.view(matrix.shape)
return mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix device mismatch when constructing mask

mask = torch.IntTensor(...) always allocates on CPU. When the weight tensor lives on CUDA, assigning patterns[pmax] (CUDA) into this CPU mask throws Expected all tensors to be on the same device. Allocate the mask on the same device as the weights instead.

Apply this diff:

-    mask = torch.IntTensor(matrix.shape).fill_(1).view(-1, m)
+    mask = torch.ones(matrix.shape, dtype=patterns.dtype, device=matrix.device).view(-1, m)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
mask = torch.IntTensor(matrix.shape).fill_(1).view(-1, m)
mat, _ = reshape_1d(matrix, m)
pmax = torch.argmax(torch.matmul(mat.abs(), patterns.t()), dim=1)
mask[:] = patterns[pmax[:]]
mask = mask.view(matrix.shape)
return mask
mask = torch.ones(matrix.shape, dtype=patterns.dtype, device=matrix.device).view(-1, m)
mat, _ = reshape_1d(matrix, m)
pmax = torch.argmax(torch.matmul(mat.abs(), patterns.t()), dim=1)
mask[:] = patterns[pmax[:]]
mask = mask.view(matrix.shape)
return mask

Comment on lines +135 to +137
def forward_loop(model):
# Simple forward loop for calibration
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Empty forward loops prevent proper calibration.

Throughout this test file, all forward_loop functions contain only pass statements, meaning no forward passes are executed during calibration. Calibration requires running the model on data to collect statistics (e.g., attention score distributions) to tune thresholds. Without executing forward passes, the calibration process cannot collect meaningful statistics, potentially causing:

  • Calibration to silently fail or use default/untuned thresholds
  • Tests to pass incorrectly without validating actual calibration behavior
  • Divergence from the intended calibration workflow

Reference implementation from tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py (lines 209-213) shows the correct pattern:

def forward_loop(model):
    """Simple forward loop for calibration."""
    test_input = torch.randint(0, 32000, (1, 64), device="cuda")
    with torch.no_grad():
        model(test_input)

Similarly, the helper function sparsify_model_and_forward in tests/_test_utils/torch_sparsity/sparse_attention_common.py (lines 162-164) demonstrates proper forward pass execution.

Apply this diff to fix the forward loops. For example, in test_calibration_simple_model:

     def forward_loop(model):
-        # Simple forward loop for calibration
-        pass
+        # Simple forward loop for calibration
+        test_input = SimpleTransformerEncoderLayer.get_input(d_model=256, seq_len=64).cuda()
+        with torch.no_grad():
+            model(test_input)

Apply similar changes to all other empty forward_loop functions in this file, ensuring the input dimensions match the model configuration (e.g., d_model=256 for models in TestCalibrationGPU and TestCalibrationEndToEnd).

Also applies to: 173-174, 204-205, 231-232, 279-280, 325-326, 373-374

Comment on lines +32 to +44
@pytest.fixture(scope="module")
def tinyllama_model():
"""Load TinyLlama model for testing."""
try:
model = AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
torch_dtype=torch.float16,
device_map="cuda",
)
return model
except Exception as e:
pytest.skip(f"Could not load TinyLlama model: {e}")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reset TinyLlama fixture between tests to avoid stale sparsified state.

sparse_attn.sparsify mutates the passed model in place and marks it as converted via the Modelopt state manager. Because this fixture is module-scoped, the very first test leaves the shared TinyLlama instance already sparsified. Every subsequent test then reuses that mutated object, so new configs/backends (and especially calibration forward loops) are silently skipped. The later cases therefore never exercise the code paths they intend to validate. Please return a fresh model per test—e.g., switch the fixture to function scope or clone/reset the model before sparsifying—so each scenario runs against an unsparsified baseline.

🤖 Prompt for AI Agents
In tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py around
lines 32 to 44, the TinyLlama fixture is module-scoped so the model is mutated
by sparse_attn.sparsify and reused across tests; change the fixture to return a
fresh unsparsified model per test by switching scope="module" to
scope="function" (or, alternatively, load/clone/reset the model inside the
fixture before returning) so each test gets an independent instance and
sparsification state does not leak between tests.

@kaix-nv kaix-nv requested a review from mxinO October 13, 2025 17:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant