diff --git a/mmengine/model/base_module.py b/mmengine/model/base_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a38bbf5292686598c5c3b98a901d8aa3aa96b0b
--- /dev/null
+++ b/mmengine/model/base_module.py
@@ -0,0 +1,167 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import logging
+import warnings
+from abc import ABCMeta
+from collections import defaultdict
+from logging import FileHandler
+
+import torch.nn as nn
+
+from mmengine.dist import master_only
+from mmengine.logging import MMLogger, print_log
+
+
+class BaseModule(nn.Module, metaclass=ABCMeta):
+    """Base module for all modules in openmmlab. ``BaseModule`` is a wrapper of
+    ``torch.nn.Module`` with additional functionality of parameter
+    initialization. Compared with ``torch.nn.Module``, ``BaseModule`` mainly
+    adds three attributes.
+
+    - ``init_cfg``: the config to control the initialization.
+    - ``init_weights``: The function of parameter initialization and recording
+      initialization information.
+    - ``_params_init_info``: Used to track the parameter initialization
+      information. This attribute only exists during executing the
+      ``init_weights``.
+    Args:
+        init_cfg (dict, optional): Initialization config dict.
+    """
+
+    def __init__(self, init_cfg=None):
+        """Initialize BaseModule, inherited from `torch.nn.Module`"""
+
+        # NOTE init_cfg can be defined in different levels, but init_cfg
+        # in low levels has a higher priority.
+
+        super().__init__()
+        # define default value of init_cfg instead of hard code
+        # in init_weights() function
+        self._is_init = False
+
+        self.init_cfg = copy.deepcopy(init_cfg)
+
+        # Backward compatibility in derived classes
+        # if pretrained is not None:
+        #     warnings.warn('DeprecationWarning: pretrained is a deprecated \
+        #         key, please consider using init_cfg')
+        #     self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+
+    @property
+    def is_init(self):
+        return self._is_init
+
+    def init_weights(self):
+        """Initialize the weights."""
+
+        is_top_level_module = False
+        # check if it is top-level module
+        if not hasattr(self, '_params_init_info'):
+            # The `_params_init_info` is used to record the initialization
+            # information of the parameters
+            # the key should be the obj:`nn.Parameter` of model and the value
+            # should be a dict containing
+            # - init_info (str): The string that describes the initialization.
+            # - tmp_mean_value (FloatTensor): The mean of the parameter,
+            #       which indicates whether the parameter has been modified.
+            # this attribute would be deleted after all parameters
+            # is initialized.
+            self._params_init_info = defaultdict(dict)
+            is_top_level_module = True
+
+            # Initialize the `_params_init_info`,
+            # When detecting the `tmp_mean_value` of
+            # the corresponding parameter is changed, update related
+            # initialization information
+            for name, param in self.named_parameters():
+                self._params_init_info[param][
+                    'init_info'] = f'The value is the same before and ' \
+                                   f'after calling `init_weights` ' \
+                                   f'of {self.__class__.__name__} '
+                self._params_init_info[param][
+                    'tmp_mean_value'] = param.data.mean()
+
+            # pass `params_init_info` to all submodules
+            # All submodules share the same `params_init_info`,
+            # so it will be updated when parameters are
+            # modified at any level of the model.
+            for sub_module in self.modules():
+                sub_module._params_init_info = self._params_init_info
+
+        logger = MMLogger.get_current_instance()
+        logger_name = logger.instance_name
+
+        from .utils.weight_init import initialize, update_init_info
+        module_name = self.__class__.__name__
+        if not self._is_init:
+            if self.init_cfg:
+                print_log(
+                    f'initialize {module_name} with init_cfg {self.init_cfg}',
+                    logger=logger_name,
+                    level=logging.DEBUG)
+                initialize(self, self.init_cfg)
+                if isinstance(self.init_cfg, dict):
+                    # prevent the parameters of
+                    # the pre-trained model
+                    # from being overwritten by
+                    # the `init_weights`
+                    if self.init_cfg['type'] == 'Pretrained':
+                        return
+
+            for m in self.children():
+                if hasattr(m, 'init_weights'):
+                    m.init_weights()
+                    # users may overload the `init_weights`
+                    update_init_info(
+                        m,
+                        init_info=f'Initialized by '
+                        f'user-defined `init_weights`'
+                        f' in {m.__class__.__name__} ')
+
+            self._is_init = True
+        else:
+            warnings.warn(f'init_weights of {self.__class__.__name__} has '
+                          f'been called more than once.')
+
+        if is_top_level_module:
+            # self._dump_init_info(logger_name)
+            self._dump_init_info()
+
+            for sub_module in self.modules():
+                del sub_module._params_init_info
+
+    @master_only
+    def _dump_init_info(self):
+        """Dump the initialization information to a file named
+        `initialization.log.json` in workdir.
+
+        Args:
+            logger_name (str): The name of logger.
+        """
+
+        logger = MMLogger.get_current_instance()
+        logger_name = logger.instance_name
+        with_file_handler = False
+        # dump the information to the logger file if there is a `FileHandler`
+        for handler in logger.handlers:
+            if isinstance(handler, FileHandler):
+                handler.stream.write(
+                    'Name of parameter - Initialization information\n')
+                for name, param in self.named_parameters():
+                    handler.stream.write(
+                        f'\n{name} - {param.shape}: '
+                        f"\n{self._params_init_info[param]['init_info']} \n")
+                handler.stream.flush()
+                with_file_handler = True
+        if not with_file_handler:
+            for name, param in self.named_parameters():
+                print_log(
+                    f'\n{name} - {param.shape}: '
+                    f"\n{self._params_init_info[param]['init_info']} \n ",
+                    logger=logger_name)
+
+    def __repr__(self):
+        s = super().__repr__()
+        if self.init_cfg:
+            s += f'\ninit_cfg={self.init_cfg}'
+        return s
diff --git a/mmengine/model/utils/weight_init.py b/mmengine/model/utils/weight_init.py
new file mode 100644
index 0000000000000000000000000000000000000000..1289d7f117f9dbdcb7d20e44c4b59b466fc989e0
--- /dev/null
+++ b/mmengine/model/utils/weight_init.py
@@ -0,0 +1,670 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import math
+import warnings
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+from mmengine.logging.logger import MMLogger, print_log
+from mmengine.registry import WEIGHT_INITIALIZERS, build_from_cfg
+
+
+def update_init_info(module, init_info):
+    """Update the `_params_init_info` in the module if the value of parameters
+    are changed.
+
+    Args:
+        module (obj:`nn.Module`): The module of PyTorch with a user-defined
+            attribute `_params_init_info` which records the initialization
+            information.
+        init_info (str): The string that describes the initialization.
+    """
+    assert hasattr(
+        module,
+        '_params_init_info'), f'Can not find `_params_init_info` in {module}'
+    for name, param in module.named_parameters():
+
+        assert param in module._params_init_info, (
+            f'Find a new :obj:`Parameter` '
+            f'named `{name}` during executing the '
+            f'`init_weights` of '
+            f'`{module.__class__.__name__}`. '
+            f'Please do not add or '
+            f'replace parameters during executing '
+            f'the `init_weights`. ')
+
+        # The parameter has been changed during executing the
+        # `init_weights` of module
+        mean_value = param.data.mean()
+        if module._params_init_info[param]['tmp_mean_value'] != mean_value:
+            module._params_init_info[param]['init_info'] = init_info
+            module._params_init_info[param]['tmp_mean_value'] = mean_value
+
+
+def constant_init(module, val, bias=0):
+    if hasattr(module, 'weight') and module.weight is not None:
+        nn.init.constant_(module.weight, val)
+    if hasattr(module, 'bias') and module.bias is not None:
+        nn.init.constant_(module.bias, bias)
+
+
+def xavier_init(module, gain=1, bias=0, distribution='normal'):
+    assert distribution in ['uniform', 'normal']
+    if hasattr(module, 'weight') and module.weight is not None:
+        if distribution == 'uniform':
+            nn.init.xavier_uniform_(module.weight, gain=gain)
+        else:
+            nn.init.xavier_normal_(module.weight, gain=gain)
+    if hasattr(module, 'bias') and module.bias is not None:
+        nn.init.constant_(module.bias, bias)
+
+
+def normal_init(module, mean=0, std=1, bias=0):
+    if hasattr(module, 'weight') and module.weight is not None:
+        nn.init.normal_(module.weight, mean, std)
+    if hasattr(module, 'bias') and module.bias is not None:
+        nn.init.constant_(module.bias, bias)
+
+
+def trunc_normal_init(module: nn.Module,
+                      mean: float = 0,
+                      std: float = 1,
+                      a: float = -2,
+                      b: float = 2,
+                      bias: float = 0) -> None:
+    if hasattr(module, 'weight') and module.weight is not None:
+        trunc_normal_(module.weight, mean, std, a, b)  # type: ignore
+    if hasattr(module, 'bias') and module.bias is not None:
+        nn.init.constant_(module.bias, bias)  # type: ignore
+
+
+def uniform_init(module, a=0, b=1, bias=0):
+    if hasattr(module, 'weight') and module.weight is not None:
+        nn.init.uniform_(module.weight, a, b)
+    if hasattr(module, 'bias') and module.bias is not None:
+        nn.init.constant_(module.bias, bias)
+
+
+def kaiming_init(module,
+                 a=0,
+                 mode='fan_out',
+                 nonlinearity='relu',
+                 bias=0,
+                 distribution='normal'):
+    assert distribution in ['uniform', 'normal']
+    if hasattr(module, 'weight') and module.weight is not None:
+        if distribution == 'uniform':
+            nn.init.kaiming_uniform_(
+                module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+        else:
+            nn.init.kaiming_normal_(
+                module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
+    if hasattr(module, 'bias') and module.bias is not None:
+        nn.init.constant_(module.bias, bias)
+
+
+def caffe2_xavier_init(module, bias=0):
+    # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
+    # Acknowledgment to FAIR's internal code
+    kaiming_init(
+        module,
+        a=1,
+        mode='fan_in',
+        nonlinearity='leaky_relu',
+        bias=bias,
+        distribution='uniform')
+
+
+def bias_init_with_prob(prior_prob):
+    """initialize conv/fc bias value according to a given probability value."""
+    bias_init = float(-np.log((1 - prior_prob) / prior_prob))
+    return bias_init
+
+
+def _get_bases_name(m):
+    return [b.__name__ for b in m.__class__.__bases__]
+
+
+class BaseInit:
+
+    def __init__(self, *, bias=0, bias_prob=None, layer=None):
+        self.wholemodule = False
+        if not isinstance(bias, (int, float)):
+            raise TypeError(f'bias must be a number, but got a {type(bias)}')
+
+        if bias_prob is not None:
+            if not isinstance(bias_prob, float):
+                raise TypeError(f'bias_prob type must be float, \
+                    but got {type(bias_prob)}')
+
+        if layer is not None:
+            if not isinstance(layer, (str, list)):
+                raise TypeError(f'layer must be a str or a list of str, \
+                    but got a {type(layer)}')
+        else:
+            layer = []
+
+        if bias_prob is not None:
+            self.bias = bias_init_with_prob(bias_prob)
+        else:
+            self.bias = bias
+        self.layer = [layer] if isinstance(layer, str) else layer
+
+    def _get_init_info(self):
+        info = f'{self.__class__.__name__}, bias={self.bias}'
+        return info
+
+
+@WEIGHT_INITIALIZERS.register_module(name='Constant')
+class ConstantInit(BaseInit):
+    """Initialize module parameters with constant values.
+
+    Args:
+        val (int | float): the value to fill the weights in the module with
+        bias (int | float): the value to fill the bias. Defaults to 0.
+        bias_prob (float, optional): the probability for bias initialization.
+            Defaults to None.
+        layer (str | list[str], optional): the layer will be initialized.
+            Defaults to None.
+    """
+
+    def __init__(self, val, **kwargs):
+        super().__init__(**kwargs)
+        self.val = val
+
+    def __call__(self, module):
+
+        def init(m):
+            if self.wholemodule:
+                constant_init(m, self.val, self.bias)
+            else:
+                layername = m.__class__.__name__
+                basesname = _get_bases_name(m)
+                if len(set(self.layer) & set([layername] + basesname)):
+                    constant_init(m, self.val, self.bias)
+
+        module.apply(init)
+        if hasattr(module, '_params_init_info'):
+            update_init_info(module, init_info=self._get_init_info())
+
+    def _get_init_info(self):
+        info = f'{self.__class__.__name__}: val={self.val}, bias={self.bias}'
+        return info
+
+
+@WEIGHT_INITIALIZERS.register_module(name='Xavier')
+class XavierInit(BaseInit):
+    r"""Initialize module parameters with values according to the method
+    described in `Understanding the difficulty of training deep feedforward
+    neural networks - Glorot, X. & Bengio, Y. (2010).
+    <http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf>`_
+    Args:
+        gain (int | float): an optional scaling factor. Defaults to 1.
+        bias (int | float): the value to fill the bias. Defaults to 0.
+        bias_prob (float, optional): the probability for bias initialization.
+            Defaults to None.
+        distribution (str): distribution either be ``'normal'``
+            or ``'uniform'``. Defaults to ``'normal'``.
+        layer (str | list[str], optional): the layer will be initialized.
+            Defaults to None.
+    """
+
+    def __init__(self, gain=1, distribution='normal', **kwargs):
+        super().__init__(**kwargs)
+        self.gain = gain
+        self.distribution = distribution
+
+    def __call__(self, module):
+
+        def init(m):
+            if self.wholemodule:
+                xavier_init(m, self.gain, self.bias, self.distribution)
+            else:
+                layername = m.__class__.__name__
+                basesname = _get_bases_name(m)
+                if len(set(self.layer) & set([layername] + basesname)):
+                    xavier_init(m, self.gain, self.bias, self.distribution)
+
+        module.apply(init)
+        if hasattr(module, '_params_init_info'):
+            update_init_info(module, init_info=self._get_init_info())
+
+    def _get_init_info(self):
+        info = f'{self.__class__.__name__}: gain={self.gain}, ' \
+               f'distribution={self.distribution}, bias={self.bias}'
+        return info
+
+
+@WEIGHT_INITIALIZERS.register_module(name='Normal')
+class NormalInit(BaseInit):
+    r"""Initialize module parameters with the values drawn from the normal
+    distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
+    Args:
+        mean (int | float):the mean of the normal distribution. Defaults to 0.
+        std (int | float): the standard deviation of the normal distribution.
+            Defaults to 1.
+        bias (int | float): the value to fill the bias. Defaults to 0.
+        bias_prob (float, optional): the probability for bias initialization.
+            Defaults to None.
+        layer (str | list[str], optional): the layer will be initialized.
+            Defaults to None.
+    """
+
+    def __init__(self, mean=0, std=1, **kwargs):
+        super().__init__(**kwargs)
+        self.mean = mean
+        self.std = std
+
+    def __call__(self, module):
+
+        def init(m):
+            if self.wholemodule:
+                normal_init(m, self.mean, self.std, self.bias)
+            else:
+                layername = m.__class__.__name__
+                basesname = _get_bases_name(m)
+                if len(set(self.layer) & set([layername] + basesname)):
+                    normal_init(m, self.mean, self.std, self.bias)
+
+        module.apply(init)
+        if hasattr(module, '_params_init_info'):
+            update_init_info(module, init_info=self._get_init_info())
+
+    def _get_init_info(self):
+        info = f'{self.__class__.__name__}: mean={self.mean},' \
+               f' std={self.std}, bias={self.bias}'
+        return info
+
+
+@WEIGHT_INITIALIZERS.register_module(name='TruncNormal')
+class TruncNormalInit(BaseInit):
+    r"""Initialize module parameters with the values drawn from the normal
+    distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values
+    outside :math:`[a, b]`.
+    Args:
+        mean (float): the mean of the normal distribution. Defaults to 0.
+        std (float):  the standard deviation of the normal distribution.
+            Defaults to 1.
+        a (float): The minimum cutoff value.
+        b ( float): The maximum cutoff value.
+        bias (float): the value to fill the bias. Defaults to 0.
+        bias_prob (float, optional): the probability for bias initialization.
+            Defaults to None.
+        layer (str | list[str], optional): the layer will be initialized.
+            Defaults to None.
+    """
+
+    def __init__(self,
+                 mean: float = 0,
+                 std: float = 1,
+                 a: float = -2,
+                 b: float = 2,
+                 **kwargs) -> None:
+        super().__init__(**kwargs)
+        self.mean = mean
+        self.std = std
+        self.a = a
+        self.b = b
+
+    def __call__(self, module: nn.Module) -> None:
+
+        def init(m):
+            if self.wholemodule:
+                trunc_normal_init(m, self.mean, self.std, self.a, self.b,
+                                  self.bias)
+            else:
+                layername = m.__class__.__name__
+                basesname = _get_bases_name(m)
+                if len(set(self.layer) & set([layername] + basesname)):
+                    trunc_normal_init(m, self.mean, self.std, self.a, self.b,
+                                      self.bias)
+
+        module.apply(init)
+        if hasattr(module, '_params_init_info'):
+            update_init_info(module, init_info=self._get_init_info())
+
+    def _get_init_info(self):
+        info = f'{self.__class__.__name__}: a={self.a}, b={self.b},' \
+               f' mean={self.mean}, std={self.std}, bias={self.bias}'
+        return info
+
+
+@WEIGHT_INITIALIZERS.register_module(name='Uniform')
+class UniformInit(BaseInit):
+    r"""Initialize module parameters with values drawn from the uniform
+    distribution :math:`\mathcal{U}(a, b)`.
+    Args:
+        a (int | float): the lower bound of the uniform distribution.
+            Defaults to 0.
+        b (int | float): the upper bound of the uniform distribution.
+            Defaults to 1.
+        bias (int | float): the value to fill the bias. Defaults to 0.
+        bias_prob (float, optional): the probability for bias initialization.
+            Defaults to None.
+        layer (str | list[str], optional): the layer will be initialized.
+            Defaults to None.
+    """
+
+    def __init__(self, a=0, b=1, **kwargs):
+        super().__init__(**kwargs)
+        self.a = a
+        self.b = b
+
+    def __call__(self, module):
+
+        def init(m):
+            if self.wholemodule:
+                uniform_init(m, self.a, self.b, self.bias)
+            else:
+                layername = m.__class__.__name__
+                basesname = _get_bases_name(m)
+                if len(set(self.layer) & set([layername] + basesname)):
+                    uniform_init(m, self.a, self.b, self.bias)
+
+        module.apply(init)
+        if hasattr(module, '_params_init_info'):
+            update_init_info(module, init_info=self._get_init_info())
+
+    def _get_init_info(self):
+        info = f'{self.__class__.__name__}: a={self.a},' \
+               f' b={self.b}, bias={self.bias}'
+        return info
+
+
+@WEIGHT_INITIALIZERS.register_module(name='Kaiming')
+class KaimingInit(BaseInit):
+    r"""Initialize module parameters with the values according to the method
+    described in `Delving deep into rectifiers: Surpassing human-level
+    performance on ImageNet classification - He, K. et al. (2015).
+    <https://www.cv-foundation.org/openaccess/content_iccv_2015/
+    papers/He_Delving_Deep_into_ICCV_2015_paper.pdf>`_
+    Args:
+        a (int | float): the negative slope of the rectifier used after this
+            layer (only used with ``'leaky_relu'``). Defaults to 0.
+        mode (str):  either ``'fan_in'`` or ``'fan_out'``. Choosing
+            ``'fan_in'`` preserves the magnitude of the variance of the weights
+            in the forward pass. Choosing ``'fan_out'`` preserves the
+            magnitudes in the backwards pass. Defaults to ``'fan_out'``.
+        nonlinearity (str): the non-linear function (`nn.functional` name),
+            recommended to use only with ``'relu'`` or ``'leaky_relu'`` .
+            Defaults to 'relu'.
+        bias (int | float): the value to fill the bias. Defaults to 0.
+        bias_prob (float, optional): the probability for bias initialization.
+            Defaults to None.
+        distribution (str): distribution either be ``'normal'`` or
+            ``'uniform'``. Defaults to ``'normal'``.
+        layer (str | list[str], optional): the layer will be initialized.
+            Defaults to None.
+    """
+
+    def __init__(self,
+                 a=0,
+                 mode='fan_out',
+                 nonlinearity='relu',
+                 distribution='normal',
+                 **kwargs):
+        super().__init__(**kwargs)
+        self.a = a
+        self.mode = mode
+        self.nonlinearity = nonlinearity
+        self.distribution = distribution
+
+    def __call__(self, module):
+
+        def init(m):
+            if self.wholemodule:
+                kaiming_init(m, self.a, self.mode, self.nonlinearity,
+                             self.bias, self.distribution)
+            else:
+                layername = m.__class__.__name__
+                basesname = _get_bases_name(m)
+                if len(set(self.layer) & set([layername] + basesname)):
+                    kaiming_init(m, self.a, self.mode, self.nonlinearity,
+                                 self.bias, self.distribution)
+
+        module.apply(init)
+        if hasattr(module, '_params_init_info'):
+            update_init_info(module, init_info=self._get_init_info())
+
+    def _get_init_info(self):
+        info = f'{self.__class__.__name__}: a={self.a}, mode={self.mode}, ' \
+               f'nonlinearity={self.nonlinearity}, ' \
+               f'distribution ={self.distribution}, bias={self.bias}'
+        return info
+
+
+@WEIGHT_INITIALIZERS.register_module(name='Caffe2Xavier')
+class Caffe2XavierInit(KaimingInit):
+    # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
+    # Acknowledgment to FAIR's internal code
+    def __init__(self, **kwargs):
+        super().__init__(
+            a=1,
+            mode='fan_in',
+            nonlinearity='leaky_relu',
+            distribution='uniform',
+            **kwargs)
+
+    def __call__(self, module):
+        super().__call__(module)
+
+
+@WEIGHT_INITIALIZERS.register_module(name='Pretrained')
+class PretrainedInit:
+    """Initialize module by loading a pretrained model.
+
+    Args:
+        checkpoint (str): the checkpoint file of the pretrained model should
+            be load.
+        prefix (str, optional): the prefix of a sub-module in the pretrained
+            model. it is for loading a part of the pretrained model to
+            initialize. For example, if we would like to only load the
+            backbone of a detector model, we can set ``prefix='backbone.'``.
+            Defaults to None.
+        map_location (str): map tensors into proper locations.
+    """
+
+    def __init__(self, checkpoint, prefix=None, map_location=None):
+        self.checkpoint = checkpoint
+        self.prefix = prefix
+        self.map_location = map_location
+
+    def __call__(self, module):
+        from mmengine.runner.checkpoint import (_load_checkpoint_with_prefix,
+                                                load_checkpoint,
+                                                load_state_dict)
+        logger = MMLogger.get_instance('mmengine')
+        if self.prefix is None:
+            print_log(f'load model from: {self.checkpoint}', logger=logger)
+            load_checkpoint(
+                module,
+                self.checkpoint,
+                map_location=self.map_location,
+                strict=False,
+                logger=logger)
+        else:
+            print_log(
+                f'load {self.prefix} in model from: {self.checkpoint}',
+                logger=logger)
+            state_dict = _load_checkpoint_with_prefix(
+                self.prefix, self.checkpoint, map_location=self.map_location)
+            load_state_dict(module, state_dict, strict=False, logger=logger)
+
+        if hasattr(module, '_params_init_info'):
+            update_init_info(module, init_info=self._get_init_info())
+
+    def _get_init_info(self):
+        info = f'{self.__class__.__name__}: load from {self.checkpoint}'
+        return info
+
+
+def _initialize(module, cfg, wholemodule=False):
+    func = build_from_cfg(cfg, WEIGHT_INITIALIZERS)
+    # wholemodule flag is for override mode, there is no layer key in override
+    # and initializer will give init values for the whole module with the name
+    # in override.
+    func.wholemodule = wholemodule
+    func(module)
+
+
+def _initialize_override(module, override, cfg):
+    if not isinstance(override, (dict, list)):
+        raise TypeError(f'override must be a dict or a list of dict, \
+                but got {type(override)}')
+
+    override = [override] if isinstance(override, dict) else override
+
+    for override_ in override:
+
+        cp_override = copy.deepcopy(override_)
+        name = cp_override.pop('name', None)
+        if name is None:
+            raise ValueError('`override` must contain the key "name",'
+                             f'but got {cp_override}')
+        # if override only has name key, it means use args in init_cfg
+        if not cp_override:
+            cp_override.update(cfg)
+        # if override has name key and other args except type key, it will
+        # raise error
+        elif 'type' not in cp_override.keys():
+            raise ValueError(
+                f'`override` need "type" key, but got {cp_override}')
+
+        if hasattr(module, name):
+            _initialize(getattr(module, name), cp_override, wholemodule=True)
+        else:
+            raise RuntimeError(f'module did not have attribute {name}, '
+                               f'but init_cfg is {cp_override}.')
+
+
+def initialize(module, init_cfg):
+    r"""Initialize a module.
+    Args:
+        module (``torch.nn.Module``): the module will be initialized.
+        init_cfg (dict | list[dict]): initialization configuration dict to
+            define initializer. OpenMMLab has implemented 6 initializers
+            including ``Constant``, ``Xavier``, ``Normal``, ``Uniform``,
+            ``Kaiming``, and ``Pretrained``.
+    Example:
+        >>> module = nn.Linear(2, 3, bias=True)
+        >>> init_cfg = dict(type='Constant', layer='Linear', val =1 , bias =2)
+        >>> initialize(module, init_cfg)
+        >>> module = nn.Sequential(nn.Conv1d(3, 1, 3), nn.Linear(1,2))
+        >>> # define key ``'layer'`` for initializing layer with different
+        >>> # configuration
+        >>> init_cfg = [dict(type='Constant', layer='Conv1d', val=1),
+                dict(type='Constant', layer='Linear', val=2)]
+        >>> initialize(module, init_cfg)
+        >>> # define key``'override'`` to initialize some specific part in
+        >>> # module
+        >>> class FooNet(nn.Module):
+        >>>     def __init__(self):
+        >>>         super().__init__()
+        >>>         self.feat = nn.Conv2d(3, 16, 3)
+        >>>         self.reg = nn.Conv2d(16, 10, 3)
+        >>>         self.cls = nn.Conv2d(16, 5, 3)
+        >>> model = FooNet()
+        >>> init_cfg = dict(type='Constant', val=1, bias=2, layer='Conv2d',
+        >>>     override=dict(type='Constant', name='reg', val=3, bias=4))
+        >>> initialize(model, init_cfg)
+        >>> model = ResNet(depth=50)
+        >>> # Initialize weights with the pretrained model.
+        >>> init_cfg = dict(type='Pretrained',
+                checkpoint='torchvision://resnet50')
+        >>> initialize(model, init_cfg)
+        >>> # Initialize weights of a sub-module with the specific part of
+        >>> # a pretrained model by using "prefix".
+        >>> url = 'http://download.openmmlab.com/mmdetection/v2.0/retinanet/'\
+        >>>     'retinanet_r50_fpn_1x_coco/'\
+        >>>     'retinanet_r50_fpn_1x_coco_20200130-c2398f9e.pth'
+        >>> init_cfg = dict(type='Pretrained',
+                checkpoint=url, prefix='backbone.')
+    """
+    if not isinstance(init_cfg, (dict, list)):
+        raise TypeError(f'init_cfg must be a dict or a list of dict, \
+                but got {type(init_cfg)}')
+
+    if isinstance(init_cfg, dict):
+        init_cfg = [init_cfg]
+
+    for cfg in init_cfg:
+        # should deeply copy the original config because cfg may be used by
+        # other modules, e.g., one init_cfg shared by multiple bottleneck
+        # blocks, the expected cfg will be changed after pop and will change
+        # the initialization behavior of other modules
+        cp_cfg = copy.deepcopy(cfg)
+        override = cp_cfg.pop('override', None)
+        _initialize(module, cp_cfg)
+
+        if override is not None:
+            cp_cfg.pop('layer', None)
+            _initialize_override(module, override, cp_cfg)
+        else:
+            # All attributes in module have same initialization.
+            pass
+
+
+def _no_grad_trunc_normal_(tensor: Tensor, mean: float, std: float, a: float,
+                           b: float) -> Tensor:
+    # Method based on
+    # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+    # Modified from
+    # https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
+    def norm_cdf(x):
+        # Computes standard normal cumulative distribution function
+        return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+    if (mean < a - 2 * std) or (mean > b + 2 * std):
+        warnings.warn(
+            'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+            'The distribution of values may be incorrect.',
+            stacklevel=2)
+
+    with torch.no_grad():
+        # Values are generated by using a truncated uniform distribution and
+        # then using the inverse CDF for the normal distribution.
+        # Get upper and lower cdf values
+        lower = norm_cdf((a - mean) / std)
+        upper = norm_cdf((b - mean) / std)
+
+        # Uniformly fill tensor with values from [lower, upper], then translate
+        # to [2lower-1, 2upper-1].
+        tensor.uniform_(2 * lower - 1, 2 * upper - 1)
+
+        # Use inverse cdf transform for normal distribution to get truncated
+        # standard normal
+        tensor.erfinv_()
+
+        # Transform to proper mean, std
+        tensor.mul_(std * math.sqrt(2.))
+        tensor.add_(mean)
+
+        # Clamp to ensure it's in the proper range
+        tensor.clamp_(min=a, max=b)
+        return tensor
+
+
+def trunc_normal_(tensor: Tensor,
+                  mean: float = 0.,
+                  std: float = 1.,
+                  a: float = -2.,
+                  b: float = 2.) -> Tensor:
+    r"""Fills the input Tensor with values drawn from a truncated
+    normal distribution. The values are effectively drawn from the
+    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+    with values outside :math:`[a, b]` redrawn until they are within
+    the bounds. The method used for generating the random values works
+    best when :math:`a \leq \text{mean} \leq b`.
+    Modified from
+    https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
+    Args:
+        tensor (``torch.Tensor``): an n-dimensional `torch.Tensor`.
+        mean (float): the mean of the normal distribution.
+        std (float): the standard deviation of the normal distribution.
+        a (float): the minimum cutoff value.
+        b (float): the maximum cutoff value.
+    """
+    return _no_grad_trunc_normal_(tensor, mean, std, a, b)
diff --git a/mmengine/utils/version_utils.py b/mmengine/utils/version_utils.py
index 963c45a2e8a86a88413ab6c18c22481fb9831985..77c41f608439f85aa29f8a6c9bd148b04d0c5973 100644
--- a/mmengine/utils/version_utils.py
+++ b/mmengine/utils/version_utils.py
@@ -41,7 +41,7 @@ def digit_version(version_str: str, length: int = 4):
             release.extend([val, 0])
 
     elif version.is_postrelease:
-        release.extend([1, version.post])
+        release.extend([1, version.post])  # type: ignore
     else:
         release.extend([0, 0])
     return tuple(release)
diff --git a/tests/test_model/test_base_module.py b/tests/test_model/test_base_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..a253e0b3810b7a7fcd2e2684005fff1af8e433fd
--- /dev/null
+++ b/tests/test_model/test_base_module.py
@@ -0,0 +1,197 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest import TestCase
+
+import torch
+from torch import nn
+
+from mmengine.logging.logger import MMLogger
+from mmengine.model.base_module import BaseModule
+from mmengine.registry import Registry, build_from_cfg
+
+COMPONENTS = Registry('component')
+FOOMODELS = Registry('model')
+
+Logger = MMLogger.get_current_instance()
+
+
+@COMPONENTS.register_module()
+class FooConv1d(BaseModule):
+
+    def __init__(self, init_cfg=None):
+        super().__init__(init_cfg)
+        self.conv1d = nn.Conv1d(4, 1, 4)
+
+    def forward(self, x):
+        return self.conv1d(x)
+
+
+@COMPONENTS.register_module()
+class FooConv2d(BaseModule):
+
+    def __init__(self, init_cfg=None):
+        super().__init__(init_cfg)
+        self.conv2d = nn.Conv2d(3, 1, 3)
+
+    def forward(self, x):
+        return self.conv2d(x)
+
+
+@COMPONENTS.register_module()
+class FooLinear(BaseModule):
+
+    def __init__(self, init_cfg=None):
+        super().__init__(init_cfg)
+        self.linear = nn.Linear(3, 4)
+
+    def forward(self, x):
+        return self.linear(x)
+
+
+@COMPONENTS.register_module()
+class FooLinearConv1d(BaseModule):
+
+    def __init__(self, linear=None, conv1d=None, init_cfg=None):
+        super().__init__(init_cfg)
+        if linear is not None:
+            self.linear = build_from_cfg(linear, COMPONENTS)
+        if conv1d is not None:
+            self.conv1d = build_from_cfg(conv1d, COMPONENTS)
+
+    def forward(self, x):
+        x = self.linear(x)
+        return self.conv1d(x)
+
+
+@FOOMODELS.register_module()
+class FooModel(BaseModule):
+
+    def __init__(self,
+                 component1=None,
+                 component2=None,
+                 component3=None,
+                 component4=None,
+                 init_cfg=None) -> None:
+        super().__init__(init_cfg)
+        if component1 is not None:
+            self.component1 = build_from_cfg(component1, COMPONENTS)
+        if component2 is not None:
+            self.component2 = build_from_cfg(component2, COMPONENTS)
+        if component3 is not None:
+            self.component3 = build_from_cfg(component3, COMPONENTS)
+        if component4 is not None:
+            self.component4 = build_from_cfg(component4, COMPONENTS)
+
+        # its type is not BaseModule, it can be initialized
+        # with "override" key.
+        self.reg = nn.Linear(3, 4)
+
+
+class TestBaseModule(TestCase):
+
+    def setUp(self) -> None:
+        self.BaseModule = BaseModule()
+        self.model_cfg = dict(
+            type='FooModel',
+            init_cfg=[
+                dict(type='Constant', val=1, bias=2, layer='Linear'),
+                dict(type='Constant', val=3, bias=4, layer='Conv1d'),
+                dict(type='Constant', val=5, bias=6, layer='Conv2d')
+            ],
+            component1=dict(type='FooConv1d'),
+            component2=dict(type='FooConv2d'),
+            component3=dict(type='FooLinear'),
+            component4=dict(
+                type='FooLinearConv1d',
+                linear=dict(type='FooLinear'),
+                conv1d=dict(type='FooConv1d')))
+
+        self.model = build_from_cfg(self.model_cfg, FOOMODELS)
+
+    def test_is_init(self):
+        assert self.BaseModule.is_init is False
+
+    def test_init_weights(self):
+        """
+        Config
+        model (FooModel, Linear: weight=1, bias=2, Conv1d: weight=3, bias=4,
+                        Conv2d: weight=5, bias=6)
+        ├──component1 (FooConv1d)
+        ├──component2 (FooConv2d)
+        ├──component3 (FooLinear)
+        ├──component4 (FooLinearConv1d)
+            ├──linear (FooLinear)
+            ├──conv1d (FooConv1d)
+        ├──reg (nn.Linear)
+        Parameters after initialization
+        model (FooModel)
+        ├──component1 (FooConv1d, weight=3, bias=4)
+        ├──component2 (FooConv2d, weight=5, bias=6)
+        ├──component3 (FooLinear, weight=1, bias=2)
+        ├──component4 (FooLinearConv1d)
+            ├──linear (FooLinear, weight=1, bias=2)
+            ├──conv1d (FooConv1d, weight=3, bias=4)
+        ├──reg (nn.Linear, weight=1, bias=2)
+        """
+
+        self.model.init_weights()
+
+        assert torch.equal(
+            self.model.component1.conv1d.weight,
+            torch.full(self.model.component1.conv1d.weight.shape, 3.0))
+        assert torch.equal(
+            self.model.component1.conv1d.bias,
+            torch.full(self.model.component1.conv1d.bias.shape, 4.0))
+        assert torch.equal(
+            self.model.component2.conv2d.weight,
+            torch.full(self.model.component2.conv2d.weight.shape, 5.0))
+        assert torch.equal(
+            self.model.component2.conv2d.bias,
+            torch.full(self.model.component2.conv2d.bias.shape, 6.0))
+        assert torch.equal(
+            self.model.component3.linear.weight,
+            torch.full(self.model.component3.linear.weight.shape, 1.0))
+        assert torch.equal(
+            self.model.component3.linear.bias,
+            torch.full(self.model.component3.linear.bias.shape, 2.0))
+        assert torch.equal(
+            self.model.component4.linear.linear.weight,
+            torch.full(self.model.component4.linear.linear.weight.shape, 1.0))
+        assert torch.equal(
+            self.model.component4.linear.linear.bias,
+            torch.full(self.model.component4.linear.linear.bias.shape, 2.0))
+        assert torch.equal(
+            self.model.component4.conv1d.conv1d.weight,
+            torch.full(self.model.component4.conv1d.conv1d.weight.shape, 3.0))
+        assert torch.equal(
+            self.model.component4.conv1d.conv1d.bias,
+            torch.full(self.model.component4.conv1d.conv1d.bias.shape, 4.0))
+        assert torch.equal(self.model.reg.weight,
+                           torch.full(self.model.reg.weight.shape, 1.0))
+        assert torch.equal(self.model.reg.bias,
+                           torch.full(self.model.reg.bias.shape, 2.0))
+
+    def test_dump_init_info(self):
+        import os
+        import shutil
+        dump_dir = 'tests/test_model/test_dump_info'
+        if not (os.path.exists(dump_dir) and os.path.isdir(dump_dir)):
+            os.makedirs(dump_dir)
+        for filename in os.listdir(dump_dir):
+            file_path = os.path.join(dump_dir, filename)
+            if os.path.isfile(file_path) or os.path.islink(file_path):
+                os.unlink(file_path)
+            elif os.path.isdir(file_path):
+                shutil.rmtree(file_path)
+
+        MMLogger.get_instance('logger1')  # add logger without FileHandler
+        model1 = build_from_cfg(self.model_cfg, FOOMODELS)
+        model1.init_weights()
+        assert len(os.listdir(dump_dir)) == 0
+        log_path = os.path.join(dump_dir, 'out.log')
+        MMLogger.get_instance(
+            'logger2', log_file=log_path)  # add logger with FileHandler
+        model2 = build_from_cfg(self.model_cfg, FOOMODELS)
+        model2.init_weights()
+        assert len(os.listdir(dump_dir)) == 1
+        assert os.stat(log_path).st_size != 0
+        shutil.rmtree(dump_dir)