From f04fec736de32cc9c9bab3383aca12d85ca02a77 Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Tue, 7 Jun 2022 22:13:53 +0800
Subject: [PATCH] [Feature]: add base model, ddp model wrapper and unit test
 (#268)

* add base model, ddp model and unit test

* add unit test

* fix unit test

* fix docstring

* fix cpu unit test

* refine base data preprocessor

* refine base data preprocessor

* refine interface of ddp module

* remove optimizer hook

* add forward

* fix as comment

* fix unit test

* fix as comment

* fix build optimizer wrapper

* rebase main and fix unit test

* stack_batch support stacking ndim tensor, add docstring for merge dict

* fix lint

* fix test loop

* make precision_context effective to data_preprocessor

* fix as comment

* fix as comment

* refine docstring

* change collate_data output typehints

* rename to_rgb to bgr_to_rgb and rgb_to_bgr

* support build basemodel with built DataPreprocessor

* fix as comment

* fix docstring
---
 mmengine/hooks/__init__.py                    |   5 +-
 mmengine/hooks/optimizer_hook.py              | 129 ---------
 mmengine/hooks/runtime_info_hook.py           |   2 +-
 mmengine/model/__init__.py                    |  15 +-
 mmengine/model/base_model/__init__.py         |   9 +
 mmengine/model/base_model/base_model.py       | 256 ++++++++++++++++++
 .../model/base_model/data_preprocessor.py     | 213 +++++++++++++++
 mmengine/model/base_module.py                 |   2 +-
 .../model/{utils/weight_init.py => utils.py}  | 131 +++++++++
 mmengine/model/wrappers/__init__.py           |   8 +-
 mmengine/model/wrappers/data_parallel.py      | 149 ----------
 mmengine/model/wrappers/distributed.py        | 123 +++++++++
 .../model/wrappers/seperate_distributed.py    | 124 +++++++++
 mmengine/runner/loops.py                      |  32 ++-
 mmengine/runner/runner.py                     |  74 ++---
 mmengine/utils/misc.py                        |   5 +-
 tests/test_hook/test_ema_hook.py              |  37 +--
 tests/test_hook/test_optimizer_hook.py        | 115 --------
 tests/test_hook/test_runtime_info_hook.py     |   7 +-
 tests/test_logging/test_message_hub.py        |   2 +-
 .../test_base_model/test_base_model.py        | 125 +++++++++
 .../test_base_model/test_data_preprocessor.py | 110 ++++++++
 .../test_wrappers/test_data_parallel.py       | 141 ----------
 .../test_wrappers/test_model_wrapper.py       | 161 +++++++++++
 tests/test_runner/test_runner.py              | 112 +++++---
 25 files changed, 1431 insertions(+), 656 deletions(-)
 delete mode 100644 mmengine/hooks/optimizer_hook.py
 create mode 100644 mmengine/model/base_model/__init__.py
 create mode 100644 mmengine/model/base_model/base_model.py
 create mode 100644 mmengine/model/base_model/data_preprocessor.py
 rename mmengine/model/{utils/weight_init.py => utils.py} (84%)
 delete mode 100644 mmengine/model/wrappers/data_parallel.py
 create mode 100644 mmengine/model/wrappers/distributed.py
 create mode 100644 mmengine/model/wrappers/seperate_distributed.py
 delete mode 100644 tests/test_hook/test_optimizer_hook.py
 create mode 100644 tests/test_model/test_base_model/test_base_model.py
 create mode 100644 tests/test_model/test_base_model/test_data_preprocessor.py
 delete mode 100644 tests/test_model/test_wrappers/test_data_parallel.py
 create mode 100644 tests/test_model/test_wrappers/test_model_wrapper.py

diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py
index ecc72e6a..fe326332 100644
--- a/mmengine/hooks/__init__.py
+++ b/mmengine/hooks/__init__.py
@@ -6,7 +6,6 @@ from .hook import Hook
 from .iter_timer_hook import IterTimerHook
 from .logger_hook import LoggerHook
 from .naive_visualization_hook import NaiveVisualizationHook
-from .optimizer_hook import OptimizerHook
 from .param_scheduler_hook import ParamSchedulerHook
 from .runtime_info_hook import RuntimeInfoHook
 from .sampler_seed_hook import DistSamplerSeedHook
@@ -14,6 +13,6 @@ from .sync_buffer_hook import SyncBuffersHook
 
 __all__ = [
     'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
-    'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook',
-    'LoggerHook', 'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook'
+    'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook', 'LoggerHook',
+    'NaiveVisualizationHook', 'EMAHook', 'RuntimeInfoHook'
 ]
diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py
deleted file mode 100644
index c00d9dea..00000000
--- a/mmengine/hooks/optimizer_hook.py
+++ /dev/null
@@ -1,129 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import logging
-from typing import List, Optional, Sequence
-
-import torch
-from torch.nn.parameter import Parameter
-from torch.nn.utils import clip_grad
-
-from mmengine.registry import HOOKS
-from .hook import Hook
-
-DATA_BATCH = Optional[Sequence[dict]]
-
-
-@HOOKS.register_module()
-class OptimizerHook(Hook):
-    """A hook contains custom operations for the optimizer.
-
-    Args:
-        grad_clip (dict, optional): A config dict to control the clip_grad.
-            Defaults to None.
-        detect_anomalous_params (bool): This option is only used for
-            debugging which will slow down the training speed.
-            Detect anomalous parameters that are not included in
-            the computational graph with ``loss`` as the root.
-            There are two cases
-                - Parameters were not used during
-                  forward pass.
-                - Parameters were not used to produce
-                  loss.
-            Defaults to False.
-    """
-
-    priority = 'HIGH'
-
-    def __init__(self,
-                 grad_clip: Optional[dict] = None,
-                 detect_anomalous_params: bool = False) -> None:
-        self.grad_clip = grad_clip
-        self.detect_anomalous_params = detect_anomalous_params
-
-    def clip_grads(self, params: List[Parameter]) -> Optional[torch.Tensor]:
-        """Clip the gradients of parameters.
-
-        Args:
-            params (list[Parameter]): Model's parameters.
-
-        Returns:
-            Optional[torch.Tensor]: Total norm of the parameters if there is
-            at least one param requiring gradient, else None.
-        """
-        params = list(
-            filter(lambda p: p.requires_grad and p.grad is not None, params))
-        if len(params) > 0:
-            return clip_grad.clip_grad_norm_(params, **self.grad_clip)
-        return None
-
-    def after_train_iter(self,
-                         runner,
-                         batch_idx: int,
-                         data_batch: DATA_BATCH = None,
-                         outputs: Optional[dict] = None) -> None:
-        """All operations need to be finished after each training iteration.
-
-        This function will finish following 3 operations:
-
-        - Detect any anomalous parameters which are not included in the
-          training graph. (optional)
-
-        - Compute the gradient of model parameters.
-
-        - Clip the gradients of each parameter. (optional)
-
-        - Update model parameters with gradients.
-
-        Args:
-            runner (Runner): The runner of the training process.
-            batch_idx (int): The index of the current batch in the train loop.
-            data_batch (Sequence[dict], optional): Data from dataloader.
-                In order to keep this interface consistent with other hooks,
-                we keep ``data_batch`` here. Defaults to None.
-            outputs (dict, optional): Outputs from model.
-                In order to keep this interface consistent with other hooks,
-                we keep ``outputs`` here. Defaults to None.
-        """
-        runner.optim_wrapper.zero_grad()
-        if self.detect_anomalous_params:
-            self.detect_anomalous_parameters(runner.outputs['loss'], runner)
-        runner.outputs['loss'].backward()
-
-        if self.grad_clip is not None:
-            grad_norm = self.clip_grads(runner.model.parameters())
-            if grad_norm is not None:
-                # Add grad norm to the logger
-                runner.message_hub.update_scalar('train/grad_norm',
-                                                 float(grad_norm))
-        runner.optim_wrapper.step()
-
-    def detect_anomalous_parameters(self, loss: torch.Tensor, runner) -> None:
-        """Detect anomalous parameters that are not included in the graph.
-
-        Args:
-            loss (torch.Tensor): The loss of current iteration.
-            runner (Runner): The runner of the training process.
-        """
-        logger = runner.logger
-        parameters_in_graph = set()
-        visited = set()
-
-        def traverse(grad_fn):
-            if grad_fn is None:
-                return
-            if grad_fn not in visited:
-                visited.add(grad_fn)
-                if hasattr(grad_fn, 'variable'):
-                    parameters_in_graph.add(grad_fn.variable)
-                parents = grad_fn.next_functions
-                if parents is not None:
-                    for parent in parents:
-                        grad_fn = parent[0]
-                        traverse(grad_fn)
-
-        traverse(loss.grad_fn)
-        for n, p in runner.model.named_parameters():
-            if p not in parameters_in_graph and p.requires_grad:
-                logger.log(
-                    level=logging.ERROR,
-                    msg=f'{n} with shape {p.size()} is not '
-                    f'in the computational graph \n')
diff --git a/mmengine/hooks/runtime_info_hook.py b/mmengine/hooks/runtime_info_hook.py
index 56186ff5..091ced50 100644
--- a/mmengine/hooks/runtime_info_hook.py
+++ b/mmengine/hooks/runtime_info_hook.py
@@ -59,7 +59,7 @@ class RuntimeInfoHook(Hook):
                          outputs: Optional[dict] = None) -> None:
         """Update ``log_vars`` in model outputs every iteration."""
         if outputs is not None:
-            for key, value in outputs['log_vars'].items():
+            for key, value in outputs.items():
                 runner.message_hub.update_scalar(f'train/{key}', value)
 
     def after_val_epoch(self,
diff --git a/mmengine/model/__init__.py b/mmengine/model/__init__.py
index 082f9131..0b7f08e7 100644
--- a/mmengine/model/__init__.py
+++ b/mmengine/model/__init__.py
@@ -1,11 +1,16 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA,
                              StochasticWeightAverage)
-from .wrappers import (MMDataParallel, MMDistributedDataParallel,
-                       is_model_wrapper)
+from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
+from .base_module import BaseModule
+from .utils import detect_anomalous_params, merge_dict, stach_batch_imgs
+from .wrappers import (MMDistributedDataParallel,
+                       MMSeparateDistributedDataParallel, is_model_wrapper)
 
 __all__ = [
-    'MMDistributedDataParallel', 'MMDataParallel', 'is_model_wrapper',
-    'StochasticWeightAverage', 'ExponentialMovingAverage',
-    'MomentumAnnealingEMA'
+    'MMDistributedDataParallel', 'is_model_wrapper', 'StochasticWeightAverage',
+    'ExponentialMovingAverage', 'MomentumAnnealingEMA', 'BaseModel',
+    'BaseDataPreprocessor', 'ImgDataPreprocessor',
+    'MMSeparateDistributedDataParallel', 'BaseModule', 'stach_batch_imgs',
+    'merge_dict', 'detect_anomalous_params'
 ]
