diff --git a/docs/zh_cn/tutorials/evaluator.md b/docs/zh_cn/tutorials/evaluator.md index abcd6fb9cb3737fbde25469d3e3c46926079896d..d2ab35b66168753846a9e122f3de9b431b546259 100644 --- a/docs/zh_cn/tutorials/evaluator.md +++ b/docs/zh_cn/tutorials/evaluator.md @@ -97,12 +97,12 @@ import numpy as np @EVALUATORS.register_module() class AccuracyEvaluator(BaseEvaluator): - + def process(self, data_samples: Dict, predictions: Dict): """Process one batch of data and predictions. The processed Results should be stored in `self.results`, which will be used to computed the metrics when all batches have been processed. - + Args: data_samples (dict): The data samples from the dataset. predictions (dict): The output of the model. @@ -113,16 +113,16 @@ class AccuracyEvaluator(BaseEvaluator): 'pred': predictions.pred_label, 'gt': data_samples.gt_label ) - + # å°†å½“å‰ batch 的结果å˜è¿› self.results self.results.append(result) - + def compute_metrics(self, results: List): """Compute the metrics from processed results. Args: results (dict): The processed results of each batch. - + Returns: Dict: The computed metrics. The keys are the names of the metrics, and the values are corresponding results. diff --git a/docs/zh_cn/tutorials/registry.md b/docs/zh_cn/tutorials/registry.md index 164cb0ea2b9dba8ba16b150289b0c88406179b7c..b0c5b7ce74abdfd0a18c43eb90a0f1802bd677a7 100644 --- a/docs/zh_cn/tutorials/registry.md +++ b/docs/zh_cn/tutorials/registry.md @@ -219,6 +219,7 @@ MMEngine 的注册器支æŒè·¨é¡¹ç›®è°ƒç”¨ï¼Œå³å¯ä»¥åœ¨ä¸€ä¸ªé¡¹ç›®ä¸ä½¿ç”¨ - DATA_SAMPLERS: `Dataloader` çš„ `sampler`ï¼Œç”¨äºŽé‡‡æ ·æ•°æ® - PIPELINES: å„ç§æ•°æ®é¢„处ç†ï¼Œå¦‚ `Resize`, `Reshape` - MODELS: 模型的å„ç§æ¨¡å— +- MODEL_WRAPPERS: 模型的包装器,如 `MMDistributedDataParallel`,用于对分布å¼æ•°æ®å¹¶è¡Œ - WEIGHT_INITIALIZERS: æƒé‡åˆå§‹åŒ–的工具 - OPTIMIZERS: 注册了 PyTorch ä¸æ‰€æœ‰çš„ `optimizer` 以åŠè‡ªå®šä¹‰çš„ `optimizer` - OPTIMIZER_CONSTRUCTORS: optimizer çš„æž„é€ å™¨ diff --git a/mmengine/__init__.py b/mmengine/__init__.py index 9e229e25e641a984933ecec67a3c181cec5564b9..55e7b929a4b0248231834eb21bbd3f99dc635407 100644 --- a/mmengine/__init__.py +++ b/mmengine/__init__.py @@ -1,8 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. # flake8: noqa from .config import * -from .dataset import * from .data import * +from .dataset import * from .fileio import * from .registry import * from .utils import * diff --git a/mmengine/model/__init__.py b/mmengine/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3620b7ff3f1c3e8b37be0a58f991c1703dc70a34 --- /dev/null +++ b/mmengine/model/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .wrappers import (MMDataParallel, MMDistributedDataParallel, + is_model_wrapper) + +__all__ = ['MMDistributedDataParallel', 'MMDataParallel', 'is_model_wrapper'] diff --git a/mmengine/model/wrappers/__init__.py b/mmengine/model/wrappers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1cab521decd4ac9cac834efc01c62113c17fafcd --- /dev/null +++ b/mmengine/model/wrappers/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .data_parallel import MMDataParallel, MMDistributedDataParallel +from .utils import is_model_wrapper + +__all__ = ['MMDistributedDataParallel', 'MMDataParallel', 'is_model_wrapper'] diff --git a/mmengine/model/wrappers/data_parallel.py b/mmengine/model/wrappers/data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..c2967ceae536a539ddef80810272f9381be80664 --- /dev/null +++ b/mmengine/model/wrappers/data_parallel.py @@ -0,0 +1,146 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import chain + +import torch +from torch.nn.parallel import DataParallel +from torch.nn.parallel.distributed import (DistributedDataParallel, + _find_tensors) + +from mmengine.registry import MODEL_WRAPPERS +from mmengine.utils import TORCH_VERSION, digit_version + + +@MODEL_WRAPPERS.register_module() +class MMDataParallel(DataParallel): + """There is no difference between MMDataParallel and pytorch's + DataParallel, "train_step" and "val_step" are added just to avoid bc + breaking. + + Warning: + MMDataParallel only supports single GPU training, if you + need to train with multiple GPUs, please use MMDistributedDataParallel + instead. If you have multiple GPUs and you just want to use + MMDataParallel, you can set the environment variable + ``CUDA_VISIBLE_DEVICES=0`` or instantiate ``MMDataParallel`` with + ``device_ids=[0]``. + """ + + def train_step(self, *inputs, **kwargs): + assert len(self.device_ids) == 1, \ + ('MMDataParallel only supports single GPU training, if you need to' + ' train with multiple GPUs, please use MMDistributedDataParallel' + ' instead.') + assert hasattr(self.module, 'train_step') + for t in chain(self.module.parameters(), self.module.buffers()): + if t.device != self.src_device_obj: + raise RuntimeError( + 'module must have its parameters and buffers ' + f'on device {self.src_device_obj} (device_ids[0]) but ' + f'found one of them on device: {t.device}') + return self.module.train_step(*inputs, **kwargs) + + def val_step(self, *inputs, **kwargs): + assert len(self.device_ids) == 1, \ + ('MMDataParallel only supports single GPU training, if you need to' + ' train with multiple GPUs, please use MMDistributedDataParallel' + ' instead.') + assert hasattr(self.module, 'val_step') + for t in chain(self.module.parameters(), self.module.buffers()): + if t.device != self.src_device_obj: + raise RuntimeError( + 'module must have its parameters and buffers ' + f'on device {self.src_device_obj} (device_ids[0]) but ' + f'found one of them on device: {t.device}') + return self.module.val_step(*inputs, **kwargs) + + +@MODEL_WRAPPERS.register_module() +class MMDistributedDataParallel(DistributedDataParallel): + """There is no difference between MMDistributedDataParallel and pytorch's + DistributedDataParallel, "train_step" and "val_step" are added just to + avoid bc breaking.""" + + def train_step(self, *inputs, **kwargs): + """train_step() API for module wrapped by DistributedDataParallel. + + This method is basically the same as + ``DistributedDataParallel.forward()``, while replacing + ``self.module.forward()`` with ``self.module.train_step()``. + It is compatible with PyTorch 1.1 - 1.5. + """ + + # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the + # end of backward to the beginning of forward. + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.7') + and self.reducer._rebuild_buckets()): + # TODO: replace with logger + print('Reducer buckets have been rebuilt in this iteration.') + + if getattr(self, 'require_forward_param_sync', True): + self._sync_params() + + if self.device_ids: + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: + output = self.module.train_step(*inputs[0], **kwargs[0]) + else: + outputs = self.parallel_apply( + self._module_copies[:len(inputs)], inputs, kwargs) + output = self.gather(outputs, self.output_device) + else: + output = self.module.train_step(*inputs, **kwargs) + + if torch.is_grad_enabled() and getattr( + self, 'require_backward_grad_sync', True): + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) > digit_version('1.2')): + self.require_forward_param_sync = False + return output + + def val_step(self, *inputs, **kwargs): + """val_step() API for module wrapped by DistributedDataParallel. + + This method is basically the same as + ``DistributedDataParallel.forward()``, while replacing + ``self.module.forward()`` with ``self.module.val_step()``. + It is compatible with PyTorch 1.1 - 1.5. + """ + + # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the + # end of backward to the beginning of forward. + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.7') + and self.reducer._rebuild_buckets()): + # TODO: replace with logger + print('Reducer buckets have been rebuilt in this iteration.') + + if getattr(self, 'require_forward_param_sync', True): + self._sync_params() + if self.device_ids: + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: + output = self.module.val_step(*inputs[0], **kwargs[0]) + else: + outputs = self.parallel_apply( + self._module_copies[:len(inputs)], inputs, kwargs) + output = self.gather(outputs, self.output_device) + else: + output = self.module.val_step(*inputs, **kwargs) + + if torch.is_grad_enabled() and getattr( + self, 'require_backward_grad_sync', True): + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) > digit_version('1.2')): + self.require_forward_param_sync = False + return output diff --git a/mmengine/model/wrappers/utils.py b/mmengine/model/wrappers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f888f49c92b20c0390873042403f264cc6b53b6b --- /dev/null +++ b/mmengine/model/wrappers/utils.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.registry import MODEL_WRAPPERS + + +def is_model_wrapper(model): + """Check if a module is a model wrapper. + + The following 4 model in MMEngine (and their subclasses) are regarded as + model wrappers: DataParallel, DistributedDataParallel, + MMDataParallel, MMDistributedDataParallel. You may add you own + model wrapper by registering it to mmengine.registry.MODEL_WRAPPERS. + + Args: + model (nn.Module): The model to be checked. + + Returns: + bool: True if the input model is a model wrapper. + """ + model_wrappers = tuple(MODEL_WRAPPERS.module_dict.values()) + return isinstance(model, model_wrappers) diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py index 24ebcb6e7f3546bb0e4b61e8c54b0701d52d052b..b0f2f6497b0251e9050b093bbac0bcb3678f96b3 100644 --- a/mmengine/registry/__init__.py +++ b/mmengine/registry/__init__.py @@ -1,13 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. from .registry import Registry, build_from_cfg -from .root import (DATA_SAMPLERS, DATASETS, EVALUATORS, HOOKS, MODELS, - OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, PARAM_SCHEDULERS, - RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, TRANSFORMS, - WEIGHT_INITIALIZERS) +from .root import (DATA_SAMPLERS, DATASETS, EVALUATORS, HOOKS, MODEL_WRAPPERS, + MODELS, OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, + PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, + TRANSFORMS, WEIGHT_INITIALIZERS) __all__ = [ 'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', - 'EVALUATORS' + 'EVALUATORS', 'MODEL_WRAPPERS' ] diff --git a/mmengine/registry/root.py b/mmengine/registry/root.py index 9636fae21f135b0d94124c53c3aecf5a814a3acf..5edf22212883653b5be0f4b280392f8cc923645c 100644 --- a/mmengine/registry/root.py +++ b/mmengine/registry/root.py @@ -22,6 +22,8 @@ TRANSFORMS = Registry('transform') # mangage all kinds of modules inheriting `nn.Module` MODELS = Registry('model') +# mangage all kinds of model wrappers like 'MMDistributedDataParallel' +MODEL_WRAPPERS = Registry('model_wrapper') # mangage all kinds of weight initialization modules like `Uniform` WEIGHT_INITIALIZERS = Registry('weight initializer') diff --git a/mmengine/utils/__init__.py b/mmengine/utils/__init__.py index c3ee0dd9816db386d57abfd7ef4a8f0f147d8bd9..cee1ac98064c60add1821ba60cd3dfb2aaa7ba51 100644 --- a/mmengine/utils/__init__.py +++ b/mmengine/utils/__init__.py @@ -6,8 +6,10 @@ from .misc import (check_prerequisites, concat_list, deprecated_api_warning, requires_executable, requires_package, slice_list, to_1tuple, to_2tuple, to_3tuple, to_4tuple, to_ntuple, tuple_cast) +from .parrots_wrapper import TORCH_VERSION from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist, scandir, symlink) +from .version_utils import digit_version, get_git_hash __all__ = [ 'is_str', 'iter_cast', 'list_cast', 'tuple_cast', 'is_seq_of', @@ -16,5 +18,6 @@ __all__ = [ 'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist', 'symlink', 'scandir', 'deprecated_api_warning', 'import_modules_from_strings', 'to_1tuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'to_ntuple', - 'is_method_overridden', 'has_method', 'mmcv_full_available' + 'is_method_overridden', 'has_method', 'mmcv_full_available', + 'digit_version', 'get_git_hash', 'TORCH_VERSION' ] diff --git a/mmengine/utils/version_utils.py b/mmengine/utils/version_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..963c45a2e8a86a88413ab6c18c22481fb9831985 --- /dev/null +++ b/mmengine/utils/version_utils.py @@ -0,0 +1,90 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import subprocess +import warnings + +from packaging.version import parse + + +def digit_version(version_str: str, length: int = 4): + """Convert a version string into a tuple of integers. + + This method is usually used for comparing two versions. For pre-release + versions: alpha < beta < rc. + + Args: + version_str (str): The version string. + length (int): The maximum number of version levels. Default: 4. + + Returns: + tuple[int]: The version info in digits (integers). + """ + assert 'parrots' not in version_str + version = parse(version_str) + assert version.release, f'failed to parse version {version_str}' + release = list(version.release) + release = release[:length] + if len(release) < length: + release = release + [0] * (length - len(release)) + if version.is_prerelease: + mapping = {'a': -3, 'b': -2, 'rc': -1} + val = -4 + # version.pre can be None + if version.pre: + if version.pre[0] not in mapping: + warnings.warn(f'unknown prerelease version {version.pre[0]}, ' + 'version checking may go wrong') + else: + val = mapping[version.pre[0]] + release.extend([val, version.pre[-1]]) + else: + release.extend([val, 0]) + + elif version.is_postrelease: + release.extend([1, version.post]) + else: + release.extend([0, 0]) + return tuple(release) + + +def _minimal_ext_cmd(cmd): + # construct minimal environment + env = {} + for k in ['SYSTEMROOT', 'PATH', 'HOME']: + v = os.environ.get(k) + if v is not None: + env[k] = v + # LANGUAGE is used on win32 + env['LANGUAGE'] = 'C' + env['LANG'] = 'C' + env['LC_ALL'] = 'C' + out = subprocess.Popen( + cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + return out + + +def get_git_hash(fallback='unknown', digits=None): + """Get the git hash of the current repo. + + Args: + fallback (str, optional): The fallback string when git hash is + unavailable. Defaults to 'unknown'. + digits (int, optional): kept digits of the hash. Defaults to None, + meaning all digits are kept. + + Returns: + str: Git commit hash. + """ + + if digits is not None and not isinstance(digits, int): + raise TypeError('digits must be None or an integer') + + try: + out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) + sha = out.strip().decode('ascii') + if digits is not None: + sha = sha[:digits] + except OSError: + sha = fallback + + return sha diff --git a/tests/test_model/test_wrappers/test_data_parallel.py b/tests/test_model/test_wrappers/test_data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..63518c23c306bef4c69388ce206d2dfa2c0861e0 --- /dev/null +++ b/tests/test_model/test_wrappers/test_data_parallel.py @@ -0,0 +1,126 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import MagicMock, patch + +import pytest +import torch +import torch.nn as nn + +from mmengine.model.wrappers import (MMDataParallel, MMDistributedDataParallel, + is_model_wrapper) +from mmengine.registry import MODEL_WRAPPERS + + +def mock(*args, **kwargs): + pass + + +@patch('torch.distributed._broadcast_coalesced', mock) +@patch('torch.distributed.broadcast', mock) +@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock) +def test_is_model_wrapper(): + + class Model(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(2, 2, 1) + + def forward(self, x): + return self.conv(x) + + # _verify_model_across_ranks is added in torch1.9.0 so we should check + # whether _verify_model_across_ranks is the member of torch.distributed + # before mocking + if hasattr(torch.distributed, '_verify_model_across_ranks'): + torch.distributed._verify_model_across_ranks = mock + + model = Model() + assert not is_model_wrapper(model) + + mmdp = MMDataParallel(model) + assert is_model_wrapper(mmdp) + + mmddp = MMDistributedDataParallel(model, process_group=MagicMock()) + assert is_model_wrapper(mmddp) + + # test model wrapper registry + @MODEL_WRAPPERS.register_module() + class ModelWrapper(object): + + def __init__(self, module): + self.module = module + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) + + model_wrapper = ModelWrapper(model) + assert is_model_wrapper(model_wrapper) + + +class TestMMDataParallel: + + def setUp(self): + """Setup the demo image in every test method. + + TestCase calls functions in this order: setUp() -> testMethod() -> + tearDown() -> cleanUp() + """ + + class Model(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(2, 2, 1) + + def forward(self, x): + return self.conv(x) + + def train_step(self, x): + return self.forward(x) + + def val_step(self, x): + return self.forward(x) + + self.model = Model() + + def test_train_step(self): + + class Model(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(1, 2, 1) + + def forward(self, x): + return self.conv(x) + + model = Model() + mmdp = MMDataParallel(model) + + # test without train_step attribute + with pytest.raises(AssertionError): + mmdp.train_step(torch.zeros([1, 1, 3, 3])) + + out = self.model.train_step([torch.zeros([1, 1, 3, 3])]) + assert out.shape == (1, 2, 3, 3) + + def test_val_step(self): + + class Model(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(1, 2, 1) + + def forward(self, x): + return self.conv(x) + + model = Model() + mmdp = MMDataParallel(model) + + # test without val_step attribute + with pytest.raises(AssertionError): + mmdp.val_step(torch.zeros([1, 1, 3, 3])) + + out = self.model.val_step([torch.zeros([1, 1, 3, 3])]) + assert out.shape == (1, 2, 3, 3)