diff --git a/mmengine/model/wrappers/distributed.py b/mmengine/model/wrappers/distributed.py
index b07b8210654503fea3b7bac8f2dcd6d737ebb3f4..889dd4c455c7db4054ec52e7536b093442ff26b7 100644
--- a/mmengine/model/wrappers/distributed.py
+++ b/mmengine/model/wrappers/distributed.py
@@ -2,12 +2,15 @@
 from typing import Any, Dict, Union
 
 import torch
-from torch.nn.parallel.distributed import DistributedDataParallel
+from torch.nn.parallel import DataParallel, DistributedDataParallel
 
 from mmengine.optim import OptimWrapper
 from mmengine.registry import MODEL_WRAPPERS
 from ..utils import detect_anomalous_params
 
+MODEL_WRAPPERS.register_module(module=DistributedDataParallel)
+MODEL_WRAPPERS.register_module(module=DataParallel)
+
 
 @MODEL_WRAPPERS.register_module()
 class MMDistributedDataParallel(DistributedDataParallel):
diff --git a/mmengine/model/wrappers/utils.py b/mmengine/model/wrappers/utils.py
index f952e9b16c6f2a155087014e020cdfc80ddd3ffa..843d5f314bbd6d8c91996358d19aa3a65e5fc949 100644
--- a/mmengine/model/wrappers/utils.py
+++ b/mmengine/model/wrappers/utils.py
@@ -1,8 +1,10 @@
 # Copyright (c) OpenMMLab. All rights reserved.
-from mmengine.registry import MODEL_WRAPPERS
+import torch.nn as nn
 
+from mmengine.registry import MODEL_WRAPPERS, Registry
 
-def is_model_wrapper(model):
+
+def is_model_wrapper(model: nn.Module, registry: Registry = MODEL_WRAPPERS):
     """Check if a module is a model wrapper.
 
     The following 4 model in MMEngine (and their subclasses) are regarded as
@@ -12,9 +14,17 @@ def is_model_wrapper(model):
 
     Args:
         model (nn.Module): The model to be checked.
+        registry (Registry): The parent registry to search for model wrappers.
 
     Returns:
         bool: True if the input model is a model wrapper.
     """
-    model_wrappers = tuple(MODEL_WRAPPERS.module_dict.values())
-    return isinstance(model, model_wrappers)
+    module_wrappers = tuple(registry.module_dict.values())
+    if isinstance(model, module_wrappers):
+        return True
+
+    if not registry.children:
+        return False
+
+    for child in registry.children.values():
+        return is_model_wrapper(model, child)
diff --git a/tests/test_model/test_model_utils.py b/tests/test_model/test_model_utils.py
index 4861533cf308115c9a55422eb43bcca7b04d124d..24e9ca7f268275672c415894e0972bfe91c0ca1a 100644
--- a/tests/test_model/test_model_utils.py
+++ b/tests/test_model/test_model_utils.py
@@ -1,9 +1,16 @@
 # Copyright (c) OpenMMLab. All rights reserved.
+import os
+
 import pytest
 import torch
 import torch.nn as nn
+from torch.distributed import destroy_process_group, init_process_group
+from torch.nn.parallel import DataParallel, DistributedDataParallel
 
-from mmengine.model import revert_sync_batchnorm
+from mmengine.model import (MMDistributedDataParallel,
+                            MMSeparateDistributedDataParallel,
+                            is_model_wrapper, revert_sync_batchnorm)
+from mmengine.registry import MODEL_WRAPPERS, Registry
 
 
 @pytest.mark.skipif(
@@ -18,3 +25,53 @@ def test_revert_syncbn():
     conv = revert_sync_batchnorm(conv)
     y = conv(x)
     assert y.shape == (1, 8, 9, 9)
+
+
+def test_is_model_wrapper():
+    # Test basic module wrapper.
+    os.environ['MASTER_ADDR'] = '127.0.0.1'
+    os.environ['MASTER_PORT'] = '29510'
+    os.environ['RANK'] = str(0)
+    init_process_group(backend='gloo', rank=0, world_size=1)
+    model = nn.Linear(1, 1)
+
+    for wrapper in [
+            DistributedDataParallel, MMDistributedDataParallel,
+            MMSeparateDistributedDataParallel, DataParallel
+    ]:
+        wrapper_model = wrapper(model)
+        assert is_model_wrapper(wrapper_model)
+
+    # Test `is_model_wrapper` can check model wrapper registered in custom
+    # registry.
+    CHILD_REGISTRY = Registry('test_is_model_wrapper', parent=MODEL_WRAPPERS)
+
+    class CustomModelWrapper(nn.Module):
+
+        def __init__(self, model):
+            super().__init__()
+            self.module = model
+
+        pass
+
+    CHILD_REGISTRY.register_module(module=CustomModelWrapper)
+
+    for wrapper in [
+            DistributedDataParallel, MMDistributedDataParallel,
+            MMSeparateDistributedDataParallel, DataParallel, CustomModelWrapper
+    ]:
+        wrapper_model = wrapper(model)
+        assert is_model_wrapper(wrapper_model)
+
+    # Test `is_model_wrapper` will not check model wrapper in parent
+    # registry from a child registry.
+    for wrapper in [
+            DistributedDataParallel, MMDistributedDataParallel,
+            MMSeparateDistributedDataParallel, DataParallel
+    ]:
+        wrapper_model = wrapper(model)
+        assert not is_model_wrapper(wrapper_model, registry=CHILD_REGISTRY)
+
+    wrapper_model = CustomModelWrapper(model)
+    assert is_model_wrapper(wrapper_model, registry=CHILD_REGISTRY)
+    destroy_process_group()