Skip to content
Snippets Groups Projects
Unverified Commit 492b2f2f authored by Haian Huang(深度眸)'s avatar Haian Huang(深度眸) Committed by GitHub
Browse files

[Feature] Add MMDataParallel and MMDistributedDataParallel (#44)

* add mmddp

* update

* update code

* update unittest

* fix comment
parent 95172781
No related branches found
No related tags found
No related merge requests found
......@@ -97,12 +97,12 @@ import numpy as np
@EVALUATORS.register_module()
class AccuracyEvaluator(BaseEvaluator):
def process(self, data_samples: Dict, predictions: Dict):
"""Process one batch of data and predictions. The processed
Results should be stored in `self.results`, which will be used
to computed the metrics when all batches have been processed.
Args:
data_samples (dict): The data samples from the dataset.
predictions (dict): The output of the model.
......@@ -113,16 +113,16 @@ class AccuracyEvaluator(BaseEvaluator):
'pred': predictions.pred_label,
'gt': data_samples.gt_label
)
# 将当前 batch 的结果存进 self.results
self.results.append(result)
def compute_metrics(self, results: List):
"""Compute the metrics from processed results.
Args:
results (dict): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
......
......@@ -219,6 +219,7 @@ MMEngine 的注册器支持跨项目调用,即可以在一个项目中使用
- DATA_SAMPLERS: `Dataloader``sampler`,用于采样数据
- PIPELINES: 各种数据预处理,如 `Resize`, `Reshape`
- MODELS: 模型的各种模块
- MODEL_WRAPPERS: 模型的包装器,如 `MMDistributedDataParallel`,用于对分布式数据并行
- WEIGHT_INITIALIZERS: 权重初始化的工具
- OPTIMIZERS: 注册了 PyTorch 中所有的 `optimizer` 以及自定义的 `optimizer`
- OPTIMIZER_CONSTRUCTORS: optimizer 的构造器
......
# Copyright (c) OpenMMLab. All rights reserved.
# flake8: noqa
from .config import *
from .dataset import *
from .data import *
from .dataset import *
from .fileio import *
from .registry import *
from .utils import *
# Copyright (c) OpenMMLab. All rights reserved.
from .wrappers import (MMDataParallel, MMDistributedDataParallel,
is_model_wrapper)
__all__ = ['MMDistributedDataParallel', 'MMDataParallel', 'is_model_wrapper']
# Copyright (c) OpenMMLab. All rights reserved.
from .data_parallel import MMDataParallel, MMDistributedDataParallel
from .utils import is_model_wrapper
__all__ = ['MMDistributedDataParallel', 'MMDataParallel', 'is_model_wrapper']
# Copyright (c) OpenMMLab. All rights reserved.
from itertools import chain
import torch
from torch.nn.parallel import DataParallel
from torch.nn.parallel.distributed import (DistributedDataParallel,
_find_tensors)
from mmengine.registry import MODEL_WRAPPERS
from mmengine.utils import TORCH_VERSION, digit_version
@MODEL_WRAPPERS.register_module()
class MMDataParallel(DataParallel):
"""There is no difference between MMDataParallel and pytorch's
DataParallel, "train_step" and "val_step" are added just to avoid bc
breaking.
Warning:
MMDataParallel only supports single GPU training, if you
need to train with multiple GPUs, please use MMDistributedDataParallel
instead. If you have multiple GPUs and you just want to use
MMDataParallel, you can set the environment variable
``CUDA_VISIBLE_DEVICES=0`` or instantiate ``MMDataParallel`` with
``device_ids=[0]``.
"""
def train_step(self, *inputs, **kwargs):
assert len(self.device_ids) == 1, \
('MMDataParallel only supports single GPU training, if you need to'
' train with multiple GPUs, please use MMDistributedDataParallel'
' instead.')
assert hasattr(self.module, 'train_step')
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError(
'module must have its parameters and buffers '
f'on device {self.src_device_obj} (device_ids[0]) but '
f'found one of them on device: {t.device}')
return self.module.train_step(*inputs, **kwargs)
def val_step(self, *inputs, **kwargs):
assert len(self.device_ids) == 1, \
('MMDataParallel only supports single GPU training, if you need to'
' train with multiple GPUs, please use MMDistributedDataParallel'
' instead.')
assert hasattr(self.module, 'val_step')
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError(
'module must have its parameters and buffers '
f'on device {self.src_device_obj} (device_ids[0]) but '
f'found one of them on device: {t.device}')
return self.module.val_step(*inputs, **kwargs)
@MODEL_WRAPPERS.register_module()
class MMDistributedDataParallel(DistributedDataParallel):
"""There is no difference between MMDistributedDataParallel and pytorch's
DistributedDataParallel, "train_step" and "val_step" are added just to
avoid bc breaking."""
def train_step(self, *inputs, **kwargs):
"""train_step() API for module wrapped by DistributedDataParallel.
This method is basically the same as
``DistributedDataParallel.forward()``, while replacing
``self.module.forward()`` with ``self.module.train_step()``.
It is compatible with PyTorch 1.1 - 1.5.
"""
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward.
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.7')
and self.reducer._rebuild_buckets()):
# TODO: replace with logger
print('Reducer buckets have been rebuilt in this iteration.')
if getattr(self, 'require_forward_param_sync', True):
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
output = self.module.train_step(*inputs[0], **kwargs[0])
else:
outputs = self.parallel_apply(
self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
output = self.module.train_step(*inputs, **kwargs)
if torch.is_grad_enabled() and getattr(
self, 'require_backward_grad_sync', True):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) > digit_version('1.2')):
self.require_forward_param_sync = False
return output
def val_step(self, *inputs, **kwargs):
"""val_step() API for module wrapped by DistributedDataParallel.
This method is basically the same as
``DistributedDataParallel.forward()``, while replacing
``self.module.forward()`` with ``self.module.val_step()``.
It is compatible with PyTorch 1.1 - 1.5.
"""
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward.
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.7')
and self.reducer._rebuild_buckets()):
# TODO: replace with logger
print('Reducer buckets have been rebuilt in this iteration.')
if getattr(self, 'require_forward_param_sync', True):
self._sync_params()
if self.device_ids:
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
output = self.module.val_step(*inputs[0], **kwargs[0])
else:
outputs = self.parallel_apply(
self._module_copies[:len(inputs)], inputs, kwargs)
output = self.gather(outputs, self.output_device)
else:
output = self.module.val_step(*inputs, **kwargs)
if torch.is_grad_enabled() and getattr(
self, 'require_backward_grad_sync', True):
if self.find_unused_parameters:
self.reducer.prepare_for_backward(list(_find_tensors(output)))
else:
self.reducer.prepare_for_backward([])
else:
if ('parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) > digit_version('1.2')):
self.require_forward_param_sync = False
return output
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.registry import MODEL_WRAPPERS
def is_model_wrapper(model):
"""Check if a module is a model wrapper.
The following 4 model in MMEngine (and their subclasses) are regarded as
model wrappers: DataParallel, DistributedDataParallel,
MMDataParallel, MMDistributedDataParallel. You may add you own
model wrapper by registering it to mmengine.registry.MODEL_WRAPPERS.
Args:
model (nn.Module): The model to be checked.
Returns:
bool: True if the input model is a model wrapper.
"""
model_wrappers = tuple(MODEL_WRAPPERS.module_dict.values())
return isinstance(model, model_wrappers)
# Copyright (c) OpenMMLab. All rights reserved.
from .registry import Registry, build_from_cfg
from .root import (DATA_SAMPLERS, DATASETS, EVALUATORS, HOOKS, MODELS,
OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, PARAM_SCHEDULERS,
RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, TRANSFORMS,
WEIGHT_INITIALIZERS)
from .root import (DATA_SAMPLERS, DATASETS, EVALUATORS, HOOKS, MODEL_WRAPPERS,
MODELS, OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS,
TRANSFORMS, WEIGHT_INITIALIZERS)
__all__ = [
'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS',
'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS',
'EVALUATORS'
'EVALUATORS', 'MODEL_WRAPPERS'
]
......@@ -22,6 +22,8 @@ TRANSFORMS = Registry('transform')
# mangage all kinds of modules inheriting `nn.Module`
MODELS = Registry('model')
# mangage all kinds of model wrappers like 'MMDistributedDataParallel'
MODEL_WRAPPERS = Registry('model_wrapper')
# mangage all kinds of weight initialization modules like `Uniform`
WEIGHT_INITIALIZERS = Registry('weight initializer')
......
......@@ -6,8 +6,10 @@ from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
requires_executable, requires_package, slice_list,
to_1tuple, to_2tuple, to_3tuple, to_4tuple, to_ntuple,
tuple_cast)
from .parrots_wrapper import TORCH_VERSION
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
scandir, symlink)
from .version_utils import digit_version, get_git_hash
__all__ = [
'is_str', 'iter_cast', 'list_cast', 'tuple_cast', 'is_seq_of',
......@@ -16,5 +18,6 @@ __all__ = [
'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist', 'symlink',
'scandir', 'deprecated_api_warning', 'import_modules_from_strings',
'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple',
'is_method_overridden', 'has_method', 'mmcv_full_available'
'is_method_overridden', 'has_method', 'mmcv_full_available',
'digit_version', 'get_git_hash', 'TORCH_VERSION'
]
# Copyright (c) OpenMMLab. All rights reserved.
import os
import subprocess
import warnings
from packaging.version import parse
def digit_version(version_str: str, length: int = 4):
"""Convert a version string into a tuple of integers.
This method is usually used for comparing two versions. For pre-release
versions: alpha < beta < rc.
Args:
version_str (str): The version string.
length (int): The maximum number of version levels. Default: 4.
Returns:
tuple[int]: The version info in digits (integers).
"""
assert 'parrots' not in version_str
version = parse(version_str)
assert version.release, f'failed to parse version {version_str}'
release = list(version.release)
release = release[:length]
if len(release) < length:
release = release + [0] * (length - len(release))
if version.is_prerelease:
mapping = {'a': -3, 'b': -2, 'rc': -1}
val = -4
# version.pre can be None
if version.pre:
if version.pre[0] not in mapping:
warnings.warn(f'unknown prerelease version {version.pre[0]}, '
'version checking may go wrong')
else:
val = mapping[version.pre[0]]
release.extend([val, version.pre[-1]])
else:
release.extend([val, 0])
elif version.is_postrelease:
release.extend([1, version.post])
else:
release.extend([0, 0])
return tuple(release)
def _minimal_ext_cmd(cmd):
# construct minimal environment
env = {}
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
v = os.environ.get(k)
if v is not None:
env[k] = v
# LANGUAGE is used on win32
env['LANGUAGE'] = 'C'
env['LANG'] = 'C'
env['LC_ALL'] = 'C'
out = subprocess.Popen(
cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
return out
def get_git_hash(fallback='unknown', digits=None):
"""Get the git hash of the current repo.
Args:
fallback (str, optional): The fallback string when git hash is
unavailable. Defaults to 'unknown'.
digits (int, optional): kept digits of the hash. Defaults to None,
meaning all digits are kept.
Returns:
str: Git commit hash.
"""
if digits is not None and not isinstance(digits, int):
raise TypeError('digits must be None or an integer')
try:
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
sha = out.strip().decode('ascii')
if digits is not None:
sha = sha[:digits]
except OSError:
sha = fallback
return sha
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
from mmengine.model.wrappers import (MMDataParallel, MMDistributedDataParallel,
is_model_wrapper)
from mmengine.registry import MODEL_WRAPPERS
def mock(*args, **kwargs):
pass
@patch('torch.distributed._broadcast_coalesced', mock)
@patch('torch.distributed.broadcast', mock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
def test_is_model_wrapper():
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(2, 2, 1)
def forward(self, x):
return self.conv(x)
# _verify_model_across_ranks is added in torch1.9.0 so we should check
# whether _verify_model_across_ranks is the member of torch.distributed
# before mocking
if hasattr(torch.distributed, '_verify_model_across_ranks'):
torch.distributed._verify_model_across_ranks = mock
model = Model()
assert not is_model_wrapper(model)
mmdp = MMDataParallel(model)
assert is_model_wrapper(mmdp)
mmddp = MMDistributedDataParallel(model, process_group=MagicMock())
assert is_model_wrapper(mmddp)
# test model wrapper registry
@MODEL_WRAPPERS.register_module()
class ModelWrapper(object):
def __init__(self, module):
self.module = module
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
model_wrapper = ModelWrapper(model)
assert is_model_wrapper(model_wrapper)
class TestMMDataParallel:
def setUp(self):
"""Setup the demo image in every test method.
TestCase calls functions in this order: setUp() -> testMethod() ->
tearDown() -> cleanUp()
"""
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(2, 2, 1)
def forward(self, x):
return self.conv(x)
def train_step(self, x):
return self.forward(x)
def val_step(self, x):
return self.forward(x)
self.model = Model()
def test_train_step(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 2, 1)
def forward(self, x):
return self.conv(x)
model = Model()
mmdp = MMDataParallel(model)
# test without train_step attribute
with pytest.raises(AssertionError):
mmdp.train_step(torch.zeros([1, 1, 3, 3]))
out = self.model.train_step([torch.zeros([1, 1, 3, 3])])
assert out.shape == (1, 2, 3, 3)
def test_val_step(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 2, 1)
def forward(self, x):
return self.conv(x)
model = Model()
mmdp = MMDataParallel(model)
# test without val_step attribute
with pytest.raises(AssertionError):
mmdp.val_step(torch.zeros([1, 1, 3, 3]))
out = self.model.val_step([torch.zeros([1, 1, 3, 3])])
assert out.shape == (1, 2, 3, 3)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment