Skip to content
Snippets Groups Projects
Unverified Commit 576e5c8f authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Regist pytorch ddp and dp to `MODEL_WRAPPERS`, add unit test to `is_model_wrapper` (#474)

* regist pytorch ddp and dp, add unit test

* minor refine

* Support check custom wrapper

* enhance ut
parent d0a74f9a
No related branches found
No related tags found
No related merge requests found
......@@ -2,12 +2,15 @@
from typing import Any, Dict, Union
import torch
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.nn.parallel import DataParallel, DistributedDataParallel
from mmengine.optim import OptimWrapper
from mmengine.registry import MODEL_WRAPPERS
from ..utils import detect_anomalous_params
MODEL_WRAPPERS.register_module(module=DistributedDataParallel)
MODEL_WRAPPERS.register_module(module=DataParallel)
@MODEL_WRAPPERS.register_module()
class MMDistributedDataParallel(DistributedDataParallel):
......
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.registry import MODEL_WRAPPERS
import torch.nn as nn
from mmengine.registry import MODEL_WRAPPERS, Registry
def is_model_wrapper(model):
def is_model_wrapper(model: nn.Module, registry: Registry = MODEL_WRAPPERS):
"""Check if a module is a model wrapper.
The following 4 model in MMEngine (and their subclasses) are regarded as
......@@ -12,9 +14,17 @@ def is_model_wrapper(model):
Args:
model (nn.Module): The model to be checked.
registry (Registry): The parent registry to search for model wrappers.
Returns:
bool: True if the input model is a model wrapper.
"""
model_wrappers = tuple(MODEL_WRAPPERS.module_dict.values())
return isinstance(model, model_wrappers)
module_wrappers = tuple(registry.module_dict.values())
if isinstance(model, module_wrappers):
return True
if not registry.children:
return False
for child in registry.children.values():
return is_model_wrapper(model, child)
# Copyright (c) OpenMMLab. All rights reserved.
import os
import pytest
import torch
import torch.nn as nn
from torch.distributed import destroy_process_group, init_process_group
from torch.nn.parallel import DataParallel, DistributedDataParallel
from mmengine.model import revert_sync_batchnorm
from mmengine.model import (MMDistributedDataParallel,
MMSeparateDistributedDataParallel,
is_model_wrapper, revert_sync_batchnorm)
from mmengine.registry import MODEL_WRAPPERS, Registry
@pytest.mark.skipif(
......@@ -18,3 +25,53 @@ def test_revert_syncbn():
conv = revert_sync_batchnorm(conv)
y = conv(x)
assert y.shape == (1, 8, 9, 9)
def test_is_model_wrapper():
# Test basic module wrapper.
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29510'
os.environ['RANK'] = str(0)
init_process_group(backend='gloo', rank=0, world_size=1)
model = nn.Linear(1, 1)
for wrapper in [
DistributedDataParallel, MMDistributedDataParallel,
MMSeparateDistributedDataParallel, DataParallel
]:
wrapper_model = wrapper(model)
assert is_model_wrapper(wrapper_model)
# Test `is_model_wrapper` can check model wrapper registered in custom
# registry.
CHILD_REGISTRY = Registry('test_is_model_wrapper', parent=MODEL_WRAPPERS)
class CustomModelWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.module = model
pass
CHILD_REGISTRY.register_module(module=CustomModelWrapper)
for wrapper in [
DistributedDataParallel, MMDistributedDataParallel,
MMSeparateDistributedDataParallel, DataParallel, CustomModelWrapper
]:
wrapper_model = wrapper(model)
assert is_model_wrapper(wrapper_model)
# Test `is_model_wrapper` will not check model wrapper in parent
# registry from a child registry.
for wrapper in [
DistributedDataParallel, MMDistributedDataParallel,
MMSeparateDistributedDataParallel, DataParallel
]:
wrapper_model = wrapper(model)
assert not is_model_wrapper(wrapper_model, registry=CHILD_REGISTRY)
wrapper_model = CustomModelWrapper(model)
assert is_model_wrapper(wrapper_model, registry=CHILD_REGISTRY)
destroy_process_group()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment