From 6ee675430f603e764a37946c329956984846eb18 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Wed, 8 Jun 2022 13:28:00 +0800 Subject: [PATCH] [Refactor]: change order of BaseModel arguments (#282) --- mmengine/model/base_model/base_model.py | 10 +++++----- tests/test_model/test_base_model/test_base_model.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py index ede27b7a..1155d999 100644 --- a/mmengine/model/base_model/base_model.py +++ b/mmengine/model/base_model/base_model.py @@ -57,21 +57,21 @@ class BaseModel(BaseModule): >>> return dict(loss=loss) Args: - init_cfg (dict, optional): The weight initialized config for - :class:`BaseModule`. data_preprocessor (dict, optional): The pre-process config of :class:`BaseDataPreprocessor`. + init_cfg (dict, optional): The weight initialized config for + :class:`BaseModule`. Attributes: - init_cfg (dict, optional): Initialization config dict. data_preprocessor (:obj:`BaseDataPreprocessor`): Used for pre-processing data sampled by dataloader to the format accepted by :meth:`forward`. + init_cfg (dict, optional): Initialization config dict. """ def __init__(self, - init_cfg: Optional[dict] = None, - data_preprocessor: Optional[Union[dict, nn.Module]] = None): + data_preprocessor: Optional[Union[dict, nn.Module]] = None, + init_cfg: Optional[dict] = None): super().__init__(init_cfg) if data_preprocessor is None: data_preprocessor = dict(type='BaseDataPreprocessor') diff --git a/tests/test_model/test_base_model/test_base_model.py b/tests/test_model/test_base_model/test_base_model.py index 280fcead..222e703d 100644 --- a/tests/test_model/test_base_model/test_base_model.py +++ b/tests/test_model/test_base_model/test_base_model.py @@ -25,7 +25,7 @@ class CustomDataPreprocessor(BaseDataPreprocessor): class ToyModel(BaseModel): def __init__(self, data_preprocessor=None): - super().__init__(None, data_preprocessor=data_preprocessor) + super().__init__(data_preprocessor=data_preprocessor, init_cfg=None) self.conv = nn.Conv2d(3, 1, 1) def forward(self, batch_inputs, data_samples=None, mode='tensor'): -- GitLab