From bb56cf42abf98bc838e66a66abb9be422b95267b Mon Sep 17 00:00:00 2001
From: liukuikun <24622904+Harold-lkk@users.noreply.github.com>
Date: Mon, 22 Aug 2022 17:16:14 +0800
Subject: [PATCH] [Visualizer] use FigureManager to manage figure to avoid
 affecting plt.show() outside Visualizer(#440)

* figure in Visualizer is not managed by plt

* encapsulate code and remove unused code
---
 mmengine/visualization/visualizer.py     | 76 +++++++++++++++---------
 tests/test_visualizer/test_visualizer.py |  5 +-
 2 files changed, 48 insertions(+), 33 deletions(-)

diff --git a/mmengine/visualization/visualizer.py b/mmengine/visualization/visualizer.py
index a0596dc1..39ad4783 100644
--- a/mmengine/visualization/visualizer.py
+++ b/mmengine/visualization/visualizer.py
@@ -8,9 +8,12 @@ import matplotlib.pyplot as plt
 import numpy as np
 import torch
 import torch.nn.functional as F
+from matplotlib.backends.backend_agg import FigureCanvasAgg
 from matplotlib.collections import (LineCollection, PatchCollection,
                                     PolyCollection)
+from matplotlib.figure import Figure
 from matplotlib.patches import Circle
+from matplotlib.pyplot import new_figure_manager
 
 from mmengine.config import Config
 from mmengine.data import BaseDataElement
@@ -157,7 +160,7 @@ class Visualizer(ManagerMixin):
         vis_backends: Optional[List[Dict]] = None,
         save_dir: Optional[str] = None,
         fig_save_cfg=dict(frameon=False),
-        fig_show_cfg=dict(frameon=False, num='show')
+        fig_show_cfg=dict(frameon=False)
     ) -> None:
         super().__init__(name)
         self._dataset_meta: Optional[dict] = None
@@ -196,17 +199,12 @@ class Visualizer(ManagerMixin):
                 vis_backend.setdefault('save_dir', save_dir)
                 self._vis_backends[name] = VISBACKENDS.build(vis_backend)
 
-        self.is_inline = 'inline' in plt.get_backend()
-
         self.fig_save = None
-        self.fig_show = None
-        self.fig_save_num = fig_save_cfg.get('num', None)
-        self.fig_show_num = fig_show_cfg.get('num', None)
         self.fig_save_cfg = fig_save_cfg
         self.fig_show_cfg = fig_show_cfg
 
-        (self.fig_save, self.ax_save,
-         self.fig_save_num) = self._initialize_fig(fig_save_cfg)
+        (self.fig_save_canvas, self.fig_save,
+         self.ax_save) = self._initialize_fig(fig_save_cfg)
         self.dpi = self.fig_save.get_dpi()
 
         if image is not None:
@@ -242,20 +240,22 @@ class Visualizer(ManagerMixin):
             continue_key (str): The key for users to continue. Defaults to
                 the space key.
         """
-        if self.is_inline:
-            return
-        if self.fig_show is None or not plt.fignum_exists(self.fig_show_num):
-            (self.fig_show, self.ax_show,
-             self.fig_show_num) = self._initialize_fig(self.fig_show_cfg)
+        is_inline = 'inline' in plt.get_backend()
         img = self.get_image() if drawn_img is None else drawn_img
-        self.ax_show.cla()
-        self.ax_show.axis(False)
-        self.fig_show.canvas.manager.set_window_title(win_name)  # type: ignore
-        # Refresh canvas, necessary for Qt5 backend.
-        self.ax_show.imshow(img)
-        self.fig_show.canvas.draw()  # type: ignore
-        wait_continue(
-            self.fig_show, timeout=wait_time, continue_key=continue_key)
+        self._init_manager(win_name)
+        fig = self.manager.canvas.figure
+        # remove white edges by set subplot margin
+        fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
+        fig.clear()
+        ax = fig.add_subplot()
+        ax.axis(False)
+        ax.imshow(img)
+        self.manager.canvas.draw()
+
+        # Find a better way for inline to show the image
+        if is_inline:
+            return fig
+        wait_continue(fig, timeout=wait_time, continue_key=continue_key)
 
     @master_only
     def set_image(self, image: np.ndarray) -> None:
@@ -291,7 +291,7 @@ class Visualizer(ManagerMixin):
             np.ndarray: the drawn image which channel is RGB.
         """
         assert self._image is not None, 'Please set image using `set_image`'
-        return img_from_canvas(self.fig_save.canvas)  # type: ignore
+        return img_from_canvas(self.fig_save_canvas)  # type: ignore
 
     def _initialize_fig(self, fig_cfg) -> tuple:
         """Build figure according to fig_cfg.
@@ -300,15 +300,34 @@ class Visualizer(ManagerMixin):
             fig_cfg (dict): The config to build figure.
 
         Returns:
-             tuple: build figure, axes and fig number.
+             tuple: build canvas figure and axes.
         """
-        fig = plt.figure(**fig_cfg)
+
+        fig = Figure(**fig_cfg)
         ax = fig.add_subplot()
         ax.axis(False)
 
         # remove white edges by set subplot margin
         fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
-        return (fig, ax, fig.number)
+        canvas = FigureCanvasAgg(fig)
+        return canvas, fig, ax
+
+    def _init_manager(self, win_name: str) -> None:
+        """Initialize the matplot manager.
+
+        Args:
+            win_name (str): The window name.
+        """
+        if getattr(self, 'manager', None) is None:
+            self.manager = new_figure_manager(
+                num=1, FigureClass=Figure, **self.fig_show_cfg)
+
+        try:
+            self.manager.set_window_title(win_name)
+        except Exception:
+            self.manager = new_figure_manager(
+                num=1, FigureClass=Figure, **self.fig_show_cfg)
+            self.manager.set_window_title(win_name)
 
     @master_only
     def get_backend(self, name) -> 'BaseVisBackend':
@@ -982,7 +1001,9 @@ class Visualizer(ManagerMixin):
                 axes.imshow(
                     convert_overlay_heatmap(topk_featmap[i], overlaid_image,
                                             alpha))
-            return img_from_canvas(fig.canvas)
+            image = img_from_canvas(fig.canvas)
+            plt.close(fig)
+            return image
 
     @master_only
     def add_config(self, config: Config, **kwargs):
@@ -1071,9 +1092,6 @@ class Visualizer(ManagerMixin):
 
     def close(self) -> None:
         """close an opened object."""
-        plt.close(self.fig_save)
-        if self.fig_show is not None:
-            plt.close(self.fig_show)
         for vis_backend in self._vis_backends.values():
             vis_backend.close()
 
diff --git a/tests/test_visualizer/test_visualizer.py b/tests/test_visualizer/test_visualizer.py
index a487501f..dc0485fa 100644
--- a/tests/test_visualizer/test_visualizer.py
+++ b/tests/test_visualizer/test_visualizer.py
@@ -4,7 +4,6 @@ import time
 from typing import Any
 from unittest import TestCase
 
-import matplotlib.pyplot as plt
 import numpy as np
 import pytest
 import torch
@@ -171,12 +170,10 @@ class TestVisualizer(TestCase):
             image=self.image,
             vis_backends=copy.deepcopy(self.vis_backend_cfg),
             save_dir='temp_dir')
-        fig_num = visualizer.fig_save_num
-        assert fig_num in plt.get_fignums()
+
         for name in ['mock1', 'mock2']:
             assert visualizer.get_backend(name)._close is False
         visualizer.close()
-        assert fig_num not in plt.get_fignums()
         for name in ['mock1', 'mock2']:
             assert visualizer.get_backend(name)._close is True
 
-- 
GitLab