Skip to content
Snippets Groups Projects
Unverified Commit f76218a4 authored by sjiang95's avatar sjiang95 Committed by GitHub
Browse files

[Enhance] Make the parameters of get_model_complexity_info() friendly (#1056)


* print_helper: optimize inputs of get_model_complexity_info

Signed-off-by: default avatarShengjiang QUAN <qsj287068067@126.com>

* directly throw error

When "input_shape" and "inputs" are both `None` or both set,
throw ValueError.

Signed-off-by: default avatarShengjiang QUAN <qsj287068067@126.com>

---------

Signed-off-by: default avatarShengjiang QUAN <qsj287068067@126.com>
parent 5e1ed7aa
No related branches found
No related tags found
No related merge requests found
...@@ -675,7 +675,7 @@ def complexity_stats_table( ...@@ -675,7 +675,7 @@ def complexity_stats_table(
def get_model_complexity_info( def get_model_complexity_info(
model: nn.Module, model: nn.Module,
input_shape: tuple, input_shape: tuple = None,
inputs: Optional[torch.Tensor] = None, inputs: Optional[torch.Tensor] = None,
show_table: bool = True, show_table: bool = True,
show_arch: bool = True, show_arch: bool = True,
...@@ -696,6 +696,11 @@ def get_model_complexity_info( ...@@ -696,6 +696,11 @@ def get_model_complexity_info(
Returns: Returns:
dict: The complexity information of the model. dict: The complexity information of the model.
""" """
if input_shape is None and inputs is None:
raise ValueError('One of "input_shape" and "inputs" should be set.')
elif input_shape is not None and inputs is not None:
raise ValueError('"input_shape" and "inputs" cannot be both set.')
if inputs is None: if inputs is None:
inputs = (torch.randn(1, *input_shape), ) inputs = (torch.randn(1, *input_shape), )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment