From 9b2a0e02da840b474f41c9bbc35335dab39eb844 Mon Sep 17 00:00:00 2001 From: Ma Zerun <mzr1996@163.com> Date: Tue, 9 Aug 2022 11:25:29 +0800 Subject: [PATCH] [Enhance] Add `data_preprocessor` config as an argument of runner. (#343) * [Enhance] Add `preprocess_cfg` as an argument of runner. * Rename `preprocess_cfg` to `data_preprocessor` * Fix docstring --- mmengine/runner/runner.py | 10 ++++++++++ .../test_base_model/test_base_model.py | 2 +- .../test_base_model/test_data_preprocessor.py | 4 ++-- tests/test_runner/test_runner.py | 19 ++++++++++++++++--- 4 files changed, 29 insertions(+), 6 deletions(-) diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 5c31cb84..9b542779 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -141,6 +141,11 @@ class Runner: custom_hooks (list[dict] or list[Hook], optional): Hooks to execute custom actions like visualizing images processed by pipeline. Defaults to None. + data_preprocessor (dict, optional): The pre-process config of + :class:`BaseDataPreprocessor`. If the ``model`` argument is a dict + and doesn't contain the key ``data_preprocessor``, set the argument + as the ``data_preprocessor`` of the ``model`` dict. + Defaults to None. load_from (str, optional): The checkpoint file to load from. Defaults to None. resume (bool): Whether to resume training. Defaults to False. If @@ -244,6 +249,7 @@ class Runner: test_evaluator: Optional[Union[Evaluator, Dict, List]] = None, default_hooks: Optional[Dict[str, Union[Hook, Dict]]] = None, custom_hooks: Optional[List[Union[Hook, Dict]]] = None, + data_preprocessor: Union[nn.Module, Dict, None] = None, load_from: Optional[str] = None, resume: bool = False, launcher: str = 'none', @@ -389,6 +395,9 @@ class Runner: self._has_loaded = False # build a model + if isinstance(model, dict) and data_preprocessor is not None: + # Merge the data_preprocessor to model config. + model.setdefault('data_preprocessor', data_preprocessor) self.model = self.build_model(model) # wrap model self.model = self.wrap_model( @@ -435,6 +444,7 @@ class Runner: test_evaluator=cfg.get('test_evaluator'), default_hooks=cfg.get('default_hooks'), custom_hooks=cfg.get('custom_hooks'), + data_preprocessor=cfg.get('data_preprocessor'), load_from=cfg.get('load_from'), resume=cfg.get('resume', False), launcher=cfg.get('launcher', 'none'), diff --git a/tests/test_model/test_base_model/test_base_model.py b/tests/test_model/test_base_model/test_base_model.py index 0abac656..1883e95b 100644 --- a/tests/test_model/test_base_model/test_base_model.py +++ b/tests/test_model/test_base_model/test_base_model.py @@ -43,7 +43,7 @@ class ToyModel(BaseModel): class TestBaseModel(TestCase): def test_init(self): - # initiate model without `preprocess_cfg` + # initiate model without `data_preprocessor` model = ToyModel() self.assertIsInstance(model.data_preprocessor, BaseDataPreprocessor) data_preprocessor = dict(type='CustomDataPreprocessor') diff --git a/tests/test_model/test_base_model/test_data_preprocessor.py b/tests/test_model/test_base_model/test_data_preprocessor.py index 99cbf289..0bdf2c36 100644 --- a/tests/test_model/test_base_model/test_data_preprocessor.py +++ b/tests/test_model/test_base_model/test_data_preprocessor.py @@ -56,14 +56,14 @@ class TestBaseDataPreprocessor(TestCase): class TestImgataPreprocessor(TestBaseDataPreprocessor): def test_init(self): - # initiate model without `preprocess_cfg` + # initiate model without `data_preprocessor` data_processor = ImgDataPreprocessor() self.assertFalse(data_processor.channel_conversion) self.assertFalse(hasattr(data_processor, 'mean')) self.assertFalse(hasattr(data_processor, 'std')) self.assertEqual(data_processor.pad_size_divisor, 1) assert_allclose(data_processor.pad_value, torch.tensor(0)) - # initiate model with preprocess_cfg` and feat keys + # initiate model with data_preprocessor` and feat keys data_processor = ImgDataPreprocessor( bgr_to_rgb=True, mean=[0, 0, 0], diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index ebceb5ce..3abd34c3 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -20,7 +20,7 @@ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook, IterTimerHook, LoggerHook, ParamSchedulerHook, RuntimeInfoHook) from mmengine.logging import LogProcessor, MessageHub, MMLogger -from mmengine.model import BaseModel +from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR, OptimWrapper, OptimWrapperDict, StepLR) from mmengine.registry import (DATASETS, EVALUATOR, HOOKS, LOG_PROCESSORS, @@ -38,8 +38,8 @@ from mmengine.visualization import Visualizer @MODELS.register_module() class ToyModel(BaseModel): - def __init__(self): - super().__init__() + def __init__(self, data_preprocessor=None): + super().__init__(data_preprocessor=data_preprocessor) self.linear1 = nn.Linear(2, 2) self.linear2 = nn.Linear(2, 1) @@ -291,6 +291,7 @@ class CustomRunner(Runner): test_evaluator=None, default_hooks=None, custom_hooks=None, + data_preprocessor=None, load_from=None, resume=False, launcher='none', @@ -360,6 +361,7 @@ class TestRunner(TestCase): checkpoint=dict( type='CheckpointHook', interval=1, by_epoch=True), sampler_seed=dict(type='DistSamplerSeedHook')), + data_preprocessor=None, launcher='none', env_cfg=dict(dist_cfg=dict(backend='nccl')), ) @@ -764,6 +766,17 @@ class TestRunner(TestCase): cfg.experiment_name = 'test_build_model' runner = Runner.from_cfg(cfg) self.assertIsInstance(runner.model, ToyModel) + self.assertIsInstance(runner.model.data_preprocessor, + BaseDataPreprocessor) + + cfg = copy.deepcopy(self.epoch_based_cfg) + cfg.experiment_name = 'test_data_preprocessor' + cfg.data_preprocessor = dict(type='ImgDataPreprocessor') + runner = Runner.from_cfg(cfg) + # data_preprocessor is passed to used if no `data_preprocessor` + # in model config. + self.assertIsInstance(runner.model.data_preprocessor, + ImgDataPreprocessor) # input should be a nn.Module object or dict with self.assertRaisesRegex(TypeError, 'model should be'): -- GitLab