Skip to content

feat: RGS for wanda++ #2537

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions test/prototype/test_wanda_pp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from torch import nn
from torch.testing._internal.common_pruning import SimpleLinear
from torch.testing._internal.common_utils import TestCase

from torchao.prototype.sparsity.pruner.wanda_pp import WandaPlusPlusSparsifier


class TestWandaPlusPlusSparsifier(TestCase):
"""Test Wanda++ Sparsifier"""

def _setup_model_and_sparsifier(self, model, sparsifier, block_configs):
"""Helper to setup model with calibration and forward pass"""
sparsifier.prepare(model, config=None)

# Setup calibration for each block
for block_name, input_shape in block_configs.items():
for _ in range(5):
calibration_input = torch.randn(1, *input_shape)
sparsifier.store_calibration_input(block_name, calibration_input)

def _verify_sparsity(self, layer, expected, tolerance=0.02):
"""Helper to verify sparsity level"""
actual = (layer.weight == 0).float().mean()
assert abs(actual - expected) < tolerance, (
f"Expected ~{expected} sparsity, got {actual}"
)

def test_prepare_and_squash(self):
"""Test preparation and cleanup inherit from Wanda"""
model = SimpleLinear()
sparsifier = WandaPlusPlusSparsifier()
sparsifier.prepare(model, config=None)

# Should inherit Wanda's preparation
assert hasattr(sparsifier.groups[0]["module"], "activation_post_process")

sparsifier.squash_mask()
assert not hasattr(sparsifier.groups[0]["module"], "activation_post_process")

def test_one_layer_sparsity(self):
"""Test single layer sparsification"""
model = nn.Sequential(nn.Linear(4, 1))
model[0].weight.data = torch.tensor([[1, 2, 3, 4]], dtype=torch.float32)

sparsifier = WandaPlusPlusSparsifier(sparsity_level=0.5)
self._setup_model_and_sparsifier(model, sparsifier, {"layer_0": (4,)})

sparsifier.set_context(model[0], "layer_0")
model(torch.tensor([[100, 10, 1, 0.1]], dtype=torch.float32))
sparsifier.step()
sparsifier.squash_mask()

self._verify_sparsity(model[0], 0.5)

def test_multi_layer_sparsification(self):
"""Test multi-layer sparsification"""
model = nn.Sequential(nn.Linear(128, 200), nn.ReLU(), nn.Linear(200, 10))
sparsifier = WandaPlusPlusSparsifier(sparsity_level=0.5)

block_configs = {"layer_0": (128,), "layer_2": (200,)}
self._setup_model_and_sparsifier(model, sparsifier, block_configs)

model(torch.randn(100, 128))

# Sparsify each linear layer
for layer, block_name in [(model[0], "layer_0"), (model[2], "layer_2")]:
sparsifier.set_context(layer, block_name)
sparsifier.step()
self._verify_sparsity(layer, 0.5)

sparsifier.squash_mask()

def test_two_layer_mlp_unstructured_custom_config(self):
"""Test custom config for selective sparsification"""
model = nn.Sequential(nn.Linear(128, 200), nn.ReLU(), nn.Linear(200, 10))
config = [{"tensor_fqn": "0.weight"}]

sparsifier = WandaPlusPlusSparsifier(sparsity_level=0.5)
sparsifier.prepare(model, config=config)

# Only setup calibration for first layer
for _ in range(5):
sparsifier.store_calibration_input("layer_0", torch.randn(1, 128))

sparsifier.set_context(model[0], "layer_0")
model(torch.randn(100, 128))
sparsifier.step()

self._verify_sparsity(model[0], 0.5)
self._verify_sparsity(model[2], 0.0)
sparsifier.squash_mask()


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions torchao/prototype/sparsity/pruner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@
"BiasHook",
"FakeStructuredSparsity",
"SaliencyPruner",
"WandaPlusPlusSparsifier",
]
113 changes: 113 additions & 0 deletions torchao/prototype/sparsity/pruner/wanda_pp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn

from torchao.sparsity import WandaSparsifier

__all__ = ["WandaPlusPlusSparsifier"]


