Skip to content
Snippets Groups Projects
Unverified Commit fafb476e authored by sjiang95's avatar sjiang95 Committed by GitHub
Browse files

[Feature] get_model_complexity_info() supports multiple inputs (#1065)

parent 43165160
No related branches found
No related tags found
No related merge requests found
...@@ -12,6 +12,7 @@ from rich.console import Console ...@@ -12,6 +12,7 @@ from rich.console import Console
from rich.table import Table from rich.table import Table
from torch import nn from torch import nn
from mmengine.utils import is_tuple_of
from .complexity_analysis import (ActivationAnalyzer, FlopAnalyzer, from .complexity_analysis import (ActivationAnalyzer, FlopAnalyzer,
parameter_count) parameter_count)
...@@ -675,19 +676,38 @@ def complexity_stats_table( ...@@ -675,19 +676,38 @@ def complexity_stats_table(
def get_model_complexity_info( def get_model_complexity_info(
model: nn.Module, model: nn.Module,
input_shape: Optional[tuple] = None, input_shape: Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...],
inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], None] = None, None] = None,
inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], Tuple[Any, ...],
None] = None,
show_table: bool = True, show_table: bool = True,
show_arch: bool = True, show_arch: bool = True,
): ):
"""Interface to get the complexity of a model. """Interface to get the complexity of a model.
The parameter `inputs` are fed to the forward method of model.
If `inputs` is not specified, the `input_shape` is required and
it will be used to construct the dummy input fed to model.
If the forward of model requires two or more inputs, the `inputs`
should be a tuple of tensor or the `input_shape` should be a tuple
of tuple which each element will be constructed into a dumpy input.
Examples:
>>> # the forward of model accepts only one input
>>> input_shape = (3, 224, 224)
>>> get_model_complexity_info(model, input_shape=input_shape)
>>> # the forward of model accepts two or more inputs
>>> input_shape = ((3, 224, 224), (3, 10))
>>> get_model_complexity_info(model, input_shape=input_shape)
Args: Args:
model (nn.Module): The model to analyze. model (nn.Module): The model to analyze.
input_shape (tuple, optional): The input shape of the model. input_shape (Union[Tuple[int, ...], Tuple[Tuple[int, ...]], None]):
If inputs is not specified, the input_shape should be set. The input shape of the model.
If "inputs" is not specified, the "input_shape" should be set.
Defaults to None. Defaults to None.
inputs (torch.Tensor or tuple[torch.Tensor, ...], optional]): inputs (torch.Tensor, tuple[torch.Tensor, ...] or Tuple[Any, ...],\
optional]):
The input tensor(s) of the model. If not given the input tensor The input tensor(s) of the model. If not given the input tensor
will be generated automatically with the given input_shape. will be generated automatically with the given input_shape.
Defaults to None. Defaults to None.
...@@ -705,7 +725,21 @@ def get_model_complexity_info( ...@@ -705,7 +725,21 @@ def get_model_complexity_info(
raise ValueError('"input_shape" and "inputs" cannot be both set.') raise ValueError('"input_shape" and "inputs" cannot be both set.')
if inputs is None: if inputs is None:
inputs = (torch.randn(1, *input_shape), ) if is_tuple_of(input_shape, int): # tuple of int, construct one tensor
inputs = (torch.randn(1, *input_shape), )
elif is_tuple_of(input_shape, tuple) and all([
is_tuple_of(one_input_shape, int)
for one_input_shape in input_shape # type: ignore
]): # tuple of tuple of int, construct multiple tensors
inputs = tuple([
torch.randn(1, *one_input_shape)
for one_input_shape in input_shape # type: ignore
])
else:
raise ValueError(
'"input_shape" should be either a `tuple of int` (to construct'
'one input tensor) or a `tuple of tuple of int` (to construct'
'multiple input tensors).')
flop_handler = FlopAnalyzer(model, inputs) flop_handler = FlopAnalyzer(model, inputs)
activation_handler = ActivationAnalyzer(model, inputs) activation_handler = ActivationAnalyzer(model, inputs)
......
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn as nn
from mmengine.analysis.complexity_analysis import FlopAnalyzer, parameter_count
from mmengine.analysis.print_helper import get_model_complexity_info
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
class NetAcceptOneTensor(nn.Module):
def __init__(self) -> None:
super().__init__()
self.l1 = nn.Linear(in_features=5, out_features=6)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.l1(x)
return out
class NetAcceptTwoTensors(nn.Module):
def __init__(self) -> None:
super().__init__()
self.l1 = nn.Linear(in_features=5, out_features=6)
self.l2 = nn.Linear(in_features=7, out_features=6)
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
out = self.l1(x1) + self.l2(x2)
return out
class NetAcceptOneTensorAndOneScalar(nn.Module):
def __init__(self) -> None:
super().__init__()
self.l1 = nn.Linear(in_features=5, out_features=6)
self.l2 = nn.Linear(in_features=5, out_features=6)
def forward(self, x1: torch.Tensor, r) -> torch.Tensor:
out = r * self.l1(x1) + (1 - r) * self.l2(x1)
return out
def test_get_model_complexity_info():
input1 = torch.randn(1, 9, 5)
input_shape1 = (9, 5)
input2 = torch.randn(1, 9, 7)
input_shape2 = (9, 7)
scalar = 0.3
# test a network that accepts one tensor as input
model = NetAcceptOneTensor()
complexity_info = get_model_complexity_info(model=model, inputs=input1)
flops = FlopAnalyzer(model=model, inputs=input1).total()
params = parameter_count(model=model)['']
assert complexity_info['flops'] == flops
assert complexity_info['params'] == params
complexity_info = get_model_complexity_info(
model=model, input_shape=input_shape1)
flops = FlopAnalyzer(
model=model, inputs=(torch.randn(1, *input_shape1), )).total()
assert complexity_info['flops'] == flops
# test a network that accepts two tensors as input
model = NetAcceptTwoTensors()
complexity_info = get_model_complexity_info(
model=model, inputs=(input1, input2))
flops = FlopAnalyzer(model=model, inputs=(input1, input2)).total()
params = parameter_count(model=model)['']
assert complexity_info['flops'] == flops
assert complexity_info['params'] == params
complexity_info = get_model_complexity_info(
model=model, input_shape=(input_shape1, input_shape2))
inputs = (torch.randn(1, *input_shape1), torch.randn(1, *input_shape2))
flops = FlopAnalyzer(model=model, inputs=inputs).total()
assert complexity_info['flops'] == flops
# test a network that accepts one tensor and one scalar as input
model = NetAcceptOneTensorAndOneScalar()
# For pytorch<1.9, a scalar input is not acceptable for torch.jit,
# wrap it to `torch.tensor`. See https://github.com/pytorch/pytorch/blob/cd9dd653e98534b5d3a9f2576df2feda40916f1d/torch/csrc/jit/python/python_arg_flatten.cpp#L90. # noqa: E501
scalar = torch.tensor([
scalar
]) if digit_version(TORCH_VERSION) < digit_version('1.9.0') else scalar
complexity_info = get_model_complexity_info(
model=model, inputs=(input1, scalar))
flops = FlopAnalyzer(model=model, inputs=(input1, scalar)).total()
params = parameter_count(model=model)['']
assert complexity_info['flops'] == flops
assert complexity_info['params'] == params
# `get_model_complexity_info()` should throw `ValueError`
# when neithor `inputs` nor `input_shape` is specified
with pytest.raises(ValueError, match='should be set'):
get_model_complexity_info(model)
# `get_model_complexity_info()` should throw `ValueError`
# when both `inputs` and `input_shape` are specified
model = NetAcceptOneTensor()
with pytest.raises(ValueError, match='cannot be both set'):
get_model_complexity_info(
model, inputs=input1, input_shape=input_shape1)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment