Skip to content
Snippets Groups Projects
Unverified Commit bb56cf42 authored by liukuikun's avatar liukuikun Committed by GitHub
Browse files

[Visualizer] use FigureManager to manage figure to avoid affecting plt.show()...

[Visualizer] use FigureManager to manage figure to avoid affecting plt.show() outside Visualizer(#440)

* figure in Visualizer is not managed by plt

* encapsulate code and remove unused code
parent b75962a6
No related branches found
No related tags found
No related merge requests found
......@@ -8,9 +8,12 @@ import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.collections import (LineCollection, PatchCollection,
PolyCollection)
from matplotlib.figure import Figure
from matplotlib.patches import Circle
from matplotlib.pyplot import new_figure_manager
from mmengine.config import Config
from mmengine.data import BaseDataElement
......@@ -157,7 +160,7 @@ class Visualizer(ManagerMixin):
vis_backends: Optional[List[Dict]] = None,
save_dir: Optional[str] = None,
fig_save_cfg=dict(frameon=False),
fig_show_cfg=dict(frameon=False, num='show')
fig_show_cfg=dict(frameon=False)
) -> None:
super().__init__(name)
self._dataset_meta: Optional[dict] = None
......@@ -196,17 +199,12 @@ class Visualizer(ManagerMixin):
vis_backend.setdefault('save_dir', save_dir)
self._vis_backends[name] = VISBACKENDS.build(vis_backend)
self.is_inline = 'inline' in plt.get_backend()
self.fig_save = None
self.fig_show = None
self.fig_save_num = fig_save_cfg.get('num', None)
self.fig_show_num = fig_show_cfg.get('num', None)
self.fig_save_cfg = fig_save_cfg
self.fig_show_cfg = fig_show_cfg
(self.fig_save, self.ax_save,
self.fig_save_num) = self._initialize_fig(fig_save_cfg)
(self.fig_save_canvas, self.fig_save,
self.ax_save) = self._initialize_fig(fig_save_cfg)
self.dpi = self.fig_save.get_dpi()
if image is not None:
......@@ -242,20 +240,22 @@ class Visualizer(ManagerMixin):
continue_key (str): The key for users to continue. Defaults to
the space key.
"""
if self.is_inline:
return
if self.fig_show is None or not plt.fignum_exists(self.fig_show_num):
(self.fig_show, self.ax_show,
self.fig_show_num) = self._initialize_fig(self.fig_show_cfg)
is_inline = 'inline' in plt.get_backend()
img = self.get_image() if drawn_img is None else drawn_img
self.ax_show.cla()
self.ax_show.axis(False)
self.fig_show.canvas.manager.set_window_title(win_name) # type: ignore
# Refresh canvas, necessary for Qt5 backend.
self.ax_show.imshow(img)
self.fig_show.canvas.draw() # type: ignore
wait_continue(
self.fig_show, timeout=wait_time, continue_key=continue_key)
self._init_manager(win_name)
fig = self.manager.canvas.figure
# remove white edges by set subplot margin
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
fig.clear()
ax = fig.add_subplot()
ax.axis(False)
ax.imshow(img)
self.manager.canvas.draw()
# Find a better way for inline to show the image
if is_inline:
return fig
wait_continue(fig, timeout=wait_time, continue_key=continue_key)
@master_only
def set_image(self, image: np.ndarray) -> None:
......@@ -291,7 +291,7 @@ class Visualizer(ManagerMixin):
np.ndarray: the drawn image which channel is RGB.
"""
assert self._image is not None, 'Please set image using `set_image`'
return img_from_canvas(self.fig_save.canvas) # type: ignore
return img_from_canvas(self.fig_save_canvas) # type: ignore
def _initialize_fig(self, fig_cfg) -> tuple:
"""Build figure according to fig_cfg.
......@@ -300,15 +300,34 @@ class Visualizer(ManagerMixin):
fig_cfg (dict): The config to build figure.
Returns:
tuple: build figure, axes and fig number.
tuple: build canvas figure and axes.
"""
fig = plt.figure(**fig_cfg)
fig = Figure(**fig_cfg)
ax = fig.add_subplot()
ax.axis(False)
# remove white edges by set subplot margin
fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
return (fig, ax, fig.number)
canvas = FigureCanvasAgg(fig)
return canvas, fig, ax
def _init_manager(self, win_name: str) -> None:
"""Initialize the matplot manager.
Args:
win_name (str): The window name.
"""
if getattr(self, 'manager', None) is None:
self.manager = new_figure_manager(
num=1, FigureClass=Figure, **self.fig_show_cfg)
try:
self.manager.set_window_title(win_name)
except Exception:
self.manager = new_figure_manager(
num=1, FigureClass=Figure, **self.fig_show_cfg)
self.manager.set_window_title(win_name)
@master_only
def get_backend(self, name) -> 'BaseVisBackend':
......@@ -982,7 +1001,9 @@ class Visualizer(ManagerMixin):
axes.imshow(
convert_overlay_heatmap(topk_featmap[i], overlaid_image,
alpha))
return img_from_canvas(fig.canvas)
image = img_from_canvas(fig.canvas)
plt.close(fig)
return image
@master_only
def add_config(self, config: Config, **kwargs):
......@@ -1071,9 +1092,6 @@ class Visualizer(ManagerMixin):
def close(self) -> None:
"""close an opened object."""
plt.close(self.fig_save)
if self.fig_show is not None:
plt.close(self.fig_show)
for vis_backend in self._vis_backends.values():
vis_backend.close()
......
......@@ -4,7 +4,6 @@ import time
from typing import Any
from unittest import TestCase
import matplotlib.pyplot as plt
import numpy as np
import pytest
import torch
......@@ -171,12 +170,10 @@ class TestVisualizer(TestCase):
image=self.image,
vis_backends=copy.deepcopy(self.vis_backend_cfg),
save_dir='temp_dir')
fig_num = visualizer.fig_save_num
assert fig_num in plt.get_fignums()
for name in ['mock1', 'mock2']:
assert visualizer.get_backend(name)._close is False
visualizer.close()
assert fig_num not in plt.get_fignums()
for name in ['mock1', 'mock2']:
assert visualizer.get_backend(name)._close is True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment