diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py index ede27b7a745a09ed816263b6c0863cfa37368f9a..1155d999b532fcecf35c3af99ebc1a375eadb439 100644 --- a/mmengine/model/base_model/base_model.py +++ b/mmengine/model/base_model/base_model.py @@ -57,21 +57,21 @@ class BaseModel(BaseModule): >>> return dict(loss=loss) Args: - init_cfg (dict, optional): The weight initialized config for - :class:`BaseModule`. data_preprocessor (dict, optional): The pre-process config of :class:`BaseDataPreprocessor`. + init_cfg (dict, optional): The weight initialized config for + :class:`BaseModule`. Attributes: - init_cfg (dict, optional): Initialization config dict. data_preprocessor (:obj:`BaseDataPreprocessor`): Used for pre-processing data sampled by dataloader to the format accepted by :meth:`forward`. + init_cfg (dict, optional): Initialization config dict. """ def __init__(self, - init_cfg: Optional[dict] = None, - data_preprocessor: Optional[Union[dict, nn.Module]] = None): + data_preprocessor: Optional[Union[dict, nn.Module]] = None, + init_cfg: Optional[dict] = None): super().__init__(init_cfg) if data_preprocessor is None: data_preprocessor = dict(type='BaseDataPreprocessor') diff --git a/tests/test_model/test_base_model/test_base_model.py b/tests/test_model/test_base_model/test_base_model.py index 280fcead669cd8ead5361d2eff486e0aa89c6a67..222e703d53128af4d7172a1cc82fc65868f6c050 100644 --- a/tests/test_model/test_base_model/test_base_model.py +++ b/tests/test_model/test_base_model/test_base_model.py @@ -25,7 +25,7 @@ class CustomDataPreprocessor(BaseDataPreprocessor): class ToyModel(BaseModel): def __init__(self, data_preprocessor=None): - super().__init__(None, data_preprocessor=data_preprocessor) + super().__init__(data_preprocessor=data_preprocessor, init_cfg=None) self.conv = nn.Conv2d(3, 1, 1) def forward(self, batch_inputs, data_samples=None, mode='tensor'):