diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py
index 45c3f910a8c13f7b2c56a1cea0e7b28753aa4454..44698bbffc6ee6ede3bb717c2b7413c322184f0f 100644
--- a/mmengine/hooks/__init__.py
+++ b/mmengine/hooks/__init__.py
@@ -4,6 +4,7 @@ from .empty_cache_hook import EmptyCacheHook
 from .hook import Hook
 from .iter_timer_hook import IterTimerHook
 from .logger_hook import LoggerHook
+from .naive_visualization_hook import NaiveVisualizationHook
 from .optimizer_hook import OptimizerHook
 from .param_scheduler_hook import ParamSchedulerHook
 from .sampler_seed_hook import DistSamplerSeedHook
@@ -12,5 +13,5 @@ from .sync_buffer_hook import SyncBuffersHook
 __all__ = [
     'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
     'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook',
-    'LoggerHook'
+    'LoggerHook', 'NaiveVisualizationHook'
 ]
diff --git a/mmengine/hooks/naive_visualization_hook.py b/mmengine/hooks/naive_visualization_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..434de95fc21b9cfab0c32062ca840cff2d2674df
--- /dev/null
+++ b/mmengine/hooks/naive_visualization_hook.py
@@ -0,0 +1,71 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+from typing import Any, Optional, Sequence, Tuple
+
+import cv2
+import numpy as np
+
+from mmengine.data import BaseDataSample
+from mmengine.hooks import Hook
+from mmengine.registry import HOOKS
+from mmengine.utils.misc import tensor2imgs
+
+
+@HOOKS.register_module()
+class NaiveVisualizationHook(Hook):
+    """Show or Write the predicted results during the process of testing.
+
+    Args:
+        interval (int): Visualization interval. Default: 1.
+        draw_gt (bool): Whether to draw the ground truth. Default to True.
+        draw_pred (bool): Whether to draw the predicted result.
+            Default to True.
+    """
+    priority = 'NORMAL'
+
+    def __init__(self,
+                 interval: int = 1,
+                 draw_gt: bool = True,
+                 draw_pred: bool = True):
+        self.draw_gt = draw_gt
+        self.draw_pred = draw_pred
+        self._interval = interval
+
+    def _unpad(self, input: np.ndarray, unpad_shape: Tuple[int,
+                                                           int]) -> np.ndarray:
+        unpad_width, unpad_height = unpad_shape
+        unpad_image = input[:unpad_height, :unpad_width]
+        return unpad_image
+
+    def after_test_iter(
+            self,
+            runner,
+            data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
+            outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
+        """Show or Write the predicted results.
+
+        Args:
+            runner (Runner): The runner of the training process.
+            data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
+                from dataloader. Defaults to None.
+            outputs (Sequence[BaseDataSample], optional): Outputs from model.
+                Defaults to None.
+        """
+        if self.every_n_iters(runner, self._interval):
+            inputs, data_samples = data_batch  # type: ignore
+            inputs = tensor2imgs(inputs,
+                                 **data_samples[0].get('img_norm_cfg', dict()))
+            for input, data_sample, output in zip(
+                    inputs,
+                    data_samples,  # type: ignore
+                    outputs):  # type: ignore
+                # TODO We will implement a function to revert the augmentation
+                # in the future.
+                ori_shape = (data_sample.ori_width, data_sample.ori_height)
+                if 'pad_shape' in data_sample:
+                    input = self._unpad(input,
+                                        data_sample.get('scale', ori_shape))
+                origin_image = cv2.resize(input, ori_shape)
+                name = osp.basename(data_sample.img_path)
+                runner.writer.add_image(name, origin_image, data_sample,
+                                        output, self.draw_gt, self.draw_pred)
diff --git a/mmengine/utils/misc.py b/mmengine/utils/misc.py
index 79977be0f3d2226b97211263dd10eae67da571e5..3a95511624b05f48e5a9b3579e7bf7785d0236ca 100644
--- a/mmengine/utils/misc.py
+++ b/mmengine/utils/misc.py
@@ -11,6 +11,8 @@ from inspect import getfullargspec
 from itertools import repeat
 from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union
 
+import numpy as np
+import torch
 import torch.nn as nn
 
 from .parrots_wrapper import _BatchNorm, _InstanceNorm
@@ -433,3 +435,46 @@ def is_norm(layer: nn.Module,
 
     all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
     return isinstance(layer, all_norm_bases)
+
+
+def tensor2imgs(tensor: torch.Tensor,
+                mean: Optional[Tuple[float, float, float]] = None,
+                std: Optional[Tuple[float, float, float]] = None,
+                to_bgr: bool = True):
+    """Convert tensor to 3-channel images or 1-channel gray images.
+
+    Args:
+        tensor (torch.Tensor): Tensor that contains multiple images, shape (
+            N, C, H, W). :math:`C` can be either 3 or 1. If C is 3, the format
+            should be RGB.
+        mean (tuple[float], optional): Mean of images. If None,
+            (0, 0, 0) will be used for tensor with 3-channel,
+            while (0, ) for tensor with 1-channel. Defaults to None.
+        std (tuple[float], optional): Standard deviation of images. If None,
+            (1, 1, 1) will be used for tensor with 3-channel,
+            while (1, ) for tensor with 1-channel. Defaults to None.
+        to_bgr (bool): For the tensor with 3 channel, convert its format to
+            BGR. For the tensor with 1 channel, it must be False. Defaults to
+            True.
+
+    Returns:
+        list[np.ndarray]: A list that contains multiple images.
+    """
+
+    assert torch.is_tensor(tensor) and tensor.ndim == 4
+    channels = tensor.size(1)
+    assert channels in [1, 3]
+    if mean is None:
+        mean = (0, ) * channels
+    if std is None:
+        std = (1, ) * channels
+    assert (channels == len(mean) == len(std) == 3) or \
+        (channels == len(mean) == len(std) == 1 and not to_bgr)
+    mean = tensor.new_tensor(mean).view(1, -1)
+    std = tensor.new_tensor(std).view(1, -1)
+    tensor = tensor.permute(0, 2, 3, 1) * std + mean
+    imgs = tensor.detach().cpu().numpy()
+    if to_bgr and channels == 3:
+        imgs = imgs[:, :, :, (2, 1, 0)]  # RGB2BGR
+    imgs = [np.ascontiguousarray(img) for img in imgs]
+    return imgs
diff --git a/tests/test_hook/test_naive_visualization_hook.py b/tests/test_hook/test_naive_visualization_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..beb053a4a6405df8100da67aade565f057002ce0
--- /dev/null
+++ b/tests/test_hook/test_naive_visualization_hook.py
@@ -0,0 +1,84 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from unittest.mock import Mock
+
+import torch
+
+from mmengine.data import BaseDataSample
+from mmengine.hooks import NaiveVisualizationHook
+
+
+class TestNaiveVisualizationHook:
+
+    def test_after_train_iter(self):
+        naive_visualization_hook = NaiveVisualizationHook()
+        Runner = Mock(iter=1)
+        Runner.writer.add_image = Mock()
+        inputs = torch.randn(1, 3, 15, 15)
+        # test with normalize, resize, pad
+        gt_datasamples = [
+            BaseDataSample(
+                metainfo=dict(
+                    img_norm_cfg=dict(
+                        mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True),
+                    scale=(10, 10),
+                    pad_shape=(15, 15, 3),
+                    ori_height=5,
+                    ori_width=5,
+                    img_path='tmp.jpg'))
+        ]
+        pred_datasamples = [BaseDataSample()]
+        data_batch = (inputs, gt_datasamples)
+        naive_visualization_hook.after_test_iter(Runner, data_batch,
+                                                 pred_datasamples)
+        # test with resize, pad
+        gt_datasamples = [
+            BaseDataSample(
+                metainfo=dict(
+                    scale=(10, 10),
+                    pad_shape=(15, 15, 3),
+                    ori_height=5,
+                    ori_width=5,
+                    img_path='tmp.jpg')),
+        ]
+        pred_datasamples = [BaseDataSample()]
+        data_batch = (inputs, gt_datasamples)
+        naive_visualization_hook.after_test_iter(Runner, data_batch,
+                                                 pred_datasamples)
+        # test with only resize
+        gt_datasamples = [
+            BaseDataSample(
+                metainfo=dict(
+                    scale=(15, 15),
+                    ori_height=5,
+                    ori_width=5,
+                    img_path='tmp.jpg')),
+        ]
+        pred_datasamples = [BaseDataSample()]
+        data_batch = (inputs, gt_datasamples)
+        naive_visualization_hook.after_test_iter(Runner, data_batch,
+                                                 pred_datasamples)
+
+        # test with only pad
+        gt_datasamples = [
+            BaseDataSample(
+                metainfo=dict(
+                    pad_shape=(15, 15, 3),
+                    ori_height=5,
+                    ori_width=5,
+                    img_path='tmp.jpg')),
+        ]
+        pred_datasamples = [BaseDataSample()]
+        data_batch = (inputs, gt_datasamples)
+        naive_visualization_hook.after_test_iter(Runner, data_batch,
+                                                 pred_datasamples)
+
+        # test no transform
+        gt_datasamples = [
+            BaseDataSample(
+                metainfo=dict(ori_height=15, ori_width=15,
+                              img_path='tmp.jpg')),
+        ]
+        pred_datasamples = [BaseDataSample()]
+        data_batch = (inputs, gt_datasamples)
+        naive_visualization_hook.after_test_iter(Runner, data_batch,
+                                                 pred_datasamples)