From 824be950b9a36d97a71ba5ef34c2bdf8887f921f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Haian=20Huang=28=E6=B7=B1=E5=BA=A6=E7=9C=B8=29?=
 <1286304229@qq.com>
Date: Mon, 7 Mar 2022 22:39:25 +0800
Subject: [PATCH] Add writer (#74)

* add writer

* update

* update

* update docstr

* update unittest

* update unittest

* fix comment

* update unittest

* fix comment

* fix comment

* fix comment

* fix comment

* update
---
 mmengine/registry/__init__.py        |   4 +-
 mmengine/registry/root.py            |   3 +-
 mmengine/visualization/__init__.py   |   7 +-
 mmengine/visualization/writer.py     | 822 +++++++++++++++++++++++++++
 tests/test_visualizer/test_writer.py | 505 +++++++++++++---
 5 files changed, 1256 insertions(+), 85 deletions(-)
 create mode 100644 mmengine/visualization/writer.py

diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py
index 0cfb1e9e..2142fae3 100644
--- a/mmengine/registry/__init__.py
+++ b/mmengine/registry/__init__.py
@@ -3,11 +3,11 @@ from .registry import Registry, build_from_cfg
 from .root import (DATA_SAMPLERS, DATASETS, EVALUATORS, HOOKS, MODEL_WRAPPERS,
                    MODELS, OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
                    PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS,
-                   TRANSFORMS, VISUALIZERS, WEIGHT_INITIALIZERS)
+                   TRANSFORMS, VISUALIZERS, WEIGHT_INITIALIZERS, WRITERS)
 
 __all__ = [
     'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS',
     'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
     'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS',
-    'EVALUATORS', 'MODEL_WRAPPERS', 'VISUALIZERS'
+    'EVALUATORS', 'MODEL_WRAPPERS', 'WRITERS', 'VISUALIZERS'
 ]
diff --git a/mmengine/registry/root.py b/mmengine/registry/root.py
index 39fad795..988cb855 100644
--- a/mmengine/registry/root.py
+++ b/mmengine/registry/root.py
@@ -39,6 +39,7 @@ TASK_UTILS = Registry('task util')
 
 # manage all kinds of evaluators for computing metrics
 EVALUATORS = Registry('evaluator')
-
 # manage visualizer
 VISUALIZERS = Registry('visualizer')
+# manage writer
+WRITERS = Registry('writer')
diff --git a/mmengine/visualization/__init__.py b/mmengine/visualization/__init__.py
index f1cf58e1..892c3daa 100644
--- a/mmengine/visualization/__init__.py
+++ b/mmengine/visualization/__init__.py
@@ -1,4 +1,9 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 from .visualizer import Visualizer
+from .writer import (BaseWriter, ComposedWriter, LocalWriter,
+                     TensorboardWriter, WandbWriter)
 
-__all__ = ['Visualizer']
+__all__ = [
+    'Visualizer', 'BaseWriter', 'LocalWriter', 'WandbWriter',
+    'TensorboardWriter', 'ComposedWriter'
+]
diff --git a/mmengine/visualization/writer.py b/mmengine/visualization/writer.py
new file mode 100644
index 00000000..31ab3e08
--- /dev/null
+++ b/mmengine/visualization/writer.py
@@ -0,0 +1,822 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import os.path as osp
+import time
+from abc import ABCMeta, abstractmethod
+from typing import Any, List, Optional, Union
+
+import cv2
+import numpy as np
+import torch
+
+from mmengine.data import BaseDataSample
+from mmengine.fileio import dump
+from mmengine.logging import BaseGlobalAccessible
+from mmengine.registry import VISUALIZERS, WRITERS
+from mmengine.utils import TORCH_VERSION
+from .visualizer import Visualizer
+
+
+class BaseWriter(metaclass=ABCMeta):
+    """Base class for writer.
+
+    Each writer can inherit ``BaseWriter`` and implement
+    the required functions.
+
+    Args:
+        visualizer (dict, :obj:`Visualizer`, optional):
+            Visualizer instance or dictionary. Default to None.
+        save_dir (str, optional): The root directory to save
+            the files produced by the writer. Default to None.
+    """
+
+    def __init__(self,
+                 visualizer: Optional[Union[dict, 'Visualizer']] = None,
+                 save_dir: Optional[str] = None):
+        self._save_dir = save_dir
+        if self._save_dir:
+            timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
+            self._save_dir = osp.join(
+                self._save_dir, f'write_data_{timestamp}')  # type: ignore
+        self._visualizer = visualizer
+        if visualizer:
+            if isinstance(visualizer, dict):
+                self._visualizer = VISUALIZERS.build(visualizer)
+            else:
+                assert isinstance(visualizer, Visualizer), \
+                    'visualizer should be an instance of Visualizer, ' \
+                    f'but got {type(visualizer)}'
+
+    @property
+    def visualizer(self) -> 'Visualizer':
+        """Return the visualizer object.
+
+        You can get the drawing backend through the visualizer property for
+        custom drawing.
+        """
+        return self._visualizer  # type: ignore
+
+    @property
+    @abstractmethod
+    def experiment(self) -> Any:
+        """Return the experiment object associated with this writer.
+
+        The experiment attribute can get the write backend, such as wandb,
+        tensorboard. If you want to write other data, such as writing a table,
+        you can directly get the write backend through experiment.
+        """
+        pass
+
+    def add_params(self, params_dict: dict, **kwargs) -> None:
+        """Record a set of parameters.
+
+        Args:
+            params_dict (dict): Each key-value pair in the dictionary is the
+                  name of the parameters and it's corresponding value.
+        """
+        pass
+
+    def add_graph(self, model: torch.nn.Module,
+                  input_tensor: Union[torch.Tensor,
+                                      List[torch.Tensor]], **kwargs) -> None:
+        """Record graph.
+
+        Args:
+            model (torch.nn.Module): Model to draw.
+            input_tensor (torch.Tensor, list[torch.Tensor]): A variable
+                or a tuple of variables to be fed.
+        """
+        pass
+
+    def add_image(self,
+                  name: str,
+                  image: Optional[np.ndarray] = None,
+                  gt_sample: Optional['BaseDataSample'] = None,
+                  pred_sample: Optional['BaseDataSample'] = None,
+                  draw_gt: bool = True,
+                  draw_pred: bool = True,
+                  step: int = 0,
+                  **kwargs) -> None:
+        """Record image.
+
+        Args:
+            name (str): The unique identifier for the image to save.
+            image (np.ndarray, optional): The image to be saved. The format
+                should be RGB. Default to None.
+            gt_sample (:obj:`BaseDataSample`, optional): The ground truth data
+                structure of OpenMMlab. Default to None.
+            pred_sample (:obj:`BaseDataSample`, optional): The predicted result
+                data structure of OpenMMlab. Default to None.
+            draw_gt (bool): Whether to draw the ground truth. Default: True.
+            draw_pred (bool): Whether to draw the predicted result.
+                Default to True.
+            step (int): Global step value to record. Default to 0.
+        """
+        pass
+
+    def add_scalar(self,
+                   name: str,
+                   value: Union[int, float],
+                   step: int = 0,
+                   **kwargs) -> None:
+        """Record scalar.
+
+        Args:
+            name (str): The unique identifier for the scalar to save.
+            value (float, int): Value to save.
+            step (int): Global step value to record. Default to 0.
+        """
+        pass
+
+    def add_scalars(self,
+                    scalar_dict: dict,
+                    step: int = 0,
+                    file_path: Optional[str] = None,
+                    **kwargs) -> None:
+        """Record scalars' data.
+
+        Args:
+            scalar_dict (dict): Key-value pair storing the tag and
+                corresponding values.
+            step (int): Global step value to record. Default to 0.
+            file_path (str, optional): The scalar's data will be
+                saved to the `file_path` file at the same time
+                if the `file_path` parameter is specified.
+                Default to None.
+        """
+        pass
+
+    def close(self) -> None:
+        """close an opened object."""
+        pass
+
+
+@WRITERS.register_module()
+class LocalWriter(BaseWriter):
+    """Local write class.
+
+    It can write image, hyperparameters, scalars, etc.
+    to the local hard disk. You can get the drawing backend
+    through the visualizer property for custom drawing.
+
+    Examples:
+        >>> from mmengine.visualization import LocalWriter
+        >>> import numpy as np
+        >>> local_writer = LocalWriter(dict(type='DetVisualizer'),\
+            save_dir='temp_dir')
+        >>> img=np.random.randint(0, 256, size=(10, 10, 3))
+        >>> local_writer.add_image('img', img)
+        >>> local_writer.add_scaler('mAP', 0.6)
+        >>> local_writer.add_scalars({'loss': [1, 2, 3], 'acc': 0.8})
+        >>> local_writer.add_params(dict(lr=0.1, mode='linear'))
+
+        >>> local_writer.visualizer.draw_bboxes(np.array([0, 0, 1, 1]), \
+            edgecolors='g')
+        >>> local_writer.add_image('img', \
+            local_writer.visualizer.get_image())
+
+    Args:
+        save_dir (str): The root directory to save the files
+            produced by the writer.
+        visualizer (dict, :obj:`Visualizer`, optional): Visualizer
+            instance or dictionary. Default to None
+        img_save_dir (str): The directory to save images.
+            Default to 'writer_image'.
+        params_save_file (str): The file to save parameters.
+            Default to 'parameters.yaml'.
+        scalar_save_file (str):  The file to save scalar values.
+            Default to 'scalars.json'.
+        img_show (bool): Whether to show the image when calling add_image.
+            Default to False.
+    """
+
+    def __init__(self,
+                 save_dir: str,
+                 visualizer: Optional[Union[dict, 'Visualizer']] = None,
+                 img_save_dir: str = 'writer_image',
+                 params_save_file: str = 'parameters.yaml',
+                 scalar_save_file: str = 'scalars.json',
+                 img_show: bool = False):
+        assert params_save_file.split('.')[-1] == 'yaml'
+        assert scalar_save_file.split('.')[-1] == 'json'
+        super(LocalWriter, self).__init__(visualizer, save_dir)
+        os.makedirs(self._save_dir, exist_ok=True)  # type: ignore
+        self._img_save_dir = osp.join(
+            self._save_dir,  # type: ignore
+            img_save_dir)
+        self._scalar_save_file = osp.join(
+            self._save_dir,  # type: ignore
+            scalar_save_file)
+        self._params_save_file = osp.join(
+            self._save_dir,  # type: ignore
+            params_save_file)
+        self._img_show = img_show
+
+    @property
+    def experiment(self) -> 'LocalWriter':
+        """Return the experiment object associated with this writer."""
+        return self
+
+    def add_params(self, params_dict: dict, **kwargs) -> None:
+        """Record parameters to disk.
+
+        Args:
+            params_dict (dict): The dict of parameters to save.
+        """
+        assert isinstance(params_dict, dict)
+        self._dump(params_dict, self._params_save_file, 'yaml')
+
+    def add_image(self,
+                  name: str,
+                  image: Optional[np.ndarray] = None,
+                  gt_sample: Optional['BaseDataSample'] = None,
+                  pred_sample: Optional['BaseDataSample'] = None,
+                  draw_gt: bool = True,
+                  draw_pred: bool = True,
+                  step: int = 0,
+                  **kwargs) -> None:
+        """Record image to disk.
+
+        Args:
+            name (str): The unique identifier for the image to save.
+            image (np.ndarray, optional): The image to be saved. The format
+                should be RGB. Default to None.
+            gt_sample (:obj:`BaseDataSample`, optional): The ground truth data
+                structure of OpenMMlab. Default to None.
+            pred_sample (:obj:`BaseDataSample`, optional): The predicted result
+                data structure of OpenMMlab. Default to None.
+            draw_gt (bool): Whether to draw the ground truth. Default to True.
+            draw_pred (bool): Whether to draw the predicted result.
+                Default to True.
+            step (int): Global step value to record. Default to 0.
+        """
+        assert self.visualizer, 'Please instantiate the visualizer ' \
+                                'object with initialization parameters.'
+        self.visualizer.draw(image, gt_sample, pred_sample, draw_gt, draw_pred)
+        if self._img_show:
+            self.visualizer.show()
+        else:
+            drawn_image = cv2.cvtColor(self.visualizer.get_image(),
+                                       cv2.COLOR_RGB2BGR)
+            os.makedirs(self._img_save_dir, exist_ok=True)
+            save_file_name = f'{name}_{step}.png'
+            cv2.imwrite(
+                osp.join(self._img_save_dir, save_file_name), drawn_image)
+
+    def add_scalar(self,
+                   name: str,
+                   value: Union[int, float],
+                   step: int = 0,
+                   **kwargs) -> None:
+        """Add scalar data to disk.
+
+        Args:
+            name (str): The unique identifier for the scalar to save.
+            value (float, int): Value to save.
+            step (int): Global step value to record. Default to 0.
+        """
+        self._dump({name: value, 'step': step}, self._scalar_save_file, 'json')
+
+    def add_scalars(self,
+                    scalar_dict: dict,
+                    step: int = 0,
+                    file_path: Optional[str] = None,
+                    **kwargs) -> None:
+        """Record scalars. The scalar dict will be written to the default and
+        specified files if ``file_name`` is specified.
+
+        Args:
+            scalar_dict (dict): Key-value pair storing the tag and
+                corresponding values.
+            step (int): Global step value to record. Default to 0.
+            file_path (str, optional): The scalar's data will be
+                saved to the ``file_path`` file at the same time
+                if the ``file_path`` parameter is specified.
+                Default to None.
+        """
+        assert isinstance(scalar_dict, dict)
+        scalar_dict.setdefault('step', step)
+        if file_path is not None:
+            assert file_path.split('.')[-1] == 'json'
+            new_save_file_path = osp.join(
+                self._save_dir,  # type: ignore
+                file_path)
+            assert new_save_file_path != self._scalar_save_file, \
+                '"file_path" and "scalar_save_file" have the same name, ' \
+                'please set "file_path" to another value'
+            self._dump(scalar_dict, new_save_file_path, 'json')
+        self._dump(scalar_dict, self._scalar_save_file, 'json')
+
+    def _dump(self, value_dict: dict, file_path: str,
+              file_format: str) -> None:
+        """dump dict to file.
+
+        Args:
+           value_dict (dict) : Save dict data.
+           file_path (str): The file path to save data.
+           file_format (str): The file format to save data.
+        """
+        with open(file_path, 'a+') as f:
+            dump(value_dict, f, file_format=file_format)
+            f.write('\n')
+
+
+@WRITERS.register_module()
+class WandbWriter(BaseWriter):
+    """Write various types of data to wandb.
+
+    Examples:
+        >>> from mmengine.visualization import WandbWriter
+        >>> import numpy as np
+        >>> wandb_writer = WandbWriter(dict(type='DetVisualizer'))
+        >>> img=np.random.randint(0, 256, size=(10, 10, 3))
+        >>> wandb_writer.add_image('img', img)
+        >>> wandb_writer.add_scaler('mAP', 0.6)
+        >>> wandb_writer.add_scalars({'loss': [1, 2, 3],'acc': 0.8})
+        >>> wandb_writer.add_params(dict(lr=0.1, mode='linear'))
+
+        >>> wandb_writer.visualizer.draw_bboxes(np.array([0, 0, 1, 1]), \
+            edgecolors='g')
+        >>> wandb_writer.add_image('img', \
+            wandb_writer.visualizer.get_image())
+
+        >>> wandb_writer = WandbWriter()
+        >>> assert wandb_writer.visualizer is None
+        >>> wandb_writer.add_image('img', img)
+
+    Args:
+        init_kwargs (dict, optional): wandb initialization
+            input parameters. Default to None.
+        commit: (bool, optional) Save the metrics dict to the wandb server
+                and increment the step.  If false `wandb.log` just
+                updates the current metrics dict with the row argument
+                and metrics won't be saved until `wandb.log` is called
+                with `commit=True`. Default to True.
+        visualizer (dict, :obj:`Visualizer`, optional):
+            Visualizer instance or dictionary. Default to None.
+        save_dir (str, optional): The root directory to save the files
+            produced by the writer. Default to None.
+    """
+
+    def __init__(self,
+                 init_kwargs: Optional[dict] = None,
+                 commit: Optional[bool] = True,
+                 visualizer: Optional[Union[dict, 'Visualizer']] = None,
+                 save_dir: Optional[str] = None):
+        super(WandbWriter, self).__init__(visualizer, save_dir)
+        self._commit = commit
+        self._wandb = self._setup_env(init_kwargs)
+
+    @property
+    def experiment(self):
+        """Return wandb object.
+
+        The experiment attribute can get the wandb backend, If you want to
+        write other data, such as writing a table, you can directly get the
+        wandb backend through experiment.
+        """
+        return self._wandb
+
+    def _setup_env(self, init_kwargs: Optional[dict] = None) -> Any:
+        """Setup env.
+
+        Args:
+            init_kwargs (dict): The init args.
+
+        Return:
+            :obj:`wandb`
+        """
+        try:
+            import wandb
+        except ImportError:
+            raise ImportError(
+                'Please run "pip install wandb" to install wandb')
+        if init_kwargs:
+            wandb.init(**init_kwargs)
+        else:
+            wandb.init()
+
+        return wandb
+
+    def add_params(self, params_dict: dict, **kwargs) -> None:
+        """Record a set of parameters to be compared in wandb.
+
+        Args:
+            params_dict (dict): Each key-value pair in the dictionary
+                is the name of the parameters and it's
+                corresponding value.
+        """
+        assert isinstance(params_dict, dict)
+        self._wandb.log(params_dict, commit=self._commit)
+
+    def add_image(self,
+                  name: str,
+                  image: Optional[np.ndarray] = None,
+                  gt_sample: Optional['BaseDataSample'] = None,
+                  pred_sample: Optional['BaseDataSample'] = None,
+                  draw_gt: bool = True,
+                  draw_pred: bool = True,
+                  step: int = 0,
+                  **kwargs) -> None:
+        """Record image to wandb.
+
+        Args:
+            name (str): The unique identifier for the image to save.
+            image (np.ndarray, optional): The image to be saved. The format
+                should be RGB. Default to None.
+            gt_sample (:obj:`BaseDataSample`, optional): The ground truth data
+                structure of OpenMMlab. Default to None.
+            pred_sample (:obj:`BaseDataSample`, optional): The predicted result
+                data structure of OpenMMlab. Default to None.
+            draw_gt (bool): Whether to draw the ground truth. Default: True.
+            draw_pred (bool): Whether to draw the predicted result.
+                Default to True.
+            step (int): Global step value to record. Default to 0.
+        """
+        if self.visualizer:
+            self.visualizer.draw(image, gt_sample, pred_sample, draw_gt,
+                                 draw_pred)
+            self._wandb.log({name: self.visualizer.get_image()},
+                            commit=self._commit,
+                            step=step)
+        else:
+            self.add_image_to_wandb(name, image, gt_sample, pred_sample,
+                                    draw_gt, draw_pred, step, **kwargs)
+
+    def add_scalar(self,
+                   name: str,
+                   value: Union[int, float],
+                   step: int = 0,
+                   **kwargs) -> None:
+        """Record scalar data to wandb.
+
+        Args:
+            name (str): The unique identifier for the scalar to save.
+            value (float, int): Value to save.
+            step (int): Global step value to record. Default to 0.
+        """
+        self._wandb.log({name: value}, commit=self._commit, step=step)
+
+    def add_scalars(self,
+                    scalar_dict: dict,
+                    step: int = 0,
+                    file_path: Optional[str] = None,
+                    **kwargs) -> None:
+        """Record scalar's data to wandb.
+
+        Args:
+            scalar_dict (dict): Key-value pair storing the tag and
+                corresponding values.
+            step (int): Global step value to record. Default to 0.
+            file_path (str, optional): Useless parameter. Just for
+                interface unification. Default to None.
+        """
+        self._wandb.log(scalar_dict, commit=self._commit, step=step)
+
+    def add_image_to_wandb(self,
+                           name: str,
+                           image: np.ndarray,
+                           gt_sample: Optional['BaseDataSample'] = None,
+                           pred_sample: Optional['BaseDataSample'] = None,
+                           draw_gt: bool = True,
+                           draw_pred: bool = True,
+                           step: int = 0,
+                           **kwargs) -> None:
+        """Record image to wandb.
+
+        Args:
+            name (str): The unique identifier for the image to save.
+            image (np.ndarray): The image to be saved. The format
+                should be BGR.
+            gt_sample (:obj:`BaseDataSample`, optional): The ground truth data
+                structure of OpenMMlab. Default to None.
+            pred_sample (:obj:`BaseDataSample`, optional): The predicted result
+                data structure of OpenMMlab. Default to None.
+            draw_gt (bool): Whether to draw the ground truth. Default to True.
+            draw_pred (bool): Whether to draw the predicted result.
+                Default to True.
+            step (int): Global step value to record. Default to 0.
+        """
+        raise NotImplementedError()
+
+    def close(self) -> None:
+        """close an opened wandb object."""
+        if hasattr(self, '_wandb'):
+            self._wandb.join()
+
+
+@WRITERS.register_module()
+class TensorboardWriter(BaseWriter):
+    """Tensorboard write class. It can write images, hyperparameters, scalars,
+    etc. to a tensorboard file.
+
+    Its drawing function is provided by Visualizer.
+
+    Examples:
+        >>> from mmengine.visualization import TensorboardWriter
+        >>> import numpy as np
+        >>> tensorboard_writer = TensorboardWriter(dict(type='DetVisualizer'),\
+            save_dir='temp_dir')
+        >>> img=np.random.randint(0, 256, size=(10, 10, 3))
+        >>> tensorboard_writer.add_image('img', img)
+        >>> tensorboard_writer.add_scaler('mAP', 0.6)
+        >>> tensorboard_writer.add_scalars({'loss': 0.1,'acc':0.8})
+        >>> tensorboard_writer.add_params(dict(lr=0.1, mode='linear'))
+
+        >>> tensorboard_writer.visualizer.draw_bboxes(np.array([0, 0, 1, 1]), \
+            edgecolors='g')
+        >>> tensorboard_writer.add_image('img', \
+            tensorboard_writer.visualizer.get_image())
+
+    Args:
+        save_dir (str): The root directory to save the files
+            produced by the writer.
+        visualizer (dict, :obj:`Visualizer`, optional): Visualizer instance
+            or dictionary. Default to None.
+        log_dir (str): Save directory location. Default to 'tf_writer'.
+    """
+
+    def __init__(self,
+                 save_dir: str,
+                 visualizer: Optional[Union[dict, 'Visualizer']] = None,
+                 log_dir: str = 'tf_logs'):
+        super(TensorboardWriter, self).__init__(visualizer, save_dir)
+        self._tensorboard = self._setup_env(log_dir)
+
+    def _setup_env(self, log_dir: str):
+        """Setup env.
+
+        Args:
+            log_dir (str): Save directory location. Default 'tf_writer'.
+
+        Return:
+            :obj:`SummaryWriter`
+        """
+        if TORCH_VERSION == 'parrots':
+            try:
+                from tensorboardX import SummaryWriter
+            except ImportError:
+                raise ImportError('Please install tensorboardX to use '
+                                  'TensorboardLoggerHook.')
+        else:
+            try:
+                from torch.utils.tensorboard import SummaryWriter
+            except ImportError:
+                raise ImportError(
+                    'Please run "pip install future tensorboard" to install '
+                    'the dependencies to use torch.utils.tensorboard '
+                    '(applicable to PyTorch 1.1 or higher)')
+
+        self.log_dir = osp.join(self._save_dir, log_dir)  # type: ignore
+        return SummaryWriter(self.log_dir)
+
+    @property
+    def experiment(self):
+        """Return Tensorboard object."""
+        return self._tensorboard
+
+    def add_graph(self, model: torch.nn.Module,
+                  input_tensor: Union[torch.Tensor,
+                                      List[torch.Tensor]], **kwargs) -> None:
+        """Record graph data to tensorboard.
+
+        Args:
+            model (torch.nn.Module): Model to draw.
+            input_tensor (torch.Tensor, list[torch.Tensor]): A variable
+                or a tuple of variables to be fed.
+        """
+        if isinstance(input_tensor, list):
+            for array in input_tensor:
+                assert array.ndim == 4
+                assert isinstance(array, torch.Tensor)
+        else:
+            assert isinstance(input_tensor,
+                              torch.Tensor) and input_tensor.ndim == 4
+        self._tensorboard.add_graph(model, input_tensor)
+
+    def add_params(self, params_dict: dict, **kwargs) -> None:
+        """Record a set of hyperparameters to be compared in TensorBoard.
+
+        Args:
+            params_dict (dict): Each key-value pair in the dictionary is the
+                  name of the hyper parameter and it's corresponding value.
+                  The type of the value can be one of `bool`, `string`,
+                   `float`, `int`, or `None`.
+        """
+        assert isinstance(params_dict, dict)
+        self._tensorboard.add_hparams(params_dict, {})
+
+    def add_image(self,
+                  name: str,
+                  image: Optional[np.ndarray] = None,
+                  gt_sample: Optional['BaseDataSample'] = None,
+                  pred_sample: Optional['BaseDataSample'] = None,
+                  draw_gt: bool = True,
+                  draw_pred: bool = True,
+                  step: int = 0,
+                  **kwargs) -> None:
+        """Record image to tensorboard.
+
+        Args:
+            name (str): The unique identifier for the image to save.
+            image (np.ndarray, optional): The image to be saved. The format
+                should be RGB. Default to None.
+            gt_sample (:obj:`BaseDataSample`, optional): The ground truth data
+                structure of OpenMMlab. Default to None.
+            pred_sample (:obj:`BaseDataSample`, optional): The predicted result
+                data structure of OpenMMlab. Default to None.
+            draw_gt (bool): Whether to draw the ground truth. Default to True.
+            draw_pred (bool): Whether to draw the predicted result.
+                Default to True.
+            step (int): Global step value to record. Default to 0.
+        """
+        assert self.visualizer, 'Please instantiate the visualizer ' \
+                                'object with initialization parameters.'
+        self.visualizer.draw(image, gt_sample, pred_sample, draw_gt, draw_pred)
+        self._tensorboard.add_image(
+            name, self.visualizer.get_image(), step, dataformats='HWC')
+
+    def add_scalar(self,
+                   name: str,
+                   value: Union[int, float],
+                   step: int = 0,
+                   **kwargs) -> None:
+        """Record scalar data to summary.
+
+        Args:
+            name (str): The unique identifier for the scalar to save.
+            value (float, int): Value to save.
+            step (int): Global step value to record. Default to 0.
+        """
+        self._tensorboard.add_scalar(name, value, step)
+
+    def add_scalars(self,
+                    scalar_dict: dict,
+                    step: int = 0,
+                    file_path: Optional[str] = None,
+                    **kwargs) -> None:
+        """Record scalar's data to summary.
+
+        Args:
+            scalar_dict (dict): Key-value pair storing the tag and
+                corresponding values.
+            step (int): Global step value to record. Default to 0.
+            file_path (str, optional): Useless parameter. Just for
+                interface unification. Default to None.
+        """
+        assert isinstance(scalar_dict, dict)
+        assert 'step' not in scalar_dict, 'Please set it directly ' \
+                                          'through the step parameter'
+        for key, value in scalar_dict.items():
+            self.add_scalar(key, value, step)
+
+    def close(self):
+        """close an opened tensorboard object."""
+        if hasattr(self, '_tensorboard'):
+            self._tensorboard.close()
+
+
+class ComposedWriter(BaseGlobalAccessible):
+    """Wrapper class to compose multiple a subclass of :class:`BaseWriter`
+    instances. By inheriting BaseGlobalAccessible, it can be accessed anywhere
+    once instantiated.
+
+    Examples:
+        >>> from mmengine.visualization import ComposedWriter
+        >>> import numpy as np
+        >>> composed_writer= ComposedWriter.create_instance( \
+            'composed_writer', writers=[dict(type='LocalWriter', \
+            visualizer=dict(type='DetVisualizer'), \
+            save_dir='temp_dir'), dict(type='WandbWriter')])
+        >>> img=np.random.randint(0, 256, size=(10, 10, 3))
+        >>> composed_writer.add_image('img', img)
+        >>> composed_writer.add_scalar('mAP', 0.6)
+        >>> composed_writer.add_scalars({'loss': 0.1,'acc':0.8})
+        >>> composed_writer.add_params(dict(lr=0.1, mode='linear'))
+
+    Args:
+        name (str): The name of the instance. Defaults: 'composed_writer'.
+        writers (list, optional): The writers to compose. Default to None
+    """
+
+    def __init__(self,
+                 name: str = 'composed_writer',
+                 writers: Optional[List[Union[dict, 'BaseWriter']]] = None):
+        super().__init__(name)
+        self._writers = []
+        if writers is not None:
+            assert isinstance(writers, list)
+            for writer in writers:
+                if isinstance(writer, dict):
+                    self._writers.append(WRITERS.build(writer))
+                else:
+                    assert isinstance(writer, BaseWriter), \
+                        f'writer should be an instance of a subclass of ' \
+                        f'BaseWriter, but got {type(writer)}'
+                    self._writers.append(writer)
+
+    def __len__(self):
+        return len(self._writers)
+
+    def get_writer(self, index: int) -> 'BaseWriter':
+        """Returns the writer object corresponding to the specified index."""
+        return self._writers[index]
+
+    def get_experiment(self, index: int) -> Any:
+        """Returns the writer's experiment object corresponding to the
+        specified index."""
+        return self._writers[index].experiment
+
+    def get_visualizer(self, index: int) -> 'Visualizer':
+        """Returns the writer's visualizer object corresponding to the
+        specified index."""
+        return self._writers[index].visualizer
+
+    def add_params(self, params_dict: dict, **kwargs):
+        """Record parameters.
+
+        Args:
+            params_dict (dict): The dictionary of parameters to save.
+        """
+        for writer in self._writers:
+            writer.add_params(params_dict, **kwargs)
+
+    def add_graph(self, model: torch.nn.Module,
+                  input_array: Union[torch.Tensor,
+                                     List[torch.Tensor]], **kwargs) -> None:
+        """Record graph data.
+
+        Args:
+            model (torch.nn.Module): Model to draw.
+            input_array (torch.Tensor, list[torch.Tensor]): A variable
+                or a tuple of variables to be fed.
+        """
+        for writer in self._writers:
+            writer.add_graph(model, input_array, **kwargs)
+
+    def add_image(self,
+                  name: str,
+                  image: Optional[np.ndarray] = None,
+                  gt_sample: Optional['BaseDataSample'] = None,
+                  pred_sample: Optional['BaseDataSample'] = None,
+                  draw_gt: bool = True,
+                  draw_pred: bool = True,
+                  step: int = 0,
+                  **kwargs) -> None:
+        """Record image.
+
+        Args:
+            name (str): The unique identifier for the image to save.
+            image (np.ndarray, optional): The image to be saved. The format
+                should be RGB. Default to None.
+            gt_sample (:obj:`BaseDataSample`, optional): The ground truth data
+                structure of OpenMMlab. Default to None.
+            pred_sample (:obj:`BaseDataSample`, optional): The predicted result
+                data structure of OpenMMlab. Default to None.
+            draw_gt (bool): Whether to draw the ground truth. Default to True.
+            draw_pred (bool): Whether to draw the predicted result.
+                Default to True.
+            step (int): Global step value to record. Default to 0.
+        """
+        for writer in self._writers:
+            writer.add_image(name, image, gt_sample, pred_sample, draw_gt,
+                             draw_pred, step, **kwargs)
+
+    def add_scalar(self,
+                   name: str,
+                   value: Union[int, float],
+                   step: int = 0,
+                   **kwargs) -> None:
+        """Record scalar data.
+
+        Args:
+            name (str): The unique identifier for the scalar to save.
+            value (float, int): Value to save.
+            step (int): Global step value to record. Default to 0.
+        """
+        for writer in self._writers:
+            writer.add_scalar(name, value, step, **kwargs)
+
+    def add_scalars(self,
+                    scalar_dict: dict,
+                    step: int = 0,
+                    file_path: Optional[str] = None,
+                    **kwargs) -> None:
+        """Record scalars' data.
+
+        Args:
+            scalar_dict (dict): Key-value pair storing the tag and
+                corresponding values.
+            step (int): Global step value to record. Default to 0.
+            file_path (str, optional): The scalar's data will be
+                saved to the `file_path` file at the same time
+                if the `file_path` parameter is specified.
+                Default to None.
+        """
+        for writer in self._writers:
+            writer.add_scalars(scalar_dict, step, file_path, **kwargs)
+
+    def close(self) -> None:
+        """close an opened object."""
+        for writer in self._writers:
+            writer.close()
diff --git a/tests/test_visualizer/test_writer.py b/tests/test_visualizer/test_writer.py
index 718ff8da..5219a2a4 100644
--- a/tests/test_visualizer/test_writer.py
+++ b/tests/test_visualizer/test_writer.py
@@ -1,141 +1,484 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-import random
+import os
+import shutil
 import sys
-from unittest.mock import MagicMock
+from unittest.mock import MagicMock, Mock, patch
 
 import numpy as np
 import pytest
 import torch
+import torch.nn as nn
 
-from mmengine.data import BaseDataElement, BaseDataSample
-from mmengine.visualizer import (VISUALIZERS, LocalWriter, TensorboardWriter,
-                                 WandbWriter)
+from mmengine.fileio import load
+from mmengine.registry import VISUALIZERS, WRITERS
+from mmengine.visualization import (ComposedWriter, LocalWriter,
+                                    TensorboardWriter, WandbWriter)
 
 
-def get_demo_datasample():
-    metainfo = dict(
-        img_id=random.randint(0, 100),
-        img_shape=(random.randint(400, 600), random.randint(400, 600)))
-    gt_instances = BaseDataElement(
-        data=dict(bboxes=torch.rand((5, 4)), labels=torch.rand((5, ))))
-    pred_instances = BaseDataElement(
-        data=dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5, ))))
-    data = dict(gt_instances=gt_instances, pred_instances=pred_instances)
-    data_sample = BaseDataSample(data=data, metainfo=metainfo)
-    return data_sample
+def draw(self, image, gt_sample, pred_sample, show_gt=True, show_pred=True):
+    self.set_image(image)
 
 
 class TestLocalWriter:
 
+    def test_init(self):
+        # visuailzer must be a dictionary or an instance
+        # of Visualizer and its subclasses
+        with pytest.raises(AssertionError):
+            LocalWriter('temp_dir', [dict(type='Visualizer')])
+
+        # 'params_save_file' format must be yaml
+        with pytest.raises(AssertionError):
+            LocalWriter('temp_dir', params_save_file='a.txt')
+
+        # 'scalar_save_file' format must be json
+        with pytest.raises(AssertionError):
+            LocalWriter('temp_dir', scalar_save_file='a.yaml')
+
+        local_writer = LocalWriter('temp_dir')
+        assert os.path.exists(local_writer._save_dir)
+        shutil.rmtree('temp_dir')
+
+        local_writer = WRITERS.build(
+            dict(
+                type='LocalWriter',
+                visualizer=dict(type='Visualizer'),
+                save_dir='temp_dir'))
+        assert os.path.exists(local_writer._save_dir)
+        shutil.rmtree('temp_dir')
+
+    def test_experiment(self):
+        local_writer = LocalWriter('temp_dir')
+        assert local_writer.experiment == local_writer
+        shutil.rmtree('temp_dir')
+
+    def test_add_params(self):
+        local_writer = LocalWriter('temp_dir')
+
+        # 'params_dict' must be dict
+        with pytest.raises(AssertionError):
+            local_writer.add_params(['lr', 0])
+
+        params_dict = dict(lr=0.1, wd=[1.0, 0.1, 0.001], mode='linear')
+        local_writer.add_params(params_dict)
+        out_dict = load(local_writer._params_save_file, 'yaml')
+        assert out_dict == params_dict
+        shutil.rmtree('temp_dir')
+
+    @patch('mmengine.visualization.visualizer.Visualizer.draw', draw)
     def test_add_image(self):
-        image = np.random.randint(0, 256, size=(10, 10, 3))
-        data_sample = get_demo_datasample()
+        image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8)
+
+        # The visuailzer parameter must be set when
+        # the local_writer object is instantiated and
+        # the `add_image` method is called.
+        with pytest.raises(AssertionError):
+            local_writer = LocalWriter('temp_dir')
+            local_writer.add_image('img', image)
 
-        local_writer = LocalWriter(visuailzer=dict(type='Visualizer'))
+        local_writer = LocalWriter('temp_dir', dict(type='Visualizer'))
         local_writer.add_image('img', image)
-        local_writer.add_image('img', image, data_sample)
+        assert os.path.exists(
+            os.path.join(local_writer._img_save_dir, 'img_0.png'))
 
         bboxes = np.array([[1, 1, 2, 2], [1, 1.5, 1, 2.5]])
         local_writer.visualizer.draw_bboxes(bboxes)
-        local_writer.add_image('img', local_writer.visualizer.get_image())
+        local_writer.add_image(
+            'img', local_writer.visualizer.get_image(), step=2)
+        assert os.path.exists(
+            os.path.join(local_writer._img_save_dir, 'img_2.png'))
 
         visuailzer = VISUALIZERS.build(dict(type='Visualizer'))
-        local_writer = LocalWriter(visuailzer=visuailzer)
+        local_writer = LocalWriter('temp_dir', visuailzer)
         local_writer.add_image('img', image)
-        local_writer.add_image('img', image, data_sample)
+        assert os.path.exists(
+            os.path.join(local_writer._img_save_dir, 'img_0.png'))
+
+        shutil.rmtree('temp_dir')
+
+    def test_add_scalar(self):
+        local_writer = LocalWriter('temp_dir')
+        local_writer.add_scalar('map', 0.9)
+        out_dict = load(local_writer._scalar_save_file, 'json')
+        assert out_dict == {'map': 0.9, 'step': 0}
+        shutil.rmtree('temp_dir')
+
+        # test append mode
+        local_writer = LocalWriter('temp_dir')
+        local_writer.add_scalar('map', 0.9, step=0)
+        local_writer.add_scalar('map', 0.95, step=1)
+        with open(local_writer._scalar_save_file) as f:
+            out_dict = f.read()
+        assert out_dict == '{"map": 0.9, "step": 0}\n{"map": ' \
+                           '0.95, "step": 1}\n'
+        shutil.rmtree('temp_dir')
+
+    def test_add_scalars(self):
+        local_writer = LocalWriter('temp_dir')
+        input_dict = {'map': 0.7, 'acc': 0.9}
+        local_writer.add_scalars(input_dict)
+        out_dict = load(local_writer._scalar_save_file, 'json')
+        assert out_dict == {'map': 0.7, 'acc': 0.9, 'step': 0}
+
+        # test append mode
+        local_writer.add_scalars({'map': 0.8, 'acc': 0.8}, step=1)
+        with open(local_writer._scalar_save_file) as f:
+            out_dict = f.read()
+        assert out_dict == '{"map": 0.7, "acc": 0.9, ' \
+                           '"step": 0}\n{"map": 0.8, "acc": 0.8, "step": 1}\n'
+
+        # test file_path
+        local_writer = LocalWriter('temp_dir')
+        local_writer.add_scalars(input_dict, file_path='temp.json')
+        assert os.path.exists(local_writer._scalar_save_file)
+        assert os.path.exists(
+            os.path.join(local_writer._save_dir, 'temp.json'))
+
+        # file_path and scalar_save_file cannot be the same
+        with pytest.raises(AssertionError):
+            local_writer.add_scalars(input_dict, file_path='scalars.json')
+
+        shutil.rmtree('temp_dir')
 
-        # test `visuailzer` parameter
-        # `visuailzer` parameter which must be either Visualizer instance
-        # or valid dictionary.
+
+class TestTensorboardWriter:
+    sys.modules['torch.utils.tensorboard'] = MagicMock()
+    sys.modules['tensorboardX'] = MagicMock()
+
+    def test_init(self):
+        # visuailzer must be a dictionary or an instance
+        # of Visualizer and its subclasses
         with pytest.raises(AssertionError):
+            LocalWriter('temp_dir', [dict(type='Visualizer')])
+
+        TensorboardWriter('temp_dir')
+        WRITERS.build(
+            dict(
+                type='TensorboardWriter',
+                visualizer=dict(type='Visualizer'),
+                save_dir='temp_dir'))
+
+    def test_experiment(self):
+        tensorboard_writer = TensorboardWriter('temp_dir')
+        assert tensorboard_writer.experiment == tensorboard_writer._tensorboard
+
+    def test_add_graph(self):
 
-            class A:
-                pass
+        class Model(nn.Module):
+
+            def __init__(self):
+                super().__init__()
+                self.conv = nn.Conv2d(1, 2, 1)
+
+            def forward(self, x, y=None):
+                return self.conv(x)
+
+        tensorboard_writer = TensorboardWriter('temp_dir')
+
+        # input must be tensor
+        with pytest.raises(AssertionError):
+            tensorboard_writer.add_graph(Model(), np.zeros([1, 1, 3, 3]))
 
-            LocalWriter(visuailzer=A())
+        # input must be 4d tensor
         with pytest.raises(AssertionError):
-            LocalWriter(visuailzer=dict(a='Visualizer'))
+            tensorboard_writer.add_graph(Model(), torch.zeros([1, 3, 3]))
+
+        # If the input is a list, the inner element must be a 4d tensor
+        with pytest.raises(AssertionError):
+            tensorboard_writer.add_graph(
+                Model(), [torch.zeros([1, 1, 3, 3]),
+                          torch.zeros([1, 3, 3])])
+
+        tensorboard_writer.add_graph(Model(), torch.zeros([1, 1, 3, 3]))
+        tensorboard_writer.add_graph(
+            Model(), [torch.zeros([1, 1, 3, 3]),
+                      torch.zeros([1, 1, 3, 3])])
+
+    def test_add_params(self):
+        tensorboard_writer = TensorboardWriter('temp_dir')
+
+        # 'params_dict' must be dict
+        with pytest.raises(AssertionError):
+            tensorboard_writer.add_params(['lr', 0])
+
+        params_dict = dict(lr=0.1, wd=0.2, mode='linear')
+        tensorboard_writer.add_params(params_dict)
+
+    @patch('mmengine.visualization.visualizer.Visualizer.draw', draw)
+    def test_add_image(self):
+        image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8)
 
-        # test not visuailzer
         # The visuailzer parameter must be set when
         # the local_writer object is instantiated and
         # the `add_image` method is called.
         with pytest.raises(AssertionError):
-            local_writer = LocalWriter()
-            local_writer.add_image('img', image)
+            tensorboard_writer = TensorboardWriter('temp_dir')
+            tensorboard_writer.add_image('img', image)
+
+        tensorboard_writer = TensorboardWriter('temp_dir',
+                                               dict(type='Visualizer'))
+        tensorboard_writer.add_image('img', image)
 
-    def test_add_scaler(self):
-        local_writer = LocalWriter()
-        local_writer.add_scaler('map', 0.9)
+        bboxes = np.array([[1, 1, 2, 2], [1, 1.5, 1, 2.5]])
+        tensorboard_writer.visualizer.draw_bboxes(bboxes)
+        tensorboard_writer.add_image(
+            'img', tensorboard_writer.visualizer.get_image(), step=2)
 
-    def test_add_hyperparams(self):
-        local_writer = LocalWriter()
-        local_writer.add_hyperparams('hyper', dict(lr=0.01))
+        visuailzer = VISUALIZERS.build(dict(type='Visualizer'))
+        tensorboard_writer = TensorboardWriter('temp_dir', visuailzer)
+        tensorboard_writer.add_image('img', image)
+
+    def test_add_scalar(self):
+        tensorboard_writer = TensorboardWriter('temp_dir')
+        tensorboard_writer.add_scalar('map', 0.9)
+        # test append mode
+        tensorboard_writer.add_scalar('map', 0.9, step=0)
+        tensorboard_writer.add_scalar('map', 0.95, step=1)
+
+    def test_add_scalars(self):
+        tensorboard_writer = TensorboardWriter('temp_dir')
+        # The step value must be passed through the parameter
+        with pytest.raises(AssertionError):
+            tensorboard_writer.add_scalars({'map': 0.7, 'acc': 0.9, 'step': 1})
+
+        input_dict = {'map': 0.7, 'acc': 0.9}
+        tensorboard_writer.add_scalars(input_dict)
+        # test append mode
+        tensorboard_writer.add_scalars({'map': 0.8, 'acc': 0.8}, step=1)
 
 
 class TestWandbWriter:
     sys.modules['wandb'] = MagicMock()
 
-    def test_add_image(self):
-        image = np.random.randint(0, 256, size=(10, 10, 3))
-        data_sample = get_demo_datasample()
+    def test_init(self):
+        WandbWriter()
+        WRITERS.build(
+            dict(
+                type='WandbWriter',
+                visualizer=dict(type='Visualizer'),
+                save_dir='temp_dir'))
 
+    def test_experiment(self):
         wandb_writer = WandbWriter()
-        assert not wandb_writer.visualizer
-        wandb_writer.add_image('img', image, data_sample)
+        assert wandb_writer.experiment == wandb_writer._wandb
 
-        wandb_writer = WandbWriter(visuailzer=dict(type='Visualizer'))
-        assert wandb_writer.visualizer
+    def test_add_params(self):
+        wandb_writer = WandbWriter()
+
+        # 'params_dict' must be dict
+        with pytest.raises(AssertionError):
+            wandb_writer.add_params(['lr', 0])
+
+        params_dict = dict(lr=0.1, wd=0.2, mode='linear')
+        wandb_writer.add_params(params_dict)
+
+    @patch('mmengine.visualization.visualizer.Visualizer.draw', draw)
+    @patch('mmengine.visualization.writer.WandbWriter.add_image_to_wandb',
+           Mock)
+    def test_add_image(self):
+        image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8)
+
+        wandb_writer = WandbWriter()
         wandb_writer.add_image('img', image)
-        wandb_writer.add_image('img', image, data_sample)
 
+        wandb_writer = WandbWriter(visualizer=dict(type='Visualizer'))
+        bboxes = np.array([[1, 1, 2, 2], [1, 1.5, 1, 2.5]])
         wandb_writer.visualizer.set_image(image)
-        wandb_writer.add_image('img', wandb_writer.visualizer.get_image())
+        wandb_writer.visualizer.draw_bboxes(bboxes)
+        wandb_writer.add_image(
+            'img', wandb_writer.visualizer.get_image(), step=2)
 
-        # TODO test file exist
+        visuailzer = VISUALIZERS.build(dict(type='Visualizer'))
+        wandb_writer = WandbWriter(visualizer=visuailzer)
+        wandb_writer.add_image('img', image)
 
-    def test_add_scaler(self):
+    def test_add_scalar(self):
         wandb_writer = WandbWriter()
-        wandb_writer.add_scaler('map', 0.9)
+        wandb_writer.add_scalar('map', 0.9)
+        # test append mode
+        wandb_writer.add_scalar('map', 0.9, step=0)
+        wandb_writer.add_scalar('map', 0.95, step=1)
 
-    def test_add_hyperparams(self):
+    def test_add_scalars(self):
         wandb_writer = WandbWriter()
-        wandb_writer.add_hyperparams('hyper', dict(lr=0.01))
+        input_dict = {'map': 0.7, 'acc': 0.9}
+        wandb_writer.add_scalars(input_dict)
+        # test append mode
+        wandb_writer.add_scalars({'map': 0.8, 'acc': 0.8}, step=1)
 
 
-class TestTensorboardWriter:
-    sys.modules['torch.utils.tensorboard.SummaryWriter'] = MagicMock()
+class TestComposedWriter:
+    sys.modules['torch.utils.tensorboard'] = MagicMock()
+    sys.modules['tensorboardX'] = MagicMock()
+    sys.modules['wandb'] = MagicMock()
 
-    def test_add_image(self):
-        image = np.random.randint(0, 256, size=(10, 10, 3))
-        data_sample = get_demo_datasample()
+    def test_init(self):
 
-        tensorboard_writer = TensorboardWriter()
-        assert not tensorboard_writer.visualizer
-        tensorboard_writer.add_image('img', image, data_sample)
+        class A:
+            pass
 
-        tensorboard_writer = TensorboardWriter(
-            visuailzer=dict(type='Visualizer'))
-        assert tensorboard_writer.visualizer
-        tensorboard_writer.add_image('img', image)
-        tensorboard_writer.add_image('img', image, data_sample)
+        # The writers inner element must be a dictionary or a
+        # subclass of Writer.
+        with pytest.raises(AssertionError):
+            ComposedWriter(writers=[A()])
+
+        composed_writer = ComposedWriter(writers=[
+            WandbWriter(),
+            dict(
+                type='TensorboardWriter',
+                visualizer=dict(type='Visualizer'),
+                save_dir='temp_dir')
+        ])
+        assert len(composed_writer._writers) == 2
+
+        # test global
+        composed_writer = ComposedWriter.create_instance(
+            'composed_writer',
+            writers=[
+                WandbWriter(),
+                dict(
+                    type='TensorboardWriter',
+                    visualizer=dict(type='Visualizer'),
+                    save_dir='temp_dir')
+            ])
+        assert len(composed_writer._writers) == 2
+        composed_writer_any = ComposedWriter.get_instance('composed_writer')
+        assert composed_writer_any == composed_writer
+
+    def test_get_writer(self):
+        composed_writer = ComposedWriter(writers=[
+            WandbWriter(),
+            dict(
+                type='TensorboardWriter',
+                visualizer=dict(type='Visualizer'),
+                save_dir='temp_dir')
+        ])
+        assert isinstance(composed_writer.get_writer(0), WandbWriter)
+        assert isinstance(composed_writer.get_writer(1), TensorboardWriter)
+
+    def test_get_experiment(self):
+        composed_writer = ComposedWriter(writers=[
+            WandbWriter(),
+            dict(
+                type='TensorboardWriter',
+                visualizer=dict(type='Visualizer'),
+                save_dir='temp_dir')
+        ])
+        assert composed_writer.get_experiment(
+            0) == composed_writer._writers[0].experiment
+        assert composed_writer.get_experiment(
+            1) == composed_writer._writers[1].experiment
+
+    def test_get_visualizer(self):
+        composed_writer = ComposedWriter(writers=[
+            WandbWriter(),
+            dict(
+                type='TensorboardWriter',
+                visualizer=dict(type='Visualizer'),
+                save_dir='temp_dir')
+        ])
+        assert composed_writer.get_visualizer(
+            0) == composed_writer._writers[0].visualizer
+        assert composed_writer.get_visualizer(
+            1) == composed_writer._writers[1].visualizer
+
+    def test_add_params(self):
+        composed_writer = ComposedWriter(writers=[
+            WandbWriter(),
+            dict(
+                type='TensorboardWriter',
+                visualizer=dict(type='Visualizer'),
+                save_dir='temp_dir')
+        ])
+
+        # 'params_dict' must be dict
+        with pytest.raises(AssertionError):
+            composed_writer.add_params(['lr', 0])
 
-        tensorboard_writer.visualizer.set_image(image)
-        tensorboard_writer.add_image('img',
-                                     tensorboard_writer.visualizer.get_image())
+        params_dict = dict(lr=0.1, wd=0.2, mode='linear')
+        composed_writer.add_params(params_dict)
 
-        # test no visualizer
-        # The visuailzer parameter must be set when
-        # the tensorboard_writer object is instantiated and
-        # the `add_image` method is called.
+    def test_add_graph(self):
+        composed_writer = ComposedWriter(writers=[
+            WandbWriter(),
+            dict(
+                type='TensorboardWriter',
+                visualizer=dict(type='Visualizer'),
+                save_dir='temp_dir')
+        ])
+
+        class Model(nn.Module):
+
+            def __init__(self):
+                super().__init__()
+                self.conv = nn.Conv2d(1, 2, 1)
+
+            def forward(self, x, y=None):
+                return self.conv(x)
+
+        # input must be tensor
         with pytest.raises(AssertionError):
-            tensorboard_writer = TensorboardWriter()
-            tensorboard_writer.add_image('img', image)
+            composed_writer.add_graph(Model(), np.zeros([1, 1, 3, 3]))
+
+        # input must be 4d tensor
+        with pytest.raises(AssertionError):
+            composed_writer.add_graph(Model(), torch.zeros([1, 3, 3]))
 
-    def test_add_scaler(self):
-        tensorboard_writer = TensorboardWriter()
-        tensorboard_writer.add_scaler('map', 0.9)
+        # If the input is a list, the inner element must be a 4d tensor
+        with pytest.raises(AssertionError):
+            composed_writer.add_graph(
+                Model(), [torch.zeros([1, 1, 3, 3]),
+                          torch.zeros([1, 3, 3])])
+
+        composed_writer.add_graph(Model(), torch.zeros([1, 1, 3, 3]))
+        composed_writer.add_graph(
+            Model(), [torch.zeros([1, 1, 3, 3]),
+                      torch.zeros([1, 1, 3, 3])])
+
+    @patch('mmengine.visualization.visualizer.Visualizer.draw', draw)
+    @patch('mmengine.visualization.writer.WandbWriter.add_image_to_wandb',
+           Mock)
+    def test_add_image(self):
+        composed_writer = ComposedWriter(writers=[
+            WandbWriter(),
+            dict(
+                type='TensorboardWriter',
+                visualizer=dict(type='Visualizer'),
+                save_dir='temp_dir')
+        ])
 
-    def test_add_hyperparams(self):
-        tensorboard_writer = TensorboardWriter()
-        tensorboard_writer.add_hyperparams('hyper', dict(lr=0.01))
+        image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8)
+        composed_writer.add_image('img', image)
+
+        bboxes = np.array([[1, 1, 2, 2], [1, 1.5, 1, 2.5]])
+        composed_writer.get_writer(1).visualizer.draw_bboxes(bboxes)
+        composed_writer.get_writer(1).add_image(
+            'img',
+            composed_writer.get_writer(1).visualizer.get_image(),
+            step=2)
+
+    def test_add_scalar(self):
+        composed_writer = ComposedWriter(writers=[
+            WandbWriter(),
+            dict(
+                type='TensorboardWriter',
+                visualizer=dict(type='Visualizer'),
+                save_dir='temp_dir')
+        ])
+        composed_writer.add_scalar('map', 0.9)
+        # test append mode
+        composed_writer.add_scalar('map', 0.9, step=0)
+        composed_writer.add_scalar('map', 0.95, step=1)
+
+    def test_add_scalars(self):
+        composed_writer = ComposedWriter(writers=[
+            WandbWriter(),
+            dict(
+                type='TensorboardWriter',
+                visualizer=dict(type='Visualizer'),
+                save_dir='temp_dir')
+        ])
+        input_dict = {'map': 0.7, 'acc': 0.9}
+        composed_writer.add_scalars(input_dict)
+        # test append mode
+        composed_writer.add_scalars({'map': 0.8, 'acc': 0.8}, step=1)
-- 
GitLab