diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a8dfbd..40e22c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # ptflops versions log +## v 0.7.4 +- Switch to aten by default. +- Add ignore and custom modules for aten. +- Add an option to disable counting of functional-style operations in pytorch backend. + ## v 0.7.3 - Add aten backend to collect the amount of flops on aten level. diff --git a/README.md b/README.md index 54eddb6..2e6dcd1 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ print per-layer computational cost of a given network. `ptflops` has two backends, `pytorch` and `aten`. `pytorch` backend is a legacy one, it considers `nn.Modules` only. However, it's still useful, since it provides a better par-layer analytics for CNNs. In all other cases it's recommended to use `aten` backend, which considers aten operations, and therefore it covers more model architectures (including transformers). +The default backend is `aten`. Please, don't use `pytorch` backend for transformer architectures. ## `aten` backend ### Operations considered: @@ -19,6 +20,9 @@ it's still useful, since it provides a better par-layer analytics for CNNs. In a - Use `verbose=True` to see the operations which were not considered during complexity computation. - This backend prints per-module statistics only for modules directly nested into the root `nn.Module`. Deeper modules at the second level of nesting are not shown in the per-layer statistics. +- `ignore_modules` option forces `ptflops` to ignore the listed modules. This can be useful +for research purposes. For instance, one can drop all convolutions from the counting process +specifying `ignore_modules=[torch.ops.aten.convolution, torch.ops.aten._convolution]`. ## `pytorch` backend ### Supported layers: @@ -41,7 +45,9 @@ Experimental support: - This backend doesn't take into account some of the `torch.nn.functional.*` and `tensor.*` operations. Therefore unsupported operations are not contributing to the final complexity estimation. See `ptflops/pytorch_ops.py:FUNCTIONAL_MAPPING,TENSOR_OPS_MAPPING` to check supported ops. -- `ptflops` launches a given model on a random tensor and estimates amount of computations during inference. Complicated models can have several inputs, some of them could be optional. To construct non-trivial input one can use the `input_constructor` argument of the `get_model_complexity_info`. `input_constructor` is a function that takes the input spatial resolution as a tuple and returns a dict with named input arguments of the model. Next this dict would be passed to the model as a keyword arguments. +Sometimes considering functional style conflicts with hooks for `nn.Module` (for instance, custom ones). In that case, counting with these ops can be disabled by +passing `backend_specific_config={"count_functional" : False}`. +- `ptflops` launches a given model on a random tensor and estimates amount of computations during inference. Complicated models can have several inputs, some of them could be optional. To construct non-trivial input one can use the `input_constructor` argument of the `get_model_complexity_info`. `input_constructor` is a function that takes the input spatial resolution as a tuple and returns a dict with named input arguments of the model. Next, this dict would be passed to the model as a keyword arguments. - `verbose` parameter allows to get information about modules that don't contribute to the final numbers. - `ignore_modules` option forces `ptflops` to ignore the listed modules. This can be useful for research purposes. For instance, one can drop all convolutions from the counting process diff --git a/ptflops/aten_engine.py b/ptflops/aten_engine.py index e5e6457..c149ca8 100644 --- a/ptflops/aten_engine.py +++ b/ptflops/aten_engine.py @@ -10,8 +10,9 @@ import sys import traceback from collections import defaultdict +from copy import deepcopy from functools import partial -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import torch from torch.utils._python_dispatch import TorchDispatchMode @@ -23,12 +24,15 @@ class FlopCounterMode(TorchDispatchMode): def __init__(self, module=None, verbose=False, print_per_layer_stat=False, - output_params=None): + output_params=None, custom_hooks={}, ignored_ops=[]): self.verbose = verbose if output_params is None: output_params = defaultdict(dict) self.output_params = output_params self.print_fn = partial(print, **self.output_params['print_params']) + self.all_ops = deepcopy(ATEN_OPS_MAPPING) + self.all_ops.update(custom_hooks) + self.ignored_ops = ignored_ops self.print_per_layer_stat = print_per_layer_stat self.flop_counts = defaultdict(lambda: defaultdict(int)) @@ -82,8 +86,11 @@ def normalize_tuple(x): out = func(*args, **kwargs) func_packet = func._overloadpacket - if func_packet in ATEN_OPS_MAPPING: - flop_count = ATEN_OPS_MAPPING[func_packet](args, normalize_tuple(out)) + + if func_packet in self.ignored_ops: + self.print_fn(f'Warning: {func_packet} operation is ignored') + elif func_packet in self.all_ops: + flop_count = self.all_ops[func_packet](args, normalize_tuple(out)) for par in self.parents: self.flop_counts[par][func_packet] += flop_count elif self.verbose: @@ -99,8 +106,9 @@ def get_flops_aten(model, input_res, custom_modules_hooks={}, output_precision=2, flops_units: Optional[str] = 'GMac', - param_units: Optional[str] = 'M') -> Tuple[Union[int, None], - Union[int, None]]: + param_units: Optional[str] = 'M', + extra_config: Dict = {}) -> Tuple[Union[int, None], + Union[int, None]]: params_sum = get_model_parameters_number(model) model.eval() @@ -119,7 +127,8 @@ def get_flops_aten(model, input_res, batch = torch.ones(()).new_empty((1, *input_res)) try: - counter = FlopCounterMode(model, verbose, print_per_layer_stat, output_params) + counter = FlopCounterMode(model, verbose, print_per_layer_stat, output_params, + custom_modules_hooks, ignore_modules) with counter: if isinstance(batch, dict): _ = model(**batch) diff --git a/ptflops/flops_counter.py b/ptflops/flops_counter.py index d43e403..4c10526 100644 --- a/ptflops/flops_counter.py +++ b/ptflops/flops_counter.py @@ -29,13 +29,15 @@ def get_model_complexity_info(model: nn.Module, input_constructor: Optional[Callable[[Tuple], Dict]] = None, ost: TextIO = sys.stdout, verbose: bool = False, - ignore_modules: List[nn.Module] = [], - custom_modules_hooks: Dict[nn.Module, Any] = {}, - backend: Union[str, FLOPS_BACKEND] = FLOPS_BACKEND.PYTORCH, + ignore_modules: List[Union[nn.Module, Any]] = [], + custom_modules_hooks: Dict[Union[nn.Module, Any], Any] = {}, + backend: Union[str, FLOPS_BACKEND] = FLOPS_BACKEND.ATEN, flops_units: Optional[str] = None, param_units: Optional[str] = None, - output_precision: int = 2) -> Tuple[Union[str, int, None], - Union[str, int, None]]: + output_precision: int = 2, + backend_specific_config: Dict = {}) -> Tuple[ + Union[str, int, None], + Union[str, int, None]]: """ Analyzes the input model and collects the amounts of parameters and MACs required to make a forward pass of the model. @@ -61,10 +63,11 @@ def get_model_complexity_info(model: nn.Module, :type ost: TextIO :param verbose: Parameter to control printing of extra information and warnings. :type verbose: bool - :param ignore_modules: A list of torch.nn.Module modules to ignore. - :type ignore_modules: nn.Module - :param custom_modules_hooks: A dict that contains custom hooks on torch modules. - :type custom_modules_hooks: Dict[nn.Module, Any] + :param ignore_modules: A list of torch.nn.Module or torch.ops.aten modules to ignore. + :type ignore_modules: List[Union[nn.Module, Any]] + :param custom_modules_hooks: A dict that contains custom hooks for torch.nn.Module or + torch.ops.aten modules. + :type custom_modules_hooks: Dict[Union[nn.Module, Any], Any] :param backend: Backend that used for evaluating model complexity. :type backend: FLOPS_BACKEND :param flops_units: Units for string representation of MACs (GMac, MMac or KMac). @@ -74,6 +77,8 @@ def get_model_complexity_info(model: nn.Module, :param output_precision: Floating point precision for representing MACs/params in given units. :type output_precision: int + :param backend_specific_config: Extra configuration for a specific backend. + :type backend_specific_config: dict Returns: Tuple[Union[str, int, None], Union[str, int, None]]: Return value is a tuple @@ -85,14 +90,16 @@ def get_model_complexity_info(model: nn.Module, assert isinstance(model, nn.Module) if FLOPS_BACKEND(backend) == FLOPS_BACKEND.PYTORCH: - flops_count, params_count = get_flops_pytorch(model, input_res, - print_per_layer_stat, - input_constructor, ost, - verbose, ignore_modules, - custom_modules_hooks, - output_precision=output_precision, - flops_units=flops_units, - param_units=param_units) + flops_count, params_count = \ + get_flops_pytorch(model, input_res, + print_per_layer_stat, + input_constructor, ost, + verbose, ignore_modules, + custom_modules_hooks, + output_precision=output_precision, + flops_units=flops_units, + param_units=param_units, + extra_config=backend_specific_config) elif FLOPS_BACKEND(backend) == FLOPS_BACKEND.ATEN: flops_count, params_count = get_flops_aten(model, input_res, print_per_layer_stat, @@ -101,7 +108,8 @@ def get_model_complexity_info(model: nn.Module, custom_modules_hooks, output_precision=output_precision, flops_units=flops_units, - param_units=param_units) + param_units=param_units, + extra_config=backend_specific_config) else: raise ValueError('Wrong backend name') diff --git a/ptflops/pytorch_engine.py b/ptflops/pytorch_engine.py index bc24572..8585ef7 100644 --- a/ptflops/pytorch_engine.py +++ b/ptflops/pytorch_engine.py @@ -9,7 +9,7 @@ import sys import traceback from functools import partial -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -27,8 +27,9 @@ def get_flops_pytorch(model, input_res, custom_modules_hooks={}, output_precision=2, flops_units: Optional[str] = 'GMac', - param_units: Optional[str] = 'M') -> Tuple[Union[int, None], - Union[int, None]]: + param_units: Optional[str] = 'M', + extra_config: Dict = {}) -> Tuple[Union[int, None], + Union[int, None]]: global CUSTOM_MODULES_MAPPING CUSTOM_MODULES_MAPPING = custom_modules_hooks flops_model = add_flops_counting_methods(model) @@ -45,15 +46,18 @@ def get_flops_pytorch(model, input_res, except StopIteration: batch = torch.ones(()).new_empty((1, *input_res)) + enable_func_ops_patching = extra_config.get('count_functional', True) torch_functional_flops = [] torch_tensor_ops_flops = [] - patch_functional(torch_functional_flops) - patch_tensor_ops(torch_tensor_ops_flops) + if enable_func_ops_patching: + patch_functional(torch_functional_flops) + patch_tensor_ops(torch_tensor_ops_flops) def reset_environment(): flops_model.stop_flops_count() - unpatch_functional() - unpatch_tensor_ops() + if enable_func_ops_patching: + unpatch_functional() + unpatch_tensor_ops() global CUSTOM_MODULES_MAPPING CUSTOM_MODULES_MAPPING = {} diff --git a/tests/common_test.py b/tests/common_test.py index e648b7d..2c94fd4 100644 --- a/tests/common_test.py +++ b/tests/common_test.py @@ -11,6 +11,17 @@ class TestOperations: def default_input_image_size(self): return (3, 224, 224) + @pytest.fixture + def simple_model_mm(self): + class CustomModel(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.matmul(x.t()) + + return CustomModel() + @pytest.mark.parametrize("backend", [FLOPS_BACKEND.PYTORCH, FLOPS_BACKEND.ATEN]) def test_conv(self, default_input_image_size, backend: FLOPS_BACKEND): net = nn.Sequential(nn.Conv2d(3, 2, 3, bias=True)) @@ -53,7 +64,8 @@ def input_constructor(input_res): macs, params = get_model_complexity_info(net, (3,), input_constructor=input_constructor, as_strings=False, - print_per_layer_stat=False) + print_per_layer_stat=False, + backend=FLOPS_BACKEND.PYTORCH) assert (macs, params) == (8, 8) @@ -73,7 +85,8 @@ def input_constructor(input_res): get_model_complexity_info(CustomLinear(), (3,), input_constructor=input_constructor, as_strings=False, - print_per_layer_stat=False) + print_per_layer_stat=False, + backend=FLOPS_BACKEND.PYTORCH) assert (macs, params) == (8, 8) @@ -89,7 +102,8 @@ def forward(self, x): macs, params = \ get_model_complexity_info(CustomModel(), (3, 10, 10), as_strings=False, - print_per_layer_stat=False) + print_per_layer_stat=False, + backend=FLOPS_BACKEND.PYTORCH) assert params == 0 assert macs > 0 @@ -99,22 +113,52 @@ def forward(self, x): macs, params = \ get_model_complexity_info(CustomModel(), (3, 10, 10), as_strings=False, - print_per_layer_stat=False) + print_per_layer_stat=False, + backend=FLOPS_BACKEND.PYTORCH) assert params == 0 assert macs > 0 - def test_ten_matmul(self): - class CustomModel(nn.Module): - def __init__(self): - super().__init__() + def test_ten_matmul(self, simple_model_mm): + macs, params = \ + get_model_complexity_info(simple_model_mm, (10, ), + as_strings=False, + print_per_layer_stat=False, + backend=FLOPS_BACKEND.PYTORCH) - def forward(self, x): - return x.matmul(x.t()) + assert params == 0 + assert macs > 0 + def test_aten_ignore(self, simple_model_mm): + ignored_list = [torch.ops.aten.matmul, torch.ops.aten.mm] macs, params = \ - get_model_complexity_info(CustomModel(), (10, ), + get_model_complexity_info(simple_model_mm, (10, ), backend=FLOPS_BACKEND.ATEN, as_strings=False, - print_per_layer_stat=False) + print_per_layer_stat=False, + ignore_modules=ignored_list) assert params == 0 - assert macs > 0 + assert macs == 0 + + def test_aten_custom(self, simple_model_mm): + reference = 42 + custom_hooks = {torch.ops.aten.mm: lambda inputs, outputs: reference} + + macs, params = \ + get_model_complexity_info(simple_model_mm, (10, ), backend=FLOPS_BACKEND.ATEN, + as_strings=False, + print_per_layer_stat=False, + custom_modules_hooks=custom_hooks) + + assert params == 0 + assert macs == reference + + def test_torch_ignore_func(self, simple_model_mm): + macs, params = \ + get_model_complexity_info(simple_model_mm, (10, ), + backend=FLOPS_BACKEND.PYTORCH, + as_strings=False, + print_per_layer_stat=False, + backend_specific_config={'count_functional': False}) + + assert params == 0 + assert macs == 0