-
Couldn't load subscription status.
- Fork 183
[New feature] Add Support For Sparse Attention #408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughAdds a new sparse-attention subsystem and weight-sparsity features: configs, registries, runtime SparseAttentionModule, flash_softmax_skip method, RULER calibration (dataset + calibrator), HuggingFace plugins, HF export augmentation, CLI example, and extensive unit/GPU tests and utilities. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor U as User/CLI
participant M as HF/PyTorch Model
participant S as SparseAttentionMode
participant C as Conversion
participant SA as SparseAttentionModule
participant SM as Sparse Method
participant ST as StatsManager
U->>S: sparsify(model, config)
S->>C: convert_to_sparse_attention_model(...)
C->>M: replace attention with `SparseAttentionModule`
Note over M,SA: Model now contains sparse-aware attention
U->>M: forward(input)
M->>SA: attention forward
SA->>SM: apply_sparsity(q,k,v, scores)
SM-->>SA: masked_scores + stats
SA->>ST: collect(stats)
SA->>M: continue forward with masked_scores
sequenceDiagram
autonumber
actor U as User
participant Cal as Calibration Driver
participant DS as RULER Dataset
participant Tok as Tokenizer
participant Dyn as DynamicThresholdCalibrator
participant SA as SparseAttentionModule
U->>Cal: calibrate_sparse_attention(model, config)
Cal->>DS: build_calibration_dataset() (if forward_loop not provided)
Cal->>Tok: AutoTokenizer.from_pretrained(...)
Cal->>Dyn: run calibrator with forward_loop
loop thresholds x samples
Dyn->>SA: set threshold, enable calibration mode
SA->>Model: forward sample -> stats
SA-->>Dyn: collect per-sample sparsity stats
end
Dyn->>Dyn: fit scale_factor = a*(1/length) via regression
Dyn->>SA: apply calibrated scale_factor to modules
Dyn-->>U: return calibration results
sequenceDiagram
autonumber
actor U as User
participant Exp as export_hf_checkpoint
participant M as Model
participant G as _get_sparse_attention_config
U->>Exp: export_hf_checkpoint(model, dst)
Exp->>G: scan model for sparse modules
alt sparse config found
G-->>Exp: sparse_attention_config dict
Exp->>Exp: inject into config.json and write
else
G-->>Exp: {}
end
Exp-->>U: files written
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120+ minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 12
🧹 Nitpick comments (2)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
125-131: Enforce integer block sizes.
The validator still accepts floats (e.g., 128.5), contradicting the “block size” contract. Reject non-integral values up front.@field_validator("br", "bc") @classmethod def validate_block_size(cls, v): """Validate block sizes are positive integers.""" - if v <= 0: - raise ValueError(f"Block size must be positive, got {v}") + if not isinstance(v, int): + raise ValueError(f"Block size must be an integer, got type {type(v).__name__}") + if v <= 0: + raise ValueError(f"Block size must be positive, got {v}") return vtests/_test_utils/torch_sparsity/sparse_attention_common.py (1)
203-205: Capture the restored module returned bymto.restore_from_modelopt_state
restore_from_modelopt_statemay hand back a different module object (especially when starting from a modellike tuple). Ignoring its return risks operating on the unmodified instance. Reassign the helper’smodel_restoredto the returned module before loading weights.Apply this diff:
- mto.restore_from_modelopt_state(model_restored, state_dict) + model_restored = mto.restore_from_modelopt_state(model_restored, state_dict)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (40)
examples/llm_sparse_attention/README.md(1 hunks)examples/llm_sparse_attention/hf_spar_attn.py(1 hunks)modelopt/torch/export/unified_export_hf.py(2 hunks)modelopt/torch/sparsity/attention_sparsity/__init__.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/config.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/conversion.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/methods/__init__.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/methods/registry.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/mode.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/model_sparsify.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/nn/__init__.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py(1 hunks)modelopt/torch/sparsity/weight_sparsity/__init__.py(1 hunks)modelopt/torch/sparsity/weight_sparsity/config.py(1 hunks)modelopt/torch/sparsity/weight_sparsity/magnitude.py(1 hunks)modelopt/torch/sparsity/weight_sparsity/mode.py(1 hunks)modelopt/torch/sparsity/weight_sparsity/module.py(1 hunks)modelopt/torch/sparsity/weight_sparsity/plugins/__init__.py(1 hunks)modelopt/torch/sparsity/weight_sparsity/plugins/megatron.py(1 hunks)modelopt/torch/sparsity/weight_sparsity/searcher.py(1 hunks)modelopt/torch/sparsity/weight_sparsity/sparsegpt.py(1 hunks)modelopt/torch/sparsity/weight_sparsity/sparsification.py(1 hunks)tests/_test_utils/torch_sparsity/sparse_attention_common.py(1 hunks)tests/gpu/torch/sparsity/attention_sparsity/test_basic_gpu.py(1 hunks)tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py(1 hunks)tests/gpu/torch/sparsity/attention_sparsity/test_export_gpu.py(1 hunks)tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py(1 hunks)tests/unit/torch/sparsity/attention_sparsity/test_export_config.py(1 hunks)tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py(1 hunks)tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py(1 hunks)tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py(1 hunks)tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (36)
modelopt/torch/sparsity/attention_sparsity/methods/__init__.py (1)
modelopt/torch/sparsity/attention_sparsity/methods/registry.py (3)
SparseAttentionMethod(23-49)get_sparse_method(88-120)register_sparse_method(56-85)
tests/unit/torch/sparsity/attention_sparsity/test_export_config.py (3)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (5)
SimpleTransformerEncoderLayer(49-69)sparsify_model_and_forward(150-182)get_input(44-46)get_input(67-69)get_input(88-90)modelopt/torch/export/unified_export_hf.py (1)
_get_sparse_attention_config(340-460)modelopt/torch/sparsity/attention_sparsity/conversion.py (1)
disable_sparse_attention(278-305)
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py (3)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (2)
SimpleAttentionModel(27-46)SimpleTransformerEncoderLayer(49-69)modelopt/torch/sparsity/attention_sparsity/conversion.py (2)
disable_sparse_attention(278-305)enable_sparse_attention(308-335)modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
SparseAttentionModule(29-201)
modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py (1)
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py (1)
register_sparse_attention_on_the_fly(55-100)
modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py (2)
modelopt/torch/nas/utils.py (1)
enabled(226-227)modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py (1)
set_calibration_mode(58-60)
modelopt/torch/sparsity/weight_sparsity/config.py (1)
modelopt/torch/opt/config.py (2)
ModeloptBaseConfig(59-147)get_kwargs_for_create_model_with_rules(322-383)
modelopt/torch/sparsity/attention_sparsity/methods/registry.py (1)
modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py (2)
apply_sparsity(241-284)name(287-289)
modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py (4)
modelopt/torch/sparsity/attention_sparsity/model_sparsify.py (1)
calibrate(151-197)modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (1)
calibrate_sparse_attention(112-177)modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)
DynamicThresholdCalibrator(31-308)modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py (1)
RulerDatasetBuilder(181-602)
tests/gpu/torch/sparsity/attention_sparsity/test_export_gpu.py (4)
tests/_test_utils/torch_model/transformers_models.py (1)
create_tiny_llama_dir(117-131)modelopt/torch/export/unified_export_hf.py (1)
export_hf_checkpoint(621-677)modelopt/torch/sparsity/attention_sparsity/config.py (1)
SparseAttentionConfig(292-319)modelopt/torch/sparsity/attention_sparsity/conversion.py (1)
disable_sparse_attention(278-305)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (6)
modelopt/torch/opt/dynamic.py (3)
DynamicModule(338-914)_DMRegistryCls(917-1124)config(1265-1278)modelopt/torch/quantization/utils.py (1)
replace_function(300-307)modelopt/torch/sparsity/attention_sparsity/config.py (1)
SparseAttentionAttributeConfig(34-156)modelopt/torch/sparsity/attention_sparsity/methods/registry.py (2)
get_sparse_method(88-120)apply_sparsity(27-44)modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py (3)
SparseAttentionStatsManager(7-125)get_summary(74-92)collect(43-72)modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py (1)
apply_sparsity(241-284)
tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py (3)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
SparseAttentionConfig(292-319)modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
SparseAttentionModule(29-201)modelopt/torch/opt/conversion.py (1)
modelopt_state(444-486)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
modelopt/torch/opt/config.py (3)
ModeloptBaseConfig(59-147)ModeloptField(50-53)keys(132-134)
modelopt/torch/sparsity/weight_sparsity/__init__.py (1)
modelopt/torch/sparsity/weight_sparsity/module.py (1)
SparseModule(31-87)
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py (1)
modelopt/torch/sparsity/attention_sparsity/config.py (3)
FlashSoftmaxSkipConfig(322-346)SparseAttentionAttributeConfig(34-156)SparseAttentionConfig(292-319)
examples/llm_sparse_attention/hf_spar_attn.py (5)
modelopt/torch/export/unified_export_hf.py (1)
export_hf_checkpoint(621-677)modelopt/torch/sparsity/attention_sparsity/config.py (1)
SparseAttentionConfig(292-319)modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (4)
SparseAttentionModule(29-201)get_stats(130-139)disable(121-123)enable(117-119)modelopt/torch/utils/memory_monitor.py (1)
launch_memory_monitor(134-151)modelopt/torch/sparsity/attention_sparsity/model_sparsify.py (1)
sparsify(35-148)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (5)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
CalibrationConfig(159-247)modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
SparseAttentionModule(29-201)modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (2)
DynamicThresholdCalibrator(31-308)calibrate(77-219)modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py (2)
RulerDatasetBuilder(181-602)build_calibration_dataset(236-262)modelopt/torch/sparsity/attention_sparsity/model_sparsify.py (1)
calibrate(151-197)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (6)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
SparseAttentionModule(29-201)modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py (3)
SparseAttentionStatsManager(7-125)set_calibration_mode(94-106)get_calibration_stats(118-125)modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (1)
forward_loop(97-107)tests/_test_utils/torch_sparsity/sparse_attention_common.py (1)
forward_loop(162-164)tests/gpu/torch/sparsity/attention_sparsity/test_export_gpu.py (1)
forward_loop(189-196)modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py (1)
set_calibration_mode(58-60)
modelopt/torch/sparsity/attention_sparsity/mode.py (4)
modelopt/torch/opt/config.py (1)
ModeloptBaseConfig(59-147)modelopt/torch/opt/mode.py (2)
ModeDescriptor(56-259)_ModeRegistryCls(267-344)modelopt/torch/sparsity/attention_sparsity/config.py (1)
SparseAttentionConfig(292-319)modelopt/torch/sparsity/attention_sparsity/conversion.py (3)
convert_to_sparse_attention_model(47-76)restore_sparse_attention_model(175-195)update_sparse_attention_metadata(234-275)
modelopt/torch/sparsity/weight_sparsity/magnitude.py (2)
modelopt/torch/sparsity/weight_sparsity/module.py (1)
SparseModule(31-87)modelopt/torch/sparsity/weight_sparsity/searcher.py (3)
BaseSparseSearcher(31-84)_check_weight_size(59-61)_compute_mask(64-66)
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py (1)
modelopt/torch/opt/mode.py (2)
_ModeRegistryCls(267-344)get_from_any(326-336)
modelopt/torch/export/unified_export_hf.py (3)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
SparseAttentionModule(29-201)modelopt/torch/trace/symbols.py (1)
named_modules(444-447)modelopt/torch/opt/dynamic.py (3)
original_cls(867-873)get_original_cls_by_level(905-914)level(158-160)
modelopt/torch/sparsity/attention_sparsity/conversion.py (5)
modelopt/torch/opt/conversion.py (4)
ModelLikeModule(318-330)ModeloptStateManager(63-311)init_modellike(326-330)state_version(135-137)modelopt/torch/utils/network.py (1)
get_unwrapped_name(599-612)modelopt/torch/sparsity/attention_sparsity/config.py (1)
SparseAttentionConfig(292-319)modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (6)
SparseAttentionModule(29-201)set_from_attribute_config(61-107)_setup(141-153)disable(121-123)enable(117-119)is_enabled(126-128)modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py (2)
register_sparse_attention_on_the_fly(55-100)_setup(33-38)
modelopt/torch/sparsity/attention_sparsity/model_sparsify.py (4)
modelopt/torch/opt/conversion.py (3)
ModeloptStateManager(63-311)apply_mode(342-429)is_converted(102-127)modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (2)
calibrate_sparse_attention(112-177)forward_loop(97-107)modelopt/torch/sparsity/attention_sparsity/config.py (1)
SparseAttentionConfig(292-319)modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)
calibrate(77-219)
modelopt/torch/sparsity/attention_sparsity/nn/__init__.py (2)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
SparseAttentionModule(29-201)modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py (1)
SparseAttentionStatsManager(7-125)
modelopt/torch/sparsity/weight_sparsity/module.py (2)
modelopt/torch/opt/dynamic.py (4)
DynamicModule(338-914)_register_hparam(383-423)_register_temp_attribute(476-513)_register_dynamic_attribute(425-474)modelopt/torch/opt/hparam.py (1)
Hparam(48-275)
modelopt/torch/sparsity/weight_sparsity/sparsegpt.py (4)
modelopt/torch/utils/logging.py (1)
print_rank_0(92-95)modelopt/torch/sparsity/weight_sparsity/magnitude.py (3)
get_nmprune_info(29-35)_check_weight_size(134-144)_compute_mask(146-148)modelopt/torch/sparsity/weight_sparsity/module.py (1)
SparseModule(31-87)modelopt/torch/sparsity/weight_sparsity/searcher.py (5)
BaseSparseSearcher(31-84)default_search_config(37-39)_check_weight_size(59-61)_compute_mask(64-66)_named_sparsifiable_modules(68-76)
modelopt/torch/sparsity/weight_sparsity/plugins/megatron.py (3)
modelopt/torch/opt/plugins/megatron.py (1)
_MegatronMLP(120-142)modelopt/torch/sparsity/weight_sparsity/module.py (1)
SparseModule(31-87)modelopt/torch/opt/config.py (1)
register_default(227-237)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (2)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
SparseAttentionModule(29-201)modelopt/torch/opt/conversion.py (2)
modelopt_state(444-486)restore_from_modelopt_state(510-567)
tests/gpu/torch/sparsity/attention_sparsity/test_basic_gpu.py (1)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (9)
SimpleAttentionModel(27-46)SimpleTransformerEncoder(72-90)SimpleTransformerEncoderLayer(49-69)get_test_configs(142-147)save_restore_test(185-218)sparsify_model_and_forward(150-182)get_input(44-46)get_input(67-69)get_input(88-90)
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py (2)
modelopt/torch/opt/dynamic.py (3)
DynamicModule(338-914)get_original_cls_by_level(905-914)level(158-160)modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (2)
SparseAttentionModule(29-201)_setup(141-153)
modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py (2)
modelopt/torch/sparsity/attention_sparsity/methods/registry.py (4)
SparseAttentionMethod(23-49)register_sparse_method(56-85)apply_sparsity(27-44)name(48-49)modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py (1)
set_calibration_mode(94-106)
modelopt/torch/sparsity/weight_sparsity/sparsification.py (3)
modelopt/torch/opt/conversion.py (2)
apply_mode(342-429)get_mode(432-441)modelopt/torch/opt/searcher.py (1)
BaseSearcher(52-271)modelopt/torch/utils/network.py (1)
unwrap_model(430-454)
modelopt/torch/sparsity/weight_sparsity/mode.py (4)
modelopt/torch/opt/dynamic.py (2)
config(1265-1278)convert_to_dynamic(1138-1187)modelopt/torch/opt/conversion.py (1)
ApplyModeError(314-315)modelopt/torch/opt/mode.py (2)
ModeDescriptor(56-259)_ModeRegistryCls(267-344)modelopt/torch/utils/network.py (2)
compare_dict(423-427)unwrap_model(430-454)
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py (3)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (5)
SimpleAttentionModel(27-46)forward_loop(162-164)forward(39-41)forward(63-64)forward(84-85)modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py (3)
RulerDatasetBuilder(181-602)_generate_target_lengths(28-56)build_calibration_dataset(236-262)modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
SparseAttentionModule(29-201)
tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py (4)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (4)
SimpleTransformerEncoderLayer(49-69)get_input(44-46)get_input(67-69)get_input(88-90)modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py (2)
RulerDatasetBuilder(181-602)build_calibration_dataset(236-262)modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
SparseAttentionModule(29-201)modelopt/torch/opt/conversion.py (2)
modelopt_state(444-486)restore_from_modelopt_state(510-567)
modelopt/torch/sparsity/weight_sparsity/searcher.py (4)
modelopt/torch/opt/searcher.py (1)
BaseSearcher(52-271)modelopt/torch/utils/logging.py (1)
print_rank_0(92-95)modelopt/torch/sparsity/weight_sparsity/module.py (2)
SparseModule(31-87)set_mask(68-87)modelopt/torch/sparsity/weight_sparsity/magnitude.py (1)
get_nmprune_info(29-35)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
| args.pyt_ckpt_path, | ||
| attn_implementation="eager", | ||
| torch_dtype=torch.bfloat16, | ||
| ) | ||
| tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path) | ||
|
|
||
| # Set pad token if not set | ||
| if tokenizer.pad_token is None: | ||
| tokenizer.pad_token = tokenizer.eos_token | ||
|
|
||
| # Move model to GPU if available | ||
| if torch.cuda.is_available(): | ||
| model = model.cuda() | ||
| print("Model moved to CUDA") | ||
|
|
||
| # Apply sparse attention to the model (with calibration if configured) | ||
| model = sparsify_model(model, args) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Call model.eval() before running inference
The model stays in training mode after loading (dropout, layernorm stats, etc.), so the sparse-attention calibration and the generation comparison run with stochastic behavior. That breaks verification and produces inconsistent results. Move the model to eval mode before sparsifying/exporting.
model = AutoModelForCausalLM.from_pretrained(
args.pyt_ckpt_path,
attn_implementation="eager",
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
# Set pad token if not set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Move model to GPU if available
if torch.cuda.is_available():
model = model.cuda()
print("Model moved to CUDA")
+ model.eval()
+
# Apply sparse attention to the model (with calibration if configured)
model = sparsify_model(model, args)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| args.pyt_ckpt_path, | |
| attn_implementation="eager", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path) | |
| # Set pad token if not set | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Move model to GPU if available | |
| if torch.cuda.is_available(): | |
| model = model.cuda() | |
| print("Model moved to CUDA") | |
| # Apply sparse attention to the model (with calibration if configured) | |
| model = sparsify_model(model, args) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| args.pyt_ckpt_path, | |
| attn_implementation="eager", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path) | |
| # Set pad token if not set | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Move model to GPU if available | |
| if torch.cuda.is_available(): | |
| model = model.cuda() | |
| print("Model moved to CUDA") | |
| model.eval() | |
| # Apply sparse attention to the model (with calibration if configured) | |
| model = sparsify_model(model, args) |
🤖 Prompt for AI Agents
In examples/llm_sparse_attention/hf_spar_attn.py around lines 273 to 290, the
model remains in training mode which allows dropout and other stochastic
behaviors during sparse-attention calibration and generation; call model.eval()
after loading (and after moving to CUDA if applicable) and before calling
sparsify_model or running any calibration/generation to ensure deterministic
behavior during verification.
| # Collect all enabled sparse attention modules | ||
| sparse_modules = [] | ||
| for name, module in model.named_modules(): | ||
| if isinstance(module, SparseAttentionModule) and module.is_enabled: | ||
| sparse_modules.append((name, module)) | ||
|
|
||
| if not sparse_modules: | ||
| return {} | ||
|
|
||
| sparse_config = { | ||
| "config_groups": {}, | ||
| "producer": { | ||
| "name": "modelopt", | ||
| "version": __version__, | ||
| }, | ||
| } | ||
|
|
||
| # Check first module for global calibration parameters | ||
| # (all modules share the same calibration parameters) | ||
| first_module = sparse_modules[0][1] | ||
| method_instance = first_module._sparse_method_instance | ||
| threshold_scale_factor = getattr(method_instance, "threshold_scale_factor", None) | ||
|
|
||
| if threshold_scale_factor is not None: | ||
| # Model was calibrated: add global calibration parameters | ||
| sparse_config["threshold_scale_factor"] = float(threshold_scale_factor) | ||
|
|
||
| target_sparsity = getattr(method_instance, "target_sparsity", None) | ||
| if target_sparsity is not None: | ||
| sparse_config["target_sparsity"] = float(target_sparsity) | ||
|
|
||
| # Group modules by configuration | ||
| # Key: (sparse_algo, threshold_repr), Value: list of module class names | ||
| config_to_targets = {} | ||
|
|
||
| for name, module in sparse_modules: | ||
| method_instance = module._sparse_method_instance | ||
|
|
||
| # Extract sparse algorithm name from method name | ||
| # e.g., "flash_softmax_skip" -> "softmax_skip" | ||
| method_name = method_instance.name | ||
| if method_name.startswith("flash_"): | ||
| sparse_algo = method_name[6:] # Remove "flash_" prefix | ||
| else: | ||
| sparse_algo = method_name | ||
|
|
||
| # Get module's original class name for targets | ||
| # Get the class name before SparseAttentionModule wrapping | ||
| original_cls = module.get_original_cls_by_level(level=0) | ||
| target_class_name = original_cls.__name__ | ||
|
|
||
| # Build config key for grouping | ||
| if threshold_scale_factor is None: | ||
| # Not calibrated: include threshold in grouping | ||
| threshold_config = getattr(method_instance, "threshold_config", None) | ||
| if isinstance(threshold_config, dict): | ||
| # Convert dict to tuple for hashable key | ||
| threshold_repr = tuple(sorted(threshold_config.items())) | ||
| else: | ||
| threshold_repr = threshold_config | ||
| else: | ||
| # Calibrated: no threshold in per-layer config | ||
| threshold_repr = None | ||
|
|
||
| config_key = (sparse_algo, threshold_repr) | ||
|
|
||
| if config_key not in config_to_targets: | ||
| config_to_targets[config_key] = { | ||
| "sparse_algo": sparse_algo, | ||
| "threshold_config": threshold_config if threshold_scale_factor is None else None, | ||
| "targets": set(), | ||
| } | ||
|
|
||
| config_to_targets[config_key]["targets"].add(target_class_name) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ensure sparse modules are fully initialized before export.
module.is_enabled can return True even before _setup() runs, so _sparse_method_instance might not exist yet. Accessing it here will raise AttributeError for modules that were converted but never forwarded. Call module._setup() (or otherwise guarantee initialization) before dereferencing _sparse_method_instance.
🤖 Prompt for AI Agents
In modelopt/torch/export/unified_export_hf.py around lines 364 to 438, some
SparseAttentionModule instances may report is_enabled=True before they have been
initialized, so accessing module._sparse_method_instance can raise
AttributeError; call module._setup() (or otherwise ensure initialization) for
each module in the sparse_modules collection before dereferencing
_sparse_method_instance, and guard access with a check that the attribute exists
after setup (skip or log modules that still lack the attribute) so export
proceeds only with fully-initialized modules.
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) | ||
| random.seed(seed) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid resetting Python’s global RNG
Calling random.seed(seed) here mutates the process-wide RNG state every time the builder is constructed. Any code relying on randomness after calibration (e.g., sampling batches, data augmentation) becomes silently deterministic, which is a major cross-cutting bug. Instantiate a dedicated random.Random (and thread it through the helpers) instead of reseeding the global RNG.
🤖 Prompt for AI Agents
In modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py around
lines 233 to 235, the constructor currently calls random.seed(seed) which
mutates the global RNG; instead create a dedicated RNG instance (e.g.,
self.random = random.Random(seed)) and remove the global seed call, then replace
any uses of the global random module in this class and pass this self.random
into any helper functions or sub-objects that need randomness so they call
methods on the instance (self.random.random(), self.random.choice(), etc.)
rather than the module-level functions.
| def restore_sparse_attention_state(model: nn.Module, state_dict: dict[str, Any]): | ||
| """Restore sparse attention state from state dict. | ||
| Args: | ||
| model: Model with sparse attention modules | ||
| state_dict: Saved state dictionary | ||
| """ | ||
| for name, module in model.named_modules(): | ||
| if isinstance(module, SparseAttentionModule): | ||
| module_name = get_unwrapped_name(name, model) | ||
| if module_name in state_dict: | ||
| module_state = state_dict[module_name] | ||
|
|
||
| # Restore method and config | ||
| if "method" in module_state: | ||
| module._method = module_state["method"] | ||
| if "method_config" in module_state: | ||
| # Restore config attributes | ||
| # Separate method instance attributes from module attributes | ||
| method_instance_attrs = {"threshold_scale_factor", "target_sparsity"} | ||
|
|
||
| for key, val in module_state["method_config"].items(): | ||
| if key not in method_instance_attrs: | ||
| # Set on module | ||
| setattr(module, f"_{key}", val) | ||
|
|
||
| # Re-setup with restored config | ||
| module._setup() | ||
|
|
||
| # Restore method instance attributes after _setup | ||
| if "method_config" in module_state: | ||
| for key, val in module_state["method_config"].items(): | ||
| if key in {"threshold_scale_factor", "target_sparsity"}: | ||
| setattr(module._sparse_method_instance, key, val) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reinitialize sparse method when restoring state
After loading metadata["sparse_attention_state"], we overwrite _method and _method_config, yet _sparse_method_instance is left pointing to the pre-existing object that still carries the old configuration. Any calibrated thresholds or even method swaps captured in metadata therefore never reach the runtime path—only threshold_scale_factor/target_sparsity get patched manually. Call _init_sparse_method() (before _setup()) once the module-level attributes are restored so the method instance is rebuilt from the restored config. Otherwise, restored models diverge from the saved sparse behavior.
Apply this diff:
- if "method_config" in module_state:
- for key, val in module_state["method_config"].items():
- if key not in method_instance_attrs:
- # Set on module
- setattr(module, f"_{key}", val)
-
- # Re-setup with restored config
- module._setup()
+ if "method_config" in module_state:
+ for key, val in module_state["method_config"].items():
+ if key not in method_instance_attrs:
+ # Set on module
+ setattr(module, f"_{key}", val)
+
+ # Rebuild sparse method with restored config before final setup
+ module._init_sparse_method()
+ # Re-setup with restored config
+ module._setup()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def restore_sparse_attention_state(model: nn.Module, state_dict: dict[str, Any]): | |
| """Restore sparse attention state from state dict. | |
| Args: | |
| model: Model with sparse attention modules | |
| state_dict: Saved state dictionary | |
| """ | |
| for name, module in model.named_modules(): | |
| if isinstance(module, SparseAttentionModule): | |
| module_name = get_unwrapped_name(name, model) | |
| if module_name in state_dict: | |
| module_state = state_dict[module_name] | |
| # Restore method and config | |
| if "method" in module_state: | |
| module._method = module_state["method"] | |
| if "method_config" in module_state: | |
| # Restore config attributes | |
| # Separate method instance attributes from module attributes | |
| method_instance_attrs = {"threshold_scale_factor", "target_sparsity"} | |
| for key, val in module_state["method_config"].items(): | |
| if key not in method_instance_attrs: | |
| # Set on module | |
| setattr(module, f"_{key}", val) | |
| # Re-setup with restored config | |
| module._setup() | |
| # Restore method instance attributes after _setup | |
| if "method_config" in module_state: | |
| for key, val in module_state["method_config"].items(): | |
| if key in {"threshold_scale_factor", "target_sparsity"}: | |
| setattr(module._sparse_method_instance, key, val) | |
| def restore_sparse_attention_state(model: nn.Module, state_dict: dict[str, Any]): | |
| """Restore sparse attention state from state dict. | |
| Args: | |
| model: Model with sparse attention modules | |
| state_dict: Saved state dictionary | |
| """ | |
| for name, module in model.named_modules(): | |
| if isinstance(module, SparseAttentionModule): | |
| module_name = get_unwrapped_name(name, model) | |
| if module_name in state_dict: | |
| module_state = state_dict[module_name] | |
| # Restore method and config | |
| if "method" in module_state: | |
| module._method = module_state["method"] | |
| if "method_config" in module_state: | |
| # Restore config attributes | |
| # Separate method instance attributes from module attributes | |
| method_instance_attrs = {"threshold_scale_factor", "target_sparsity"} | |
| for key, val in module_state["method_config"].items(): | |
| if key not in method_instance_attrs: | |
| # Set on module | |
| setattr(module, f"_{key}", val) | |
| # Rebuild sparse method with restored config before final setup | |
| module._init_sparse_method() | |
| # Re-setup with restored config | |
| module._setup() | |
| # Restore method instance attributes after _setup | |
| if "method_config" in module_state: | |
| for key, val in module_state["method_config"].items(): | |
| if key in {"threshold_scale_factor", "target_sparsity"}: | |
| setattr(module._sparse_method_instance, key, val) |
🤖 Prompt for AI Agents
In modelopt/torch/sparsity/attention_sparsity/conversion.py around lines 198 to
232, when restoring a SparseAttentionModule you set module._method and
module-level config but do not recreate the method instance, leaving
module._sparse_method_instance with stale state; after restoring module-level
attributes from module_state (i.e. after setting module._method and any
module._* config fields) call module._init_sparse_method() to reinitialize the
method instance, then proceed to module._setup() and finally patch any
instance-specific attributes (threshold_scale_factor, target_sparsity) if
present in module_state so the runtime method reflects the restored config.
| def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): | ||
| sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets) | ||
|
|
||
| sparse_state_dict = { | ||
| k: v | ||
| for k, v in self.state_dict(prefix="", keep_vars=True).items() | ||
| if k == "_weight_mask" | ||
| } | ||
|
|
||
| sharded_axis_dict = self._get_shard_axis_dict(sparse_state_dict) | ||
|
|
||
| if sparse_state_dict: | ||
| sharded_state_dict.update( | ||
| **make_sharded_tensors_for_checkpoint( | ||
| sparse_state_dict, prefix, sharded_axis_dict, sharded_offsets | ||
| ) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Propagate metadata when calling the base sharded_state_dict.
The override discards the caller-supplied metadata, breaking checkpoint writers that depend on it.
- def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
- sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets)
+ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
+ sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): | |
| sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets) | |
| sparse_state_dict = { | |
| k: v | |
| for k, v in self.state_dict(prefix="", keep_vars=True).items() | |
| if k == "_weight_mask" | |
| } | |
| sharded_axis_dict = self._get_shard_axis_dict(sparse_state_dict) | |
| if sparse_state_dict: | |
| sharded_state_dict.update( | |
| **make_sharded_tensors_for_checkpoint( | |
| sparse_state_dict, prefix, sharded_axis_dict, sharded_offsets | |
| ) | |
| ) | |
| def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): | |
| sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) | |
| sparse_state_dict = { | |
| k: v | |
| for k, v in self.state_dict(prefix="", keep_vars=True).items() | |
| if k == "_weight_mask" | |
| } | |
| sharded_axis_dict = self._get_shard_axis_dict(sparse_state_dict) | |
| if sparse_state_dict: | |
| sharded_state_dict.update( | |
| **make_sharded_tensors_for_checkpoint( | |
| sparse_state_dict, prefix, sharded_axis_dict, sharded_offsets | |
| ) | |
| ) |
🤖 Prompt for AI Agents
In modelopt/torch/sparsity/weight_sparsity/plugins/megatron.py around lines 32
to 48, the override of sharded_state_dict discards the caller-supplied metadata
when calling the base implementation; update the call to
super().sharded_state_dict to pass the metadata through (e.g.,
super().sharded_state_dict(prefix, sharded_offsets, metadata=metadata) or
equivalent named parameter), leaving the rest of the logic unchanged so
checkpoint writers that rely on metadata continue to receive it.
| is_nm_prune, n, m = asp.get_nmprune_info(config_sanitized["pattern"]) | ||
| assert is_nm_prune and n == 2 and m == 4, ( | ||
| f"Unsupported pattern {self.config['pattern']} for sparsity" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix misuse of self.config in assertion.
When the pattern is invalid, the assertion message accesses self.config, but self.config is not populated until after sanitize_search_config returns. That triggers an AttributeError instead of the intended assertion, hiding the real cause from the user. Please reference config_sanitized (or the incoming config) in the message instead.
Apply this diff:
- assert is_nm_prune and n == 2 and m == 4, (
- f"Unsupported pattern {self.config['pattern']} for sparsity"
- )
+ assert is_nm_prune and n == 2 and m == 4, (
+ f"Unsupported pattern {config_sanitized['pattern']} for sparsity"
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| is_nm_prune, n, m = asp.get_nmprune_info(config_sanitized["pattern"]) | |
| assert is_nm_prune and n == 2 and m == 4, ( | |
| f"Unsupported pattern {self.config['pattern']} for sparsity" | |
| ) | |
| is_nm_prune, n, m = asp.get_nmprune_info(config_sanitized["pattern"]) | |
| assert is_nm_prune and n == 2 and m == 4, ( | |
| f"Unsupported pattern {config_sanitized['pattern']} for sparsity" | |
| ) |
🤖 Prompt for AI Agents
In modelopt/torch/sparsity/weight_sparsity/searcher.py around lines 51 to 54,
the assertion error message wrongly references self.config which isn't set yet
and can raise AttributeError; change the assertion message to reference
config_sanitized (or the incoming config) instead so the correct pattern value
is reported when the assertion fails, e.g. use config_sanitized["pattern"] in
the f-string and keep the existing is_nm_prune and n/m checks unchanged.
| def invert(hessian: torch.Tensor) -> torch.Tensor: | ||
| """Invert a Hessian matrix.""" | ||
| try: | ||
| hessian_inv = torch.linalg.cholesky(hessian) | ||
| hessian_inv = torch.cholesky_inverse(hessian_inv) | ||
| hessian_inv = torch.linalg.cholesky(hessian_inv, upper=True) | ||
| except RuntimeError: | ||
| cols = hessian.size(0) | ||
| eps = 1e-6 * torch.eye(cols).to(hessian.device) | ||
| hessian_inv = torch.cholesky_inverse(torch.linalg.cholesky(hessian + eps)) | ||
|
|
||
| return hessian_inv | ||
|
|
||
|
|
||
| def prepare( | ||
| tensor: torch.Tensor, hessian: torch.Tensor, hessian_damp: float | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Prepare the inverse Hessian matrix.""" | ||
| weight = tensor.detach().clone() | ||
| # move the hessian matrix from CPU to GPU for acceleration | ||
| hessian = hessian.to(weight.device) | ||
| if len(weight.size()) == 4: | ||
| weight = weight.flatten(1) | ||
|
|
||
| zero = torch.diag(hessian) == 0 | ||
| hessian[zero, zero] = 1 | ||
| weight[:, zero] = 0 | ||
|
|
||
| damp = hessian_damp * torch.mean(torch.diag(hessian)) | ||
| cols = weight.size(1) | ||
| diag = torch.arange(cols) | ||
| hessian[diag, diag] += damp | ||
|
|
||
| hessian_inv = invert(hessian) | ||
|
|
||
| # remove the Hessian matrix to save GPU memory | ||
| del hessian | ||
| torch.cuda.empty_cache() | ||
|
|
||
| return weight, hessian_inv | ||
|
|
||
|
|
||
| @torch.no_grad() | ||
| def create_sgpt_mask( | ||
| tensor: torch.Tensor, hessian: torch.Tensor, config: SearchConfig | ||
| ) -> torch.Tensor: | ||
| """Create a sparse mask for the given tensor.""" | ||
| shape = tensor.size() | ||
| weight, hessian_inv = prepare(tensor, hessian, config["hessian_damp"]) | ||
| rows, cols = weight.size() | ||
| hessian_inv_diag = torch.diagonal(hessian_inv, dim1=0, dim2=1) | ||
|
|
||
| is_nm_prune, n, m = get_nmprune_info(config["pattern"]) | ||
| col_bs = config["col_block_size"] | ||
| row_bs = config["row_block_size"] | ||
| # if row_bs is not specified, prune the whole weight block | ||
| if row_bs == -1: | ||
| row_bs = rows | ||
|
|
||
| for r1 in range(0, rows, row_bs): | ||
| r2 = min(r1 + row_bs, rows) | ||
| # the mask of the weights not to be pruned | ||
| w_rows = weight[r1:r2].float() | ||
|
|
||
| # pruning the weight block W[row:row+row_bs, i1:i1+col_bs] | ||
| for i1 in range(0, cols, col_bs): | ||
| i2 = min(i1 + col_bs, cols) | ||
| w_blk = w_rows[:, i1:i2].clone() | ||
| q_blk = torch.zeros_like(w_blk) | ||
| # the error of the weights to be pruned | ||
| delta_blk = torch.zeros_like(w_blk) | ||
| hinv_blk = hessian_inv[i1:i2, i1:i2] | ||
| hinv_diag_blk = hessian_inv_diag[i1:i2] | ||
|
|
||
| errors_blk = (w_blk**2) / (hinv_diag_blk**2 + 1e-9) | ||
| if torch.isnan(errors_blk).any(): | ||
| print("nan in errors_blk.") | ||
|
|
||
| mask_blk = torch.zeros_like(w_blk, dtype=torch.bool) | ||
|
|
||
| for j in range(i2 - i1): | ||
| # compute the error of the weights to be pruned | ||
| w = w_blk[:, j] | ||
| d = hinv_diag_blk[j] | ||
| if is_nm_prune and j % m == 0: | ||
| errors_blk = (w_blk[:, j : j + m] ** 2) / (hinv_diag_blk[j : j + m] ** 2 + 1e-9) | ||
| mask_blk.scatter_( | ||
| 1, j + torch.topk(errors_blk, n, dim=1, largest=False)[1], True | ||
| ) | ||
|
|
||
| q = w.clone() | ||
| q[mask_blk[:, j]] = 0 | ||
| q_blk[:, j] = q | ||
|
|
||
| # update the remaining weights in the col_bs block to compensate the error caused by pruning W[:, j] | ||
| err = (w - q) / d | ||
| w_blk[:, j:] -= err.unsqueeze(1).matmul(hinv_blk[j, j:].unsqueeze(0)) | ||
| delta_blk[:, j] = err | ||
|
|
||
| # compensate the error caused by pruning W[:, i: i + col_bs] with the weights update in W[:, i + col_bs:] | ||
| w_rows[:, i1:i2] = q_blk | ||
| w_rows[:, i2:] -= delta_blk.matmul(hessian_inv[i1:i2, i2:]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix incorrect Hessian “inverse” usage
invert() currently returns the Cholesky factor of the inverse (upper‑triangular) instead of the actual inverse matrix. Downstream code (e.g., Lines 105‑134) then treats it as a full inverse—leading to inconsistent scaling (hinv_diag_blk is the square root of the true diagonal) and using triangular blocks where the dense inverse is required. The resulting masks/weight updates are numerically wrong and can severely degrade pruning quality. Please return the genuine inverse and consume it consistently (no extra square root assumptions):
@@
- try:
- hessian_inv = torch.linalg.cholesky(hessian)
- hessian_inv = torch.cholesky_inverse(hessian_inv)
- hessian_inv = torch.linalg.cholesky(hessian_inv, upper=True)
+ try:
+ chol = torch.linalg.cholesky(hessian)
+ hessian_inv = torch.cholesky_inverse(chol)
@@
- hessian_inv_diag = torch.diagonal(hessian_inv, dim1=0, dim2=1)
+ hessian_inv_diag = torch.diagonal(hessian_inv, dim1=0, dim2=1)
@@
- errors_blk = (w_blk**2) / (hinv_diag_blk**2 + 1e-9)
+ errors_blk = (w_blk**2) / (hinv_diag_blk + 1e-9)
@@
- if is_nm_prune and j % m == 0:
- errors_blk = (w_blk[:, j : j + m] ** 2) / (hinv_diag_blk[j : j + m] ** 2 + 1e-9)
+ if is_nm_prune and j % m == 0:
+ errors_blk = (w_blk[:, j : j + m] ** 2) / (hinv_diag_blk[j : j + m] + 1e-9)
@@
- err = (w - q) / d
+ err = (w - q) / (d + 1e-9)Adjust any related calculations to use the full inverse (no extra Cholesky).
| if mod.hessian.device.type == "cuda": | ||
| if cls._is_memory_sufficient(mod.hessian.device.index, 0.8): | ||
| mod.hessian += inp.matmul(inp.t()).to(mod.hessian.device) | ||
| else: | ||
| target_device = "cpu" | ||
| mod.hessian = mod.hessian.to("cpu") | ||
| mod.hessian += inp.matmul(inp.t()).to(target_device) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid double-counting Hessian updates on CUDA
When enough GPU memory is available, the hook adds inp.matmul(inp.t()) twice: once inside the CUDA branch and once again right after, inflating Hessian statistics and corrupting the mask search. Add the update exactly once and only move to CPU when needed:
- if mod.hessian.device.type == "cuda":
- if cls._is_memory_sufficient(mod.hessian.device.index, 0.8):
- mod.hessian += inp.matmul(inp.t()).to(mod.hessian.device)
- else:
- target_device = "cpu"
- mod.hessian = mod.hessian.to("cpu")
- mod.hessian += inp.matmul(inp.t()).to(target_device)
+ update = inp.matmul(inp.t())
+ if mod.hessian.device.type == "cuda" and not cls._is_memory_sufficient(
+ mod.hessian.device.index, 0.8
+ ):
+ mod.hessian = mod.hessian.to("cpu")
+ mod.hessian += update.to(mod.hessian.device)This keeps the accumulation correct while still supporting the CPU fallback.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if mod.hessian.device.type == "cuda": | |
| if cls._is_memory_sufficient(mod.hessian.device.index, 0.8): | |
| mod.hessian += inp.matmul(inp.t()).to(mod.hessian.device) | |
| else: | |
| target_device = "cpu" | |
| mod.hessian = mod.hessian.to("cpu") | |
| mod.hessian += inp.matmul(inp.t()).to(target_device) | |
| # compute the rank-1 update once | |
| update = inp.matmul(inp.t()) | |
| # if on GPU but memory is insufficient, move Hessian to CPU | |
| if mod.hessian.device.type == "cuda" and not cls._is_memory_sufficient( | |
| mod.hessian.device.index, 0.8 | |
| ): | |
| mod.hessian = mod.hessian.to("cpu") | |
| # apply the update on whichever device mod.hessian currently resides | |
| mod.hessian += update.to(mod.hessian.device) |
🤖 Prompt for AI Agents
In modelopt/torch/sparsity/weight_sparsity/sparsegpt.py around lines 268 to 275,
the Hessian update is being applied twice on CUDA (once inside the CUDA branch
and once after), which double-counts; change the logic so you compute the new
outer-product once, decide the target_device beforehand (move mod.hessian to CPU
only if memory is insufficient), and then perform a single mod.hessian +=
outer_product.to(target_device); ensure you do not add inside the CUDA branch
and only set target_device and move mod.hessian when necessary.
| def sparsify( | ||
| model: nn.Module, mode: ModeLike, config: SearchConfig | None = None | ||
| ) -> tuple[nn.Module, dict[str, Any]]: | ||
| """Sparsify a given model and search for they optimal sparsified weights. | ||
| Args: | ||
| model: A standard model that contains standard building blocks to be sparsified in-place. | ||
| mode: A (list of) string(s) or Mode(s) or a list of tuples containing the mode and its | ||
| config indicating the desired mode(s) (and configurations) for the convert | ||
| process. Modes set up the model for different algorithms for model optimization. The | ||
| following modes are available: | ||
| * :class:`"sparse_magnitude"<modelopt.torch.sparsity.mode.SparseMagnitudeModeDescriptor>`: | ||
| The ``model`` will be sparsified according to the magnitude of weights in each | ||
| layer. The mode's config is described in | ||
| :class:`SparseMagnitudeConfig<modelopt.torch.sparsity.config.SparseMagnitudeConfig>`. | ||
| * :class:`"sparsegpt"<modelopt.torch.sparsity.mode.SparseGPTModeDescriptor>`: | ||
| The ``model`` will be sparsified and weights are updated optimally using an Hessian | ||
| approximation of the loss function (see SparseGPT paper for details). The mode's | ||
| config is described in | ||
| :class:`SparseGPTConfig<modelopt.torch.sparsity.config.SparseGPTConfig>`. | ||
| If the mode argument is specified as a dictionary, the keys should indicate the mode and | ||
| the values specify the per-mode configuration. If not provided, then default | ||
| configuration would be used. | ||
| config: Additional optional arguments to configure the search. Currently, we support: | ||
| * ``verbose``: Whether to print detailed search stats during search. | ||
| * ``forward_loop``: A ``Callable`` that takes a model as input and runs a forward loop | ||
| on it. It is recommended to choose the data loader used inside the forward loop | ||
| carefully to reduce the runtime. Cannot be provided at the same time as | ||
| ``data_loader`` and ``collect_func``. | ||
| * ``data_loader``: An iterator yielding batches of data for calibrating the | ||
| normalization layers in the model or compute gradient scores. It is recommended to use | ||
| the same data loader as for training but with significantly fewer iterations. Cannot | ||
| be provided at the same time as ``forward_loop``. | ||
| * ``collect_func``: A ``Callable`` that takes a batch of data from the data loader as | ||
| input and returns the input to ``model.forward()`` as described in | ||
| :meth:`run_forward_loop <modelopt.torch.utils.network.run_forward_loop>`. Cannot | ||
| be provided at the same time as ``forward_loop``. | ||
| .. note:: | ||
| Additional configuration options may be added by individual algorithms. Please | ||
| refer to the documentation of the individual algorithms for more information. | ||
| Returns: A sparsified model | ||
| .. note:: | ||
| The given model is sparsified in-place. The returned model is thus a reference to the same | ||
| model instance as the input model. | ||
| """ | ||
| # apply sparsity to the model | ||
| model = apply_mode(model, mode, registry=SparsityModeRegistry) | ||
|
|
||
| # retrieve searcher class | ||
| searcher_cls: type[BaseSearcher] = getattr(get_mode(model), "search_algorithm") | ||
|
|
||
| # run search+sparsification algorithm | ||
| searcher = searcher_cls() | ||
| searcher.search(model, {}, (), config) | ||
|
|
||
| # return the sparsified model | ||
| return model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return the search results as promised.
The signature advertises tuple[nn.Module, dict[str, Any]], yet the implementation returns only the model. Callers that unpack (model, state) will crash with “not enough values to unpack,” making this a functional bug. Please propagate the search results (e.g., the SearchStateDict returned by searcher.search) alongside the model so the runtime matches the declared API.
A direct fix would be:
- searcher.search(model, {}, (), config)
-
- # return the sparsified model
- return model
+ search_state = searcher.search(model, {}, (), config)
+
+ # return the sparsified model and search details
+ return model, search_state📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def sparsify( | |
| model: nn.Module, mode: ModeLike, config: SearchConfig | None = None | |
| ) -> tuple[nn.Module, dict[str, Any]]: | |
| """Sparsify a given model and search for they optimal sparsified weights. | |
| Args: | |
| model: A standard model that contains standard building blocks to be sparsified in-place. | |
| mode: A (list of) string(s) or Mode(s) or a list of tuples containing the mode and its | |
| config indicating the desired mode(s) (and configurations) for the convert | |
| process. Modes set up the model for different algorithms for model optimization. The | |
| following modes are available: | |
| * :class:`"sparse_magnitude"<modelopt.torch.sparsity.mode.SparseMagnitudeModeDescriptor>`: | |
| The ``model`` will be sparsified according to the magnitude of weights in each | |
| layer. The mode's config is described in | |
| :class:`SparseMagnitudeConfig<modelopt.torch.sparsity.config.SparseMagnitudeConfig>`. | |
| * :class:`"sparsegpt"<modelopt.torch.sparsity.mode.SparseGPTModeDescriptor>`: | |
| The ``model`` will be sparsified and weights are updated optimally using an Hessian | |
| approximation of the loss function (see SparseGPT paper for details). The mode's | |
| config is described in | |
| :class:`SparseGPTConfig<modelopt.torch.sparsity.config.SparseGPTConfig>`. | |
| If the mode argument is specified as a dictionary, the keys should indicate the mode and | |
| the values specify the per-mode configuration. If not provided, then default | |
| configuration would be used. | |
| config: Additional optional arguments to configure the search. Currently, we support: | |
| * ``verbose``: Whether to print detailed search stats during search. | |
| * ``forward_loop``: A ``Callable`` that takes a model as input and runs a forward loop | |
| on it. It is recommended to choose the data loader used inside the forward loop | |
| carefully to reduce the runtime. Cannot be provided at the same time as | |
| ``data_loader`` and ``collect_func``. | |
| * ``data_loader``: An iterator yielding batches of data for calibrating the | |
| normalization layers in the model or compute gradient scores. It is recommended to use | |
| the same data loader as for training but with significantly fewer iterations. Cannot | |
| be provided at the same time as ``forward_loop``. | |
| * ``collect_func``: A ``Callable`` that takes a batch of data from the data loader as | |
| input and returns the input to ``model.forward()`` as described in | |
| :meth:`run_forward_loop <modelopt.torch.utils.network.run_forward_loop>`. Cannot | |
| be provided at the same time as ``forward_loop``. | |
| .. note:: | |
| Additional configuration options may be added by individual algorithms. Please | |
| refer to the documentation of the individual algorithms for more information. | |
| Returns: A sparsified model | |
| .. note:: | |
| The given model is sparsified in-place. The returned model is thus a reference to the same | |
| model instance as the input model. | |
| """ | |
| # apply sparsity to the model | |
| model = apply_mode(model, mode, registry=SparsityModeRegistry) | |
| # retrieve searcher class | |
| searcher_cls: type[BaseSearcher] = getattr(get_mode(model), "search_algorithm") | |
| # run search+sparsification algorithm | |
| searcher = searcher_cls() | |
| searcher.search(model, {}, (), config) | |
| # return the sparsified model | |
| return model | |
| def sparsify( | |
| model: nn.Module, mode: ModeLike, config: SearchConfig | None = None | |
| ) -> tuple[nn.Module, dict[str, Any]]: | |
| """Sparsify a given model and search for the optimal sparsified weights. | |
| Args: | |
| model: A standard model that contains standard building blocks to be sparsified in-place. | |
| mode: A (list of) string(s) or Mode(s) or a list of tuples containing the mode and its | |
| config indicating the desired mode(s) (and configurations) for the convert | |
| process. ... | |
| config: Additional optional arguments to configure the search. ... | |
| Returns: A sparsified model and the search state dict | |
| .. note:: | |
| The given model is sparsified in-place. The returned model is thus a reference to the same | |
| model instance as the input model. | |
| """ | |
| # apply sparsity to the model | |
| model = apply_mode(model, mode, registry=SparsityModeRegistry) | |
| # retrieve searcher class | |
| searcher_cls: type[BaseSearcher] = getattr(get_mode(model), "search_algorithm") | |
| # run search+sparsification algorithm | |
| searcher = searcher_cls() | |
| search_state = searcher.search(model, {}, (), config) | |
| # return the sparsified model and search details | |
| return model, search_state |
🤖 Prompt for AI Agents
In modelopt/torch/sparsity/weight_sparsity/sparsification.py around lines 32-97,
the function signature promises to return tuple[nn.Module, dict[str, Any]] but
currently only returns the model; capture the result of searcher.search (e.g.,
state = searcher.search(...)), ensure it is a dict (coerce to {} if None), and
return (model, state) so callers that unpack (model, state) won’t fail and the
runtime matches the declared API.
Signed-off-by: Kai Xu <[email protected]>
Signed-off-by: Kai Xu <[email protected]>
Signed-off-by: Kai Xu <[email protected]>
Signed-off-by: Kai Xu <[email protected]>
2ab2a02 to
023159a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
♻️ Duplicate comments (9)
modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py (1)
59-71: Count unexpected phases under “unknown”.Unrecognized phases are still discarded, so the “unknown” bucket never actually captures them. Please fold any missing key back to
"unknown"before incrementing, e.g.:- phase = stats.get("phase", "unknown") - if phase in self.aggregated_stats["phase_counts"]: - self.aggregated_stats["phase_counts"][phase] += 1 + phase = stats.get("phase", "unknown") + if phase not in self.aggregated_stats["phase_counts"]: + phase = "unknown" + self.aggregated_stats["phase_counts"][phase] += 1examples/llm_sparse_attention/hf_spar_attn.py (1)
284-289: Switch the model to eval() before sparsifying.Dropout and other train-mode layers stay active, making calibration and verification nondeterministic. Call
model.eval()right after the CUDA move (and beforesparsify_model) so inference runs in deterministic eval mode:if torch.cuda.is_available(): model = model.cuda() print("Model moved to CUDA") + model.eval() + # Apply sparse attention to the model (with calibration if configured)modelopt/torch/sparsity/weight_sparsity/plugins/megatron.py (1)
32-47: Passmetadatathrough the sharded_state_dict override.The override still drops the caller-supplied
metadata, so checkpoint writers depending on it lose information. Forward the argument to the base implementation.- sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets) + sharded_state_dict = super().sharded_state_dict( + prefix, sharded_offsets, metadata + )modelopt/torch/sparsity/weight_sparsity/searcher.py (1)
46-54: Fix assertion to avoid touching unsetself.config.At this stage
self.configisn’t initialized, so the assertion raisesAttributeErrorinstead of surfacing the bad pattern. Useconfig_sanitized(or the incoming config) in the message instead.- assert is_nm_prune and n == 2 and m == 4, ( - f"Unsupported pattern {self.config['pattern']} for sparsity" - ) + assert is_nm_prune and n == 2 and m == 4, ( + f"Unsupported pattern {config_sanitized['pattern']} for sparsity" + )modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py (1)
233-600: Use a dedicated RNG instead of reseeding Python’s global state.
random.seed(seed)and the module-levelrandom.*calls mutate the process-wide RNG every time the builder runs, causing unrelated code to become deterministically seeded. Initialize a per-instance generator and route all randomness through it.- self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) - random.seed(seed) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) + self._rng = random.Random(seed) @@ - random.shuffle(all_samples) + self._rng.shuffle(all_samples)Apply the same replacement for every subsequent
random.*usage (e.g.,randint,sample,choice,choices,shuffle, etc.) so they all callself._rng.modelopt/torch/sparsity/attention_sparsity/methods/registry.py (1)
18-114: Select the latest method version numerically, not lexicographically.
sorted(...)[-1]will pickv9overv10, so callers defaulting to “latest” silently regress once version numbers reach two digits. Please keep the fallback deterministic but parse the numeric suffix first.@@ -from abc import ABC, abstractmethod - -import torch +from abc import ABC, abstractmethod +import re + +import torch @@ - if not version: - version = sorted(method_versions.keys())[-1] + if not version: + def _version_key(tag: str) -> tuple[int, str]: + match = re.match(r"^v(\d+)$", tag) + return (int(match.group(1)) if match else -1, tag) + + version = max(method_versions.keys(), key=_version_key)modelopt/torch/sparsity/weight_sparsity/sparsegpt.py (2)
32-134: Return the true Hessian inverse and fix downstream scaling.
invert()hands back the Cholesky factor of the inverse, not the inverse itself, so every consumer (e.g.,hinv_diag_blk) reads square roots and the pruning math collapses. That cascades into dividing byhinv_diag_blk**2and bydwithout stabilizers, severely distorting error estimates.Please return the actual inverse and consume it consistently:
@@ -def invert(hessian: torch.Tensor) -> torch.Tensor: - """Invert a Hessian matrix.""" - try: - hessian_inv = torch.linalg.cholesky(hessian) - hessian_inv = torch.cholesky_inverse(hessian_inv) - hessian_inv = torch.linalg.cholesky(hessian_inv, upper=True) - except RuntimeError: - cols = hessian.size(0) - eps = 1e-6 * torch.eye(cols).to(hessian.device) - hessian_inv = torch.cholesky_inverse(torch.linalg.cholesky(hessian + eps)) - - return hessian_inv +def invert(hessian: torch.Tensor) -> torch.Tensor: + """Invert a Hessian matrix.""" + try: + chol = torch.linalg.cholesky(hessian) + return torch.cholesky_inverse(chol) + except RuntimeError: + cols = hessian.size(0) + eps = 1e-6 * torch.eye(cols, device=hessian.device, dtype=hessian.dtype) + chol = torch.linalg.cholesky(hessian + eps) + return torch.cholesky_inverse(chol) @@ - hessian_inv_diag = torch.diagonal(hessian_inv, dim1=0, dim2=1) + hessian_inv_diag = torch.diagonal(hessian_inv, dim1=0, dim2=1) @@ - errors_blk = (w_blk**2) / (hinv_diag_blk**2 + 1e-9) + errors_blk = (w_blk**2) / (hinv_diag_blk + 1e-9) @@ - if is_nm_prune and j % m == 0: - errors_blk = (w_blk[:, j : j + m] ** 2) / (hinv_diag_blk[j : j + m] ** 2 + 1e-9) + if is_nm_prune and j % m == 0: + errors_blk = (w_blk[:, j : j + m] ** 2) / ( + hinv_diag_blk[j : j + m] + 1e-9 + ) @@ - err = (w - q) / d + err = (w - q) / (d + 1e-9)
266-275: Avoid double-counting the Hessian update on CUDA.The CUDA branch adds
inp.matmul(inp.t())once inside the branch and then again unconditionally, inflating statistics whenever GPU memory is sufficient.@@ - if mod.hessian.device.type == "cuda": - if cls._is_memory_sufficient(mod.hessian.device.index, 0.8): - mod.hessian += inp.matmul(inp.t()).to(mod.hessian.device) - else: - target_device = "cpu" - mod.hessian = mod.hessian.to("cpu") - mod.hessian += inp.matmul(inp.t()).to(target_device) + update = inp.matmul(inp.t()) + if mod.hessian.device.type == "cuda" and not cls._is_memory_sufficient( + mod.hessian.device.index, 0.8 + ): + target_device = "cpu" + mod.hessian = mod.hessian.to("cpu") + mod.hessian += update.to(target_device)modelopt/torch/sparsity/attention_sparsity/conversion.py (1)
198-232: Rebuild sparse method instance after restoring configuration.As flagged in a previous review comment, after restoring
module._methodandmodule._method_configfrom the saved state (lines 212-222), the existingmodule._sparse_method_instanceretains its pre-restoration configuration. Callingmodule._setup()(line 225) without first rebuilding the method instance means:
- The method instance is never reconstructed with the restored configuration
- Only
threshold_scale_factorandtarget_sparsityare manually patched afterward (lines 228-231)- Other restored method configuration fields remain stale in the instance
This causes the restored model's runtime behavior to diverge from the saved sparse attention configuration.
Apply this diff to rebuild the method instance with the restored configuration:
if "method_config" in module_state: # Restore config attributes # Separate method instance attributes from module attributes method_instance_attrs = {"threshold_scale_factor", "target_sparsity"} for key, val in module_state["method_config"].items(): if key not in method_instance_attrs: # Set on module setattr(module, f"_{key}", val) + # Rebuild sparse method with restored config before final setup + module._init_sparse_method() + # Re-setup with restored config module._setup() # Restore method instance attributes after _setup if "method_config" in module_state: for key, val in module_state["method_config"].items(): if key in {"threshold_scale_factor", "target_sparsity"}: setattr(module._sparse_method_instance, key, val)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (41)
examples/llm_sparse_attention/README.md(1 hunks)examples/llm_sparse_attention/hf_spar_attn.py(1 hunks)examples/llm_sparse_attention/requirements.txt(1 hunks)modelopt/torch/export/unified_export_hf.py(2 hunks)modelopt/torch/sparsity/attention_sparsity/__init__.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/calibration/__init__.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/config.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/conversion.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/methods/__init__.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/methods/registry.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/mode.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/model_sparsify.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/nn/__init__.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py(1 hunks)modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py(1 hunks)modelopt/torch/sparsity/weight_sparsity/__init__.py(1 hunks)modelopt/torch/sparsity/weight_sparsity/config.py(1 hunks)modelopt/torch/sparsity/weight_sparsity/magnitude.py(1 hunks)modelopt/torch/sparsity/weight_sparsity/mode.py(1 hunks)modelopt/torch/sparsity/weight_sparsity/module.py(1 hunks)modelopt/torch/sparsity/weight_sparsity/plugins/__init__.py(1 hunks)modelopt/torch/sparsity/weight_sparsity/plugins/megatron.py(1 hunks)modelopt/torch/sparsity/weight_sparsity/searcher.py(1 hunks)modelopt/torch/sparsity/weight_sparsity/sparsegpt.py(1 hunks)modelopt/torch/sparsity/weight_sparsity/sparsification.py(1 hunks)tests/_test_utils/torch_sparsity/sparse_attention_common.py(1 hunks)tests/gpu/torch/sparsity/attention_sparsity/test_basic_gpu.py(1 hunks)tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py(1 hunks)tests/gpu/torch/sparsity/attention_sparsity/test_export_gpu.py(1 hunks)tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py(1 hunks)tests/unit/torch/sparsity/attention_sparsity/test_export_config.py(1 hunks)tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py(1 hunks)tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py(1 hunks)tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py(1 hunks)tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- examples/llm_sparse_attention/requirements.txt
🚧 Files skipped from review as they are similar to previous changes (12)
- modelopt/torch/sparsity/attention_sparsity/nn/init.py
- modelopt/torch/sparsity/attention_sparsity/calibration/init.py
- modelopt/torch/sparsity/weight_sparsity/config.py
- modelopt/torch/sparsity/attention_sparsity/mode.py
- modelopt/torch/sparsity/attention_sparsity/methods/init.py
- tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_mode.py
- modelopt/torch/sparsity/weight_sparsity/plugins/init.py
- tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_config.py
- modelopt/torch/sparsity/weight_sparsity/sparsification.py
- modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py
- modelopt/torch/export/unified_export_hf.py
- modelopt/torch/sparsity/weight_sparsity/module.py
🧰 Additional context used
🧬 Code graph analysis (26)
modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py (1)
modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py (1)
set_calibration_mode(58-60)
tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py (3)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
SparseAttentionConfig(292-319)modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
SparseAttentionModule(29-201)modelopt/torch/opt/conversion.py (1)
modelopt_state(444-486)
modelopt/torch/sparsity/weight_sparsity/__init__.py (1)
modelopt/torch/sparsity/weight_sparsity/module.py (1)
SparseModule(31-87)
modelopt/torch/sparsity/attention_sparsity/model_sparsify.py (3)
modelopt/torch/opt/conversion.py (3)
ModeloptStateManager(63-311)apply_mode(342-429)is_converted(102-127)modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (2)
calibrate_sparse_attention(112-177)forward_loop(97-107)modelopt/torch/sparsity/attention_sparsity/config.py (1)
SparseAttentionConfig(292-319)
examples/llm_sparse_attention/hf_spar_attn.py (5)
modelopt/torch/export/unified_export_hf.py (1)
export_hf_checkpoint(622-682)modelopt/torch/sparsity/attention_sparsity/config.py (1)
SparseAttentionConfig(292-319)modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (4)
SparseAttentionModule(29-201)get_stats(130-139)disable(121-123)enable(117-119)modelopt/torch/utils/memory_monitor.py (1)
launch_memory_monitor(134-151)modelopt/torch/sparsity/attention_sparsity/model_sparsify.py (1)
sparsify(35-148)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
modelopt/torch/opt/config.py (3)
ModeloptBaseConfig(59-147)ModeloptField(50-53)keys(132-134)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrate.py (5)
modelopt/torch/sparsity/attention_sparsity/config.py (1)
CalibrationConfig(159-247)modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
SparseAttentionModule(29-201)modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (2)
DynamicThresholdCalibrator(31-308)calibrate(77-219)modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py (2)
RulerDatasetBuilder(181-602)build_calibration_dataset(236-262)modelopt/torch/sparsity/attention_sparsity/model_sparsify.py (1)
calibrate(151-197)
modelopt/torch/sparsity/weight_sparsity/plugins/megatron.py (3)
modelopt/torch/opt/plugins/megatron.py (1)
_MegatronMLP(120-142)modelopt/torch/sparsity/weight_sparsity/module.py (1)
SparseModule(31-87)modelopt/torch/opt/config.py (1)
register_default(227-237)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (6)
modelopt/torch/opt/dynamic.py (3)
DynamicModule(338-914)_DMRegistryCls(917-1124)config(1265-1278)modelopt/torch/quantization/utils.py (1)
replace_function(300-307)modelopt/torch/sparsity/attention_sparsity/config.py (1)
SparseAttentionAttributeConfig(34-156)modelopt/torch/sparsity/attention_sparsity/methods/registry.py (2)
get_sparse_method(88-120)apply_sparsity(27-44)modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py (3)
SparseAttentionStatsManager(7-125)get_summary(74-92)collect(43-72)modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py (1)
apply_sparsity(241-284)
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py (3)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (2)
SimpleAttentionModel(27-46)SimpleTransformerEncoderLayer(49-69)modelopt/torch/sparsity/attention_sparsity/conversion.py (2)
disable_sparse_attention(278-305)enable_sparse_attention(308-335)modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
SparseAttentionModule(29-201)
modelopt/torch/sparsity/weight_sparsity/magnitude.py (2)
modelopt/torch/sparsity/weight_sparsity/module.py (1)
SparseModule(31-87)modelopt/torch/sparsity/weight_sparsity/searcher.py (3)
BaseSparseSearcher(31-84)_check_weight_size(59-61)_compute_mask(64-66)
modelopt/torch/sparsity/weight_sparsity/sparsegpt.py (4)
modelopt/torch/utils/logging.py (1)
print_rank_0(92-95)modelopt/torch/sparsity/weight_sparsity/magnitude.py (3)
get_nmprune_info(29-35)_check_weight_size(134-144)_compute_mask(146-148)modelopt/torch/sparsity/weight_sparsity/module.py (1)
SparseModule(31-87)modelopt/torch/sparsity/weight_sparsity/searcher.py (5)
BaseSparseSearcher(31-84)default_search_config(37-39)_check_weight_size(59-61)_compute_mask(64-66)_named_sparsifiable_modules(68-76)
modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py (1)
modelopt/torch/utils/random.py (3)
random(59-61)shuffle(148-150)choice(64-89)
tests/gpu/torch/sparsity/attention_sparsity/test_basic_gpu.py (1)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (9)
SimpleAttentionModel(27-46)SimpleTransformerEncoder(72-90)SimpleTransformerEncoderLayer(49-69)get_test_configs(142-147)save_restore_test(185-218)sparsify_model_and_forward(150-182)get_input(44-46)get_input(67-69)get_input(88-90)
tests/gpu/torch/sparsity/attention_sparsity/test_calibration_gpu.py (5)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (5)
SimpleTransformerEncoderLayer(49-69)forward_loop(162-164)get_input(44-46)get_input(67-69)get_input(88-90)modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py (2)
RulerDatasetBuilder(181-602)build_calibration_dataset(236-262)modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
SparseAttentionModule(29-201)tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py (1)
forward_loop(210-214)modelopt/torch/opt/conversion.py (2)
modelopt_state(444-486)restore_from_modelopt_state(510-567)
modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py (1)
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py (1)
register_sparse_attention_on_the_fly(55-100)
modelopt/torch/sparsity/attention_sparsity/conversion.py (5)
modelopt/torch/opt/conversion.py (4)
ModelLikeModule(318-330)ModeloptStateManager(63-311)init_modellike(326-330)state_version(135-137)modelopt/torch/utils/network.py (1)
get_unwrapped_name(599-612)modelopt/torch/sparsity/attention_sparsity/config.py (1)
SparseAttentionConfig(292-319)modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (6)
SparseAttentionModule(29-201)set_from_attribute_config(61-107)_setup(141-153)disable(121-123)enable(117-119)is_enabled(126-128)modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py (2)
register_sparse_attention_on_the_fly(55-100)_setup(33-38)
modelopt/torch/sparsity/weight_sparsity/mode.py (6)
modelopt/torch/opt/dynamic.py (3)
config(1265-1278)DynamicSpace(1127-1368)convert_to_dynamic(1138-1187)modelopt/torch/opt/config.py (1)
ModeloptBaseConfig(59-147)modelopt/torch/opt/conversion.py (1)
ApplyModeError(314-315)modelopt/torch/opt/mode.py (2)
ModeDescriptor(56-259)_ModeRegistryCls(267-344)modelopt/torch/utils/network.py (2)
compare_dict(423-427)unwrap_model(430-454)modelopt/torch/sparsity/weight_sparsity/config.py (1)
ExportSparseConfig(50-51)
tests/unit/torch/sparsity/attention_sparsity/test_export_config.py (3)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (5)
SimpleTransformerEncoderLayer(49-69)sparsify_model_and_forward(150-182)get_input(44-46)get_input(67-69)get_input(88-90)modelopt/torch/export/unified_export_hf.py (1)
_get_sparse_attention_config(341-461)modelopt/torch/sparsity/attention_sparsity/conversion.py (1)
disable_sparse_attention(278-305)
tests/gpu/torch/sparsity/attention_sparsity/test_export_gpu.py (4)
tests/_test_utils/torch_model/transformers_models.py (1)
create_tiny_llama_dir(117-131)modelopt/torch/export/unified_export_hf.py (1)
export_hf_checkpoint(622-682)modelopt/torch/sparsity/attention_sparsity/config.py (1)
SparseAttentionConfig(292-319)modelopt/torch/sparsity/attention_sparsity/conversion.py (1)
disable_sparse_attention(278-305)
modelopt/torch/sparsity/weight_sparsity/searcher.py (5)
modelopt/torch/opt/searcher.py (1)
BaseSearcher(52-271)modelopt/torch/utils/logging.py (1)
print_rank_0(92-95)modelopt/torch/sparsity/weight_sparsity/module.py (2)
SparseModule(31-87)set_mask(68-87)modelopt/torch/sparsity/weight_sparsity/magnitude.py (1)
get_nmprune_info(29-35)modelopt/torch/trace/symbols.py (1)
named_modules(444-447)
modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (3)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
SparseAttentionModule(29-201)modelopt/torch/sparsity/attention_sparsity/nn/stats_manager.py (3)
SparseAttentionStatsManager(7-125)set_calibration_mode(94-106)get_calibration_stats(118-125)modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py (1)
set_calibration_mode(58-60)
modelopt/torch/sparsity/attention_sparsity/methods/registry.py (1)
modelopt/torch/sparsity/attention_sparsity/methods/flash_softmax_skip.py (2)
apply_sparsity(241-284)name(287-289)
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_calibration.py (4)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (5)
SimpleAttentionModel(27-46)forward_loop(162-164)forward(39-41)forward(63-64)forward(84-85)modelopt/torch/sparsity/attention_sparsity/calibration/calibrator.py (1)
DynamicThresholdCalibrator(31-308)modelopt/torch/sparsity/attention_sparsity/calibration/dataset.py (3)
RulerDatasetBuilder(181-602)_generate_target_lengths(28-56)build_calibration_dataset(236-262)modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
SparseAttentionModule(29-201)
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py (3)
modelopt/torch/opt/dynamic.py (3)
DynamicModule(338-914)get_original_cls_by_level(905-914)level(158-160)modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (2)
SparseAttentionModule(29-201)_setup(141-153)modelopt/torch/export/layer_utils.py (1)
is_attention(318-320)
tests/_test_utils/torch_sparsity/sparse_attention_common.py (2)
modelopt/torch/sparsity/attention_sparsity/nn/sparse_attention.py (1)
SparseAttentionModule(29-201)modelopt/torch/opt/conversion.py (2)
modelopt_state(444-486)restore_from_modelopt_state(510-567)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (5)
modelopt/torch/sparsity/weight_sparsity/__init__.py (1)
18-23: Public API wiring looks solid.The package initializer cleanly surfaces the expected registries and modules, matching the surrounding structure. Nice work keeping the namespace consistent.
modelopt/torch/sparsity/attention_sparsity/plugins/__init__.py (1)
18-22: Plugin export hub looks good.Centralizing the HuggingFace registration hook here keeps the plugin surface tidy and discoverable.
tests/unit/torch/sparsity/attention_sparsity/test_sparse_attention_conversion.py (1)
36-96: Great coverage on conversion flows.These cases hit replacement, toggling, and pattern selection, giving solid confidence in the conversion logic.
tests/unit/torch/sparsity/attention_sparsity/test_export_config.py (1)
32-203: Excellent validation of export config serialization.This suite thoroughly covers default, phase-aware, disabled, calibrated, and metadata paths, so regressions in
_get_sparse_attention_configshould surface quickly.tests/_test_utils/torch_sparsity/sparse_attention_common.py (1)
1-218: Well-structured test utilities.This module provides clean, reusable test infrastructure for sparse attention testing:
- Three test model classes with consistent APIs (
forwardandget_inputclass methods)- Comprehensive configuration constants covering different sparse attention scenarios
- Helper functions that properly implement forward passes during calibration (lines 162-164) and validation (lines 175-180)
- Robust save/restore testing with output verification (lines 207-217)
These utilities establish good patterns that should be followed by the test code in
test_calibration_gpu.py.
| "--backend", | ||
| type=str, | ||
| default="pytorch", | ||
| choices=["pytorch", "triton"], | ||
| help="Backend to use for sparse attention computation (default: pytorch)", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the unsupported triton backend option.
SparseAttentionAttributeConfig.validate_backend hard-rejects any backend other than "pytorch", so choosing "triton" causes an immediate crash. Until Triton support lands, constrain the CLI to the actually supported backend (and mirror the README update):
- parser.add_argument(
- "--backend",
- type=str,
- default="pytorch",
- choices=["pytorch", "triton"],
- help="Backend to use for sparse attention computation (default: pytorch)",
- )
+ parser.add_argument(
+ "--backend",
+ type=str,
+ default="pytorch",
+ choices=["pytorch"],
+ help="Backend to use for sparse attention computation (currently only 'pytorch').",
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| "--backend", | |
| type=str, | |
| default="pytorch", | |
| choices=["pytorch", "triton"], | |
| help="Backend to use for sparse attention computation (default: pytorch)", | |
| parser.add_argument( | |
| "--backend", | |
| type=str, | |
| default="pytorch", | |
| choices=["pytorch"], | |
| help="Backend to use for sparse attention computation (currently only 'pytorch').", | |
| ) |
🤖 Prompt for AI Agents
In examples/llm_sparse_attention/hf_spar_attn.py around lines 326 to 330, the
CLI parser currently includes "triton" as a choice for the --backend option even
though SparseAttentionAttributeConfig.validate_backend only accepts "pytorch";
remove "triton" from the choices and update the default/choices to only include
"pytorch" so the CLI cannot select an unsupported backend and crash at
validation time.
| - `--backend`: Backend for computation - pytorch or triton (default: pytorch) | ||
| - `--seq_len`: Maximum sequence length (default: 2048) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Align backend docs with actual support.
The options table advertises a triton backend, but SparseAttentionAttributeConfig.validate_backend rejects anything other than pytorch. Following the README literally causes the CLI to fail. Please either document it as “pytorch only” (and mention Triton as future work) or update the implementation to accept the advertised backend.
🤖 Prompt for AI Agents
examples/llm_sparse_attention/README.md lines 119-120: the README claims a
`triton` backend but the code (SparseAttentionAttributeConfig.validate_backend)
only accepts `pytorch`, causing the CLI to fail; either (A) update the README to
list backend as "pytorch (only)" and add a short note that Triton is future
work, or (B) modify validate_backend to accept `"triton"` as a valid value and
wire through any backend-specific switches/placeholders so the CLI won't reject
`--backend triton` (ensure tests/usage handle the triton path or clearly stub
it).
| mask = torch.IntTensor(matrix.shape).fill_(1).view(-1, m) | ||
| mat, _ = reshape_1d(matrix, m) | ||
| pmax = torch.argmax(torch.matmul(mat.abs(), patterns.t()), dim=1) | ||
| mask[:] = patterns[pmax[:]] | ||
| mask = mask.view(matrix.shape) | ||
| return mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix device mismatch when constructing mask
mask = torch.IntTensor(...) always allocates on CPU. When the weight tensor lives on CUDA, assigning patterns[pmax] (CUDA) into this CPU mask throws Expected all tensors to be on the same device. Allocate the mask on the same device as the weights instead.
Apply this diff:
- mask = torch.IntTensor(matrix.shape).fill_(1).view(-1, m)
+ mask = torch.ones(matrix.shape, dtype=patterns.dtype, device=matrix.device).view(-1, m)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| mask = torch.IntTensor(matrix.shape).fill_(1).view(-1, m) | |
| mat, _ = reshape_1d(matrix, m) | |
| pmax = torch.argmax(torch.matmul(mat.abs(), patterns.t()), dim=1) | |
| mask[:] = patterns[pmax[:]] | |
| mask = mask.view(matrix.shape) | |
| return mask | |
| mask = torch.ones(matrix.shape, dtype=patterns.dtype, device=matrix.device).view(-1, m) | |
| mat, _ = reshape_1d(matrix, m) | |
| pmax = torch.argmax(torch.matmul(mat.abs(), patterns.t()), dim=1) | |
| mask[:] = patterns[pmax[:]] | |
| mask = mask.view(matrix.shape) | |
| return mask |
| def forward_loop(model): | ||
| # Simple forward loop for calibration | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Empty forward loops prevent proper calibration.
Throughout this test file, all forward_loop functions contain only pass statements, meaning no forward passes are executed during calibration. Calibration requires running the model on data to collect statistics (e.g., attention score distributions) to tune thresholds. Without executing forward passes, the calibration process cannot collect meaningful statistics, potentially causing:
- Calibration to silently fail or use default/untuned thresholds
- Tests to pass incorrectly without validating actual calibration behavior
- Divergence from the intended calibration workflow
Reference implementation from tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py (lines 209-213) shows the correct pattern:
def forward_loop(model):
"""Simple forward loop for calibration."""
test_input = torch.randint(0, 32000, (1, 64), device="cuda")
with torch.no_grad():
model(test_input)Similarly, the helper function sparsify_model_and_forward in tests/_test_utils/torch_sparsity/sparse_attention_common.py (lines 162-164) demonstrates proper forward pass execution.
Apply this diff to fix the forward loops. For example, in test_calibration_simple_model:
def forward_loop(model):
- # Simple forward loop for calibration
- pass
+ # Simple forward loop for calibration
+ test_input = SimpleTransformerEncoderLayer.get_input(d_model=256, seq_len=64).cuda()
+ with torch.no_grad():
+ model(test_input)Apply similar changes to all other empty forward_loop functions in this file, ensuring the input dimensions match the model configuration (e.g., d_model=256 for models in TestCalibrationGPU and TestCalibrationEndToEnd).
Also applies to: 173-174, 204-205, 231-232, 279-280, 325-326, 373-374
| @pytest.fixture(scope="module") | ||
| def tinyllama_model(): | ||
| """Load TinyLlama model for testing.""" | ||
| try: | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | ||
| torch_dtype=torch.float16, | ||
| device_map="cuda", | ||
| ) | ||
| return model | ||
| except Exception as e: | ||
| pytest.skip(f"Could not load TinyLlama model: {e}") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reset TinyLlama fixture between tests to avoid stale sparsified state.
sparse_attn.sparsify mutates the passed model in place and marks it as converted via the Modelopt state manager. Because this fixture is module-scoped, the very first test leaves the shared TinyLlama instance already sparsified. Every subsequent test then reuses that mutated object, so new configs/backends (and especially calibration forward loops) are silently skipped. The later cases therefore never exercise the code paths they intend to validate. Please return a fresh model per test—e.g., switch the fixture to function scope or clone/reset the model before sparsifying—so each scenario runs against an unsparsified baseline.
🤖 Prompt for AI Agents
In tests/gpu/torch/sparsity/attention_sparsity/test_integration_gpu.py around
lines 32 to 44, the TinyLlama fixture is module-scoped so the model is mutated
by sparse_attn.sparsify and reused across tests; change the fixture to return a
fresh unsparsified model per test by switching scope="module" to
scope="function" (or, alternatively, load/clone/reset the model inside the
fixture before returning) so each test gets an independent instance and
sparsification state does not leak between tests.
What does this PR do?
This PR introduces a new sparse attention module for LLMs that reduces computational complexity by selectively computing only the most important attention scores. This feature enables speedup and memory reduction during inference, especially for long sequences, while maintaining model quality.
Type of change: ?
New feature
Overview: ?
The goal of this PR is to add a new ModelOpt module that supports sparse attention, improving inference efficiency and lowering cost by skipping unnecessary attention computations. The sparse attention module provides a flexible foundation for future attention sparsity techniques.
Core Design
SparseAttentionModule: Central module managing sparse attention behavior
Extensible Method Registry: Plugin architecture for sparse attention algorithms
DSA (planned): Future support for additional sparse attention methods
Statistics Manager: Collects and reports sparsity metrics during inference
Calibration Support: Automatic threshold tuning to achieve target sparsity ratios
HuggingFace Unified Checkpoint Export: Export models with sparse_attention_config metadata
Multi-Backend Kernel Support:
Usage
Testing
Unit/GPU test passed.
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Documentation
Tests
Chores