Skip to content
Snippets Groups Projects
Unverified Commit ab8b5168 authored by RangiLyu's avatar RangiLyu Committed by GitHub
Browse files

[Fix] Init weights after build model. (#164)

* [Fix] Init weights after build model.

* add unit tests and docstring
parent 87da7599
No related branches found
No related tags found
No related merge requests found
......@@ -658,6 +658,10 @@ class Runner:
def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module:
"""Build model.
If ``model`` is a dict, it will be used to build a nn.Module object
and initialize the weights if it has ``init_weights`` method.
Else, if ``model`` is a nn.Module object it will be returned directly.
An example of ``model``::
model = dict(type='ResNet')
......@@ -673,7 +677,11 @@ class Runner:
if isinstance(model, nn.Module):
return model
elif isinstance(model, dict):
return MODELS.build(model)
model = MODELS.build(model)
# init weights
if hasattr(model, 'init_weights'):
model.init_weights()
return model
else:
raise TypeError('model should be a nn.Module object or dict, '
f'but got {model}')
......
......@@ -511,6 +511,25 @@ class TestRunner(TestCase):
model = runner.build_model(dict(type='ToyModel1'))
self.assertIsInstance(model, ToyModel1)
# test init weights
@MODELS.register_module()
class ToyModel2(ToyModel):
def __init__(self):
super().__init__()
self.initiailzed = False
def init_weights(self):
self.initiailzed = True
model = runner.build_model(dict(type='ToyModel2'))
self.assertTrue(model.initiailzed)
# test init weights with model object
_model = ToyModel2()
model = runner.build_model(_model)
self.assertFalse(model.initiailzed)
def test_wrap_model(self):
# TODO: test on distributed environment
# custom model wrapper
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment