diff --git a/torchao/sparsity/wanda.py b/torchao/sparsity/wanda.py index 7ad12a2d55..801a38c4b7 100644 --- a/torchao/sparsity/wanda.py +++ b/torchao/sparsity/wanda.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import warnings -from typing import Dict, List, Optional, Tuple +from typing import Optional import torch from torch import nn @@ -48,8 +48,8 @@ def __init__( ) super().__init__(defaults=defaults) - # `typing.Dict[, ]` to avoid runtime subscripting errors. - def prepare(self, model: nn.Module, config: List[Dict]) -> None: + # `typing.dict[, ]` to avoid runtime subscripting errors. + def prepare(self, model: nn.Module, config: list[dict]) -> None: # activation: use PerChannelNormObserver # use no-op placeholder weight observer if config is None: @@ -88,35 +88,38 @@ def update_mask( # type: ignore[override] by comparing this metric across the whole current layer. """ - # Step 1: get the tensor and the mask from the parametrizations + # Step 1: get the attributes (tensor and mask) from the parametrizations mask = getattr(module.parametrizations, tensor_name)[0].mask tensor = getattr(module.parametrizations, tensor_name).original activation_norm_per_channel = module.activation_post_process.norm - # Step 2: Calculate Wx + # Step 2: Calculate pruning criteria : '|weight| * ||activation||' pruning_metric = torch.abs(tensor) * activation_norm_per_channel - # defaults for unstructured sparsity + # Step 3 : Calculate the number of elements (weight params) block_size = pruning_metric.numel() + + # Step 4 : Define pruning boundary : N(elements) * (pruning ratio) num_specified = int(block_size * sparsity_level) - # if set to use semi-structured, ignore sparsity_level + # if set to use semi-structured, ignore sparsity_level and apply 1:2 sparsity if kwargs.get("semi_structured_block_size", None) is not None: block_size = kwargs["semi_structured_block_size"] num_specified = block_size // 2 - # get indicies to prune + # Step 5 : Flatten it for sorting and prune lower-boundary weights pruning_inds = pruning_metric.view(-1, block_size).argsort(dim=1)[ :, :num_specified ] - # update mask + + # Step 6 : Reshape and zeroize lower-boundary elements mask.data.view(-1, block_size).scatter_( 1, pruning_inds, torch.zeros_like(pruning_inds, dtype=mask.dtype) ) def squash_mask( self, - params_to_keep: Optional[Tuple[str, ...]] = None, - params_to_keep_per_layer: Optional[Dict[str, Tuple[str, ...]]] = None, + params_to_keep: Optional[tuple[str, ...]] = None, + params_to_keep_per_layer: Optional[dict[str, tuple[str, ...]]] = None, *args, **kwargs, ):