From 5ac3c23338a401606afb7d27f047bee5d9c5ab67 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Tue, 28 Jun 2022 22:20:20 +0800 Subject: [PATCH] [Fix]: fix MMSeparateDistributedDataParallel (#338) --- mmengine/model/wrappers/seperate_distributed.py | 16 ++++++++-------- .../test_wrappers/test_model_wrapper.py | 14 +++++++++++++- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/mmengine/model/wrappers/seperate_distributed.py b/mmengine/model/wrappers/seperate_distributed.py index 8e04f059..d1e2caa2 100644 --- a/mmengine/model/wrappers/seperate_distributed.py +++ b/mmengine/model/wrappers/seperate_distributed.py @@ -46,16 +46,16 @@ class MMSeparateDistributedDataParallel(DistributedDataParallel): device = get_device() # Wrap the submodule with parameters of `self.module` to # `MMDistributedDataParallel` - for name, _module in module._modules.items(): + for name, sub_module in module._modules.items(): # module without parameters. - if next(_module.parameters(), None) is None: - _module = _module.to(device) - elif all(not p.requires_grad for p in module.parameters()): - _module = _module.to(device) + if next(sub_module.parameters(), None) is None: + sub_module = sub_module.to(device) + elif all(not p.requires_grad for p in sub_module.parameters()): + sub_module = sub_module.to(device) else: - _module = MMDistributedDataParallel( - module=_module.to(device), *args, **kwargs) - module._modules[name] = _module + sub_module = MMDistributedDataParallel( + module=sub_module.to(device), *args, **kwargs) + module._modules[name] = sub_module def train_step(self, data: List[dict], optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]: diff --git a/tests/test_model/test_wrappers/test_model_wrapper.py b/tests/test_model/test_wrappers/test_model_wrapper.py index 90ae3643..7634fe7d 100644 --- a/tests/test_model/test_wrappers/test_model_wrapper.py +++ b/tests/test_model/test_wrappers/test_model_wrapper.py @@ -11,6 +11,7 @@ from torch.optim import SGD from mmengine.dist import all_gather from mmengine.model import (BaseModel, MMDistributedDataParallel, MMSeparateDistributedDataParallel) +from mmengine.model.averaged_model import ExponentialMovingAverage from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict from mmengine.testing import assert_allclose from mmengine.testing._internal import MultiProcessTestCase @@ -18,7 +19,7 @@ from mmengine.utils.parrots_wrapper import TORCH_VERSION from mmengine.utils.version_utils import digit_version if digit_version(TORCH_VERSION) >= digit_version('1.11.0'): - from mmengine.model import MMFullyShardedDataParallel + from mmengine.model import MMFullyShardedDataParallel # noqa: F401 class ToyModel(BaseModel): @@ -132,6 +133,17 @@ class TestDistributedDataParallel(MultiProcessTestCase): not torch.cuda.is_available(), reason='cuda should be available') class TestMMSeparateDistributedDataParallel(TestDistributedDataParallel): + def test_init(self): + self._init_dist_env(self.rank, self.world_size) + model = ComplexModel() + model.ema = ExponentialMovingAverage(nn.Conv2d(1, 1, 1)) + model.act = nn.ReLU() + ddp_model = MMSeparateDistributedDataParallel(model.cuda()) + self.assertIsInstance(ddp_model.module.ema, ExponentialMovingAverage) + self.assertIsInstance(ddp_model.module.conv1, + MMDistributedDataParallel) + self.assertIsInstance(ddp_model.module.act, nn.ReLU) + def test_train_step(self): self._init_dist_env(self.rank, self.world_size) # Test `optim_wrapper` is a dict. In this case, -- GitLab