From dceef1f66fd5afd6346f9941b6595b81ee6b364e Mon Sep 17 00:00:00 2001
From: Alex Yang <50511903+imabackstabber@users.noreply.github.com>
Date: Tue, 21 Jun 2022 15:39:59 +0800
Subject: [PATCH] [Refactor] Refactor `after_val_epoch` to make it output
 metric by epoch (#278)

* [Refactor]:Refactor `after_val_epoch` to make it output metric by epoch

* add an option for user to choose the way of outputing metric

* rename variable

* reformat docstring

* add type alias

* reformat code

* add test function

* add comment and test code

* add comment and test code
---
 mmengine/hooks/logger_hook.py       | 44 +++++++++++++++++++---------
 tests/test_hook/test_logger_hook.py | 45 ++++++++++++++++++++++++++++-
 2 files changed, 75 insertions(+), 14 deletions(-)

diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py
index 6d9bc869..b9fb5f82 100644
--- a/mmengine/hooks/logger_hook.py
+++ b/mmengine/hooks/logger_hook.py
@@ -11,6 +11,7 @@ from mmengine.registry import HOOKS
 from mmengine.utils import is_tuple_of, scandir
 
 DATA_BATCH = Optional[Sequence[dict]]
+SUFFIX_TYPE = Union[Sequence[str], str]
 
 
 @HOOKS.register_module()
@@ -51,6 +52,11 @@ class LoggerHook(Hook):
         file_client_args (dict, optional): Arguments to instantiate a
             FileClient. See :class:`mmengine.fileio.FileClient` for details.
             Defaults to None.
+        log_metric_by_epoch (bool): Whether to output metric in validation step
+            by epoch. It can be true when running in epoch based runner.
+            If set to True, `after_val_epoch` will set `step` to self.epoch in
+            `runner.visualizer.add_scalars`. Otherwise `step` will be
+            self.iter. Default to True.
 
     Examples:
         >>> # The simplest LoggerHook config.
@@ -58,17 +64,15 @@ class LoggerHook(Hook):
     """
     priority = 'BELOW_NORMAL'
 
-    def __init__(
-        self,
-        interval: int = 10,
-        ignore_last: bool = True,
-        interval_exp_name: int = 1000,
-        out_dir: Optional[Union[str, Path]] = None,
-        out_suffix: Union[Sequence[str],
-                          str] = ('.json', '.log', '.py', 'yaml'),
-        keep_local: bool = True,
-        file_client_args: Optional[dict] = None,
-    ):
+    def __init__(self,
+                 interval: int = 10,
+                 ignore_last: bool = True,
+                 interval_exp_name: int = 1000,
+                 out_dir: Optional[Union[str, Path]] = None,
+                 out_suffix: SUFFIX_TYPE = ('.json', '.log', '.py', 'yaml'),
+                 keep_local: bool = True,
+                 file_client_args: Optional[dict] = None,
+                 log_metric_by_epoch: bool = True):
         self.interval = interval
         self.ignore_last = ignore_last
         self.interval_exp_name = interval_exp_name
@@ -91,6 +95,7 @@ class LoggerHook(Hook):
         if self.out_dir is not None:
             self.file_client = FileClient.infer_client(file_client_args,
                                                        self.out_dir)
+        self.log_metric_by_epoch = log_metric_by_epoch
 
     def before_run(self, runner) -> None:
         """Infer ``self.file_client`` from ``self.out_dir``. Initialize the
@@ -203,8 +208,21 @@ class LoggerHook(Hook):
         tag, log_str = runner.log_processor.get_log_after_epoch(
             runner, len(runner.val_dataloader), 'val')
         runner.logger.info(log_str)
-        runner.visualizer.add_scalars(
-            tag, step=runner.iter, file_path=self.json_log_path)
+        if self.log_metric_by_epoch:
+            # when `log_metric_by_epoch` is set to True, it's expected
+            # that validation metric can be logged by epoch rather than
+            # by iter. At the same time, scalars related to time should
+            # still be logged by iter to avoid messy visualized result.
+            # see details in PR #278.
+            time_tags = {k: v for k, v in tag.items() if 'time' in k}
+            metric_tags = {k: v for k, v in tag.items() if 'time' not in k}
+            runner.visualizer.add_scalars(
+                time_tags, step=runner.iter, file_path=self.json_log_path)
+            runner.visualizer.add_scalars(
+                metric_tags, step=runner.epoch, file_path=self.json_log_path)
+        else:
+            runner.visualizer.add_scalars(
+                tag, step=runner.iter, file_path=self.json_log_path)
 
     def after_test_epoch(self,
                          runner,
diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py
index 2cf75dc4..66d0d06c 100644
--- a/tests/test_hook/test_logger_hook.py
+++ b/tests/test_hook/test_logger_hook.py
@@ -1,6 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 import os.path as osp
-from unittest.mock import MagicMock
+from unittest.mock import ANY, MagicMock
 
 import pytest
 
@@ -119,6 +119,49 @@ class TestLoggerHook:
         runner.logger.info.assert_called()
         runner.visualizer.add_scalars.assert_called()
 
+        # Test when `log_metric_by_epoch` is True
+        runner.log_processor.get_log_after_epoch = MagicMock(
+            return_value=({
+                'time': 1,
+                'datatime': 1,
+                'acc': 0.8
+            }, 'string'))
+        logger_hook.after_val_epoch(runner)
+        args = {'step': ANY, 'file_path': ANY}
+        # expect visualizer log `time` and `metric` respectively
+        runner.visualizer.add_scalars.assert_any_call(
+            {
+                'time': 1,
+                'datatime': 1
+            }, **args)
+        runner.visualizer.add_scalars.assert_any_call({'acc': 0.8}, **args)
+
+        # Test when `log_metric_by_epoch` is False
+        logger_hook = LoggerHook(log_metric_by_epoch=False)
+        runner.log_processor.get_log_after_epoch = MagicMock(
+            return_value=({
+                'time': 5,
+                'datatime': 5,
+                'acc': 0.5
+            }, 'string'))
+        logger_hook.after_val_epoch(runner)
+        # expect visualizer log `time` and `metric` jointly
+        runner.visualizer.add_scalars.assert_any_call(
+            {
+                'time': 5,
+                'datatime': 5,
+                'acc': 0.5
+            }, **args)
+
+        with pytest.raises(AssertionError):
+            runner.visualizer.add_scalars.assert_any_call(
+                {
+                    'time': 5,
+                    'datatime': 5
+                }, **args)
+        with pytest.raises(AssertionError):
+            runner.visualizer.add_scalars.assert_any_call({'acc': 0.5}, **args)
+
     def test_after_test_epoch(self):
         logger_hook = LoggerHook()
         runner = MagicMock()
-- 
GitLab