From 6f69039ca9abd3c26211c70c1a77d88454bd3e43 Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Tue, 8 Mar 2022 16:10:30 +0800
Subject: [PATCH] [Feature] Add LoggerHook (#77)

* add logger hook

* update

* update

* update test

* update

* update test

* update

* update

* update

* update

* update

* Add logger hook

* Fix pre-commit

* Fix as comment

* Fix as comment

* Fix as comment

* Fix as comment

* Fix as comment

* Fix bytes

* update

* Fix as comment

* Fix as comment

* Update runner

* Fix as comment

* Fix as comment

* Fix as comment

* Fix as comment
---
 mmengine/hooks/__init__.py          |   4 +-
 mmengine/hooks/logger_hook.py       | 509 ++++++++++++++++++++++++++++
 tests/test_hook/test_logger_hook.py | 355 +++++++++++++++++++
 3 files changed, 867 insertions(+), 1 deletion(-)
 create mode 100644 mmengine/hooks/logger_hook.py
 create mode 100644 tests/test_hook/test_logger_hook.py

diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py
index 1cb5b535..45c3f910 100644
--- a/mmengine/hooks/__init__.py
+++ b/mmengine/hooks/__init__.py
@@ -3,6 +3,7 @@ from .checkpoint_hook import CheckpointHook
 from .empty_cache_hook import EmptyCacheHook
 from .hook import Hook
 from .iter_timer_hook import IterTimerHook
+from .logger_hook import LoggerHook
 from .optimizer_hook import OptimizerHook
 from .param_scheduler_hook import ParamSchedulerHook
 from .sampler_seed_hook import DistSamplerSeedHook
@@ -10,5 +11,6 @@ from .sync_buffer_hook import SyncBuffersHook
 
 __all__ = [
     'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
-    'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook'
+    'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook',
+    'LoggerHook'
 ]
diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py
new file mode 100644
index 00000000..741bfd95
--- /dev/null
+++ b/mmengine/hooks/logger_hook.py
@@ -0,0 +1,509 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import datetime
+import os
+import os.path as osp
+from collections import OrderedDict
+from pathlib import Path
+from typing import Any, Optional, Sequence, Tuple, Union
+
+import torch
+
+from mmengine.data import BaseDataSample
+from mmengine.fileio import FileClient
+from mmengine.hooks import Hook
+from mmengine.registry import HOOKS
+from mmengine.utils import is_tuple_of, scandir
+
+DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
+
+
+@HOOKS.register_module()
+class LoggerHook(Hook):
+    """In this logger hook, the information will be printed on the terminal and
+    saved in JSON file, tensorboard, wandb .etc.
+
+    Args:
+        by_epoch (bool): Whether ``EpochBasedLoop`` is used.
+            Defaults to True.
+        interval (int): Logging interval (every k iterations).
+            Defaults to 10.
+        custom_keys (dict, optional): Defines the keys in the log and which
+            kinds of statistic methods should be used to log them.
+
+            - ``custom_keys`` contains multiple string-dict pairs. In each
+            string-dict pair, the string defines a key name in the log and the
+            dict is a config defines the statistic methods and corresponding
+            arguments used to log the value. For example,
+            ``dict(loss=dict(method_name='mean', log_name='global_loss',
+            window_size='global'))`` which means the log key ``loss`` will be
+            counted as global mean and additionally logged as ``global_loss``.
+            If ``log_name`` is not defined in config dict, the original logged
+            key will be overwritten.
+            - The key in ``LoggerHook.fixed_smooth_keys`` cannot be overwritten
+            because ``time`` and ``iter_time`` will be used to calculate
+            estimated time of arrival. If you want to recount the time, you
+            should set ``log_name`` in corresponding values.
+            - For those statistic methods with the ``window_size`` argument,
+            if ``by_epoch`` is set to False, ``windows_size`` should not be
+            `epoch` to statistics log value by epoch.
+        ignore_last (bool): Ignore the log of last iterations in each epoch if
+            the number of remaining iterations is less than :attr:`interval`.
+            Defaults to True.
+        interval_exp_name (int): Logging interval for experiment name. This
+            feature is to help users conveniently get the experiment
+            information from screen or log file. Defaults to 1000.
+        out_dir (str or Path, optional): The root directory to save
+            checkpoints. If not specified, ``runner.work_dir`` will be used
+            by default. If specified, the ``out_dir`` will be the concatenation
+             of ``out_dir`` and the last level directory of
+            ``runner.work_dir``. For example, if the input ``our_dir`` is
+            ``./tmp`` and ``runner.work_dir`` is ``./work_dir/cur_exp``,
+            then the log will be saved in ``./tmp/cur_exp``. Deafule to None.
+        out_suffix (Tuple[str] or str): Those filenames ending with
+            ``out_suffix`` will be copied to ``out_dir``. Defaults to
+            ('.log.json', '.log', '.py').
+        keep_local (bool): Whether to keep local logs in the local machine
+            when :attr:`out_dir` is specified. If False, the local log will be
+            removed. Defaults to True.
+        file_client_args (dict, optional): Arguments to instantiate a
+            FileClient. See :class:`mmengine.fileio.FileClient` for details.
+            Defaults to None.
+
+    Examples:
+        >>> # `log_name` is defined, `loss_mean_window` will be an additional
+        >>> # record.
+        >>> logger_hook_cfg = dict(by_epoch=True,
+        >>>                        custom_keys=dict(
+        >>>                            loss=dict(
+        >>>                                log_name='loss_mean_window',
+        >>>                                method_name='mean',
+        >>>                                window_size=10)))
+        >>> # `log_name` is not defined. `loss` will be overwritten by
+        >>> # `global_mean` statistics.
+        >>> logger_hook_cfg = dict(by_epoch=True,
+        >>>                        custom_keys=dict(
+        >>>                            loss=dict(
+        >>>                                method_name='mean',
+        >>>                                window_size='global')))
+        >>> # `time` cannot be overwritten, `global_time` will be an additional
+        >>> # record.
+        >>> logger_hook_cfg = dict(by_epoch=True,
+        >>>                        custom_keys=dict(
+        >>>                            time=dict(
+        >>>                                 log_name='global_time',
+        >>>                                 method='mean',
+        >>>                                 window_size='global')))
+        >>> # Record loss with different statistics methods.
+        >>> logger_hook_cfg = dict(by_epoch=True,
+        >>>                        custom_keys=dict(loss=[
+        >>>                            dict(log_name='loss_mean_window',
+        >>>                                 method_name='mean',
+        >>>                                 window_size=10),
+        >>>                            dict(method_name='mean',
+        >>>                                 window_size='global')]))
+    """
+    # eta will be calculated by time. `time` and `data_time` should not be
+    # overwritten.
+    fixed_smooth_keys = ('time', 'data_time')
+    priority = 'BELOW_NORMAL'
+
+    def __init__(
+        self,
+        by_epoch: bool = True,
+        interval: int = 10,
+        custom_keys: Optional[dict] = None,
+        ignore_last: bool = True,
+        interval_exp_name: int = 1000,
+        out_dir: Optional[Union[str, Path]] = None,
+        out_suffix: Union[Sequence[str], str] = ('.log.json', '.log', '.py'),
+        keep_local=True,
+        file_client_args=None,
+    ):
+        self.by_epoch = by_epoch
+        self.interval = interval
+        self.custom_keys = custom_keys if custom_keys is not None else dict()
+        self.ignore_last = ignore_last
+
+        self.time_sec_tot = 0
+        self.interval_exp_name = interval_exp_name
+        self._check_custom_keys()
+
+        if out_dir is None and file_client_args is not None:
+            raise ValueError(
+                'file_client_args should be "None" when `out_dir` is not'
+                'specified.')
+        self.out_dir = out_dir
+
+        if not (out_dir is None or isinstance(out_dir, str)
+                or is_tuple_of(out_dir, str)):
+            raise TypeError('out_dir should be None or string or tuple of '
+                            f'string, but got {type(out_dir)}')
+        self.out_suffix = out_suffix
+
+        self.keep_local = keep_local
+        self.file_client_args = file_client_args
+        if self.out_dir is not None:
+            self.file_client = FileClient.infer_client(file_client_args,
+                                                       self.out_dir)
+
+    def before_run(self, runner) -> None:
+        """Infer ``self.file_client`` from ``self.out_dir``. Initialize the
+        ``self.start_iter`` and record the meta information.
+
+        Args:
+            runner (Runner): The runner of the training process.
+        """
+        if self.out_dir is not None:
+            # The final `self.out_dir` is the concatenation of `self.out_dir`
+            # and the last level directory of `runner.work_dir`
+            basename = osp.basename(runner.work_dir.rstrip(osp.sep))
+            self.out_dir = self.file_client.join_path(self.out_dir, basename)
+            runner.logger.info(
+                (f'Text logs will be saved to {self.out_dir} by '
+                 f'{self.file_client.name} after the training process.'))
+
+        self.json_log_path = osp.join(runner.work_dir,
+                                      f'{runner.timestamp}.log.json')
+        self.yaml_log_path = osp.join(runner.work_dir,
+                                      f'{runner.timestamp}.log.json')
+        self.start_iter = runner.iter
+        if runner.meta is not None:
+            runner.writer.add_params(runner.meta, file_path=self.yaml_log_path)
+
+    def after_train_iter(
+            self,
+            runner,
+            data_batch: DATA_BATCH = None,
+            outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
+        """Record training logs.
+
+        Args:
+            runner (Runner): The runner of the training process.
+            data_batch (Sequence[BaseDataSample], optional): Data from
+                dataloader. Defaults to None.
+            outputs (Sequence[BaseDataSample], optional): Outputs from model.
+                Defaults to None.
+        """
+        if runner.meta is not None and 'exp_name' in runner.meta:
+            if (self.every_n_iters(runner, self.interval_exp_name)) or (
+                    self.by_epoch and self.end_of_epoch(runner)):
+                exp_info = f'Exp name: {runner.meta["exp_name"]}'
+                runner.logger.info(exp_info)
+        if self.by_epoch and self.every_n_inner_iters(runner, self.interval):
+            self._log_train(runner)
+        elif not self.by_epoch and self.every_n_iters(runner, self.interval):
+            self._log_train(runner)
+        elif self.end_of_epoch(runner) and not self.ignore_last:
+            # `runner.max_iters` may not be divisible by `self.interval`. if
+            # `self.ignore_last==True`, the log of remaining iterations will
+            # be recorded (Epoch [4][1000/1007], the logs of 998-1007
+            # iterations will be recorded).
+            self._log_train(runner)
+
+    def after_val_epoch(self, runner) -> None:
+        """Record validation logs.
+
+        Args:
+            runner (Runner): The runner of the training process.
+        """
+        self._log_val(runner)
+
+    def after_run(self, runner) -> None:
+        """Copy logs to ``self.out_dir`` if ``self.out_dir is not None``
+
+        Args:
+            runner (Runner): The runner of the training process.
+        """
+        # copy or upload logs to self.out_dir
+        if self.out_dir is None:
+            return
+        for filename in scandir(runner.work_dir, self.out_suffix, True):
+            local_filepath = osp.join(runner.work_dir, filename)
+            out_filepath = self.file_client.join_path(self.out_dir, filename)
+            with open(local_filepath, 'r') as f:
+                self.file_client.put_text(f.read(), out_filepath)
+
+            runner.logger.info(
+                (f'The file {local_filepath} has been uploaded to '
+                 f'{out_filepath}.'))
+
+            if not self.keep_local:
+                os.remove(local_filepath)
+                runner.logger.info((f'{local_filepath} was removed due to the '
+                                    '`self.keep_local=False`'))
+
+    def _log_train(self, runner) -> None:
+        """Collect and record training logs which start named with "train/*".
+
+        Args:
+            runner (Runner): The runner of the training process.
+        """
+        tag = self._collect_info(runner, 'train')
+        # The training log default defines `lr`, `momentum`, `time` and
+        # `data_time`. `log_tag` will pop these keys and loop other keys to
+        # `log_str`.
+        log_tag = copy.deepcopy(tag)
+        cur_iter = self._get_iter(runner, inner_iter=True)
+        cur_epoch = self._get_epoch(runner, 'train')
+
+        # Record learning rate and momentum.
+        lr_str_list = []
+        momentum_str_list = []
+        for key, value in tag.items():
+            if key.startswith('lr'):
+                log_tag.pop(key)
+                lr_str_list.append(f'{key}: {value:.3e}')
+        lr_str = ' '.join(lr_str_list)
+        for key, value in tag.items():
+            if key.startswith('momentum'):
+                log_tag.pop(key)
+                momentum_str_list.append(f'{key}: {value:.3e}')
+        momentum_str = ' '.join(momentum_str_list)
+        lr_momentum_str = f'{lr_str} {momentum_str}'
+        # by epoch: Epoch [4][100/1000]
+        # by iter:  Iter [100/100000]
+        if self.by_epoch:
+            log_str = f'Epoch [{cur_epoch}]' \
+                      f'[{cur_iter}/{len(runner.data_loader)}]\t'
+        else:
+            log_str = f'Iter [{cur_iter}/{runner.max_iters}]\t'
+        log_str += f'{lr_momentum_str}, '
+        # Calculate eta time.
+        self.time_sec_tot += (tag['time'] * self.interval)
+        time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter + 1)
+        eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
+        eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
+        log_str += f'eta: {eta_str}, '
+        log_str += f'time: {tag["time"]:.3f}, ' \
+                   f'data_time: {tag["data_time"]:.3f}, '
+        # Pop recorded keys
+        log_tag.pop('time')
+        log_tag.pop('data_time')
+        # statistic memory
+        if torch.cuda.is_available():
+            log_str += f'memory: {self._get_max_memory(runner)}, '
+        # Loop left keys to fill `log_str`.
+        log_items = []
+        for name, val in log_tag.items():
+            if isinstance(val, float):
+                val = f'{val:.4f}'
+            log_items.append(f'{name}: {val}')
+        log_str += ', '.join(log_items)
+        runner.logger.info(log_str)
+        # Write logs to local, tensorboad, and wandb.
+        runner.writer.add_scalars(
+            tag, step=runner.iter + 1, file_path=self.json_log_path)
+
+    def _log_val(self, runner) -> None:
+        """Collect and record training logs which start named with "val/*".
+
+        Args:
+            runner (Runner): The runner of the training process.
+        """
+        tag = self._collect_info(runner, 'val')
+        # Compatible with function `log` https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/logger/text.py # noqa E501
+        eval_iter = len(runner.data_loader)
+        cur_iter = self._get_iter(runner)
+        cur_epoch = self._get_epoch(runner, 'val')
+        # val/test time
+        # here 1000 is the length of the val dataloader
+        # by epoch: Epoch[val] [4][1000]
+        # by iter: Iter[val] [1000]
+        if self.by_epoch:
+            # runner.epoch += 1 has been done before val workflow
+            log_str = f'Epoch(val) [{cur_epoch}][{eval_iter}]\t'
+        else:
+            log_str = f'Iter(val) [{eval_iter}]\t'
+
+        log_items = []
+        for name, val in tag.items():
+            if isinstance(val, float):
+                val = f'{val:.4f}'
+            log_items.append(f'{name}: {val}')
+        log_str += ', '.join(log_items)
+        runner.logger.info(log_str)
+        # Write tag.
+        runner.writer.add_scalars(
+            tag, step=cur_iter, file_path=self.json_log_path)
+
+    def _get_window_size(self, runner, window_size: Union[int, str]) \
+            -> int:
+        """Parse window_size specified in ``self.custom_keys`` to int value.
+
+        Args:
+            runner (Runner): The runner of the training process.
+            window_size (int or str): Smoothing scale of logs.
+
+        Returns:
+            int: Smoothing window for statistical methods.
+        """
+        if isinstance(window_size, int):
+            assert window_size == self.interval, \
+                'The value of windows size must equal to LoggerHook.interval'
+            return window_size
+        elif window_size == 'epoch':
+            return runner.inner_iter + 1
+        elif window_size == 'global':
+            return runner.iter + 1
+        else:
+            raise ValueError('window_size should be int, epoch or global, but '
+                             f'got invalid {window_size}')
+
+    def _collect_info(self, runner, mode: str) -> dict:
+        """Collect log information to a dict according to mode.
+
+        Args:
+            runner (Runner): The runner of the training process.
+            mode (str): 'train' or 'val', which means the prefix attached by
+                runner.
+
+        Returns:
+            dict: Statistical values of logs.
+        """
+        tag = OrderedDict()
+        log_buffers = runner.message_hub.log_buffers
+        mode_log_buffers = OrderedDict()
+        # Filter log_buffers which starts with `mode`.
+        for prefix_key, log_buffer in log_buffers.items():
+            if prefix_key.startswith(mode):
+                key = prefix_key.split('/')[-1]
+                mode_log_buffers[key] = log_buffer
+        # Ensure all metric and lr values are latest.
+        for key in mode_log_buffers:
+            # Update the latest learning rate and smoothed time logs.
+            if key in self.fixed_smooth_keys or key.startswith('loss'):
+                tag[key] = mode_log_buffers[key].mean(self.interval)
+            else:
+                tag[key] = mode_log_buffers[key].current()
+        # Update custom keys.
+        if mode == 'train':
+            for log_key, log_cfg in self.custom_keys.items():
+                self._parse_custom_keys(runner, log_key,
+                                        copy.deepcopy(log_cfg),
+                                        mode_log_buffers, tag)
+        return tag
+
+    def _parse_custom_keys(self, runner, log_key: str, log_cfg: dict,
+                           log_buffers: OrderedDict, tag: OrderedDict) -> None:
+        """Statistics logs in log_buffers according to custom_keys.
+
+        Args:
+            runner (Runner): The runner of the training process.
+            log_key (str): log key specified in ``self.custom_keys``
+            log_cfg (dict): A config dict for describing the logging
+                statistics method.
+            log_buffers (OrderedDict): All logs for the corresponding phase.
+            tag (OrderedDict): A dict which defines all statistic values of
+                logs.
+        """
+        if isinstance(log_cfg, list):
+            log_names = set()
+            for cfg in log_cfg:
+                log_name = cfg.get('log_name', None)
+                if log_name in log_names:
+                    raise KeyError(f'{cfg["log_name"]} cannot be Redefined in '
+                                   'log_key')
+                if log_name is not None:
+                    log_names.add(log_name)
+                self._parse_custom_keys(runner, log_key, cfg, log_buffers, tag)
+            assert len(log_names) == len(log_cfg) - 1, \
+                f'{log_key} cannot be overwritten multiple times, please ' \
+                f'check only one key does not contain `log_name` in {log_cfg}.'
+        elif isinstance(log_cfg, dict):
+            if 'window_size' in log_cfg:
+                log_cfg['window_size'] = \
+                    self._get_window_size(runner, log_cfg['window_size'])
+            if 'log_name' in log_cfg:
+                name = log_cfg.pop('log_name')
+            else:
+                name = log_key
+            tag[name] = log_buffers[log_key].statistics(**log_cfg)
+        else:
+            raise ValueError('The structure of `LoggerHook.custom key` is '
+                             'wrong, please make sure the type of each key is '
+                             'dict or list.')
+
+    def _get_max_memory(self, runner) -> int:
+        """Returns the maximum GPU memory occupied by tensors in megabytes (MB)
+        for a given device.
+
+        Args:
+            runner (Runner): The runner of the training process.
+
+        Returns:
+            The maximum GPU memory occupied by tensors in megabytes for a given
+            device.
+        """
+        # TODO use `mmengine.dist.max_memory_allocated` to count mem_mb
+        device = getattr(runner.model, 'output_device', None)
+        mem = torch.cuda.max_memory_allocated(device=device)
+        mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
+                              dtype=torch.int,
+                              device=device)
+        torch.cuda.reset_peak_memory_stats()
+        return int(mem_mb.item())
+
+    def _check_custom_keys(self) -> None:
+        """Check the legality of ``self.custom_keys``.
+
+        If ``self.by_epoch==False``, ``window_size`` should not be "epoch". The
+        key of ``self.fixed_smooth_keys`` cannot be overwritten.
+        """
+
+        def _check_window_size(item):
+            if not self.by_epoch:
+                assert item['window_size'] != 'epoch', \
+                    'window_size cannot be epoch if LoggerHook.by_epoch is ' \
+                    'False.'
+
+        def _check_fixed_keys(key, item):
+            if key in self.fixed_smooth_keys:
+                assert 'log_name' in item, f'{key} cannot be overwritten by ' \
+                                           'custom keys!'
+
+        for key, value in self.custom_keys.items():
+            if isinstance(value, Sequence):
+                [(_check_window_size(item), _check_fixed_keys(key, item))
+                 for item in value]
+
+            else:
+                _check_window_size(value)
+                _check_fixed_keys(key, value)
+
+    def _get_epoch(self, runner, mode: str) -> int:
+        """Get epoch according to mode.
+
+        Args:
+            runner (Runner): The runner of the training process.
+            mode (str): Train or val.
+
+        Returns:
+            int: The current epoch.
+        """
+        if mode == 'train':
+            epoch = runner.epoch + 1
+        elif mode == 'val':
+            # normal val mode
+            # runner.epoch += 1 has been done before val workflow
+            epoch = runner.epoch
+        else:
+            raise ValueError(f"runner mode should be 'train' or 'val', "
+                             f'but got {runner.mode}')
+        return epoch
+
+    def _get_iter(self, runner, inner_iter=False) -> int:
+        """Get the current training iteration step.
+        Args:
+            runner (Runner): The runner of the training process.
+            inner_iter (bool): Whether to return the inner iter of an epoch.
+                Defaults to False.
+
+        Returns:
+            int: The current global iter or inner iter.
+        """
+        if self.by_epoch and inner_iter:
+            current_iter = runner.inner_iter + 1
+        else:
+            current_iter = runner.iter + 1
+        return current_iter
diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py
new file mode 100644
index 00000000..631736e8
--- /dev/null
+++ b/tests/test_hook/test_logger_hook.py
@@ -0,0 +1,355 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import datetime
+import logging
+import os.path as osp
+import sys
+from collections import OrderedDict
+from unittest.mock import MagicMock, patch
+
+import pytest
+import torch
+
+from mmengine.fileio.file_client import HardDiskBackend
+from mmengine.hooks import LoggerHook
+
+
+class TestLoggerHook:
+
+    def test_init(self):
+        logger_hook = LoggerHook(out_dir='tmp.txt')
+        assert logger_hook.by_epoch
+        assert logger_hook.interval == 10
+        assert not logger_hook.custom_keys
+        assert logger_hook.ignore_last
+        assert logger_hook.time_sec_tot == 0
+        assert logger_hook.interval_exp_name == 1000
+        assert logger_hook.out_suffix == ('.log.json', '.log', '.py')
+        assert logger_hook.keep_local
+        assert logger_hook.file_client_args is None
+        assert isinstance(logger_hook.file_client.client, HardDiskBackend)
+        # out_dir should be None or string or tuple of string.
+        with pytest.raises(TypeError):
+            LoggerHook(out_dir=1)
+        # time cannot be overwritten.
+        with pytest.raises(AssertionError):
+            LoggerHook(custom_keys=dict(time=dict(method='max')))
+        LoggerHook(
+            custom_keys=dict(time=[
+                dict(method='max', log_name='time_max'),
+                dict(method='min', log_name='time_min')
+            ]))
+        # Epoch window_size cannot be used when `LoggerHook.by_epoch=False`
+        with pytest.raises(AssertionError):
+            LoggerHook(
+                by_epoch=False,
+                custom_keys=dict(
+                    time=dict(
+                        method='max', log_name='time_max',
+                        window_size='epoch')))
+        with pytest.raises(ValueError):
+            LoggerHook(file_client_args=dict(enable_mc=True))
+
+    def test_before_run(self):
+        runner = MagicMock()
+        runner.iter = 10
+        runner.timestamp = 'timestamp'
+        runner.work_dir = 'work_dir'
+        runner.logger = MagicMock()
+        logger_hook = LoggerHook(out_dir='out_dir')
+        logger_hook.before_run(runner)
+        assert logger_hook.out_dir == osp.join('out_dir', 'work_dir')
+        assert logger_hook.json_log_path == osp.join('work_dir',
+                                                     'timestamp.log.json')
+        assert logger_hook.start_iter == runner.iter
+        runner.writer.add_params.assert_called()
+
+    def test_after_run(self, tmp_path):
+        out_dir = tmp_path / 'out_dir'
+        out_dir.mkdir()
+        work_dir = tmp_path / 'work_dir'
+        work_dir.mkdir()
+        work_dir_json = work_dir / 'tmp.log.json'
+        json_f = open(work_dir_json, 'w')
+        json_f.close()
+        runner = MagicMock()
+        runner.work_dir = work_dir
+
+        logger_hook = LoggerHook(out_dir=str(tmp_path), keep_local=False)
+        logger_hook.out_dir = str(out_dir)
+        logger_hook.after_run(runner)
+        # Verify that the file has been moved to `out_dir`.
+        assert not osp.exists(str(work_dir_json))
+        assert osp.exists(str(out_dir / 'tmp.log.json'))
+
+    def test_after_train_iter(self):
+        # Test LoggerHook by iter.
+        runner = MagicMock()
+        runner.iter = 10
+        logger_hook = LoggerHook(by_epoch=False)
+        logger_hook._log_train = MagicMock()
+        logger_hook.after_train_iter(runner)
+        # `cur_iter=10+1`, which cannot be exact division by
+        # `logger_hook.interval`
+        logger_hook._log_train.assert_not_called()
+        runner.iter = 9
+        logger_hook.after_train_iter(runner)
+        logger_hook._log_train.assert_called()
+
+        # Test LoggerHook by epoch.
+        logger_hook = LoggerHook(by_epoch=True)
+        logger_hook._log_train = MagicMock()
+        # Only `runner.inner_iter` will work.
+        runner.iter = 9
+        runner.inner_iter = 10
+        logger_hook.after_train_iter(runner)
+        logger_hook._log_train.assert_not_called()
+        runner.inner_iter = 9
+        logger_hook.after_train_iter(runner)
+        logger_hook._log_train.assert_called()
+
+        # Test end of the epoch.
+        logger_hook = LoggerHook(by_epoch=True, ignore_last=False)
+        logger_hook._log_train = MagicMock()
+        runner.data_loader = [0] * 5
+        runner.inner_iter = 4
+        logger_hook.after_train_iter(runner)
+        logger_hook._log_train.assert_called()
+
+        # Test print exp_name
+        runner.meta = dict(exp_name='retinanet')
+        logger_hook = LoggerHook()
+        runner.logger = MagicMock()
+        logger_hook._log_train = MagicMock()
+        logger_hook.after_train_iter(runner)
+        runner.logger.info.assert_called_with(
+            f'Exp name: {runner.meta["exp_name"]}')
+
+    def test_after_val_epoch(self):
+        logger_hook = LoggerHook()
+        runner = MagicMock()
+        logger_hook._log_val = MagicMock()
+        logger_hook.after_val_epoch(runner)
+        logger_hook._log_val.assert_called()
+
+    @pytest.mark.parametrize('by_epoch', [True, False])
+    def test_log_train(self, by_epoch, capsys):
+        runner = self._setup_runner()
+        runner.meta = dict(exp_name='retinanet')
+        # Prepare LoggerHook
+        logger_hook = LoggerHook(by_epoch=by_epoch)
+        logger_hook.writer = MagicMock()
+        logger_hook.time_sec_tot = 1000
+        logger_hook.start_iter = 0
+        logger_hook._get_max_memory = MagicMock(return_value='100')
+        logger_hook.json_log_path = 'tmp.json'
+
+        # Prepare training information.
+        train_infos = dict(
+            lr=0.1, momentum=0.9, time=1.0, data_time=1.0, loss_cls=1.0)
+        logger_hook._collect_info = MagicMock(return_value=train_infos)
+        logger_hook._log_train(runner)
+        # Verify that the correct variables have been written.
+        runner.writer.add_scalars.assert_called_with(
+            train_infos, step=11, file_path='tmp.json')
+        # Verify that the correct context have been logged.
+        out, _ = capsys.readouterr()
+        time_avg = logger_hook.time_sec_tot / (
+            runner.iter + 1 - logger_hook.start_iter)
+        eta_second = time_avg * (runner.max_iters - runner.iter - 1)
+        eta_str = str(datetime.timedelta(seconds=int(eta_second)))
+        if by_epoch:
+            if torch.cuda.is_available():
+                log_str = 'Epoch [2][2/5]\t' \
+                          f"lr: {train_infos['lr']:.3e} " \
+                          f"momentum: {train_infos['momentum']:.3e}, " \
+                          f'eta: {eta_str}, ' \
+                          f"time: {train_infos['time']:.3f}, " \
+                          f"data_time: {train_infos['data_time']:.3f}, " \
+                          f'memory: 100, ' \
+                          f"loss_cls: {train_infos['loss_cls']:.4f}\n"
+            else:
+                log_str = 'Epoch [2][2/5]\t' \
+                          f"lr: {train_infos['lr']:.3e} " \
+                          f"momentum: {train_infos['momentum']:.3e}, " \
+                          f'eta: {eta_str}, ' \
+                          f"time: {train_infos['time']:.3f}, " \
+                          f"data_time: {train_infos['data_time']:.3f}, " \
+                          f"loss_cls: {train_infos['loss_cls']:.4f}\n"
+            assert out == log_str
+        else:
+            if torch.cuda.is_available():
+                log_str = 'Iter [11/50]\t' \
+                          f"lr: {train_infos['lr']:.3e} " \
+                          f"momentum: {train_infos['momentum']:.3e}, " \
+                          f'eta: {eta_str}, ' \
+                          f"time: {train_infos['time']:.3f}, " \
+                          f"data_time: {train_infos['data_time']:.3f}, " \
+                          f'memory: 100, ' \
+                          f"loss_cls: {train_infos['loss_cls']:.4f}\n"
+            else:
+                log_str = 'Iter [11/50]\t' \
+                          f"lr: {train_infos['lr']:.3e} " \
+                          f"momentum: {train_infos['momentum']:.3e}, " \
+                          f'eta: {eta_str}, ' \
+                          f"time: {train_infos['time']:.3f}, " \
+                          f"data_time: {train_infos['data_time']:.3f}, " \
+                          f"loss_cls: {train_infos['loss_cls']:.4f}\n"
+            assert out == log_str
+
+    @pytest.mark.parametrize('by_epoch', [True, False])
+    def test_log_val(self, by_epoch, capsys):
+        runner = self._setup_runner()
+        # Prepare LoggerHook.
+        logger_hook = LoggerHook(by_epoch=by_epoch)
+        logger_hook.json_log_path = 'tmp.json'
+        metric = dict(accuracy=0.9, data_time=1.0)
+        logger_hook._collect_info = MagicMock(return_value=metric)
+        logger_hook._log_val(runner)
+        # Verify that the correct context have been logged.
+        out, _ = capsys.readouterr()
+        runner.writer.add_scalars.assert_called_with(
+            metric, step=11, file_path='tmp.json')
+        if by_epoch:
+            assert out == 'Epoch(val) [1][5]\taccuracy: 0.9000, ' \
+                          'data_time: 1.0000\n'
+
+        else:
+            assert out == 'Iter(val) [5]\taccuracy: 0.9000, ' \
+                          'data_time: 1.0000\n'
+
+    def test_get_window_size(self):
+        runner = self._setup_runner()
+        logger_hook = LoggerHook()
+        # Test get window size by name.
+        assert logger_hook._get_window_size(runner, 'epoch') == 2
+        assert logger_hook._get_window_size(runner, 'global') == 11
+        assert logger_hook._get_window_size(runner, 10) == 10
+        # Window size must equal to `logger_hook.interval`.
+        with pytest.raises(AssertionError):
+            logger_hook._get_window_size(runner, 20)
+
+        with pytest.raises(ValueError):
+            logger_hook._get_window_size(runner, 'unknwon')
+
+    def test_parse_custom_keys(self):
+        tag = OrderedDict()
+        runner = self._setup_runner()
+        log_buffers = OrderedDict(lr=MagicMock(), loss=MagicMock())
+        cfg_dict = dict(
+            lr=dict(method='min'),
+            loss=[
+                dict(method='min', window_size='global'),
+                dict(method='max', log_name='loss_max')
+            ])
+        logger_hook = LoggerHook()
+        for log_key, log_cfg in cfg_dict.items():
+            logger_hook._parse_custom_keys(runner, log_key, log_cfg,
+                                           log_buffers, tag)
+        assert list(tag) == ['lr', 'loss', 'loss_max']
+        assert log_buffers['lr'].min.assert_called
+        assert log_buffers['loss'].min.assert_called
+        assert log_buffers['loss'].max.assert_called
+        assert log_buffers['loss'].mean.assert_called
+        # `log_name` Cannot be repeated.
+        with pytest.raises(KeyError):
+            cfg_dict = dict(loss=[
+                dict(method='min', window_size='global'),
+                dict(method='max', log_name='loss_max'),
+                dict(method='mean', log_name='loss_max')
+            ])
+            logger_hook.custom_keys = cfg_dict
+            for log_key, log_cfg in cfg_dict.items():
+                logger_hook._parse_custom_keys(runner, log_key, log_cfg,
+                                               log_buffers, tag)
+        # `log_key` cannot be overwritten multiple times.
+        with pytest.raises(AssertionError):
+            cfg_dict = dict(loss=[
+                dict(method='min', window_size='global'),
+                dict(method='max'),
+            ])
+            logger_hook.custom_keys = cfg_dict
+            for log_key, log_cfg in cfg_dict.items():
+                logger_hook._parse_custom_keys(runner, log_key, log_cfg,
+                                               log_buffers, tag)
+
+    def test_collect_info(self):
+        runner = self._setup_runner()
+        logger_hook = LoggerHook(
+            custom_keys=dict(time=dict(method='max', log_name='time_max')))
+        logger_hook._parse_custom_keys = MagicMock()
+        # Collect with prefix.
+        log_buffers = {
+            'train/time': MagicMock(),
+            'lr': MagicMock(),
+            'train/loss_cls': MagicMock(),
+            'val/metric': MagicMock()
+        }
+        runner.message_hub.log_buffers = log_buffers
+        tag = logger_hook._collect_info(runner, mode='train')
+        # Test parse custom_keys
+        logger_hook._parse_custom_keys.assert_called()
+        # Test training key in tag.
+        assert list(tag.keys()) == ['time', 'loss_cls']
+        # Test statistics lr with `current`, loss and time with 'mean'
+        log_buffers['train/time'].mean.assert_called()
+        log_buffers['train/loss_cls'].mean.assert_called()
+        log_buffers['train/loss_cls'].current.assert_not_called()
+
+        tag = logger_hook._collect_info(runner, mode='val')
+        assert list(tag.keys()) == ['metric']
+        log_buffers['val/metric'].current.assert_called()
+
+    @patch('torch.distributed.reduce', MagicMock())
+    def test_get_max_memory(self):
+        logger_hook = LoggerHook()
+        runner = MagicMock()
+        runner.world_size = 1
+        runner.model = torch.nn.Linear(1, 1)
+        logger_hook._get_max_memory(runner)
+        torch.distributed.reduce.assert_not_called()
+        runner.world_size = 2
+        logger_hook._get_max_memory(runner)
+        torch.distributed.reduce.assert_called()
+
+    def test_get_iter(self):
+        runner = self._setup_runner()
+        logger_hook = LoggerHook()
+        # Get global iter when `inner_iter=False`
+        iter = logger_hook._get_iter(runner)
+        assert iter == 11
+        # Get inner iter
+        iter = logger_hook._get_iter(runner, inner_iter=True)
+        assert iter == 2
+        # Still get global iter when `logger_hook.by_epoch==False`
+        logger_hook.by_epoch = False
+        iter = logger_hook._get_iter(runner, inner_iter=True)
+        assert iter == 11
+
+    def test_get_epoch(self):
+        runner = self._setup_runner()
+        logger_hook = LoggerHook()
+        epoch = logger_hook._get_epoch(runner, 'train')
+        assert epoch == 2
+        epoch = logger_hook._get_epoch(runner, 'val')
+        assert epoch == 1
+        with pytest.raises(ValueError):
+            logger_hook._get_epoch(runner, 'test')
+
+    def _setup_runner(self):
+        runner = MagicMock()
+        runner.epoch = 1
+        runner.data_loader = [0] * 5
+        runner.inner_iter = 1
+        runner.iter = 10
+        runner.max_iters = 50
+        logger = logging.getLogger()
+        logger.setLevel(logging.INFO)
+        for handler in logger.handlers:
+            if not isinstance(handler, logging.StreamHandler):
+                continue
+        else:
+            logger.addHandler(logging.StreamHandler(stream=sys.stdout))
+        runner.logger = logger
+        runner.message_hub = MagicMock()
+        runner.composed_wirter = MagicMock()
+        return runner
-- 
GitLab