Skip to content
Snippets Groups Projects
Unverified Commit fcd783fc authored by Ma Zerun's avatar Ma Zerun Committed by GitHub
Browse files

[Enhance] Support non-scalar type metric value. (#827)

* [Enhance] Support non-scalar type metric value.

* Refactor support.

* Fix non-scalar tags problem during validation

* Update tag processor.
parent 79067e46
No related branches found
No related tags found
No related merge requests found
......@@ -2,9 +2,13 @@
import os
import os.path as osp
import warnings
from collections import OrderedDict
from pathlib import Path
from typing import Dict, Optional, Sequence, Union
import numpy as np
import torch
from mmengine.fileio import FileClient, dump
from mmengine.fileio.io import get_file_backend
from mmengine.hooks import Hook
......@@ -252,9 +256,33 @@ class LoggerHook(Hook):
metrics, and the values are corresponding results.
"""
tag, log_str = runner.log_processor.get_log_after_epoch(
runner, len(runner.test_dataloader), 'test')
runner, len(runner.test_dataloader), 'test', with_non_scalar=True)
runner.logger.info(log_str)
dump(tag, osp.join(runner.log_dir, self.json_log_path)) # type: ignore
dump(
self._process_tags(tag),
osp.join(runner.log_dir, self.json_log_path)) # type: ignore
@staticmethod
def _process_tags(tags: dict):
"""Convert tag values to json-friendly type."""
def process_val(value):
if isinstance(value, (list, tuple)):
# Array type of json
return [process_val(item) for item in value]
elif isinstance(value, dict):
# Object type of json
return {k: process_val(v) for k, v in value.items()}
elif isinstance(value, (str, int, float, bool)) or value is None:
# Other supported type of json
return value
elif isinstance(value, (torch.Tensor, np.ndarray)):
return value.tolist()
# Drop unsupported values.
processed_tags = OrderedDict(process_val(tags))
return processed_tags
def after_run(self, runner) -> None:
"""Copy logs to ``self.out_dir`` if ``self.out_dir is not None``
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
from mmengine.registry import HOOKS
from mmengine.utils import get_git_hash
......@@ -9,6 +12,24 @@ from .hook import Hook
DATA_BATCH = Optional[Union[dict, tuple, list]]
def _is_scalar(value: Any) -> bool:
"""Determine the value is a scalar type value.
Args:
value (Any): value of log.
Returns:
bool: whether the value is a scalar type value.
"""
if isinstance(value, np.ndarray):
return value.size == 1
elif isinstance(value, (int, float)):
return True
elif isinstance(value, torch.Tensor):
return value.numel() == 1
return False
@HOOKS.register_module()
class RuntimeInfoHook(Hook):
"""A hook that updates runtime information into message hub.
......@@ -112,7 +133,10 @@ class RuntimeInfoHook(Hook):
"""
if metrics is not None:
for key, value in metrics.items():
runner.message_hub.update_scalar(f'val/{key}', value)
if _is_scalar(value):
runner.message_hub.update_scalar(f'val/{key}', value)
else:
runner.message_hub.update_info(f'val/{key}', value)
def after_test_epoch(self,
runner,
......@@ -128,4 +152,7 @@ class RuntimeInfoHook(Hook):
"""
if metrics is not None:
for key, value in metrics.items():
runner.message_hub.update_scalar(f'test/{key}', value)
if _is_scalar(value):
runner.message_hub.update_scalar(f'test/{key}', value)
else:
runner.message_hub.update_info(f'test/{key}', value)
......@@ -2,8 +2,12 @@
import copy
import datetime
from collections import OrderedDict
from itertools import chain
from typing import List, Optional, Tuple
import numpy as np
import torch
from mmengine.device import get_max_cuda_memory, is_cuda_available
from mmengine.registry import LOG_PROCESSORS
......@@ -206,8 +210,11 @@ class LogProcessor:
log_str += ' '.join(log_items)
return tag, log_str
def get_log_after_epoch(self, runner, batch_idx: int,
mode: str) -> Tuple[dict, str]:
def get_log_after_epoch(self,
runner,
batch_idx: int,
mode: str,
with_non_scalar: bool = False) -> Tuple[dict, str]:
"""Format log string after validation or testing epoch.
Args:
......@@ -215,6 +222,8 @@ class LogProcessor:
batch_idx (int): The index of the current batch in the current
loop.
mode (str): Current mode of runner.
with_non_scalar (bool): Whether to include non-scalar infos in the
returned tag. Defaults to False.
Return:
Tuple(dict, str): Formatted log dict/string which will be
......@@ -230,6 +239,7 @@ class LogProcessor:
custom_cfg_copy = self._parse_windows_size(runner, batch_idx)
# tag is used to write log information to different backends.
tag = self._collect_scalars(custom_cfg_copy, runner, mode)
non_scalar_tag = self._collect_non_scalars(runner, mode)
tag.pop('time', None)
tag.pop('data_time', None)
# By epoch:
......@@ -252,11 +262,17 @@ class LogProcessor:
# `time` and `data_time` will not be recorded in after epoch log
# message.
log_items = []
for name, val in tag.items():
for name, val in chain(tag.items(), non_scalar_tag.items()):
if isinstance(val, float):
val = f'{val:.{self.num_digits}f}'
if isinstance(val, (torch.Tensor, np.ndarray)):
# newline to display tensor and array.
val = f'\n{val}\n'
log_items.append(f'{name}: {val}')
log_str += ' '.join(log_items)
if with_non_scalar:
tag.update(non_scalar_tag)
return tag, log_str
def _collect_scalars(self, custom_cfg: List[dict], runner,
......@@ -305,6 +321,28 @@ class LogProcessor:
**log_cfg)
return tag
def _collect_non_scalars(self, runner, mode: str) -> dict:
"""Collect log information to compose a dict according to mode.
Args:
runner (Runner): The runner of the training/testing/validation
process.
mode (str): Current mode of runner.
Returns:
dict: non-scalar infos of the specified mode.
"""
# infos of train/val/test phase.
infos = runner.message_hub.runtime_info
# corresponding mode infos
mode_infos = OrderedDict()
# extract log info and remove prefix to `mode_infos` according to mode.
for prefix_key, value in infos.items():
if prefix_key.startswith(mode):
key = prefix_key.partition('/')[-1]
mode_infos[key] = value
return mode_infos
def _check_custom_cfg(self) -> None:
"""Check the legality of ``self.custom_cfg``."""
......
......@@ -3,7 +3,9 @@ import os.path as osp
from unittest.mock import ANY, MagicMock
import pytest
import torch
from mmengine.fileio import load
from mmengine.fileio.file_client import HardDiskBackend
from mmengine.hooks import LoggerHook
......@@ -178,12 +180,17 @@ class TestLoggerHook:
runner.log_dir = tmp_path
runner.timestamp = 'test_after_test_epoch'
runner.log_processor.get_log_after_epoch = MagicMock(
return_value=(dict(a=1, b=2), 'log_str'))
return_value=(
dict(a=1, b=2, c={'list': [1, 2]}, d=torch.tensor([1, 2, 3])),
'log_str'))
logger_hook.before_run(runner)
logger_hook.after_test_epoch(runner)
runner.log_processor.get_log_after_epoch.assert_called()
runner.logger.info.assert_called()
osp.isfile(osp.join(runner.log_dir, 'test_after_test_epoch.json'))
json_content = load(
osp.join(runner.log_dir, 'test_after_test_epoch.json'))
assert json_content == dict(a=1, b=2, c={'list': [1, 2]}, d=[1, 2, 3])
def test_after_val_iter(self):
logger_hook = LoggerHook()
......
......@@ -144,19 +144,28 @@ class TestLogProcessor:
# Prepare LoggerHook
log_processor = LogProcessor(by_epoch=by_epoch)
# Prepare validation information.
val_logs = dict(accuracy=0.9, data_time=1.0)
log_processor._collect_scalars = MagicMock(return_value=val_logs)
scalar_logs = dict(accuracy=0.9, data_time=1.0)
non_scalar_logs = dict(
recall={
'cat': 1,
'dog': 0
}, cm=torch.tensor([1, 2, 3]))
log_processor._collect_scalars = MagicMock(return_value=scalar_logs)
log_processor._collect_non_scalars = MagicMock(
return_value=non_scalar_logs)
_, out = log_processor.get_log_after_epoch(self.runner, 2, mode)
expect_metric_str = ("accuracy: 0.9000 recall: {'cat': 1, 'dog': 0} "
'cm: \ntensor([1, 2, 3])\n')
if by_epoch:
if mode == 'test':
assert out == 'Epoch(test) [5/5] accuracy: 0.9000'
assert out == 'Epoch(test) [5/5] ' + expect_metric_str
else:
assert out == 'Epoch(val) [1][10/10] accuracy: 0.9000'
assert out == 'Epoch(val) [1][10/10] ' + expect_metric_str
else:
if mode == 'test':
assert out == 'Iter(test) [5/5] accuracy: 0.9000'
assert out == 'Iter(test) [5/5] ' + expect_metric_str
else:
assert out == 'Iter(val) [10/10] accuracy: 0.9000'
assert out == 'Iter(val) [10/10] ' + expect_metric_str
def test_collect_scalars(self):
history_count = np.ones(100)
......@@ -196,6 +205,21 @@ class TestLogProcessor:
assert list(tag.keys()) == ['metric']
assert tag['metric'] == metric_scalars[-1]
def test_collect_non_scalars(self):
metric1 = np.random.rand(10)
metric2 = torch.tensor(10)
log_processor = LogProcessor()
# Collect with prefix.
log_infos = {'test/metric1': metric1, 'test/metric2': metric2}
self.runner.message_hub._runtime_info = log_infos
tag = log_processor._collect_non_scalars(self.runner, mode='test')
# Test training key in tag.
assert list(tag.keys()) == ['metric1', 'metric2']
# Test statistics lr with `current`, loss and time with 'mean'
assert tag['metric1'] is metric1
assert tag['metric2'] is metric2
@patch('torch.cuda.max_memory_allocated', MagicMock())
@patch('torch.cuda.reset_peak_memory_stats', MagicMock())
def test_get_max_memory(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment