From 3d830a28b6437d8ed4c352cc2e65c850bef26968 Mon Sep 17 00:00:00 2001 From: RangiLyu <lyuchqi@gmail.com> Date: Fri, 8 Apr 2022 22:18:23 +0800 Subject: [PATCH] [Fix]: Fix is_model_wrapper and add DistSamplerSeedHook to default hooks. (#172) * [Fix]: Fix model_wrapper and add DistSamplerSeedHook as default hook. * add comments --- mmengine/hooks/sampler_seed_hook.py | 14 +++++--- mmengine/model/wrappers/data_parallel.py | 3 ++ mmengine/runner/runner.py | 7 +++- .../test_wrappers/test_data_parallel.py | 8 +++++ tests/test_runner/test_runner.py | 32 ++++++++++--------- 5 files changed, 44 insertions(+), 20 deletions(-) diff --git a/mmengine/hooks/sampler_seed_hook.py b/mmengine/hooks/sampler_seed_hook.py index b90657c8..9ddfd7ab 100644 --- a/mmengine/hooks/sampler_seed_hook.py +++ b/mmengine/hooks/sampler_seed_hook.py @@ -20,11 +20,17 @@ class DistSamplerSeedHook(Hook): Args: runner (Runner): The runner of the training process. """ - if hasattr(runner.train_loop.dataloader.sampler, 'set_epoch'): - # in case the data loader uses `SequentialSampler` in Pytorch + if hasattr(runner.train_loop.dataloader, 'sampler') and hasattr( + runner.train_loop.dataloader.sampler, 'set_epoch'): + # In case the` _SingleProcessDataLoaderIter` has no sampler, + # or data loader uses `SequentialSampler` in Pytorch. runner.train_loop.dataloader.sampler.set_epoch(runner.epoch) - elif hasattr(runner.train_loop.dataloader.batch_sampler.sampler, - 'set_epoch'): + + elif hasattr(runner.train_loop.dataloader, + 'batch_sampler') and hasattr( + runner.train_loop.dataloader.batch_sampler.sampler, + 'set_epoch'): + # In case the` _SingleProcessDataLoaderIter` has no batch sampler. # batch sampler in pytorch warps the sampler as its attributes. runner.train_loop.dataloader.batch_sampler.sampler.set_epoch( runner.epoch) diff --git a/mmengine/model/wrappers/data_parallel.py b/mmengine/model/wrappers/data_parallel.py index c2967cea..d31b009c 100644 --- a/mmengine/model/wrappers/data_parallel.py +++ b/mmengine/model/wrappers/data_parallel.py @@ -9,6 +9,9 @@ from torch.nn.parallel.distributed import (DistributedDataParallel, from mmengine.registry import MODEL_WRAPPERS from mmengine.utils import TORCH_VERSION, digit_version +MODEL_WRAPPERS.register_module(module=DataParallel) +MODEL_WRAPPERS.register_module(module=DistributedDataParallel) + @MODEL_WRAPPERS.register_module() class MMDataParallel(DataParallel): diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 28d339ca..12ecccd1 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -1397,8 +1397,13 @@ class Runner: # Add comments to describe the usage of `after_load_ckpt` self.call_hook('after_load_ckpt', checkpoint=checkpoint) + if is_model_wrapper(self.model): + model = self.model.module + else: + model = self.model + checkpoint = _load_checkpoint_to_model( - self.model, checkpoint, strict, revise_keys=revise_keys) + model, checkpoint, strict, revise_keys=revise_keys) self._has_loaded = True diff --git a/tests/test_model/test_wrappers/test_data_parallel.py b/tests/test_model/test_wrappers/test_data_parallel.py index fa3e9993..34b0cad2 100644 --- a/tests/test_model/test_wrappers/test_data_parallel.py +++ b/tests/test_model/test_wrappers/test_data_parallel.py @@ -5,6 +5,8 @@ from unittest.mock import MagicMock, patch import pytest import torch import torch.nn as nn +from torch.nn.parallel import DataParallel +from torch.nn.parallel.distributed import DistributedDataParallel from mmengine.model.wrappers import (MMDataParallel, MMDistributedDataParallel, is_model_wrapper) @@ -44,6 +46,12 @@ def test_is_model_wrapper(): mmddp = MMDistributedDataParallel(model, process_group=MagicMock()) assert is_model_wrapper(mmddp) + torch_dp = DataParallel(model) + assert is_model_wrapper(torch_dp) + + torch_ddp = DistributedDataParallel(model, process_group=MagicMock()) + assert is_model_wrapper(torch_ddp) + # test model wrapper registry @MODEL_WRAPPERS.register_module() class ModelWrapper(object): diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index fa1b9e5b..d5f403b2 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -14,8 +14,8 @@ from torch.utils.data import DataLoader, Dataset from mmengine.config import Config from mmengine.data import DefaultSampler from mmengine.evaluator import BaseMetric, Evaluator -from mmengine.hooks import (Hook, IterTimerHook, LoggerHook, OptimizerHook, - ParamSchedulerHook) +from mmengine.hooks import (DistSamplerSeedHook, Hook, IterTimerHook, + LoggerHook, OptimizerHook, ParamSchedulerHook) from mmengine.hooks.checkpoint_hook import CheckpointHook from mmengine.logging import MessageHub, MMLogger from mmengine.optim.scheduler import MultiStepLR, StepLR @@ -913,33 +913,35 @@ class TestRunner(TestCase): # register five hooks by default runner.register_default_hooks() - self.assertEqual(len(runner._hooks), 5) - # the forth registered hook should be `ParamSchedulerHook` - self.assertTrue(isinstance(runner._hooks[3], ParamSchedulerHook)) + self.assertEqual(len(runner._hooks), 6) + # the third registered hook should be `DistSamplerSeedHook` + self.assertTrue(isinstance(runner._hooks[2], DistSamplerSeedHook)) + # the fifth registered hook should be `ParamSchedulerHook` + self.assertTrue(isinstance(runner._hooks[4], ParamSchedulerHook)) runner._hooks = [] # remove `ParamSchedulerHook` from default hooks runner.register_default_hooks(hooks=dict(timer=None)) - self.assertEqual(len(runner._hooks), 4) + self.assertEqual(len(runner._hooks), 5) # `ParamSchedulerHook` was popped so the forth is `CheckpointHook` - self.assertTrue(isinstance(runner._hooks[3], CheckpointHook)) + self.assertTrue(isinstance(runner._hooks[4], CheckpointHook)) # add a new default hook runner._hooks = [] runner.register_default_hooks(hooks=dict(ToyHook=dict(type='ToyHook'))) - self.assertEqual(len(runner._hooks), 6) - self.assertTrue(isinstance(runner._hooks[5], ToyHook)) + self.assertEqual(len(runner._hooks), 7) + self.assertTrue(isinstance(runner._hooks[6], ToyHook)) def test_custom_hooks(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_custom_hooks' runner = Runner.from_cfg(cfg) - self.assertEqual(len(runner._hooks), 5) + self.assertEqual(len(runner._hooks), 6) custom_hooks = [dict(type='ToyHook')] runner.register_custom_hooks(custom_hooks) - self.assertEqual(len(runner._hooks), 6) - self.assertTrue(isinstance(runner._hooks[5], ToyHook)) + self.assertEqual(len(runner._hooks), 7) + self.assertTrue(isinstance(runner._hooks[6], ToyHook)) def test_register_hooks(self): cfg = copy.deepcopy(self.epoch_based_cfg) @@ -949,9 +951,9 @@ class TestRunner(TestCase): runner._hooks = [] custom_hooks = [dict(type='ToyHook')] runner.register_hooks(custom_hooks=custom_hooks) - # five default hooks + custom hook (ToyHook) - self.assertEqual(len(runner._hooks), 6) - self.assertTrue(isinstance(runner._hooks[5], ToyHook)) + # six default hooks + custom hook (ToyHook) + self.assertEqual(len(runner._hooks), 7) + self.assertTrue(isinstance(runner._hooks[6], ToyHook)) def test_custom_loop(self): # test custom loop with additional hook -- GitLab