diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py index b07b8210654503fea3b7bac8f2dcd6d737ebb3f4..889dd4c455c7db4054ec52e7536b093442ff26b7 100644 --- a/mmengine/model/wrappers/distributed.py +++ b/mmengine/model/wrappers/distributed.py @@ -2,12 +2,15 @@ from typing import Any, Dict, Union import torch -from torch.nn.parallel.distributed import DistributedDataParallel +from torch.nn.parallel import DataParallel, DistributedDataParallel from mmengine.optim import OptimWrapper from mmengine.registry import MODEL_WRAPPERS from ..utils import detect_anomalous_params +MODEL_WRAPPERS.register_module(module=DistributedDataParallel) +MODEL_WRAPPERS.register_module(module=DataParallel) + @MODEL_WRAPPERS.register_module() class MMDistributedDataParallel(DistributedDataParallel): diff --git a/mmengine/model/wrappers/utils.py b/mmengine/model/wrappers/utils.py index f952e9b16c6f2a155087014e020cdfc80ddd3ffa..843d5f314bbd6d8c91996358d19aa3a65e5fc949 100644 --- a/mmengine/model/wrappers/utils.py +++ b/mmengine/model/wrappers/utils.py @@ -1,8 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmengine.registry import MODEL_WRAPPERS +import torch.nn as nn +from mmengine.registry import MODEL_WRAPPERS, Registry -def is_model_wrapper(model): + +def is_model_wrapper(model: nn.Module, registry: Registry = MODEL_WRAPPERS): """Check if a module is a model wrapper. The following 4 model in MMEngine (and their subclasses) are regarded as @@ -12,9 +14,17 @@ def is_model_wrapper(model): Args: model (nn.Module): The model to be checked. + registry (Registry): The parent registry to search for model wrappers. Returns: bool: True if the input model is a model wrapper. """ - model_wrappers = tuple(MODEL_WRAPPERS.module_dict.values()) - return isinstance(model, model_wrappers) + module_wrappers = tuple(registry.module_dict.values()) + if isinstance(model, module_wrappers): + return True + + if not registry.children: + return False + + for child in registry.children.values(): + return is_model_wrapper(model, child) diff --git a/tests/test_model/test_model_utils.py b/tests/test_model/test_model_utils.py index 4861533cf308115c9a55422eb43bcca7b04d124d..24e9ca7f268275672c415894e0972bfe91c0ca1a 100644 --- a/tests/test_model/test_model_utils.py +++ b/tests/test_model/test_model_utils.py @@ -1,9 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os + import pytest import torch import torch.nn as nn +from torch.distributed import destroy_process_group, init_process_group +from torch.nn.parallel import DataParallel, DistributedDataParallel -from mmengine.model import revert_sync_batchnorm +from mmengine.model import (MMDistributedDataParallel, + MMSeparateDistributedDataParallel, + is_model_wrapper, revert_sync_batchnorm) +from mmengine.registry import MODEL_WRAPPERS, Registry @pytest.mark.skipif( @@ -18,3 +25,53 @@ def test_revert_syncbn(): conv = revert_sync_batchnorm(conv) y = conv(x) assert y.shape == (1, 8, 9, 9) + + +def test_is_model_wrapper(): + # Test basic module wrapper. + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29510' + os.environ['RANK'] = str(0) + init_process_group(backend='gloo', rank=0, world_size=1) + model = nn.Linear(1, 1) + + for wrapper in [ + DistributedDataParallel, MMDistributedDataParallel, + MMSeparateDistributedDataParallel, DataParallel + ]: + wrapper_model = wrapper(model) + assert is_model_wrapper(wrapper_model) + + # Test `is_model_wrapper` can check model wrapper registered in custom + # registry. + CHILD_REGISTRY = Registry('test_is_model_wrapper', parent=MODEL_WRAPPERS) + + class CustomModelWrapper(nn.Module): + + def __init__(self, model): + super().__init__() + self.module = model + + pass + + CHILD_REGISTRY.register_module(module=CustomModelWrapper) + + for wrapper in [ + DistributedDataParallel, MMDistributedDataParallel, + MMSeparateDistributedDataParallel, DataParallel, CustomModelWrapper + ]: + wrapper_model = wrapper(model) + assert is_model_wrapper(wrapper_model) + + # Test `is_model_wrapper` will not check model wrapper in parent + # registry from a child registry. + for wrapper in [ + DistributedDataParallel, MMDistributedDataParallel, + MMSeparateDistributedDataParallel, DataParallel + ]: + wrapper_model = wrapper(model) + assert not is_model_wrapper(wrapper_model, registry=CHILD_REGISTRY) + + wrapper_model = CustomModelWrapper(model) + assert is_model_wrapper(wrapper_model, registry=CHILD_REGISTRY) + destroy_process_group()