Skip to content
Snippets Groups Projects
Unverified Commit 5ac3c233 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix]: fix MMSeparateDistributedDataParallel (#338)

parent d624fa91
No related branches found
No related tags found
No related merge requests found
...@@ -46,16 +46,16 @@ class MMSeparateDistributedDataParallel(DistributedDataParallel): ...@@ -46,16 +46,16 @@ class MMSeparateDistributedDataParallel(DistributedDataParallel):
device = get_device() device = get_device()
# Wrap the submodule with parameters of `self.module` to # Wrap the submodule with parameters of `self.module` to
# `MMDistributedDataParallel` # `MMDistributedDataParallel`
for name, _module in module._modules.items(): for name, sub_module in module._modules.items():
# module without parameters. # module without parameters.
if next(_module.parameters(), None) is None: if next(sub_module.parameters(), None) is None:
_module = _module.to(device) sub_module = sub_module.to(device)
elif all(not p.requires_grad for p in module.parameters()): elif all(not p.requires_grad for p in sub_module.parameters()):
_module = _module.to(device) sub_module = sub_module.to(device)
else: else:
_module = MMDistributedDataParallel( sub_module = MMDistributedDataParallel(
module=_module.to(device), *args, **kwargs) module=sub_module.to(device), *args, **kwargs)
module._modules[name] = _module module._modules[name] = sub_module
def train_step(self, data: List[dict], def train_step(self, data: List[dict],
optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]: optim_wrapper: OptimWrapperDict) -> Dict[str, torch.Tensor]:
......
...@@ -11,6 +11,7 @@ from torch.optim import SGD ...@@ -11,6 +11,7 @@ from torch.optim import SGD
from mmengine.dist import all_gather from mmengine.dist import all_gather
from mmengine.model import (BaseModel, MMDistributedDataParallel, from mmengine.model import (BaseModel, MMDistributedDataParallel,
MMSeparateDistributedDataParallel) MMSeparateDistributedDataParallel)
from mmengine.model.averaged_model import ExponentialMovingAverage
from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict from mmengine.optim import AmpOptimWrapper, OptimWrapper, OptimWrapperDict
from mmengine.testing import assert_allclose from mmengine.testing import assert_allclose
from mmengine.testing._internal import MultiProcessTestCase from mmengine.testing._internal import MultiProcessTestCase
...@@ -18,7 +19,7 @@ from mmengine.utils.parrots_wrapper import TORCH_VERSION ...@@ -18,7 +19,7 @@ from mmengine.utils.parrots_wrapper import TORCH_VERSION
from mmengine.utils.version_utils import digit_version from mmengine.utils.version_utils import digit_version
if digit_version(TORCH_VERSION) >= digit_version('1.11.0'): if digit_version(TORCH_VERSION) >= digit_version('1.11.0'):
from mmengine.model import MMFullyShardedDataParallel from mmengine.model import MMFullyShardedDataParallel # noqa: F401
class ToyModel(BaseModel): class ToyModel(BaseModel):
...@@ -132,6 +133,17 @@ class TestDistributedDataParallel(MultiProcessTestCase): ...@@ -132,6 +133,17 @@ class TestDistributedDataParallel(MultiProcessTestCase):
not torch.cuda.is_available(), reason='cuda should be available') not torch.cuda.is_available(), reason='cuda should be available')
class TestMMSeparateDistributedDataParallel(TestDistributedDataParallel): class TestMMSeparateDistributedDataParallel(TestDistributedDataParallel):
def test_init(self):
self._init_dist_env(self.rank, self.world_size)
model = ComplexModel()
model.ema = ExponentialMovingAverage(nn.Conv2d(1, 1, 1))
model.act = nn.ReLU()
ddp_model = MMSeparateDistributedDataParallel(model.cuda())
self.assertIsInstance(ddp_model.module.ema, ExponentialMovingAverage)
self.assertIsInstance(ddp_model.module.conv1,
MMDistributedDataParallel)
self.assertIsInstance(ddp_model.module.act, nn.ReLU)
def test_train_step(self): def test_train_step(self):
self._init_dist_env(self.rank, self.world_size) self._init_dist_env(self.rank, self.world_size)
# Test `optim_wrapper` is a dict. In this case, # Test `optim_wrapper` is a dict. In this case,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment