diff --git a/mmengine/model/wrappers/seperate_distributed.py b/mmengine/model/wrappers/seperate_distributed.py index 8e04f05983686e185e8f363de66d5147b2732d04..d1e2caa2b328d45809fc2cbd6144a605438324ee 100644 --- a/mmengine/model/wrappers/seperate_distributed.py +++ b/mmengine/model/wrappers/seperate_distributed.py @@ -46,16 +46,16 @@ class MMSeparateDistributedDataParallel(DistributedDataParallel): device = get_device() # Wrap the submodule with parameters of `self.module` to # `MMDistributedDataParallel` - for name, _module in module._modules.items(): + for name, sub_module in module._modules.items(): # module without parameters. - if next(_module.parameters(), None) is None: - _module = _module.to(device) - elif all(not p.requires_grad for p in module.parameters()): - _module = _module.to(device) + if next(sub_module.parameters(), None) is None: + sub_module = sub_module.to(device) + elif all(not p.requires_grad for p in sub_module.parameters()): + sub_module = sub_module.to(device) else: - _module = MMDistributedDataParallel( - module=_module.to(device), *args, **kwargs) - module._modules[name] = _module + sub_module = MMDistributedDataParallel( + module=sub_module.to(device), *args, **kwargs) + module._modules[name] = sub_module def train_step(self, data: List[dict], optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]: diff --git a/tests/test_model/test_wrappers/test_model_wrapper.py b/tests/test_model/test_wrappers/test_model_wrapper.py index 90ae3643778ad20b09a3ca3f19041527f1d4dd0c..7634fe7d6571375f01191a69063d44523a87c5fa 100644 --- a/tests/test_model/test_wrappers/test_model_wrapper.py +++ b/tests/test_model/test_wrappers/test_model_wrapper.py @@ -11,6 +11,7 @@ from torch.optim import SGD from mmengine.dist import all_gather from mmengine.model import (BaseModel, MMDistributedDataParallel, MMSeparateDistributedDataParallel) +from mmengine.model.averaged_model import ExponentialMovingAverage from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict from mmengine.testing import assert_allclose from mmengine.testing._internal import MultiProcessTestCase @@ -18,7 +19,7 @@ from mmengine.utils.parrots_wrapper import TORCH_VERSION from mmengine.utils.version_utils import digit_version if digit_version(TORCH_VERSION) >= digit_version('1.11.0'): - from mmengine.model import MMFullyShardedDataParallel + from mmengine.model import MMFullyShardedDataParallel # noqa: F401 class ToyModel(BaseModel): @@ -132,6 +133,17 @@ class TestDistributedDataParallel(MultiProcessTestCase): not torch.cuda.is_available(), reason='cuda should be available') class TestMMSeparateDistributedDataParallel(TestDistributedDataParallel): + def test_init(self): + self._init_dist_env(self.rank, self.world_size) + model = ComplexModel() + model.ema = ExponentialMovingAverage(nn.Conv2d(1, 1, 1)) + model.act = nn.ReLU() + ddp_model = MMSeparateDistributedDataParallel(model.cuda()) + self.assertIsInstance(ddp_model.module.ema, ExponentialMovingAverage) + self.assertIsInstance(ddp_model.module.conv1, + MMDistributedDataParallel) + self.assertIsInstance(ddp_model.module.act, nn.ReLU) + def test_train_step(self): self._init_dist_env(self.rank, self.world_size) # Test `optim_wrapper` is a dict. In this case,