Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions classy_vision/generic/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,18 @@ def profile(
return profiler


def get_input_shape(x: Union[Tuple, List, Dict]) -> Union[Tuple, List, Dict]:
def get_shape(x: Union[Tuple, List, Dict]) -> Union[Tuple, List, Dict]:
"""
Some layer may take tuple/list/dict/list[dict] as input in forward function. We
recursively query tensor shape.
Some layer may take/generate tuple/list/dict/list[dict] as input/output in forward function.
We recursively query tensor shape.
"""
if isinstance(x, (list, tuple)):
assert len(x) > 0, "input x of tuple/list type must have at least one element"
return [get_input_shape(xi) for xi in x]
assert len(x) > 0, "x of tuple/list type must have at least one element"
return [get_shape(xi) for xi in x]
elif isinstance(x, dict):
return {k: get_input_shape(v) for k, v in x.items()}
return {k: get_shape(v) for k, v in x.items()}
else:
assert isinstance(x, torch.Tensor), "input x is expected to be a torch tensor"
assert isinstance(x, torch.Tensor), "x is expected to be a torch tensor"
return x.size()


Expand Down Expand Up @@ -346,8 +346,8 @@ def flops(self, x):

message = [
f"module type: {typestr}",
f"input size: {get_input_shape(x)}",
f"output size: {list(y.size())}",
f"input size: {get_shape(x)}",
f"output size: {get_shape(y)}",
f"params(M): {count_params(layer) / 1e6}",
f"flops(M): {int(flops) / 1e6}",
]
Expand Down
10 changes: 5 additions & 5 deletions test/generic_profiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
compute_activations,
compute_flops,
count_params,
get_input_shape,
get_shape,
)
from classy_vision.models import build_model

Expand Down Expand Up @@ -181,15 +181,15 @@ def test_flops_calculation(self):


class TestHelperFunctions(unittest.TestCase):
def test_get_input_shape(self) -> None:
def test_get_shape(self) -> None:
list_x = [torch.zeros(2, 4), torch.zeros(3, 3)]
shapes = get_input_shape(list_x)
shapes = get_shape(list_x)
expected_shapes = [torch.zeros(2, 4).size(), torch.zeros(3, 3).size()]
for shape, expected in zip(shapes, expected_shapes):
self.assertEqual(shape, expected)

dict_x = {"x1": torch.zeros(2, 4), "x2": torch.zeros(3, 3)}
shapes = get_input_shape(dict_x)
shapes = get_shape(dict_x)
expected_shapes = {
"x1": torch.zeros(2, 4).size(),
"x2": torch.zeros(3, 3).size(),
Expand All @@ -201,7 +201,7 @@ def test_get_input_shape(self) -> None:
{"x1": torch.zeros(2, 4), "x2": torch.zeros(3, 3)},
{"x1": torch.zeros(3, 4), "x2": torch.zeros(4, 5)},
]
shapes = get_input_shape(list_dict_x)
shapes = get_shape(list_dict_x)
expected_shapes = [
{"x1": torch.zeros(2, 4).size(), "x2": torch.zeros(3, 3).size()},
{"x1": torch.zeros(3, 4).size(), "x2": torch.zeros(4, 5).size()},
Expand Down