From 3d830a28b6437d8ed4c352cc2e65c850bef26968 Mon Sep 17 00:00:00 2001
From: RangiLyu <lyuchqi@gmail.com>
Date: Fri, 8 Apr 2022 22:18:23 +0800
Subject: [PATCH] [Fix]: Fix is_model_wrapper and add DistSamplerSeedHook to
 default hooks. (#172)

* [Fix]: Fix model_wrapper and add DistSamplerSeedHook as default hook.

* add comments
---
 mmengine/hooks/sampler_seed_hook.py           | 14 +++++---
 mmengine/model/wrappers/data_parallel.py      |  3 ++
 mmengine/runner/runner.py                     |  7 +++-
 .../test_wrappers/test_data_parallel.py       |  8 +++++
 tests/test_runner/test_runner.py              | 32 ++++++++++---------
 5 files changed, 44 insertions(+), 20 deletions(-)

diff --git a/mmengine/hooks/sampler_seed_hook.py b/mmengine/hooks/sampler_seed_hook.py
index b90657c8..9ddfd7ab 100644
--- a/mmengine/hooks/sampler_seed_hook.py
+++ b/mmengine/hooks/sampler_seed_hook.py
@@ -20,11 +20,17 @@ class DistSamplerSeedHook(Hook):
         Args:
             runner (Runner): The runner of the training process.
         """
-        if hasattr(runner.train_loop.dataloader.sampler, 'set_epoch'):
-            # in case the data loader uses `SequentialSampler` in Pytorch
+        if hasattr(runner.train_loop.dataloader, 'sampler') and hasattr(
+                runner.train_loop.dataloader.sampler, 'set_epoch'):
+            # In case the` _SingleProcessDataLoaderIter` has no sampler,
+            # or data loader uses `SequentialSampler` in Pytorch.
             runner.train_loop.dataloader.sampler.set_epoch(runner.epoch)
-        elif hasattr(runner.train_loop.dataloader.batch_sampler.sampler,
-                     'set_epoch'):
+
+        elif hasattr(runner.train_loop.dataloader,
+                     'batch_sampler') and hasattr(
+                         runner.train_loop.dataloader.batch_sampler.sampler,
+                         'set_epoch'):
+            # In case the` _SingleProcessDataLoaderIter` has no batch sampler.
             # batch sampler in pytorch warps the sampler as its attributes.
             runner.train_loop.dataloader.batch_sampler.sampler.set_epoch(
                 runner.epoch)
diff --git a/mmengine/model/wrappers/data_parallel.py b/mmengine/model/wrappers/data_parallel.py
index c2967cea..d31b009c 100644
--- a/mmengine/model/wrappers/data_parallel.py
+++ b/mmengine/model/wrappers/data_parallel.py
@@ -9,6 +9,9 @@ from torch.nn.parallel.distributed import (DistributedDataParallel,
 from mmengine.registry import MODEL_WRAPPERS
 from mmengine.utils import TORCH_VERSION, digit_version
 
+MODEL_WRAPPERS.register_module(module=DataParallel)
+MODEL_WRAPPERS.register_module(module=DistributedDataParallel)
+
 
 @MODEL_WRAPPERS.register_module()
 class MMDataParallel(DataParallel):
diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py
index 28d339ca..12ecccd1 100644
--- a/mmengine/runner/runner.py
+++ b/mmengine/runner/runner.py
@@ -1397,8 +1397,13 @@ class Runner:
         # Add comments to describe the usage of `after_load_ckpt`
         self.call_hook('after_load_ckpt', checkpoint=checkpoint)
 
+        if is_model_wrapper(self.model):
+            model = self.model.module
+        else:
+            model = self.model
+
         checkpoint = _load_checkpoint_to_model(
-            self.model, checkpoint, strict, revise_keys=revise_keys)
+            model, checkpoint, strict, revise_keys=revise_keys)
 
         self._has_loaded = True
 
diff --git a/tests/test_model/test_wrappers/test_data_parallel.py b/tests/test_model/test_wrappers/test_data_parallel.py
index fa3e9993..34b0cad2 100644
--- a/tests/test_model/test_wrappers/test_data_parallel.py
+++ b/tests/test_model/test_wrappers/test_data_parallel.py
@@ -5,6 +5,8 @@ from unittest.mock import MagicMock, patch
 import pytest
 import torch
 import torch.nn as nn
+from torch.nn.parallel import DataParallel
+from torch.nn.parallel.distributed import DistributedDataParallel
 
 from mmengine.model.wrappers import (MMDataParallel, MMDistributedDataParallel,
                                      is_model_wrapper)
@@ -44,6 +46,12 @@ def test_is_model_wrapper():
     mmddp = MMDistributedDataParallel(model, process_group=MagicMock())
     assert is_model_wrapper(mmddp)
 
+    torch_dp = DataParallel(model)
+    assert is_model_wrapper(torch_dp)
+
+    torch_ddp = DistributedDataParallel(model, process_group=MagicMock())
+    assert is_model_wrapper(torch_ddp)
+
     # test model wrapper registry
     @MODEL_WRAPPERS.register_module()
     class ModelWrapper(object):
diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py
index fa1b9e5b..d5f403b2 100644
--- a/tests/test_runner/test_runner.py
+++ b/tests/test_runner/test_runner.py
@@ -14,8 +14,8 @@ from torch.utils.data import DataLoader, Dataset
 from mmengine.config import Config
 from mmengine.data import DefaultSampler
 from mmengine.evaluator import BaseMetric, Evaluator
-from mmengine.hooks import (Hook, IterTimerHook, LoggerHook, OptimizerHook,
-                            ParamSchedulerHook)
+from mmengine.hooks import (DistSamplerSeedHook, Hook, IterTimerHook,
+                            LoggerHook, OptimizerHook, ParamSchedulerHook)
 from mmengine.hooks.checkpoint_hook import CheckpointHook
 from mmengine.logging import MessageHub, MMLogger
 from mmengine.optim.scheduler import MultiStepLR, StepLR
@@ -913,33 +913,35 @@ class TestRunner(TestCase):
 
         # register five hooks by default
         runner.register_default_hooks()
-        self.assertEqual(len(runner._hooks), 5)
-        # the forth registered hook should be `ParamSchedulerHook`
-        self.assertTrue(isinstance(runner._hooks[3], ParamSchedulerHook))
+        self.assertEqual(len(runner._hooks), 6)
+        # the third registered hook should be `DistSamplerSeedHook`
+        self.assertTrue(isinstance(runner._hooks[2], DistSamplerSeedHook))
+        # the fifth registered hook should be `ParamSchedulerHook`
+        self.assertTrue(isinstance(runner._hooks[4], ParamSchedulerHook))
 
         runner._hooks = []
         # remove `ParamSchedulerHook` from default hooks
         runner.register_default_hooks(hooks=dict(timer=None))
-        self.assertEqual(len(runner._hooks), 4)
+        self.assertEqual(len(runner._hooks), 5)
         # `ParamSchedulerHook` was popped so the forth is `CheckpointHook`
-        self.assertTrue(isinstance(runner._hooks[3], CheckpointHook))
+        self.assertTrue(isinstance(runner._hooks[4], CheckpointHook))
 
         # add a new default hook
         runner._hooks = []
         runner.register_default_hooks(hooks=dict(ToyHook=dict(type='ToyHook')))
-        self.assertEqual(len(runner._hooks), 6)
-        self.assertTrue(isinstance(runner._hooks[5], ToyHook))
+        self.assertEqual(len(runner._hooks), 7)
+        self.assertTrue(isinstance(runner._hooks[6], ToyHook))
 
     def test_custom_hooks(self):
         cfg = copy.deepcopy(self.epoch_based_cfg)
         cfg.experiment_name = 'test_custom_hooks'
         runner = Runner.from_cfg(cfg)
 
-        self.assertEqual(len(runner._hooks), 5)
+        self.assertEqual(len(runner._hooks), 6)
         custom_hooks = [dict(type='ToyHook')]
         runner.register_custom_hooks(custom_hooks)
-        self.assertEqual(len(runner._hooks), 6)
-        self.assertTrue(isinstance(runner._hooks[5], ToyHook))
+        self.assertEqual(len(runner._hooks), 7)
+        self.assertTrue(isinstance(runner._hooks[6], ToyHook))
 
     def test_register_hooks(self):
         cfg = copy.deepcopy(self.epoch_based_cfg)
@@ -949,9 +951,9 @@ class TestRunner(TestCase):
         runner._hooks = []
         custom_hooks = [dict(type='ToyHook')]
         runner.register_hooks(custom_hooks=custom_hooks)
-        # five default hooks + custom hook (ToyHook)
-        self.assertEqual(len(runner._hooks), 6)
-        self.assertTrue(isinstance(runner._hooks[5], ToyHook))
+        # six default hooks + custom hook (ToyHook)
+        self.assertEqual(len(runner._hooks), 7)
+        self.assertTrue(isinstance(runner._hooks[6], ToyHook))
 
     def test_custom_loop(self):
         # test custom loop with additional hook
-- 
GitLab