Skip to content
Snippets Groups Projects
Unverified Commit f3189918 authored by RangiLyu's avatar RangiLyu Committed by GitHub
Browse files

[Fix] Fix logged and current weight not on the same device. (#382)

parent 16ef54c4
No related branches found
No related tags found
No related merge requests found
......@@ -80,7 +80,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
f'after calling `init_weights` ' \
f'of {self.__class__.__name__} '
self._params_init_info[param][
'tmp_mean_value'] = param.data.mean()
'tmp_mean_value'] = param.data.mean().cpu()
# pass `params_init_info` to all submodules
# All submodules share the same `params_init_info`,
......
......@@ -41,7 +41,7 @@ def update_init_info(module, init_info):
# The parameter has been changed during executing the
# `init_weights` of module
mean_value = param.data.mean()
mean_value = param.data.mean().cpu()
if module._params_init_info[param]['tmp_mean_value'] != mean_value:
module._params_init_info[param]['init_info'] = init_info
module._params_init_info[param]['tmp_mean_value'] = mean_value
......
......@@ -776,8 +776,7 @@ class Runner:
def build_model(self, model: Union[BaseModel, Dict]) -> BaseModel:
"""Build model.
If ``model`` is a dict, it will be used to build a nn.Module object
and initialize the weights if it has ``init_weights`` method.
If ``model`` is a dict, it will be used to build a nn.Module object.
Else, if ``model`` is a nn.Module object it will be returned directly.
An example of ``model``::
......@@ -796,7 +795,6 @@ class Runner:
return model
elif isinstance(model, dict):
model = MODELS.build(model)
# init weights
return model # type: ignore
else:
raise TypeError('model should be a nn.Module object or dict, '
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment