From bcab813242cf0be9ea98d8ffbe628c1fc2a477c2 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Mon, 13 Jun 2022 13:51:07 +0800 Subject: [PATCH] [Feature] Add ModuleList Sequential and ModuleDict (#299) * add module list * add module list * fix docstring --- mmengine/model/__init__.py | 5 +- mmengine/model/base_module.py | 53 +++++++++ tests/test_model/test_base_module.py | 163 ++++++++++++++++++++++++++- 3 files changed, 218 insertions(+), 3 deletions(-) diff --git a/mmengine/model/__init__.py b/mmengine/model/__init__.py index 47e2356c..f1aedde2 100644 --- a/mmengine/model/__init__.py +++ b/mmengine/model/__init__.py @@ -2,7 +2,7 @@ from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA, StochasticWeightAverage) from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor -from .base_module import BaseModule +from .base_module import BaseModule, ModuleDict, ModuleList, Sequential from .utils import detect_anomalous_params, merge_dict, stack_batch from .wrappers import (MMDistributedDataParallel, MMSeparateDistributedDataParallel, is_model_wrapper) @@ -12,5 +12,6 @@ __all__ = [ 'ExponentialMovingAverage', 'MomentumAnnealingEMA', 'BaseModel', 'BaseDataPreprocessor', 'ImgDataPreprocessor', 'MMSeparateDistributedDataParallel', 'BaseModule', 'stack_batch', - 'merge_dict', 'detect_anomalous_params' + 'merge_dict', 'detect_anomalous_params', 'ModuleList', 'ModuleDict', + 'Sequential' ] diff --git a/mmengine/model/base_module.py b/mmengine/model/base_module.py index 89f140dc..380a804c 100644 --- a/mmengine/model/base_module.py +++ b/mmengine/model/base_module.py @@ -5,6 +5,7 @@ import warnings from abc import ABCMeta from collections import defaultdict from logging import FileHandler +from typing import Iterable, Optional import torch.nn as nn @@ -165,3 +166,55 @@ class BaseModule(nn.Module, metaclass=ABCMeta): if self.init_cfg: s += f'\ninit_cfg={self.init_cfg}' return s + + +class Sequential(BaseModule, nn.Sequential): + """Sequential module in openmmlab. + + Ensures that all modules in ``Sequential`` have a different initialization + strategy than the outer model + + Args: + init_cfg (dict, optional): Initialization config dict. + """ + + def __init__(self, *args, init_cfg: Optional[dict] = None): + BaseModule.__init__(self, init_cfg) + nn.Sequential.__init__(self, *args) + + +class ModuleList(BaseModule, nn.ModuleList): + """ModuleList in openmmlab. + + Ensures that all modules in ``ModuleList`` have a different initialization + strategy than the outer model + + Args: + modules (iterable, optional): An iterable of modules to add. + init_cfg (dict, optional): Initialization config dict. + """ + + def __init__(self, + modules: Optional[Iterable] = None, + init_cfg: Optional[dict] = None): + BaseModule.__init__(self, init_cfg) + nn.ModuleList.__init__(self, modules) + + +class ModuleDict(BaseModule, nn.ModuleDict): + """ModuleDict in openmmlab. + + Ensures that all modules in ``ModuleDict`` have a different initialization + strategy than the outer model + + Args: + modules (dict, optional): A mapping (dictionary) of (string: module) + or an iterable of key-value pairs of type (string, module). + init_cfg (dict, optional): Initialization config dict. + """ + + def __init__(self, + modules: Optional[dict] = None, + init_cfg: Optional[dict] = None): + BaseModule.__init__(self, init_cfg) + nn.ModuleDict.__init__(self, modules) diff --git a/tests/test_model/test_base_module.py b/tests/test_model/test_base_module.py index a253e0b3..b9d3c4f7 100644 --- a/tests/test_model/test_base_module.py +++ b/tests/test_model/test_base_module.py @@ -5,7 +5,7 @@ import torch from torch import nn from mmengine.logging.logger import MMLogger -from mmengine.model.base_module import BaseModule +from mmengine.model import BaseModule, ModuleDict, ModuleList, Sequential from mmengine.registry import Registry, build_from_cfg COMPONENTS = Registry('component') @@ -195,3 +195,164 @@ class TestBaseModule(TestCase): assert len(os.listdir(dump_dir)) == 1 assert os.stat(log_path).st_size != 0 shutil.rmtree(dump_dir) + + +class TestModuleList(TestCase): + + def test_modulelist_weight_init(self): + models_cfg = [ + dict( + type='FooConv1d', + init_cfg=dict( + type='Constant', layer='Conv1d', val=0., bias=1.)), + dict( + type='FooConv2d', + init_cfg=dict( + type='Constant', layer='Conv2d', val=2., bias=3.)), + ] + layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg] + modellist = ModuleList(layers) + modellist.init_weights() + self.assertTrue( + torch.equal(modellist[0].conv1d.weight, + torch.full(modellist[0].conv1d.weight.shape, 0.))) + self.assertTrue( + torch.equal(modellist[0].conv1d.bias, + torch.full(modellist[0].conv1d.bias.shape, 1.))) + self.assertTrue( + torch.equal(modellist[1].conv2d.weight, + torch.full(modellist[1].conv2d.weight.shape, 2.))) + self.assertTrue( + torch.equal(modellist[1].conv2d.bias, + torch.full(modellist[1].conv2d.bias.shape, 3.))) + # inner init_cfg has higher priority + layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg] + modellist = ModuleList( + layers, + init_cfg=dict( + type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) + modellist.init_weights() + self.assertTrue( + torch.equal(modellist[0].conv1d.weight, + torch.full(modellist[0].conv1d.weight.shape, 0.))) + self.assertTrue( + torch.equal(modellist[0].conv1d.bias, + torch.full(modellist[0].conv1d.bias.shape, 1.))) + self.assertTrue( + torch.equal(modellist[1].conv2d.weight, + torch.full(modellist[1].conv2d.weight.shape, 2.))) + self.assertTrue( + torch.equal(modellist[1].conv2d.bias, + torch.full(modellist[1].conv2d.bias.shape, 3.))) + + +class TestModuleDict(TestCase): + + def test_moduledict_weight_init(self): + models_cfg = dict( + foo_conv_1d=dict( + type='FooConv1d', + init_cfg=dict( + type='Constant', layer='Conv1d', val=0., bias=1.)), + foo_conv_2d=dict( + type='FooConv2d', + init_cfg=dict( + type='Constant', layer='Conv2d', val=2., bias=3.)), + ) + layers = { + name: build_from_cfg(cfg, COMPONENTS) + for name, cfg in models_cfg.items() + } + modeldict = ModuleDict(layers) + modeldict.init_weights() + self.assertTrue( + torch.equal( + modeldict['foo_conv_1d'].conv1d.weight, + torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.))) + self.assertTrue( + torch.equal( + modeldict['foo_conv_1d'].conv1d.bias, + torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.))) + self.assertTrue( + torch.equal( + modeldict['foo_conv_2d'].conv2d.weight, + torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.))) + self.assertTrue( + torch.equal( + modeldict['foo_conv_2d'].conv2d.bias, + torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.))) + # inner init_cfg has higher priority + layers = { + name: build_from_cfg(cfg, COMPONENTS) + for name, cfg in models_cfg.items() + } + modeldict = ModuleDict( + layers, + init_cfg=dict( + type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) + modeldict.init_weights() + self.assertTrue( + torch.equal( + modeldict['foo_conv_1d'].conv1d.weight, + torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.))) + self.assertTrue( + torch.equal( + modeldict['foo_conv_1d'].conv1d.bias, + torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.))) + self.assertTrue( + torch.equal( + modeldict['foo_conv_2d'].conv2d.weight, + torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.))) + self.assertTrue( + torch.equal( + modeldict['foo_conv_2d'].conv2d.bias, + torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.))) + + +class TestSequential(TestCase): + + def test_sequential_model_weight_init(self): + seq_model_cfg = [ + dict( + type='FooConv1d', + init_cfg=dict( + type='Constant', layer='Conv1d', val=0., bias=1.)), + dict( + type='FooConv2d', + init_cfg=dict( + type='Constant', layer='Conv2d', val=2., bias=3.)), + ] + layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg] + seq_model = Sequential(*layers) + seq_model.init_weights() + self.assertTrue( + torch.equal(seq_model[0].conv1d.weight, + torch.full(seq_model[0].conv1d.weight.shape, 0.))) + self.assertTrue( + torch.equal(seq_model[0].conv1d.bias, + torch.full(seq_model[0].conv1d.bias.shape, 1.))) + self.assertTrue( + torch.equal(seq_model[1].conv2d.weight, + torch.full(seq_model[1].conv2d.weight.shape, 2.))) + self.assertTrue( + torch.equal(seq_model[1].conv2d.bias, + torch.full(seq_model[1].conv2d.bias.shape, 3.))) + # inner init_cfg has higher priority + layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg] + seq_model = Sequential( + *layers, + init_cfg=dict( + type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.)) + seq_model.init_weights() + self.assertTrue( + torch.equal(seq_model[0].conv1d.weight, + torch.full(seq_model[0].conv1d.weight.shape, 0.))) + self.assertTrue( + torch.equal(seq_model[0].conv1d.bias, + torch.full(seq_model[0].conv1d.bias.shape, 1.))) + self.assertTrue( + torch.equal(seq_model[1].conv2d.weight, + torch.full(seq_model[1].conv2d.weight.shape, 2.))) + self.assertTrue( + torch.equal(seq_model[1].conv2d.bias, + torch.full(seq_model[1].conv2d.bias.shape, 3.))) -- GitLab