From 32cf37b1ea16a95c3a34b1d0b0d572877a99e306 Mon Sep 17 00:00:00 2001 From: Mitesh Kumar Singh Date: Wed, 1 Jul 2020 13:37:34 -0700 Subject: [PATCH] Fix model output shape retrieval during flops calculation Summary: Use get_input_shape function (added recently in D21702925 (https://github.com/facebookresearch/ClassyVision/commit/f0528a17caf981ab20b7b508bd419cdd1fdd143e)) to get the shape of output. Updated the function name to get_shape to include both input and output. Differential Revision: D22327937 fbshipit-source-id: 321deffde974f267edcf9e1b7ca773d4df7444b7 --- classy_vision/generic/profiler.py | 18 +++++++++--------- test/generic_profiler_test.py | 10 +++++----- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/classy_vision/generic/profiler.py b/classy_vision/generic/profiler.py index 4efc23e0a8..4b1044e54a 100644 --- a/classy_vision/generic/profiler.py +++ b/classy_vision/generic/profiler.py @@ -70,18 +70,18 @@ def profile( return profiler -def get_input_shape(x: Union[Tuple, List, Dict]) -> Union[Tuple, List, Dict]: +def get_shape(x: Union[Tuple, List, Dict]) -> Union[Tuple, List, Dict]: """ - Some layer may take tuple/list/dict/list[dict] as input in forward function. We - recursively query tensor shape. + Some layer may take/generate tuple/list/dict/list[dict] as input/output in forward function. + We recursively query tensor shape. """ if isinstance(x, (list, tuple)): - assert len(x) > 0, "input x of tuple/list type must have at least one element" - return [get_input_shape(xi) for xi in x] + assert len(x) > 0, "x of tuple/list type must have at least one element" + return [get_shape(xi) for xi in x] elif isinstance(x, dict): - return {k: get_input_shape(v) for k, v in x.items()} + return {k: get_shape(v) for k, v in x.items()} else: - assert isinstance(x, torch.Tensor), "input x is expected to be a torch tensor" + assert isinstance(x, torch.Tensor), "x is expected to be a torch tensor" return x.size() @@ -346,8 +346,8 @@ def flops(self, x): message = [ f"module type: {typestr}", - f"input size: {get_input_shape(x)}", - f"output size: {list(y.size())}", + f"input size: {get_shape(x)}", + f"output size: {get_shape(y)}", f"params(M): {count_params(layer) / 1e6}", f"flops(M): {int(flops) / 1e6}", ] diff --git a/test/generic_profiler_test.py b/test/generic_profiler_test.py index 465479edef..86bb2cc72b 100644 --- a/test/generic_profiler_test.py +++ b/test/generic_profiler_test.py @@ -13,7 +13,7 @@ compute_activations, compute_flops, count_params, - get_input_shape, + get_shape, ) from classy_vision.models import build_model @@ -181,15 +181,15 @@ def test_flops_calculation(self): class TestHelperFunctions(unittest.TestCase): - def test_get_input_shape(self) -> None: + def test_get_shape(self) -> None: list_x = [torch.zeros(2, 4), torch.zeros(3, 3)] - shapes = get_input_shape(list_x) + shapes = get_shape(list_x) expected_shapes = [torch.zeros(2, 4).size(), torch.zeros(3, 3).size()] for shape, expected in zip(shapes, expected_shapes): self.assertEqual(shape, expected) dict_x = {"x1": torch.zeros(2, 4), "x2": torch.zeros(3, 3)} - shapes = get_input_shape(dict_x) + shapes = get_shape(dict_x) expected_shapes = { "x1": torch.zeros(2, 4).size(), "x2": torch.zeros(3, 3).size(), @@ -201,7 +201,7 @@ def test_get_input_shape(self) -> None: {"x1": torch.zeros(2, 4), "x2": torch.zeros(3, 3)}, {"x1": torch.zeros(3, 4), "x2": torch.zeros(4, 5)}, ] - shapes = get_input_shape(list_dict_x) + shapes = get_shape(list_dict_x) expected_shapes = [ {"x1": torch.zeros(2, 4).size(), "x2": torch.zeros(3, 3).size()}, {"x1": torch.zeros(3, 4).size(), "x2": torch.zeros(4, 5).size()},