Skip to content
Snippets Groups Projects
Unverified Commit 6f69039c authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Feature] Add LoggerHook (#77)

* add logger hook

* update

* update

* update test

* update

* update test

* update

* update

* update

* update

* update

* Add logger hook

* Fix pre-commit

* Fix as comment

* Fix as comment

* Fix as comment

* Fix as comment

* Fix as comment

* Fix bytes

* update

* Fix as comment

* Fix as comment

* Update runner

* Fix as comment

* Fix as comment

* Fix as comment

* Fix as comment
parent 49b7d0ce
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@ from .checkpoint_hook import CheckpointHook
from .empty_cache_hook import EmptyCacheHook
from .hook import Hook
from .iter_timer_hook import IterTimerHook
from .logger_hook import LoggerHook
from .optimizer_hook import OptimizerHook
from .param_scheduler_hook import ParamSchedulerHook
from .sampler_seed_hook import DistSamplerSeedHook
......@@ -10,5 +11,6 @@ from .sync_buffer_hook import SyncBuffersHook
__all__ = [
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook'
'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook',
'LoggerHook'
]
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import datetime
import os
import os.path as osp
from collections import OrderedDict
from pathlib import Path
from typing import Any, Optional, Sequence, Tuple, Union
import torch
from mmengine.data import BaseDataSample
from mmengine.fileio import FileClient
from mmengine.hooks import Hook
from mmengine.registry import HOOKS
from mmengine.utils import is_tuple_of, scandir
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
@HOOKS.register_module()
class LoggerHook(Hook):
"""In this logger hook, the information will be printed on the terminal and
saved in JSON file, tensorboard, wandb .etc.
Args:
by_epoch (bool): Whether ``EpochBasedLoop`` is used.
Defaults to True.
interval (int): Logging interval (every k iterations).
Defaults to 10.
custom_keys (dict, optional): Defines the keys in the log and which
kinds of statistic methods should be used to log them.
- ``custom_keys`` contains multiple string-dict pairs. In each
string-dict pair, the string defines a key name in the log and the
dict is a config defines the statistic methods and corresponding
arguments used to log the value. For example,
``dict(loss=dict(method_name='mean', log_name='global_loss',
window_size='global'))`` which means the log key ``loss`` will be
counted as global mean and additionally logged as ``global_loss``.
If ``log_name`` is not defined in config dict, the original logged
key will be overwritten.
- The key in ``LoggerHook.fixed_smooth_keys`` cannot be overwritten
because ``time`` and ``iter_time`` will be used to calculate
estimated time of arrival. If you want to recount the time, you
should set ``log_name`` in corresponding values.
- For those statistic methods with the ``window_size`` argument,
if ``by_epoch`` is set to False, ``windows_size`` should not be
`epoch` to statistics log value by epoch.
ignore_last (bool): Ignore the log of last iterations in each epoch if
the number of remaining iterations is less than :attr:`interval`.
Defaults to True.
interval_exp_name (int): Logging interval for experiment name. This
feature is to help users conveniently get the experiment
information from screen or log file. Defaults to 1000.
out_dir (str or Path, optional): The root directory to save
checkpoints. If not specified, ``runner.work_dir`` will be used
by default. If specified, the ``out_dir`` will be the concatenation
of ``out_dir`` and the last level directory of
``runner.work_dir``. For example, if the input ``our_dir`` is
``./tmp`` and ``runner.work_dir`` is ``./work_dir/cur_exp``,
then the log will be saved in ``./tmp/cur_exp``. Deafule to None.
out_suffix (Tuple[str] or str): Those filenames ending with
``out_suffix`` will be copied to ``out_dir``. Defaults to
('.log.json', '.log', '.py').
keep_local (bool): Whether to keep local logs in the local machine
when :attr:`out_dir` is specified. If False, the local log will be
removed. Defaults to True.
file_client_args (dict, optional): Arguments to instantiate a
FileClient. See :class:`mmengine.fileio.FileClient` for details.
Defaults to None.
Examples:
>>> # `log_name` is defined, `loss_mean_window` will be an additional
>>> # record.
>>> logger_hook_cfg = dict(by_epoch=True,
>>> custom_keys=dict(
>>> loss=dict(
>>> log_name='loss_mean_window',
>>> method_name='mean',
>>> window_size=10)))
>>> # `log_name` is not defined. `loss` will be overwritten by
>>> # `global_mean` statistics.
>>> logger_hook_cfg = dict(by_epoch=True,
>>> custom_keys=dict(
>>> loss=dict(
>>> method_name='mean',
>>> window_size='global')))
>>> # `time` cannot be overwritten, `global_time` will be an additional
>>> # record.
>>> logger_hook_cfg = dict(by_epoch=True,
>>> custom_keys=dict(
>>> time=dict(
>>> log_name='global_time',
>>> method='mean',
>>> window_size='global')))
>>> # Record loss with different statistics methods.
>>> logger_hook_cfg = dict(by_epoch=True,
>>> custom_keys=dict(loss=[
>>> dict(log_name='loss_mean_window',
>>> method_name='mean',
>>> window_size=10),
>>> dict(method_name='mean',
>>> window_size='global')]))
"""
# eta will be calculated by time. `time` and `data_time` should not be
# overwritten.
fixed_smooth_keys = ('time', 'data_time')
priority = 'BELOW_NORMAL'
def __init__(
self,
by_epoch: bool = True,
interval: int = 10,
custom_keys: Optional[dict] = None,
ignore_last: bool = True,
interval_exp_name: int = 1000,
out_dir: Optional[Union[str, Path]] = None,
out_suffix: Union[Sequence[str], str] = ('.log.json', '.log', '.py'),
keep_local=True,
file_client_args=None,
):
self.by_epoch = by_epoch
self.interval = interval
self.custom_keys = custom_keys if custom_keys is not None else dict()
self.ignore_last = ignore_last
self.time_sec_tot = 0
self.interval_exp_name = interval_exp_name
self._check_custom_keys()
if out_dir is None and file_client_args is not None:
raise ValueError(
'file_client_args should be "None" when `out_dir` is not'
'specified.')
self.out_dir = out_dir
if not (out_dir is None or isinstance(out_dir, str)
or is_tuple_of(out_dir, str)):
raise TypeError('out_dir should be None or string or tuple of '
f'string, but got {type(out_dir)}')
self.out_suffix = out_suffix
self.keep_local = keep_local
self.file_client_args = file_client_args
if self.out_dir is not None:
self.file_client = FileClient.infer_client(file_client_args,
self.out_dir)
def before_run(self, runner) -> None:
"""Infer ``self.file_client`` from ``self.out_dir``. Initialize the
``self.start_iter`` and record the meta information.
Args:
runner (Runner): The runner of the training process.
"""
if self.out_dir is not None:
# The final `self.out_dir` is the concatenation of `self.out_dir`
# and the last level directory of `runner.work_dir`
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
self.out_dir = self.file_client.join_path(self.out_dir, basename)
runner.logger.info(
(f'Text logs will be saved to {self.out_dir} by '
f'{self.file_client.name} after the training process.'))
self.json_log_path = osp.join(runner.work_dir,
f'{runner.timestamp}.log.json')
self.yaml_log_path = osp.join(runner.work_dir,
f'{runner.timestamp}.log.json')
self.start_iter = runner.iter
if runner.meta is not None:
runner.writer.add_params(runner.meta, file_path=self.yaml_log_path)
def after_train_iter(
self,
runner,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Record training logs.
Args:
runner (Runner): The runner of the training process.
data_batch (Sequence[BaseDataSample], optional): Data from
dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
Defaults to None.
"""
if runner.meta is not None and 'exp_name' in runner.meta:
if (self.every_n_iters(runner, self.interval_exp_name)) or (
self.by_epoch and self.end_of_epoch(runner)):
exp_info = f'Exp name: {runner.meta["exp_name"]}'
runner.logger.info(exp_info)
if self.by_epoch and self.every_n_inner_iters(runner, self.interval):
self._log_train(runner)
elif not self.by_epoch and self.every_n_iters(runner, self.interval):
self._log_train(runner)
elif self.end_of_epoch(runner) and not self.ignore_last:
# `runner.max_iters` may not be divisible by `self.interval`. if
# `self.ignore_last==True`, the log of remaining iterations will
# be recorded (Epoch [4][1000/1007], the logs of 998-1007
# iterations will be recorded).
self._log_train(runner)
def after_val_epoch(self, runner) -> None:
"""Record validation logs.
Args:
runner (Runner): The runner of the training process.
"""
self._log_val(runner)
def after_run(self, runner) -> None:
"""Copy logs to ``self.out_dir`` if ``self.out_dir is not None``
Args:
runner (Runner): The runner of the training process.
"""
# copy or upload logs to self.out_dir
if self.out_dir is None:
return
for filename in scandir(runner.work_dir, self.out_suffix, True):
local_filepath = osp.join(runner.work_dir, filename)
out_filepath = self.file_client.join_path(self.out_dir, filename)
with open(local_filepath, 'r') as f:
self.file_client.put_text(f.read(), out_filepath)
runner.logger.info(
(f'The file {local_filepath} has been uploaded to '
f'{out_filepath}.'))
if not self.keep_local:
os.remove(local_filepath)
runner.logger.info((f'{local_filepath} was removed due to the '
'`self.keep_local=False`'))
def _log_train(self, runner) -> None:
"""Collect and record training logs which start named with "train/*".
Args:
runner (Runner): The runner of the training process.
"""
tag = self._collect_info(runner, 'train')
# The training log default defines `lr`, `momentum`, `time` and
# `data_time`. `log_tag` will pop these keys and loop other keys to
# `log_str`.
log_tag = copy.deepcopy(tag)
cur_iter = self._get_iter(runner, inner_iter=True)
cur_epoch = self._get_epoch(runner, 'train')
# Record learning rate and momentum.
lr_str_list = []
momentum_str_list = []
for key, value in tag.items():
if key.startswith('lr'):
log_tag.pop(key)
lr_str_list.append(f'{key}: {value:.3e}')
lr_str = ' '.join(lr_str_list)
for key, value in tag.items():
if key.startswith('momentum'):
log_tag.pop(key)
momentum_str_list.append(f'{key}: {value:.3e}')
momentum_str = ' '.join(momentum_str_list)
lr_momentum_str = f'{lr_str} {momentum_str}'
# by epoch: Epoch [4][100/1000]
# by iter: Iter [100/100000]
if self.by_epoch:
log_str = f'Epoch [{cur_epoch}]' \
f'[{cur_iter}/{len(runner.data_loader)}]\t'
else:
log_str = f'Iter [{cur_iter}/{runner.max_iters}]\t'
log_str += f'{lr_momentum_str}, '
# Calculate eta time.
self.time_sec_tot += (tag['time'] * self.interval)
time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter + 1)
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
log_str += f'eta: {eta_str}, '
log_str += f'time: {tag["time"]:.3f}, ' \
f'data_time: {tag["data_time"]:.3f}, '
# Pop recorded keys
log_tag.pop('time')
log_tag.pop('data_time')
# statistic memory
if torch.cuda.is_available():
log_str += f'memory: {self._get_max_memory(runner)}, '
# Loop left keys to fill `log_str`.
log_items = []
for name, val in log_tag.items():
if isinstance(val, float):
val = f'{val:.4f}'
log_items.append(f'{name}: {val}')
log_str += ', '.join(log_items)
runner.logger.info(log_str)
# Write logs to local, tensorboad, and wandb.
runner.writer.add_scalars(
tag, step=runner.iter + 1, file_path=self.json_log_path)
def _log_val(self, runner) -> None:
"""Collect and record training logs which start named with "val/*".
Args:
runner (Runner): The runner of the training process.
"""
tag = self._collect_info(runner, 'val')
# Compatible with function `log` https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/logger/text.py # noqa E501
eval_iter = len(runner.data_loader)
cur_iter = self._get_iter(runner)
cur_epoch = self._get_epoch(runner, 'val')
# val/test time
# here 1000 is the length of the val dataloader
# by epoch: Epoch[val] [4][1000]
# by iter: Iter[val] [1000]
if self.by_epoch:
# runner.epoch += 1 has been done before val workflow
log_str = f'Epoch(val) [{cur_epoch}][{eval_iter}]\t'
else:
log_str = f'Iter(val) [{eval_iter}]\t'
log_items = []
for name, val in tag.items():
if isinstance(val, float):
val = f'{val:.4f}'
log_items.append(f'{name}: {val}')
log_str += ', '.join(log_items)
runner.logger.info(log_str)
# Write tag.
runner.writer.add_scalars(
tag, step=cur_iter, file_path=self.json_log_path)
def _get_window_size(self, runner, window_size: Union[int, str]) \
-> int:
"""Parse window_size specified in ``self.custom_keys`` to int value.
Args:
runner (Runner): The runner of the training process.
window_size (int or str): Smoothing scale of logs.
Returns:
int: Smoothing window for statistical methods.
"""
if isinstance(window_size, int):
assert window_size == self.interval, \
'The value of windows size must equal to LoggerHook.interval'
return window_size
elif window_size == 'epoch':
return runner.inner_iter + 1
elif window_size == 'global':
return runner.iter + 1
else:
raise ValueError('window_size should be int, epoch or global, but '
f'got invalid {window_size}')
def _collect_info(self, runner, mode: str) -> dict:
"""Collect log information to a dict according to mode.
Args:
runner (Runner): The runner of the training process.
mode (str): 'train' or 'val', which means the prefix attached by
runner.
Returns:
dict: Statistical values of logs.
"""
tag = OrderedDict()
log_buffers = runner.message_hub.log_buffers
mode_log_buffers = OrderedDict()
# Filter log_buffers which starts with `mode`.
for prefix_key, log_buffer in log_buffers.items():
if prefix_key.startswith(mode):
key = prefix_key.split('/')[-1]
mode_log_buffers[key] = log_buffer
# Ensure all metric and lr values are latest.
for key in mode_log_buffers:
# Update the latest learning rate and smoothed time logs.
if key in self.fixed_smooth_keys or key.startswith('loss'):
tag[key] = mode_log_buffers[key].mean(self.interval)
else:
tag[key] = mode_log_buffers[key].current()
# Update custom keys.
if mode == 'train':
for log_key, log_cfg in self.custom_keys.items():
self._parse_custom_keys(runner, log_key,
copy.deepcopy(log_cfg),
mode_log_buffers, tag)
return tag
def _parse_custom_keys(self, runner, log_key: str, log_cfg: dict,
log_buffers: OrderedDict, tag: OrderedDict) -> None:
"""Statistics logs in log_buffers according to custom_keys.
Args:
runner (Runner): The runner of the training process.
log_key (str): log key specified in ``self.custom_keys``
log_cfg (dict): A config dict for describing the logging
statistics method.
log_buffers (OrderedDict): All logs for the corresponding phase.
tag (OrderedDict): A dict which defines all statistic values of
logs.
"""
if isinstance(log_cfg, list):
log_names = set()
for cfg in log_cfg:
log_name = cfg.get('log_name', None)
if log_name in log_names:
raise KeyError(f'{cfg["log_name"]} cannot be Redefined in '
'log_key')
if log_name is not None:
log_names.add(log_name)
self._parse_custom_keys(runner, log_key, cfg, log_buffers, tag)
assert len(log_names) == len(log_cfg) - 1, \
f'{log_key} cannot be overwritten multiple times, please ' \
f'check only one key does not contain `log_name` in {log_cfg}.'
elif isinstance(log_cfg, dict):
if 'window_size' in log_cfg:
log_cfg['window_size'] = \
self._get_window_size(runner, log_cfg['window_size'])
if 'log_name' in log_cfg:
name = log_cfg.pop('log_name')
else:
name = log_key
tag[name] = log_buffers[log_key].statistics(**log_cfg)
else:
raise ValueError('The structure of `LoggerHook.custom key` is '
'wrong, please make sure the type of each key is '
'dict or list.')
def _get_max_memory(self, runner) -> int:
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB)
for a given device.
Args:
runner (Runner): The runner of the training process.
Returns:
The maximum GPU memory occupied by tensors in megabytes for a given
device.
"""
# TODO use `mmengine.dist.max_memory_allocated` to count mem_mb
device = getattr(runner.model, 'output_device', None)
mem = torch.cuda.max_memory_allocated(device=device)
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
dtype=torch.int,
device=device)
torch.cuda.reset_peak_memory_stats()
return int(mem_mb.item())
def _check_custom_keys(self) -> None:
"""Check the legality of ``self.custom_keys``.
If ``self.by_epoch==False``, ``window_size`` should not be "epoch". The
key of ``self.fixed_smooth_keys`` cannot be overwritten.
"""
def _check_window_size(item):
if not self.by_epoch:
assert item['window_size'] != 'epoch', \
'window_size cannot be epoch if LoggerHook.by_epoch is ' \
'False.'
def _check_fixed_keys(key, item):
if key in self.fixed_smooth_keys:
assert 'log_name' in item, f'{key} cannot be overwritten by ' \
'custom keys!'
for key, value in self.custom_keys.items():
if isinstance(value, Sequence):
[(_check_window_size(item), _check_fixed_keys(key, item))
for item in value]
else:
_check_window_size(value)
_check_fixed_keys(key, value)
def _get_epoch(self, runner, mode: str) -> int:
"""Get epoch according to mode.
Args:
runner (Runner): The runner of the training process.
mode (str): Train or val.
Returns:
int: The current epoch.
"""
if mode == 'train':
epoch = runner.epoch + 1
elif mode == 'val':
# normal val mode
# runner.epoch += 1 has been done before val workflow
epoch = runner.epoch
else:
raise ValueError(f"runner mode should be 'train' or 'val', "
f'but got {runner.mode}')
return epoch
def _get_iter(self, runner, inner_iter=False) -> int:
"""Get the current training iteration step.
Args:
runner (Runner): The runner of the training process.
inner_iter (bool): Whether to return the inner iter of an epoch.
Defaults to False.
Returns:
int: The current global iter or inner iter.
"""
if self.by_epoch and inner_iter:
current_iter = runner.inner_iter + 1
else:
current_iter = runner.iter + 1
return current_iter
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import logging
import os.path as osp
import sys
from collections import OrderedDict
from unittest.mock import MagicMock, patch
import pytest
import torch
from mmengine.fileio.file_client import HardDiskBackend
from mmengine.hooks import LoggerHook
class TestLoggerHook:
def test_init(self):
logger_hook = LoggerHook(out_dir='tmp.txt')
assert logger_hook.by_epoch
assert logger_hook.interval == 10
assert not logger_hook.custom_keys
assert logger_hook.ignore_last
assert logger_hook.time_sec_tot == 0
assert logger_hook.interval_exp_name == 1000
assert logger_hook.out_suffix == ('.log.json', '.log', '.py')
assert logger_hook.keep_local
assert logger_hook.file_client_args is None
assert isinstance(logger_hook.file_client.client, HardDiskBackend)
# out_dir should be None or string or tuple of string.
with pytest.raises(TypeError):
LoggerHook(out_dir=1)
# time cannot be overwritten.
with pytest.raises(AssertionError):
LoggerHook(custom_keys=dict(time=dict(method='max')))
LoggerHook(
custom_keys=dict(time=[
dict(method='max', log_name='time_max'),
dict(method='min', log_name='time_min')
]))
# Epoch window_size cannot be used when `LoggerHook.by_epoch=False`
with pytest.raises(AssertionError):
LoggerHook(
by_epoch=False,
custom_keys=dict(
time=dict(
method='max', log_name='time_max',
window_size='epoch')))
with pytest.raises(ValueError):
LoggerHook(file_client_args=dict(enable_mc=True))
def test_before_run(self):
runner = MagicMock()
runner.iter = 10
runner.timestamp = 'timestamp'
runner.work_dir = 'work_dir'
runner.logger = MagicMock()
logger_hook = LoggerHook(out_dir='out_dir')
logger_hook.before_run(runner)
assert logger_hook.out_dir == osp.join('out_dir', 'work_dir')
assert logger_hook.json_log_path == osp.join('work_dir',
'timestamp.log.json')
assert logger_hook.start_iter == runner.iter
runner.writer.add_params.assert_called()
def test_after_run(self, tmp_path):
out_dir = tmp_path / 'out_dir'
out_dir.mkdir()
work_dir = tmp_path / 'work_dir'
work_dir.mkdir()
work_dir_json = work_dir / 'tmp.log.json'
json_f = open(work_dir_json, 'w')
json_f.close()
runner = MagicMock()
runner.work_dir = work_dir
logger_hook = LoggerHook(out_dir=str(tmp_path), keep_local=False)
logger_hook.out_dir = str(out_dir)
logger_hook.after_run(runner)
# Verify that the file has been moved to `out_dir`.
assert not osp.exists(str(work_dir_json))
assert osp.exists(str(out_dir / 'tmp.log.json'))
def test_after_train_iter(self):
# Test LoggerHook by iter.
runner = MagicMock()
runner.iter = 10
logger_hook = LoggerHook(by_epoch=False)
logger_hook._log_train = MagicMock()
logger_hook.after_train_iter(runner)
# `cur_iter=10+1`, which cannot be exact division by
# `logger_hook.interval`
logger_hook._log_train.assert_not_called()
runner.iter = 9
logger_hook.after_train_iter(runner)
logger_hook._log_train.assert_called()
# Test LoggerHook by epoch.
logger_hook = LoggerHook(by_epoch=True)
logger_hook._log_train = MagicMock()
# Only `runner.inner_iter` will work.
runner.iter = 9
runner.inner_iter = 10
logger_hook.after_train_iter(runner)
logger_hook._log_train.assert_not_called()
runner.inner_iter = 9
logger_hook.after_train_iter(runner)
logger_hook._log_train.assert_called()
# Test end of the epoch.
logger_hook = LoggerHook(by_epoch=True, ignore_last=False)
logger_hook._log_train = MagicMock()
runner.data_loader = [0] * 5
runner.inner_iter = 4
logger_hook.after_train_iter(runner)
logger_hook._log_train.assert_called()
# Test print exp_name
runner.meta = dict(exp_name='retinanet')
logger_hook = LoggerHook()
runner.logger = MagicMock()
logger_hook._log_train = MagicMock()
logger_hook.after_train_iter(runner)
runner.logger.info.assert_called_with(
f'Exp name: {runner.meta["exp_name"]}')
def test_after_val_epoch(self):
logger_hook = LoggerHook()
runner = MagicMock()
logger_hook._log_val = MagicMock()
logger_hook.after_val_epoch(runner)
logger_hook._log_val.assert_called()
@pytest.mark.parametrize('by_epoch', [True, False])
def test_log_train(self, by_epoch, capsys):
runner = self._setup_runner()
runner.meta = dict(exp_name='retinanet')
# Prepare LoggerHook
logger_hook = LoggerHook(by_epoch=by_epoch)
logger_hook.writer = MagicMock()
logger_hook.time_sec_tot = 1000
logger_hook.start_iter = 0
logger_hook._get_max_memory = MagicMock(return_value='100')
logger_hook.json_log_path = 'tmp.json'
# Prepare training information.
train_infos = dict(
lr=0.1, momentum=0.9, time=1.0, data_time=1.0, loss_cls=1.0)
logger_hook._collect_info = MagicMock(return_value=train_infos)
logger_hook._log_train(runner)
# Verify that the correct variables have been written.
runner.writer.add_scalars.assert_called_with(
train_infos, step=11, file_path='tmp.json')
# Verify that the correct context have been logged.
out, _ = capsys.readouterr()
time_avg = logger_hook.time_sec_tot / (
runner.iter + 1 - logger_hook.start_iter)
eta_second = time_avg * (runner.max_iters - runner.iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_second)))
if by_epoch:
if torch.cuda.is_available():
log_str = 'Epoch [2][2/5]\t' \
f"lr: {train_infos['lr']:.3e} " \
f"momentum: {train_infos['momentum']:.3e}, " \
f'eta: {eta_str}, ' \
f"time: {train_infos['time']:.3f}, " \
f"data_time: {train_infos['data_time']:.3f}, " \
f'memory: 100, ' \
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
else:
log_str = 'Epoch [2][2/5]\t' \
f"lr: {train_infos['lr']:.3e} " \
f"momentum: {train_infos['momentum']:.3e}, " \
f'eta: {eta_str}, ' \
f"time: {train_infos['time']:.3f}, " \
f"data_time: {train_infos['data_time']:.3f}, " \
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
assert out == log_str
else:
if torch.cuda.is_available():
log_str = 'Iter [11/50]\t' \
f"lr: {train_infos['lr']:.3e} " \
f"momentum: {train_infos['momentum']:.3e}, " \
f'eta: {eta_str}, ' \
f"time: {train_infos['time']:.3f}, " \
f"data_time: {train_infos['data_time']:.3f}, " \
f'memory: 100, ' \
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
else:
log_str = 'Iter [11/50]\t' \
f"lr: {train_infos['lr']:.3e} " \
f"momentum: {train_infos['momentum']:.3e}, " \
f'eta: {eta_str}, ' \
f"time: {train_infos['time']:.3f}, " \
f"data_time: {train_infos['data_time']:.3f}, " \
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
assert out == log_str
@pytest.mark.parametrize('by_epoch', [True, False])
def test_log_val(self, by_epoch, capsys):
runner = self._setup_runner()
# Prepare LoggerHook.
logger_hook = LoggerHook(by_epoch=by_epoch)
logger_hook.json_log_path = 'tmp.json'
metric = dict(accuracy=0.9, data_time=1.0)
logger_hook._collect_info = MagicMock(return_value=metric)
logger_hook._log_val(runner)
# Verify that the correct context have been logged.
out, _ = capsys.readouterr()
runner.writer.add_scalars.assert_called_with(
metric, step=11, file_path='tmp.json')
if by_epoch:
assert out == 'Epoch(val) [1][5]\taccuracy: 0.9000, ' \
'data_time: 1.0000\n'
else:
assert out == 'Iter(val) [5]\taccuracy: 0.9000, ' \
'data_time: 1.0000\n'
def test_get_window_size(self):
runner = self._setup_runner()
logger_hook = LoggerHook()
# Test get window size by name.
assert logger_hook._get_window_size(runner, 'epoch') == 2
assert logger_hook._get_window_size(runner, 'global') == 11
assert logger_hook._get_window_size(runner, 10) == 10
# Window size must equal to `logger_hook.interval`.
with pytest.raises(AssertionError):
logger_hook._get_window_size(runner, 20)
with pytest.raises(ValueError):
logger_hook._get_window_size(runner, 'unknwon')
def test_parse_custom_keys(self):
tag = OrderedDict()
runner = self._setup_runner()
log_buffers = OrderedDict(lr=MagicMock(), loss=MagicMock())
cfg_dict = dict(
lr=dict(method='min'),
loss=[
dict(method='min', window_size='global'),
dict(method='max', log_name='loss_max')
])
logger_hook = LoggerHook()
for log_key, log_cfg in cfg_dict.items():
logger_hook._parse_custom_keys(runner, log_key, log_cfg,
log_buffers, tag)
assert list(tag) == ['lr', 'loss', 'loss_max']
assert log_buffers['lr'].min.assert_called
assert log_buffers['loss'].min.assert_called
assert log_buffers['loss'].max.assert_called
assert log_buffers['loss'].mean.assert_called
# `log_name` Cannot be repeated.
with pytest.raises(KeyError):
cfg_dict = dict(loss=[
dict(method='min', window_size='global'),
dict(method='max', log_name='loss_max'),
dict(method='mean', log_name='loss_max')
])
logger_hook.custom_keys = cfg_dict
for log_key, log_cfg in cfg_dict.items():
logger_hook._parse_custom_keys(runner, log_key, log_cfg,
log_buffers, tag)
# `log_key` cannot be overwritten multiple times.
with pytest.raises(AssertionError):
cfg_dict = dict(loss=[
dict(method='min', window_size='global'),
dict(method='max'),
])
logger_hook.custom_keys = cfg_dict
for log_key, log_cfg in cfg_dict.items():
logger_hook._parse_custom_keys(runner, log_key, log_cfg,
log_buffers, tag)
def test_collect_info(self):
runner = self._setup_runner()
logger_hook = LoggerHook(
custom_keys=dict(time=dict(method='max', log_name='time_max')))
logger_hook._parse_custom_keys = MagicMock()
# Collect with prefix.
log_buffers = {
'train/time': MagicMock(),
'lr': MagicMock(),
'train/loss_cls': MagicMock(),
'val/metric': MagicMock()
}
runner.message_hub.log_buffers = log_buffers
tag = logger_hook._collect_info(runner, mode='train')
# Test parse custom_keys
logger_hook._parse_custom_keys.assert_called()
# Test training key in tag.
assert list(tag.keys()) == ['time', 'loss_cls']
# Test statistics lr with `current`, loss and time with 'mean'
log_buffers['train/time'].mean.assert_called()
log_buffers['train/loss_cls'].mean.assert_called()
log_buffers['train/loss_cls'].current.assert_not_called()
tag = logger_hook._collect_info(runner, mode='val')
assert list(tag.keys()) == ['metric']
log_buffers['val/metric'].current.assert_called()
@patch('torch.distributed.reduce', MagicMock())
def test_get_max_memory(self):
logger_hook = LoggerHook()
runner = MagicMock()
runner.world_size = 1
runner.model = torch.nn.Linear(1, 1)
logger_hook._get_max_memory(runner)
torch.distributed.reduce.assert_not_called()
runner.world_size = 2
logger_hook._get_max_memory(runner)
torch.distributed.reduce.assert_called()
def test_get_iter(self):
runner = self._setup_runner()
logger_hook = LoggerHook()
# Get global iter when `inner_iter=False`
iter = logger_hook._get_iter(runner)
assert iter == 11
# Get inner iter
iter = logger_hook._get_iter(runner, inner_iter=True)
assert iter == 2
# Still get global iter when `logger_hook.by_epoch==False`
logger_hook.by_epoch = False
iter = logger_hook._get_iter(runner, inner_iter=True)
assert iter == 11
def test_get_epoch(self):
runner = self._setup_runner()
logger_hook = LoggerHook()
epoch = logger_hook._get_epoch(runner, 'train')
assert epoch == 2
epoch = logger_hook._get_epoch(runner, 'val')
assert epoch == 1
with pytest.raises(ValueError):
logger_hook._get_epoch(runner, 'test')
def _setup_runner(self):
runner = MagicMock()
runner.epoch = 1
runner.data_loader = [0] * 5
runner.inner_iter = 1
runner.iter = 10
runner.max_iters = 50
logger = logging.getLogger()
logger.setLevel(logging.INFO)
for handler in logger.handlers:
if not isinstance(handler, logging.StreamHandler):
continue
else:
logger.addHandler(logging.StreamHandler(stream=sys.stdout))
runner.logger = logger
runner.message_hub = MagicMock()
runner.composed_wirter = MagicMock()
return runner
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment