From 3e0c064f4960552fcb0c666fc3d50347cadfbdf3 Mon Sep 17 00:00:00 2001 From: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Date: Thu, 10 Mar 2022 17:22:31 +0800 Subject: [PATCH] [Feature] NaiveVisualizationHook (#98) * [WIP] testvisualizationhook * add TestNaiveVisualizationHook * fix comment * unpad * batch imdenormalize * fix comment * fix comment --- mmengine/hooks/__init__.py | 3 +- mmengine/hooks/naive_visualization_hook.py | 71 ++++++++++++++++ mmengine/utils/misc.py | 45 ++++++++++ .../test_naive_visualization_hook.py | 84 +++++++++++++++++++ 4 files changed, 202 insertions(+), 1 deletion(-) create mode 100644 mmengine/hooks/naive_visualization_hook.py create mode 100644 tests/test_hook/test_naive_visualization_hook.py diff --git a/mmengine/hooks/__init__.py b/mmengine/hooks/__init__.py index 45c3f910..44698bbf 100644 --- a/mmengine/hooks/__init__.py +++ b/mmengine/hooks/__init__.py @@ -4,6 +4,7 @@ from .empty_cache_hook import EmptyCacheHook from .hook import Hook from .iter_timer_hook import IterTimerHook from .logger_hook import LoggerHook +from .naive_visualization_hook import NaiveVisualizationHook from .optimizer_hook import OptimizerHook from .param_scheduler_hook import ParamSchedulerHook from .sampler_seed_hook import DistSamplerSeedHook @@ -12,5 +13,5 @@ from .sync_buffer_hook import SyncBuffersHook __all__ = [ 'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook', 'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook', - 'LoggerHook' + 'LoggerHook', 'NaiveVisualizationHook' ] diff --git a/mmengine/hooks/naive_visualization_hook.py b/mmengine/hooks/naive_visualization_hook.py new file mode 100644 index 00000000..434de95f --- /dev/null +++ b/mmengine/hooks/naive_visualization_hook.py @@ -0,0 +1,71 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Any, Optional, Sequence, Tuple + +import cv2 +import numpy as np + +from mmengine.data import BaseDataSample +from mmengine.hooks import Hook +from mmengine.registry import HOOKS +from mmengine.utils.misc import tensor2imgs + + +@HOOKS.register_module() +class NaiveVisualizationHook(Hook): + """Show or Write the predicted results during the process of testing. + + Args: + interval (int): Visualization interval. Default: 1. + draw_gt (bool): Whether to draw the ground truth. Default to True. + draw_pred (bool): Whether to draw the predicted result. + Default to True. + """ + priority = 'NORMAL' + + def __init__(self, + interval: int = 1, + draw_gt: bool = True, + draw_pred: bool = True): + self.draw_gt = draw_gt + self.draw_pred = draw_pred + self._interval = interval + + def _unpad(self, input: np.ndarray, unpad_shape: Tuple[int, + int]) -> np.ndarray: + unpad_width, unpad_height = unpad_shape + unpad_image = input[:unpad_height, :unpad_width] + return unpad_image + + def after_test_iter( + self, + runner, + data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None, + outputs: Optional[Sequence[BaseDataSample]] = None) -> None: + """Show or Write the predicted results. + + Args: + runner (Runner): The runner of the training process. + data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data + from dataloader. Defaults to None. + outputs (Sequence[BaseDataSample], optional): Outputs from model. + Defaults to None. + """ + if self.every_n_iters(runner, self._interval): + inputs, data_samples = data_batch # type: ignore + inputs = tensor2imgs(inputs, + **data_samples[0].get('img_norm_cfg', dict())) + for input, data_sample, output in zip( + inputs, + data_samples, # type: ignore + outputs): # type: ignore + # TODO We will implement a function to revert the augmentation + # in the future. + ori_shape = (data_sample.ori_width, data_sample.ori_height) + if 'pad_shape' in data_sample: + input = self._unpad(input, + data_sample.get('scale', ori_shape)) + origin_image = cv2.resize(input, ori_shape) + name = osp.basename(data_sample.img_path) + runner.writer.add_image(name, origin_image, data_sample, + output, self.draw_gt, self.draw_pred) diff --git a/mmengine/utils/misc.py b/mmengine/utils/misc.py index 79977be0..3a955116 100644 --- a/mmengine/utils/misc.py +++ b/mmengine/utils/misc.py @@ -11,6 +11,8 @@ from inspect import getfullargspec from itertools import repeat from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union +import numpy as np +import torch import torch.nn as nn from .parrots_wrapper import _BatchNorm, _InstanceNorm @@ -433,3 +435,46 @@ def is_norm(layer: nn.Module, all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm) return isinstance(layer, all_norm_bases) + + +def tensor2imgs(tensor: torch.Tensor, + mean: Optional[Tuple[float, float, float]] = None, + std: Optional[Tuple[float, float, float]] = None, + to_bgr: bool = True): + """Convert tensor to 3-channel images or 1-channel gray images. + + Args: + tensor (torch.Tensor): Tensor that contains multiple images, shape ( + N, C, H, W). :math:`C` can be either 3 or 1. If C is 3, the format + should be RGB. + mean (tuple[float], optional): Mean of images. If None, + (0, 0, 0) will be used for tensor with 3-channel, + while (0, ) for tensor with 1-channel. Defaults to None. + std (tuple[float], optional): Standard deviation of images. If None, + (1, 1, 1) will be used for tensor with 3-channel, + while (1, ) for tensor with 1-channel. Defaults to None. + to_bgr (bool): For the tensor with 3 channel, convert its format to + BGR. For the tensor with 1 channel, it must be False. Defaults to + True. + + Returns: + list[np.ndarray]: A list that contains multiple images. + """ + + assert torch.is_tensor(tensor) and tensor.ndim == 4 + channels = tensor.size(1) + assert channels in [1, 3] + if mean is None: + mean = (0, ) * channels + if std is None: + std = (1, ) * channels + assert (channels == len(mean) == len(std) == 3) or \ + (channels == len(mean) == len(std) == 1 and not to_bgr) + mean = tensor.new_tensor(mean).view(1, -1) + std = tensor.new_tensor(std).view(1, -1) + tensor = tensor.permute(0, 2, 3, 1) * std + mean + imgs = tensor.detach().cpu().numpy() + if to_bgr and channels == 3: + imgs = imgs[:, :, :, (2, 1, 0)] # RGB2BGR + imgs = [np.ascontiguousarray(img) for img in imgs] + return imgs diff --git a/tests/test_hook/test_naive_visualization_hook.py b/tests/test_hook/test_naive_visualization_hook.py new file mode 100644 index 00000000..beb053a4 --- /dev/null +++ b/tests/test_hook/test_naive_visualization_hook.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import Mock + +import torch + +from mmengine.data import BaseDataSample +from mmengine.hooks import NaiveVisualizationHook + + +class TestNaiveVisualizationHook: + + def test_after_train_iter(self): + naive_visualization_hook = NaiveVisualizationHook() + Runner = Mock(iter=1) + Runner.writer.add_image = Mock() + inputs = torch.randn(1, 3, 15, 15) + # test with normalize, resize, pad + gt_datasamples = [ + BaseDataSample( + metainfo=dict( + img_norm_cfg=dict( + mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True), + scale=(10, 10), + pad_shape=(15, 15, 3), + ori_height=5, + ori_width=5, + img_path='tmp.jpg')) + ] + pred_datasamples = [BaseDataSample()] + data_batch = (inputs, gt_datasamples) + naive_visualization_hook.after_test_iter(Runner, data_batch, + pred_datasamples) + # test with resize, pad + gt_datasamples = [ + BaseDataSample( + metainfo=dict( + scale=(10, 10), + pad_shape=(15, 15, 3), + ori_height=5, + ori_width=5, + img_path='tmp.jpg')), + ] + pred_datasamples = [BaseDataSample()] + data_batch = (inputs, gt_datasamples) + naive_visualization_hook.after_test_iter(Runner, data_batch, + pred_datasamples) + # test with only resize + gt_datasamples = [ + BaseDataSample( + metainfo=dict( + scale=(15, 15), + ori_height=5, + ori_width=5, + img_path='tmp.jpg')), + ] + pred_datasamples = [BaseDataSample()] + data_batch = (inputs, gt_datasamples) + naive_visualization_hook.after_test_iter(Runner, data_batch, + pred_datasamples) + + # test with only pad + gt_datasamples = [ + BaseDataSample( + metainfo=dict( + pad_shape=(15, 15, 3), + ori_height=5, + ori_width=5, + img_path='tmp.jpg')), + ] + pred_datasamples = [BaseDataSample()] + data_batch = (inputs, gt_datasamples) + naive_visualization_hook.after_test_iter(Runner, data_batch, + pred_datasamples) + + # test no transform + gt_datasamples = [ + BaseDataSample( + metainfo=dict(ori_height=15, ori_width=15, + img_path='tmp.jpg')), + ] + pred_datasamples = [BaseDataSample()] + data_batch = (inputs, gt_datasamples) + naive_visualization_hook.after_test_iter(Runner, data_batch, + pred_datasamples) -- GitLab