diff --git a/mmengine/model/base_model/__init__.py b/mmengine/model/base_model/__init__.py
new file mode 100644
index 00000000..696c83ad
--- /dev/null
+++ b/mmengine/model/base_model/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .base_model import BaseModel
+from .data_preprocessor import (BaseDataElement, BaseDataPreprocessor,
+                                ImgDataPreprocessor)
+
+__all__ = [
+    'BaseModel', 'BaseDataElement', 'ImgDataPreprocessor',
+    'BaseDataPreprocessor'
+]
diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py
new file mode 100644
index 00000000..ede27b7a
--- /dev/null
+++ b/mmengine/model/base_model/base_model.py
@@ -0,0 +1,256 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import abstractmethod
+from collections import OrderedDict
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from mmengine.data import BaseDataElement
+from mmengine.optim import OptimWrapper
+from mmengine.registry import MODELS
+from mmengine.utils import is_list_of
+from ..base_module import BaseModule
+
+ForwardResults = Union[Dict[str, torch.Tensor], List[BaseDataElement],
+                       Tuple[torch.Tensor], torch.Tensor]
+
+
+class BaseModel(BaseModule):
+    """Base class for all algorithmic models.
+
+    BaseModel implements the basic functions of the algorithmic model, such as
+    weights initialize, batch inputs preprocess(see more information in
+    :class:`BaseDataPreprocessor`), parse losses, and update model parameters.
+
+    Subclasses inherit from BaseModel only need to implement the forward
+    method, which implements the logic to calculate loss and predictions,
+    then can be trained in the runner.
+
+    Examples:
+        >>> @MODELS.register_module()
+        >>> class ToyModel(BaseModel):
+        >>>
+        >>>     def __init__(self):
+        >>>         super().__init__()
+        >>>         self.backbone = nn.Sequential()
+        >>>         self.backbone.add_module('conv1', nn.Conv2d(3, 6, 5))
+        >>>         self.backbone.add_module('pool', nn.MaxPool2d(2, 2))
+        >>>         self.backbone.add_module('conv2', nn.Conv2d(6, 16, 5))
+        >>>         self.backbone.add_module('fc1', nn.Linear(16 * 5 * 5, 120))
+        >>>         self.backbone.add_module('fc2', nn.Linear(120, 84))
+        >>>         self.backbone.add_module('fc3', nn.Linear(84, 10))
+        >>>
+        >>>         self.criterion = nn.CrossEntropyLoss()
+        >>>
+        >>>     def forward(self, batch_inputs, data_samples, mode='tensor'):
+        >>>         data_samples = torch.stack(data_samples)
+        >>>         if mode == 'tensor':
+        >>>             return self.backbone(batch_inputs)
+        >>>         elif mode == 'predict':
+        >>>             feats = self.backbone(batch_inputs)
+        >>>             predictions = torch.argmax(feats, 1)
+        >>>             return predictions
+        >>>         elif mode == 'loss':
+        >>>             feats = self.backbone(batch_inputs)
+        >>>             loss = self.criterion(feats, data_samples)
+        >>>             return dict(loss=loss)
+
+    Args:
+        init_cfg (dict, optional): The weight initialized config for
+            :class:`BaseModule`.
+        data_preprocessor (dict, optional): The pre-process config of
+            :class:`BaseDataPreprocessor`.
+
+    Attributes:
+        init_cfg (dict, optional): Initialization config dict.
+        data_preprocessor (:obj:`BaseDataPreprocessor`): Used for
+            pre-processing data sampled by dataloader to the format accepted by
+            :meth:`forward`.
+    """
+
+    def __init__(self,
+                 init_cfg: Optional[dict] = None,
+                 data_preprocessor: Optional[Union[dict, nn.Module]] = None):
+        super().__init__(init_cfg)
+        if data_preprocessor is None:
+            data_preprocessor = dict(type='BaseDataPreprocessor')
+        if isinstance(data_preprocessor, nn.Module):
+            self.data_preprocessor = data_preprocessor
+        elif isinstance(data_preprocessor, dict):
+            self.data_preprocessor = MODELS.build(data_preprocessor)
+        else:
+            raise TypeError('data_preprocessor should be a `dict` or '
+                            f'`nn.Module` instance, but got '
+                            f'{type(data_preprocessor)}')
+
+    def train_step(self, data: List[dict],
+                   optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
+        """Implements the default model training process including
+        preprocessing, model forward propagation, loss calculation,
+        optimization, and back-propagation.
+
+        During non-distributed training. If subclasses do not override the
+        :meth:`train_step`, :class:`EpochBasedTrainLoop` or
+        :class:`IterBasedTrainLoop` will call this method to update model
+        parameters. The default parameter update process is as follows:
+
+        1. Calls ``self.data_processor(data, training=False) to collext
+          batch_inputs and corresponding data_samples(labels).
+        2. Calls ``self(batch_inputs, data_samples, mode='loss')`` to get raw
+          loss
+        3. Calls ``self.parse_losses`` to get ``parsed_losses`` tensor used to
+          backward and dict of loss tensor used to log messages.
+        4. Calls ``optim_wrapper.update_params(loss)`` to update model.
+
+        Args:
+            data (List[dict]): Data sampled from dataloader.
+            optim_wrapper (OptimWrapper): OptimWrapper instance
+                used to update model parameters.
+
+        Returns:
+            Dict[str, torch.Tensor]: A ``dict`` of tensor for logging.
+        """
+        # enable automatic mixed precision training context.
+        with optim_wrapper.precision_context():
+            batch_inputs, data_samples = self.data_preprocessor(data, True)
+            losses = self(batch_inputs, data_samples, mode='loss')
+        parsed_losses, log_vars = self.parse_losses(losses)
+        optim_wrapper.update_params(parsed_losses)
+        return log_vars
+
+    def val_step(self, data: List[dict]) -> List[BaseDataElement]:
+        """Gets the predictions of given data.
+
+        Calls ``self.data_preprocessor(data, False)`` and
+        ``self(inputs, data_sample, mode='predict')`` in order. Return the
+        predictions which will be passed to evaluator.
+
+        Args:
+            data (List[dict]): Data sampled from dataloader.
+
+        Returns:
+            List[BaseDataElement]: The predictions of given data.
+        """
+        inputs, data_sample = self.data_preprocessor(data, False)
+        return self(inputs, data_sample, mode='predict')
+
+    def test_step(self, data: List[dict]) -> List[BaseDataElement]:
+        """``BaseModel`` implements ``test_step`` the same as ``val_step``.
+
+        Args:
+            data (List[dict]): Data sampled from dataloader.
+
+        Returns:
+            List[BaseDataElement]: The predictions of given data.
+        """
+        inputs, data_sample = self.data_preprocessor(data, False)
+        return self(inputs, data_sample, mode='predict')
+
+    def parse_losses(
+        self, losses: Dict[str, torch.Tensor]
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+        """Parses the raw outputs (losses) of the network.
+
+        Args:
+            losses (dict): Raw output of the network, which usually contain
+                losses and other necessary information.
+
+        Returns:
+            tuple[Tensor, dict]: There are two elements. The first is the
+            loss tensor passed to optim_wrapper which may be a weighted sum of
+            all losses, and the second is log_vars which will be sent to the
+            logger.
+        """
+        log_vars = OrderedDict()
+        for loss_name, loss_value in losses.items():
+            if isinstance(loss_value, torch.Tensor):
+                log_vars[loss_name] = loss_value.mean()
+            elif is_list_of(loss_value, torch.Tensor):
+                log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
+            else:
+                raise TypeError(
+                    f'{loss_name} is not a tensor or list of tensors')
+
+        loss = sum(value for key, value in log_vars.items() if 'loss' in key)
+        log_vars['loss'] = loss
+
+        return loss, log_vars
+
+    def to(self, device: Optional[Union[int, torch.device]], *args,
+           **kwargs) -> nn.Module:
+        """Overrides this method to set the ``device`` attribute of
+        :obj:`BaseDataPreprocessor` additionally
+
+        Args:
+            device (int or torch.device, optional): the desired device of the
+                parameters and buffers in this module.
+
+        Returns:
+            nn.Module: The model itself.
+        """
+        self.data_preprocessor.device = torch.device(device)
+        return super().to(device)
+
+    def cuda(self, *args, **kwargs) -> nn.Module:
+        """Overrides this method to set the ``device`` attribute of
+        :obj:`BaseDataPreprocessor` additionally
+
+        Returns:
+            nn.Module: The model itself.
+        """
+        self.data_preprocessor.device = torch.cuda.current_device()
+        return super().cuda()
+
+    @abstractmethod
+    def forward(self,
+                batch_inputs: torch.Tensor,
+                data_samples: Optional[List[BaseDataElement]] = None,
+                mode: str = 'tensor') -> ForwardResults:
+        """Returns losses or predictions of training, validation, testing, and
+        simple inference process.
+
+        ``forward`` method of BaseModel is an abstract method, its subclasses
+        must implement this method.
+
+        Accepts ``batch_inputs`` and ``data_samples`` processed by
+        :attr:`data_preprocessor`, and returns results according to mode
+        arguments.
+
+        During non-distributed training, validation, and testing process,
+        ``forward`` will be called by ``BaseModel.train_step``,
+        ``BaseModel.val_step`` and ``BaseModel.val_step`` directly.
+
+        During distributed data parallel training process,
+        ``MMSeparateDistributedDataParallel.train_step`` will first call
+        ``DistributedDataParallel.forward`` to enable automatic
+        gradient synchronization, and then call ``forward`` to get training
+        loss.
+
+        Args:
+            batch_inputs (torch.Tensor): batch input tensor collated by
+                :attr:`data_preprocessor`.
+            data_samples (List[BaseDataElement], optional):
+                data samples collated by :attr:`data_preprocessor`.
+            mode (str): mode should be one of ``loss``, ``predict`` and
+                ``tensor``
+
+                - ``loss``: Called by ``train_step`` and return loss ``dict``
+                  used for logging
+                - ``predict``: Called by ``val_step`` and ``test_step``
+                  and return list of ``BaseDataElement`` results used for
+                  computing metric.
+                - ``tensor``: Called by custom use to get ``Tensor`` type
+                  results.
+
+        Returns:
+            ForwardResults:
+
+                - If ``mode == loss``, return a ``dict`` of loss tensor used
+                  for backward and logging.
+                - If ``mode == predict``, return a ``list`` of
+                  :obj:`BaseDataElement` for computing metric
+                  and getting inference result.
+                - If ``mode == tensor``, return a tensor or ``tuple`` of tensor
+                  or ``dict of tensor for custom use.
+        """
diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py
new file mode 100644
index 00000000..2b9d2cb3
--- /dev/null
+++ b/mmengine/model/base_model/data_preprocessor.py
@@ -0,0 +1,213 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import List, Optional, Sequence, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from mmengine.data import BaseDataElement
+from mmengine.registry import MODELS
+from ..utils import stach_batch_imgs
+
+
+@MODELS.register_module()
+class BaseDataPreprocessor(nn.Module):
+    """Base data pre-processor used for collating and copying data to the
+    target device.
+
+    ``BaseDataPreprocessor`` performs data pre-processing according to the
+    following steps:
+
+    - Collates the data sampled from dataloader.
+    - Copies data to the target device.
+    - Stacks the input tensor at the first dimension.
+
+    Subclasses inherit from ``BaseDataPreprocessor`` could override the
+    forward method to implement custom data pre-processing, such as
+    batch-resize, MixUp, or CutMix.
+
+    Args:
+        device (int or torch.device): Target device.
+
+    Warnings:
+        Each item of data sampled from dataloader must be a dict and at least
+        contain the ``inputs`` key. Furthermore, the value of ``inputs``
+        must be a ``Tensor`` with the same shape.
+    """
+
+    def __init__(self, device: Union[int, torch.device] = 'cpu'):
+        super().__init__()
+        self.device = device
+
+    def collate_data(
+            self,
+            data: Sequence[dict]) -> Tuple[List[torch.Tensor], Optional[list]]:
+        """Collating and copying data to the target device.
+
+        Collates the data sampled from dataloader into a list of tensor and
+        list of labels, and then copies tensor to the target device.
+
+        Subclasses could override it to be compatible with the custom format
+        data sampled from custom dataloader.
+
+        Args:
+            data (Sequence[dict]): Data sampled from dataloader.
+
+        Returns:
+            Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input
+            tensor and list of labels at target device.
+        """
+        inputs = [_data['inputs'].to(self.device) for _data in data]
+        batch_data_samples: List[BaseDataElement] = []
+        # Model can get predictions without any data samples.
+        for _data in data:
+            if 'data_sample' in _data:
+                batch_data_samples.append(_data['data_sample'])
+        # Move data from CPU to corresponding device.
+        batch_data_samples = [
+            data_sample.to(self.device) for data_sample in batch_data_samples
+        ]
+
+        if not batch_data_samples:
+            batch_data_samples = None  # type: ignore
+
+        return inputs, batch_data_samples
+
+    def forward(self,
+                data: Sequence[dict],
+                training: bool = False) -> Tuple[torch.Tensor, Optional[list]]:
+        """Preprocesses the data into the model input format.
+
+        After the data pre-processing of :meth:`collate_data`, ``forward``
+        will stack the input tensor list to a batch tensor at the first
+        dimension.
+
+        Args:
+            data (Sequence[dict]): data sampled from dataloader.
+            training (bool): Whether to enable training time augmentation.
+
+        Returns:
+            Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
+            model input.
+        """
+        inputs, batch_data_samples = self.collate_data(data)
+        batch_inputs = torch.stack(inputs, dim=0)
+        return batch_inputs, batch_data_samples
+
+    def to(self, device: Optional[Union[int, torch.device]], *args,
+           **kwargs) -> nn.Module:
+        """Overrides this method to set the :attr:`device`
+
+        Args:
+            device (int or torch.device, optional): The desired device of the
+                parameters and buffers in this module.
+
+        Returns:
+            nn.Module: The model itself.
+        """
+        self.device = torch.device(device)
+        return super().to(device)
+
+    def cuda(self, *args, **kwargs) -> nn.Module:
+        """Overrides this method to set the :attr:`device`
+
+        Returns:
+            nn.Module: The model itself.
+        """
+        self.device = torch.cuda.current_device()
+        return super().cuda()
+
+
+@MODELS.register_module()
+class ImgDataPreprocessor(BaseDataPreprocessor):
+    """Image pre-processor for normalization and bgr to rgb conversion.
+
+    Accepts the data sampled by the dataloader, and preprocesses it into the
+    format of the model input. ``ImgDataPreprocessor`` provides the
+    basic data pre-processing as follows
+
+    - Collates and moves data to the target device.
+    - Converts inputs from bgr to rgb if the shape of input is (3, H, W).
+    - Normalizes image with defined std and mean.
+    - Pads inputs to the maximum size of current batch with defined
+      ``pad_value``. The padding size can be divisible by a defined
+      ``pad_size_divisor``
+    - Stack inputs to batch_inputs.
+
+    For ``ImgDataPreprocessor``, the dimension of the single inputs must be
+    (3, H, W).
+
+    Note:
+        ``ImgDataPreprocessor`` and its subclass is built in the
+        constructor of :class:`BaseDataset`.
+
+    Args:
+        mean (Sequence[float or int]): The pixel mean of image channels. If
+            ``bgr_to_rgb=True`` it means the mean value of R, G, B channels.
+            If ``mean`` and ``std`` are not specified, ``ImgDataPreprocessor``
+            will normalize images to [-1, 1]. Defaults to (127.5, 127.5,
+            127.5).
+        std (Sequence[float or int]): The pixel standard deviation of image
+            channels. If ``bgr_to_rgb=True`` it means the standard deviation of
+            R, G, B channels. If ``mean`` and ``std`` are not specified,
+            ImgDataPreprocessor will normalize images to [-1, 1]. Defaults
+            to (127.5, 127.5, 127.5).
+        pad_size_divisor (int): The size of padded image should be
+            divisible by ``pad_size_divisor``. Defaults to 1.
+        pad_value (float or int): The padded pixel value. Defaults to 0.
+        bgr_to_rgb (bool): whether to convert image from BGR to RGB.
+            Defaults to False.
+        rgb_to_bgr (bool): whether to convert image from RGB to RGB.
+            Defaults to False.
+        device (int or torch.device): Target device.
+    """
+
+    def __init__(self,
+                 mean: Sequence[Union[float, int]] = (127.5, 127.5, 127.5),
+                 std: Sequence[Union[float, int]] = (127.5, 127.5, 127.5),
+                 pad_size_divisor: int = 1,
+                 pad_value: Union[float, int] = 0,
+                 bgr_to_rgb: bool = False,
+                 rgb_to_bgr: bool = False,
+                 device: Union[int, torch.device] = 'cpu'):
+        super().__init__(device)
+        assert len(mean) == 3 or len(mean) == 1, (
+            'The length of mean should be 1 or 3 to be compatible with RGB '
+            f'or gray image, but got {len(mean)}')
+        assert len(std) == 3 or len(std) == 1, (
+            'The length of std should be 1 or 3 to be compatible with RGB '
+            f'or gray image, but got {len(std)}')
+        assert not (bgr_to_rgb and rgb_to_bgr), (
+            '`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time')
+        self.channel_conversion = rgb_to_bgr or bgr_to_rgb
+        self.register_buffer('mean', torch.tensor(mean).view(-1, 1, 1), False)
+        self.register_buffer('std', torch.tensor(std).view(-1, 1, 1), False)
+        self.pad_size_divisor = pad_size_divisor
+        self.pad_value = pad_value
+
+    def forward(self,
+                data: Sequence[dict],
+                training: bool = False) -> Tuple[torch.Tensor, Optional[list]]:
+        """Performs normalization、padding and bgr2rgb conversion based on
+        ``BaseDataPreprocessor``.
+
+        Args:
+            data (Sequence[dict]): data sampled from dataloader.
+            training (bool): Whether to enable training time augmentation. If
+                subclasses override this method, they can perform different
+                preprocessing strategies for training and testing based on the
+                value of ``training``.
+
+        Returns:
+            Tuple[torch.Tensor, Optional[list]]: Data in the same format as the
+            model input.
+        """
+        inputs, batch_data_samples = self.collate_data(data)
+        # channel transform
+        if self.channel_conversion:
+            inputs = [_input[[2, 1, 0], ...] for _input in inputs]
+        # Normalization.
+        inputs = [(_input - self.mean) / self.std for _input in inputs]
+        # Pad and stack Tensor.
+        batch_inputs = stach_batch_imgs(inputs, self.pad_size_divisor,
+                                        self.pad_value)
+        return batch_inputs, batch_data_samples
diff --git a/mmengine/model/base_module.py b/mmengine/model/base_module.py
index 0a38bbf5..89f140dc 100644
--- a/mmengine/model/base_module.py
+++ b/mmengine/model/base_module.py
@@ -91,7 +91,7 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
         logger = MMLogger.get_current_instance()
         logger_name = logger.instance_name
 
-        from .utils.weight_init import initialize, update_init_info
+        from .utils import initialize, update_init_info
         module_name = self.__class__.__name__
         if not self._is_init:
             if self.init_cfg:
diff --git a/mmengine/model/utils/weight_init.py b/mmengine/model/utils.py
similarity index 84%
rename from mmengine/model/utils/weight_init.py
rename to mmengine/model/utils.py
index 1289d7f1..29cc779b 100644
--- a/mmengine/model/utils/weight_init.py
+++ b/mmengine/model/utils.py
@@ -1,11 +1,14 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import copy
+import logging
 import math
 import warnings
+from typing import List, Union
 
 import numpy as np
 import torch
 import torch.nn as nn
+import torch.nn.functional as F
 from torch import Tensor
 
 from mmengine.logging.logger import MMLogger, print_log
@@ -668,3 +671,131 @@ def trunc_normal_(tensor: Tensor,
         b (float): the maximum cutoff value.
     """
     return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+def stach_batch_imgs(tensor_list: List[torch.Tensor],
+                     pad_size_divisor: int = 1,
+                     pad_value: Union[int, float] = 0) -> torch.Tensor:
+    """Stack multiple tensors to form a batch and pad the images to the max
+    shape use the right bottom padding mode in these images. If
+    ``pad_size_divisor > 0``, add padding to ensure the shape of each dim is
+    divisible by ``pad_size_divisor``.
+
+    Args:
+        tensor_list (List[Tensor]): A list of tensors with the same dim.
+        pad_size_divisor (int): If ``pad_size_divisor > 0``, add padding
+            to ensure the shape of each dim is divisible by
+            ``pad_size_divisor``. This depends on the model, and many
+            models need to be divisible by 32. Defaults to 1
+        pad_value (int, float): The padding value. Defaults to 0.
+
+    Returns:
+       Tensor: The 4D-tensor.
+    """
+    assert isinstance(
+        tensor_list,
+        list), (f'Expected input type to be list, but got {type(tensor_list)}')
+    assert tensor_list, '`tensor_list` could not be an empty list'
+    assert len({
+        tensor.ndim
+        for tensor in tensor_list
+    }) == 1, (f'Expected the dimensions of all tensors must be the same, '
+              f'but got {[tensor.ndim for tensor in tensor_list]}')
+
+    dim = tensor_list[0].dim()
+    num_img = len(tensor_list)
+    all_sizes: torch.Tensor = torch.Tensor(
+        [tensor.shape for tensor in tensor_list])
+    max_sizes = torch.ceil(
+        torch.max(all_sizes, dim=0)[0] / pad_size_divisor) * pad_size_divisor
+    padded_sizes = max_sizes - all_sizes
+    # The first dim normally means channel,  which should not be padded.
+    padded_sizes[:, 0] = 0
+    if padded_sizes.sum() == 0:
+        return torch.stack(tensor_list)
+    # `pad` is the second arguments of `F.pad`. If pad is (1, 2, 3, 4),
+    # it means that padding the last dim with 1(left) 2(right), padding the
+    # penultimate dim to 3(top) 4(bottom). The order of `pad` is opposite of
+    # the `padded_sizes`. Therefore, the `padded_sizes` needs to be reversed,
+    # and only odd index of pad should be assigned to keep padding "right" and
+    # "bottom".
+    pad = torch.zeros(num_img, 2 * dim, dtype=torch.int)
+    pad[:, 1::2] = padded_sizes[:, range(dim - 1, -1, -1)]
+    batch_tensor = []
+    for idx, tensor in enumerate(tensor_list):
+        batch_tensor.append(
+            F.pad(tensor, tuple(pad[idx].tolist()), value=pad_value))
+    return torch.stack(batch_tensor)
+
+
+def detect_anomalous_params(loss: torch.Tensor, model) -> None:
+    parameters_in_graph = set()
+    visited = set()
+
+    def traverse(grad_fn):
+        if grad_fn is None:
+            return
+        if grad_fn not in visited:
+            visited.add(grad_fn)
+            if hasattr(grad_fn, 'variable'):
+                parameters_in_graph.add(grad_fn.variable)
+            parents = grad_fn.next_functions
+            if parents is not None:
+                for parent in parents:
+                    grad_fn = parent[0]
+                    traverse(grad_fn)
+
+    traverse(loss.grad_fn)
+    from mmengine import MMLogger
+    logger = MMLogger.get_current_instance()
+    for n, p in model.named_parameters():
+        if p not in parameters_in_graph and p.requires_grad:
+            logger.log(
+                level=logging.ERROR,
+                msg=f'{n} with shape {p.size()} is not '
+                f'in the computational graph \n')
+
+
+def merge_dict(*args):
+    """Merge all dictionaries into one dictionary.
+
+    If pytorch version >= 1.8, ``merge_dict`` will be wrapped
+    by ``torch.fx.wrap``,  which will make ``torch.fx.symbolic_trace`` skip
+    trace ``merge_dict``.
+
+    Note:
+        If a function needs to be traced by ``torch.fx.symbolic_trace``,
+        but inevitably needs to use ``update`` method of ``dict``(``update``
+        is not traceable). It should use ``merge_dict`` to replace
+        ``xxx.update``.
+
+    Args:
+        *args: dictionary needs to be merged.
+
+    Returns:
+        dict: Merged dict from args
+    """
+    output = dict()
+    for item in args:
+        assert isinstance(
+            item,
+            dict), (f'all arguments of merge_dict should be a dict, but got '
+                    f'{type(item)}')
+        output.update(item)
+    return output
+
+
+# torch.fx is only available when pytorch version >= 1.8.
+# If the subclass of `BaseModel` has multiple submodules, and each module
+# will return a loss dict during training process, i.e., `TwoStageDetector`
+# in mmdet. It should use `merge_dict` to get the total loss, rather than
+# `loss.update` to keep model traceable.
+try:
+    import torch.fx
+
+    # make torch.fx skip trace `merge_dict`.
+    merge_dict = torch.fx.wrap(merge_dict)
+
+except ImportError:
+    warnings.warn('Cannot import torch.fx, `merge_dict` is a simple function '
+                  'to merge multiple dicts')
diff --git a/mmengine/model/wrappers/__init__.py b/mmengine/model/wrappers/__init__.py
index 1cab521d..d6ece713 100644
--- a/mmengine/model/wrappers/__init__.py
+++ b/mmengine/model/wrappers/__init__.py
@@ -1,5 +1,9 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-from .data_parallel import MMDataParallel, MMDistributedDataParallel
+from .distributed import MMDistributedDataParallel
+from .seperate_distributed import MMSeparateDistributedDataParallel
 from .utils import is_model_wrapper
 
-__all__ = ['MMDistributedDataParallel', 'MMDataParallel', 'is_model_wrapper']
+__all__ = [
+    'MMDistributedDataParallel', 'is_model_wrapper',
+    'MMSeparateDistributedDataParallel'
+]
diff --git a/mmengine/model/wrappers/data_parallel.py b/mmengine/model/wrappers/data_parallel.py
deleted file mode 100644
index d31b009c..00000000
--- a/mmengine/model/wrappers/data_parallel.py
+++ /dev/null
@@ -1,149 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from itertools import chain
-
-import torch
-from torch.nn.parallel import DataParallel
-from torch.nn.parallel.distributed import (DistributedDataParallel,
-                                           _find_tensors)
-
-from mmengine.registry import MODEL_WRAPPERS
-from mmengine.utils import TORCH_VERSION, digit_version
-
-MODEL_WRAPPERS.register_module(module=DataParallel)
-MODEL_WRAPPERS.register_module(module=DistributedDataParallel)
-
-
-@MODEL_WRAPPERS.register_module()
-class MMDataParallel(DataParallel):
-    """There is no difference between MMDataParallel and pytorch's
-    DataParallel, "train_step" and "val_step" are added just to avoid bc
-    breaking.
-
-    Warning:
-        MMDataParallel only supports single GPU training, if you
-        need to  train with multiple GPUs, please use MMDistributedDataParallel
-        instead. If you have multiple GPUs and you just want to use
-        MMDataParallel, you can set the environment variable
-        ``CUDA_VISIBLE_DEVICES=0`` or instantiate ``MMDataParallel`` with
-        ``device_ids=[0]``.
-    """
-
-    def train_step(self, *inputs, **kwargs):
-        assert len(self.device_ids) == 1, \
-            ('MMDataParallel only supports single GPU training, if you need to'
-             ' train with multiple GPUs, please use MMDistributedDataParallel'
-             ' instead.')
-        assert hasattr(self.module, 'train_step')
-        for t in chain(self.module.parameters(), self.module.buffers()):
-            if t.device != self.src_device_obj:
-                raise RuntimeError(
-                    'module must have its parameters and buffers '
-                    f'on device {self.src_device_obj} (device_ids[0]) but '
-                    f'found one of them on device: {t.device}')
-        return self.module.train_step(*inputs, **kwargs)
-
-    def val_step(self, *inputs, **kwargs):
-        assert len(self.device_ids) == 1, \
-            ('MMDataParallel only supports single GPU training, if you need to'
-             ' train with multiple GPUs, please use MMDistributedDataParallel'
-             ' instead.')
-        assert hasattr(self.module, 'val_step')
-        for t in chain(self.module.parameters(), self.module.buffers()):
-            if t.device != self.src_device_obj:
-                raise RuntimeError(
-                    'module must have its parameters and buffers '
-                    f'on device {self.src_device_obj} (device_ids[0]) but '
-                    f'found one of them on device: {t.device}')
-        return self.module.val_step(*inputs, **kwargs)
-
-
-@MODEL_WRAPPERS.register_module()
-class MMDistributedDataParallel(DistributedDataParallel):
-    """There is no difference between MMDistributedDataParallel and pytorch's
-    DistributedDataParallel, "train_step" and "val_step" are added just to
-    avoid bc breaking."""
-
-    def train_step(self, *inputs, **kwargs):
-        """train_step() API for module wrapped by DistributedDataParallel.
-
-        This method is basically the same as
-        ``DistributedDataParallel.forward()``, while replacing
-        ``self.module.forward()`` with ``self.module.train_step()``.
-        It is compatible with PyTorch 1.1 - 1.5.
-        """
-
-        # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
-        # end of backward to the beginning of forward.
-        if ('parrots' not in TORCH_VERSION
-                and digit_version(TORCH_VERSION) >= digit_version('1.7')
-                and self.reducer._rebuild_buckets()):
-            # TODO: replace with logger
-            print('Reducer buckets have been rebuilt in this iteration.')
-
-        if getattr(self, 'require_forward_param_sync', True):
-            self._sync_params()
-
-        if self.device_ids:
-            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
-            if len(self.device_ids) == 1:
-                output = self.module.train_step(*inputs[0], **kwargs[0])
-            else:
-                outputs = self.parallel_apply(
-                    self._module_copies[:len(inputs)], inputs, kwargs)
-                output = self.gather(outputs, self.output_device)
-        else:
-            output = self.module.train_step(*inputs, **kwargs)
-
-        if torch.is_grad_enabled() and getattr(
-                self, 'require_backward_grad_sync', True):
-            if self.find_unused_parameters:
-                self.reducer.prepare_for_backward(list(_find_tensors(output)))
-            else:
-                self.reducer.prepare_for_backward([])
-        else:
-            if ('parrots' not in TORCH_VERSION
-                    and digit_version(TORCH_VERSION) > digit_version('1.2')):
-                self.require_forward_param_sync = False
-        return output
-
-    def val_step(self, *inputs, **kwargs):
-        """val_step() API for module wrapped by DistributedDataParallel.
-
-        This method is basically the same as
-        ``DistributedDataParallel.forward()``, while replacing
-        ``self.module.forward()`` with ``self.module.val_step()``.
-        It is compatible with PyTorch 1.1 - 1.5.
-        """
-
-        # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
-        # end of backward to the beginning of forward.
-        if ('parrots' not in TORCH_VERSION
-                and digit_version(TORCH_VERSION) >= digit_version('1.7')
-                and self.reducer._rebuild_buckets()):
-            # TODO: replace with logger
-            print('Reducer buckets have been rebuilt in this iteration.')
-
-        if getattr(self, 'require_forward_param_sync', True):
-            self._sync_params()
-        if self.device_ids:
-            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
-            if len(self.device_ids) == 1:
-                output = self.module.val_step(*inputs[0], **kwargs[0])
-            else:
-                outputs = self.parallel_apply(
-                    self._module_copies[:len(inputs)], inputs, kwargs)
-                output = self.gather(outputs, self.output_device)
-        else:
-            output = self.module.val_step(*inputs, **kwargs)
-
-        if torch.is_grad_enabled() and getattr(
-                self, 'require_backward_grad_sync', True):
-            if self.find_unused_parameters:
-                self.reducer.prepare_for_backward(list(_find_tensors(output)))
-            else:
-                self.reducer.prepare_for_backward([])
-        else:
-            if ('parrots' not in TORCH_VERSION
-                    and digit_version(TORCH_VERSION) > digit_version('1.2')):
-                self.require_forward_param_sync = False
-        return output
diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py
new file mode 100644
index 00000000..4084dde7
--- /dev/null
+++ b/mmengine/model/wrappers/distributed.py
@@ -0,0 +1,123 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, List
+
+import torch
+from torch.nn.parallel.distributed import DistributedDataParallel
+
+from mmengine.data import BaseDataElement
+from mmengine.optim import OptimWrapper
+from mmengine.registry import MODEL_WRAPPERS
+from ..utils import detect_anomalous_params
+
+
+@MODEL_WRAPPERS.register_module()
+class MMDistributedDataParallel(DistributedDataParallel):
+    """A distributed model wrapper used for training,testing and validation in
+    loop.
+
+    Different from DistributedDataParallel, MMDistributedDataParallel
+    implements three methods :meth:`train_step`, :meth:`val_step` and
+    :meth:`test_step`, which will be called by ``train_loop``, ``val_loop``
+    and ``test_loop``.
+
+    - ``train_step``: Called by ``runner.train_loop``, and implement
+      default model forward, gradient back propagation, parameter updating
+      logic. To take advantage of DistributedDataParallel's automatic gradient
+      synchronization, ``train_step`` calls ``DistributedDataParallel.forward``
+      to calculate the losses, and call other methods of :obj:`BaseModel` to
+      pre-process data and parse losses. Finally, update model parameters by
+      :obj:``OptimWrapper`` and return the loss dictionary used for logging.
+
+    - ``val_step``: Called by ``runner.val_loop`` and get the inference
+      results. Since there is no gradient synchronization requirement,
+      this procedure is equivalent to ``BaseModel.val_step``
+
+    - ``test_step``: Called by ``runner.test_loop``, equivalent ``val_step``.
+
+    Args:
+        detect_anomalous_params (bool): This option is only used for
+            debugging which will slow down the training speed.
+            Detect anomalous parameters that are not included in
+            the computational graph with `loss` as the root.
+            There are two cases
+
+                - Parameters were not used during
+                  forward pass.
+                - Parameters were not used to produce
+                  loss.
+            Default: False.
+
+        *args: list arguments passed to ``DistributedDataParallel``
+        **kwargs: keyword arguments passed to ``DistributedDataParallel``.
+
+    Note:
+        If model has multiple submodules and each module has
+        separate optimization strategies,
+        :class:`MMSeparateDistributedDataParallel` should be used to wrap
+        the model.
+
+    Note:
+        If model itself has custom optimization strategy, rather than
+        simply forward model and update model. A custom model wrapper
+        inherit from ``MMDistributedDataParallel`` should be defined and
+        override the ``train_step`` method.
+    """
+
+    def __init__(self, detect_anomalous_params: bool = False, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.detect_anomalous_params = detect_anomalous_params
+
+    def train_step(self, data: List[dict],
+                   optim_wrapper: OptimWrapper) -> Dict[str, torch.Tensor]:
+        """Interface for model forward, backward and parameters updating during
+        training process.
+
+        :meth:`train_step` will perform the following steps in order:
+
+        - If :attr:`module` defines the preprocess method,
+            call ``module.preprocess`` to pre-processing data.
+        - Call ``module.forward(**data)`` and get losses.
+        - Parse losses.
+        - Call ``optim_wrapper.optimizer_step`` to update parameters.
+        - Return log messages of losses.
+
+        Args:
+            data (List[dict]): Data sampled by dataloader.
+            optim_wrapper (OptimWrapper): A wrapper of optimizer to
+                update parameters.
+
+        Returns:
+            Dict[str, torch.Tensor]: A ``dict`` of tensor for logging.
+        """
+        # enable automatic mixed precision training context.
+        with optim_wrapper.precision_context():
+            batch_inputs, data_samples = self.module.data_preprocessor(
+                data, training=True)
+            losses = self(batch_inputs, data_samples, mode='loss')
+        if self.detect_anomalous_params:
+            detect_anomalous_params(losses, model=self)
+        parsed_loss, log_vars = self.module.parse_losses(losses)
+        optim_wrapper.update_params(parsed_loss)
+        return log_vars
+
+    def val_step(self, data: List[dict]) -> List[BaseDataElement]:
+        """Gets the prediction of module during validation process.
+
+        Args:
+            data (List[dict]): Data sampled by dataloader.
+
+        Returns:
+            List[BaseDataElement] or dict: The predictions of given data.
+        """
+        return self.module.val_step(data)
+
+    def test_step(self, data: List[dict]) -> List[BaseDataElement]:
+        """Gets the predictions of module during testing process.
+
+        Args:
+            data: Data sampled by dataloader.
+
+        Returns:
+            List[BaseDataElement]: The predictions of given data.
+        """
+        return self.module.test_step(data)
diff --git a/mmengine/model/wrappers/seperate_distributed.py b/mmengine/model/wrappers/seperate_distributed.py
new file mode 100644
index 00000000..a369be06
--- /dev/null
+++ b/mmengine/model/wrappers/seperate_distributed.py
@@ -0,0 +1,124 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from contextlib import ExitStack, contextmanager
+from typing import Dict, List
+
+import torch
+import torch.nn as nn
+from torch.nn.parallel.distributed import DistributedDataParallel
+
+from mmengine.data import BaseDataElement
+from mmengine.optim import OptimWrapperDict
+from mmengine.registry import MODEL_WRAPPERS
+from .distributed import MMDistributedDataParallel
+
+
+@MODEL_WRAPPERS.register_module()
+class MMSeparateDistributedDataParallel(DistributedDataParallel):
+    """A DistributedDataParallel wrapper for models in MMGeneration.
+
+    In MMedting and MMGeneration there is a need to wrap different modules in
+    the models with separate DistributedDataParallel. Otherwise, it will cause
+    errors for GAN training. For example, the GAN model, usually has two
+    submodules: generator and discriminator. If we wrap both of them in one
+    standard DistributedDataParallel, it will cause errors during training,
+    because when we update the parameters of the generator (or discriminator),
+    the parameters of the discriminator (or generator) is not updated, which is
+    not allowed for DistributedDataParallel. So we design this wrapper to
+    separately wrap DistributedDataParallel for generator and discriminator.
+    In this wrapper, we perform two operations:
+
+    1. Wraps each module in the models with separate MMDistributedDataParallel.
+       Note that only modules with parameters will be wrapped.
+    2. Calls ``train_step``, ``val_step`` and ``test_step`` of submodules to
+       get losses and predictions.
+
+    Args:
+        module (nn.Module): model contain multiple submodules which have
+            separately updating strategy.
+        *args: list arguments passed to ``MMDistributedDataParallel``
+        **kwargs: keyword arguments passed to ``MMDistributedDataParallel``.
+    """
+
+    def __init__(self, module: nn.Module, *args, **kwargs):
+        super(DistributedDataParallel, self).__init__()
+        self.module = module
+        # Wrap the submodule with parameters of `self.module` to
+        # `MMDistributedDataParallel`
+        for name, _module in module._modules.items():
+            # module without parameters.
+            if next(_module.parameters(), None) is None:
+                _module = _module.cuda()
+            elif all(not p.requires_grad for p in module.parameters()):
+                _module = _module.cuda()
+            else:
+                _module = MMDistributedDataParallel(
+                    module=_module.cuda(), *args, **kwargs)
+            module._modules[name] = _module
+
+    def train_step(self, data: List[dict],
+                   optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]:
+        """Interface for model forward, backward and parameters updating during
+        training process.
+
+        Args:
+            data: Data sampled by dataloader.
+            optim_wrapper (OptimWrapperDict): A wrapper of optimizer to
+                update parameters.
+
+        Returns:
+            Dict[str, torch.Tensor]: A dict of tensor for logging.
+        """
+        return self.module.train_step(data, optim_wrapper)
+
+    def val_step(self, data) -> List[BaseDataElement]:
+        """Gets the prediction of module during validation process.
+
+        Args:
+            data (List[dict]): Data sampled by dataloader.
+
+        Returns:
+            List[BaseDataElement]: The predictions of given data.
+        """
+        return self.module.val_step(data)
+
+    def test_step(self, data: List[dict]) -> List[BaseDataElement]:
+        """Gets the predictions of module during testing process.
+
+        Args:
+            data: Data sampled by dataloader.
+
+        Returns:
+            ForwardResults: The predictions of given data.
+        """
+        return self.module.test_step(data)
+
+    @contextmanager
+    def no_sync(self):
+        """Enables ``no_sync`` context of all sub ``MMDistributedDataParallel``
+        modules."""
+        with ExitStack() as stack:
+            for sub_ddp_model in self.module._modules.values():
+                stack.enter_context(sub_ddp_model.no_sync())
+                yield
+
+    def train(self, mode: bool = True) -> 'MMSeparateDistributedDataParallel':
+        """Sets the module in training mode.
+
+        In order to make the ddp wrapper inheritance hierarchy more uniform,
+        ``MMSeparateDistributedDataParallel`` inherits from
+        ``DistributedDataParallel``, but will not call its constructor.
+        Since the attributes of ``DistributedDataParallel`` have not been
+        initialized, call the ``train`` method of ``DistributedDataParallel``
+        will raise an error if pytorch version <= 1.9. Therefore, override
+        this method to call the ``train`` method of submodules.
+
+        Args:
+            mode (bool): whether to set training mode (``True``) or evaluation
+                 mode (``False``). Default: ``True``.
+
+        Returns:
+            Module: self.
+        """
+        self.training = mode
+        self.module.train(mode)
+        return self
diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py
index 482207dd..6e186855 100644
--- a/mmengine/runner/loops.py
+++ b/mmengine/runner/loops.py
@@ -99,15 +99,20 @@ class EpochBasedTrainLoop(BaseLoop):
         """
         self.runner.call_hook(
             'before_train_iter', batch_idx=idx, data_batch=data_batch)
-        # outputs should be a dict containing one or multiple loss tensors
-        self.runner.outputs = self.runner.model(data_batch, return_loss=True)
+        # Enable gradient accumulation mode and avoid unnecessary gradient
+        # synchronization during gradient accumulation process.
+        with self.runner.optim_wrapper.accumulate_grad(self.runner.model,
+                                                       self._iter,
+                                                       self._max_iters):
+            # outputs should be a dict of loss.
+            outputs = self.runner.model.train_step(
+                data_batch, optim_wrapper=self.runner.optim_wrapper)
 
         self.runner.call_hook(
             'after_train_iter',
             batch_idx=idx,
             data_batch=data_batch,
-            outputs=self.runner.outputs)
-
+            outputs=outputs)
         self._iter += 1
 
 
@@ -197,14 +202,21 @@ class IterBasedTrainLoop(BaseLoop):
         """
         self.runner.call_hook(
             'before_train_iter', batch_idx=self._iter, data_batch=data_batch)
-        # outputs should be a dict containing loss tensor
-        self.runner.outputs = self.runner.model(data_batch, return_loss=True)
+        # Enable gradient accumulation mode and avoid unnecessary gradient
+        # synchronization during gradient accumulation process.
+        with self.runner.optim_wrapper.accumulate_grad(self.runner.model,
+                                                       self._iter,
+                                                       self._max_iters):
+            # train_logs should be a dict of loss.
+            train_logs = self.runner.model.train_step(
+                data_batch, optim_wrapper=self.runner.optim_wrapper)
+        self.runner.message_hub.update_info('train_logs', train_logs)
 
         self.runner.call_hook(
             'after_train_iter',
             batch_idx=self._iter,
             data_batch=data_batch,
-            outputs=self.runner.outputs)
+            outputs=train_logs)
         self._iter += 1
 
 
@@ -247,7 +259,6 @@ class ValLoop(BaseLoop):
 
         # compute metrics
         metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
-
         self.runner.call_hook('after_val_epoch', metrics=metrics)
         self.runner.call_hook('after_val')
 
@@ -262,7 +273,7 @@ class ValLoop(BaseLoop):
         self.runner.call_hook(
             'before_val_iter', batch_idx=idx, data_batch=data_batch)
         # outputs should be sequence of BaseDataElement
-        outputs = self.runner.model(data_batch)
+        outputs = self.runner.model.val_step(data_batch)
         self.evaluator.process(data_batch, outputs)
         self.runner.call_hook(
             'after_val_iter',
@@ -310,7 +321,6 @@ class TestLoop(BaseLoop):
 
         # compute metrics
         metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
-
         self.runner.call_hook('after_test_epoch', metrics=metrics)
         self.runner.call_hook('after_test')
 
@@ -324,7 +334,7 @@ class TestLoop(BaseLoop):
         self.runner.call_hook(
             'before_test_iter', batch_idx=idx, data_batch=data_batch)
         # predictions should be sequence of BaseDataElement
-        predictions = self.runner.model(data_batch)
+        predictions = self.runner.model.test_step(data_batch)
         self.evaluator.process(data_batch, predictions)
         self.runner.call_hook(
             'after_test_iter',
diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py
index aff5c6ba..38355db5 100644
--- a/mmengine/runner/runner.py
+++ b/mmengine/runner/runner.py
@@ -13,7 +13,7 @@ from typing import Callable, Dict, List, Optional, Sequence, Union
 import numpy as np
 import torch
 import torch.nn as nn
-from torch.nn.parallel import DistributedDataParallel
+from torch.nn.parallel.distributed import DistributedDataParallel
 from torch.optim import Optimizer
 from torch.utils.data import DataLoader
 
@@ -25,7 +25,8 @@ from mmengine.dist import (broadcast, get_dist_info, get_rank, init_dist,
 from mmengine.evaluator import Evaluator
 from mmengine.hooks import Hook
 from mmengine.logging import LogProcessor, MessageHub, MMLogger
-from mmengine.model import is_model_wrapper
+from mmengine.model import (BaseModel, MMDistributedDataParallel,
+                            is_model_wrapper)
 from mmengine.optim import (OptimWrapper, OptimWrapperDict, _ParamScheduler,
                             build_optim_wrapper)
 from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
@@ -743,7 +744,7 @@ class Runner:
                 'visualizer should be Visualizer object, a dict or None, '
                 f'but got {visualizer}')
 
-    def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module:
+    def build_model(self, model: Union[BaseModel, Dict]) -> BaseModel:
         """Build model.
 
         If ``model`` is a dict, it will be used to build a nn.Module object
@@ -755,28 +756,30 @@ class Runner:
             model = dict(type='ResNet')
 
         Args:
-            model (nn.Module or dict): A nn.Module object or a dict to build
+            model (BaseModel or dict): A nn.Module object or a dict to build
                 nn.Module object. If ``model`` is a nn.Module object, just
                 returns itself.
 
         Returns:
             nn.Module: Model build from ``model``.
         """
-        if isinstance(model, nn.Module):
+        if isinstance(model, BaseModel):
             return model
         elif isinstance(model, dict):
             model = MODELS.build(model)
             # init weights
-            if hasattr(model, 'init_weights'):
-                model.init_weights()
-            return model
+            if hasattr(model, 'init_weights'):  # type: ignore
+                model.init_weights()  # type: ignore
+            return model  # type: ignore
         else:
             raise TypeError('model should be a nn.Module object or dict, '
                             f'but got {model}')
 
-    def wrap_model(self, model_wrapper_cfg: Optional[Dict],
-                   model: nn.Module) -> nn.Module:
-        """Wrap model.
+    def wrap_model(
+            self, model_wrapper_cfg: Optional[Dict],
+            model: BaseModel) -> Union[DistributedDataParallel, BaseModel]:
+        """Wrap the model to :obj:``MMDistributedDataParallel`` or other custom
+        distributed data-parallel module wrappers.
 
         An example of ``model_wrapper_cfg``::
 
@@ -789,10 +792,11 @@ class Runner:
             model_wrapper_cfg (dict, optional): Config to wrap model. If not
                 specified, ``DistributedDataParallel`` will be used in
                 distributed environment. Defaults to None.
-            model (nn.Module): Model to be wrapped.
+            model (BaseModel): Model to be wrapped.
 
         Returns:
-            nn.Module: Wrapped model.
+            BaseModel or DistributedDataParallel: BaseModel or subclass of
+            ``DistributedDataParallel``.
         """
         if is_model_wrapper(model):
             if model_wrapper_cfg is not None:
@@ -802,25 +806,26 @@ class Runner:
 
             return model
 
+        # Set `export CUDA_VISIBLE_DEVICES=-1` to enable CPU training.
+        if torch.cuda.is_available():
+            model = model.cuda()
+
+        if not self.distributed:
+            return model
+
         if model_wrapper_cfg is None:
-            if self.distributed:
-                find_unused_parameters = self.cfg.get('find_unused_parameters',
-                                                      False)
-                # Sets the `find_unused_parameters` parameter in
-                # torch.nn.parallel.DistributedDataParallel
-                model = DistributedDataParallel(
-                    self.model.cuda(),
-                    device_ids=[torch.cuda.current_device()],
-                    broadcast_buffers=False,
-                    find_unused_parameters=find_unused_parameters)
-            else:
-                # Set `export CUDA_VISIBLE_DEVICES=-1` can enable CPU training.
-                if torch.cuda.is_available():
-                    model = model.cuda()
+            find_unused_parameters = self.cfg.get('find_unused_parameters',
+                                                  False)
+            # Sets the `find_unused_parameters` parameter in
+            # torch.nn.parallel.DistributedDataParallel
+            model = MMDistributedDataParallel(
+                module=model,
+                device_ids=[torch.cuda.current_device()],
+                broadcast_buffers=False,
+                find_unused_parameters=find_unused_parameters)
         else:
             model = MODEL_WRAPPERS.build(
-                model_wrapper_cfg, default_args=dict(model=self.model))
-
+                model_wrapper_cfg, default_args=dict(module=model))
         return model
 
     def scale_lr(self,
@@ -908,10 +913,17 @@ class Runner:
                 dict to build OptimWrapper objects. If ``optim_wrapper`` is an
                 OptimWrapper, just return an ``OptimizeWrapper`` instance.
 
+        Note:
+            For single optimizer training, if `optim_wrapper` is a config
+            dict, `type` is optional(defaults to :obj:`OptimWrapper`) and it
+            must contain `optimizer` to build the corresponding optimizer.
+
         Examples:
             >>> # build an optimizer
             >>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict(
             ...     type='SGD', lr=0.01))
+            >>> # optim_wrapper_cfg = dict(optimizer=dict(type='SGD', lr=0.01))
+            >>> # is also valid.
             >>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg)
             >>> optim_wrapper
             Type: OptimWrapper
