# Copyright (c) OpenMMLab. All rights reserved. import datetime import logging import os.path as osp import sys from collections import OrderedDict from unittest.mock import MagicMock, patch import pytest import torch from mmengine.fileio.file_client import HardDiskBackend from mmengine.hooks import LoggerHook class TestLoggerHook: def test_init(self): logger_hook = LoggerHook(out_dir='tmp.txt') assert logger_hook.by_epoch assert logger_hook.interval == 10 assert not logger_hook.custom_keys assert logger_hook.ignore_last assert logger_hook.time_sec_tot == 0 assert logger_hook.interval_exp_name == 1000 assert logger_hook.out_suffix == ('.log.json', '.log', '.py') assert logger_hook.keep_local assert logger_hook.file_client_args is None assert isinstance(logger_hook.file_client.client, HardDiskBackend) # out_dir should be None or string or tuple of string. with pytest.raises(TypeError): LoggerHook(out_dir=1) # time cannot be overwritten. with pytest.raises(AssertionError): LoggerHook(custom_keys=dict(time=dict(method='max'))) LoggerHook( custom_keys=dict(time=[ dict(method='max', log_name='time_max'), dict(method='min', log_name='time_min') ])) # Epoch window_size cannot be used when `LoggerHook.by_epoch=False` with pytest.raises(AssertionError): LoggerHook( by_epoch=False, custom_keys=dict( time=dict( method='max', log_name='time_max', window_size='epoch'))) with pytest.raises(ValueError): LoggerHook(file_client_args=dict(enable_mc=True)) def test_before_run(self): runner = MagicMock() runner.iter = 10 runner.timestamp = 'timestamp' runner.work_dir = 'work_dir' runner.logger = MagicMock() logger_hook = LoggerHook(out_dir='out_dir') logger_hook.before_run(runner) assert logger_hook.out_dir == osp.join('out_dir', 'work_dir') assert logger_hook.json_log_path == osp.join('work_dir', 'timestamp.log.json') assert logger_hook.start_iter == runner.iter runner.writer.add_params.assert_called() def test_after_run(self, tmp_path): out_dir = tmp_path / 'out_dir' out_dir.mkdir() work_dir = tmp_path / 'work_dir' work_dir.mkdir() work_dir_json = work_dir / 'tmp.log.json' json_f = open(work_dir_json, 'w') json_f.close() runner = MagicMock() runner.work_dir = work_dir logger_hook = LoggerHook(out_dir=str(tmp_path), keep_local=False) logger_hook.out_dir = str(out_dir) logger_hook.after_run(runner) # Verify that the file has been moved to `out_dir`. assert not osp.exists(str(work_dir_json)) assert osp.exists(str(out_dir / 'tmp.log.json')) def test_after_train_iter(self): # Test LoggerHook by iter. runner = MagicMock() runner.iter = 10 batch_idx = 5 logger_hook = LoggerHook(by_epoch=False) logger_hook._log_train = MagicMock() logger_hook.after_train_iter(runner, batch_idx=batch_idx) # `cur_iter=10+1`, which cannot be exact division by # `logger_hook.interval` logger_hook._log_train.assert_not_called() runner.iter = 9 logger_hook.after_train_iter(runner, batch_idx=batch_idx) logger_hook._log_train.assert_called() # Test LoggerHook by epoch. logger_hook = LoggerHook(by_epoch=True) logger_hook._log_train = MagicMock() # Only `runner.inner_iter` will work. runner.iter = 9 batch_idx = 10 logger_hook.after_train_iter(runner, batch_idx=batch_idx) logger_hook._log_train.assert_not_called() batch_idx = 9 logger_hook.after_train_iter(runner, batch_idx=batch_idx) logger_hook._log_train.assert_called() # Test end of the epoch. logger_hook = LoggerHook(by_epoch=True, ignore_last=False) logger_hook._log_train = MagicMock() runner.train_loop.dataloader = [0] * 5 batch_idx = 4 logger_hook.after_train_iter(runner, batch_idx=batch_idx) logger_hook._log_train.assert_called() # Test print exp_name runner.meta = dict(exp_name='retinanet') logger_hook = LoggerHook() runner.logger = MagicMock() logger_hook._log_train = MagicMock() logger_hook.after_train_iter(runner, batch_idx=batch_idx) runner.logger.info.assert_called_with( f'Exp name: {runner.meta["exp_name"]}') def test_after_val_epoch(self): logger_hook = LoggerHook() runner = MagicMock() logger_hook._log_val = MagicMock() logger_hook.after_val_epoch(runner) logger_hook._log_val.assert_called() @pytest.mark.parametrize('by_epoch', [True, False]) def test_log_train(self, by_epoch, capsys): runner = self._setup_runner() runner.meta = dict(exp_name='retinanet') # Prepare LoggerHook logger_hook = LoggerHook(by_epoch=by_epoch) logger_hook._inner_iter = 1 logger_hook.writer = MagicMock() logger_hook.time_sec_tot = 1000 logger_hook.start_iter = 0 logger_hook._get_max_memory = MagicMock(return_value='100') logger_hook.json_log_path = 'tmp.json' # Prepare training information. train_infos = dict( lr=0.1, momentum=0.9, time=1.0, data_time=1.0, loss_cls=1.0) logger_hook._collect_info = MagicMock(return_value=train_infos) logger_hook._log_train(runner) # Verify that the correct variables have been written. runner.writer.add_scalars.assert_called_with( train_infos, step=11, file_path='tmp.json') # Verify that the correct context have been logged. out, _ = capsys.readouterr() time_avg = logger_hook.time_sec_tot / ( runner.iter + 1 - logger_hook.start_iter) eta_second = time_avg * (runner.train_loop.max_iters - runner.iter - 1) eta_str = str(datetime.timedelta(seconds=int(eta_second))) if by_epoch: if torch.cuda.is_available(): log_str = 'Epoch [2][2/5] ' \ f"lr: {train_infos['lr']:.3e} " \ f"momentum: {train_infos['momentum']:.3e}, " \ f'eta: {eta_str}, ' \ f"time: {train_infos['time']:.3f}, " \ f"data_time: {train_infos['data_time']:.3f}, " \ f'memory: 100, ' \ f"loss_cls: {train_infos['loss_cls']:.4f}\n" else: log_str = 'Epoch [2][2/5] ' \ f"lr: {train_infos['lr']:.3e} " \ f"momentum: {train_infos['momentum']:.3e}, " \ f'eta: {eta_str}, ' \ f"time: {train_infos['time']:.3f}, " \ f"data_time: {train_infos['data_time']:.3f}, " \ f"loss_cls: {train_infos['loss_cls']:.4f}\n" assert out == log_str else: if torch.cuda.is_available(): log_str = 'Iter [11/50] ' \ f"lr: {train_infos['lr']:.3e} " \ f"momentum: {train_infos['momentum']:.3e}, " \ f'eta: {eta_str}, ' \ f"time: {train_infos['time']:.3f}, " \ f"data_time: {train_infos['data_time']:.3f}, " \ f'memory: 100, ' \ f"loss_cls: {train_infos['loss_cls']:.4f}\n" else: log_str = 'Iter [11/50] ' \ f"lr: {train_infos['lr']:.3e} " \ f"momentum: {train_infos['momentum']:.3e}, " \ f'eta: {eta_str}, ' \ f"time: {train_infos['time']:.3f}, " \ f"data_time: {train_infos['data_time']:.3f}, " \ f"loss_cls: {train_infos['loss_cls']:.4f}\n" assert out == log_str @pytest.mark.parametrize('by_epoch', [True, False]) def test_log_val(self, by_epoch, capsys): runner = self._setup_runner() # Prepare LoggerHook. logger_hook = LoggerHook(by_epoch=by_epoch) logger_hook.json_log_path = 'tmp.json' metric = dict(accuracy=0.9, data_time=1.0) logger_hook._collect_info = MagicMock(return_value=metric) logger_hook._log_val(runner) # Verify that the correct context have been logged. out, _ = capsys.readouterr() runner.writer.add_scalars.assert_called_with( metric, step=11, file_path='tmp.json') if by_epoch: assert out == 'Epoch(val) [1][5] accuracy: 0.9000, ' \ 'data_time: 1.0000\n' else: assert out == 'Iter(val) [5] accuracy: 0.9000, ' \ 'data_time: 1.0000\n' def test_get_window_size(self): runner = self._setup_runner() logger_hook = LoggerHook() logger_hook._inner_iter = 1 # Test get window size by name. assert logger_hook._get_window_size(runner, 'epoch') == 2 assert logger_hook._get_window_size(runner, 'global') == 11 assert logger_hook._get_window_size(runner, 10) == 10 # Window size must equal to `logger_hook.interval`. with pytest.raises(AssertionError): logger_hook._get_window_size(runner, 20) with pytest.raises(ValueError): logger_hook._get_window_size(runner, 'unknwon') def test_parse_custom_keys(self): tag = OrderedDict() runner = self._setup_runner() log_buffers = OrderedDict(lr=MagicMock(), loss=MagicMock()) cfg_dict = dict( lr=dict(method='min'), loss=[ dict(method='min', window_size='global'), dict(method='max', log_name='loss_max') ]) logger_hook = LoggerHook() for log_key, log_cfg in cfg_dict.items(): logger_hook._parse_custom_keys(runner, log_key, log_cfg, log_buffers, tag) assert list(tag) == ['lr', 'loss', 'loss_max'] assert log_buffers['lr'].min.assert_called assert log_buffers['loss'].min.assert_called assert log_buffers['loss'].max.assert_called assert log_buffers['loss'].mean.assert_called # `log_name` Cannot be repeated. with pytest.raises(KeyError): cfg_dict = dict(loss=[ dict(method='min', window_size='global'), dict(method='max', log_name='loss_max'), dict(method='mean', log_name='loss_max') ]) logger_hook.custom_keys = cfg_dict for log_key, log_cfg in cfg_dict.items(): logger_hook._parse_custom_keys(runner, log_key, log_cfg, log_buffers, tag) # `log_key` cannot be overwritten multiple times. with pytest.raises(AssertionError): cfg_dict = dict(loss=[ dict(method='min', window_size='global'), dict(method='max'), ]) logger_hook.custom_keys = cfg_dict for log_key, log_cfg in cfg_dict.items(): logger_hook._parse_custom_keys(runner, log_key, log_cfg, log_buffers, tag) def test_collect_info(self): runner = self._setup_runner() logger_hook = LoggerHook( custom_keys=dict(time=dict(method='max', log_name='time_max'))) logger_hook._parse_custom_keys = MagicMock() # Collect with prefix. log_buffers = { 'train/time': MagicMock(), 'lr': MagicMock(), 'train/loss_cls': MagicMock(), 'val/metric': MagicMock() } runner.message_hub.log_buffers = log_buffers tag = logger_hook._collect_info(runner, mode='train') # Test parse custom_keys logger_hook._parse_custom_keys.assert_called() # Test training key in tag. assert list(tag.keys()) == ['time', 'loss_cls'] # Test statistics lr with `current`, loss and time with 'mean' log_buffers['train/time'].mean.assert_called() log_buffers['train/loss_cls'].mean.assert_called() log_buffers['train/loss_cls'].current.assert_not_called() tag = logger_hook._collect_info(runner, mode='val') assert list(tag.keys()) == ['metric'] log_buffers['val/metric'].current.assert_called() @patch('torch.cuda.max_memory_allocated', MagicMock()) @patch('torch.cuda.reset_peak_memory_stats', MagicMock()) def test_get_max_memory(self): logger_hook = LoggerHook() runner = MagicMock() runner.world_size = 1 runner.model = torch.nn.Linear(1, 1) logger_hook._get_max_memory(runner) torch.cuda.max_memory_allocated.assert_called() torch.cuda.reset_peak_memory_stats.assert_called() def test_get_iter(self): runner = self._setup_runner() logger_hook = LoggerHook() logger_hook._inner_iter = 1 # Get global iter when `inner_iter=False` iter = logger_hook._get_iter(runner) assert iter == 11 # Get inner iter iter = logger_hook._get_iter(runner, inner_iter=True) assert iter == 2 # Still get global iter when `logger_hook.by_epoch==False` logger_hook.by_epoch = False iter = logger_hook._get_iter(runner, inner_iter=True) assert iter == 11 def test_get_epoch(self): runner = self._setup_runner() logger_hook = LoggerHook() epoch = logger_hook._get_epoch(runner, 'train') assert epoch == 2 epoch = logger_hook._get_epoch(runner, 'val') assert epoch == 1 with pytest.raises(ValueError): logger_hook._get_epoch(runner, 'test') def _setup_runner(self): runner = MagicMock() runner.epoch = 1 runner.train_loop.dataloader = [0] * 5 runner.val_loop.dataloader = [0] * 5 runner.test_loop.dataloader = [0] * 5 runner.iter = 10 runner.train_loop.max_iters = 50 logger = logging.getLogger() logger.setLevel(logging.INFO) for handler in logger.handlers: if not isinstance(handler, logging.StreamHandler): continue else: logger.addHandler(logging.StreamHandler(stream=sys.stdout)) runner.logger = logger runner.message_hub = MagicMock() runner.composed_wirter = MagicMock() return runner