From dc931fd2c094642304d9734144c0d601afbcc796 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Wed, 5 Apr 2023 10:33:24 +0800 Subject: [PATCH] [Fix] Initialize nested modules in ddp which define 'init_weights' method (#1045) --- mmengine/model/base_module.py | 3 +++ tests/test_model/test_base_module.py | 18 +++++++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/mmengine/model/base_module.py b/mmengine/model/base_module.py index 1167bdf2..6912a353 100644 --- a/mmengine/model/base_module.py +++ b/mmengine/model/base_module.py @@ -11,6 +11,7 @@ import torch.nn as nn from mmengine.dist import master_only from mmengine.logging import MMLogger, print_log from .weight_init import initialize, update_init_info +from .wrappers.utils import is_model_wrapper class BaseModule(nn.Module, metaclass=ABCMeta): @@ -123,6 +124,8 @@ class BaseModule(nn.Module, metaclass=ABCMeta): initialize(self, other_cfgs) for m in self.children(): + if is_model_wrapper(m) and not hasattr(m, 'init_weights'): + m = m.module if hasattr(m, 'init_weights'): m.init_weights() # users may overload the `init_weights` diff --git a/tests/test_model/test_base_module.py b/tests/test_model/test_base_module.py index 23c87860..b882c48d 100644 --- a/tests/test_model/test_base_module.py +++ b/tests/test_model/test_base_module.py @@ -4,6 +4,7 @@ import logging import os.path as osp import tempfile from unittest import TestCase +from unittest.mock import patch import torch from torch import nn @@ -145,7 +146,6 @@ class TestBaseModule(TestCase): ├──conv1d (FooConv1d, weight=3, bias=4) ├──reg (nn.Linear, weight=1, bias=2) """ - self.model.init_weights() assert torch.equal( @@ -222,6 +222,22 @@ class TestBaseModule(TestCase): self.assertTrue((ori_layer_weight != model.linear.linear.weight).any()) self.assertTrue((ori_layer_bias != model.linear.linear.bias).any()) + class FakeDDP(nn.Module): + + def __init__(self, module) -> None: + super().__init__() + self.module = module + + # Test initialization of nested modules in DDPModule which define + # `init_weights`. + with patch('mmengine.model.base_module.is_model_wrapper', + lambda x: isinstance(x, FakeDDP)): + model = FOOMODELS.build(model_cfg) + model.ddp = FakeDDP(CustomLinear()) + model.init_weights() + self.assertTrue((model.ddp.module.linear.weight == 1).all()) + self.assertTrue((model.ddp.module.linear.bias == 2).all()) + def test_dump_init_info(self): import os import shutil -- GitLab