# TODO: Implement Regional Optimization (RO)
# TODO: Add `prepare` function for building quantization configs same as WandaSparsifier
class WandaPlusPlusSparsifier(WandaSparsifier):
r"""Wanda++ sparsifier extending Wanda with regional gradients

Wanda++ (Pruning by Weights and activations with Regional Gradients), proposed in
https://arxiv.org/abs/2503.04992, extends the Wanda method by incorporating
regional gradients for more accurate pruning criteria.

The sparsifier removes weights based on the Regional Gradient Score (RGS):
S_ij = (α * G_ij + ||X_j||_2) * |W_ij|

where:
- G_ij: Regional gradient computed from L^l_RGS(X^l_n) = ||f^l(X^l_n)||_2
- f^l: l-th decoder block function
- X^l_n: n-th input sample to the l-th decoder block
- α: Scaling factor for regional gradients (default: 100 from paper)

Args:
alpha: Regional gradient scaling factor (default: 100 from paper)
calibration_samples: Number of samples for gradient computation (default: 32 from paper)
**kwargs: Arguments passed to WandaSparsifier
"""

def __init__(self, alpha: float = 100.0, calibration_samples: int = 32, **kwargs):
super().__init__(**kwargs)
self.defaults.update(
{"alpha": alpha, "calibration_samples": calibration_samples}
)
self._calibration_inputs = {}
self._current_decoder_block = None
self._current_block_name = None

def store_calibration_input(
self, block_name: str, input_tensor: torch.Tensor
) -> None:
"""Store calibration inputs for regional gradient computation"""
if block_name not in self._calibration_inputs:
self._calibration_inputs[block_name] = []

if (
len(self._calibration_inputs[block_name])
< self.defaults["calibration_samples"]
):
self._calibration_inputs[block_name].append(input_tensor.detach().clone())

def set_context(self, decoder_block: nn.Module, block_name: str) -> None:
"""Set decoder block and block name for regional gradient computation"""
self._current_decoder_block = decoder_block
self._current_block_name = block_name

def update_mask(
self, module: nn.Module, tensor_name: str, sparsity_level: float, **kwargs
) -> None:
"""Update mask using regional gradients (RO)"""

# Step 1: get the tensor and the mask from the parametrizations
mask = getattr(module.parametrizations, tensor_name)[0].mask
tensor = getattr(module.parametrizations, tensor_name).original

# Step 2: Compute regional gradients (RGS)
pruning_metric = self._compute_wandapp_metric(module, tensor, tensor_name)

# Step 3: Apply sparsity using WandaSparsifier
self._apply_sparsity_pattern(mask, pruning_metric, sparsity_level, kwargs)

def _compute_wandapp_metric(
self, module: nn.Module, tensor: torch.Tensor, tensor_name: str
) -> torch.Tensor:
"""Compute RO : (α * G_ij + ||X_j||_2) * |W_ij|"""
activation_norm_per_channel = module.activation_post_process.norm
regional_gradients = self._compute_regional_gradients(module, tensor_name)

return (
self.defaults["alpha"] * regional_gradients
+ activation_norm_per_channel.unsqueeze(0)
) * tensor.abs()

def _compute_regional_gradients(
self, module: nn.Module, tensor_name: str
) -> torch.Tensor:
"""Compute regional gradients from calibration inputs"""

inputs = self._calibration_inputs.get(self._current_block_name)
target_param = getattr(module.parametrizations, tensor_name).original
accumulated_gradients = torch.zeros_like(target_param)

self._current_decoder_block.eval()

# Compute L2-norm regional gradients
for input_tensor in inputs:
self._current_decoder_block.zero_grad()
with torch.enable_grad():
output = self._current_decoder_block(input_tensor)
torch.norm(output, p=2).backward()
if target_param.grad is not None:
accumulated_gradients += target_param.grad.abs()

return accumulated_gradients / len(inputs)
11 changes: 11 additions & 0 deletions torchao/sparsity/wanda.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ def update_mask( # type: ignore[override]
# Step 2: Calculate Wx
pruning_metric = torch.abs(tensor) * activation_norm_per_channel

# Step 3: Apply sparsity pattern
self._apply_sparsity_pattern(mask, pruning_metric, sparsity_level, kwargs)

def _apply_sparsity_pattern(
self,
mask: torch.Tensor,
pruning_metric: torch.Tensor,
sparsity_level: float,
kwargs: dict,
) -> None:
"""Apply sparsity pattern based on pruning metric"""
# defaults for unstructured sparsity
block_size = pruning_metric.numel()
num_specified = int(block_size * sparsity_level)
Expand Down