From 5762b288473766e9b3a891ed133dbab9b028d556 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Fri, 7 Apr 2023 16:20:38 +0800 Subject: [PATCH] [Refactor] Refactor logger hook unit tests (#797) * Enhance config * add unit test data * reafactor unittest of loggerhook * fix rebase error * Fix permission error in windows * Fix CI * Fix windows ci * Fix windows ci * Fix windows ci * Fix windows CI * Apply suggestions from code review Co-authored-by: Qian Zhao <112053249+C1rN09@users.noreply.github.com> * clean the code * Refine as comment * Refine error rasing * Update mmengine/hooks/logger_hook.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * replace assert_called_with with assert_has_calls * Fix as comment * Do not remove filehandler and fix unit test --------- Co-authored-by: Qian Zhao <112053249+C1rN09@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmengine/hooks/logger_hook.py | 51 ++++- .../config/py_config/test_custom_class.py | 5 + tests/test_hooks/test_logger_hook.py | 206 ++++++++++-------- 3 files changed, 157 insertions(+), 105 deletions(-) create mode 100644 tests/data/config/py_config/test_custom_class.py diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index 26cd35b7..dc310b56 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -14,7 +14,7 @@ from mmengine.fileio.io import get_file_backend from mmengine.hooks import Hook from mmengine.logging import print_log from mmengine.registry import HOOKS -from mmengine.utils import is_tuple_of, scandir +from mmengine.utils import is_seq_of, scandir DATA_BATCH = Optional[Union[dict, tuple, list]] SUFFIX_TYPE = Union[Sequence[str], str] @@ -84,15 +84,30 @@ class LoggerHook(Hook): file_client_args: Optional[dict] = None, log_metric_by_epoch: bool = True, backend_args: Optional[dict] = None): - self.interval = interval - self.ignore_last = ignore_last - self.interval_exp_name = interval_exp_name + + if not isinstance(interval, int): + raise TypeError('interval must be an integer') + if interval <= 0: + raise ValueError('interval must be greater than 0') + + if not isinstance(ignore_last, bool): + raise TypeError('ignore_last must be a boolean') + + if not isinstance(interval_exp_name, int): + raise TypeError('interval_exp_name must be an integer') + if interval_exp_name <= 0: + raise ValueError('interval_exp_name must be greater than 0') + + if out_dir is not None and not isinstance(out_dir, (str, Path)): + raise TypeError('out_dir must be a str or Path object') + + if not isinstance(keep_local, bool): + raise TypeError('keep_local must be a boolean') if out_dir is None and file_client_args is not None: raise ValueError( 'file_client_args should be "None" when `out_dir` is not' 'specified.') - self.out_dir = out_dir if file_client_args is not None: print_log( @@ -105,12 +120,15 @@ class LoggerHook(Hook): '"file_client_args" and "backend_args" cannot be set ' 'at the same time.') - if not (out_dir is None or isinstance(out_dir, str) - or is_tuple_of(out_dir, str)): - raise TypeError('out_dir should be None or string or tuple of ' - f'string, but got {type(out_dir)}') - self.out_suffix = out_suffix + if not (isinstance(out_suffix, str) or is_seq_of(out_suffix, str)): + raise TypeError('out_suffix should be a string or a sequence of ' + f'string, but got {type(out_suffix)}') + self.out_suffix = out_suffix + self.out_dir = out_dir + self.interval = interval + self.ignore_last = ignore_last + self.interval_exp_name = interval_exp_name self.keep_local = keep_local self.file_client_args = file_client_args self.json_log_path: Optional[str] = None @@ -291,8 +309,11 @@ class LoggerHook(Hook): # copy or upload logs to self.out_dir if self.out_dir is None: return + + removed_files = [] for filename in scandir(runner._log_dir, self.out_suffix, True): local_filepath = osp.join(runner._log_dir, filename) + removed_files.append(local_filepath) out_filepath = self.file_backend.join_path(self.out_dir, filename) with open(local_filepath) as f: self.file_backend.put_text(f.read(), out_filepath) @@ -302,7 +323,15 @@ class LoggerHook(Hook): f'{out_filepath}.') if not self.keep_local: - os.remove(local_filepath) runner.logger.info(f'{local_filepath} was removed due to the ' '`self.keep_local=False`. You can check ' f'the running logs in {out_filepath}') + + if not self.keep_local: + # Close file handler to avoid PermissionError on Windows. + for handler in runner.logger.handlers: + if isinstance(handler, logging.FileHandler): + handler.close() + + for file in removed_files: + os.remove(file) diff --git a/tests/data/config/py_config/test_custom_class.py b/tests/data/config/py_config/test_custom_class.py new file mode 100644 index 00000000..ad706b08 --- /dev/null +++ b/tests/data/config/py_config/test_custom_class.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +class A: + ... + +item_a = dict(a=A) diff --git a/tests/test_hooks/test_logger_hook.py b/tests/test_hooks/test_logger_hook.py index aab2817a..b6f6268f 100644 --- a/tests/test_hooks/test_logger_hook.py +++ b/tests/test_hooks/test_logger_hook.py @@ -1,89 +1,65 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy +import os import os.path as osp -from unittest.mock import ANY, MagicMock +import shutil +from unittest.mock import ANY, MagicMock, call -import pytest import torch from mmengine.fileio import load -from mmengine.fileio.file_client import HardDiskBackend from mmengine.hooks import LoggerHook +from mmengine.logging import MMLogger +from mmengine.testing import RunnerTestCase +from mmengine.utils import mkdir_or_exist, scandir -class TestLoggerHook: +class TestLoggerHook(RunnerTestCase): def test_init(self): - logger_hook = LoggerHook(out_dir='tmp.txt') - assert logger_hook.interval == 10 - assert logger_hook.ignore_last - assert logger_hook.interval_exp_name == 1000 - assert logger_hook.out_suffix == ('.json', '.log', '.py', 'yaml') - assert logger_hook.keep_local - assert logger_hook.file_client_args is None - assert isinstance(logger_hook.file_client.client, HardDiskBackend) + # Test build logger hook. + LoggerHook() + LoggerHook(interval=100, ignore_last=False, interval_exp_name=100) + + with self.assertRaisesRegex(TypeError, 'interval must be'): + LoggerHook(interval='100') + + with self.assertRaisesRegex(ValueError, 'interval must be'): + LoggerHook(interval=-1) + + with self.assertRaisesRegex(TypeError, 'ignore_last must be'): + LoggerHook(ignore_last='False') + + with self.assertRaisesRegex(TypeError, 'interval_exp_name'): + LoggerHook(interval_exp_name='100') + + with self.assertRaisesRegex(ValueError, 'interval_exp_name'): + LoggerHook(interval_exp_name=-1) + + with self.assertRaisesRegex(TypeError, 'out_suffix'): + LoggerHook(out_suffix=[100]) + # out_dir should be None or string or tuple of string. - with pytest.raises(TypeError): + with self.assertRaisesRegex(TypeError, 'out_dir must be'): LoggerHook(out_dir=1) - with pytest.raises(ValueError): + with self.assertRaisesRegex(ValueError, 'file_client_args'): LoggerHook(file_client_args=dict(enable_mc=True)) - # test `file_client_args` and `backend_args` - # TODO Refine this unit test - # with pytest.warns( - # DeprecationWarning, - # match='"file_client_args" will be deprecated in future'): - # logger_hook = LoggerHook( - # out_dir='tmp.txt', file_client_args={'backend': 'disk'}) + # test deprecated warning raised by `file_client_args` + logger = MMLogger.get_current_instance() + with self.assertLogs(logger, level='WARNING'): + LoggerHook( + out_dir=self.temp_dir.name, + file_client_args=dict(backend='disk')) - with pytest.raises( + with self.assertRaisesRegex( ValueError, - match='"file_client_args" and "backend_args" cannot be ' - 'set at the same time'): - logger_hook = LoggerHook( - out_dir='tmp.txt', - file_client_args={'backend': 'disk'}, - backend_args={'backend': 'local'}) - - def test_before_run(self): - runner = MagicMock() - runner.iter = 10 - runner.timestamp = '20220429' - runner._log_dir = f'work_dir/{runner.timestamp}' - runner.work_dir = 'work_dir' - runner.logger = MagicMock() - logger_hook = LoggerHook(out_dir='out_dir') - logger_hook.before_run(runner) - assert logger_hook.out_dir == osp.join('out_dir', 'work_dir') - assert logger_hook.json_log_path == f'{runner.timestamp}.json' - - def test_after_run(self, tmp_path): - # Test - timestamp = '20220429' - out_dir = tmp_path / 'out_dir' - out_dir.mkdir() - work_dir = tmp_path / 'work_dir' - work_dir.mkdir() - log_dir = work_dir / timestamp - log_dir.mkdir() - log_dir_json = log_dir / 'tmp.log.json' - runner = MagicMock() - runner._log_dir = str(log_dir) - runner.timestamp = timestamp - runner.work_dir = str(work_dir) - # Test without out_dir. - logger_hook = LoggerHook() - logger_hook.after_run(runner) - # Test with out_dir and make sure json file has been moved to out_dir. - json_f = open(log_dir_json, 'w') - json_f.close() - logger_hook = LoggerHook(out_dir=str(out_dir), keep_local=False) - logger_hook.out_dir = str(out_dir) - logger_hook.before_run(runner) - logger_hook.after_run(runner) - # Verify that the file has been moved to `out_dir`. - assert not osp.exists(str(log_dir_json)) - assert osp.exists(str(out_dir / 'work_dir' / 'tmp.log.json')) + '"file_client_args" and "backend_args" cannot be '): + LoggerHook( + out_dir=self.temp_dir.name, + file_client_args=dict(enable_mc=True), + backend_args=dict(enable_mc=True)) def test_after_train_iter(self): # Test LoggerHook by iter. @@ -130,13 +106,6 @@ class TestLoggerHook: def test_after_val_epoch(self): logger_hook = LoggerHook() runner = MagicMock() - runner.log_processor.get_log_after_epoch = MagicMock( - return_value=(dict(), 'string')) - logger_hook.after_val_epoch(runner) - runner.log_processor.get_log_after_epoch.assert_called() - runner.logger.info.assert_called() - runner.visualizer.add_scalars.assert_called() - # Test when `log_metric_by_epoch` is True runner.log_processor.get_log_after_epoch = MagicMock( return_value=({ @@ -145,14 +114,19 @@ class TestLoggerHook: 'acc': 0.8 }, 'string')) logger_hook.after_val_epoch(runner) - args = {'step': ANY, 'file_path': ANY} + # expect visualizer log `time` and `metric` respectively - runner.visualizer.add_scalars.assert_called_with( - { + args = {'step': ANY, 'file_path': ANY} + calls = [ + call({ 'time': 1, 'datatime': 1, 'acc': 0.8 - }, **args) + }, **args), + ] + self.assertEqual( + len(calls), len(runner.visualizer.add_scalars.mock_calls)) + runner.visualizer.add_scalars.assert_has_calls(calls) # Test when `log_metric_by_epoch` is False logger_hook = LoggerHook(log_metric_by_epoch=False) @@ -163,27 +137,28 @@ class TestLoggerHook: 'acc': 0.5 }, 'string')) logger_hook.after_val_epoch(runner) + # expect visualizer log `time` and `metric` jointly - runner.visualizer.add_scalars.assert_called_with( - { + calls = [ + call({ + 'time': 1, + 'datatime': 1, + 'acc': 0.8 + }, **args), + call({ 'time': 5, 'datatime': 5, 'acc': 0.5 - }, **args) - - with pytest.raises(AssertionError): - runner.visualizer.add_scalars.assert_any_call( - { - 'time': 5, - 'datatime': 5 - }, **args) - with pytest.raises(AssertionError): - runner.visualizer.add_scalars.assert_any_call({'acc': 0.5}, **args) - - def test_after_test_epoch(self, tmp_path): + }, **args), + ] + self.assertEqual( + len(calls), len(runner.visualizer.add_scalars.mock_calls)) + runner.visualizer.add_scalars.assert_has_calls(calls) + + def test_after_test_epoch(self): logger_hook = LoggerHook() runner = MagicMock() - runner.log_dir = tmp_path + runner.log_dir = self.temp_dir.name runner.timestamp = 'test_after_test_epoch' runner.log_processor.get_log_after_epoch = MagicMock( return_value=( @@ -219,3 +194,46 @@ class TestLoggerHook: runner.log_processor.get_log_after_iter.assert_not_called() logger_hook.after_test_iter(runner, 9) runner.log_processor.get_log_after_iter.assert_called() + + def test_with_runner(self): + # Test dumped the json exits + cfg = copy.deepcopy(self.epoch_based_cfg) + cfg.default_hooks.logger = dict(type='LoggerHook') + cfg.train_cfg.max_epochs = 10 + runner = self.build_runner(cfg) + runner.train() + json_path = osp.join(runner._log_dir, 'vis_data', + f'{runner.timestamp}.json') + self.assertTrue(osp.isfile(json_path)) + + # Test out_dir + out_dir = osp.join(cfg.work_dir, 'test') + mkdir_or_exist(out_dir) + cfg.default_hooks.logger = dict(type='LoggerHook', out_dir=out_dir) + runner = self.build_runner(cfg) + runner.train() + self.assertTrue(os.listdir(out_dir)) + # clean the out_dir + for filename in os.listdir(out_dir): + shutil.rmtree(osp.join(out_dir, filename)) + + # Test out_suffix + cfg.default_hooks.logger = dict( + type='LoggerHook', out_dir=out_dir, out_suffix='.log') + runner = self.build_runner(cfg) + runner.train() + filenames = scandir(out_dir, recursive=True) + self.assertTrue( + all(filename.endswith('.log') for filename in filenames)) + + # Test keep_local=False + cfg.default_hooks.logger = dict( + type='LoggerHook', out_dir=out_dir, keep_local=False) + runner = self.build_runner(cfg) + runner.train() + filenames = scandir(runner._log_dir, recursive=True) + + for filename in filenames: + self.assertFalse( + filename.endswith(('.log', '.json', '.py', '.yaml')), + f'{filename} should not be kept.') -- GitLab