@@ -1648,8 +1660,6 @@ class Runner:
         +======================+=========================+
         | RuntimeInfoHook      | VERY_HIGH (10)          |
         +----------------------+-------------------------+
-        | OptimizerHook        | HIGH (30)               |
-        +----------------------+-------------------------+
         | IterTimerHook        | NORMAL (40)             |
         +----------------------+-------------------------+
         | DistSamplerSeedHook  | NORMAL (40)             |
@@ -1666,7 +1676,6 @@ class Runner:
 
             default_hooks = dict(
                 runtime_info=dict(type='RuntimeInfoHook'),
-                optimizer=dict(type='OptimizerHook', grad_clip=None),
                 timer=dict(type='IterTimerHook'),
                 sampler_seed=dict(type='DistSamplerSeedHook'),
                 logger=dict(type='LoggerHook'),
@@ -1689,7 +1698,6 @@ class Runner:
         """
         default_hooks: dict = dict(
             runtime_info=dict(type='RuntimeInfoHook'),
-            optimizer=dict(type='OptimizerHook', grad_clip=None),
             timer=dict(type='IterTimerHook'),
             logger=dict(type='LoggerHook'),
             param_scheduler=dict(type='ParamSchedulerHook'),
diff --git a/mmengine/utils/misc.py b/mmengine/utils/misc.py
index 13b8e322..5e151561 100644
--- a/mmengine/utils/misc.py
+++ b/mmengine/utils/misc.py
@@ -227,7 +227,7 @@ def check_prerequisites(
         prerequisites,
         checker,
         msg_tmpl='Prerequisites "{}" are required in method "{}" but not '
-        'found, please install them first.'):  # yapf: disable
+                 'found, please install them first.'):  # yapf: disable
     """A decorator factory to check if prerequisites are satisfied.
 
     Args:
@@ -341,7 +341,6 @@ def deprecated_api_warning(name_dict: dict,
             if kwargs:
                 for src_arg_name, dst_arg_name in name_dict.items():
                     if src_arg_name in kwargs:
-
                         assert dst_arg_name not in kwargs, (
                             f'The expected behavior is to replace '
                             f'the deprecated key `{src_arg_name}` to '
@@ -471,7 +470,7 @@ def tensor2imgs(tensor: torch.Tensor,
     if std is None:
         std = (1, ) * channels
     assert (channels == len(mean) == len(std) == 3) or \
-        (channels == len(mean) == len(std) == 1 and not to_bgr)
+           (channels == len(mean) == len(std) == 1 and not to_bgr)
     mean = tensor.new_tensor(mean).view(1, -1)
     std = tensor.new_tensor(std).view(1, -1)
     tensor = tensor.permute(0, 2, 3, 1) * std + mean
diff --git a/tests/test_hook/test_ema_hook.py b/tests/test_hook/test_ema_hook.py
index 4cae6c83..0da87f1a 100644
--- a/tests/test_hook/test_ema_hook.py
+++ b/tests/test_hook/test_ema_hook.py
@@ -9,7 +9,7 @@ import torch.nn as nn
 from torch.utils.data import Dataset
 
 from mmengine.hooks import EMAHook
-from mmengine.model import ExponentialMovingAverage
+from mmengine.model import BaseModel, ExponentialMovingAverage
 from mmengine.optim import OptimWrapper
 from mmengine.registry import DATASETS, MODEL_WRAPPERS
 from mmengine.runner import Runner
@@ -21,25 +21,28 @@ class ToyModel(nn.Module):
         super().__init__()
         self.linear = nn.Linear(2, 1)
 
-    def forward(self, data_batch, return_loss=False):
-        inputs, labels = [], []
-        for x in data_batch:
-            inputs.append(x['inputs'])
-            labels.append(x['data_sample'])
-
-        device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
-        inputs = torch.stack(inputs).to(device)
-        labels = torch.stack(labels).to(device)
-        outputs = self.linear(inputs)
-        if return_loss:
+    def forward(self, batch_inputs, labels, mode='tensor'):
+        labels = torch.stack(labels)
+        outputs = self.linear(batch_inputs)
+        if mode == 'tensor':
+            return outputs
+        elif mode == 'loss':
             loss = (labels - outputs).sum()
-            outputs = dict(loss=loss, log_vars=dict(loss=loss.item()))
+            outputs = dict(loss=loss)
             return outputs
         else:
-            outputs = dict(log_vars=dict(a=1, b=0.5))
             return outputs
 
 
+class ToyModel1(BaseModel, ToyModel):
+
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, *args, **kwargs):
+        return super(BaseModel, self).forward(*args, **kwargs)
+
+
 @DATASETS.register_module()
 class DummyDataset(Dataset):
     METAINFO = dict()  # type: ignore
@@ -67,7 +70,7 @@ class TestEMAHook(TestCase):
 
     def test_ema_hook(self):
         device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
-        model = ToyModel().to(device)
+        model = ToyModel1().to(device)
         evaluator = Mock()
         evaluator.evaluate = Mock(return_value=dict(acc=0.5))
         runner = Runner(
@@ -121,7 +124,7 @@ class TestEMAHook(TestCase):
         runner.test()
 
         @MODEL_WRAPPERS.register_module()
-        class DummyWrapper(nn.Module):
+        class DummyWrapper(BaseModel):
 
             def __init__(self, model):
                 super().__init__()
@@ -132,7 +135,7 @@ class TestEMAHook(TestCase):
 
         # with model wrapper
         runner = Runner(
-            model=DummyWrapper(model),
+            model=DummyWrapper(ToyModel()),
             test_dataloader=dict(
                 dataset=dict(type='DummyDataset'),
                 sampler=dict(type='DefaultSampler', shuffle=True),
diff --git a/tests/test_hook/test_optimizer_hook.py b/tests/test_hook/test_optimizer_hook.py
deleted file mode 100644
index 1e665a93..00000000
--- a/tests/test_hook/test_optimizer_hook.py
+++ /dev/null
@@ -1,115 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from unittest.mock import MagicMock, Mock
-
-import torch
-from torch import nn
-
-from mmengine.hooks import OptimizerHook
-
-
-class TestOptimizerHook:
-
-    def test_after_train_iter(self):
-
-        class Model(nn.Module):
-
-            def __init__(self):
-                super().__init__()
-                self.conv1 = nn.Conv2d(
-                    in_channels=1,
-                    out_channels=2,
-                    kernel_size=3,
-                    stride=1,
-                    padding=1,
-                    dilation=1)
-                self.conv2 = nn.Conv2d(
-                    in_channels=2,
-                    out_channels=2,
-                    kernel_size=3,
-                    stride=1,
-                    padding=1,
-                    dilation=1)
-                self.conv3 = nn.Conv2d(
-                    in_channels=1,
-                    out_channels=2,
-                    kernel_size=3,
-                    stride=1,
-                    padding=1,
-                    dilation=1)
-
-            def forward(self, x):
-                x1 = self.conv1(x)
-                x2 = self.conv2(x1)
-                return x1, x2
-
-        model = Model()
-        x = torch.rand(1, 1, 3, 3)
-
-        dummy_runner = MagicMock()
-        dummy_runner.optim_wrapper.zero_grad = Mock(return_value=None)
-        dummy_runner.optim_wrapper.step = Mock(return_value=None)
-        dummy_runner.model = model
-        dummy_runner.outputs = dict()
-
-        dummy_runner.outputs['num_samples'] = 0
-
-        class DummyLogger():
-
-            def __init__(self):
-                self.msg = ''
-
-            def log(self, msg=None, **kwargs):
-                self.msg += msg
-
-        dummy_runner.logger = DummyLogger()
-        optimizer_hook = OptimizerHook(
-            dict(max_norm=2), detect_anomalous_params=True)
-
-        dummy_runner.outputs['loss'] = model(x)[0].sum()
-
-        dummy_runner.outputs['loss'].backward = Mock(
-            wraps=dummy_runner.outputs['loss'].backward)
-        optimizer_hook.detect_anomalous_parameters = Mock(
-            wraps=optimizer_hook.detect_anomalous_parameters)
-        optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads)
-
-        optimizer_hook.after_train_iter(dummy_runner, 0)
-        # assert the parameters of conv2 and conv3 are not in the
-        # computational graph which is with x1.sum() as root.
-        assert 'conv2.weight' in dummy_runner.logger.msg
-        assert 'conv2.bias' in dummy_runner.logger.msg
-        assert 'conv3.weight' in dummy_runner.logger.msg
-        assert 'conv3.bias' in dummy_runner.logger.msg
-        assert 'conv1.weight' not in dummy_runner.logger.msg
-        assert 'conv1.bias' not in dummy_runner.logger.msg
-        dummy_runner.optim_wrapper.step.assert_called()
-        dummy_runner.outputs['loss'].backward.assert_called()
-        optimizer_hook.clip_grads.assert_called()
-        optimizer_hook.detect_anomalous_parameters.assert_called()
-
-        dummy_runner.outputs['loss'] = model(x)[1].sum()
-        dummy_runner.logger.msg = ''
-        optimizer_hook.after_train_iter(dummy_runner, 0)
-        # assert the parameters of conv3 are not in the computational graph
-        assert 'conv3.weight' in dummy_runner.logger.msg
-        assert 'conv3.bias' in dummy_runner.logger.msg
-        assert 'conv2.weight' not in dummy_runner.logger.msg
-        assert 'conv2.bias' not in dummy_runner.logger.msg
-        assert 'conv1.weight' not in dummy_runner.logger.msg
-        assert 'conv1.bias' not in dummy_runner.logger.msg
-
-        # grad_clip is None and detect_anomalous_parameters is False
-        optimizer_hook = OptimizerHook(detect_anomalous_params=False)
-        optimizer_hook.detect_anomalous_parameters = Mock(
-            wraps=optimizer_hook.detect_anomalous_parameters)
-        optimizer_hook.clip_grads = Mock(wraps=optimizer_hook.clip_grads)
-        dummy_runner.outputs['loss'] = model(x)[0].sum()
-        dummy_runner.outputs['loss'].backward = Mock(
-            wraps=dummy_runner.outputs['loss'].backward)
-
-        optimizer_hook.after_train_iter(dummy_runner, 0)
-
-        dummy_runner.optim_wrapper.step.assert_called()
-        dummy_runner.outputs['loss'].backward.assert_called()
-        optimizer_hook.clip_grads.assert_not_called()
-        optimizer_hook.detect_anomalous_parameters.assert_not_called()
diff --git a/tests/test_hook/test_runtime_info_hook.py b/tests/test_hook/test_runtime_info_hook.py
index 2eb651c5..b57e26eb 100644
--- a/tests/test_hook/test_runtime_info_hook.py
+++ b/tests/test_hook/test_runtime_info_hook.py
@@ -102,12 +102,7 @@ class TestRuntimeInfoHook(TestCase):
         runner.message_hub = message_hub
         hook = RuntimeInfoHook()
         hook.after_train_iter(
-            runner,
-            batch_idx=2,
-            data_batch=None,
-            outputs={'log_vars': {
-                'loss_cls': 1.111
-            }})
+            runner, batch_idx=2, data_batch=None, outputs={'loss_cls': 1.111})
         self.assertEqual(
             message_hub.get_scalar('train/loss_cls').current(), 1.111)
 
diff --git a/tests/test_logging/test_message_hub.py b/tests/test_logging/test_message_hub.py
index c36d239d..e64cb88a 100644
--- a/tests/test_logging/test_message_hub.py
+++ b/tests/test_logging/test_message_hub.py
@@ -101,7 +101,7 @@ class TestMessageHub:
         message_hub.update_scalar('lr', 0.1, resumed=False)
         # update runtime information
         message_hub.update_info('iter', 1, resumed=True)
-        message_hub.update_info('feat', [1, 2, 3], resumed=False)
+        message_hub.update_info('tensor', [1, 2, 3], resumed=False)
         obj = pickle.dumps(message_hub)
         instance = pickle.loads(obj)
 
diff --git a/tests/test_model/test_base_model/test_base_model.py b/tests/test_model/test_base_model/test_base_model.py
new file mode 100644
index 00000000..280fcead
--- /dev/null
+++ b/tests/test_model/test_base_model/test_base_model.py
@@ -0,0 +1,125 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import unittest
+from unittest import TestCase
+
+import torch
+import torch.nn as nn
+from torch.optim import SGD
+
+from mmengine.model import BaseDataPreprocessor, BaseModel
+from mmengine.optim import OptimWrapper
+from mmengine.registry import MODELS
+from mmengine.testing import assert_allclose
+
+
+@MODELS.register_module()
+class CustomDataPreprocessor(BaseDataPreprocessor):
+
+    def forward(self, data, training=False):
+        if training:
+            return 1
+        else:
+            return 2
+
+
+class ToyModel(BaseModel):
+
+    def __init__(self, data_preprocessor=None):
+        super().__init__(None, data_preprocessor=data_preprocessor)
+        self.conv = nn.Conv2d(3, 1, 1)
+
+    def forward(self, batch_inputs, data_samples=None, mode='tensor'):
+        if mode == 'loss':
+            out = self.conv(batch_inputs)
+            return dict(loss=out)
+        elif mode == 'predict':
+            out = self.conv(batch_inputs)
+            return out
+        elif mode == 'tensor':
+            out = self.conv(batch_inputs)
+            return out
+
+
+class TestBaseModel(TestCase):
+
+    def test_init(self):
+        # initiate model without `preprocess_cfg`
+        model = ToyModel()
+        self.assertIsInstance(model.data_preprocessor, BaseDataPreprocessor)
+        data_preprocessor = dict(type='CustomDataPreprocessor')
+        model = ToyModel(data_preprocessor=data_preprocessor)
+        self.assertIsInstance(model.data_preprocessor, CustomDataPreprocessor)
+        self.assertEqual(model.data_preprocessor(1, training=True), 1)
+        self.assertEqual(model.data_preprocessor(1, training=False), 2)
+
+        # initiate model with built `data_preprocessor`.
+        data_preprocessor = CustomDataPreprocessor()
+        model = ToyModel(data_preprocessor=data_preprocessor)
+        self.assertIs(model.data_preprocessor, data_preprocessor)
+
+        # initiate model with error type `data_preprocessor`.
+        with self.assertRaisesRegex(TypeError, 'data_preprocessor should be'):
+            ToyModel(data_preprocessor=[data_preprocessor])
+
+    def test_parse_losses(self):
+        model = ToyModel()
+        loss_cls = torch.tensor(1, dtype=torch.float32)
+        loss_list = [
+            torch.tensor(2, dtype=torch.float32),
+            torch.tensor(3, dtype=torch.float32)
+        ]
+        losses = dict(loss_cls=loss_cls, loss_list=loss_list)
+        target_parsed_losses = torch.tensor(6, dtype=torch.float32)
+        targe_log_vars = dict(
+            loss=torch.tensor(6, dtype=torch.float32),
+            loss_cls=torch.tensor(1, dtype=torch.float32),
+            loss_list=torch.tensor(5, dtype=torch.float32))
+        parse_losses, log_vars = model.parse_losses(losses)
+        assert_allclose(parse_losses, target_parsed_losses)
+        for key in log_vars:
+            self.assertIn(key, targe_log_vars)
+            assert_allclose(log_vars[key], targe_log_vars[key])
+
+        with self.assertRaises(TypeError):
+            losses['error_key'] = dict()
+            model.parse_losses(losses)
+
+    def test_train_step(self):
+        model = ToyModel()
+        optimizer = SGD(model.parameters(), lr=0.1)
+        optim_wrapper = OptimWrapper(optimizer)
+        inputs = torch.randn(3, 1, 1)
+        data = dict(inputs=inputs)
+        # initiate grad.
+        # model.conv.weight.grad = torch.randn(1, 3, 1, 1)
+        log_vars = model.train_step([data], optim_wrapper)
+        self.assertIsNotNone(model.conv.weight.grad)
+        self.assertIsInstance(log_vars['loss'], torch.Tensor)
+
+    def test_val_step(self):
+        inputs = torch.randn(3, 1, 1)
+        data = dict(inputs=inputs)
+        model = ToyModel()
+        out = model.val_step([data])
+        self.assertIsInstance(out, torch.Tensor)
+
+    def test_test_step(self):
+        inputs = torch.randn(3, 1, 1)
+        data = dict(inputs=inputs)
+        model = ToyModel()
+        out = model.val_step([data])
+        self.assertIsInstance(out, torch.Tensor)
+
+    @unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available')
+    def test_cuda(self):
+        inputs = torch.randn(3, 1, 1).cuda()
+        data = dict(inputs=inputs)
+        model = ToyModel().cuda()
+        model.val_step([data])
+
+    @unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available')
+    def test_to(self):
+        inputs = torch.randn(3, 1, 1).cuda()
+        data = dict(inputs=inputs)
+        model = ToyModel().to(torch.cuda.current_device())
+        model.val_step([data])
diff --git a/tests/test_model/test_base_model/test_data_preprocessor.py b/tests/test_model/test_base_model/test_data_preprocessor.py
new file mode 100644
index 00000000..146ed35e
--- /dev/null
+++ b/tests/test_model/test_base_model/test_data_preprocessor.py
@@ -0,0 +1,110 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest import TestCase
+
+import torch
+import torch.nn.functional as F
+
+from mmengine import InstanceData
+from mmengine.model import BaseDataPreprocessor, ImgDataPreprocessor
+from mmengine.testing import assert_allclose
+
+
+class TestBaseDataPreprocessor(TestCase):
+
+    def test_init(self):
+        base_data_preprocessor = BaseDataPreprocessor()
+        self.assertEqual(base_data_preprocessor.device, 'cpu')
+
+    def test_forward(self):
+        base_data_preprocessor = BaseDataPreprocessor()
+        input1 = torch.randn(1, 3, 5)
+        input2 = torch.randn(1, 3, 5)
+        label1 = torch.randn(1)
+        label2 = torch.randn(1)
+
+        data = [
+            dict(inputs=input1, data_sample=label1),
+            dict(inputs=input2, data_sample=label2)
+        ]
+
+        batch_inputs, batch_labels = base_data_preprocessor(data)
+        self.assertEqual(batch_inputs.shape, (2, 1, 3, 5))
+
+        assert_allclose(input1, batch_inputs[0])
+        assert_allclose(input2, batch_inputs[1])
+        assert_allclose(label1, batch_labels[0])
+        assert_allclose(label2, batch_labels[1])
+
+
+class TestImageDataPreprocessor(TestBaseDataPreprocessor):
+
+    def test_init(self):
+        # initiate model without `preprocess_cfg`
+        data_processor = ImgDataPreprocessor()
+        self.assertFalse(data_processor.channel_conversion)
+        assert_allclose(data_processor.mean,
+                        torch.tensor([127.5, 127.5, 127.5]).view(-1, 1, 1))
+        assert_allclose(data_processor.std,
+                        torch.tensor([127.5, 127.5, 127.5]).view(-1, 1, 1))
+        self.assertEqual(data_processor.pad_size_divisor, 1)
+        assert_allclose(data_processor.pad_value, torch.tensor(0))
+        # initiate model with preprocess_cfg` and feat keys
+        data_processor = ImgDataPreprocessor(
+            bgr_to_rgb=True,
+            mean=[0, 0, 0],
+            std=[255, 255, 255],
+            pad_size_divisor=16,
+            pad_value=10)
+        self.assertTrue(data_processor.channel_conversion, True)
+        assert_allclose(data_processor.mean,
+                        torch.tensor([0, 0, 0]).view(-1, 1, 1))
+        assert_allclose(data_processor.std,
+                        torch.tensor([255, 255, 255]).view(-1, 1, 1))
+        assert_allclose(data_processor.pad_value, torch.tensor(10))
+        self.assertEqual(data_processor.pad_size_divisor, 16)
+
+        with self.assertRaisesRegex(AssertionError, 'The length of mean'):
+            ImgDataPreprocessor(mean=(1, 2))
+
+        with self.assertRaisesRegex(AssertionError, 'The length of std'):
+            ImgDataPreprocessor(std=(1, 2))
+
+        with self.assertRaisesRegex(AssertionError, '`bgr2rgb` and `rgb2bgr`'):
+            ImgDataPreprocessor(bgr_to_rgb=True, rgb_to_bgr=True)
+
+    def test_forward(self):
+        # Test `pad_value`, `to_rgb`, `pad_size_divisor`.
+        data_preprocessor = ImgDataPreprocessor(
+            mean=[127.5],
+            std=[1, 2, 3],
+            pad_size_divisor=16,
+            pad_value=10,
+            rgb_to_bgr=True,
+        )
+        inputs1 = torch.randn(3, 10, 10)
+        inputs2 = torch.randn(3, 15, 15)
+        data_sample1 = InstanceData(bboxes=torch.randn(5, 4))
+        data_sample2 = InstanceData(bboxes=torch.randn(5, 4))
+        data = [
+            dict(inputs=inputs1.clone(), data_sample=data_sample1.clone()),
+            dict(inputs=inputs2.clone(), data_sample=data_sample2.clone())
+        ]
+
+        std = torch.tensor([1, 2, 3]).view(-1, 1, 1)
+        inputs1 = (inputs1[[2, 1, 0], ...] - 127.5) / std
+        inputs2 = (inputs2[[2, 1, 0], ...] - 127.5) / std
+        inputs1 = F.pad(inputs1, (0, 6, 0, 6), value=10)
+        inputs2 = F.pad(inputs2, (0, 1, 0, 1), value=10)
+
+        target_inputs = [inputs1, inputs2]
+        inputs, data_samples = data_preprocessor(data, True)
+
+        target_data_samples = [data_sample1, data_sample2]
+        for input_, data_sample, target_input, target_data_sample in zip(
+                inputs, data_samples, target_inputs, target_data_samples):
+            assert_allclose(input_, target_input)
+            assert_allclose(data_sample.bboxes, target_data_sample.bboxes)
+
+        # Test empty `data_sample`
+        data = [dict(inputs=inputs1.clone()), dict(inputs=inputs2.clone())]
+        data_preprocessor(data, True)
diff --git a/tests/test_model/test_wrappers/test_data_parallel.py b/tests/test_model/test_wrappers/test_data_parallel.py
deleted file mode 100644
index c1f96ac4..00000000
--- a/tests/test_model/test_wrappers/test_data_parallel.py
+++ /dev/null
@@ -1,141 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-from unittest import TestCase
-from unittest.mock import MagicMock, patch
-
-import pytest
-import torch
-import torch.nn as nn
-from torch.nn.parallel import DataParallel
-from torch.nn.parallel.distributed import DistributedDataParallel
-
-from mmengine.model.wrappers import (MMDataParallel, MMDistributedDataParallel,
-                                     is_model_wrapper)
-from mmengine.registry import MODEL_WRAPPERS
-
-
-def mock(*args, **kwargs):
-    pass
-
-
-@patch('torch.distributed._broadcast_coalesced', mock)
-@patch('torch.distributed.broadcast', mock)
-@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
-def test_is_model_wrapper():
-
-    class Model(nn.Module):
-
-        def __init__(self):
-            super().__init__()
-            self.conv = nn.Conv2d(2, 2, 1)
-
-        def forward(self, x):
-            return self.conv(x)
-
-    # _verify_model_across_ranks is added in torch1.9.0 so we should check
-    # whether _verify_model_across_ranks is the member of torch.distributed
-    # before mocking
-    if hasattr(torch.distributed, '_verify_model_across_ranks'):
-        torch.distributed._verify_model_across_ranks = mock
-
-    # _verify_model_across_ranks is added in torch1.11.0 so we should check
-    # whether _verify_params_across_processes is the member of
-    # torch.distributed before mocking
-    if hasattr(torch.distributed, '_verify_params_across_processes'):
-        torch.distributed._verify_params_across_processes = mock
-
-    model = Model()
-    assert not is_model_wrapper(model)
-
-    mmdp = MMDataParallel(model)
-    assert is_model_wrapper(mmdp)
-
-    mmddp = MMDistributedDataParallel(model, process_group=MagicMock())
-    assert is_model_wrapper(mmddp)
-
-    torch_dp = DataParallel(model)
-    assert is_model_wrapper(torch_dp)
-
-    torch_ddp = DistributedDataParallel(model, process_group=MagicMock())
-    assert is_model_wrapper(torch_ddp)
-
-    # test model wrapper registry
-    @MODEL_WRAPPERS.register_module()
-    class ModelWrapper:
-
-        def __init__(self, module):
-            self.module = module
-
-        def forward(self, *args, **kwargs):
-            return self.module(*args, **kwargs)
-
-    model_wrapper = ModelWrapper(model)
-    assert is_model_wrapper(model_wrapper)
-
-
-class TestMMDataParallel(TestCase):
-
-    def setUp(self):
-        """Setup the demo image in every test method.
-
-        TestCase calls functions in this order: setUp() -> testMethod() ->
-        tearDown() -> cleanUp()
-        """
-
-        class Model(nn.Module):
-
-            def __init__(self):
-                super().__init__()
-                self.conv = nn.Conv2d(1, 2, 1)
-
-            def forward(self, x):
-                return self.conv(x)
-
-            def train_step(self, x):
-                return self.forward(x)
-
-            def val_step(self, x):
-                return self.forward(x)
-
-        self.model = Model()
-
-    def test_train_step(self):
-
-        class Model(nn.Module):
-
-            def __init__(self):
-                super().__init__()
-                self.conv = nn.Conv2d(1, 2, 1)
-
-            def forward(self, x):
-                return self.conv(x)
-
-        model = Model()
-        mmdp = MMDataParallel(model)
-
-        # test without train_step attribute
-        with pytest.raises(AssertionError):
-            mmdp.train_step(torch.zeros([1, 1, 3, 3]))
-
-        out = self.model.train_step(torch.zeros([1, 1, 3, 3]))
-        assert out.shape == (1, 2, 3, 3)
-
-    def test_val_step(self):
-
-        class Model(nn.Module):
-
-            def __init__(self):
-                super().__init__()
-                self.conv = nn.Conv2d(1, 2, 1)
-
-            def forward(self, x):
-                return self.conv(x)
-
-        model = Model()
-        mmdp = MMDataParallel(model)
-
-        # test without val_step attribute
-        with pytest.raises(AssertionError):
-            mmdp.val_step(torch.zeros([1, 1, 3, 3]))
-
-        out = self.model.val_step(torch.zeros([1, 1, 3, 3]))
-        assert out.shape == (1, 2, 3, 3)
diff --git a/tests/test_model/test_wrappers/test_model_wrapper.py b/tests/test_model/test_wrappers/test_model_wrapper.py
new file mode 100644
index 00000000..2a7e74be
--- /dev/null
+++ b/tests/test_model/test_wrappers/test_model_wrapper.py
@@ -0,0 +1,161 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import unittest
+from unittest.mock import MagicMock
+
+import torch
+import torch.distributed as torch_dist
+import torch.nn as nn
+from torch.optim import SGD
+
+from mmengine.model import (BaseModel, MMDistributedDataParallel,
+                            MMSeparateDistributedDataParallel)
+from mmengine.optim import OptimWrapper, OptimWrapperDict
+from mmengine.testing import assert_allclose
+from mmengine.testing._internal import MultiProcessTestCase
+
+
+class ToyModel(BaseModel):
+
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(3, 1, 1)
+        self.conv2 = nn.Conv2d(1, 1, 1)
+
+    def forward(self, x, data_samples=None, mode='tensor'):
+        if mode == 'loss':
+            x = self.conv1(x)
+            x = self.conv2(x)
+            return dict(loss=x)
+        elif mode == 'predict':
+            return x
+        else:
+            return x
+
+
+class ComplexModel(BaseModel):
+
+    def __init__(self):
+        super().__init__()
+        self.conv1 = nn.Conv2d(3, 1, 1)
+        self.conv2 = nn.Conv2d(3, 1, 1)
+
+    def train_step(self, data, optim_wrapper):
+        batch_inputs, _ = self.data_preprocessor(data)
+        loss1 = self.conv1(batch_inputs)
+        optim_wrapper['optim_wrapper1'].update_params(loss1)
+        loss2 = self.conv2(batch_inputs)
+        optim_wrapper['optim_wrapper2'].update_params(loss2)
+        return dict(loss1=loss1, loss2=loss2)
+
+    def val_step(self, data):
+        return 1
+
+    def test_step(self, data):
+        return 2
+
+    def forward(self):
+        pass
+
+
+class TestModelWrapper(MultiProcessTestCase):
+
+    def setUp(self) -> None:
+        super().setUp()
+        self._spawn_processes()
+
+    def test_train_step(self):
+        self._init_dist_env(self.rank, self.world_size)
+        # Test `optim_wrapper` is a instance of `OptimWrapper`
+        model = ToyModel()
+        ddp_model = MMDistributedDataParallel(module=model)
+        optimizer = SGD(ddp_model.parameters(), lr=0)
+        optim_wrapper = OptimWrapper(optimizer, accumulative_iters=1)
+        inputs = torch.randn(3, 1, 1) * self.rank * 255
+        data = dict(inputs=inputs, data_sample=MagicMock())
+        ddp_model.train_step([data], optim_wrapper=optim_wrapper)
+        grad = ddp_model.module.conv1.weight.grad
+        assert_allclose(grad, torch.zeros_like(grad))
+
+    def test_val_step(self):
+        self._init_dist_env(self.rank, self.world_size)
+        model = ToyModel()
+        ddp_model = MMDistributedDataParallel(module=model)
+        inputs = torch.randn(3, 1, 1) * self.rank * 255
+        data = dict(inputs=inputs, data_sample=MagicMock())
+        # Test get predictions.
+        predictions = ddp_model.val_step([data])
+        self.assertIsInstance(predictions, torch.Tensor)
+
+    def test_test_step(self):
+        self._init_dist_env(self.rank, self.world_size)
+        model = ToyModel()
+        ddp_model = MMDistributedDataParallel(module=model)
+        inputs = torch.randn(3, 1, 1) * self.rank * 255
+        data = dict(inputs=inputs, data_sample=MagicMock())
+        predictions = ddp_model.test_step([data])
+        self.assertIsInstance(predictions, torch.Tensor)
+
+    def _init_dist_env(self, rank, world_size):
+        """Initialize the distributed environment."""
+        os.environ['MASTER_ADDR'] = '127.0.0.1'
+        os.environ['MASTER_PORT'] = '29510'
+        os.environ['RANK'] = str(rank)
+        torch_dist.init_process_group(
+            backend='gloo', rank=rank, world_size=world_size)
+
+
+@unittest.skipIf(
+    not torch.cuda.is_available(), reason='cuda should be available')
+class TestMMSeparateDistributedDataParallel(TestModelWrapper):
+
+    def test_train_step(self):
+        self._init_dist_env(self.rank, self.world_size)
+        # Test `optim_wrapper` is a dict. In this case,
+        # There will be two independently updated `DistributedDataParallel`
+        # submodules.
+        model = ComplexModel()
+        ddp_model = MMSeparateDistributedDataParallel(model.cuda())
+        optimizer1 = SGD(model.conv1.parameters(), lr=0.1)
+        optimizer2 = SGD(model.conv1.parameters(), lr=0.2)
+        optim_wrapper1 = OptimWrapper(optimizer1, 1)
+        optim_wrapper2 = OptimWrapper(optimizer2, 1)
+        optim_wrapper_dict = OptimWrapperDict(
+            optim_wrapper1=optim_wrapper1, optim_wrapper2=optim_wrapper2)
+        inputs = torch.randn(3, 1, 1).cuda() * self.rank * 255
+        data = dict(inputs=inputs)
+        # Automatically sync grads of `optim_wrapper1` since
+        # `cumulative_iters` = 1
+        ddp_model.train()
+        self.assertTrue(ddp_model.training)
+        ddp_model.train_step([data], optim_wrapper=optim_wrapper_dict)
+
+    def test_val_step(self):
+        self._init_dist_env(self.rank, self.world_size)
+        model = ComplexModel()
+        ddp_model = MMSeparateDistributedDataParallel(model)
+        data = torch.randn(3, 1, 1)
+        # Test get predictions.
+        ddp_model.eval()
+        self.assertFalse(ddp_model.training)
+        predictions = ddp_model.val_step([data])
+        self.assertEqual(predictions, 1)
+
+    def test_test_step(self):
+        self._init_dist_env(self.rank, self.world_size)
+        model = ComplexModel()
+        ddp_model = MMSeparateDistributedDataParallel(model)
+        data = torch.randn(3, 1, 1)
+        # Test get predictions.
+        ddp_model.eval()
+        self.assertFalse(ddp_model.training)
+        predictions = ddp_model.test_step(data)
+        self.assertEqual(predictions, 2)
+
+    def _init_dist_env(self, rank, world_size):
+        """Initialize the distributed environment."""
+        os.environ['MASTER_ADDR'] = '127.0.0.1'
+        os.environ['MASTER_PORT'] = '29515'
+        os.environ['RANK'] = str(rank)
+        torch_dist.init_process_group(
+            backend='gloo', rank=rank, world_size=world_size)
diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py
index 438b840e..d88a0e5b 100644
--- a/tests/test_runner/test_runner.py
+++ b/tests/test_runner/test_runner.py
@@ -1,5 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import copy
+import os
 import os.path as osp
 import shutil
 import tempfile
@@ -19,6 +20,7 @@ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook,
                             IterTimerHook, LoggerHook, ParamSchedulerHook,
                             RuntimeInfoHook)
 from mmengine.logging import LogProcessor, MessageHub, MMLogger
+from mmengine.model import BaseModel
 from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR,
                             OptimWrapper, OptimWrapperDict, StepLR)
 from mmengine.registry import (DATASETS, HOOKS, LOG_PROCESSORS, LOOPS, METRICS,
@@ -33,29 +35,25 @@ from mmengine.visualization import Visualizer
 
 
 @MODELS.register_module()
-class ToyModel(nn.Module):
+class ToyModel(BaseModel):
 
     def __init__(self):
         super().__init__()
         self.linear1 = nn.Linear(2, 2)
         self.linear2 = nn.Linear(2, 1)
 
-    def forward(self, data_batch, return_loss=False):
-        inputs, labels = [], []
-        for x in data_batch:
-            inputs.append(x['inputs'])
-            labels.append(x['data_sample'])
-
-        device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
-        inputs = torch.stack(inputs).to(device)
-        labels = torch.stack(labels).to(device)
-        outputs = self.linear1(inputs)
+    def forward(self, batch_inputs, labels, mode='tensor'):
+        labels = torch.stack(labels)
+        outputs = self.linear1(batch_inputs)
         outputs = self.linear2(outputs)
-        if return_loss:
+
+        if mode == 'tensor':
+            return outputs
+        elif mode == 'loss':
             loss = (labels - outputs).sum()
-            outputs = dict(loss=loss, log_vars=dict(loss=loss.item()))
+            outputs = dict(loss=loss)
             return outputs
-        else:
+        elif mode == 'predict':
             outputs = dict(log_vars=dict(a=1, b=0.5))
             return outputs
 
@@ -67,12 +65,43 @@ class ToyModel1(ToyModel):
         super().__init__()
 
 
+@MODELS.register_module()
+class TopGANModel(BaseModel):
+
+    def __init__(self):
+        super().__init__()
+        self.linear1 = nn.Linear(2, 1)
+        self.linear2 = nn.Linear(2, 1)
+
+    def forward(self, batch_inputs, labels, mode='tensor'):
+        labels = torch.stack(labels)
+        output1 = self.linear1(batch_inputs)
+        output2 = self.linear2(batch_inputs)
+
+        if mode == 'tensor':
+            return output1, output2
+        elif mode == 'loss':
+            loss1 = (labels - output1).sum()
+            loss2 = (labels - output2).sum()
+            outputs = dict(linear1=loss1, linear2=loss2)
+            return outputs
+        elif mode == 'predict':
+            return output1, output2
+
+    def train_step(self, data, optim_wrapper):
+        batch_inputs, batch_labels = self.data_preprocessor(data)
+        loss = self(batch_inputs, batch_labels, mode='loss')
+        optim_wrapper['linear1'].update_params(loss['linear1'])
+        optim_wrapper['linear2'].update_params(loss['linear2'])
+        return loss
+
+
 @MODEL_WRAPPERS.register_module()
 class CustomModelWrapper(nn.Module):
 
-    def __init__(self, model):
+    def __init__(self, module):
         super().__init__()
-        self.model = model
+        self.model = module
 
 
 @OPTIM_WRAPPER_CONSTRUCTORS.register_module()
@@ -294,7 +323,6 @@ class TestRunner(TestCase):
             custom_hooks=[],
             default_hooks=dict(
                 runtime_info=dict(type='RuntimeInfoHook'),
-                optimizer=dict(type='OptimizerHook', grad_clip=None),
                 timer=dict(type='IterTimerHook'),
                 logger=dict(type='LoggerHook'),
                 param_scheduler=dict(type='ParamSchedulerHook'),
@@ -314,7 +342,6 @@ class TestRunner(TestCase):
         self.iter_based_cfg.train_cfg = dict(by_epoch=False, max_iters=12)
         self.iter_based_cfg.default_hooks = dict(
             runtime_info=dict(type='RuntimeInfoHook'),
-            optimizer=dict(type='OptimizerHook', grad_clip=None),
             timer=dict(type='IterTimerHook'),
             logger=dict(type='LoggerHook'),
             param_scheduler=dict(type='ParamSchedulerHook'),
@@ -649,13 +676,21 @@ class TestRunner(TestCase):
         self.assertFalse(model.initiailzed)
 
     def test_wrap_model(self):
-        # TODO: test on distributed environment
         # custom model wrapper
         cfg = copy.deepcopy(self.epoch_based_cfg)
         cfg.experiment_name = 'test_wrap_model'
         cfg.model_wrapper_cfg = dict(type='CustomModelWrapper')
         runner = Runner.from_cfg(cfg)
-        self.assertIsInstance(runner.model, CustomModelWrapper)
+        self.assertIsInstance(runner.model, BaseModel)
+        if torch.cuda.is_available():
+            os.environ['MASTER_ADDR'] = '127.0.0.1'
+            os.environ['MASTER_PORT'] = '29515'
+            os.environ['RANK'] = str(0)
+            os.environ['WORLD_SIZE'] = str(1)
+            cfg.launcher = 'pytorch'
+            cfg.experiment_name = 'test_wrap_model1'
+            runner = Runner.from_cfg(cfg)
+            self.assertIsInstance(runner.model, CustomModelWrapper)
 
     def test_scale_lr(self):
         cfg = copy.deepcopy(self.epoch_based_cfg)
@@ -1270,35 +1305,35 @@ class TestRunner(TestCase):
 
         # register 7 hooks by default
         runner.register_default_hooks()
-        self.assertEqual(len(runner._hooks), 7)
+        self.assertEqual(len(runner._hooks), 6)
         # the third registered hook should be `DistSamplerSeedHook`
-        self.assertTrue(isinstance(runner._hooks[3], DistSamplerSeedHook))
+        self.assertTrue(isinstance(runner._hooks[2], DistSamplerSeedHook))
         # the fifth registered hook should be `ParamSchedulerHook`
-        self.assertTrue(isinstance(runner._hooks[5], ParamSchedulerHook))
+        self.assertTrue(isinstance(runner._hooks[4], ParamSchedulerHook))
 
         runner._hooks = []
         # remove `ParamSchedulerHook` from default hooks
         runner.register_default_hooks(hooks=dict(timer=None))
-        self.assertEqual(len(runner._hooks), 6)
+        self.assertEqual(len(runner._hooks), 5)
         # `ParamSchedulerHook` was popped so the fifth is `CheckpointHook`
-        self.assertTrue(isinstance(runner._hooks[5], CheckpointHook))
+        self.assertTrue(isinstance(runner._hooks[4], CheckpointHook))
 
         # add a new default hook
         runner._hooks = []
         runner.register_default_hooks(hooks=dict(ToyHook=dict(type='ToyHook')))
-        self.assertEqual(len(runner._hooks), 8)
-        self.assertTrue(isinstance(runner._hooks[7], ToyHook))
+        self.assertEqual(len(runner._hooks), 7)
+        self.assertTrue(isinstance(runner._hooks[6], ToyHook))
 
     def test_custom_hooks(self):
         cfg = copy.deepcopy(self.epoch_based_cfg)
         cfg.experiment_name = 'test_custom_hooks'
         runner = Runner.from_cfg(cfg)
 
-        self.assertEqual(len(runner._hooks), 7)
+        self.assertEqual(len(runner._hooks), 6)
         custom_hooks = [dict(type='ToyHook')]
         runner.register_custom_hooks(custom_hooks)
-        self.assertEqual(len(runner._hooks), 8)
-        self.assertTrue(isinstance(runner._hooks[7], ToyHook))
+        self.assertEqual(len(runner._hooks), 7)
+        self.assertTrue(isinstance(runner._hooks[6], ToyHook))
 
     def test_register_hooks(self):
         cfg = copy.deepcopy(self.epoch_based_cfg)
@@ -1309,8 +1344,8 @@ class TestRunner(TestCase):
         custom_hooks = [dict(type='ToyHook')]
         runner.register_hooks(custom_hooks=custom_hooks)
         # six default hooks + custom hook (ToyHook)
-        self.assertEqual(len(runner._hooks), 8)
-        self.assertTrue(isinstance(runner._hooks[7], ToyHook))
+        self.assertEqual(len(runner._hooks), 7)
+        self.assertTrue(isinstance(runner._hooks[6], ToyHook))
 
     def test_custom_loop(self):
         # test custom loop with additional hook
@@ -1346,12 +1381,11 @@ class TestRunner(TestCase):
             def warmup_iter(self, data_batch):
                 self.runner.call_hook(
                     'before_warmup_iter', data_batch=data_batch)
-                self.runner.outputs = self.runner.model(
-                    data_batch, return_loss=True)
+                train_logs = self.runner.model.train_step(
+                    data_batch, self.runner.optim_wrapper)
+                self.runner.message_hub.update_info('train_logs', train_logs)
                 self.runner.call_hook(
-                    'after_warmup_iter',
-                    data_batch=data_batch,
-                    outputs=self.runner.outputs)
+                    'after_warmup_iter', data_batch=data_batch)
 
         before_warmup_iter_results = []
         after_warmup_iter_results = []
@@ -1513,8 +1547,8 @@ class TestRunner(TestCase):
             linear2=dict(
                 type='OptimWrapper', optimizer=dict(type='Adam', lr=0.02)),
             constructor='ToyMultipleOptimizerConstructor')
+        cfg.model = dict(type='TopGANModel')
         # disable OptimizerHook because it only works with one optimizer
-        cfg.default_hooks = dict(optimizer=None)
         runner = Runner.from_cfg(cfg)
         runner.train()
         path = osp.join(self.temp_dir, 'epoch_3.pth')
@@ -1533,8 +1567,8 @@ class TestRunner(TestCase):
             linear2=dict(
                 type='OptimWrapper', optimizer=dict(type='Adam', lr=0.03)),
             constructor='ToyMultipleOptimizerConstructor')
+        cfg.model = dict(type='TopGANModel')
         cfg.param_scheduler = dict(type='MultiStepLR', milestones=[1, 2, 3])
-        cfg.default_hooks = dict(optimizer=None)
         runner = Runner.from_cfg(cfg)
         runner.resume(path)
         self.assertIsInstance(runner.optim_wrapper, OptimWrapperDict)
-- 
GitLab