From 5762b288473766e9b3a891ed133dbab9b028d556 Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Fri, 7 Apr 2023 16:20:38 +0800
Subject: [PATCH] [Refactor] Refactor logger hook unit tests (#797)

* Enhance config

* add unit test data

* reafactor unittest of loggerhook

* fix rebase error

* Fix permission error in windows

* Fix CI

* Fix windows ci

* Fix windows ci

* Fix windows ci

* Fix windows CI

* Apply suggestions from code review

Co-authored-by: Qian Zhao <112053249+C1rN09@users.noreply.github.com>

* clean the code

* Refine as comment

* Refine error rasing

* Update mmengine/hooks/logger_hook.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* replace assert_called_with with assert_has_calls

* Fix as comment

* Do not remove filehandler and fix unit test

---------

Co-authored-by: Qian Zhao <112053249+C1rN09@users.noreply.github.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
---
 mmengine/hooks/logger_hook.py                 |  51 ++++-
 .../config/py_config/test_custom_class.py     |   5 +
 tests/test_hooks/test_logger_hook.py          | 206 ++++++++++--------
 3 files changed, 157 insertions(+), 105 deletions(-)
 create mode 100644 tests/data/config/py_config/test_custom_class.py

diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py
index 26cd35b7..dc310b56 100644
--- a/mmengine/hooks/logger_hook.py
+++ b/mmengine/hooks/logger_hook.py
@@ -14,7 +14,7 @@ from mmengine.fileio.io import get_file_backend
 from mmengine.hooks import Hook
 from mmengine.logging import print_log
 from mmengine.registry import HOOKS
-from mmengine.utils import is_tuple_of, scandir
+from mmengine.utils import is_seq_of, scandir
 
 DATA_BATCH = Optional[Union[dict, tuple, list]]
 SUFFIX_TYPE = Union[Sequence[str], str]
@@ -84,15 +84,30 @@ class LoggerHook(Hook):
                  file_client_args: Optional[dict] = None,
                  log_metric_by_epoch: bool = True,
                  backend_args: Optional[dict] = None):
-        self.interval = interval
-        self.ignore_last = ignore_last
-        self.interval_exp_name = interval_exp_name
+
+        if not isinstance(interval, int):
+            raise TypeError('interval must be an integer')
+        if interval <= 0:
+            raise ValueError('interval must be greater than 0')
+
+        if not isinstance(ignore_last, bool):
+            raise TypeError('ignore_last must be a boolean')
+
+        if not isinstance(interval_exp_name, int):
+            raise TypeError('interval_exp_name must be an integer')
+        if interval_exp_name <= 0:
+            raise ValueError('interval_exp_name must be greater than 0')
+
+        if out_dir is not None and not isinstance(out_dir, (str, Path)):
+            raise TypeError('out_dir must be a str or Path object')
+
+        if not isinstance(keep_local, bool):
+            raise TypeError('keep_local must be a boolean')
 
         if out_dir is None and file_client_args is not None:
             raise ValueError(
                 'file_client_args should be "None" when `out_dir` is not'
                 'specified.')
-        self.out_dir = out_dir
 
         if file_client_args is not None:
             print_log(
@@ -105,12 +120,15 @@ class LoggerHook(Hook):
                     '"file_client_args" and "backend_args" cannot be set '
                     'at the same time.')
 
-        if not (out_dir is None or isinstance(out_dir, str)
-                or is_tuple_of(out_dir, str)):
-            raise TypeError('out_dir should be None or string or tuple of '
-                            f'string, but got {type(out_dir)}')
-        self.out_suffix = out_suffix
+        if not (isinstance(out_suffix, str) or is_seq_of(out_suffix, str)):
+            raise TypeError('out_suffix should be a string or a sequence of '
+                            f'string, but got {type(out_suffix)}')
 
+        self.out_suffix = out_suffix
+        self.out_dir = out_dir
+        self.interval = interval
+        self.ignore_last = ignore_last
+        self.interval_exp_name = interval_exp_name
         self.keep_local = keep_local
         self.file_client_args = file_client_args
         self.json_log_path: Optional[str] = None
@@ -291,8 +309,11 @@ class LoggerHook(Hook):
         # copy or upload logs to self.out_dir
         if self.out_dir is None:
             return
+
+        removed_files = []
         for filename in scandir(runner._log_dir, self.out_suffix, True):
             local_filepath = osp.join(runner._log_dir, filename)
+            removed_files.append(local_filepath)
             out_filepath = self.file_backend.join_path(self.out_dir, filename)
             with open(local_filepath) as f:
                 self.file_backend.put_text(f.read(), out_filepath)
@@ -302,7 +323,15 @@ class LoggerHook(Hook):
                 f'{out_filepath}.')
 
             if not self.keep_local:
-                os.remove(local_filepath)
                 runner.logger.info(f'{local_filepath} was removed due to the '
                                    '`self.keep_local=False`. You can check '
                                    f'the running logs in {out_filepath}')
+
+        if not self.keep_local:
+            # Close file handler to avoid PermissionError on Windows.
+            for handler in runner.logger.handlers:
+                if isinstance(handler, logging.FileHandler):
+                    handler.close()
+
+            for file in removed_files:
+                os.remove(file)
diff --git a/tests/data/config/py_config/test_custom_class.py b/tests/data/config/py_config/test_custom_class.py
new file mode 100644
index 00000000..ad706b08
--- /dev/null
+++ b/tests/data/config/py_config/test_custom_class.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+class A:
+    ...
+
+item_a = dict(a=A)
diff --git a/tests/test_hooks/test_logger_hook.py b/tests/test_hooks/test_logger_hook.py
index aab2817a..b6f6268f 100644
--- a/tests/test_hooks/test_logger_hook.py
+++ b/tests/test_hooks/test_logger_hook.py
@@ -1,89 +1,65 @@
 # Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import os
 import os.path as osp
-from unittest.mock import ANY, MagicMock
+import shutil
+from unittest.mock import ANY, MagicMock, call
 
-import pytest
 import torch
 
 from mmengine.fileio import load
-from mmengine.fileio.file_client import HardDiskBackend
 from mmengine.hooks import LoggerHook
+from mmengine.logging import MMLogger
+from mmengine.testing import RunnerTestCase
+from mmengine.utils import mkdir_or_exist, scandir
 
 
-class TestLoggerHook:
+class TestLoggerHook(RunnerTestCase):
 
     def test_init(self):
-        logger_hook = LoggerHook(out_dir='tmp.txt')
-        assert logger_hook.interval == 10
-        assert logger_hook.ignore_last
-        assert logger_hook.interval_exp_name == 1000
-        assert logger_hook.out_suffix == ('.json', '.log', '.py', 'yaml')
-        assert logger_hook.keep_local
-        assert logger_hook.file_client_args is None
-        assert isinstance(logger_hook.file_client.client, HardDiskBackend)
+        # Test build logger hook.
+        LoggerHook()
+        LoggerHook(interval=100, ignore_last=False, interval_exp_name=100)
+
+        with self.assertRaisesRegex(TypeError, 'interval must be'):
+            LoggerHook(interval='100')
+
+        with self.assertRaisesRegex(ValueError, 'interval must be'):
+            LoggerHook(interval=-1)
+
+        with self.assertRaisesRegex(TypeError, 'ignore_last must be'):
+            LoggerHook(ignore_last='False')
+
+        with self.assertRaisesRegex(TypeError, 'interval_exp_name'):
+            LoggerHook(interval_exp_name='100')
+
+        with self.assertRaisesRegex(ValueError, 'interval_exp_name'):
+            LoggerHook(interval_exp_name=-1)
+
+        with self.assertRaisesRegex(TypeError, 'out_suffix'):
+            LoggerHook(out_suffix=[100])
+
         # out_dir should be None or string or tuple of string.
-        with pytest.raises(TypeError):
+        with self.assertRaisesRegex(TypeError, 'out_dir must be'):
             LoggerHook(out_dir=1)
 
-        with pytest.raises(ValueError):
+        with self.assertRaisesRegex(ValueError, 'file_client_args'):
             LoggerHook(file_client_args=dict(enable_mc=True))
 
-        # test `file_client_args` and `backend_args`
-        # TODO Refine this unit test
-        # with pytest.warns(
-        #         DeprecationWarning,
-        #         match='"file_client_args" will be deprecated in future'):
-        #     logger_hook = LoggerHook(
-        #         out_dir='tmp.txt', file_client_args={'backend': 'disk'})
+        # test deprecated warning raised by `file_client_args`
+        logger = MMLogger.get_current_instance()
+        with self.assertLogs(logger, level='WARNING'):
+            LoggerHook(
+                out_dir=self.temp_dir.name,
+                file_client_args=dict(backend='disk'))
 
-        with pytest.raises(
+        with self.assertRaisesRegex(
                 ValueError,
-                match='"file_client_args" and "backend_args" cannot be '
-                'set at the same time'):
-            logger_hook = LoggerHook(
-                out_dir='tmp.txt',
-                file_client_args={'backend': 'disk'},
-                backend_args={'backend': 'local'})
-
-    def test_before_run(self):
-        runner = MagicMock()
-        runner.iter = 10
-        runner.timestamp = '20220429'
-        runner._log_dir = f'work_dir/{runner.timestamp}'
-        runner.work_dir = 'work_dir'
-        runner.logger = MagicMock()
-        logger_hook = LoggerHook(out_dir='out_dir')
-        logger_hook.before_run(runner)
-        assert logger_hook.out_dir == osp.join('out_dir', 'work_dir')
-        assert logger_hook.json_log_path == f'{runner.timestamp}.json'
-
-    def test_after_run(self, tmp_path):
-        # Test
-        timestamp = '20220429'
-        out_dir = tmp_path / 'out_dir'
-        out_dir.mkdir()
-        work_dir = tmp_path / 'work_dir'
-        work_dir.mkdir()
-        log_dir = work_dir / timestamp
-        log_dir.mkdir()
-        log_dir_json = log_dir / 'tmp.log.json'
-        runner = MagicMock()
-        runner._log_dir = str(log_dir)
-        runner.timestamp = timestamp
-        runner.work_dir = str(work_dir)
-        # Test without out_dir.
-        logger_hook = LoggerHook()
-        logger_hook.after_run(runner)
-        # Test with out_dir and make sure json file has been moved to out_dir.
-        json_f = open(log_dir_json, 'w')
-        json_f.close()
-        logger_hook = LoggerHook(out_dir=str(out_dir), keep_local=False)
-        logger_hook.out_dir = str(out_dir)
-        logger_hook.before_run(runner)
-        logger_hook.after_run(runner)
-        # Verify that the file has been moved to `out_dir`.
-        assert not osp.exists(str(log_dir_json))
-        assert osp.exists(str(out_dir / 'work_dir' / 'tmp.log.json'))
+                '"file_client_args" and "backend_args" cannot be '):
+            LoggerHook(
+                out_dir=self.temp_dir.name,
+                file_client_args=dict(enable_mc=True),
+                backend_args=dict(enable_mc=True))
 
     def test_after_train_iter(self):
         # Test LoggerHook by iter.
@@ -130,13 +106,6 @@ class TestLoggerHook:
     def test_after_val_epoch(self):
         logger_hook = LoggerHook()
         runner = MagicMock()
-        runner.log_processor.get_log_after_epoch = MagicMock(
-            return_value=(dict(), 'string'))
-        logger_hook.after_val_epoch(runner)
-        runner.log_processor.get_log_after_epoch.assert_called()
-        runner.logger.info.assert_called()
-        runner.visualizer.add_scalars.assert_called()
-
         # Test when `log_metric_by_epoch` is True
         runner.log_processor.get_log_after_epoch = MagicMock(
             return_value=({
@@ -145,14 +114,19 @@ class TestLoggerHook:
                 'acc': 0.8
             }, 'string'))
         logger_hook.after_val_epoch(runner)
-        args = {'step': ANY, 'file_path': ANY}
+
         # expect visualizer log `time` and `metric` respectively
-        runner.visualizer.add_scalars.assert_called_with(
-            {
+        args = {'step': ANY, 'file_path': ANY}
+        calls = [
+            call({
                 'time': 1,
                 'datatime': 1,
                 'acc': 0.8
-            }, **args)
+            }, **args),
+        ]
+        self.assertEqual(
+            len(calls), len(runner.visualizer.add_scalars.mock_calls))
+        runner.visualizer.add_scalars.assert_has_calls(calls)
 
         # Test when `log_metric_by_epoch` is False
         logger_hook = LoggerHook(log_metric_by_epoch=False)
@@ -163,27 +137,28 @@ class TestLoggerHook:
                 'acc': 0.5
             }, 'string'))
         logger_hook.after_val_epoch(runner)
+
         # expect visualizer log `time` and `metric` jointly
-        runner.visualizer.add_scalars.assert_called_with(
-            {
+        calls = [
+            call({
+                'time': 1,
+                'datatime': 1,
+                'acc': 0.8
+            }, **args),
+            call({
                 'time': 5,
                 'datatime': 5,
                 'acc': 0.5
-            }, **args)
-
-        with pytest.raises(AssertionError):
-            runner.visualizer.add_scalars.assert_any_call(
-                {
-                    'time': 5,
-                    'datatime': 5
-                }, **args)
-        with pytest.raises(AssertionError):
-            runner.visualizer.add_scalars.assert_any_call({'acc': 0.5}, **args)
-
-    def test_after_test_epoch(self, tmp_path):
+            }, **args),
+        ]
+        self.assertEqual(
+            len(calls), len(runner.visualizer.add_scalars.mock_calls))
+        runner.visualizer.add_scalars.assert_has_calls(calls)
+
+    def test_after_test_epoch(self):
         logger_hook = LoggerHook()
         runner = MagicMock()
-        runner.log_dir = tmp_path
+        runner.log_dir = self.temp_dir.name
         runner.timestamp = 'test_after_test_epoch'
         runner.log_processor.get_log_after_epoch = MagicMock(
             return_value=(
@@ -219,3 +194,46 @@ class TestLoggerHook:
         runner.log_processor.get_log_after_iter.assert_not_called()
         logger_hook.after_test_iter(runner, 9)
         runner.log_processor.get_log_after_iter.assert_called()
+
+    def test_with_runner(self):
+        # Test dumped the json exits
+        cfg = copy.deepcopy(self.epoch_based_cfg)
+        cfg.default_hooks.logger = dict(type='LoggerHook')
+        cfg.train_cfg.max_epochs = 10
+        runner = self.build_runner(cfg)
+        runner.train()
+        json_path = osp.join(runner._log_dir, 'vis_data',
+                             f'{runner.timestamp}.json')
+        self.assertTrue(osp.isfile(json_path))
+
+        # Test out_dir
+        out_dir = osp.join(cfg.work_dir, 'test')
+        mkdir_or_exist(out_dir)
+        cfg.default_hooks.logger = dict(type='LoggerHook', out_dir=out_dir)
+        runner = self.build_runner(cfg)
+        runner.train()
+        self.assertTrue(os.listdir(out_dir))
+        # clean the out_dir
+        for filename in os.listdir(out_dir):
+            shutil.rmtree(osp.join(out_dir, filename))
+
+        # Test out_suffix
+        cfg.default_hooks.logger = dict(
+            type='LoggerHook', out_dir=out_dir, out_suffix='.log')
+        runner = self.build_runner(cfg)
+        runner.train()
+        filenames = scandir(out_dir, recursive=True)
+        self.assertTrue(
+            all(filename.endswith('.log') for filename in filenames))
+
+        # Test keep_local=False
+        cfg.default_hooks.logger = dict(
+            type='LoggerHook', out_dir=out_dir, keep_local=False)
+        runner = self.build_runner(cfg)
+        runner.train()
+        filenames = scandir(runner._log_dir, recursive=True)
+
+        for filename in filenames:
+            self.assertFalse(
+                filename.endswith(('.log', '.json', '.py', '.yaml')),
+                f'{filename} should not be kept.')
-- 
GitLab