Skip to content
Snippets Groups Projects
Unverified Commit c89d4ef8 authored by Zaida Zhou's avatar Zaida Zhou Committed by GitHub
Browse files

[Enhance] Remove unnecessary calls and lazily import to speed import performance (#837)

* [Enhance] Remove unnecessary calls to speed import performance

* lazily import matplotlib

* minor refinement
parent fcd783fc
No related branches found
No related tags found
No related merge requests found
......@@ -12,7 +12,6 @@ from tempfile import TemporaryDirectory
from typing import Callable, Dict, Optional
import torch
import torchvision
import mmengine
from mmengine.dist import get_dist_info
......@@ -112,6 +111,7 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
def get_torchvision_models():
import torchvision
if digit_version(torchvision.__version__) < digit_version('0.13.0a0'):
model_urls = dict()
# When the version of torchvision is lower than 0.13, the model url is
......
......@@ -4,7 +4,6 @@ import os.path as osp
import subprocess
import sys
from collections import OrderedDict, defaultdict
from distutils import errors
import cv2
import numpy as np
......@@ -47,6 +46,8 @@ def collect_env():
- OpenCV (optional): OpenCV version.
- MMENGINE: MMENGINE version.
"""
from distutils import errors
env_info = OrderedDict()
env_info['sys.platform'] = sys.platform
env_info['Python'] = sys.version.replace('\n', '')
......
......@@ -103,7 +103,6 @@ def _get_norm() -> tuple:
_ConvNd, _ConvTransposeMixin = _get_conv()
DataLoader, PoolDataLoader = _get_dataloader()
BuildExtension, CppExtension, CUDAExtension = _get_extension()
_BatchNorm, _InstanceNorm, SyncBatchNorm_ = _get_norm()
_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd = _get_pool()
......
......@@ -3,9 +3,6 @@ import importlib
import os.path as osp
import subprocess
import pkg_resources
from pkg_resources import get_distribution
def is_installed(package: str) -> bool:
"""Check package whether installed.
......@@ -13,6 +10,12 @@ def is_installed(package: str) -> bool:
Args:
package (str): Name of package to be checked.
"""
# When executing `import mmengine.runner`,
# pkg_resources will be imported and it takes too much time.
# Therefore, import it in function scope to save time.
import pkg_resources
from pkg_resources import get_distribution
# refresh the pkg_resources
# more datails at https://github.com/pypa/setuptools/issues/373
importlib.reload(pkg_resources)
......@@ -33,6 +36,8 @@ def get_installed_path(package: str) -> str:
>>> get_installed_path('mmcls')
>>> '.../lib/python3.7/site-packages/mmcls'
"""
from pkg_resources import get_distribution
# if the package name is not the same as module name, module name should be
# inferred. For example, mmcv-full is the package name, but mmcv is module
# name. If we want to get the installed path of mmcv-full, we should concat
......@@ -51,6 +56,7 @@ def package2module(package: str):
Args:
package (str): Package to infer module name.
"""
from pkg_resources import get_distribution
pkg = get_distribution(package)
if pkg.has_metadata('top_level.txt'):
module_name = pkg.get_metadata('top_level.txt').split('\n')[0]
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, List, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type, Union
import cv2
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.backend_bases import CloseEvent
from matplotlib.backends.backend_agg import FigureCanvasAgg
if TYPE_CHECKING:
from matplotlib.backends.backend_agg import FigureCanvasAgg
def tensor2ndarray(value: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
......@@ -131,6 +130,7 @@ def color_str2rgb(color: str) -> tuple:
Returns:
tuple: RGB color.
"""
import matplotlib
rgb_color: tuple = matplotlib.colors.to_rgb(color)
rgb_color = tuple(int(c * 255) for c in rgb_color)
return rgb_color
......@@ -186,6 +186,8 @@ def wait_continue(figure, timeout: int = 0, continue_key: str = ' ') -> int:
int: If zero, means time out or the user pressed ``continue_key``,
and if one, means the user closed the show figure.
""" # noqa: E501
import matplotlib.pyplot as plt
from matplotlib.backend_bases import CloseEvent
is_inline = 'inline' in plt.get_backend()
if is_inline:
# If use inline backend, interactive input and timeout is no use.
......@@ -226,7 +228,7 @@ def wait_continue(figure, timeout: int = 0, continue_key: str = ' ') -> int:
return 0 # Quit for continue.
def img_from_canvas(canvas: FigureCanvasAgg) -> np.ndarray:
def img_from_canvas(canvas: 'FigureCanvasAgg') -> np.ndarray:
"""Get RGB image from ``FigureCanvasAgg``.
Args:
......
......@@ -4,16 +4,9 @@ import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.collections import (LineCollection, PatchCollection,
PolyCollection)
from matplotlib.figure import Figure
from matplotlib.patches import Circle
from matplotlib.pyplot import new_figure_manager
from mmengine.config import Config
from mmengine.dist import master_only
......@@ -240,6 +233,7 @@ class Visualizer(ManagerMixin):
continue_key (str): The key for users to continue. Defaults to
the space key.
"""
import matplotlib.pyplot as plt
is_inline = 'inline' in plt.get_backend()
img = self.get_image() if drawn_img is None else drawn_img
self._init_manager(win_name)
......@@ -302,7 +296,8 @@ class Visualizer(ManagerMixin):
Returns:
tuple: build canvas figure and axes.
"""
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.figure import Figure
fig = Figure(**fig_cfg)
ax = fig.add_subplot()
ax.axis(False)
......@@ -318,6 +313,8 @@ class Visualizer(ManagerMixin):
Args:
win_name (str): The window name.
"""
from matplotlib.figure import Figure
from matplotlib.pyplot import new_figure_manager
if getattr(self, 'manager', None) is None:
self.manager = new_figure_manager(
num=1, FigureClass=Figure, **self.fig_show_cfg)
......@@ -546,6 +543,7 @@ class Visualizer(ManagerMixin):
If ``line_widths`` is single value, all the lines will
have the same linewidth. Defaults to 2.
"""
from matplotlib.collections import LineCollection
check_type('x_datas', x_datas, (np.ndarray, torch.Tensor))
x_datas = tensor2ndarray(x_datas)
check_type('y_datas', y_datas, (np.ndarray, torch.Tensor))
......@@ -614,6 +612,8 @@ class Visualizer(ManagerMixin):
alpha (Union[int, float]): The transparency of circles.
Defaults to 0.8.
"""
from matplotlib.collections import PatchCollection
from matplotlib.patches import Circle
check_type('center', center, (np.ndarray, torch.Tensor))
center = tensor2ndarray(center)
check_type('radius', radius, (np.ndarray, torch.Tensor))
......@@ -760,6 +760,7 @@ class Visualizer(ManagerMixin):
alpha (Union[int, float]): The transparency of polygons.
Defaults to 0.8.
"""
from matplotlib.collections import PolyCollection
check_type('polygons', polygons, (list, np.ndarray, torch.Tensor))
edge_colors = color_val_matplotlib(edge_colors) # type: ignore
face_colors = color_val_matplotlib(face_colors) # type: ignore
......@@ -916,6 +917,7 @@ class Visualizer(ManagerMixin):
Returns:
np.ndarray: RGB image.
"""
import matplotlib.pyplot as plt
assert isinstance(featmap,
torch.Tensor), (f'`featmap` should be torch.Tensor,'
f' but got {type(featmap)}')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment