diff --git a/mmengine/analysis/print_helper.py b/mmengine/analysis/print_helper.py index 7c03fb78be055d356fbf0805ced412312894cb0c..cd6092e98d7312254d942b4d1f8cd69fa0c130aa 100644 --- a/mmengine/analysis/print_helper.py +++ b/mmengine/analysis/print_helper.py @@ -12,6 +12,7 @@ from rich.console import Console from rich.table import Table from torch import nn +from mmengine.utils import is_tuple_of from .complexity_analysis import (ActivationAnalyzer, FlopAnalyzer, parameter_count) @@ -675,19 +676,38 @@ def complexity_stats_table( def get_model_complexity_info( model: nn.Module, - input_shape: Optional[tuple] = None, - inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], None] = None, + input_shape: Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...], + None] = None, + inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], Tuple[Any, ...], + None] = None, show_table: bool = True, show_arch: bool = True, ): """Interface to get the complexity of a model. + The parameter `inputs` are fed to the forward method of model. + If `inputs` is not specified, the `input_shape` is required and + it will be used to construct the dummy input fed to model. + If the forward of model requires two or more inputs, the `inputs` + should be a tuple of tensor or the `input_shape` should be a tuple + of tuple which each element will be constructed into a dumpy input. + + Examples: + >>> # the forward of model accepts only one input + >>> input_shape = (3, 224, 224) + >>> get_model_complexity_info(model, input_shape=input_shape) + >>> # the forward of model accepts two or more inputs + >>> input_shape = ((3, 224, 224), (3, 10)) + >>> get_model_complexity_info(model, input_shape=input_shape) + Args: model (nn.Module): The model to analyze. - input_shape (tuple, optional): The input shape of the model. - If inputs is not specified, the input_shape should be set. + input_shape (Union[Tuple[int, ...], Tuple[Tuple[int, ...]], None]): + The input shape of the model. + If "inputs" is not specified, the "input_shape" should be set. Defaults to None. - inputs (torch.Tensor or tuple[torch.Tensor, ...], optional]): + inputs (torch.Tensor, tuple[torch.Tensor, ...] or Tuple[Any, ...],\ + optional]): The input tensor(s) of the model. If not given the input tensor will be generated automatically with the given input_shape. Defaults to None. @@ -705,7 +725,21 @@ def get_model_complexity_info( raise ValueError('"input_shape" and "inputs" cannot be both set.') if inputs is None: - inputs = (torch.randn(1, *input_shape), ) + if is_tuple_of(input_shape, int): # tuple of int, construct one tensor + inputs = (torch.randn(1, *input_shape), ) + elif is_tuple_of(input_shape, tuple) and all([ + is_tuple_of(one_input_shape, int) + for one_input_shape in input_shape # type: ignore + ]): # tuple of tuple of int, construct multiple tensors + inputs = tuple([ + torch.randn(1, *one_input_shape) + for one_input_shape in input_shape # type: ignore + ]) + else: + raise ValueError( + '"input_shape" should be either a `tuple of int` (to construct' + 'one input tensor) or a `tuple of tuple of int` (to construct' + 'multiple input tensors).') flop_handler = FlopAnalyzer(model, inputs) activation_handler = ActivationAnalyzer(model, inputs) diff --git a/tests/test_analysis/test_print_helper.py b/tests/test_analysis/test_print_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..14366583d583edfa3b88c97cdfe579375e9ad2e3 --- /dev/null +++ b/tests/test_analysis/test_print_helper.py @@ -0,0 +1,108 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import pytest +import torch +import torch.nn as nn + +from mmengine.analysis.complexity_analysis import FlopAnalyzer, parameter_count +from mmengine.analysis.print_helper import get_model_complexity_info +from mmengine.utils import digit_version +from mmengine.utils.dl_utils import TORCH_VERSION + + +class NetAcceptOneTensor(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.l1 = nn.Linear(in_features=5, out_features=6) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.l1(x) + return out + + +class NetAcceptTwoTensors(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.l1 = nn.Linear(in_features=5, out_features=6) + self.l2 = nn.Linear(in_features=7, out_features=6) + + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + out = self.l1(x1) + self.l2(x2) + return out + + +class NetAcceptOneTensorAndOneScalar(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.l1 = nn.Linear(in_features=5, out_features=6) + self.l2 = nn.Linear(in_features=5, out_features=6) + + def forward(self, x1: torch.Tensor, r) -> torch.Tensor: + out = r * self.l1(x1) + (1 - r) * self.l2(x1) + return out + + +def test_get_model_complexity_info(): + input1 = torch.randn(1, 9, 5) + input_shape1 = (9, 5) + input2 = torch.randn(1, 9, 7) + input_shape2 = (9, 7) + scalar = 0.3 + + # test a network that accepts one tensor as input + model = NetAcceptOneTensor() + complexity_info = get_model_complexity_info(model=model, inputs=input1) + flops = FlopAnalyzer(model=model, inputs=input1).total() + params = parameter_count(model=model)[''] + assert complexity_info['flops'] == flops + assert complexity_info['params'] == params + + complexity_info = get_model_complexity_info( + model=model, input_shape=input_shape1) + flops = FlopAnalyzer( + model=model, inputs=(torch.randn(1, *input_shape1), )).total() + assert complexity_info['flops'] == flops + + # test a network that accepts two tensors as input + model = NetAcceptTwoTensors() + complexity_info = get_model_complexity_info( + model=model, inputs=(input1, input2)) + flops = FlopAnalyzer(model=model, inputs=(input1, input2)).total() + params = parameter_count(model=model)[''] + assert complexity_info['flops'] == flops + assert complexity_info['params'] == params + + complexity_info = get_model_complexity_info( + model=model, input_shape=(input_shape1, input_shape2)) + inputs = (torch.randn(1, *input_shape1), torch.randn(1, *input_shape2)) + flops = FlopAnalyzer(model=model, inputs=inputs).total() + assert complexity_info['flops'] == flops + + # test a network that accepts one tensor and one scalar as input + model = NetAcceptOneTensorAndOneScalar() + # For pytorch<1.9, a scalar input is not acceptable for torch.jit, + # wrap it to `torch.tensor`. See https://github.com/pytorch/pytorch/blob/cd9dd653e98534b5d3a9f2576df2feda40916f1d/torch/csrc/jit/python/python_arg_flatten.cpp#L90. # noqa: E501 + scalar = torch.tensor([ + scalar + ]) if digit_version(TORCH_VERSION) < digit_version('1.9.0') else scalar + complexity_info = get_model_complexity_info( + model=model, inputs=(input1, scalar)) + flops = FlopAnalyzer(model=model, inputs=(input1, scalar)).total() + params = parameter_count(model=model)[''] + assert complexity_info['flops'] == flops + assert complexity_info['params'] == params + + # `get_model_complexity_info()` should throw `ValueError` + # when neithor `inputs` nor `input_shape` is specified + with pytest.raises(ValueError, match='should be set'): + get_model_complexity_info(model) + + # `get_model_complexity_info()` should throw `ValueError` + # when both `inputs` and `input_shape` are specified + model = NetAcceptOneTensor() + with pytest.raises(ValueError, match='cannot be both set'): + get_model_complexity_info( + model, inputs=input1, input_shape=input_shape1)