Skip to content
Snippets Groups Projects
test_logger_hook.py 14.8 KiB
Newer Older
Mashiro's avatar
Mashiro committed
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import logging
import os.path as osp
import sys
from collections import OrderedDict
from unittest.mock import MagicMock, patch

import pytest
import torch

from mmengine.fileio.file_client import HardDiskBackend
from mmengine.hooks import LoggerHook


class TestLoggerHook:

    def test_init(self):
        logger_hook = LoggerHook(out_dir='tmp.txt')
        assert logger_hook.by_epoch
        assert logger_hook.interval == 10
        assert not logger_hook.custom_keys
        assert logger_hook.ignore_last
        assert logger_hook.time_sec_tot == 0
        assert logger_hook.interval_exp_name == 1000
        assert logger_hook.out_suffix == ('.log.json', '.log', '.py')
        assert logger_hook.keep_local
        assert logger_hook.file_client_args is None
        assert isinstance(logger_hook.file_client.client, HardDiskBackend)
        # out_dir should be None or string or tuple of string.
        with pytest.raises(TypeError):
            LoggerHook(out_dir=1)
        # time cannot be overwritten.
        with pytest.raises(AssertionError):
            LoggerHook(custom_keys=dict(time=dict(method='max')))
        LoggerHook(
            custom_keys=dict(time=[
                dict(method='max', log_name='time_max'),
                dict(method='min', log_name='time_min')
            ]))
        # Epoch window_size cannot be used when `LoggerHook.by_epoch=False`
        with pytest.raises(AssertionError):
            LoggerHook(
                by_epoch=False,
                custom_keys=dict(
                    time=dict(
                        method='max', log_name='time_max',
                        window_size='epoch')))
        with pytest.raises(ValueError):
            LoggerHook(file_client_args=dict(enable_mc=True))

    def test_before_run(self):
        runner = MagicMock()
        runner.iter = 10
        runner.timestamp = 'timestamp'
        runner.work_dir = 'work_dir'
        runner.logger = MagicMock()
        logger_hook = LoggerHook(out_dir='out_dir')
        logger_hook.before_run(runner)
        assert logger_hook.out_dir == osp.join('out_dir', 'work_dir')
        assert logger_hook.json_log_path == osp.join('work_dir',
                                                     'timestamp.log.json')
        assert logger_hook.start_iter == runner.iter
        runner.writer.add_params.assert_called()

    def test_after_run(self, tmp_path):
        out_dir = tmp_path / 'out_dir'
        out_dir.mkdir()
        work_dir = tmp_path / 'work_dir'
        work_dir.mkdir()
        work_dir_json = work_dir / 'tmp.log.json'
        json_f = open(work_dir_json, 'w')
        json_f.close()
        runner = MagicMock()
        runner.work_dir = work_dir

        logger_hook = LoggerHook(out_dir=str(tmp_path), keep_local=False)
        logger_hook.out_dir = str(out_dir)
        logger_hook.after_run(runner)
        # Verify that the file has been moved to `out_dir`.
        assert not osp.exists(str(work_dir_json))
        assert osp.exists(str(out_dir / 'tmp.log.json'))

    def test_after_train_iter(self):
        # Test LoggerHook by iter.
        runner = MagicMock()
        runner.iter = 10
Mashiro's avatar
Mashiro committed
        logger_hook = LoggerHook(by_epoch=False)
        logger_hook._log_train = MagicMock()
        logger_hook.after_train_iter(runner, batch_idx=batch_idx)
Mashiro's avatar
Mashiro committed
        # `cur_iter=10+1`, which cannot be exact division by
        # `logger_hook.interval`
        logger_hook._log_train.assert_not_called()
        runner.iter = 9
        logger_hook.after_train_iter(runner, batch_idx=batch_idx)
Mashiro's avatar
Mashiro committed
        logger_hook._log_train.assert_called()

        # Test LoggerHook by epoch.
        logger_hook = LoggerHook(by_epoch=True)
        logger_hook._log_train = MagicMock()
        # Only `runner.inner_iter` will work.
        runner.iter = 9
        batch_idx = 10
        logger_hook.after_train_iter(runner, batch_idx=batch_idx)
