Skip to content
Snippets Groups Projects
Unverified Commit aee2f6a6 authored by Songyang Zhang's avatar Songyang Zhang Committed by GitHub
Browse files

[Feature] Support model complexity computation (#779)


* [Feature] Add support model complexity computation

* [Fix] fix lint error

* [Feature] update print_helper

* Update docstring

* update api, docs, fix lint

* fix lint

* update doc and add test

* update docstring

* update docstring

* update test

* Update mmengine/analysis/print_helper.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/analysis/print_helper.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/analysis/print_helper.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/analysis/print_helper.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/analysis/print_helper.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/analysis/print_helper.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/analysis/print_helper.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/analysis/print_helper.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/analysis/print_helper.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/analysis/print_helper.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/analysis/print_helper.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/analysis/print_helper.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/analysis/print_helper.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/analysis/print_helper.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/analysis/complexity_analysis.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update docs/en/advanced_tutorials/model_analysis.md

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update docs/en/advanced_tutorials/model_analysis.md

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* update docs

* update docs

* update docs and docstring

* update docs

* update test withj mmlogger

* Update mmengine/analysis/complexity_analysis.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update tests/test_analysis/test_activation_count.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* update test according to review

* Apply suggestions from code review

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* fix lint

* fix test

* Apply suggestions from code review

* fix API document

* Update analysis.rst

* rename variables

* minor refinement

* Apply suggestions from code review

* fix lint

* replace tabulate with existing rich

* Apply suggestions from code review

* indent

* Update mmengine/analysis/complexity_analysis.py

* Update mmengine/analysis/complexity_analysis.py

* Update mmengine/analysis/complexity_analysis.py

---------

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: default avatarzhouzaida <zhouzaida@163.com>
parent 1d97c070
No related branches found
No related tags found
No related merge requests found
Showing with 4223 additions and 0 deletions
# Model Complexity Analysis
We provide a tool to help with the complexity analysis for the network. We borrow the idea from the implementation of [fvcore](https://github.com/facebookresearch/fvcore) to build this tool, and plan to support more custom operators in the future. Currently, it provides the interfaces to compute "parameter", "activation" and "flops" of the given model, and supports printing the related information layer-by-layer in terms of network structure or table. The analysis tool provides both operator-level and module-level flop counts simultaneously. Please refer to [Flop Count](https://github.com/facebookresearch/fvcore/blob/main/docs/flop_count.md) for implementation details of how to accurately measure the flops of one operator if interested.
## What's FLOPs
Flop is not a well-defined metric in complexity analysis, we follow [detectron2](https://detectron2.readthedocs.io/en/latest/modules/fvcore.html#fvcore.nn.FlopCountAnalysis) to use one fused multiple-add as one flop.
## What's Activation
Activation is used to measure the feature quantity produced from one layer.
For example, given the inputs with shape `inputs = torch.randn((1, 3, 10, 10))`, and one linear layer with `conv = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=1)`.
We get the `output` with shape `(1, 10, 10, 10)` after feeding the `inputs` into `conv`. The activation quantity of `output` of this `conv` layer is `1000=10*10*10`
Let's start with the following examples.
## Usage Example 1: Model built with native nn.Module
### Code
```python
import torch
from torch import nn
from mmengine.analysis import get_model_complexity_info
# return a dict of analysis results, including:
# ['flops', 'flops_str', 'activations', 'activations_str', 'params', 'params_str', 'out_table', 'out_arch']
class InnerNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10,10)
self.fc2 = nn.Linear(10,10)
def forward(self, x):
return self.fc1(self.fc2(x))
class TestNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10,10)
self.fc2 = nn.Linear(10,10)
self.inner = InnerNet()
def forward(self, x):
return self.fc1(self.fc2(self.inner(x)))
input_shape = (1, 10)
model = TestNet()
analysis_results = get_model_complexity_info(model, input_shape)
print(analysis_results['out_table'])
print(analysis_results['out_arch'])
print("Model Flops:{}".format(analysis_results['flops_str']))
print("Model Parameters:{}".format(analysis_results['params_str']))
```
### Description of Results
The return outputs is dict, which contains the following keys:
- `flops`: number of total flops, e.g., 10000, 10000
- `flops_str`: with formatted string, e.g., 1.0G, 100M
- `params`: number of total parameters, e.g., 10000, 10000
- `params_str`: with formatted string, e.g., 1.0G, 100M
- `activations`: number of total activations, e.g., 10000, 10000
- `activations_str`: with formatted string, e.g., 1.0G, 100M
- `out_table`: print related information by table
```
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━┓
┃ module ┃ #parameters or shape ┃ #flops ┃ #activations ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━┩
│ model │ 0.44K │ 0.4K │ 40 │
│ fc1 │ 0.11K │ 100 │ 10 │
│ fc1.weight │ (10, 10) │ │ │
│ fc1.bias │ (10,) │ │ │
│ fc2 │ 0.11K │ 100 │ 10 │
│ fc2.weight │ (10, 10) │ │ │
│ fc2.bias │ (10,) │ │ │
│ inner │ 0.22K │ 0.2K │ 20 │
│ inner.fc1 │ 0.11K │ 100 │ 10 │
│ inner.fc1.weight │ (10, 10) │ │ │
│ inner.fc1.bias │ (10,) │ │ │
│ inner.fc2 │ 0.11K │ 100 │ 10 │
│ inner.fc2.weight │ (10, 10) │ │ │
│ inner.fc2.bias │ (10,) │ │ │
└─────────────────────┴──────────────────────┴────────┴──────────────┘
```
- `out_arch`: print related information by network layers
```bash
TestNet(
#params: 0.44K, #flops: 0.4K, #acts: 40
(fc1): Linear(
in_features=10, out_features=10, bias=True
#params: 0.11K, #flops: 100, #acts: 10
)
(fc2): Linear(
in_features=10, out_features=10, bias=True
#params: 0.11K, #flops: 100, #acts: 10
)
(inner): InnerNet(
#params: 0.22K, #flops: 0.2K, #acts: 20
(fc1): Linear(
in_features=10, out_features=10, bias=True
#params: 0.11K, #flops: 100, #acts: 10
)
(fc2): Linear(
in_features=10, out_features=10, bias=True
#params: 0.11K, #flops: 100, #acts: 10
)
)
)
```
## Usage Example 2: Model built with mmengine
### Code
```python
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel
from mmengine.analysis import get_model_complexity_info
class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()
def forward(self, imgs, labels=None, mode='tensor'):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels
elif mode == 'tensor':
return x
input_shape = (3, 224, 224)
model = MMResNet50()
analysis_results = get_model_complexity_info(model, input_shape)
print("Model Flops:{}".format(analysis_results['flops_str']))
print("Model Parameters:{}".format(analysis_results['params_str']))
```
### Output
```bash
Model Flops:4.145G
Model Parameters:25.557M
```
## Interface
We provide more options to support custom output
- `model`: (nn.Module) the model to be analyzed
- `input_shape`: (tuple) the shape of the input, e.g., (3, 224, 224)
- `inputs`: (optional: torch.Tensor), if given, `input_shape` will be ignored
- `show_table`: (bool) whether return the statistics in the form of table, default: True
- `show_arch`: (bool) whether return the statistics in the form of table, default: True
.. role:: hidden
:class: hidden-section
mmengine.analysis
===================================
.. contents:: mmengine.analysis
:depth: 2
:local:
:backlinks: top
.. currentmodule:: mmengine.analysis
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
ActivationAnalyzer
FlopAnalyzer
.. autosummary::
:toctree: generated
:nosignatures:
activation_count
flop_count
parameter_count
parameter_count_table
get_model_complexity_info
......@@ -53,6 +53,7 @@ You can switch between Chinese and English documents in the lower-left corner of
advanced_tutorials/manager_mixin.md
advanced_tutorials/cross_library.md
advanced_tutorials/test_time_augmentation.md
advanced_tutorials/model_analysis.md
.. toctree::
:maxdepth: 1
......@@ -79,6 +80,7 @@ You can switch between Chinese and English documents in the lower-left corner of
:maxdepth: 2
:caption: API Reference
mmengine.analysis <api/analysis>
mmengine.registry <api/registry>
mmengine.config <api/config>
mmengine.runner <api/runner>
......
.. role:: hidden
:class: hidden-section
mmengine.analysis
===================================
.. contents:: mmengine.analysis
:depth: 2
:local:
:backlinks: top
.. currentmodule:: mmengine.analysis
.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst
ActivationAnalyzer
FlopAnalyzer
.. autosummary::
:toctree: generated
:nosignatures:
activation_count
flop_count
parameter_count
parameter_count_table
get_model_complexity_info
......@@ -81,6 +81,7 @@
:maxdepth: 2
:caption: API 文档
mmengine.analysis <api/analysis>
mmengine.registry <api/registry>
mmengine.config <api/config>
mmengine.runner <api/runner>
......
# Copyright (c) OpenMMLab. All rights reserved.
from .complexity_analysis import (ActivationAnalyzer, FlopAnalyzer,
activation_count, flop_count,
parameter_count, parameter_count_table)
from .print_helper import get_model_complexity_info
__all__ = [
'FlopAnalyzer', 'ActivationAnalyzer', 'flop_count', 'activation_count',
'parameter_count', 'parameter_count_table', 'get_model_complexity_info'
]
# Copyright (c) OpenMMLab. All rights reserved.
import typing
from collections import defaultdict
from typing import Any, Counter, DefaultDict, Dict, Optional, Tuple, Union
import torch.nn as nn
from rich.console import Console
from rich.table import Table
from torch import Tensor
from .jit_analysis import JitModelAnalysis
from .jit_handles import (Handle, addmm_flop_jit, batchnorm_flop_jit,
bmm_flop_jit, conv_flop_jit, einsum_flop_jit,
elementwise_flop_counter, generic_activation_jit,
linear_flop_jit, matmul_flop_jit, norm_flop_counter)
# A dictionary that maps supported operations to their flop count jit handles.
_DEFAULT_SUPPORTED_FLOP_OPS: Dict[str, Handle] = {
'aten::addmm': addmm_flop_jit,
'aten::bmm': bmm_flop_jit,
'aten::_convolution': conv_flop_jit,
'aten::einsum': einsum_flop_jit,
'aten::matmul': matmul_flop_jit,
'aten::mm': matmul_flop_jit,
'aten::linear': linear_flop_jit,
# You might want to ignore BN flops due to inference-time fusion.
# Use `set_op_handle("aten::batch_norm", None)
'aten::batch_norm': batchnorm_flop_jit,
'aten::group_norm': norm_flop_counter(2),
'aten::layer_norm': norm_flop_counter(2),
'aten::instance_norm': norm_flop_counter(1),
'aten::upsample_nearest2d': elementwise_flop_counter(0, 1),
'aten::upsample_bilinear2d': elementwise_flop_counter(0, 4),
'aten::adaptive_avg_pool2d': elementwise_flop_counter(1, 0),
'aten::grid_sampler': elementwise_flop_counter(0, 4), # assume bilinear
}
# A dictionary that maps supported operations to
# their activation count handles.
_DEFAULT_SUPPORTED_ACT_OPS: Dict[str, Handle] = {
'aten::_convolution': generic_activation_jit('conv'),
'aten::addmm': generic_activation_jit(),
'aten::bmm': generic_activation_jit(),
'aten::einsum': generic_activation_jit(),
'aten::matmul': generic_activation_jit(),
'aten::linear': generic_activation_jit(),
}
class FlopAnalyzer(JitModelAnalysis):
"""Provides access to per-submodule model flop count obtained by tracing a
model with pytorch's jit tracing functionality.
By default, comes with standard flop counters for a few common operators.
Note:
- Flop is not a well-defined concept. We just produce our best
estimate.
- We count one fused multiply-add as one flop.
Handles for additional operators may be added, or the default ones
overwritten, using the ``.set_op_handle(name, func)`` method.
See the method documentation for details.
Flop counts can be obtained as:
- ``.total(module_name="")``: total flop count for the module
- ``.by_operator(module_name="")``: flop counts for the module, as a
Counter over different operator types
- ``.by_module()``: Counter of flop counts for all submodules
- ``.by_module_and_operator()``: dictionary indexed by descendant of
Counters over different operator types
An operator is treated as within a module if it is executed inside the
module's ``__call__`` method. Note that this does not include calls to
other methods of the module or explicit calls to ``module.forward(...)``.
Modified from
https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/flop_count.py
Args:
model (nn.Module): The model to analyze.
inputs (Union[Tensor, Tuple[Tensor, ...]]): The input to the model.
Examples:
>>> import torch.nn as nn
>>> import torch
>>> class TestModel(nn.Module):
... def __init__(self):
... super().__init__()
... self.fc = nn.Linear(in_features=1000, out_features=10)
... self.conv = nn.Conv2d(
... in_channels=3, out_channels=10, kernel_size=1
... )
... self.act = nn.ReLU()
... def forward(self, x):
... return self.fc(self.act(self.conv(x)).flatten(1))
>>> model = TestModel()
>>> inputs = (torch.randn((1,3,10,10)),)
>>> flops = FlopAnalyzer(model, inputs)
>>> flops.total()
13000
>>> flops.total("fc")
10000
>>> flops.by_operator()
Counter({"addmm" : 10000, "conv" : 3000})
>>> flops.by_module()
Counter({"" : 13000, "fc" : 10000, "conv" : 3000, "act" : 0})
>>> flops.by_module_and_operator()
{"" : Counter({"addmm" : 10000, "conv" : 3000}),
"fc" : Counter({"addmm" : 10000}),
"conv" : Counter({"conv" : 3000}),
"act" : Counter()
}
"""
def __init__(
self,
model: nn.Module,
inputs: Union[Tensor, Tuple[Tensor, ...]],
) -> None:
super().__init__(model=model, inputs=inputs)
self.set_op_handle(**_DEFAULT_SUPPORTED_FLOP_OPS)
__init__.__doc__ = JitModelAnalysis.__init__.__doc__
class ActivationAnalyzer(JitModelAnalysis):
"""Provides access to per-submodule model activation count obtained by
tracing a model with pytorch's jit tracing functionality.
By default, comes with standard activation counters for convolutional and
dot-product operators. Handles for additional operators may be added, or
the default ones overwritten, using the ``.set_op_handle(name, func)``
method. See the method documentation for details. Activation counts can be
obtained as:
- ``.total(module_name="")``: total activation count for a module
- ``.by_operator(module_name="")``: activation counts for the module,
as a Counter over different operator types
- ``.by_module()``: Counter of activation counts for all submodules
- ``.by_module_and_operator()``: dictionary indexed by descendant of
Counters over different operator types
An operator is treated as within a module if it is executed inside the
module's ``__call__`` method. Note that this does not include calls to
other methods of the module or explicit calls to ``module.forward(...)``.
Modified from
https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/activation_count.py
Args:
model (nn.Module): The model to analyze.
inputs (Union[Tensor, Tuple[Tensor, ...]]): The input to the model.
Examples:
>>> import torch.nn as nn
>>> import torch
>>> class TestModel(nn.Module):
... def __init__(self):
... super().__init__()
... self.fc = nn.Linear(in_features=1000, out_features=10)
... self.conv = nn.Conv2d(
... in_channels=3, out_channels=10, kernel_size=1
... )
... self.act = nn.ReLU()
... def forward(self, x):
... return self.fc(self.act(self.conv(x)).flatten(1))
>>> model = TestModel()
>>> inputs = (torch.randn((1,3,10,10)),)
>>> acts = ActivationAnalyzer(model, inputs)
>>> acts.total()
1010
>>> acts.total("fc")
10
>>> acts.by_operator()
Counter({"conv" : 1000, "addmm" : 10})
>>> acts.by_module()
Counter({"" : 1010, "fc" : 10, "conv" : 1000, "act" : 0})
>>> acts.by_module_and_operator()
{"" : Counter({"conv" : 1000, "addmm" : 10}),
"fc" : Counter({"addmm" : 10}),
"conv" : Counter({"conv" : 1000}),
"act" : Counter()
}
"""
def __init__(
self,
model: nn.Module,
inputs: Union[Tensor, Tuple[Tensor, ...]],
) -> None:
super().__init__(model=model, inputs=inputs)
self.set_op_handle(**_DEFAULT_SUPPORTED_ACT_OPS)
__init__.__doc__ = JitModelAnalysis.__init__.__doc__
def flop_count(
model: nn.Module,
inputs: Tuple[Any, ...],
supported_ops: Optional[Dict[str, Handle]] = None,
) -> Tuple[DefaultDict[str, float], Counter[str]]:
"""Given a model and an input to the model, compute the per-operator Gflops
of the given model.
Adopted from
https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/flop_count.py
Args:
model (nn.Module): The model to compute flop counts.
inputs (tuple): Inputs that are passed to `model` to count flops.
Inputs need to be in a tuple.
supported_ops (dict(str,Callable) or None) : provide additional
handlers for extra ops, or overwrite the existing handlers for
convolution and matmul and einsum. The key is operator name and
the value is a function that takes (inputs, outputs) of the op.
We count one Multiply-Add as one FLOP.
Returns:
tuple[defaultdict, Counter]: A dictionary that records the number of
gflops for each operation and a Counter that records the number of
unsupported operations.
"""
if supported_ops is None:
supported_ops = {}
flop_counter = FlopAnalyzer(model, inputs).set_op_handle(**supported_ops)
giga_flops = defaultdict(float)
for op, flop in flop_counter.by_operator().items():
giga_flops[op] = flop / 1e9
return giga_flops, flop_counter.unsupported_ops()
def activation_count(
model: nn.Module,
inputs: Tuple[Any, ...],
supported_ops: Optional[Dict[str, Handle]] = None,
) -> Tuple[DefaultDict[str, float], Counter[str]]:
"""Given a model and an input to the model, compute the total number of
activations of the model.
Adopted from
https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/activation_count.py
Args:
model (nn.Module): The model to compute activation counts.
inputs (tuple): Inputs that are passed to `model` to count activations.
Inputs need to be in a tuple.
supported_ops (dict(str,Callable) or None) : provide additional
handlers for extra ops, or overwrite the existing handlers for
convolution and matmul. The key is operator name and the value
is a function that takes (inputs, outputs) of the op.
Returns:
tuple[defaultdict, Counter]: A dictionary that records the number of
activation (mega) for each operation and a Counter that records the
number of unsupported operations.
"""
if supported_ops is None:
supported_ops = {}
act_counter = ActivationAnalyzer(model,
inputs).set_op_handle(**supported_ops)
mega_acts = defaultdict(float)
for op, act in act_counter.by_operator().items():
mega_acts[op] = act / 1e6
return mega_acts, act_counter.unsupported_ops()
def parameter_count(model: nn.Module) -> typing.DefaultDict[str, int]:
"""Count parameters of a model and its submodules.
Adopted from
https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/parameter_count.py
Args:
model (nn.Module): the model to count parameters.
Returns:
dict[str, int]: the key is either a parameter name or a module name.
The value is the number of elements in the parameter, or in all
parameters of the module. The key "" corresponds to the total
number of parameters of the model.
"""
count = defaultdict(int) # type: typing.DefaultDict[str, int]
for name, param in model.named_parameters():
size = param.numel()
name = name.split('.')
for k in range(0, len(name) + 1):
prefix = '.'.join(name[:k])
count[prefix] += size
return count
def parameter_count_table(model: nn.Module, max_depth: int = 3) -> str:
"""Format the parameter count of the model (and its submodules or
parameters)
Adopted from
https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/parameter_count.py
Args:
model (nn.Module): the model to count parameters.
max_depth (int): maximum depth to recursively print submodules or
parameters
Returns:
str: the table to be printed
"""
count: typing.DefaultDict[str, int] = parameter_count(model)
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
param_shape: typing.Dict[str, typing.Tuple] = {
k: tuple(v.shape)
for k, v in model.named_parameters()
}
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
rows: typing.List[typing.Tuple] = []
def format_size(x: int) -> str:
if x > 1e8:
return f'{x / 1e9:.1f}G'
if x > 1e5:
return f'{x / 1e6:.1f}M'
if x > 1e2:
return f'{x / 1e3:.1f}K'
return str(x)
def fill(lvl: int, prefix: str) -> None:
if lvl >= max_depth:
return
for name, v in count.items():
if name.count('.') == lvl and name.startswith(prefix):
indent = ' ' * (lvl + 1)
if name in param_shape:
rows.append(
(indent + name, indent + str(param_shape[name])))
else:
rows.append((indent + name, indent + format_size(v)))
fill(lvl + 1, name + '.')
rows.append(('model', format_size(count.pop(''))))
fill(0, '')
table = Table(title=f'parameter count of {model.__class__.__name__}')
table.add_column('name')
table.add_column('#elements or shape')
for row in rows:
table.add_row(*row)
console = Console()
with console.capture() as capture:
console.print(table, end='')
return capture.get()
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Modified from
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py
import typing
from collections import Counter, OrderedDict
from typing import Any, Callable, List, Optional, Union
import numpy as np
try:
from math import prod # type: ignore
except ImportError:
from numpy import prod # type: ignore
Handle = Callable[[List[Any], List[Any]], Union[typing.Counter[str], int]]
def get_shape(val: Any) -> Optional[List[int]]:
"""Get the shapes from a jit value object.
Args:
val (torch._C.Value): jit value object.
Returns:
list(int): return a list of ints.
"""
if val.isCompleteTensor():
return val.type().sizes()
else:
return None # type: ignore
"""
Below are flop/activation counters for various ops.
Every counter has the following signature:
Args:
inputs (list(torch._C.Value)):
The inputs of the op in the form of a list of jit object.
outputs (list(torch._C.Value)):
The outputs of the op in the form of a list of jit object.
Returns:
number: The number of flops/activations for the operation.
or Counter[str]
"""
def generic_activation_jit(op_name: Optional[str] = None) -> Handle:
"""This method returns a handle that counts the number of activation from
the output shape for the specified operation.
Args:
op_name (str): The name of the operation. If given, the handle will
return a counter using this name.
Returns:
Callable: An activation handle for the given operation.
"""
def _generic_activation_jit(
i: Any, outputs: List[Any]) -> Union[typing.Counter[str], int]:
"""This is a generic jit handle that counts the number of activations
for any operation given the output shape."""
out_shape = get_shape(outputs[0])
ac_count = prod(out_shape) # type: ignore
if op_name is None:
return ac_count # type: ignore
else:
return Counter({op_name: ac_count})
return _generic_activation_jit
def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]:
"""Count flops for fully connected layers."""
# Count flop for nn.Linear
# inputs is a list of length 3.
input_shapes = [get_shape(v) for v in inputs[1:3]]
# input_shapes[0]: [batch size, input feature dimension]
# input_shapes[1]: [batch size, output feature dimension]
assert len(input_shapes[0]) == 2, input_shapes[0] # type: ignore
assert len(input_shapes[1]) == 2, input_shapes[1] # type: ignore
batch_size, input_dim = input_shapes[0] # type: ignore
output_dim = input_shapes[1][1] # type: ignore
flops = batch_size * input_dim * output_dim
return flops
def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]:
"""Count flops for the aten::linear operator."""
# Inputs is a list of length 3; unlike aten::addmm, it is the first
# two elements that are relevant.
input_shapes = [get_shape(v) for v in inputs[0:2]]
# input_shapes[0]: [dim0, dim1, ..., input_feature_dim]
# input_shapes[1]: [output_feature_dim, input_feature_dim]
assert input_shapes[0][-1] == input_shapes[1][-1] # type: ignore
flops = prod(input_shapes[0]) * input_shapes[1][0] # type: ignore
return flops
def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]:
"""Count flops for the bmm operation."""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two tensor.
assert len(inputs) == 2, len(inputs)
input_shapes = [get_shape(v) for v in inputs]
n, c, t = input_shapes[0] # type: ignore
d = input_shapes[-1][-1] # type: ignore
flop = n * c * t * d
return flop
def conv_flop_count(
x_shape: List[int],
w_shape: List[int],
out_shape: List[int],
transposed: bool = False,
) -> Union[int, Any]:
"""Count flops for convolution. Note only multiplication is counted.
Computation for addition and bias is ignored. Flops for a transposed
convolution are calculated as.
flops = (x_shape[2:] * prod(w_shape) * batch_size).
Args:
x_shape (list(int)): The input shape before convolution.
w_shape (list(int)): The filter shape.
out_shape (list(int)): The output shape after convolution.
transposed (bool): is the convolution transposed
Returns:
int: the number of flops
"""
batch_size = x_shape[0]
conv_shape = (x_shape if transposed else out_shape)[2:]
flop = batch_size * prod(w_shape) * prod(conv_shape)
return flop
def conv_flop_jit(inputs: List[Any],
outputs: List[Any]) -> typing.Counter[str]:
"""Count flops for convolution."""
# Inputs of Convolution should be a list of length 12 or 13.
# They represent:
# 0) input tensor, 1) convolution filter, 2) bias, 3) stride, 4) padding,
# 5) dilation, 6) transposed, 7) out_pad, 8) groups, 9) benchmark_cudnn,
# 10) deterministic_cudnn and 11) user_enabled_cudnn.
# starting with #40737 it will be 12) user_enabled_tf32
assert len(inputs) == 12 or len(inputs) == 13, len(inputs)
x, w = inputs[:2]
x_shape, w_shape, out_shape = (get_shape(x), get_shape(w),
get_shape(outputs[0]))
transposed = inputs[6].toIValue()
# use a custom name instead of "_convolution"
return Counter({
'conv':
conv_flop_count(
x_shape, # type: ignore
w_shape, # type: ignore
out_shape, # type: ignore
transposed=transposed) # type: ignore
})
def einsum_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]:
"""Count flops for the einsum operation."""
# Inputs of einsum should be a list of length 2+.
# Inputs[0] stores the equation used for einsum.
# Inputs[1] stores the list of input shapes.
assert len(inputs) >= 2, len(inputs)
equation = inputs[0].toIValue()
# Get rid of white space in the equation string.
equation = equation.replace(' ', '')
input_shapes_jit = inputs[1].node().inputs()
input_shapes = [get_shape(v) for v in input_shapes_jit]
# Re-map equation so that same equation with different alphabet
# representations will look the same.
letter_order = OrderedDict((k, 0) for k in equation if k.isalpha()).keys()
mapping = {ord(x): 97 + i for i, x in enumerate(letter_order)}
equation = equation.translate(mapping)
if equation == 'abc,abd->acd':
n, c, t = input_shapes[0] # type: ignore
p = input_shapes[-1][-1] # type: ignore
flop = n * c * t * p
return flop
elif equation == 'abc,adc->adb':
n, t, g = input_shapes[0] # type: ignore
c = input_shapes[-1][1] # type: ignore
flop = n * t * g * c
return flop
else:
np_arrs = [np.zeros(s) for s in input_shapes]
optim = np.einsum_path(equation, *np_arrs, optimize='optimal')[1]
for line in optim.split('\n'):
if 'optimized flop' in line.lower():
# divided by 2 because we count MAC
# (multiply-add counted as one flop)
flop = float(np.floor(float(line.split(':')[-1]) / 2))
return flop
raise NotImplementedError('Unsupported einsum operation.')
def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Union[int, Any]:
"""Count flops for matmul."""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two matrices.
input_shapes = [get_shape(v) for v in inputs]
assert len(input_shapes) == 2, input_shapes
assert input_shapes[0][-1] == input_shapes[1][ # type: ignore
-2], input_shapes # type: ignore
flop = prod(input_shapes[0]) * input_shapes[-1][-1] # type: ignore
return flop
def norm_flop_counter(affine_arg_index: int) -> Handle:
"""
Args:
affine_arg_index: index of the affine argument in inputs
"""
def norm_flop_jit(inputs: List[Any],
outputs: List[Any]) -> Union[int, Any]:
"""Count flops for norm layers."""
# Inputs[0] contains the shape of the input.
input_shape = get_shape(inputs[0])
has_affine = get_shape(inputs[affine_arg_index]) is not None
assert 2 <= len(input_shape) <= 5, input_shape # type: ignore
# 5 is just a rough estimate
flop = prod(input_shape) * (5 if has_affine else 4) # type: ignore
return flop
return norm_flop_jit
def batchnorm_flop_jit(inputs: List[Any],
outputs: List[Any]) -> Union[int, Any]:
training = inputs[5].toIValue()
assert isinstance(training,
bool), 'Signature of aten::batch_norm has changed!'
if training:
return norm_flop_counter(1)(inputs, outputs) # pyre-ignore
has_affine = get_shape(inputs[1]) is not None
input_shape = prod(get_shape(inputs[0])) # type: ignore
return input_shape * (2 if has_affine else 1)
def elementwise_flop_counter(input_scale: float = 1,
output_scale: float = 0) -> Handle:
"""Count flops by.
input_tensor.numel() * input_scale +
output_tensor.numel() * output_scale
Args:
input_scale: scale of the input tensor (first argument)
output_scale: scale of the output tensor (first element in outputs)
"""
def elementwise_flop(inputs: List[Any],
outputs: List[Any]) -> Union[int, Any]:
ret = 0
if input_scale != 0:
shape = get_shape(inputs[0])
ret += input_scale * prod(shape) # type: ignore
if output_scale != 0:
shape = get_shape(outputs[0])
ret += output_scale * prod(shape) # type: ignore
return ret
return elementwise_flop
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Modified from
# https://github.com/facebookresearch/fvcore/blob/main/tests/test_activation_count.py
# pyre-ignore-all-errors[2]
import typing
import unittest
from collections import Counter, defaultdict
from typing import Any, Dict, List, Tuple
import torch
import torch.nn as nn
from numpy import prod
from mmengine.analysis import ActivationAnalyzer, activation_count
from mmengine.analysis.jit_handles import Handle
class SmallConvNet(nn.Module):
"""A network with three conv layers.
This is used for testing convolution layers for activation count.
"""
def __init__(self, input_dim: int) -> None:
super().__init__()
conv_dim1 = 8
conv_dim2 = 4
conv_dim3 = 2
self.conv1 = nn.Conv2d(input_dim, conv_dim1, 1, 1)
self.conv2 = nn.Conv2d(conv_dim1, conv_dim2, 1, 2)
self.conv3 = nn.Conv2d(conv_dim2, conv_dim3, 1, 2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
return x
def get_gt_activation(self, x: torch.Tensor) -> Tuple[int, int, int]:
x = self.conv1(x)
count1 = prod(list(x.size()))
x = self.conv2(x)
count2 = prod(list(x.size()))
x = self.conv3(x)
count3 = prod(list(x.size()))
return count1, count2, count3
class TestActivationAnalyzer(unittest.TestCase):
"""Unittest for activation_count."""
def setUp(self) -> None:
# nn.Linear uses a different operator based on version, so make sure
# we are testing the right thing.
lin = nn.Linear(10, 10)
lin_x: torch.Tensor = torch.randn(10, 10)
trace = torch.jit.trace(lin, (lin_x, ))
node_kinds = [node.kind() for node in trace.graph.nodes()]
assert 'aten::addmm' in node_kinds or 'aten::linear' in node_kinds
if 'aten::addmm' in node_kinds:
self.lin_op = 'addmm'
else:
self.lin_op = 'linear'
def test_conv2d(self) -> None:
"""Test the activation count for convolutions."""
batch_size = 1
input_dim = 3
spatial_dim = 32
x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim)
conv_net = SmallConvNet(input_dim)
ac_dict, _ = activation_count(conv_net, (x, ))
gt_count = sum(conv_net.get_gt_activation(x))
gt_dict = defaultdict(float)
gt_dict['conv'] = gt_count / 1e6
self.assertDictEqual(
gt_dict,
ac_dict,
'conv_net with 3 layers failed to pass the activation count test.',
)
def test_linear(self) -> None:
"""Test the activation count for fully connected layer."""
batch_size = 1
input_dim = 10
output_dim = 20
linear = nn.Linear(input_dim, output_dim)
x = torch.randn(batch_size, input_dim)
ac_dict, _ = activation_count(linear, (x, ))
gt_count = batch_size * output_dim
gt_dict = defaultdict(float)
gt_dict[self.lin_op] = gt_count / 1e6
self.assertEqual(gt_dict, ac_dict,
'FC layer failed to pass the activation count test.')
def test_supported_ops(self) -> None:
"""Test the activation count for user provided handles."""
def dummy_handle(inputs: List[Any],
outputs: List[Any]) -> typing.Counter[str]:
return Counter({'conv': 100})
batch_size = 1
input_dim = 3
spatial_dim = 32
x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim)
conv_net = SmallConvNet(input_dim)
sp_ops: Dict[str, Handle] = {'aten::_convolution': dummy_handle}
ac_dict, _ = activation_count(conv_net, (x, ), sp_ops)
gt_dict = defaultdict(float)
conv_layers = 3
gt_dict['conv'] = 100 * conv_layers / 1e6
self.assertDictEqual(
gt_dict,
ac_dict,
'conv_net with 3 layers failed to pass the activation count test.',
)
def test_activation_count_class(self) -> None:
"""Tests ActivationAnalyzer."""
batch_size = 1
input_dim = 10
output_dim = 20
netLinear = nn.Linear(input_dim, output_dim)
x = torch.randn(batch_size, input_dim)
gt_count = batch_size * output_dim
gt_dict = Counter({
'': gt_count,
})
acts_counter = ActivationAnalyzer(netLinear, (x, ))
self.assertEqual(acts_counter.by_module(), gt_dict)
batch_size = 1
input_dim = 3
spatial_dim = 32
x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim)
conv_net = SmallConvNet(input_dim)
acts_counter = ActivationAnalyzer(conv_net, (x, ))
gt_counts = conv_net.get_gt_activation(x)
gt_dict = Counter({
'': sum(gt_counts),
'conv1': gt_counts[0],
'conv2': gt_counts[1],
'conv3': gt_counts[2],
})
self.assertDictEqual(gt_dict, acts_counter.by_module())
This diff is collapsed.
This diff is collapsed.
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Modified from
# https://github.com/facebookresearch/fvcore/blob/main/tests/test_param_count.py
import unittest
from torch import nn
from mmengine.analysis.complexity_analysis import (parameter_count,
parameter_count_table)
class NetWithReuse(nn.Module):
def __init__(self, reuse: bool = False) -> None:
super().__init__()
self.conv1 = nn.Conv2d(100, 100, 3)
self.conv2 = nn.Conv2d(100, 100, 3)
if reuse:
self.conv2.weight = self.conv1.weight
class NetWithDupPrefix(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(100, 100, 3)
self.conv111 = nn.Conv2d(100, 100, 3)
class TestParamCount(unittest.TestCase):
def test_param(self) -> None:
net = NetWithReuse()
count = parameter_count(net)
self.assertTrue(count[''], 180200)
self.assertTrue(count['conv2'], 90100)
def test_param_with_reuse(self) -> None:
net = NetWithReuse(reuse=True)
count = parameter_count(net)
self.assertTrue(count[''], 90200)
self.assertTrue(count['conv2'], 100)
def test_param_with_same_prefix(self) -> None:
net = NetWithDupPrefix()
table = parameter_count_table(net)
c = ['conv111.weight' in line for line in table.split('\n')]
self.assertEqual(
sum(c), 1) # it only appears once, despite being a prefix of conv1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment