From bb56cf42abf98bc838e66a66abb9be422b95267b Mon Sep 17 00:00:00 2001 From: liukuikun <24622904+Harold-lkk@users.noreply.github.com> Date: Mon, 22 Aug 2022 17:16:14 +0800 Subject: [PATCH] [Visualizer] use FigureManager to manage figure to avoid affecting plt.show() outside Visualizer(#440) * figure in Visualizer is not managed by plt * encapsulate code and remove unused code --- mmengine/visualization/visualizer.py | 76 +++++++++++++++--------- tests/test_visualizer/test_visualizer.py | 5 +- 2 files changed, 48 insertions(+), 33 deletions(-) diff --git a/mmengine/visualization/visualizer.py b/mmengine/visualization/visualizer.py index a0596dc1..39ad4783 100644 --- a/mmengine/visualization/visualizer.py +++ b/mmengine/visualization/visualizer.py @@ -8,9 +8,12 @@ import matplotlib.pyplot as plt import numpy as np import torch import torch.nn.functional as F +from matplotlib.backends.backend_agg import FigureCanvasAgg from matplotlib.collections import (LineCollection, PatchCollection, PolyCollection) +from matplotlib.figure import Figure from matplotlib.patches import Circle +from matplotlib.pyplot import new_figure_manager from mmengine.config import Config from mmengine.data import BaseDataElement @@ -157,7 +160,7 @@ class Visualizer(ManagerMixin): vis_backends: Optional[List[Dict]] = None, save_dir: Optional[str] = None, fig_save_cfg=dict(frameon=False), - fig_show_cfg=dict(frameon=False, num='show') + fig_show_cfg=dict(frameon=False) ) -> None: super().__init__(name) self._dataset_meta: Optional[dict] = None @@ -196,17 +199,12 @@ class Visualizer(ManagerMixin): vis_backend.setdefault('save_dir', save_dir) self._vis_backends[name] = VISBACKENDS.build(vis_backend) - self.is_inline = 'inline' in plt.get_backend() - self.fig_save = None - self.fig_show = None - self.fig_save_num = fig_save_cfg.get('num', None) - self.fig_show_num = fig_show_cfg.get('num', None) self.fig_save_cfg = fig_save_cfg self.fig_show_cfg = fig_show_cfg - (self.fig_save, self.ax_save, - self.fig_save_num) = self._initialize_fig(fig_save_cfg) + (self.fig_save_canvas, self.fig_save, + self.ax_save) = self._initialize_fig(fig_save_cfg) self.dpi = self.fig_save.get_dpi() if image is not None: @@ -242,20 +240,22 @@ class Visualizer(ManagerMixin): continue_key (str): The key for users to continue. Defaults to the space key. """ - if self.is_inline: - return - if self.fig_show is None or not plt.fignum_exists(self.fig_show_num): - (self.fig_show, self.ax_show, - self.fig_show_num) = self._initialize_fig(self.fig_show_cfg) + is_inline = 'inline' in plt.get_backend() img = self.get_image() if drawn_img is None else drawn_img - self.ax_show.cla() - self.ax_show.axis(False) - self.fig_show.canvas.manager.set_window_title(win_name) # type: ignore - # Refresh canvas, necessary for Qt5 backend. - self.ax_show.imshow(img) - self.fig_show.canvas.draw() # type: ignore - wait_continue( - self.fig_show, timeout=wait_time, continue_key=continue_key) + self._init_manager(win_name) + fig = self.manager.canvas.figure + # remove white edges by set subplot margin + fig.subplots_adjust(left=0, right=1, bottom=0, top=1) + fig.clear() + ax = fig.add_subplot() + ax.axis(False) + ax.imshow(img) + self.manager.canvas.draw() + + # Find a better way for inline to show the image + if is_inline: + return fig + wait_continue(fig, timeout=wait_time, continue_key=continue_key) @master_only def set_image(self, image: np.ndarray) -> None: @@ -291,7 +291,7 @@ class Visualizer(ManagerMixin): np.ndarray: the drawn image which channel is RGB. """ assert self._image is not None, 'Please set image using `set_image`' - return img_from_canvas(self.fig_save.canvas) # type: ignore + return img_from_canvas(self.fig_save_canvas) # type: ignore def _initialize_fig(self, fig_cfg) -> tuple: """Build figure according to fig_cfg. @@ -300,15 +300,34 @@ class Visualizer(ManagerMixin): fig_cfg (dict): The config to build figure. Returns: - tuple: build figure, axes and fig number. + tuple: build canvas figure and axes. """ - fig = plt.figure(**fig_cfg) + + fig = Figure(**fig_cfg) ax = fig.add_subplot() ax.axis(False) # remove white edges by set subplot margin fig.subplots_adjust(left=0, right=1, bottom=0, top=1) - return (fig, ax, fig.number) + canvas = FigureCanvasAgg(fig) + return canvas, fig, ax + + def _init_manager(self, win_name: str) -> None: + """Initialize the matplot manager. + + Args: + win_name (str): The window name. + """ + if getattr(self, 'manager', None) is None: + self.manager = new_figure_manager( + num=1, FigureClass=Figure, **self.fig_show_cfg) + + try: + self.manager.set_window_title(win_name) + except Exception: + self.manager = new_figure_manager( + num=1, FigureClass=Figure, **self.fig_show_cfg) + self.manager.set_window_title(win_name) @master_only def get_backend(self, name) -> 'BaseVisBackend': @@ -982,7 +1001,9 @@ class Visualizer(ManagerMixin): axes.imshow( convert_overlay_heatmap(topk_featmap[i], overlaid_image, alpha)) - return img_from_canvas(fig.canvas) + image = img_from_canvas(fig.canvas) + plt.close(fig) + return image @master_only def add_config(self, config: Config, **kwargs): @@ -1071,9 +1092,6 @@ class Visualizer(ManagerMixin): def close(self) -> None: """close an opened object.""" - plt.close(self.fig_save) - if self.fig_show is not None: - plt.close(self.fig_show) for vis_backend in self._vis_backends.values(): vis_backend.close() diff --git a/tests/test_visualizer/test_visualizer.py b/tests/test_visualizer/test_visualizer.py index a487501f..dc0485fa 100644 --- a/tests/test_visualizer/test_visualizer.py +++ b/tests/test_visualizer/test_visualizer.py @@ -4,7 +4,6 @@ import time from typing import Any from unittest import TestCase -import matplotlib.pyplot as plt import numpy as np import pytest import torch @@ -171,12 +170,10 @@ class TestVisualizer(TestCase): image=self.image, vis_backends=copy.deepcopy(self.vis_backend_cfg), save_dir='temp_dir') - fig_num = visualizer.fig_save_num - assert fig_num in plt.get_fignums() + for name in ['mock1', 'mock2']: assert visualizer.get_backend(name)._close is False visualizer.close() - assert fig_num not in plt.get_fignums() for name in ['mock1', 'mock2']: assert visualizer.get_backend(name)._close is True -- GitLab