Skip to content
Snippets Groups Projects
Unverified Commit dceef1f6 authored by Alex Yang's avatar Alex Yang Committed by GitHub
Browse files

[Refactor] Refactor `after_val_epoch` to make it output metric by epoch (#278)

* [Refactor]:Refactor `after_val_epoch` to make it output metric by epoch

* add an option for user to choose the way of outputing metric

* rename variable

* reformat docstring

* add type alias

* reformat code

* add test function

* add comment and test code

* add comment and test code
parent ef946404
No related branches found
No related tags found
No related merge requests found
...@@ -11,6 +11,7 @@ from mmengine.registry import HOOKS ...@@ -11,6 +11,7 @@ from mmengine.registry import HOOKS
from mmengine.utils import is_tuple_of, scandir from mmengine.utils import is_tuple_of, scandir
DATA_BATCH = Optional[Sequence[dict]] DATA_BATCH = Optional[Sequence[dict]]
SUFFIX_TYPE = Union[Sequence[str], str]
@HOOKS.register_module() @HOOKS.register_module()
...@@ -51,6 +52,11 @@ class LoggerHook(Hook): ...@@ -51,6 +52,11 @@ class LoggerHook(Hook):
file_client_args (dict, optional): Arguments to instantiate a file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmengine.fileio.FileClient` for details. FileClient. See :class:`mmengine.fileio.FileClient` for details.
Defaults to None. Defaults to None.
log_metric_by_epoch (bool): Whether to output metric in validation step
by epoch. It can be true when running in epoch based runner.
If set to True, `after_val_epoch` will set `step` to self.epoch in
`runner.visualizer.add_scalars`. Otherwise `step` will be
self.iter. Default to True.
Examples: Examples:
>>> # The simplest LoggerHook config. >>> # The simplest LoggerHook config.
...@@ -58,17 +64,15 @@ class LoggerHook(Hook): ...@@ -58,17 +64,15 @@ class LoggerHook(Hook):
""" """
priority = 'BELOW_NORMAL' priority = 'BELOW_NORMAL'
def __init__( def __init__(self,
self, interval: int = 10,
interval: int = 10, ignore_last: bool = True,
ignore_last: bool = True, interval_exp_name: int = 1000,
interval_exp_name: int = 1000, out_dir: Optional[Union[str, Path]] = None,
out_dir: Optional[Union[str, Path]] = None, out_suffix: SUFFIX_TYPE = ('.json', '.log', '.py', 'yaml'),
out_suffix: Union[Sequence[str], keep_local: bool = True,
str] = ('.json', '.log', '.py', 'yaml'), file_client_args: Optional[dict] = None,
keep_local: bool = True, log_metric_by_epoch: bool = True):
file_client_args: Optional[dict] = None,
):
self.interval = interval self.interval = interval
self.ignore_last = ignore_last self.ignore_last = ignore_last
self.interval_exp_name = interval_exp_name self.interval_exp_name = interval_exp_name
...@@ -91,6 +95,7 @@ class LoggerHook(Hook): ...@@ -91,6 +95,7 @@ class LoggerHook(Hook):
if self.out_dir is not None: if self.out_dir is not None:
self.file_client = FileClient.infer_client(file_client_args, self.file_client = FileClient.infer_client(file_client_args,
self.out_dir) self.out_dir)
self.log_metric_by_epoch = log_metric_by_epoch
def before_run(self, runner) -> None: def before_run(self, runner) -> None:
"""Infer ``self.file_client`` from ``self.out_dir``. Initialize the """Infer ``self.file_client`` from ``self.out_dir``. Initialize the
...@@ -203,8 +208,21 @@ class LoggerHook(Hook): ...@@ -203,8 +208,21 @@ class LoggerHook(Hook):
tag, log_str = runner.log_processor.get_log_after_epoch( tag, log_str = runner.log_processor.get_log_after_epoch(
runner, len(runner.val_dataloader), 'val') runner, len(runner.val_dataloader), 'val')
runner.logger.info(log_str) runner.logger.info(log_str)
runner.visualizer.add_scalars( if self.log_metric_by_epoch:
tag, step=runner.iter, file_path=self.json_log_path) # when `log_metric_by_epoch` is set to True, it's expected
# that validation metric can be logged by epoch rather than
# by iter. At the same time, scalars related to time should
# still be logged by iter to avoid messy visualized result.
# see details in PR #278.
time_tags = {k: v for k, v in tag.items() if 'time' in k}
metric_tags = {k: v for k, v in tag.items() if 'time' not in k}
runner.visualizer.add_scalars(
time_tags, step=runner.iter, file_path=self.json_log_path)
runner.visualizer.add_scalars(
metric_tags, step=runner.epoch, file_path=self.json_log_path)
else:
runner.visualizer.add_scalars(
tag, step=runner.iter, file_path=self.json_log_path)
def after_test_epoch(self, def after_test_epoch(self,
runner, runner,
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp import os.path as osp
from unittest.mock import MagicMock from unittest.mock import ANY, MagicMock
import pytest import pytest
...@@ -119,6 +119,49 @@ class TestLoggerHook: ...@@ -119,6 +119,49 @@ class TestLoggerHook:
runner.logger.info.assert_called() runner.logger.info.assert_called()
runner.visualizer.add_scalars.assert_called() runner.visualizer.add_scalars.assert_called()
# Test when `log_metric_by_epoch` is True
runner.log_processor.get_log_after_epoch = MagicMock(
return_value=({
'time': 1,
'datatime': 1,
'acc': 0.8
}, 'string'))
logger_hook.after_val_epoch(runner)
args = {'step': ANY, 'file_path': ANY}
# expect visualizer log `time` and `metric` respectively
runner.visualizer.add_scalars.assert_any_call(
{
'time': 1,
'datatime': 1
}, **args)
runner.visualizer.add_scalars.assert_any_call({'acc': 0.8}, **args)
# Test when `log_metric_by_epoch` is False
logger_hook = LoggerHook(log_metric_by_epoch=False)
runner.log_processor.get_log_after_epoch = MagicMock(
return_value=({
'time': 5,
'datatime': 5,
'acc': 0.5
}, 'string'))
logger_hook.after_val_epoch(runner)
# expect visualizer log `time` and `metric` jointly
runner.visualizer.add_scalars.assert_any_call(
{
'time': 5,
'datatime': 5,
'acc': 0.5
}, **args)
with pytest.raises(AssertionError):
runner.visualizer.add_scalars.assert_any_call(
{
'time': 5,
'datatime': 5
}, **args)
with pytest.raises(AssertionError):
runner.visualizer.add_scalars.assert_any_call({'acc': 0.5}, **args)
def test_after_test_epoch(self): def test_after_test_epoch(self):
logger_hook = LoggerHook() logger_hook = LoggerHook()
runner = MagicMock() runner = MagicMock()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment