Skip to content
Snippets Groups Projects
Unverified Commit 8d3bd4df authored by Haian Huang(深度眸)'s avatar Haian Huang(深度眸) Committed by GitHub
Browse files

Move get_max_cuda_memory and set_multi_processing to public function (#250)

* move get_max_cuda_memory and set_multi_processing to a public function

* fix lint

* fix lint

* fix lint

* delete _set_multi_processing

* fix error

* rename
parent a976257c
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@
from .config import *
from .data import *
from .dataset import *
from .device import *
from .fileio import *
from .hooks import *
from .logging import *
......
# Copyright (c) OpenMMLab. All rights reserved.
from .utils import get_max_cuda_memory
__all__ = ['get_max_cuda_memory']
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
def get_max_cuda_memory(device: Optional[torch.device] = None) -> int:
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB) for
a given device. By default, this returns the peak allocated memory since
the beginning of this program.
Args:
device (torch.device, optional): selected device. Returns
statistic for the current device, given by
:func:`~torch.cuda.current_device`, if ``device`` is None.
Defaults to None.
Returns:
int: The maximum GPU memory occupied by tensors in megabytes
for a given device.
"""
mem = torch.cuda.max_memory_allocated(device=device)
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
dtype=torch.int,
device=device)
torch.cuda.reset_peak_memory_stats()
return int(mem_mb.item())
......@@ -6,6 +6,7 @@ from typing import List, Optional, Tuple
import torch
from mmengine.device import get_max_cuda_memory
from mmengine.registry import LOG_PROCESSOR
......@@ -345,13 +346,9 @@ class LogProcessor:
The maximum GPU memory occupied by tensors in megabytes for a given
device.
"""
device = getattr(runner.model, 'output_device', None)
mem = torch.cuda.max_memory_allocated(device=device)
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
dtype=torch.int,
device=device)
torch.cuda.reset_peak_memory_stats()
return int(mem_mb.item())
return get_max_cuda_memory(device)
def _get_iter(self, runner, batch_idx: int = None) -> int:
"""Get current iteration index.
......
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import multiprocessing as mp
import os
import os.path as osp
import platform
import random
......@@ -34,7 +32,8 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
count_registered_modules)
from mmengine.registry.root import LOG_PROCESSOR
from mmengine.utils import (TORCH_VERSION, digit_version,
find_latest_checkpoint, is_list_of, symlink)
find_latest_checkpoint, is_list_of,
set_multi_processing, symlink)
from mmengine.visualization import Visualizer
from .base_loop import BaseLoop
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
......@@ -582,12 +581,13 @@ class Runner:
if env_cfg.get('cudnn_benchmark'):
torch.backends.cudnn.benchmark = True
if env_cfg.get('mp_cfg') is not None:
self._set_multi_processing(**env_cfg.get('mp_cfg')) # type: ignore
mp_cfg: dict = env_cfg.get('mp_cfg', {})
set_multi_processing(**mp_cfg, distributed=self.distributed)
# init distributed env first, since logger depends on the dist info.
if self.distributed and env_cfg.get('dist_cfg') is not None:
init_dist(self.launcher, **env_cfg.get('dist_cfg')) # type: ignore
if self.distributed:
dist_cfg: dict = env_cfg.get('dist_cfg', {})
init_dist(self.launcher, **dist_cfg)
self._rank, self._world_size = get_dist_info()
......@@ -597,59 +597,6 @@ class Runner:
self._timestamp = time.strftime('%Y%m%d_%H%M%S',
time.localtime(timestamp.item()))
def _set_multi_processing(self,
mp_start_method: str = 'fork',
opencv_num_threads: int = 0) -> None:
"""Set multi-processing related environment.
Args:
mp_start_method (str): Set the method which should be used to start
child processes. Defaults to 'fork'.
opencv_num_threads (int): Number of threads for opencv.
Defaults to 0.
"""
# set multi-process start method as `fork` to speed up the training
if platform.system() != 'Windows':
current_method = mp.get_start_method(allow_none=True)
if (current_method is not None
and current_method != mp_start_method):
warnings.warn(
f'Multi-processing start method `{mp_start_method}` is '
f'different from the previous setting `{current_method}`.'
f'It will be force set to `{mp_start_method}`. You can '
'change this behavior by changing `mp_start_method` in '
'your config.')
mp.set_start_method(mp_start_method, force=True)
try:
import cv2
# disable opencv multithreading to avoid system being overloaded
cv2.setNumThreads(opencv_num_threads)
except ImportError:
pass
# setup OMP threads
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
if 'OMP_NUM_THREADS' not in os.environ and self.distributed:
omp_num_threads = 1
warnings.warn(
'Setting OMP_NUM_THREADS environment variable for each process'
f' to be {omp_num_threads} in default, to avoid your system '
'being overloaded, please further tune the variable for '
'optimal performance in your application as needed.')
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
# setup MKL threads
if 'MKL_NUM_THREADS' not in os.environ and self.distributed:
mkl_num_threads = 1
warnings.warn(
'Setting MKL_NUM_THREADS environment variable for each process'
f' to be {mkl_num_threads} in default, to avoid your system '
'being overloaded, please further tune the variable for '
'optimal performance in your application as needed.')
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
def set_randomness(self, seed, deterministic: bool = False) -> None:
"""Set random seed to guarantee reproducible results.
......
......@@ -12,6 +12,7 @@ from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
from .parrots_wrapper import TORCH_VERSION
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
scandir, symlink)
from .setup_env import set_multi_processing
from .version_utils import digit_version, get_git_hash
__all__ = [
......@@ -23,5 +24,6 @@ __all__ = [
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
'is_method_overridden', 'has_method', 'mmcv_full_available',
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin'
'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin',
'set_multi_processing'
]
# Copyright (c) OpenMMLab. All rights reserved.
import os
import platform
import warnings
import torch.multiprocessing as mp
def set_multi_processing(mp_start_method: str = 'fork',
opencv_num_threads: int = 0,
distributed: bool = False) -> None:
"""Set multi-processing related environment.
Args:
mp_start_method (str): Set the method which should be used to start
child processes. Defaults to 'fork'.
opencv_num_threads (int): Number of threads for opencv.
Defaults to 0.
distributed (bool): True if distributed environment.
Defaults to False.
"""
# set multi-process start method as `fork` to speed up the training
if platform.system() != 'Windows':
current_method = mp.get_start_method(allow_none=True)
if (current_method is not None and current_method != mp_start_method):
warnings.warn(
f'Multi-processing start method `{mp_start_method}` is '
f'different from the previous setting `{current_method}`.'
f'It will be force set to `{mp_start_method}`. You can '
'change this behavior by changing `mp_start_method` in '
'your config.')
mp.set_start_method(mp_start_method, force=True)
try:
import cv2
# disable opencv multithreading to avoid system being overloaded
cv2.setNumThreads(opencv_num_threads)
except ImportError:
pass
# setup OMP threads
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
if 'OMP_NUM_THREADS' not in os.environ and distributed:
omp_num_threads = 1
warnings.warn(
'Setting OMP_NUM_THREADS environment variable for each process'
f' to be {omp_num_threads} in default, to avoid your system '
'being overloaded, please further tune the variable for '
'optimal performance in your application as needed.')
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
# setup MKL threads
if 'MKL_NUM_THREADS' not in os.environ and distributed:
mkl_num_threads = 1
warnings.warn(
'Setting MKL_NUM_THREADS environment variable for each process'
f' to be {mkl_num_threads} in default, to avoid your system '
'being overloaded, please further tune the variable for '
'optimal performance in your application as needed.')
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
# Copyright (c) OpenMMLab. All rights reserved.
import multiprocessing as mp
import os
import platform
import cv2
from mmengine.utils import set_multi_processing
def test_setup_multi_processes():
# temp save system setting
sys_start_mehod = mp.get_start_method(allow_none=True)
sys_cv_threads = cv2.getNumThreads()
# pop and temp save system env vars
sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None)
sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None)
# test distributed
set_multi_processing(distributed=True)
assert os.getenv('OMP_NUM_THREADS') == '1'
assert os.getenv('MKL_NUM_THREADS') == '1'
# when set to 0, the num threads will be 1
assert cv2.getNumThreads() == 1
if platform.system() != 'Windows':
assert mp.get_start_method() == 'fork'
# test num workers <= 1
os.environ.pop('OMP_NUM_THREADS')
os.environ.pop('MKL_NUM_THREADS')
set_multi_processing(distributed=False)
assert 'OMP_NUM_THREADS' not in os.environ
assert 'MKL_NUM_THREADS' not in os.environ
# test manually set env var
os.environ['OMP_NUM_THREADS'] = '4'
set_multi_processing(distributed=False)
assert os.getenv('OMP_NUM_THREADS') == '4'
# test manually set opencv threads and mp start method
config = dict(
mp_start_method='spawn', opencv_num_threads=4, distributed=True)
set_multi_processing(**config)
assert cv2.getNumThreads() == 4
assert mp.get_start_method() == 'spawn'
# revert setting to avoid affecting other programs
if sys_start_mehod:
mp.set_start_method(sys_start_mehod, force=True)
cv2.setNumThreads(sys_cv_threads)
if sys_omp_threads:
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
else:
os.environ.pop('OMP_NUM_THREADS')
if sys_mkl_threads:
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
else:
os.environ.pop('MKL_NUM_THREADS')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment