From dc931fd2c094642304d9734144c0d601afbcc796 Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Wed, 5 Apr 2023 10:33:24 +0800
Subject: [PATCH] [Fix] Initialize nested modules in ddp which define
 'init_weights' method (#1045)

---
 mmengine/model/base_module.py        |  3 +++
 tests/test_model/test_base_module.py | 18 +++++++++++++++++-
 2 files changed, 20 insertions(+), 1 deletion(-)

diff --git a/mmengine/model/base_module.py b/mmengine/model/base_module.py
index 1167bdf2..6912a353 100644
--- a/mmengine/model/base_module.py
+++ b/mmengine/model/base_module.py
@@ -11,6 +11,7 @@ import torch.nn as nn
 from mmengine.dist import master_only
 from mmengine.logging import MMLogger, print_log
 from .weight_init import initialize, update_init_info
+from .wrappers.utils import is_model_wrapper
 
 
 class BaseModule(nn.Module, metaclass=ABCMeta):
@@ -123,6 +124,8 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
                 initialize(self, other_cfgs)
 
             for m in self.children():
+                if is_model_wrapper(m) and not hasattr(m, 'init_weights'):
+                    m = m.module
                 if hasattr(m, 'init_weights'):
                     m.init_weights()
                     # users may overload the `init_weights`
diff --git a/tests/test_model/test_base_module.py b/tests/test_model/test_base_module.py
index 23c87860..b882c48d 100644
--- a/tests/test_model/test_base_module.py
+++ b/tests/test_model/test_base_module.py
@@ -4,6 +4,7 @@ import logging
 import os.path as osp
 import tempfile
 from unittest import TestCase
+from unittest.mock import patch
 
 import torch
 from torch import nn
@@ -145,7 +146,6 @@ class TestBaseModule(TestCase):
             ├──conv1d (FooConv1d, weight=3, bias=4)
         ├──reg (nn.Linear, weight=1, bias=2)
         """
-
         self.model.init_weights()
 
         assert torch.equal(
@@ -222,6 +222,22 @@ class TestBaseModule(TestCase):
         self.assertTrue((ori_layer_weight != model.linear.linear.weight).any())
         self.assertTrue((ori_layer_bias != model.linear.linear.bias).any())
 
+        class FakeDDP(nn.Module):
+
+            def __init__(self, module) -> None:
+                super().__init__()
+                self.module = module
+
+        # Test initialization of nested modules in DDPModule which define
+        # `init_weights`.
+        with patch('mmengine.model.base_module.is_model_wrapper',
+                   lambda x: isinstance(x, FakeDDP)):
+            model = FOOMODELS.build(model_cfg)
+            model.ddp = FakeDDP(CustomLinear())
+            model.init_weights()
+            self.assertTrue((model.ddp.module.linear.weight == 1).all())
+            self.assertTrue((model.ddp.module.linear.bias == 2).all())
+
     def test_dump_init_info(self):
         import os
         import shutil
-- 
GitLab