From 6996bdc89279ff6783c7c58f9015cd543ce3a67e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Haian=20Huang=28=E6=B7=B1=E5=BA=A6=E7=9C=B8=29?= <1286304229@qq.com> Date: Wed, 27 Apr 2022 19:44:40 +0800 Subject: [PATCH] Update visualizer code (#184) * update vis backend * update vis backend * update vis backend * update visualizer * update visualizer * update visualizer * update featmap * update featmap * update visualizer and unitest * add draw points unitest and refactor vis_backend * fix typo and close unitest * fix comment * add docstring * fix comment * add master only * fix comment Co-authored-by: liukuikun <641417025@qq.com> --- mmengine/visualization/utils.py | 116 +++- mmengine/visualization/vis_backend.py | 379 ++++++++----- mmengine/visualization/visualizer.py | 636 ++++++++++++---------- tests/test_visualizer/test_vis_backend.py | 124 +++-- tests/test_visualizer/test_visualizer.py | 257 ++++++--- 5 files changed, 949 insertions(+), 563 deletions(-) diff --git a/mmengine/visualization/utils.py b/mmengine/visualization/utils.py index a0033dac..c9a9c8dc 100644 --- a/mmengine/visualization/utils.py +++ b/mmengine/visualization/utils.py @@ -4,8 +4,11 @@ from typing import Any, List, Optional, Tuple, Type, Union import cv2 import matplotlib +import matplotlib.pyplot as plt import numpy as np import torch +from matplotlib.backend_bases import CloseEvent +from matplotlib.backends.backend_agg import FigureCanvasAgg def tensor2ndarray(value: Union[np.ndarray, torch.Tensor]) -> np.ndarray: @@ -89,13 +92,16 @@ def check_type_and_length(name: str, value: Any, check_length(name, value, valid_length) -def color_val_matplotlib(colors): +def color_val_matplotlib( + colors: Union[str, tuple, List[Union[str, tuple]]] +) -> Union[str, tuple, List[Union[str, tuple]]]: """Convert various input in RGB order to normalized RGB matplotlib color tuples, Args: - color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Color inputs + colors (Union[str, tuple, List[Union[str, tuple]]]): Color inputs Returns: - tuple[float]: A tuple of 3 normalized floats indicating RGB channels. + Union[str, tuple, List[Union[str, tuple]]]: A tuple of 3 normalized + floats indicating RGB channels. """ if isinstance(colors, str): return colors @@ -106,16 +112,28 @@ def color_val_matplotlib(colors): colors = [channel / 255 for channel in colors] return tuple(colors) elif isinstance(colors, list): - colors = [color_val_matplotlib(color) for color in colors] + colors = [ + color_val_matplotlib(color) # type:ignore + for color in colors + ] return colors else: raise TypeError(f'Invalid type for color: {type(colors)}') -def str_color_to_rgb(color): - color = matplotlib.colors.to_rgb(color) - color = tuple([int(c * 255) for c in color]) - return color +def color_str2rgb(color: str) -> tuple: + """Convert Matplotlib str color to an RGB color which range is 0 to 255, + silently dropping the alpha channel. + + Args: + color (str): Matplotlib color. + + Returns: + tuple: RGB color. + """ + rgb_color: tuple = matplotlib.colors.to_rgb(color) + rgb_color = tuple([int(c * 255) for c in rgb_color]) + return rgb_color def convert_overlay_heatmap(feat_map: Union[np.ndarray, torch.Tensor], @@ -129,18 +147,96 @@ def convert_overlay_heatmap(feat_map: Union[np.ndarray, torch.Tensor], the image width. img (np.ndarray, optional): The origin image. The format should be RGB. Defaults to None. - alpha (float): The transparency of origin image. Defaults to 0.5. + alpha (float): The transparency of featmap. Defaults to 0.5. Returns: np.ndarray: heatmap """ + assert feat_map.ndim == 2 or (feat_map.ndim == 3 + and feat_map.shape[0] in [1, 3]) if isinstance(feat_map, torch.Tensor): feat_map = feat_map.detach().cpu().numpy() + + if feat_map.ndim == 3: + feat_map = feat_map.transpose(1, 2, 0) + norm_img = np.zeros(feat_map.shape) norm_img = cv2.normalize(feat_map, norm_img, 0, 255, cv2.NORM_MINMAX) norm_img = np.asarray(norm_img, dtype=np.uint8) heat_img = cv2.applyColorMap(norm_img, cv2.COLORMAP_JET) heat_img = cv2.cvtColor(heat_img, cv2.COLOR_BGR2RGB) if img is not None: - heat_img = cv2.addWeighted(img, alpha, heat_img, 1 - alpha, 0) + heat_img = cv2.addWeighted(img, 1 - alpha, heat_img, alpha, 0) return heat_img + + +def wait_continue(figure, timeout: int = 0, continue_key: str = ' ') -> int: + """Show the image and wait for the user's input. + + This implementation refers to + https://github.com/matplotlib/matplotlib/blob/v3.5.x/lib/matplotlib/_blocking_input.py + + Args: + timeout (int): If positive, continue after ``timeout`` seconds. + Defaults to 0. + continue_key (str): The key for users to continue. Defaults to + the space key. + + Returns: + int: If zero, means time out or the user pressed ``continue_key``, + and if one, means the user closed the show figure. + """ # noqa: E501 + is_inline = 'inline' in plt.get_backend() + if is_inline: + # If use inline backend, interactive input and timeout is no use. + return 0 + + if figure.canvas.manager: # type: ignore + # Ensure that the figure is shown + figure.show() # type: ignore + + while True: + + # Connect the events to the handler function call. + event = None + + def handler(ev): + # Set external event variable + nonlocal event + # Qt backend may fire two events at the same time, + # use a condition to avoid missing close event. + event = ev if not isinstance(event, CloseEvent) else event + figure.canvas.stop_event_loop() + + cids = [ + figure.canvas.mpl_connect(name, handler) # type: ignore + for name in ('key_press_event', 'close_event') + ] + + try: + figure.canvas.start_event_loop(timeout) # type: ignore + finally: # Run even on exception like ctrl-c. + # Disconnect the callbacks. + for cid in cids: + figure.canvas.mpl_disconnect(cid) # type: ignore + + if isinstance(event, CloseEvent): + return 1 # Quit for close. + elif event is None or event.key == continue_key: + return 0 # Quit for continue. + + +def img_from_canvas(canvas: FigureCanvasAgg) -> np.ndarray: + """Get RGB image from ``FigureCanvasAgg``. + + Args: + canvas (FigureCanvasAgg): The canvas to get image. + + Returns: + np.ndarray: the output of image in RGB. + """ # noqa: E501 + s, (width, height) = canvas.print_to_buffer() + buffer = np.frombuffer(s, dtype='uint8') + img_rgba = buffer.reshape(height, width, 4) + rgb, alpha = np.split(img_rgba, [3], axis=2) + return rgb.astype('uint8') diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py index 13de36d8..5c135fa4 100644 --- a/mmengine/visualization/vis_backend.py +++ b/mmengine/visualization/vis_backend.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +import functools import os import os.path as osp -import time +import warnings from abc import ABCMeta, abstractmethod -from typing import Any, Optional, Sequence, Union +from typing import Any, Callable, Optional, Sequence, Union import cv2 import numpy as np @@ -15,37 +16,78 @@ from mmengine.registry import VISBACKENDS from mmengine.utils import TORCH_VERSION +def force_init_env(old_func: Callable) -> Any: + """Those methods decorated by ``force_init_env`` will be forced to call + ``_init_env`` if the instance has not been fully initiated. This function + will decorated all the `add_xxx` method and `experiment` method, because + `VisBackend` is initialized only when used its API. + + Args: + old_func (Callable): Decorated function, make sure the first arg is an + instance with ``_init_env`` method. + + Returns: + Any: Depends on old_func. + """ + + @functools.wraps(old_func) + def wrapper(obj: object, *args, **kwargs): + # The instance must have `_init_env` method. + if not hasattr(obj, '_init_env'): + raise AttributeError(f'{type(obj)} does not have _init_env ' + 'method.') + # If instance does not have `_env_initialized` attribute or + # `_env_initialized` is False, call `_init_env` and set + # `_env_initialized` to True + if not getattr(obj, '_env_initialized', False): + warnings.warn('Attribute `_env_initialized` is not defined in ' + f'{type(obj)} or `type(obj)._env_initialized is ' + 'False, `_init_env` will be called and ' + f'{type(obj)}._env_initialized will be set to ' + 'True') + obj._init_env() # type: ignore + obj._env_initialized = True # type: ignore + + return old_func(obj, *args, **kwargs) + + return wrapper + + class BaseVisBackend(metaclass=ABCMeta): - """Base class for vis backend. + """Base class for visualization backend. All backends must inherit ``BaseVisBackend`` and implement the required functions. Args: save_dir (str, optional): The root directory to save - the files produced by the backend. Default to None. + the files produced by the backend. """ - def __init__(self, save_dir: Optional[str] = None): + def __init__(self, save_dir: str): self._save_dir = save_dir - if self._save_dir: - timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) - self._save_dir = osp.join(self._save_dir, - f'vis_data_{timestamp}') # type: ignore + self._env_initialized = False @property @abstractmethod def experiment(self) -> Any: - """Return the experiment object associated with this writer. + """Return the experiment object associated with this visualization + backend. - The experiment attribute can get the visualizer backend, such as wandb, - tensorboard. If you want to write other data, such as writing a table, - you can directly get the visualizer backend through experiment. + The experiment attribute can get the visualization backend, such as + wandb, tensorboard. If you want to write other data, such as writing a + table, you can directly get the visualization backend through + experiment. """ pass + @abstractmethod + def _init_env(self) -> Any: + """Setup env for VisBackend.""" + pass + def add_config(self, config: Config, **kwargs) -> None: - """Record a set of parameters. + """Record the config. Args: config (Config): The Config object @@ -54,7 +96,7 @@ class BaseVisBackend(metaclass=ABCMeta): def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict], **kwargs) -> None: - """Record graph. + """Record the model graph. Args: model (torch.nn.Module): Model to draw. @@ -67,11 +109,11 @@ class BaseVisBackend(metaclass=ABCMeta): image: np.ndarray, step: int = 0, **kwargs) -> None: - """Record image. + """Record the image. Args: - name (str): The unique identifier for the image to save. - image (np.ndarray, optional): The image to be saved. The format + name (str): The image identifier. + image (np.ndarray): The image to be saved. The format should be RGB. Default to None. step (int): Global step value to record. Default to 0. """ @@ -82,11 +124,11 @@ class BaseVisBackend(metaclass=ABCMeta): value: Union[int, float], step: int = 0, **kwargs) -> None: - """Record scalar. + """Record the scalar. Args: - name (str): The unique identifier for the scalar to save. - value (float, int): Value to save. + name (str): The scalar identifier. + value (int, float): Value to save. step (int): Global step value to record. Default to 0. """ pass @@ -96,7 +138,7 @@ class BaseVisBackend(metaclass=ABCMeta): step: int = 0, file_path: Optional[str] = None, **kwargs) -> None: - """Record scalars' data. + """Record the scalars' data. Args: scalar_dict (dict): Key-value pair storing the tag and @@ -116,108 +158,130 @@ class BaseVisBackend(metaclass=ABCMeta): @VISBACKENDS.register_module() class LocalVisBackend(BaseVisBackend): - """Local vis backend class. + """Local visualization backend class. It can write image, config, scalars, etc. to the local hard disk. You can get the drawing backend - through the visualizer property for custom drawing. + through the experiment property for custom drawing. Examples: >>> from mmengine.visualization import LocalVisBackend >>> import numpy as np >>> local_vis_backend = LocalVisBackend(save_dir='temp_dir') - >>> img=np.random.randint(0, 256, size=(10, 10, 3)) + >>> img = np.random.randint(0, 256, size=(10, 10, 3)) >>> local_vis_backend.add_image('img', img) - >>> local_vis_backend.add_scaler('mAP', 0.6) + >>> local_vis_backend.add_scalar('mAP', 0.6) >>> local_vis_backend.add_scalars({'loss': [1, 2, 3], 'acc': 0.8}) - >>> local_vis_backend.add_image('img', image) + >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) + >>> local_vis_backend.add_config(cfg) Args: save_dir (str, optional): The root directory to save the files - produced by the writer. If it is none, it means no data - is stored. Default None. + produced by the visualizer. If it is none, it means no data + is stored. img_save_dir (str): The directory to save images. - Default to 'writer_image'. - config_save_file (str): The file to save parameters. - Default to 'parameters.yaml'. - scalar_save_file (str): The file to save scalar values. + Default to 'vis_image'. + config_save_file (str): The file name to save config. + Default to 'config.py'. + scalar_save_file (str): The file name to save scalar values. Default to 'scalars.json'. """ def __init__(self, - save_dir: Optional[str] = None, + save_dir: str, img_save_dir: str = 'vis_image', config_save_file: str = 'config.py', scalar_save_file: str = 'scalars.json'): assert config_save_file.split('.')[-1] == 'py' assert scalar_save_file.split('.')[-1] == 'json' super(LocalVisBackend, self).__init__(save_dir) - if self._save_dir is not None: - os.makedirs(self._save_dir, exist_ok=True) # type: ignore - self._img_save_dir = osp.join( - self._save_dir, # type: ignore - img_save_dir) - self._scalar_save_file = osp.join( - self._save_dir, # type: ignore - scalar_save_file) - self._config_save_file = osp.join( - self._save_dir, # type: ignore - config_save_file) - - @property + self._img_save_dir = img_save_dir + self._config_save_file = config_save_file + self._scalar_save_file = scalar_save_file + + def _init_env(self): + """Init save dir.""" + if not os.path.exists(self._save_dir): + os.makedirs(self._save_dir, exist_ok=True) + self._img_save_dir = osp.join( + self._save_dir, # type: ignore + self._img_save_dir) + self._config_save_file = osp.join( + self._save_dir, # type: ignore + self._config_save_file) + self._scalar_save_file = osp.join( + self._save_dir, # type: ignore + self._scalar_save_file) + + @property # type: ignore + @force_init_env def experiment(self) -> 'LocalVisBackend': - """Return the experiment object associated with this visualizer + """Return the experiment object associated with this visualization backend.""" return self + @force_init_env def add_config(self, config: Config, **kwargs) -> None: - # TODO + """Record the config to disk. + + Args: + config (Config): The Config object + """ assert isinstance(config, Config) + config.dump(self._config_save_file) + @force_init_env def add_image(self, name: str, - image: np.ndarray = None, + image: np.array, step: int = 0, **kwargs) -> None: - """Record image to disk. + """Record the image to disk. Args: - name (str): The unique identifier for the image to save. - image (np.ndarray, optional): The image to be saved. The format + name (str): The image identifier. + image (np.ndarray): The image to be saved. The format should be RGB. Default to None. step (int): Global step value to record. Default to 0. """ - + assert image.dtype == np.uint8 drawn_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) os.makedirs(self._img_save_dir, exist_ok=True) save_file_name = f'{name}_{step}.png' cv2.imwrite(osp.join(self._img_save_dir, save_file_name), drawn_image) + @force_init_env def add_scalar(self, name: str, - value: Union[int, float], + value: Union[int, float, torch.Tensor, np.ndarray], step: int = 0, **kwargs) -> None: - """Add scalar data to disk. + """Record the scalar data to disk. Args: - name (str): The unique identifier for the scalar to save. - value (float, int): Value to save. + name (str): The scalar identifier. + value (int, float, torch.Tensor, np.ndarray): Value to save. step (int): Global step value to record. Default to 0. """ + if isinstance(value, torch.Tensor): + value = value.item() self._dump({name: value, 'step': step}, self._scalar_save_file, 'json') + @force_init_env def add_scalars(self, scalar_dict: dict, step: int = 0, file_path: Optional[str] = None, **kwargs) -> None: - """Record scalars. The scalar dict will be written to the default and - specified files if ``file_name`` is specified. + """Record the scalars to disk. + + The scalar dict will be written to the default and + specified files if ``file_path`` is specified. Args: scalar_dict (dict): Key-value pair storing the tag and - corresponding values. + corresponding values. The value must be dumped + into json format. step (int): Global step value to record. Default to 0. file_path (str, optional): The scalar's data will be saved to the ``file_path`` file at the same time @@ -226,14 +290,15 @@ class LocalVisBackend(BaseVisBackend): """ assert isinstance(scalar_dict, dict) scalar_dict.setdefault('step', step) + if file_path is not None: assert file_path.split('.')[-1] == 'json' new_save_file_path = osp.join( self._save_dir, # type: ignore file_path) assert new_save_file_path != self._scalar_save_file, \ - '"file_path" and "scalar_save_file" have the same name, ' \ - 'please set "file_path" to another value' + '``file_path`` and ``scalar_save_file`` have the ' \ + 'same name, please set ``file_path`` to another value' self._dump(scalar_dict, new_save_file_path, 'json') self._dump(scalar_dict, self._scalar_save_file, 'json') @@ -242,7 +307,7 @@ class LocalVisBackend(BaseVisBackend): """dump dict to file. Args: - value_dict (dict) : Save dict data. + value_dict (dict) : The dict data to saved. file_path (str): The file path to save data. file_format (str): The file format to save data. """ @@ -253,7 +318,7 @@ class LocalVisBackend(BaseVisBackend): @VISBACKENDS.register_module() class WandbVisBackend(BaseVisBackend): - """Write various types of data to wandb. + """Wandb visualization backend class. Examples: >>> from mmengine.visualization import WandbVisBackend @@ -263,9 +328,12 @@ class WandbVisBackend(BaseVisBackend): >>> wandb_vis_backend.add_image('img', img) >>> wandb_vis_backend.add_scaler('mAP', 0.6) >>> wandb_vis_backend.add_scalars({'loss': [1, 2, 3],'acc': 0.8}) - >>> wandb_vis_backend.add_image('img', img) + >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) + >>> wandb_vis_backend.add_config(cfg) Args: + save_dir (str, optional): The root directory to save the files + produced by the visualizer. init_kwargs (dict, optional): wandb initialization input parameters. Default to None. commit: (bool, optional) Save the metrics dict to the wandb server @@ -273,19 +341,35 @@ class WandbVisBackend(BaseVisBackend): updates the current metrics dict with the row argument and metrics won't be saved until `wandb.log` is called with `commit=True`. Default to True. - save_dir (str, optional): The root directory to save the files - produced by the writer. Default to None. """ def __init__(self, + save_dir: str, init_kwargs: Optional[dict] = None, - commit: Optional[bool] = True, - save_dir: Optional[str] = None): + commit: Optional[bool] = True): super(WandbVisBackend, self).__init__(save_dir) + self._init_kwargs = init_kwargs self._commit = commit - self._wandb = self._setup_env(init_kwargs) - @property + def _init_env(self): + """Setup env for wandb.""" + if not os.path.exists(self._save_dir): + os.makedirs(self._save_dir, exist_ok=True) # type: ignore + if self._init_kwargs is None: + self._init_kwargs = {'dir': self._save_dir} + else: + self._init_kwargs.setdefault('dir', self._save_dir) + try: + import wandb + except ImportError: + raise ImportError( + 'Please run "pip install wandb" to install wandb') + + wandb.init(**self._init_kwargs) + self._wandb = wandb + + @property # type: ignore + @force_init_env def experiment(self): """Return wandb object. @@ -295,75 +379,69 @@ class WandbVisBackend(BaseVisBackend): """ return self._wandb - def _setup_env(self, init_kwargs: Optional[dict] = None) -> Any: - """Setup env. + @force_init_env + def add_config(self, config: Config, **kwargs) -> None: + """Record the config to wandb. Args: - init_kwargs (dict): The init args. - - Return: - :obj:`wandb` + config (Config): The Config object """ - try: - import wandb - except ImportError: - raise ImportError( - 'Please run "pip install wandb" to install wandb') - if init_kwargs: - wandb.init(**init_kwargs) - else: - wandb.init() - - return wandb - - def add_config(self, config: Config, **kwargs) -> None: - # TODO - pass + cfg_path = os.path.join(self._wandb.run.dir, 'config.py') + config.dump(cfg_path) + # Files under run.dir are automatically uploaded, + # so no need to manually call save. + # self._wandb.save(cfg_path) + @force_init_env def add_image(self, name: str, - image: np.ndarray = None, + image: np.ndarray, step: int = 0, **kwargs) -> None: - """Record image to wandb. + """Record the image to wandb. Args: - name (str): The unique identifier for the image to save. - image (np.ndarray, optional): The image to be saved. The format - should be RGB. Default to None. - step (int): Global step value to record. Default to 0. + name (str): The image identifier. + image (np.ndarray): The image to be saved. The format + should be RGB. + step (int): Useless parameter. Wandb does not + need this parameter. Default to 0. """ - self._wandb.log({name: image}, commit=self._commit, step=step) + self._wandb.log({name: image}, commit=self._commit) + @force_init_env def add_scalar(self, name: str, - value: Union[int, float], + value: Union[int, float, torch.Tensor, np.ndarray], step: int = 0, **kwargs) -> None: - """Record scalar data to wandb. + """Record the scalar data to wandb. Args: - name (str): The unique identifier for the scalar to save. - value (float, int): Value to save. - step (int): Global step value to record. Default to 0. + name (str): The scalar identifier. + value (int, float, torch.Tensor, np.ndarray): Value to save. + step (int): Useless parameter. Wandb does not + need this parameter. Default to 0. """ - self._wandb.log({name: value}, commit=self._commit, step=step) + self._wandb.log({name: value}, commit=self._commit) + @force_init_env def add_scalars(self, scalar_dict: dict, step: int = 0, file_path: Optional[str] = None, **kwargs) -> None: - """Record scalar's data to wandb. + """Record the scalar's data to wandb. Args: scalar_dict (dict): Key-value pair storing the tag and corresponding values. - step (int): Global step value to record. Default to 0. + step (int): Useless parameter. Wandb does not + need this parameter. Default to 0. file_path (str, optional): Useless parameter. Just for interface unification. Default to None. """ - self._wandb.log(scalar_dict, commit=self._commit, step=step) + self._wandb.log(scalar_dict, commit=self._commit) def close(self) -> None: """close an opened wandb object.""" @@ -373,43 +451,35 @@ class WandbVisBackend(BaseVisBackend): @VISBACKENDS.register_module() class TensorboardVisBackend(BaseVisBackend): - """Tensorboard class. It can write images, config, scalars, etc. to a - tensorboard file. + """Tensorboard visualization backend class. - Its drawing function is provided by Visualizer. + It can write images, config, scalars, etc. to a + tensorboard file. Examples: >>> from mmengine.visualization import TensorboardVisBackend >>> import numpy as np - >>> tensorboard_visualizer = TensorboardVisBackend(save_dir='temp_dir') + >>> tensorboard_vis_backend = \ + >>> TensorboardVisBackend(save_dir='temp_dir') >>> img=np.random.randint(0, 256, size=(10, 10, 3)) - >>> tensorboard_visualizer.add_image('img', img) - >>> tensorboard_visualizer.add_scaler('mAP', 0.6) - >>> tensorboard_visualizer.add_scalars({'loss': 0.1,'acc':0.8}) - >>> tensorboard_visualizer.add_image('img', image) + >>> tensorboard_vis_backend.add_image('img', img) + >>> tensorboard_vis_backend.add_scaler('mAP', 0.6) + >>> tensorboard_vis_backend.add_scalars({'loss': 0.1,'acc':0.8}) + >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) + >>> tensorboard_vis_backend.add_config(cfg) Args: save_dir (str): The root directory to save the files produced by the backend. - log_dir (str): Save directory location. Default to 'tf_logs'. """ - def __init__(self, - save_dir: Optional[str] = None, - log_dir: str = 'tf_logs'): + def __init__(self, save_dir: str): super(TensorboardVisBackend, self).__init__(save_dir) - if save_dir is not None: - self._tensorboard = self._setup_env(log_dir) - - def _setup_env(self, log_dir: str): - """Setup env. - Args: - log_dir (str): Save directory location. - - Return: - :obj:`SummaryWriter` - """ + def _init_env(self): + """Setup env for Tensorboard.""" + if not os.path.exists(self._save_dir): + os.makedirs(self._save_dir, exist_ok=True) # type: ignore if TORCH_VERSION == 'parrots': try: from tensorboardX import SummaryWriter @@ -424,56 +494,65 @@ class TensorboardVisBackend(BaseVisBackend): 'Please run "pip install future tensorboard" to install ' 'the dependencies to use torch.utils.tensorboard ' '(applicable to PyTorch 1.1 or higher)') - if self._save_dir is None: - return SummaryWriter(f'./{log_dir}') - else: - self.log_dir = osp.join(self._save_dir, log_dir) # type: ignore - return SummaryWriter(self.log_dir) + self._tensorboard = SummaryWriter(self._save_dir) - @property + @property # type: ignore + @force_init_env def experiment(self): """Return Tensorboard object.""" return self._tensorboard + @force_init_env def add_config(self, config: Config, **kwargs) -> None: - # TODO - pass + """Record the config to tensorboard. + Args: + config (Config): The Config object + """ + self._tensorboard.add_text('config', config.pretty_text) + + @force_init_env def add_image(self, name: str, image: np.ndarray, step: int = 0, **kwargs) -> None: - """Record image to tensorboard. + """Record the image to tensorboard. Args: - name (str): The unique identifier for the image to save. - image (np.ndarray, optional): The image to be saved. The format - should be RGB. Default to None. + name (str): The image identifier. + image (np.ndarray): The image to be saved. The format + should be RGB. step (int): Global step value to record. Default to 0. """ self._tensorboard.add_image(name, image, step, dataformats='HWC') + @force_init_env def add_scalar(self, name: str, - value: Union[int, float], + value: Union[int, float, torch.Tensor, np.ndarray], step: int = 0, **kwargs) -> None: - """Record scalar data to summary. + """Record the scalar data to tensorboard. Args: - name (str): The unique identifier for the scalar to save. - value (float, int): Value to save. + name (str): The scalar identifier. + value (int, float, torch.Tensor, np.ndarray): Value to save. step (int): Global step value to record. Default to 0. """ - self._tensorboard.add_scalar(name, value, step) + if isinstance(value, (int, float, torch.Tensor, np.ndarray)): + self._tensorboard.add_scalar(name, value, step) + else: + warnings.warn(f'Got {type(value)}, but numpy array, torch tensor, ' + f'int or float are expected. skip itï¼') + @force_init_env def add_scalars(self, scalar_dict: dict, step: int = 0, file_path: Optional[str] = None, **kwargs) -> None: - """Record scalar's data to summary. + """Record the scalar's data to tensorboard. Args: scalar_dict (dict): Key-value pair storing the tag and diff --git a/mmengine/visualization/visualizer.py b/mmengine/visualization/visualizer.py index d8025616..880683b4 100644 --- a/mmengine/visualization/visualizer.py +++ b/mmengine/visualization/visualizer.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp import warnings from typing import Dict, List, Optional, Sequence, Tuple, Union @@ -6,22 +7,21 @@ import cv2 import matplotlib.pyplot as plt import numpy as np import torch -from matplotlib.backend_bases import CloseEvent -from matplotlib.backends.backend_agg import FigureCanvasAgg +import torch.nn.functional as F from matplotlib.collections import (LineCollection, PatchCollection, PolyCollection) -from matplotlib.figure import Figure from matplotlib.patches import Circle from mmengine.config import Config from mmengine.data import BaseDataElement +from mmengine.dist import master_only from mmengine.registry import VISBACKENDS, VISUALIZERS from mmengine.utils import ManagerMixin from mmengine.visualization.utils import (check_type, check_type_and_length, - color_val_matplotlib, + color_str2rgb, color_val_matplotlib, convert_overlay_heatmap, - str_color_to_rgb, tensor2ndarray, - value2list) + img_from_canvas, tensor2ndarray, + value2list, wait_continue) from mmengine.visualization.vis_backend import BaseVisBackend @@ -30,40 +30,55 @@ class Visualizer(ManagerMixin): """MMEngine provides a Visualizer class that uses the ``Matplotlib`` library as the backend. It has the following functions: - - Basic info methods - - - set_image: sets the original image data - - get_image: get the image data in Numpy format after drawing - - show: visualization. - - register_task: registers the drawing function. - - Basic drawing methods - draw_bboxes: draw single or multiple bounding boxes - draw_texts: draw single or multiple text boxes + - draw_points: draw single or multiple points - draw_lines: draw single or multiple line segments - draw_circles: draw single or multiple circles - draw_polygons: draw single or multiple polygons - draw_binary_masks: draw single or multiple binary masks - - draw: The abstract drawing interface used by the user + - draw_featmap: draw feature map - - Enhanced methods + - Basic visualizer backend methods + + - add_configs: write config to all vis storage backends + - add_graph: write model graph to all vis storage backends + - add_image: write image to all vis storage backends + - add_scalar: write scalar to all vis storage backends + - add_scalars: write scalars to all vis storage backends + - add_datasample: write datasample to all vis storage \ + backends. The abstract drawing interface used by the user + + - Basic info methods + + - set_image: sets the original image data + - get_image: get the image data in Numpy format after drawing + - show: visualization + - close: close all resources that have been opened + - get_backend: get the specified vis backend - - draw_featmap: draw feature map All the basic drawing methods support chain calls, which is convenient for overlaydrawing and display. Each downstream algorithm library can inherit - ``Visualizer`` and implement the draw logic in the draw interface. For - example, ``DetVisualizer`` in MMDetection inherits from ``Visualizer`` + ``Visualizer`` and implement the add_datasample logic. For example, + ``DetLocalVisualizer`` in MMDetection inherits from ``Visualizer`` and implements functions, such as visual detection boxes, instance masks, - and semantic segmentation maps in the draw interface. + and semantic segmentation maps in the add_datasample interface. Args: - metadata (dict, optional): A dict contains the meta information - of single image. such as ``dict(img_shape=(512, 512, 3), - scale_factor=(1, 1, 1, 1))``. Defaults to None. + name (str): Name of the instance. Defaults to 'visualizer'. image (np.ndarray, optional): the origin image to draw. The format should be RGB. Defaults to None. + vis_backends (list, optional): Visual backend config list. + Default to None. + save_dir (str, optional): Save file dir for all storage backends. + If it is None, the backend storage will not save any data. + fig_save_cfg (dict): Keyword parameters of figure for saving. + Defaults to empty dict. + fig_show_cfg (dict): Keyword parameters of figure for showing. + Defaults to empty dict. Examples: >>> # Basic info methods @@ -73,83 +88,111 @@ class Visualizer(ManagerMixin): >>> vis.show() >>> # Basic drawing methods - >>> vis = Visualizer(metadata=metadata, image=image) + >>> vis = Visualizer(image=image) >>> vis.draw_bboxes(np.array([0, 0, 1, 1]), edge_colors='g') >>> vis.draw_bboxes(bbox=np.array([[1, 1, 2, 2], [2, 2, 3, 3]]), - edge_colors=['g', 'r'], is_filling=True) + >>> edge_colors=['g', 'r']) >>> vis.draw_lines(x_datas=np.array([1, 3]), - y_datas=np.array([1, 3]), - colors='r', line_widths=1) + >>> y_datas=np.array([1, 3]), + >>> colors='r', line_widths=1) >>> vis.draw_lines(x_datas=np.array([[1, 3], [2, 4]]), - y_datas=np.array([[1, 3], [2, 4]]), - colors=['r', 'r'], line_widths=[1, 2]) + >>> y_datas=np.array([[1, 3], [2, 4]]), + >>> colors=['r', 'r'], line_widths=[1, 2]) >>> vis.draw_texts(text='MMEngine', - position=np.array([2, 2]), - colors='b') - >>> vis.draw_texts(text=['MMEngine','OpenMMLab'] - position=np.array([[2, 2], [5, 5]]), - colors=['b', 'b']) + >>> position=np.array([2, 2]), + >>> colors='b') + >>> vis.draw_texts(text=['MMEngine','OpenMMLab'], + >>> position=np.array([[2, 2], [5, 5]]), + >>> colors=['b', 'b']) >>> vis.draw_circles(circle_coord=np.array([2, 2]), radius=np.array[1]) >>> vis.draw_circles(circle_coord=np.array([[2, 2], [3, 5]), - radius=np.array[1, 2], colors=['g', 'r'], - is_filling=True) + >>> radius=np.array[1, 2], colors=['g', 'r']) >>> vis.draw_polygons(np.array([0, 0, 1, 0, 1, 1, 0, 1]), - edge_colors='g') + >>> edge_colors='g') >>> vis.draw_polygons(bbox=[np.array([0, 0, 1, 0, 1, 1, 0, 1], - np.array([2, 2, 3, 2, 3, 3, 2, 3]]), - edge_colors=['g', 'r'], is_filling=True) + >>> np.array([2, 2, 3, 2, 3, 3, 2, 3]], + >>> edge_colors=['g', 'r']) >>> vis.draw_binary_masks(binary_mask, alpha=0.6) + >>> heatmap = vis.draw_featmap(featmap, img, + >>> channel_reduction='select_max') + >>> heatmap = vis.draw_featmap(featmap, img, channel_reduction=None, + >>> topk=8, arrangement=(4, 2)) + >>> heatmap = vis.draw_featmap(featmap, img, channel_reduction=None, + >>> topk=-1) >>> # chain calls >>> vis.draw_bboxes().draw_texts().draw_circle().draw_binary_masks() - >>> # Enhanced method - >>> vis = Visualizer(metadata=metadata, image=image) - >>> heatmap = vis.draw_featmap(tensor_chw, img, mode='mean') - >>> heatmap = vis.draw_featmap(tensor_chw, img, mode=None, - topk=8, arrangement=(4, 2)) - >>> heatmap = vis.draw_featmap(tensor_chw, img, mode=None, - topk=-1) + >>> # Backend related methods + >>> vis = Visualizer(vis_backends=[dict(type='LocalVisBackend')], + >>> save_dir='temp_dir') + >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) + >>> vis.add_config(cfg) + >>> image=np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) + >>> vis.add_image('image',image) + >>> vis.add_scaler('mAP', 0.6) + >>> vis.add_scalars({'loss': 0.1,'acc':0.8}) >>> # inherit - >>> class DetVisualizer2(Visualizer): - >>> def add_datasample(self, - >>> image: Optional[np.ndarray] = None, - >>> gt_sample: Optional['BaseDataElement'] = None, - >>> pred_sample: Optional['BaseDataElement'] = None, - >>> show_gt: bool = True, - >>> show_pred: bool = True, - >>> show:bool = True) -> None: - >>> pass + >>> class DetLocalVisualizer(Visualizer): + >>> def add_datasample(self, + >>> name, + >>> image: np.ndarray, + >>> gt_sample: + >>> Optional['BaseDataElement'] = None, + >>> pred_sample: + >>> Optional['BaseDataElement'] = None, + >>> draw_gt: bool = True, + >>> draw_pred: bool = True, + >>> show: bool = False, + >>> wait_time: int = 0, + >>> step: int = 0) -> None: + >>> pass """ def __init__( self, name='visualizer', image: Optional[np.ndarray] = None, - vis_backends: Optional[Dict] = None, + vis_backends: Optional[List[Dict]] = None, save_dir: Optional[str] = None, fig_save_cfg=dict(frameon=False), fig_show_cfg=dict(frameon=False, num='show') ) -> None: - super().__init__(name) - self._dataset_meta: Union[None, dict] = None + super(Visualizer, self).__init__(name) + self._dataset_meta: Optional[dict] = None self._vis_backends: Union[Dict, Dict[str, 'BaseVisBackend']] = dict() - if vis_backends: - with_name = False - without_name = False - for vis_backend in vis_backends: - if 'name' in vis_backend: - with_name = True + if save_dir is None: + warnings.warn('`Visualizer` backend is not initialized ' + 'because save_dir is None.') + elif vis_backends is not None: + assert len(vis_backends) > 0, 'empty list' + names = [ + vis_backend.get('name', None) for vis_backend in vis_backends + ] + if None in names: + if len(set(names)) > 1: + raise RuntimeError( + 'If one of them has a name attribute, ' + 'all backends must use the name attribute') else: - without_name = True - if with_name and without_name: - raise AssertionError + type_names = [ + vis_backend['type'] for vis_backend in vis_backends + ] + if len(set(type_names)) != len(type_names): + raise RuntimeError( + 'The same vis backend cannot exist in ' + '`vis_backend` config. ' + 'Please specify the name field.') + + if None not in names and len(set(names)) != len(names): + raise RuntimeError('The name fields cannot be the same') + + save_dir = osp.join(save_dir, 'vis_data') for vis_backend in vis_backends: name = vis_backend.pop('name', vis_backend['type']) - assert name not in self._vis_backends vis_backend.setdefault('save_dir', save_dir) self._vis_backends[name] = VISBACKENDS.build(vis_backend) @@ -165,17 +208,23 @@ class Visualizer(ManagerMixin): (self.fig_save, self.ax_save, self.fig_save_num) = self._initialize_fig(fig_save_cfg) self.dpi = self.fig_save.get_dpi() + if image is not None: self.set_image(image) - @property + @property # type: ignore + @master_only def dataset_meta(self) -> Optional[dict]: + """Optional[dict]: Meta info of the dataset.""" return self._dataset_meta - @dataset_meta.setter + @dataset_meta.setter # type: ignore + @master_only def dataset_meta(self, dataset_meta: dict) -> None: + """Set the dataset meta info to the Visualizer.""" self._dataset_meta = dataset_meta + @master_only def show(self, drawn_img: Optional[np.ndarray] = None, win_name: str = 'image', @@ -184,8 +233,14 @@ class Visualizer(ManagerMixin): """Show the drawn image. Args: - wait_time (int, optional): Delay in milliseconds. 0 is the special + drawn_img (np.ndarray, optional): The image to show. If drawn_img + is None, it will show the image got by Visualizer. Defaults + to None. + win_name (str): The image title. Defaults to 'image'. + wait_time (int): Delay in milliseconds. 0 is the special value that means "forever". Defaults to 0. + continue_key (str): The key for users to continue. Defaults to + the space key. """ if self.is_inline: return @@ -193,20 +248,16 @@ class Visualizer(ManagerMixin): (self.fig_show, self.ax_show, self.fig_show_num) = self._initialize_fig(self.fig_show_cfg) img = self.get_image() if drawn_img is None else drawn_img - # dpi = self.fig_show.get_dpi() - # height, width = img.shape[:2] - # self.fig_show.set_size_inches((width + 1e-2) / dpi, - # (height + 1e-2) / dpi) self.ax_show.cla() self.ax_show.axis(False) - # self.ax_show.set_title(win_name) - # self.fig_show.set_label(win_name) - + self.fig_show.canvas.manager.set_window_title(win_name) # type: ignore # Refresh canvas, necessary for Qt5 backend. self.ax_show.imshow(img) self.fig_show.canvas.draw() # type: ignore - self._wait_continue(timeout=wait_time, continue_key=continue_key) + wait_continue( + self.fig_show, timeout=wait_time, continue_key=continue_key) + @master_only def set_image(self, image: np.ndarray) -> None: """Set the image to draw. @@ -232,30 +283,43 @@ class Visualizer(ManagerMixin): extent=(0, self.width, self.height, 0), interpolation='none') + @master_only def get_image(self) -> np.ndarray: """Get the drawn image. The format is RGB. Returns: - np.ndarray: the drawn image which channel is rgb. + np.ndarray: the drawn image which channel is RGB. """ assert self._image is not None, 'Please set image using `set_image`' - canvas = self.fig_save.canvas # type: ignore - s, (width, height) = canvas.print_to_buffer() - buffer = np.frombuffer(s, dtype='uint8') - img_rgba = buffer.reshape(height, width, 4) - rgb, alpha = np.split(img_rgba, [3], axis=2) - return rgb.astype('uint8') - - def _initialize_fig(self, fig_cfg): + return img_from_canvas(self.fig_save.canvas) # type: ignore + + def _initialize_fig(self, fig_cfg) -> tuple: + """Build figure according to fig_cfg. + + Args: + fig_cfg (dict): The config to build figure. + + Returns: + tuple: build figure, axes and fig number. + """ fig = plt.figure(**fig_cfg) ax = fig.add_subplot() ax.axis(False) # remove white edges by set subplot margin fig.subplots_adjust(left=0, right=1, bottom=0, top=1) - return fig, ax, fig.number + return (fig, ax, fig.number) + @master_only def get_backend(self, name) -> 'BaseVisBackend': + """get vis backend by name. + + Args: + name (str): The name of vis backend + + Returns: + BaseVisBackend: The vis backend. + """ return self._vis_backends.get(name) # type: ignore def _is_posion_valid(self, position: np.ndarray) -> bool: @@ -274,65 +338,28 @@ class Visualizer(ManagerMixin): (position[..., 1] >= 0).all() return flag - def _wait_continue(self, timeout: int = 0, continue_key=' ') -> int: - """Show the image and wait for the user's input. - - This implementation refers to - https://github.com/matplotlib/matplotlib/blob/v3.5.x/lib/matplotlib/_blocking_input.py - - Args: - timeout (int): If positive, continue after ``timeout`` seconds. - Defaults to 0. - continue_key (str): The key for users to continue. Defaults to - the space key. - - Returns: - int: If zero, means time out or the user pressed ``continue_key``, - and if one, means the user closed the show figure. - """ # noqa: E501 - if self.is_inline: - # If use inline backend, interactive input and timeout is no use. - return 0 - - if self.fig_show.canvas.manager: # type: ignore - # Ensure that the figure is shown - self.fig_show.show() # type: ignore - - while True: - - # Connect the events to the handler function call. - event = None - - def handler(ev): - # Set external event variable - nonlocal event - # Qt backend may fire two events at the same time, - # use a condition to avoid missing close event. - event = ev if not isinstance(event, CloseEvent) else event - self.fig_show.canvas.stop_event_loop() - - cids = [ - self.fig_show.canvas.mpl_connect(name, handler) # type: ignore - for name in ('key_press_event', 'close_event') - ] - - try: - self.fig_show.canvas.start_event_loop(timeout) # type: ignore - finally: # Run even on exception like ctrl-c. - # Disconnect the callbacks. - for cid in cids: - self.fig_show.canvas.mpl_disconnect(cid) # type: ignore - - if isinstance(event, CloseEvent): - return 1 # Quit for close. - elif event is None or event.key == continue_key: - return 0 # Quit for continue. - + @master_only def draw_points(self, positions: Union[np.ndarray, torch.Tensor], colors: Union[str, tuple, List[str], List[tuple]] = 'g', marker: Optional[str] = None, sizes: Optional[Union[np.ndarray, torch.Tensor]] = None): + """Draw single or multiple points. + + Args: + positions (Union[np.ndarray, torch.Tensor]): Positions to draw. + colors (Union[str, tuple, List[str], List[tuple]]): The colors + of points. ``colors`` can have the same length with points or + just single value. If ``colors`` is single value, all the + points will have the same colors. Reference to + https://matplotlib.org/stable/gallery/color/named_colors.html + for more details. Defaults to 'g. + marker (str, optional): The marker style. + See :mod:`matplotlib.markers` for more information about + marker styles. Default to None. + sizes (Optional[Union[np.ndarray, torch.Tensor]]): The marker size. + Default to None. + """ check_type('positions', positions, (np.ndarray, torch.Tensor)) positions = tensor2ndarray(positions) @@ -341,11 +368,12 @@ class Visualizer(ManagerMixin): assert positions.shape[-1] == 2, ( 'The shape of `positions` should be (N, 2), ' f'but got {positions.shape}') - colors = color_val_matplotlib(colors) + colors = color_val_matplotlib(colors) # type: ignore self.ax_save.scatter( positions[:, 0], positions[:, 1], c=colors, s=sizes, marker=marker) return self + @master_only def draw_texts( self, texts: Union[str, List[str]], @@ -355,7 +383,6 @@ class Visualizer(ManagerMixin): vertical_alignments: Union[str, List[str]] = 'top', horizontal_alignments: Union[str, List[str]] = 'left', font_families: Union[str, List[str]] = 'sans-serif', - rotations: Union[int, str, List[Union[int, str]]] = 0, bboxes: Optional[Union[dict, List[dict]]] = None) -> 'Visualizer': """Draw single or multiple text boxes. @@ -398,11 +425,6 @@ class Visualizer(ManagerMixin): the texts will have the same font family. font_familiy can be 'serif', 'sans-serif', 'cursive', 'fantasy' or 'monospace'. Defaults to 'sans-serif'. - rotations (Union[int, List[int]]): The rotation degrees of - texts. ``rotations`` can have the same length with texts or - just single value. If ``rotations`` is single value, all the - texts will have the same rotation. rotation can be angle - in degrees, 'vertical' or 'horizontal'. Defaults to 0. bboxes (Union[dict, List[dict]], optional): The bounding box of the texts. If bboxes is None, there are no bounding box around texts. ``bboxes`` can have the same length with texts or @@ -436,7 +458,7 @@ class Visualizer(ManagerMixin): check_type_and_length('colors', colors, (str, tuple, list), num_text) colors = value2list(colors, (str, tuple), num_text) - colors = color_val_matplotlib(colors) + colors = color_val_matplotlib(colors) # type: ignore check_type_and_length('vertical_alignments', vertical_alignments, (str, list), num_text) @@ -447,9 +469,6 @@ class Visualizer(ManagerMixin): horizontal_alignments = value2list(horizontal_alignments, str, num_text) - check_type_and_length('rotations', rotations, (int, list), num_text) - rotations = value2list(rotations, int, num_text) - check_type_and_length('font_families', font_families, (str, list), num_text) font_families = value2list(font_families, str, num_text) @@ -473,6 +492,7 @@ class Visualizer(ManagerMixin): color=colors[i]) return self + @master_only def draw_lines( self, x_datas: Union[np.ndarray, torch.Tensor], @@ -519,7 +539,7 @@ class Visualizer(ManagerMixin): if len(x_datas.shape) == 1: x_datas = x_datas[None] y_datas = y_datas[None] - colors = color_val_matplotlib(colors) + colors = color_val_matplotlib(colors) # type: ignore lines = np.concatenate( (x_datas.reshape(-1, 2, 1), y_datas.reshape(-1, 2, 1)), axis=-1) if not self._is_posion_valid(lines): @@ -534,23 +554,24 @@ class Visualizer(ManagerMixin): self.ax_save.add_collection(line_collect) return self + @master_only def draw_circles( self, center: Union[np.ndarray, torch.Tensor], radius: Union[np.ndarray, torch.Tensor], - alpha: Union[float, int] = 0.8, edge_colors: Union[str, tuple, List[str], List[tuple]] = 'g', line_styles: Union[str, List[str]] = '-', line_widths: Union[Union[int, float], List[Union[int, float]]] = 2, - face_colors: Union[str, tuple, List[str], List[tuple]] = 'none' + face_colors: Union[str, tuple, List[str], List[tuple]] = 'none', + alpha: Union[float, int] = 0.8, ) -> 'Visualizer': """Draw single or multiple circles. Args: center (Union[np.ndarray, torch.Tensor]): The x coordinate of - each line' start and end points. + each line' start and end points. radius (Union[np.ndarray, torch.Tensor]): The y coordinate of - each line' start and end points. + each line' start and end points. edge_colors (Union[str, tuple, List[str], List[tuple]]): The colors of circles. ``colors`` can have the same length with lines or just single value. If ``colors`` is single value, @@ -569,8 +590,10 @@ class Visualizer(ManagerMixin): the same length with lines or just single value. If ``line_widths`` is single value, all the lines will have the same linewidth. Defaults to 2. - is_filling (bool): Whether to fill all the circles. Defaults to - False. + face_colors (Union[str, tuple, List[str], List[tuple]]): + The face colors. Default to None. + alpha (Union[int, float]): The transparency of circles. + Defaults to 0.8. """ check_type('center', center, (np.ndarray, torch.Tensor)) center = tensor2ndarray(center) @@ -591,8 +614,8 @@ class Visualizer(ManagerMixin): center = center.tolist() radius = radius.tolist() - edge_colors = color_val_matplotlib(edge_colors) - face_colors = color_val_matplotlib(face_colors) + edge_colors = color_val_matplotlib(edge_colors) # type: ignore + face_colors = color_val_matplotlib(face_colors) # type: ignore circles = [] for i in range(len(center)): circles.append(Circle(tuple(center[i]), radius[i])) @@ -613,14 +636,15 @@ class Visualizer(ManagerMixin): self.ax_save.add_collection(p) return self + @master_only def draw_bboxes( self, bboxes: Union[np.ndarray, torch.Tensor], - alpha: Union[int, float] = 0.8, edge_colors: Union[str, tuple, List[str], List[tuple]] = 'g', line_styles: Union[str, List[str]] = '-', line_widths: Union[Union[int, float], List[Union[int, float]]] = 2, - face_colors: Union[str, tuple, List[str], List[tuple]] = 'none' + face_colors: Union[str, tuple, List[str], List[tuple]] = 'none', + alpha: Union[int, float] = 0.8, ) -> 'Visualizer': """Draw single or multiple bboxes. @@ -644,9 +668,11 @@ class Visualizer(ManagerMixin): The linewidth of lines. ``line_widths`` can have the same length with lines or just single value. If ``line_widths`` is single value, all the lines will - have the same linewidth. Defaults to 1. - is_filling (bool): Whether to fill all the bboxes. Defaults to - False. + have the same linewidth. Defaults to 2. + face_colors (Union[str, tuple, List[str], List[tuple]]): + The face colors. Default to None. + alpha (Union[int, float]): The transparency of bboxes. + Defaults to 0.8. """ check_type('bboxes', bboxes, (np.ndarray, torch.Tensor)) bboxes = tensor2ndarray(bboxes) @@ -675,15 +701,16 @@ class Visualizer(ManagerMixin): line_widths=line_widths, face_colors=face_colors) + @master_only def draw_polygons( self, polygons: Union[Union[np.ndarray, torch.Tensor], List[Union[np.ndarray, torch.Tensor]]], - alpha: Union[int, float] = 0.8, edge_colors: Union[str, tuple, List[str], List[tuple]] = 'g', line_styles: Union[str, List[str]] = '-', line_widths: Union[Union[int, float], List[Union[int, float]]] = 2, - face_colors: Union[str, tuple, List[str], List[tuple]] = 'none' + face_colors: Union[str, tuple, List[str], List[tuple]] = 'none', + alpha: Union[int, float] = 0.8, ) -> 'Visualizer': """Draw single or multiple bboxes. @@ -709,12 +736,14 @@ class Visualizer(ManagerMixin): the same length with lines or just single value. If ``line_widths`` is single value, all the lines will have the same linewidth. Defaults to 2. - is_filling (bool): Whether to fill all the polygons. Defaults to - False. + face_colors (Union[str, tuple, List[str], List[tuple]]): + The face colors. Default to None. + alpha (Union[int, float]): The transparency of polygons. + Defaults to 0.8. """ check_type('polygons', polygons, (list, np.ndarray, torch.Tensor)) - edge_colors = color_val_matplotlib(edge_colors) - face_colors = color_val_matplotlib(face_colors) + edge_colors = color_val_matplotlib(edge_colors) # type: ignore + face_colors = color_val_matplotlib(face_colors) # type: ignore if isinstance(polygons, (np.ndarray, torch.Tensor)): polygons = [polygons] @@ -746,12 +775,12 @@ class Visualizer(ManagerMixin): self.ax_save.add_collection(polygon_collection) return self + @master_only def draw_binary_masks( - self, - binary_masks: Union[np.ndarray, torch.Tensor], - alphas: Union[float, List[float]] = 0.8, - colors: Union[str, tuple, List[str], - List[tuple]] = 'g') -> 'Visualizer': + self, + binary_masks: Union[np.ndarray, torch.Tensor], + colors: Union[str, tuple, List[str], List[tuple]] = 'g', + alphas: Union[float, List[float]] = 0.8) -> 'Visualizer': """Draw single or multiple binary masks. Args: @@ -764,8 +793,8 @@ class Visualizer(ManagerMixin): single value. If ``colors`` is single value, all the binary_masks will convert to the same colors. The colors format is RGB. Defaults to np.array([0, 255, 0]). - alphas (Union[int, List[int]]): The transparency of origin image. - Defaults to 0.5. + alphas (Union[int, List[int]]): The transparency of masks. + Defaults to 0.8. """ check_type('binary_masks', binary_masks, (np.ndarray, torch.Tensor)) binary_masks = tensor2ndarray(binary_masks) @@ -777,23 +806,22 @@ class Visualizer(ManagerMixin): if binary_masks.ndim == 2: binary_masks = binary_masks[None] assert img.shape[:2] == binary_masks.shape[ - 1:], '`binary_marks` must have the same shpe with image' + 1:], '`binary_marks` must have ' \ + 'the same shape with image' binary_mask_len = binary_masks.shape[0] check_type_and_length('colors', colors, (str, tuple, list), binary_mask_len) colors = value2list(colors, (str, tuple), binary_mask_len) colors = [ - str_color_to_rgb(color) if isinstance(color, str) else color + color_str2rgb(color) if isinstance(color, str) else color for color in colors ] for color in colors: assert len(color) == 3 for channel in color: assert 0 <= channel <= 255 # type: ignore - colors = np.array(colors) - if colors.ndim == 1: # type: ignore - colors = np.tile(colors, (binary_mask_len, 1)) + if isinstance(alphas, float): alphas = [alphas] * binary_mask_len @@ -813,138 +841,177 @@ class Visualizer(ManagerMixin): return self @staticmethod - def draw_featmap(tensor_chw: torch.Tensor, - image: Optional[np.ndarray] = None, - mode: str = 'mean', - topk: int = 10, - arrangement: Tuple[int, int] = (5, 2), - alpha: float = 0.8) -> np.ndarray: - """Draw featmap. If img is not None, the final image will be the - weighted sum of img and featmap. It support the mode: - - - if mode is not None, it will compress tensor_chw to single channel - image and sum to image. - - if mode is None. - - - if topk <= 0, tensor_chw is assert to be one or three - channel and treated as image and will be sum to ``image``. - - if topk > 0, it will select topk channel to show by the sum of - each channel. + @master_only + def draw_featmap(featmap: torch.Tensor, + overlaid_image: Optional[np.ndarray] = None, + channel_reduction: Optional[str] = 'squeeze_mean', + topk: int = 20, + arrangement: Tuple[int, int] = (4, 5), + resize_shape: Optional[tuple] = None, + alpha: float = 0.5) -> np.ndarray: + """Draw featmap. + + - If `overlaid_image` is not None, the final output image will be the + weighted sum of img and featmap. + + - If `resize_shape` is specified, `featmap` and `overlaid_image` + are interpolated. + + - If `resize_shape` is None and `overlaid_image` is not None, + the feature map will be interpolated to the spatial size of the image + in the case where the spatial dimensions of `overlaid_image` and + `featmap` are different. + + - If `channel_reduction` is "squeeze_mean" and "select_max", + it will compress featmap to single channel image and weighted + sum to `overlaid_image`. + + - if `channel_reduction` is None + + - If topk <= 0, featmap is assert to be one or three + channel and treated as image and will be weighted sum + to ``overlaid_image``. + - If topk > 0, it will select topk channel to show by the sum of + each channel. At the same time, you can specify the `arrangement` + to set the window layout. Args: - tensor_chw (torch.Tensor): The featmap to draw which format is + featmap (torch.Tensor): The featmap to draw which format is (C, H, W). - image (np.ndarray): The colors which binary_masks will convert to. - ``colors`` can have the same length with binary_masks or just - single value. If ``colors`` is single value, all the - binary_masks will convert to the same colors. The colors format - is rgb. Defaults to np.array([0, 255, 0]). - mode (str): The mode to compress `tensor_chw` to single channel. - Defaults to 'mean'. - topk (int): If mode is not None and topk > 0, it will select topk - channel to show by the sum of each channel. if topk <= 0, - tensor_chw is assert to be one or three. Defaults to 10. - arrangement (Tuple[int, int]): The arrangement of featmaps when - mode is not None and topk > 0. Defaults to (5, 2). - alphas (Union[int, List[int]]): The transparency of origin image. + overlaid_image (np.ndarray, optional): The overlaid image. + Default to None. + channel_reduction (str, optional): Reduce multiple channels to a + single channel. The optional value is 'squeeze_mean' + or 'select_max'. Defaults to 'squeeze_mean'. + topk (int): If channel_reduction is not None and topk > 0, + it will select topk channel to show by the sum of each channel. + if topk <= 0, tensor_chw is assert to be one or three. + Defaults to 20. + arrangement (Tuple[int, int]): The arrangement of featmap when + channel_reduction is not None and topk > 0. Defaults to (4, 5). + resize_shape (tuple, optional): The shape to scale the feature map. + Default to None. + alpha (Union[int, List[int]]): The transparency of featmap. Defaults to 0.5. + Returns: - np.ndarray: featmap. + np.ndarray: RGB image. """ - assert isinstance( - tensor_chw, - torch.Tensor), (f'`tensor_chw` should be {torch.Tensor} ' - f' but got {type(tensor_chw)}') - tensor_chw = tensor_chw.detach().cpu() - assert tensor_chw.ndim == 3, 'Input dimension must be 3' - if image is not None: - assert image.shape[:2] == tensor_chw.shape[1:] - if image.ndim == 2: - image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) - if mode is not None: - assert mode in [ - 'mean', 'max', 'min' - ], (f'Mode only support "mean", "max", "min", but got {mode}') - if mode == 'max': - feat_map, _ = torch.max(tensor_chw, dim=0) - elif mode == 'mean': - feat_map = torch.mean(tensor_chw, dim=0) - return convert_overlay_heatmap(feat_map, image, alpha) - - if topk <= 0: - tensor_chw_channel = tensor_chw.shape[0] - assert tensor_chw_channel in [ + assert isinstance(featmap, + torch.Tensor), (f'`featmap` should be torch.Tensor,' + f' but got {type(featmap)}') + assert featmap.ndim == 3, f'Input dimension must be 3, ' \ + f'but got {featmap.ndim}' + featmap = featmap.detach().cpu() + + if overlaid_image is not None: + if overlaid_image.ndim == 2: + overlaid_image = cv2.cvtColor(overlaid_image, + cv2.COLOR_GRAY2RGB) + + if overlaid_image.shape[:2] != featmap.shape[1:]: + warnings.warn( + f'Since the spatial dimensions of ' + f'overlaid_image: {overlaid_image.shape[:2]} and ' + f'featmap: {featmap.shape[1:]} are not same, ' + f'the feature map will be interpolated. ' + f'This may cause mismatch problems ï¼') + if resize_shape is None: + featmap = F.interpolate( + featmap[None], + overlaid_image.shape[:2], + mode='bilinear', + align_corners=False)[0] + + if resize_shape is not None: + featmap = F.interpolate( + featmap[None], + resize_shape, + mode='bilinear', + align_corners=False)[0] + if overlaid_image is not None: + overlaid_image = cv2.resize(overlaid_image, resize_shape[::-1]) + + if channel_reduction is not None: + assert channel_reduction in [ + 'squeeze_mean', 'select_max'], \ + f'Mode only support "squeeze_mean", "select_max", ' \ + f'but got {channel_reduction}' + if channel_reduction == 'select_max': + sum_channel_featmap = torch.sum(featmap, dim=(1, 2)) + _, indices = torch.topk(sum_channel_featmap, 1) + feat_map = featmap[indices] + else: + feat_map = torch.mean(featmap, dim=0) + return convert_overlay_heatmap(feat_map, overlaid_image, alpha) + elif topk <= 0: + featmap_channel = featmap.shape[0] + assert featmap_channel in [ 1, 3 ], ('The input tensor channel dimension must be 1 or 3 ' 'when topk is less than 1, but the channel ' - f'dimension you input is {tensor_chw_channel}, you can use the' - ' mode parameter or set topk greater than 0 to solve ' - 'the error') - if tensor_chw_channel == 1: - return convert_overlay_heatmap(tensor_chw[0], image, alpha) - else: - tensor_chw = tensor_chw.permute(1, 2, 0).numpy() - norm_img = cv2.normalize(tensor_chw, None, 0, 255, - cv2.NORM_MINMAX) - heat_img = np.asarray(norm_img, dtype=np.uint8) - if image is not None: - heat_img = cv2.addWeighted(image, 1 - alpha, heat_img, - alpha, 0) - return heat_img + f'dimension you input is {featmap_channel}, you can use the' + ' channel_reduction parameter or set topk greater than ' + '0 to solve the error') + return convert_overlay_heatmap(featmap, overlaid_image, alpha) else: row, col = arrangement - channel, height, width = tensor_chw.shape - assert row * col >= topk - sum_channel = torch.sum(tensor_chw, dim=(1, 2)) + channel, height, width = featmap.shape + assert row * col >= topk, 'The product of row and col in ' \ + 'the `arrangement` is less than ' \ + 'topk, please set the ' \ + '`arrangement` correctly' + + # Extract the feature map of topk topk = min(channel, topk) - _, indices = torch.topk(sum_channel, topk) - topk_tensor = tensor_chw[indices] - fig = Figure(frameon=False) + sum_channel_featmap = torch.sum(featmap, dim=(1, 2)) + _, indices = torch.topk(sum_channel_featmap, topk) + topk_featmap = featmap[indices] + + fig = plt.figure(frameon=False) + # Set the window layout fig.subplots_adjust( left=0, right=1, bottom=0, top=1, wspace=0, hspace=0) dpi = fig.get_dpi() fig.set_size_inches((width * col + 1e-2) / dpi, (height * row + 1e-2) / dpi) - canvas = FigureCanvasAgg(fig) - fig.subplots_adjust(wspace=0, hspace=0) - fig.tight_layout(h_pad=0, w_pad=0) - for i in range(topk): axes = fig.add_subplot(row, col, i + 1) axes.axis('off') + axes.text(2, 15, f'channel: {indices[i]}', fontsize=10) axes.imshow( - convert_overlay_heatmap(topk_tensor[i], image, alpha)) - s, (width, height) = canvas.print_to_buffer() - buffer = np.frombuffer(s, dtype='uint8') - img_rgba = buffer.reshape(height, width, 4) - rgb, alpha = np.split(img_rgba, [3], axis=2) - return rgb.astype('uint8') + convert_overlay_heatmap(topk_featmap[i], overlaid_image, + alpha)) + return img_from_canvas(fig.canvas) + @master_only def add_config(self, config: Config, **kwargs): - """Record parameters. + """Record the config. Args: config (Config): The Config object. """ for vis_backend in self._vis_backends.values(): - vis_backend.add_config(config, **kwargs) # type: ignore + vis_backend.add_config(config, **kwargs) + @master_only def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict], **kwargs) -> None: - """Record graph data. + """Record the model graph. Args: model (torch.nn.Module): Model to draw. data_batch (Sequence[dict]): Batch of data from dataloader. """ for vis_backend in self._vis_backends.values(): - vis_backend.add_graph(model, data_batch, **kwargs) # type: ignore + vis_backend.add_graph(model, data_batch, **kwargs) + @master_only def add_image(self, name: str, image: np.ndarray, step: int = 0) -> None: - """Record image. + """Record the image. Args: - name (str): The unique identifier for the image to save. + name (str): The image identifier. image (np.ndarray, optional): The image to be saved. The format should be RGB. Default to None. step (int): Global step value to record. Default to 0. @@ -952,27 +1019,29 @@ class Visualizer(ManagerMixin): for vis_backend in self._vis_backends.values(): vis_backend.add_image(name, image, step) # type: ignore + @master_only def add_scalar(self, name: str, value: Union[int, float], step: int = 0, **kwargs) -> None: - """Record scalar data. + """Record the scalar data. Args: - name (str): The unique identifier for the scalar to save. + name (str): The scalar identifier. value (float, int): Value to save. step (int): Global step value to record. Default to 0. """ for vis_backend in self._vis_backends.values(): vis_backend.add_scalar(name, value, step, **kwargs) # type: ignore + @master_only def add_scalars(self, scalar_dict: dict, step: int = 0, file_path: Optional[str] = None, **kwargs) -> None: - """Record scalars' data. + """Record the scalars' data. Args: scalar_dict (dict): Key-value pair storing the tag and @@ -984,9 +1053,9 @@ class Visualizer(ManagerMixin): Default to None. """ for vis_backend in self._vis_backends.values(): - vis_backend.add_scalars( # type: ignore - scalar_dict, step, file_path, **kwargs) + vis_backend.add_scalars(scalar_dict, step, file_path, **kwargs) + @master_only def add_datasample(self, name, image: np.ndarray, @@ -997,6 +1066,7 @@ class Visualizer(ManagerMixin): show: bool = False, wait_time: int = 0, step: int = 0) -> None: + """Draw datasample.""" pass def close(self) -> None: @@ -1005,7 +1075,7 @@ class Visualizer(ManagerMixin): if self.fig_show is not None: plt.close(self.fig_show) for vis_backend in self._vis_backends.values(): - vis_backend.close() # type: ignore + vis_backend.close() @classmethod def get_instance(cls, name: str, **kwargs) -> 'Visualizer': @@ -1034,7 +1104,7 @@ class Visualizer(ManagerMixin): >>> assert id(visualizer1) == id(visualizer2) == id(visualizer3) Args: - name (str): Name of instance. Defaults to ''. + name (str): Name of instance. Returns: object: Corresponding name instance. diff --git a/tests/test_visualizer/test_vis_backend.py b/tests/test_visualizer/test_vis_backend.py index da662a65..1ac5a1bf 100644 --- a/tests/test_visualizer/test_vis_backend.py +++ b/tests/test_visualizer/test_vis_backend.py @@ -6,7 +6,9 @@ from unittest.mock import MagicMock import numpy as np import pytest +import torch +from mmengine import Config from mmengine.fileio import load from mmengine.registry import VISBACKENDS from mmengine.visualization import (LocalVisBackend, TensorboardVisBackend, @@ -16,7 +18,6 @@ from mmengine.visualization import (LocalVisBackend, TensorboardVisBackend, class TestLocalVisBackend: def test_init(self): - # 'config_save_file' format must be py with pytest.raises(AssertionError): LocalVisBackend('temp_dir', config_save_file='a.txt') @@ -25,42 +26,35 @@ class TestLocalVisBackend: with pytest.raises(AssertionError): LocalVisBackend('temp_dir', scalar_save_file='a.yaml') - local_vis_backend = LocalVisBackend('temp_dir') - assert os.path.exists(local_vis_backend._save_dir) - shutil.rmtree('temp_dir') - local_vis_backend = VISBACKENDS.build( dict(type='LocalVisBackend', save_dir='temp_dir')) - assert os.path.exists(local_vis_backend._save_dir) - shutil.rmtree('temp_dir') + assert isinstance(local_vis_backend, LocalVisBackend) def test_experiment(self): local_vis_backend = LocalVisBackend('temp_dir') assert local_vis_backend.experiment == local_vis_backend - shutil.rmtree('temp_dir') def test_add_config(self): + cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) local_vis_backend = LocalVisBackend('temp_dir') - - # 'params_dict' must be dict - with pytest.raises(AssertionError): - local_vis_backend.add_config(['lr', 0]) - - # TODO - + local_vis_backend.add_config(cfg) + assert os.path.exists(local_vis_backend._config_save_file) shutil.rmtree('temp_dir') def test_add_image(self): - image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) + image = np.random.randint(0, 256, size=(10, 10, 3)) local_vis_backend = LocalVisBackend('temp_dir') - local_vis_backend.add_image('img', image) + + # image must be in np.uint8 format + with pytest.raises(AssertionError): + local_vis_backend.add_image('img', image) + + local_vis_backend.add_image('img', image.astype(np.uint8)) assert os.path.exists( os.path.join(local_vis_backend._img_save_dir, 'img_0.png')) - - local_vis_backend.add_image('img', image, step=2) + local_vis_backend.add_image('img', image.astype(np.uint8), step=2) assert os.path.exists( os.path.join(local_vis_backend._img_save_dir, 'img_2.png')) - shutil.rmtree('temp_dir') def test_add_scalar(self): @@ -72,17 +66,25 @@ class TestLocalVisBackend: # test append mode local_vis_backend = LocalVisBackend('temp_dir') - local_vis_backend.add_scalar('map', 0.9, step=0) + local_vis_backend.add_scalar('map', 1, step=0) local_vis_backend.add_scalar('map', 0.95, step=1) + # local_vis_backend.add_scalar('map', torch.IntTensor(1), step=2) + local_vis_backend.add_scalar('map', np.array(0.9), step=2) with open(local_vis_backend._scalar_save_file) as f: out_dict = f.read() - assert out_dict == '{"map": 0.9, "step": 0}\n{"map": ' \ - '0.95, "step": 1}\n' + assert out_dict == '{"map": 1, "step": 0}\n' \ + '{"map": 0.95, "step": 1}\n' \ + '{"map": 0.9, "step": 2}\n' shutil.rmtree('temp_dir') - def test_add_scalars(self): local_vis_backend = LocalVisBackend('temp_dir') + local_vis_backend.add_scalar('map', torch.tensor(1.)) + assert os.path.exists(local_vis_backend._scalar_save_file) + shutil.rmtree('temp_dir') + + def test_add_scalars(self): input_dict = {'map': 0.7, 'acc': 0.9} + local_vis_backend = LocalVisBackend('temp_dir') local_vis_backend.add_scalars(input_dict) out_dict = load(local_vis_backend._scalar_save_file, 'json') assert out_dict == {'map': 0.7, 'acc': 0.9, 'step': 0} @@ -95,7 +97,6 @@ class TestLocalVisBackend: '"step": 0}\n{"map": 0.8, "acc": 0.8, "step": 1}\n' # test file_path - local_vis_backend = LocalVisBackend('temp_dir') local_vis_backend.add_scalars(input_dict, file_path='temp.json') assert os.path.exists(local_vis_backend._scalar_save_file) assert os.path.exists( @@ -113,7 +114,6 @@ class TestTensorboardVisBackend: sys.modules['tensorboardX'] = MagicMock() def test_init(self): - TensorboardVisBackend('temp_dir') VISBACKENDS.build( dict(type='TensorboardVisBackend', save_dir='temp_dir')) @@ -122,22 +122,20 @@ class TestTensorboardVisBackend: tensorboard_vis_backend = TensorboardVisBackend('temp_dir') assert (tensorboard_vis_backend.experiment == tensorboard_vis_backend._tensorboard) - - def test_add_graph(self): - # TODO - pass + shutil.rmtree('temp_dir') def test_add_config(self): - # TODO - pass + cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) + tensorboard_vis_backend = TensorboardVisBackend('temp_dir') + tensorboard_vis_backend.add_config(cfg) + shutil.rmtree('temp_dir') def test_add_image(self): image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) - tensorboard_vis_backend = TensorboardVisBackend('temp_dir') tensorboard_vis_backend.add_image('img', image) - tensorboard_vis_backend.add_image('img', image, step=2) + shutil.rmtree('temp_dir') def test_add_scalar(self): tensorboard_vis_backend = TensorboardVisBackend('temp_dir') @@ -146,6 +144,11 @@ class TestTensorboardVisBackend: tensorboard_vis_backend.add_scalar('map', 0.9, step=0) tensorboard_vis_backend.add_scalar('map', 0.95, step=1) + # Unprocessable data will output a warning message + with pytest.warns(Warning): + tensorboard_vis_backend.add_scalar('map', [0.95]) + shutil.rmtree('temp_dir') + def test_add_scalars(self): tensorboard_vis_backend = TensorboardVisBackend('temp_dir') # The step value must be passed through the parameter @@ -156,45 +159,72 @@ class TestTensorboardVisBackend: 'step': 1 }) + # Unprocessable data will output a warning message + with pytest.warns(Warning): + tensorboard_vis_backend.add_scalars({ + 'map': [1, 2], + }) + input_dict = {'map': 0.7, 'acc': 0.9} tensorboard_vis_backend.add_scalars(input_dict) # test append mode tensorboard_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8}, step=1) + shutil.rmtree('temp_dir') + + def test_close(self): + tensorboard_vis_backend = TensorboardVisBackend('temp_dir') + tensorboard_vis_backend._init_env() + tensorboard_vis_backend.close() + shutil.rmtree('temp_dir') class TestWandbVisBackend: sys.modules['wandb'] = MagicMock() + sys.modules['wandb.run'] = MagicMock() def test_init(self): - WandbVisBackend() + WandbVisBackend('temp_dir') VISBACKENDS.build(dict(type='WandbVisBackend', save_dir='temp_dir')) def test_experiment(self): - wandb_vis_backend = WandbVisBackend() + wandb_vis_backend = WandbVisBackend('temp_dir') assert wandb_vis_backend.experiment == wandb_vis_backend._wandb + shutil.rmtree('temp_dir') def test_add_config(self): - # TODO - pass + cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) + wandb_vis_backend = WandbVisBackend('temp_dir') + _wandb = wandb_vis_backend.experiment + _wandb.run.dir = 'temp_dir' + wandb_vis_backend.add_config(cfg) + shutil.rmtree('temp_dir') def test_add_image(self): image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) - - wandb_vis_backend = WandbVisBackend() + wandb_vis_backend = WandbVisBackend('temp_dir') wandb_vis_backend.add_image('img', image) - - wandb_vis_backend.add_image('img', image, step=2) + wandb_vis_backend.add_image('img', image) + shutil.rmtree('temp_dir') def test_add_scalar(self): - wandb_vis_backend = WandbVisBackend() + wandb_vis_backend = WandbVisBackend('temp_dir') wandb_vis_backend.add_scalar('map', 0.9) # test append mode - wandb_vis_backend.add_scalar('map', 0.9, step=0) - wandb_vis_backend.add_scalar('map', 0.95, step=1) + wandb_vis_backend.add_scalar('map', 0.9) + wandb_vis_backend.add_scalar('map', 0.95) + shutil.rmtree('temp_dir') def test_add_scalars(self): - wandb_vis_backend = WandbVisBackend() + wandb_vis_backend = WandbVisBackend('temp_dir') input_dict = {'map': 0.7, 'acc': 0.9} wandb_vis_backend.add_scalars(input_dict) # test append mode - wandb_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8}, step=1) + wandb_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8}) + wandb_vis_backend.add_scalars({'map': [0.8], 'acc': 0.8}) + shutil.rmtree('temp_dir') + + def test_close(self): + wandb_vis_backend = WandbVisBackend('temp_dir') + wandb_vis_backend._init_env() + wandb_vis_backend.close() + shutil.rmtree('temp_dir') diff --git a/tests/test_visualizer/test_visualizer.py b/tests/test_visualizer/test_visualizer.py index ce3de94d..cc8fa2b8 100644 --- a/tests/test_visualizer/test_visualizer.py +++ b/tests/test_visualizer/test_visualizer.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy -from typing import Any, List, Optional, Union +from typing import Any from unittest import TestCase import matplotlib.pyplot as plt @@ -9,14 +9,14 @@ import pytest import torch import torch.nn as nn -from mmengine import VISBACKENDS +from mmengine import VISBACKENDS, Config from mmengine.visualization import Visualizer @VISBACKENDS.register_module() class MockVisBackend: - def __init__(self, save_dir: Optional[str] = None): + def __init__(self, save_dir: str): self._save_dir = save_dir self._close = False @@ -24,33 +24,22 @@ class MockVisBackend: def experiment(self) -> Any: return self - def add_config(self, params_dict: dict, **kwargs) -> None: + def add_config(self, config, **kwargs) -> None: self._add_config = True - def add_graph(self, model: torch.nn.Module, - input_tensor: Union[torch.Tensor, - List[torch.Tensor]], **kwargs) -> None: - + def add_graph(self, model, data_batch, **kwargs) -> None: self._add_graph = True - def add_image(self, - name: str, - image: np.ndarray, - step: int = 0, - **kwargs) -> None: + def add_image(self, name, image, step=0, **kwargs) -> None: self._add_image = True - def add_scalar(self, - name: str, - value: Union[int, float], - step: int = 0, - **kwargs) -> None: + def add_scalar(self, name, value, step=0, **kwargs) -> None: self._add_scalar = True def add_scalars(self, - scalar_dict: dict, - step: int = 0, - file_path: Optional[str] = None, + scalar_dict, + step=0, + file_path=None, **kwargs) -> None: self._add_scalars = True @@ -70,22 +59,68 @@ class TestVisualizer(TestCase): self.image = np.random.randint( 0, 256, size=(10, 10, 3)).astype('uint8') self.vis_backend_cfg = [ - dict(type='MockVisBackend', name='mock1', save_dir='tmp'), - dict(type='MockVisBackend', name='mock2', save_dir='tmp') + dict(type='MockVisBackend', name='mock1'), + dict(type='MockVisBackend', name='mock2') ] def test_init(self): visualizer = Visualizer(image=self.image) visualizer.get_image() + # test save_dir + with pytest.warns( + Warning, + match='`Visualizer` backend is not initialized ' + 'because save_dir is None.'): + Visualizer() + visualizer = Visualizer( vis_backends=copy.deepcopy(self.vis_backend_cfg)) + assert visualizer.get_backend('mock1') is None + + visualizer = Visualizer( + vis_backends=copy.deepcopy(self.vis_backend_cfg), + save_dir='temp_dir') assert isinstance(visualizer.get_backend('mock1'), MockVisBackend) assert len(visualizer._vis_backends) == 2 - # test global + # test empty list + with pytest.raises(AssertionError): + Visualizer(vis_backends=[], save_dir='temp_dir') + + # test name + # If one of them has a name attribute, all backends must + # use the name attribute + with pytest.raises(RuntimeError): + Visualizer( + vis_backends=[ + dict(type='MockVisBackend'), + dict(type='MockVisBackend', name='mock2') + ], + save_dir='temp_dir') + + # The name fields cannot be the same + with pytest.raises(RuntimeError): + Visualizer( + vis_backends=[ + dict(type='MockVisBackend'), + dict(type='MockVisBackend') + ], + save_dir='temp_dir') + + with pytest.raises(RuntimeError): + Visualizer( + vis_backends=[ + dict(type='MockVisBackend', name='mock1'), + dict(type='MockVisBackend', name='mock1') + ], + save_dir='temp_dir') + + # test global init visualizer = Visualizer.get_instance( - 'visualizer', vis_backends=copy.deepcopy(self.vis_backend_cfg)) + 'visualizer', + vis_backends=copy.deepcopy(self.vis_backend_cfg), + save_dir='temp_dir') assert len(visualizer._vis_backends) == 2 visualizer_any = Visualizer.get_instance('visualizer') assert visualizer_any == visualizer @@ -131,7 +166,9 @@ class TestVisualizer(TestCase): def test_close(self): visualizer = Visualizer( - image=self.image, vis_backends=copy.deepcopy(self.vis_backend_cfg)) + image=self.image, + vis_backends=copy.deepcopy(self.vis_backend_cfg), + save_dir='temp_dir') fig_num = visualizer.fig_save_num assert fig_num in plt.get_fignums() for name in ['mock1', 'mock2']: @@ -141,6 +178,23 @@ class TestVisualizer(TestCase): for name in ['mock1', 'mock2']: assert visualizer.get_backend(name)._close is True + def test_draw_points(self): + visualizer = Visualizer(image=self.image) + + with pytest.raises(TypeError): + visualizer.draw_points(positions=[1, 2]) + with pytest.raises(AssertionError): + visualizer.draw_points(positions=np.array([1, 2, 3])) + # test color + visualizer.draw_points( + positions=torch.tensor([[1, 1], [3, 3]]), + colors=['g', (255, 255, 0)]) + visualizer.draw_points( + positions=torch.tensor([[1, 1], [3, 3]]), + colors=['g', (255, 255, 0)], + marker='.', + sizes=[1, 5]) + def test_draw_texts(self): visualizer = Visualizer(image=self.image) @@ -340,17 +394,48 @@ class TestVisualizer(TestCase): def test_draw_featmap(self): visualizer = Visualizer() image = np.random.randint(0, 256, size=(3, 3, 3), dtype='uint8') + + # must be Tensor + with pytest.raises( + AssertionError, + match='`featmap` should be torch.Tensor, but got ' + "<class 'numpy.ndarray'>"): + visualizer.draw_featmap(np.ones((3, 3, 3))) + # test tensor format - with pytest.raises(AssertionError, match='Input dimension must be 3'): + with pytest.raises( + AssertionError, match='Input dimension must be 3, but got 4'): visualizer.draw_featmap(torch.randn(1, 1, 3, 3)) - # test mode parameter - # mode only supports 'mean' and 'max' - with pytest.raises(AssertionError): - visualizer.draw_featmap(torch.randn(2, 3, 3), mode='xx') - # test tensor_chw and img have difference height and width + # test overlaid_image shape + with pytest.warns(Warning): + visualizer.draw_featmap(torch.randn(1, 4, 3), overlaid_image=image) + + # test resize_shape + featmap = visualizer.draw_featmap( + torch.randn(1, 4, 3), resize_shape=(6, 7)) + assert featmap.shape[:2] == (6, 7) + featmap = visualizer.draw_featmap( + torch.randn(1, 4, 3), overlaid_image=image, resize_shape=(6, 7)) + assert featmap.shape[:2] == (6, 7) + + # test channel_reduction parameter + # mode only supports 'squeeze_mean' and 'select_max' with pytest.raises(AssertionError): - visualizer.draw_featmap(torch.randn(2, 3, 3), mode='xx') + visualizer.draw_featmap( + torch.randn(2, 3, 3), channel_reduction='xx') + + featmap = visualizer.draw_featmap( + torch.randn(2, 3, 3), channel_reduction='squeeze_mean') + assert featmap.shape[:2] == (3, 3) + featmap = visualizer.draw_featmap( + torch.randn(2, 3, 3), channel_reduction='select_max') + assert featmap.shape[:2] == (3, 3) + featmap = visualizer.draw_featmap( + torch.randn(2, 4, 3), + overlaid_image=image, + channel_reduction='select_max') + assert featmap.shape[:2] == (3, 3) # test topk parameter with pytest.raises( @@ -358,36 +443,56 @@ class TestVisualizer(TestCase): match='The input tensor channel dimension must be 1 or 3 ' 'when topk is less than 1, but the channel ' 'dimension you input is 6, you can use the ' - 'mode parameter or set topk greater than 0 to solve ' - 'the error'): - visualizer.draw_featmap(torch.randn(6, 3, 3), mode=None, topk=0) - - visualizer.draw_featmap(torch.randn(6, 3, 3), mode='mean') - visualizer.draw_featmap(torch.randn(1, 3, 3), mode='mean') - visualizer.draw_featmap(torch.randn(6, 3, 3), mode='max') - visualizer.draw_featmap(torch.randn(6, 3, 3), mode='max', topk=10) - visualizer.draw_featmap(torch.randn(1, 3, 3), mode=None, topk=-1) - visualizer.draw_featmap( - torch.randn(3, 3, 3), image=image, mode=None, topk=-1) - visualizer.draw_featmap(torch.randn(6, 3, 3), mode=None, topk=4) - visualizer.draw_featmap( - torch.randn(6, 3, 3), image=image, mode=None, topk=8) + 'channel_reduction parameter or set topk ' + 'greater than 0 to solve the error'): + visualizer.draw_featmap( + torch.randn(6, 3, 3), channel_reduction=None, topk=0) + + featmap = visualizer.draw_featmap( + torch.randn(6, 3, 3), channel_reduction='select_max', topk=10) + assert featmap.shape[:2] == (3, 3) + featmap = visualizer.draw_featmap( + torch.randn(1, 4, 3), channel_reduction=None, topk=-1) + assert featmap.shape[:2] == (4, 3) + + featmap = visualizer.draw_featmap( + torch.randn(3, 4, 3), + overlaid_image=image, + channel_reduction=None, + topk=-1) + assert featmap.shape[:2] == (3, 3) + featmap = visualizer.draw_featmap( + torch.randn(6, 3, 3), + channel_reduction=None, + topk=4, + arrangement=(2, 2)) + assert featmap.shape[:2] == (6, 6) + featmap = visualizer.draw_featmap( + torch.randn(6, 3, 3), + channel_reduction=None, + topk=4, + arrangement=(1, 4)) + assert featmap.shape[:2] == (3, 12) + with pytest.raises( + AssertionError, + match='The product of row and col in the `arrangement` ' + 'is less than topk, please set ' + 'the `arrangement` correctly'): + visualizer.draw_featmap( + torch.randn(6, 3, 3), + channel_reduction=None, + topk=4, + arrangement=(1, 2)) # test gray - visualizer.draw_featmap( + featmap = visualizer.draw_featmap( torch.randn(6, 3, 3), - image=np.random.randint(0, 256, size=(3, 3), dtype='uint8'), - mode=None, - topk=8) - - # test arrangement - with pytest.raises(AssertionError): - visualizer.draw_featmap( - torch.randn(10, 3, 3), - image=image, - mode=None, - topk=8, - arrangement=(2, 2)) + overlaid_image=np.random.randint( + 0, 256, size=(3, 3), dtype='uint8'), + channel_reduction=None, + topk=4, + arrangement=(2, 2)) + assert featmap.shape[:2] == (6, 6) def test_chain_call(self): visualizer = Visualizer(image=self.image) @@ -402,22 +507,26 @@ class TestVisualizer(TestCase): def test_get_backend(self): visualizer = Visualizer( - image=self.image, vis_backends=copy.deepcopy(self.vis_backend_cfg)) + image=self.image, + vis_backends=copy.deepcopy(self.vis_backend_cfg), + save_dir='temp_dir') for name in ['mock1', 'mock2']: assert isinstance(visualizer.get_backend(name), MockVisBackend) def test_add_config(self): visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg)) + vis_backends=copy.deepcopy(self.vis_backend_cfg), + save_dir='temp_dir') - params_dict = dict(lr=0.1, wd=0.2, mode='linear') - visualizer.add_config(params_dict) + cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) + visualizer.add_config(cfg) for name in ['mock1', 'mock2']: assert visualizer.get_backend(name)._add_config is True def test_add_graph(self): visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg)) + vis_backends=copy.deepcopy(self.vis_backend_cfg), + save_dir='temp_dir') class Model(nn.Module): @@ -435,7 +544,8 @@ class TestVisualizer(TestCase): def test_add_image(self): image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8) visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg)) + vis_backends=copy.deepcopy(self.vis_backend_cfg), + save_dir='temp_dir') visualizer.add_image('img', image) for name in ['mock1', 'mock2']: @@ -443,14 +553,16 @@ class TestVisualizer(TestCase): def test_add_scalar(self): visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg)) + vis_backends=copy.deepcopy(self.vis_backend_cfg), + save_dir='temp_dir') visualizer.add_scalar('map', 0.9, step=0) for name in ['mock1', 'mock2']: assert visualizer.get_backend(name)._add_scalar is True def test_add_scalars(self): visualizer = Visualizer( - vis_backends=copy.deepcopy(self.vis_backend_cfg)) + vis_backends=copy.deepcopy(self.vis_backend_cfg), + save_dir='temp_dir') input_dict = {'map': 0.7, 'acc': 0.9} visualizer.add_scalars(input_dict) for name in ['mock1', 'mock2']: @@ -468,8 +580,7 @@ class TestVisualizer(TestCase): visualizer3 = DetLocalVisualizer.get_current_instance() assert id(visualizer1) == id(visualizer2) == id(visualizer3) - -if __name__ == '__main__': - t = TestVisualizer() - t.setUp() - t.test_init() + def test_data_info(self): + visualizer = Visualizer() + visualizer.dataset_meta = {'class': 'cat'} + assert visualizer.dataset_meta['class'] == 'cat' -- GitLab