Skip to content
Snippets Groups Projects
Unverified Commit 576e5c8f authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Regist pytorch ddp and dp to `MODEL_WRAPPERS`, add unit test to `is_model_wrapper` (#474)

* regist pytorch ddp and dp, add unit test

* minor refine

* Support check custom wrapper

* enhance ut
parent d0a74f9a
No related branches found
No related tags found
No related merge requests found
...@@ -2,12 +2,15 @@ ...@@ -2,12 +2,15 @@
from typing import Any, Dict, Union from typing import Any, Dict, Union
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel from torch.nn.parallel import DataParallel, DistributedDataParallel
from mmengine.optim import OptimWrapper from mmengine.optim import OptimWrapper
from mmengine.registry import MODEL_WRAPPERS from mmengine.registry import MODEL_WRAPPERS
from ..utils import detect_anomalous_params from ..utils import detect_anomalous_params
MODEL_WRAPPERS.register_module(module=DistributedDataParallel)
MODEL_WRAPPERS.register_module(module=DataParallel)
@MODEL_WRAPPERS.register_module() @MODEL_WRAPPERS.register_module()
class MMDistributedDataParallel(DistributedDataParallel): class MMDistributedDataParallel(DistributedDataParallel):
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from mmengine.registry import MODEL_WRAPPERS import torch.nn as nn
from mmengine.registry import MODEL_WRAPPERS, Registry
def is_model_wrapper(model):
def is_model_wrapper(model: nn.Module, registry: Registry = MODEL_WRAPPERS):
"""Check if a module is a model wrapper. """Check if a module is a model wrapper.
The following 4 model in MMEngine (and their subclasses) are regarded as The following 4 model in MMEngine (and their subclasses) are regarded as
...@@ -12,9 +14,17 @@ def is_model_wrapper(model): ...@@ -12,9 +14,17 @@ def is_model_wrapper(model):
Args: Args:
model (nn.Module): The model to be checked. model (nn.Module): The model to be checked.
registry (Registry): The parent registry to search for model wrappers.
Returns: Returns:
bool: True if the input model is a model wrapper. bool: True if the input model is a model wrapper.
""" """
model_wrappers = tuple(MODEL_WRAPPERS.module_dict.values()) module_wrappers = tuple(registry.module_dict.values())
return isinstance(model, model_wrappers) if isinstance(model, module_wrappers):
return True
if not registry.children:
return False
for child in registry.children.values():
return is_model_wrapper(model, child)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed import destroy_process_group, init_process_group
from torch.nn.parallel import DataParallel, DistributedDataParallel
from mmengine.model import revert_sync_batchnorm from mmengine.model import (MMDistributedDataParallel,
MMSeparateDistributedDataParallel,
is_model_wrapper, revert_sync_batchnorm)
from mmengine.registry import MODEL_WRAPPERS, Registry
@pytest.mark.skipif( @pytest.mark.skipif(
...@@ -18,3 +25,53 @@ def test_revert_syncbn(): ...@@ -18,3 +25,53 @@ def test_revert_syncbn():
conv = revert_sync_batchnorm(conv) conv = revert_sync_batchnorm(conv)
y = conv(x) y = conv(x)
assert y.shape == (1, 8, 9, 9) assert y.shape == (1, 8, 9, 9)
def test_is_model_wrapper():
# Test basic module wrapper.
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29510'
os.environ['RANK'] = str(0)
init_process_group(backend='gloo', rank=0, world_size=1)
model = nn.Linear(1, 1)
for wrapper in [
DistributedDataParallel, MMDistributedDataParallel,
MMSeparateDistributedDataParallel, DataParallel
]:
wrapper_model = wrapper(model)
assert is_model_wrapper(wrapper_model)
# Test `is_model_wrapper` can check model wrapper registered in custom
# registry.
CHILD_REGISTRY = Registry('test_is_model_wrapper', parent=MODEL_WRAPPERS)
class CustomModelWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.module = model
pass
CHILD_REGISTRY.register_module(module=CustomModelWrapper)
for wrapper in [
DistributedDataParallel, MMDistributedDataParallel,
MMSeparateDistributedDataParallel, DataParallel, CustomModelWrapper
]:
wrapper_model = wrapper(model)
assert is_model_wrapper(wrapper_model)
# Test `is_model_wrapper` will not check model wrapper in parent
# registry from a child registry.
for wrapper in [
DistributedDataParallel, MMDistributedDataParallel,
MMSeparateDistributedDataParallel, DataParallel
]:
wrapper_model = wrapper(model)
assert not is_model_wrapper(wrapper_model, registry=CHILD_REGISTRY)
wrapper_model = CustomModelWrapper(model)
assert is_model_wrapper(wrapper_model, registry=CHILD_REGISTRY)
destroy_process_group()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment