From dceef1f66fd5afd6346f9941b6595b81ee6b364e Mon Sep 17 00:00:00 2001 From: Alex Yang <50511903+imabackstabber@users.noreply.github.com> Date: Tue, 21 Jun 2022 15:39:59 +0800 Subject: [PATCH] [Refactor] Refactor `after_val_epoch` to make it output metric by epoch (#278) * [Refactor]:Refactor `after_val_epoch` to make it output metric by epoch * add an option for user to choose the way of outputing metric * rename variable * reformat docstring * add type alias * reformat code * add test function * add comment and test code * add comment and test code --- mmengine/hooks/logger_hook.py | 44 +++++++++++++++++++--------- tests/test_hook/test_logger_hook.py | 45 ++++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 14 deletions(-) diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index 6d9bc869..b9fb5f82 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -11,6 +11,7 @@ from mmengine.registry import HOOKS from mmengine.utils import is_tuple_of, scandir DATA_BATCH = Optional[Sequence[dict]] +SUFFIX_TYPE = Union[Sequence[str], str] @HOOKS.register_module() @@ -51,6 +52,11 @@ class LoggerHook(Hook): file_client_args (dict, optional): Arguments to instantiate a FileClient. See :class:`mmengine.fileio.FileClient` for details. Defaults to None. + log_metric_by_epoch (bool): Whether to output metric in validation step + by epoch. It can be true when running in epoch based runner. + If set to True, `after_val_epoch` will set `step` to self.epoch in + `runner.visualizer.add_scalars`. Otherwise `step` will be + self.iter. Default to True. Examples: >>> # The simplest LoggerHook config. @@ -58,17 +64,15 @@ class LoggerHook(Hook): """ priority = 'BELOW_NORMAL' - def __init__( - self, - interval: int = 10, - ignore_last: bool = True, - interval_exp_name: int = 1000, - out_dir: Optional[Union[str, Path]] = None, - out_suffix: Union[Sequence[str], - str] = ('.json', '.log', '.py', 'yaml'), - keep_local: bool = True, - file_client_args: Optional[dict] = None, - ): + def __init__(self, + interval: int = 10, + ignore_last: bool = True, + interval_exp_name: int = 1000, + out_dir: Optional[Union[str, Path]] = None, + out_suffix: SUFFIX_TYPE = ('.json', '.log', '.py', 'yaml'), + keep_local: bool = True, + file_client_args: Optional[dict] = None, + log_metric_by_epoch: bool = True): self.interval = interval self.ignore_last = ignore_last self.interval_exp_name = interval_exp_name @@ -91,6 +95,7 @@ class LoggerHook(Hook): if self.out_dir is not None: self.file_client = FileClient.infer_client(file_client_args, self.out_dir) + self.log_metric_by_epoch = log_metric_by_epoch def before_run(self, runner) -> None: """Infer ``self.file_client`` from ``self.out_dir``. Initialize the @@ -203,8 +208,21 @@ class LoggerHook(Hook): tag, log_str = runner.log_processor.get_log_after_epoch( runner, len(runner.val_dataloader), 'val') runner.logger.info(log_str) - runner.visualizer.add_scalars( - tag, step=runner.iter, file_path=self.json_log_path) + if self.log_metric_by_epoch: + # when `log_metric_by_epoch` is set to True, it's expected + # that validation metric can be logged by epoch rather than + # by iter. At the same time, scalars related to time should + # still be logged by iter to avoid messy visualized result. + # see details in PR #278. + time_tags = {k: v for k, v in tag.items() if 'time' in k} + metric_tags = {k: v for k, v in tag.items() if 'time' not in k} + runner.visualizer.add_scalars( + time_tags, step=runner.iter, file_path=self.json_log_path) + runner.visualizer.add_scalars( + metric_tags, step=runner.epoch, file_path=self.json_log_path) + else: + runner.visualizer.add_scalars( + tag, step=runner.iter, file_path=self.json_log_path) def after_test_epoch(self, runner, diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py index 2cf75dc4..66d0d06c 100644 --- a/tests/test_hook/test_logger_hook.py +++ b/tests/test_hook/test_logger_hook.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp -from unittest.mock import MagicMock +from unittest.mock import ANY, MagicMock import pytest @@ -119,6 +119,49 @@ class TestLoggerHook: runner.logger.info.assert_called() runner.visualizer.add_scalars.assert_called() + # Test when `log_metric_by_epoch` is True + runner.log_processor.get_log_after_epoch = MagicMock( + return_value=({ + 'time': 1, + 'datatime': 1, + 'acc': 0.8 + }, 'string')) + logger_hook.after_val_epoch(runner) + args = {'step': ANY, 'file_path': ANY} + # expect visualizer log `time` and `metric` respectively + runner.visualizer.add_scalars.assert_any_call( + { + 'time': 1, + 'datatime': 1 + }, **args) + runner.visualizer.add_scalars.assert_any_call({'acc': 0.8}, **args) + + # Test when `log_metric_by_epoch` is False + logger_hook = LoggerHook(log_metric_by_epoch=False) + runner.log_processor.get_log_after_epoch = MagicMock( + return_value=({ + 'time': 5, + 'datatime': 5, + 'acc': 0.5 + }, 'string')) + logger_hook.after_val_epoch(runner) + # expect visualizer log `time` and `metric` jointly + runner.visualizer.add_scalars.assert_any_call( + { + 'time': 5, + 'datatime': 5, + 'acc': 0.5 + }, **args) + + with pytest.raises(AssertionError): + runner.visualizer.add_scalars.assert_any_call( + { + 'time': 5, + 'datatime': 5 + }, **args) + with pytest.raises(AssertionError): + runner.visualizer.add_scalars.assert_any_call({'acc': 0.5}, **args) + def test_after_test_epoch(self): logger_hook = LoggerHook() runner = MagicMock() -- GitLab