Mashiro's avatar
Mashiro committed
        logger_hook._log_train.assert_not_called()
        batch_idx = 9
        logger_hook.after_train_iter(runner, batch_idx=batch_idx)
Mashiro's avatar
Mashiro committed
        logger_hook._log_train.assert_called()

        # Test end of the epoch.
        logger_hook = LoggerHook(by_epoch=True, ignore_last=False)
        logger_hook._log_train = MagicMock()
        runner.train_loop.dataloader = [0] * 5
        batch_idx = 4
        logger_hook.after_train_iter(runner, batch_idx=batch_idx)
Mashiro's avatar
Mashiro committed
        logger_hook._log_train.assert_called()

        # Test print exp_name
        runner.meta = dict(exp_name='retinanet')
        logger_hook = LoggerHook()
        runner.logger = MagicMock()
        logger_hook._log_train = MagicMock()
        logger_hook.after_train_iter(runner, batch_idx=batch_idx)
Mashiro's avatar
Mashiro committed
        runner.logger.info.assert_called_with(
            f'Exp name: {runner.meta["exp_name"]}')

    def test_after_val_epoch(self):
        logger_hook = LoggerHook()
        runner = MagicMock()
        logger_hook._log_val = MagicMock()
        logger_hook.after_val_epoch(runner)
        logger_hook._log_val.assert_called()

    @pytest.mark.parametrize('by_epoch', [True, False])
    def test_log_train(self, by_epoch, capsys):
        runner = self._setup_runner()
        runner.meta = dict(exp_name='retinanet')
        # Prepare LoggerHook
        logger_hook = LoggerHook(by_epoch=by_epoch)
        logger_hook._inner_iter = 1
Mashiro's avatar
Mashiro committed
        logger_hook.writer = MagicMock()
        logger_hook.time_sec_tot = 1000
        logger_hook.start_iter = 0
        logger_hook._get_max_memory = MagicMock(return_value='100')
        logger_hook.json_log_path = 'tmp.json'

        # Prepare training information.
        train_infos = dict(
            lr=0.1, momentum=0.9, time=1.0, data_time=1.0, loss_cls=1.0)
        logger_hook._collect_info = MagicMock(return_value=train_infos)
        logger_hook._log_train(runner)
        # Verify that the correct variables have been written.
        runner.writer.add_scalars.assert_called_with(
            train_infos, step=11, file_path='tmp.json')
        # Verify that the correct context have been logged.
        out, _ = capsys.readouterr()
        time_avg = logger_hook.time_sec_tot / (
            runner.iter + 1 - logger_hook.start_iter)
        eta_second = time_avg * (runner.train_loop.max_iters - runner.iter - 1)
Mashiro's avatar
Mashiro committed
        eta_str = str(datetime.timedelta(seconds=int(eta_second)))
        if by_epoch:
            if torch.cuda.is_available():
Mashiro's avatar
Mashiro committed
                log_str = 'Epoch [2][2/5] ' \
Mashiro's avatar
Mashiro committed
                          f"lr: {train_infos['lr']:.3e} " \
                          f"momentum: {train_infos['momentum']:.3e}, " \
                          f'eta: {eta_str}, ' \
                          f"time: {train_infos['time']:.3f}, " \
                          f"data_time: {train_infos['data_time']:.3f}, " \
                          f'memory: 100, ' \
                          f"loss_cls: {train_infos['loss_cls']:.4f}\n"
            else:
Mashiro's avatar
Mashiro committed
                log_str = 'Epoch [2][2/5] ' \
Mashiro's avatar
Mashiro committed
                          f"lr: {train_infos['lr']:.3e} " \
                          f"momentum: {train_infos['momentum']:.3e}, " \
                          f'eta: {eta_str}, ' \
                          f"time: {train_infos['time']:.3f}, " \
                          f"data_time: {train_infos['data_time']:.3f}, " \
                          f"loss_cls: {train_infos['loss_cls']:.4f}\n"
            assert out == log_str
        else:
            if torch.cuda.is_available():
