diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py index 1cb5b5356092173ce19e0edc32cb21a1cb92c266..45c3f910a8c13f7b2c56a1cea0e7b28753aa4454 100644 --- a/mmengine/hooks/__init__.py +++ b/mmengine/hooks/__init__.py @@ -3,6 +3,7 @@ from .checkpoint_hook import CheckpointHook from .empty_cache_hook import EmptyCacheHook from .hook import Hook from .iter_timer_hook import IterTimerHook +from .logger_hook import LoggerHook from .optimizer_hook import OptimizerHook from .param_scheduler_hook import ParamSchedulerHook from .sampler_seed_hook import DistSamplerSeedHook @@ -10,5 +11,6 @@ from .sync_buffer_hook import SyncBuffersHook __all__ = [ 'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook', - 'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook' + 'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook', + 'LoggerHook' ] diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..741bfd95442021855c9f90337012cc5965f8346c --- /dev/null +++ b/mmengine/hooks/logger_hook.py @@ -0,0 +1,509 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import datetime +import os +import os.path as osp +from collections import OrderedDict +from pathlib import Path +from typing import Any, Optional, Sequence, Tuple, Union + +import torch + +from mmengine.data import BaseDataSample +from mmengine.fileio import FileClient +from mmengine.hooks import Hook +from mmengine.registry import HOOKS +from mmengine.utils import is_tuple_of, scandir + +DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]] + + +@HOOKS.register_module() +class LoggerHook(Hook): + """In this logger hook, the information will be printed on the terminal and + saved in JSON file, tensorboard, wandb .etc. + + Args: + by_epoch (bool): Whether ``EpochBasedLoop`` is used. + Defaults to True. + interval (int): Logging interval (every k iterations). + Defaults to 10. + custom_keys (dict, optional): Defines the keys in the log and which + kinds of statistic methods should be used to log them. + + - ``custom_keys`` contains multiple string-dict pairs. In each + string-dict pair, the string defines a key name in the log and the + dict is a config defines the statistic methods and corresponding + arguments used to log the value. For example, + ``dict(loss=dict(method_name='mean', log_name='global_loss', + window_size='global'))`` which means the log key ``loss`` will be + counted as global mean and additionally logged as ``global_loss``. + If ``log_name`` is not defined in config dict, the original logged + key will be overwritten. + - The key in ``LoggerHook.fixed_smooth_keys`` cannot be overwritten + because ``time`` and ``iter_time`` will be used to calculate + estimated time of arrival. If you want to recount the time, you + should set ``log_name`` in corresponding values. + - For those statistic methods with the ``window_size`` argument, + if ``by_epoch`` is set to False, ``windows_size`` should not be + `epoch` to statistics log value by epoch. + ignore_last (bool): Ignore the log of last iterations in each epoch if + the number of remaining iterations is less than :attr:`interval`. + Defaults to True. + interval_exp_name (int): Logging interval for experiment name. This + feature is to help users conveniently get the experiment + information from screen or log file. Defaults to 1000. + out_dir (str or Path, optional): The root directory to save + checkpoints. If not specified, ``runner.work_dir`` will be used + by default. If specified, the ``out_dir`` will be the concatenation + of ``out_dir`` and the last level directory of + ``runner.work_dir``. For example, if the input ``our_dir`` is + ``./tmp`` and ``runner.work_dir`` is ``./work_dir/cur_exp``, + then the log will be saved in ``./tmp/cur_exp``. Deafule to None. + out_suffix (Tuple[str] or str): Those filenames ending with + ``out_suffix`` will be copied to ``out_dir``. Defaults to + ('.log.json', '.log', '.py'). + keep_local (bool): Whether to keep local logs in the local machine + when :attr:`out_dir` is specified. If False, the local log will be + removed. Defaults to True. + file_client_args (dict, optional): Arguments to instantiate a + FileClient. See :class:`mmengine.fileio.FileClient` for details. + Defaults to None. + + Examples: + >>> # `log_name` is defined, `loss_mean_window` will be an additional + >>> # record. + >>> logger_hook_cfg = dict(by_epoch=True, + >>> custom_keys=dict( + >>> loss=dict( + >>> log_name='loss_mean_window', + >>> method_name='mean', + >>> window_size=10))) + >>> # `log_name` is not defined. `loss` will be overwritten by + >>> # `global_mean` statistics. + >>> logger_hook_cfg = dict(by_epoch=True, + >>> custom_keys=dict( + >>> loss=dict( + >>> method_name='mean', + >>> window_size='global'))) + >>> # `time` cannot be overwritten, `global_time` will be an additional + >>> # record. + >>> logger_hook_cfg = dict(by_epoch=True, + >>> custom_keys=dict( + >>> time=dict( + >>> log_name='global_time', + >>> method='mean', + >>> window_size='global'))) + >>> # Record loss with different statistics methods. + >>> logger_hook_cfg = dict(by_epoch=True, + >>> custom_keys=dict(loss=[ + >>> dict(log_name='loss_mean_window', + >>> method_name='mean', + >>> window_size=10), + >>> dict(method_name='mean', + >>> window_size='global')])) + """ + # eta will be calculated by time. `time` and `data_time` should not be + # overwritten. + fixed_smooth_keys = ('time', 'data_time') + priority = 'BELOW_NORMAL' + + def __init__( + self, + by_epoch: bool = True, + interval: int = 10, + custom_keys: Optional[dict] = None, + ignore_last: bool = True, + interval_exp_name: int = 1000, + out_dir: Optional[Union[str, Path]] = None, + out_suffix: Union[Sequence[str], str] = ('.log.json', '.log', '.py'), + keep_local=True, + file_client_args=None, + ): + self.by_epoch = by_epoch + self.interval = interval + self.custom_keys = custom_keys if custom_keys is not None else dict() + self.ignore_last = ignore_last + + self.time_sec_tot = 0 + self.interval_exp_name = interval_exp_name + self._check_custom_keys() + + if out_dir is None and file_client_args is not None: + raise ValueError( + 'file_client_args should be "None" when `out_dir` is not' + 'specified.') + self.out_dir = out_dir + + if not (out_dir is None or isinstance(out_dir, str) + or is_tuple_of(out_dir, str)): + raise TypeError('out_dir should be None or string or tuple of ' + f'string, but got {type(out_dir)}') + self.out_suffix = out_suffix + + self.keep_local = keep_local + self.file_client_args = file_client_args + if self.out_dir is not None: + self.file_client = FileClient.infer_client(file_client_args, + self.out_dir) + + def before_run(self, runner) -> None: + """Infer ``self.file_client`` from ``self.out_dir``. Initialize the + ``self.start_iter`` and record the meta information. + + Args: + runner (Runner): The runner of the training process. + """ + if self.out_dir is not None: + # The final `self.out_dir` is the concatenation of `self.out_dir` + # and the last level directory of `runner.work_dir` + basename = osp.basename(runner.work_dir.rstrip(osp.sep)) + self.out_dir = self.file_client.join_path(self.out_dir, basename) + runner.logger.info( + (f'Text logs will be saved to {self.out_dir} by ' + f'{self.file_client.name} after the training process.')) + + self.json_log_path = osp.join(runner.work_dir, + f'{runner.timestamp}.log.json') + self.yaml_log_path = osp.join(runner.work_dir, + f'{runner.timestamp}.log.json') + self.start_iter = runner.iter + if runner.meta is not None: + runner.writer.add_params(runner.meta, file_path=self.yaml_log_path) + + def after_train_iter( + self, + runner, + data_batch: DATA_BATCH = None, + outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + """Record training logs. + + Args: + runner (Runner): The runner of the training process. + data_batch (Sequence[BaseDataSample], optional): Data from + dataloader. Defaults to None. + outputs (Sequence[BaseDataSample], optional): Outputs from model. + Defaults to None. + """ + if runner.meta is not None and 'exp_name' in runner.meta: + if (self.every_n_iters(runner, self.interval_exp_name)) or ( + self.by_epoch and self.end_of_epoch(runner)): + exp_info = f'Exp name: {runner.meta["exp_name"]}' + runner.logger.info(exp_info) + if self.by_epoch and self.every_n_inner_iters(runner, self.interval): + self._log_train(runner) + elif not self.by_epoch and self.every_n_iters(runner, self.interval): + self._log_train(runner) + elif self.end_of_epoch(runner) and not self.ignore_last: + # `runner.max_iters` may not be divisible by `self.interval`. if + # `self.ignore_last==True`, the log of remaining iterations will + # be recorded (Epoch [4][1000/1007], the logs of 998-1007 + # iterations will be recorded). + self._log_train(runner) + + def after_val_epoch(self, runner) -> None: + """Record validation logs. + + Args: + runner (Runner): The runner of the training process. + """ + self._log_val(runner) + + def after_run(self, runner) -> None: + """Copy logs to ``self.out_dir`` if ``self.out_dir is not None`` + + Args: + runner (Runner): The runner of the training process. + """ + # copy or upload logs to self.out_dir + if self.out_dir is None: + return + for filename in scandir(runner.work_dir, self.out_suffix, True): + local_filepath = osp.join(runner.work_dir, filename) + out_filepath = self.file_client.join_path(self.out_dir, filename) + with open(local_filepath, 'r') as f: + self.file_client.put_text(f.read(), out_filepath) + + runner.logger.info( + (f'The file {local_filepath} has been uploaded to ' + f'{out_filepath}.')) + + if not self.keep_local: + os.remove(local_filepath) + runner.logger.info((f'{local_filepath} was removed due to the ' + '`self.keep_local=False`')) + + def _log_train(self, runner) -> None: + """Collect and record training logs which start named with "train/*". + + Args: + runner (Runner): The runner of the training process. + """ + tag = self._collect_info(runner, 'train') + # The training log default defines `lr`, `momentum`, `time` and + # `data_time`. `log_tag` will pop these keys and loop other keys to + # `log_str`. + log_tag = copy.deepcopy(tag) + cur_iter = self._get_iter(runner, inner_iter=True) + cur_epoch = self._get_epoch(runner, 'train') + + # Record learning rate and momentum. + lr_str_list = [] + momentum_str_list = [] + for key, value in tag.items(): + if key.startswith('lr'): + log_tag.pop(key) + lr_str_list.append(f'{key}: {value:.3e}') + lr_str = ' '.join(lr_str_list) + for key, value in tag.items(): + if key.startswith('momentum'): + log_tag.pop(key) + momentum_str_list.append(f'{key}: {value:.3e}') + momentum_str = ' '.join(momentum_str_list) + lr_momentum_str = f'{lr_str} {momentum_str}' + # by epoch: Epoch [4][100/1000] + # by iter: Iter [100/100000] + if self.by_epoch: + log_str = f'Epoch [{cur_epoch}]' \ + f'[{cur_iter}/{len(runner.data_loader)}]\t' + else: + log_str = f'Iter [{cur_iter}/{runner.max_iters}]\t' + log_str += f'{lr_momentum_str}, ' + # Calculate eta time. + self.time_sec_tot += (tag['time'] * self.interval) + time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter + 1) + eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1) + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + log_str += f'eta: {eta_str}, ' + log_str += f'time: {tag["time"]:.3f}, ' \ + f'data_time: {tag["data_time"]:.3f}, ' + # Pop recorded keys + log_tag.pop('time') + log_tag.pop('data_time') + # statistic memory + if torch.cuda.is_available(): + log_str += f'memory: {self._get_max_memory(runner)}, ' + # Loop left keys to fill `log_str`. + log_items = [] + for name, val in log_tag.items(): + if isinstance(val, float): + val = f'{val:.4f}' + log_items.append(f'{name}: {val}') + log_str += ', '.join(log_items) + runner.logger.info(log_str) + # Write logs to local, tensorboad, and wandb. + runner.writer.add_scalars( + tag, step=runner.iter + 1, file_path=self.json_log_path) + + def _log_val(self, runner) -> None: + """Collect and record training logs which start named with "val/*". + + Args: + runner (Runner): The runner of the training process. + """ + tag = self._collect_info(runner, 'val') + # Compatible with function `log` https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/logger/text.py # noqa E501 + eval_iter = len(runner.data_loader) + cur_iter = self._get_iter(runner) + cur_epoch = self._get_epoch(runner, 'val') + # val/test time + # here 1000 is the length of the val dataloader + # by epoch: Epoch[val] [4][1000] + # by iter: Iter[val] [1000] + if self.by_epoch: + # runner.epoch += 1 has been done before val workflow + log_str = f'Epoch(val) [{cur_epoch}][{eval_iter}]\t' + else: + log_str = f'Iter(val) [{eval_iter}]\t' + + log_items = [] + for name, val in tag.items(): + if isinstance(val, float): + val = f'{val:.4f}' + log_items.append(f'{name}: {val}') + log_str += ', '.join(log_items) + runner.logger.info(log_str) + # Write tag. + runner.writer.add_scalars( + tag, step=cur_iter, file_path=self.json_log_path) + + def _get_window_size(self, runner, window_size: Union[int, str]) \ + -> int: + """Parse window_size specified in ``self.custom_keys`` to int value. + + Args: + runner (Runner): The runner of the training process. + window_size (int or str): Smoothing scale of logs. + + Returns: + int: Smoothing window for statistical methods. + """ + if isinstance(window_size, int): + assert window_size == self.interval, \ + 'The value of windows size must equal to LoggerHook.interval' + return window_size + elif window_size == 'epoch': + return runner.inner_iter + 1 + elif window_size == 'global': + return runner.iter + 1 + else: + raise ValueError('window_size should be int, epoch or global, but ' + f'got invalid {window_size}') + + def _collect_info(self, runner, mode: str) -> dict: + """Collect log information to a dict according to mode. + + Args: + runner (Runner): The runner of the training process. + mode (str): 'train' or 'val', which means the prefix attached by + runner. + + Returns: + dict: Statistical values of logs. + """ + tag = OrderedDict() + log_buffers = runner.message_hub.log_buffers + mode_log_buffers = OrderedDict() + # Filter log_buffers which starts with `mode`. + for prefix_key, log_buffer in log_buffers.items(): + if prefix_key.startswith(mode): + key = prefix_key.split('/')[-1] + mode_log_buffers[key] = log_buffer + # Ensure all metric and lr values are latest. + for key in mode_log_buffers: + # Update the latest learning rate and smoothed time logs. + if key in self.fixed_smooth_keys or key.startswith('loss'): + tag[key] = mode_log_buffers[key].mean(self.interval) + else: + tag[key] = mode_log_buffers[key].current() + # Update custom keys. + if mode == 'train': + for log_key, log_cfg in self.custom_keys.items(): + self._parse_custom_keys(runner, log_key, + copy.deepcopy(log_cfg), + mode_log_buffers, tag) + return tag + + def _parse_custom_keys(self, runner, log_key: str, log_cfg: dict, + log_buffers: OrderedDict, tag: OrderedDict) -> None: + """Statistics logs in log_buffers according to custom_keys. + + Args: + runner (Runner): The runner of the training process. + log_key (str): log key specified in ``self.custom_keys`` + log_cfg (dict): A config dict for describing the logging + statistics method. + log_buffers (OrderedDict): All logs for the corresponding phase. + tag (OrderedDict): A dict which defines all statistic values of + logs. + """ + if isinstance(log_cfg, list): + log_names = set() + for cfg in log_cfg: + log_name = cfg.get('log_name', None) + if log_name in log_names: + raise KeyError(f'{cfg["log_name"]} cannot be Redefined in ' + 'log_key') + if log_name is not None: + log_names.add(log_name) + self._parse_custom_keys(runner, log_key, cfg, log_buffers, tag) + assert len(log_names) == len(log_cfg) - 1, \ + f'{log_key} cannot be overwritten multiple times, please ' \ + f'check only one key does not contain `log_name` in {log_cfg}.' + elif isinstance(log_cfg, dict): + if 'window_size' in log_cfg: + log_cfg['window_size'] = \ + self._get_window_size(runner, log_cfg['window_size']) + if 'log_name' in log_cfg: + name = log_cfg.pop('log_name') + else: + name = log_key + tag[name] = log_buffers[log_key].statistics(**log_cfg) + else: + raise ValueError('The structure of `LoggerHook.custom key` is ' + 'wrong, please make sure the type of each key is ' + 'dict or list.') + + def _get_max_memory(self, runner) -> int: + """Returns the maximum GPU memory occupied by tensors in megabytes (MB) + for a given device. + + Args: + runner (Runner): The runner of the training process. + + Returns: + The maximum GPU memory occupied by tensors in megabytes for a given + device. + """ + # TODO use `mmengine.dist.max_memory_allocated` to count mem_mb + device = getattr(runner.model, 'output_device', None) + mem = torch.cuda.max_memory_allocated(device=device) + mem_mb = torch.tensor([int(mem) // (1024 * 1024)], + dtype=torch.int, + device=device) + torch.cuda.reset_peak_memory_stats() + return int(mem_mb.item()) + + def _check_custom_keys(self) -> None: + """Check the legality of ``self.custom_keys``. + + If ``self.by_epoch==False``, ``window_size`` should not be "epoch". The + key of ``self.fixed_smooth_keys`` cannot be overwritten. + """ + + def _check_window_size(item): + if not self.by_epoch: + assert item['window_size'] != 'epoch', \ + 'window_size cannot be epoch if LoggerHook.by_epoch is ' \ + 'False.' + + def _check_fixed_keys(key, item): + if key in self.fixed_smooth_keys: + assert 'log_name' in item, f'{key} cannot be overwritten by ' \ + 'custom keys!' + + for key, value in self.custom_keys.items(): + if isinstance(value, Sequence): + [(_check_window_size(item), _check_fixed_keys(key, item)) + for item in value] + + else: + _check_window_size(value) + _check_fixed_keys(key, value) + + def _get_epoch(self, runner, mode: str) -> int: + """Get epoch according to mode. + + Args: + runner (Runner): The runner of the training process. + mode (str): Train or val. + + Returns: + int: The current epoch. + """ + if mode == 'train': + epoch = runner.epoch + 1 + elif mode == 'val': + # normal val mode + # runner.epoch += 1 has been done before val workflow + epoch = runner.epoch + else: + raise ValueError(f"runner mode should be 'train' or 'val', " + f'but got {runner.mode}') + return epoch + + def _get_iter(self, runner, inner_iter=False) -> int: + """Get the current training iteration step. + Args: + runner (Runner): The runner of the training process. + inner_iter (bool): Whether to return the inner iter of an epoch. + Defaults to False. + + Returns: + int: The current global iter or inner iter. + """ + if self.by_epoch and inner_iter: + current_iter = runner.inner_iter + 1 + else: + current_iter = runner.iter + 1 + return current_iter diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py new file mode 100644 index 0000000000000000000000000000000000000000..631736e8e01161e9201ac687208c68000ca21b93 --- /dev/null +++ b/tests/test_hook/test_logger_hook.py @@ -0,0 +1,355 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import datetime +import logging +import os.path as osp +import sys +from collections import OrderedDict +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from mmengine.fileio.file_client import HardDiskBackend +from mmengine.hooks import LoggerHook + + +class TestLoggerHook: + + def test_init(self): + logger_hook = LoggerHook(out_dir='tmp.txt') + assert logger_hook.by_epoch + assert logger_hook.interval == 10 + assert not logger_hook.custom_keys + assert logger_hook.ignore_last + assert logger_hook.time_sec_tot == 0 + assert logger_hook.interval_exp_name == 1000 + assert logger_hook.out_suffix == ('.log.json', '.log', '.py') + assert logger_hook.keep_local + assert logger_hook.file_client_args is None + assert isinstance(logger_hook.file_client.client, HardDiskBackend) + # out_dir should be None or string or tuple of string. + with pytest.raises(TypeError): + LoggerHook(out_dir=1) + # time cannot be overwritten. + with pytest.raises(AssertionError): + LoggerHook(custom_keys=dict(time=dict(method='max'))) + LoggerHook( + custom_keys=dict(time=[ + dict(method='max', log_name='time_max'), + dict(method='min', log_name='time_min') + ])) + # Epoch window_size cannot be used when `LoggerHook.by_epoch=False` + with pytest.raises(AssertionError): + LoggerHook( + by_epoch=False, + custom_keys=dict( + time=dict( + method='max', log_name='time_max', + window_size='epoch'))) + with pytest.raises(ValueError): + LoggerHook(file_client_args=dict(enable_mc=True)) + + def test_before_run(self): + runner = MagicMock() + runner.iter = 10 + runner.timestamp = 'timestamp' + runner.work_dir = 'work_dir' + runner.logger = MagicMock() + logger_hook = LoggerHook(out_dir='out_dir') + logger_hook.before_run(runner) + assert logger_hook.out_dir == osp.join('out_dir', 'work_dir') + assert logger_hook.json_log_path == osp.join('work_dir', + 'timestamp.log.json') + assert logger_hook.start_iter == runner.iter + runner.writer.add_params.assert_called() + + def test_after_run(self, tmp_path): + out_dir = tmp_path / 'out_dir' + out_dir.mkdir() + work_dir = tmp_path / 'work_dir' + work_dir.mkdir() + work_dir_json = work_dir / 'tmp.log.json' + json_f = open(work_dir_json, 'w') + json_f.close() + runner = MagicMock() + runner.work_dir = work_dir + + logger_hook = LoggerHook(out_dir=str(tmp_path), keep_local=False) + logger_hook.out_dir = str(out_dir) + logger_hook.after_run(runner) + # Verify that the file has been moved to `out_dir`. + assert not osp.exists(str(work_dir_json)) + assert osp.exists(str(out_dir / 'tmp.log.json')) + + def test_after_train_iter(self): + # Test LoggerHook by iter. + runner = MagicMock() + runner.iter = 10 + logger_hook = LoggerHook(by_epoch=False) + logger_hook._log_train = MagicMock() + logger_hook.after_train_iter(runner) + # `cur_iter=10+1`, which cannot be exact division by + # `logger_hook.interval` + logger_hook._log_train.assert_not_called() + runner.iter = 9 + logger_hook.after_train_iter(runner) + logger_hook._log_train.assert_called() + + # Test LoggerHook by epoch. + logger_hook = LoggerHook(by_epoch=True) + logger_hook._log_train = MagicMock() + # Only `runner.inner_iter` will work. + runner.iter = 9 + runner.inner_iter = 10 + logger_hook.after_train_iter(runner) + logger_hook._log_train.assert_not_called() + runner.inner_iter = 9 + logger_hook.after_train_iter(runner) + logger_hook._log_train.assert_called() + + # Test end of the epoch. + logger_hook = LoggerHook(by_epoch=True, ignore_last=False) + logger_hook._log_train = MagicMock() + runner.data_loader = [0] * 5 + runner.inner_iter = 4 + logger_hook.after_train_iter(runner) + logger_hook._log_train.assert_called() + + # Test print exp_name + runner.meta = dict(exp_name='retinanet') + logger_hook = LoggerHook() + runner.logger = MagicMock() + logger_hook._log_train = MagicMock() + logger_hook.after_train_iter(runner) + runner.logger.info.assert_called_with( + f'Exp name: {runner.meta["exp_name"]}') + + def test_after_val_epoch(self): + logger_hook = LoggerHook() + runner = MagicMock() + logger_hook._log_val = MagicMock() + logger_hook.after_val_epoch(runner) + logger_hook._log_val.assert_called() + + @pytest.mark.parametrize('by_epoch', [True, False]) + def test_log_train(self, by_epoch, capsys): + runner = self._setup_runner() + runner.meta = dict(exp_name='retinanet') + # Prepare LoggerHook + logger_hook = LoggerHook(by_epoch=by_epoch) + logger_hook.writer = MagicMock() + logger_hook.time_sec_tot = 1000 + logger_hook.start_iter = 0 + logger_hook._get_max_memory = MagicMock(return_value='100') + logger_hook.json_log_path = 'tmp.json' + + # Prepare training information. + train_infos = dict( + lr=0.1, momentum=0.9, time=1.0, data_time=1.0, loss_cls=1.0) + logger_hook._collect_info = MagicMock(return_value=train_infos) + logger_hook._log_train(runner) + # Verify that the correct variables have been written. + runner.writer.add_scalars.assert_called_with( + train_infos, step=11, file_path='tmp.json') + # Verify that the correct context have been logged. + out, _ = capsys.readouterr() + time_avg = logger_hook.time_sec_tot / ( + runner.iter + 1 - logger_hook.start_iter) + eta_second = time_avg * (runner.max_iters - runner.iter - 1) + eta_str = str(datetime.timedelta(seconds=int(eta_second))) + if by_epoch: + if torch.cuda.is_available(): + log_str = 'Epoch [2][2/5]\t' \ + f"lr: {train_infos['lr']:.3e} " \ + f"momentum: {train_infos['momentum']:.3e}, " \ + f'eta: {eta_str}, ' \ + f"time: {train_infos['time']:.3f}, " \ + f"data_time: {train_infos['data_time']:.3f}, " \ + f'memory: 100, ' \ + f"loss_cls: {train_infos['loss_cls']:.4f}\n" + else: + log_str = 'Epoch [2][2/5]\t' \ + f"lr: {train_infos['lr']:.3e} " \ + f"momentum: {train_infos['momentum']:.3e}, " \ + f'eta: {eta_str}, ' \ + f"time: {train_infos['time']:.3f}, " \ + f"data_time: {train_infos['data_time']:.3f}, " \ + f"loss_cls: {train_infos['loss_cls']:.4f}\n" + assert out == log_str + else: + if torch.cuda.is_available(): + log_str = 'Iter [11/50]\t' \ + f"lr: {train_infos['lr']:.3e} " \ + f"momentum: {train_infos['momentum']:.3e}, " \ + f'eta: {eta_str}, ' \ + f"time: {train_infos['time']:.3f}, " \ + f"data_time: {train_infos['data_time']:.3f}, " \ + f'memory: 100, ' \ + f"loss_cls: {train_infos['loss_cls']:.4f}\n" + else: + log_str = 'Iter [11/50]\t' \ + f"lr: {train_infos['lr']:.3e} " \ + f"momentum: {train_infos['momentum']:.3e}, " \ + f'eta: {eta_str}, ' \ + f"time: {train_infos['time']:.3f}, " \ + f"data_time: {train_infos['data_time']:.3f}, " \ + f"loss_cls: {train_infos['loss_cls']:.4f}\n" + assert out == log_str + + @pytest.mark.parametrize('by_epoch', [True, False]) + def test_log_val(self, by_epoch, capsys): + runner = self._setup_runner() + # Prepare LoggerHook. + logger_hook = LoggerHook(by_epoch=by_epoch) + logger_hook.json_log_path = 'tmp.json' + metric = dict(accuracy=0.9, data_time=1.0) + logger_hook._collect_info = MagicMock(return_value=metric) + logger_hook._log_val(runner) + # Verify that the correct context have been logged. + out, _ = capsys.readouterr() + runner.writer.add_scalars.assert_called_with( + metric, step=11, file_path='tmp.json') + if by_epoch: + assert out == 'Epoch(val) [1][5]\taccuracy: 0.9000, ' \ + 'data_time: 1.0000\n' + + else: + assert out == 'Iter(val) [5]\taccuracy: 0.9000, ' \ + 'data_time: 1.0000\n' + + def test_get_window_size(self): + runner = self._setup_runner() + logger_hook = LoggerHook() + # Test get window size by name. + assert logger_hook._get_window_size(runner, 'epoch') == 2 + assert logger_hook._get_window_size(runner, 'global') == 11 + assert logger_hook._get_window_size(runner, 10) == 10 + # Window size must equal to `logger_hook.interval`. + with pytest.raises(AssertionError): + logger_hook._get_window_size(runner, 20) + + with pytest.raises(ValueError): + logger_hook._get_window_size(runner, 'unknwon') + + def test_parse_custom_keys(self): + tag = OrderedDict() + runner = self._setup_runner() + log_buffers = OrderedDict(lr=MagicMock(), loss=MagicMock()) + cfg_dict = dict( + lr=dict(method='min'), + loss=[ + dict(method='min', window_size='global'), + dict(method='max', log_name='loss_max') + ]) + logger_hook = LoggerHook() + for log_key, log_cfg in cfg_dict.items(): + logger_hook._parse_custom_keys(runner, log_key, log_cfg, + log_buffers, tag) + assert list(tag) == ['lr', 'loss', 'loss_max'] + assert log_buffers['lr'].min.assert_called + assert log_buffers['loss'].min.assert_called + assert log_buffers['loss'].max.assert_called + assert log_buffers['loss'].mean.assert_called + # `log_name` Cannot be repeated. + with pytest.raises(KeyError): + cfg_dict = dict(loss=[ + dict(method='min', window_size='global'), + dict(method='max', log_name='loss_max'), + dict(method='mean', log_name='loss_max') + ]) + logger_hook.custom_keys = cfg_dict + for log_key, log_cfg in cfg_dict.items(): + logger_hook._parse_custom_keys(runner, log_key, log_cfg, + log_buffers, tag) + # `log_key` cannot be overwritten multiple times. + with pytest.raises(AssertionError): + cfg_dict = dict(loss=[ + dict(method='min', window_size='global'), + dict(method='max'), + ]) + logger_hook.custom_keys = cfg_dict + for log_key, log_cfg in cfg_dict.items(): + logger_hook._parse_custom_keys(runner, log_key, log_cfg, + log_buffers, tag) + + def test_collect_info(self): + runner = self._setup_runner() + logger_hook = LoggerHook( + custom_keys=dict(time=dict(method='max', log_name='time_max'))) + logger_hook._parse_custom_keys = MagicMock() + # Collect with prefix. + log_buffers = { + 'train/time': MagicMock(), + 'lr': MagicMock(), + 'train/loss_cls': MagicMock(), + 'val/metric': MagicMock() + } + runner.message_hub.log_buffers = log_buffers + tag = logger_hook._collect_info(runner, mode='train') + # Test parse custom_keys + logger_hook._parse_custom_keys.assert_called() + # Test training key in tag. + assert list(tag.keys()) == ['time', 'loss_cls'] + # Test statistics lr with `current`, loss and time with 'mean' + log_buffers['train/time'].mean.assert_called() + log_buffers['train/loss_cls'].mean.assert_called() + log_buffers['train/loss_cls'].current.assert_not_called() + + tag = logger_hook._collect_info(runner, mode='val') + assert list(tag.keys()) == ['metric'] + log_buffers['val/metric'].current.assert_called() + + @patch('torch.distributed.reduce', MagicMock()) + def test_get_max_memory(self): + logger_hook = LoggerHook() + runner = MagicMock() + runner.world_size = 1 + runner.model = torch.nn.Linear(1, 1) + logger_hook._get_max_memory(runner) + torch.distributed.reduce.assert_not_called() + runner.world_size = 2 + logger_hook._get_max_memory(runner) + torch.distributed.reduce.assert_called() + + def test_get_iter(self): + runner = self._setup_runner() + logger_hook = LoggerHook() + # Get global iter when `inner_iter=False` + iter = logger_hook._get_iter(runner) + assert iter == 11 + # Get inner iter + iter = logger_hook._get_iter(runner, inner_iter=True) + assert iter == 2 + # Still get global iter when `logger_hook.by_epoch==False` + logger_hook.by_epoch = False + iter = logger_hook._get_iter(runner, inner_iter=True) + assert iter == 11 + + def test_get_epoch(self): + runner = self._setup_runner() + logger_hook = LoggerHook() + epoch = logger_hook._get_epoch(runner, 'train') + assert epoch == 2 + epoch = logger_hook._get_epoch(runner, 'val') + assert epoch == 1 + with pytest.raises(ValueError): + logger_hook._get_epoch(runner, 'test') + + def _setup_runner(self): + runner = MagicMock() + runner.epoch = 1 + runner.data_loader = [0] * 5 + runner.inner_iter = 1 + runner.iter = 10 + runner.max_iters = 50 + logger = logging.getLogger() + logger.setLevel(logging.INFO) + for handler in logger.handlers: + if not isinstance(handler, logging.StreamHandler): + continue + else: + logger.addHandler(logging.StreamHandler(stream=sys.stdout)) + runner.logger = logger + runner.message_hub = MagicMock() + runner.composed_wirter = MagicMock() + return runner