From 576e5c8f9148913d1fc4b3545ac22708e141c318 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Sun, 28 Aug 2022 17:09:41 +0800 Subject: [PATCH] [Fix] Regist pytorch ddp and dp to `MODEL_WRAPPERS`, add unit test to `is_model_wrapper` (#474) * regist pytorch ddp and dp, add unit test * minor refine * Support check custom wrapper * enhance ut --- mmengine/model/wrappers/distributed.py | 5 ++- mmengine/model/wrappers/utils.py | 18 ++++++-- tests/test_model/test_model_utils.py | 59 +++++++++++++++++++++++++- 3 files changed, 76 insertions(+), 6 deletions(-) diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py index b07b8210..889dd4c4 100644 --- a/mmengine/model/wrappers/distributed.py +++ b/mmengine/model/wrappers/distributed.py @@ -2,12 +2,15 @@ from typing import Any, Dict, Union import torch -from torch.nn.parallel.distributed import DistributedDataParallel +from torch.nn.parallel import DataParallel, DistributedDataParallel from mmengine.optim import OptimWrapper from mmengine.registry import MODEL_WRAPPERS from ..utils import detect_anomalous_params +MODEL_WRAPPERS.register_module(module=DistributedDataParallel) +MODEL_WRAPPERS.register_module(module=DataParallel) + @MODEL_WRAPPERS.register_module() class MMDistributedDataParallel(DistributedDataParallel): diff --git a/mmengine/model/wrappers/utils.py b/mmengine/model/wrappers/utils.py index f952e9b1..843d5f31 100644 --- a/mmengine/model/wrappers/utils.py +++ b/mmengine/model/wrappers/utils.py @@ -1,8 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -from mmengine.registry import MODEL_WRAPPERS +import torch.nn as nn +from mmengine.registry import MODEL_WRAPPERS, Registry -def is_model_wrapper(model): + +def is_model_wrapper(model: nn.Module, registry: Registry = MODEL_WRAPPERS): """Check if a module is a model wrapper. The following 4 model in MMEngine (and their subclasses) are regarded as @@ -12,9 +14,17 @@ def is_model_wrapper(model): Args: model (nn.Module): The model to be checked. + registry (Registry): The parent registry to search for model wrappers. Returns: bool: True if the input model is a model wrapper. """ - model_wrappers = tuple(MODEL_WRAPPERS.module_dict.values()) - return isinstance(model, model_wrappers) + module_wrappers = tuple(registry.module_dict.values()) + if isinstance(model, module_wrappers): + return True + + if not registry.children: + return False + + for child in registry.children.values(): + return is_model_wrapper(model, child) diff --git a/tests/test_model/test_model_utils.py b/tests/test_model/test_model_utils.py index 4861533c..24e9ca7f 100644 --- a/tests/test_model/test_model_utils.py +++ b/tests/test_model/test_model_utils.py @@ -1,9 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os + import pytest import torch import torch.nn as nn +from torch.distributed import destroy_process_group, init_process_group +from torch.nn.parallel import DataParallel, DistributedDataParallel -from mmengine.model import revert_sync_batchnorm +from mmengine.model import (MMDistributedDataParallel, + MMSeparateDistributedDataParallel, + is_model_wrapper, revert_sync_batchnorm) +from mmengine.registry import MODEL_WRAPPERS, Registry @pytest.mark.skipif( @@ -18,3 +25,53 @@ def test_revert_syncbn(): conv = revert_sync_batchnorm(conv) y = conv(x) assert y.shape == (1, 8, 9, 9) + + +def test_is_model_wrapper(): + # Test basic module wrapper. + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = '29510' + os.environ['RANK'] = str(0) + init_process_group(backend='gloo', rank=0, world_size=1) + model = nn.Linear(1, 1) + + for wrapper in [ + DistributedDataParallel, MMDistributedDataParallel, + MMSeparateDistributedDataParallel, DataParallel + ]: + wrapper_model = wrapper(model) + assert is_model_wrapper(wrapper_model) + + # Test `is_model_wrapper` can check model wrapper registered in custom + # registry. + CHILD_REGISTRY = Registry('test_is_model_wrapper', parent=MODEL_WRAPPERS) + + class CustomModelWrapper(nn.Module): + + def __init__(self, model): + super().__init__() + self.module = model + + pass + + CHILD_REGISTRY.register_module(module=CustomModelWrapper) + + for wrapper in [ + DistributedDataParallel, MMDistributedDataParallel, + MMSeparateDistributedDataParallel, DataParallel, CustomModelWrapper + ]: + wrapper_model = wrapper(model) + assert is_model_wrapper(wrapper_model) + + # Test `is_model_wrapper` will not check model wrapper in parent + # registry from a child registry. + for wrapper in [ + DistributedDataParallel, MMDistributedDataParallel, + MMSeparateDistributedDataParallel, DataParallel + ]: + wrapper_model = wrapper(model) + assert not is_model_wrapper(wrapper_model, registry=CHILD_REGISTRY) + + wrapper_model = CustomModelWrapper(model) + assert is_model_wrapper(wrapper_model, registry=CHILD_REGISTRY) + destroy_process_group() -- GitLab