Mashiro's avatar
Mashiro committed
                log_str = 'Iter [11/50] ' \
Mashiro's avatar
Mashiro committed
                          f"lr: {train_infos['lr']:.3e} " \
                          f"momentum: {train_infos['momentum']:.3e}, " \
                          f'eta: {eta_str}, ' \
                          f"time: {train_infos['time']:.3f}, " \
                          f"data_time: {train_infos['data_time']:.3f}, " \
                          f'memory: 100, ' \
                          f"loss_cls: {train_infos['loss_cls']:.4f}\n"
            else:
Mashiro's avatar
Mashiro committed
                log_str = 'Iter [11/50] ' \
Mashiro's avatar
Mashiro committed
                          f"lr: {train_infos['lr']:.3e} " \
                          f"momentum: {train_infos['momentum']:.3e}, " \
                          f'eta: {eta_str}, ' \
                          f"time: {train_infos['time']:.3f}, " \
                          f"data_time: {train_infos['data_time']:.3f}, " \
                          f"loss_cls: {train_infos['loss_cls']:.4f}\n"
            assert out == log_str

    @pytest.mark.parametrize('by_epoch', [True, False])
    def test_log_val(self, by_epoch, capsys):
        runner = self._setup_runner()
        # Prepare LoggerHook.
        logger_hook = LoggerHook(by_epoch=by_epoch)
        logger_hook.json_log_path = 'tmp.json'
        metric = dict(accuracy=0.9, data_time=1.0)
        logger_hook._collect_info = MagicMock(return_value=metric)
        logger_hook._log_val(runner)
        # Verify that the correct context have been logged.
        out, _ = capsys.readouterr()
        runner.writer.add_scalars.assert_called_with(
            metric, step=11, file_path='tmp.json')
        if by_epoch:
Mashiro's avatar
Mashiro committed
            assert out == 'Epoch(val) [1][5] accuracy: 0.9000, ' \
Mashiro's avatar
Mashiro committed
                          'data_time: 1.0000\n'

        else:
Mashiro's avatar
Mashiro committed
            assert out == 'Iter(val) [5] accuracy: 0.9000, ' \
Mashiro's avatar
Mashiro committed
                          'data_time: 1.0000\n'

    def test_get_window_size(self):
        runner = self._setup_runner()
        logger_hook = LoggerHook()
        logger_hook._inner_iter = 1
Mashiro's avatar
Mashiro committed
        # Test get window size by name.
        assert logger_hook._get_window_size(runner, 'epoch') == 2
        assert logger_hook._get_window_size(runner, 'global') == 11
        assert logger_hook._get_window_size(runner, 10) == 10
        # Window size must equal to `logger_hook.interval`.
        with pytest.raises(AssertionError):
            logger_hook._get_window_size(runner, 20)

        with pytest.raises(ValueError):
            logger_hook._get_window_size(runner, 'unknwon')

    def test_parse_custom_keys(self):
        tag = OrderedDict()
        runner = self._setup_runner()
        log_buffers = OrderedDict(lr=MagicMock(), loss=MagicMock())
        cfg_dict = dict(
            lr=dict(method='min'),
            loss=[
                dict(method='min', window_size='global'),
                dict(method='max', log_name='loss_max')
            ])
        logger_hook = LoggerHook()
        for log_key, log_cfg in cfg_dict.items():
            logger_hook._parse_custom_keys(runner, log_key, log_cfg,
                                           log_buffers, tag)
        assert list(tag) == ['lr', 'loss', 'loss_max']
        assert log_buffers['lr'].min.assert_called
        assert log_buffers['loss'].min.assert_called
        assert log_buffers['loss'].max.assert_called
        assert log_buffers['loss'].mean.assert_called
        # `log_name` Cannot be repeated.
        with pytest.raises(KeyError):
            cfg_dict = dict(loss=[
                dict(method='min', window_size='global'),
                dict(method='max', log_name='loss_max'),
                dict(method='mean', log_name='loss_max')
            ])
            logger_hook.custom_keys = cfg_dict
            for log_key, log_cfg in cfg_dict.items():
                logger_hook._parse_custom_keys(runner, log_key, log_cfg,
                                               log_buffers, tag)
        # `log_key` cannot be overwritten multiple times.
        with pytest.raises(AssertionError):
            cfg_dict = dict(loss=[
                dict(method='min', window_size='global'),
                dict(method='max'),
            ])
            logger_hook.custom_keys = cfg_dict
            for log_key, log_cfg in cfg_dict.items():
                logger_hook._parse_custom_keys(runner, log_key, log_cfg,
                                               log_buffers, tag)

    def test_collect_info(self):
        runner = self._setup_runner()
        logger_hook = LoggerHook(
            custom_keys=dict(time=dict(method='max', log_name='time_max')))
        logger_hook._parse_custom_keys = MagicMock()
        # Collect with prefix.
        log_buffers = {
            'train/time': MagicMock(),
            'lr': MagicMock(),
            'train/loss_cls': MagicMock(),
            'val/metric': MagicMock()
        }
        runner.message_hub.log_buffers = log_buffers
        tag = logger_hook._collect_info(runner, mode='train')
        # Test parse custom_keys
        logger_hook._parse_custom_keys.assert_called()
        # Test training key in tag.
        assert list(tag.keys()) == ['time', 'loss_cls']
        # Test statistics lr with `current`, loss and time with 'mean'
        log_buffers['train/time'].mean.assert_called()
        log_buffers['train/loss_cls'].mean.assert_called()
        log_buffers['train/loss_cls'].current.assert_not_called()

        tag = logger_hook._collect_info(runner, mode='val')
        assert list(tag.keys()) == ['metric']
        log_buffers['val/metric'].current.assert_called()

Mashiro's avatar
Mashiro committed
    @patch('torch.cuda.max_memory_allocated', MagicMock())
    @patch('torch.cuda.reset_peak_memory_stats', MagicMock())
Mashiro's avatar
Mashiro committed
    def test_get_max_memory(self):
        logger_hook = LoggerHook()
        runner = MagicMock()
        runner.world_size = 1
        runner.model = torch.nn.Linear(1, 1)
        logger_hook._get_max_memory(runner)
Mashiro's avatar
Mashiro committed
        torch.cuda.max_memory_allocated.assert_called()
        torch.cuda.reset_peak_memory_stats.assert_called()
Mashiro's avatar
Mashiro committed

    def test_get_iter(self):
        runner = self._setup_runner()
        logger_hook = LoggerHook()
        logger_hook._inner_iter = 1
Mashiro's avatar
Mashiro committed
        # Get global iter when `inner_iter=False`
        iter = logger_hook._get_iter(runner)
        assert iter == 11
        # Get inner iter
        iter = logger_hook._get_iter(runner, inner_iter=True)
        assert iter == 2
        # Still get global iter when `logger_hook.by_epoch==False`
        logger_hook.by_epoch = False
        iter = logger_hook._get_iter(runner, inner_iter=True)
        assert iter == 11

    def test_get_epoch(self):
        runner = self._setup_runner()
        logger_hook = LoggerHook()
        epoch = logger_hook._get_epoch(runner, 'train')
        assert epoch == 2
        epoch = logger_hook._get_epoch(runner, 'val')
        assert epoch == 1
        with pytest.raises(ValueError):
            logger_hook._get_epoch(runner, 'test')

    def _setup_runner(self):
        runner = MagicMock()
        runner.epoch = 1
        runner.train_loop.dataloader = [0] * 5
        runner.val_loop.dataloader = [0] * 5
        runner.test_loop.dataloader = [0] * 5
Mashiro's avatar
Mashiro committed
        runner.iter = 10
        runner.train_loop.max_iters = 50
Mashiro's avatar
Mashiro committed
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
        for handler in logger.handlers:
            if not isinstance(handler, logging.StreamHandler):
                continue
        else:
            logger.addHandler(logging.StreamHandler(stream=sys.stdout))
        runner.logger = logger
        runner.message_hub = MagicMock()
        runner.composed_wirter = MagicMock()
        return runner