Skip to content
Snippets Groups Projects
Unverified Commit dc931fd2 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Initialize nested modules in ddp which define 'init_weights' method (#1045)

parent fd84c210
No related branches found
No related tags found
No related merge requests found
......@@ -11,6 +11,7 @@ import torch.nn as nn
from mmengine.dist import master_only
from mmengine.logging import MMLogger, print_log
from .weight_init import initialize, update_init_info
from .wrappers.utils import is_model_wrapper
class BaseModule(nn.Module, metaclass=ABCMeta):
......@@ -123,6 +124,8 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
initialize(self, other_cfgs)
for m in self.children():
if is_model_wrapper(m) and not hasattr(m, 'init_weights'):
m = m.module
if hasattr(m, 'init_weights'):
m.init_weights()
# users may overload the `init_weights`
......
......@@ -4,6 +4,7 @@ import logging
import os.path as osp
import tempfile
from unittest import TestCase
from unittest.mock import patch
import torch
from torch import nn
......@@ -145,7 +146,6 @@ class TestBaseModule(TestCase):
├──conv1d (FooConv1d, weight=3, bias=4)
├──reg (nn.Linear, weight=1, bias=2)
"""
self.model.init_weights()
assert torch.equal(
......@@ -222,6 +222,22 @@ class TestBaseModule(TestCase):
self.assertTrue((ori_layer_weight != model.linear.linear.weight).any())
self.assertTrue((ori_layer_bias != model.linear.linear.bias).any())
class FakeDDP(nn.Module):
def __init__(self, module) -> None:
super().__init__()
self.module = module
# Test initialization of nested modules in DDPModule which define
# `init_weights`.
with patch('mmengine.model.base_module.is_model_wrapper',
lambda x: isinstance(x, FakeDDP)):
model = FOOMODELS.build(model_cfg)
model.ddp = FakeDDP(CustomLinear())
model.init_weights()
self.assertTrue((model.ddp.module.linear.weight == 1).all())
self.assertTrue((model.ddp.module.linear.bias == 2).all())
def test_dump_init_info(self):
import os
import shutil
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment