Skip to content
Snippets Groups Projects
Unverified Commit 83d76abc authored by CescMessi's avatar CescMessi Committed by GitHub
Browse files

[Fix] Fix the incorrect device of inputs in get_model_complexity_info (#1130)

parent 2085046d
No related branches found
No related tags found
No related merge requests found
......@@ -725,14 +725,15 @@ def get_model_complexity_info(
raise ValueError('"input_shape" and "inputs" cannot be both set.')
if inputs is None:
device = next(model.parameters()).device
if is_tuple_of(input_shape, int): # tuple of int, construct one tensor
inputs = (torch.randn(1, *input_shape), )
inputs = (torch.randn(1, *input_shape).to(device), )
elif is_tuple_of(input_shape, tuple) and all([
is_tuple_of(one_input_shape, int)
for one_input_shape in input_shape # type: ignore
]): # tuple of tuple of int, construct multiple tensors
inputs = tuple([
torch.randn(1, *one_input_shape)
torch.randn(1, *one_input_shape).to(device)
for one_input_shape in input_shape # type: ignore
])
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment