From a9afdad7a81b46d5b533dd4cbcf5347dcd7dbea6 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Thu, 9 Jun 2022 11:45:19 +0800 Subject: [PATCH] [Fix] Fix BaseDataPreprocessor and BaseModel (#285) * fix BaseDataPreprocessor * fix BaseDataPreprocessor * change device type to torch.device * change device type to torch.device * fix cpu method of base model --- mmengine/model/__init__.py | 4 +- mmengine/model/base_model/base_model.py | 22 +++++++--- .../model/base_model/data_preprocessor.py | 40 +++++++++++-------- mmengine/model/utils.py | 10 ++--- .../test_base_model/test_base_model.py | 8 ++-- .../test_base_model/test_data_preprocessor.py | 15 ++++++- 6 files changed, 66 insertions(+), 33 deletions(-) diff --git a/mmengine/model/__init__.py b/mmengine/model/__init__.py index 0b7f08e7..47e2356c 100644 --- a/mmengine/model/__init__.py +++ b/mmengine/model/__init__.py @@ -3,7 +3,7 @@ from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA, StochasticWeightAverage) from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor from .base_module import BaseModule -from .utils import detect_anomalous_params, merge_dict, stach_batch_imgs +from .utils import detect_anomalous_params, merge_dict, stack_batch from .wrappers import (MMDistributedDataParallel, MMSeparateDistributedDataParallel, is_model_wrapper) @@ -11,6 +11,6 @@ __all__ = [ 'MMDistributedDataParallel', 'is_model_wrapper', 'StochasticWeightAverage', 'ExponentialMovingAverage', 'MomentumAnnealingEMA', 'BaseModel', 'BaseDataPreprocessor', 'ImgDataPreprocessor', - 'MMSeparateDistributedDataParallel', 'BaseModule', 'stach_batch_imgs', + 'MMSeparateDistributedDataParallel', 'BaseModule', 'stack_batch', 'merge_dict', 'detect_anomalous_params' ] diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py index 1155d999..fe970acb 100644 --- a/mmengine/model/base_model/base_model.py +++ b/mmengine/model/base_model/base_model.py @@ -179,8 +179,8 @@ class BaseModel(BaseModule): def to(self, device: Optional[Union[int, torch.device]], *args, **kwargs) -> nn.Module: - """Overrides this method to set the ``device`` attribute of - :obj:`BaseDataPreprocessor` additionally + """Overrides this method to call :meth:`BaseDataPreprocessor.to` + additionally. Args: device (int or torch.device, optional): the desired device of the @@ -189,19 +189,29 @@ class BaseModel(BaseModule): Returns: nn.Module: The model itself. """ - self.data_preprocessor.device = torch.device(device) + self.data_preprocessor.to(device) return super().to(device) def cuda(self, *args, **kwargs) -> nn.Module: - """Overrides this method to set the ``device`` attribute of - :obj:`BaseDataPreprocessor` additionally + """Overrides this method to call :meth:`BaseDataPreprocessor.cuda` + additionally. Returns: nn.Module: The model itself. """ - self.data_preprocessor.device = torch.cuda.current_device() + self.data_preprocessor.cuda() return super().cuda() + def cpu(self, *args, **kwargs) -> nn.Module: + """Overrides this method to call :meth:`BaseDataPreprocessor.cpu` + additionally. + + Returns: + nn.Module: The model itself. + """ + self.data_preprocessor.cpu() + return super().cpu() + @abstractmethod def forward(self, batch_inputs: torch.Tensor, diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index 2b9d2cb3..cf94b54d 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -6,7 +6,7 @@ import torch.nn as nn from mmengine.data import BaseDataElement from mmengine.registry import MODELS -from ..utils import stach_batch_imgs +from ..utils import stack_batch @MODELS.register_module() @@ -25,18 +25,15 @@ class BaseDataPreprocessor(nn.Module): forward method to implement custom data pre-processing, such as batch-resize, MixUp, or CutMix. - Args: - device (int or torch.device): Target device. - Warnings: Each item of data sampled from dataloader must be a dict and at least contain the ``inputs`` key. Furthermore, the value of ``inputs`` must be a ``Tensor`` with the same shape. """ - def __init__(self, device: Union[int, torch.device] = 'cpu'): + def __init__(self): super().__init__() - self.device = device + self._device = torch.device('cpu') def collate_data( self, @@ -56,7 +53,7 @@ class BaseDataPreprocessor(nn.Module): Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input tensor and list of labels at target device. """ - inputs = [_data['inputs'].to(self.device) for _data in data] + inputs = [_data['inputs'].to(self._device) for _data in data] batch_data_samples: List[BaseDataElement] = [] # Model can get predictions without any data samples. for _data in data: @@ -64,7 +61,7 @@ class BaseDataPreprocessor(nn.Module): batch_data_samples.append(_data['data_sample']) # Move data from CPU to corresponding device. batch_data_samples = [ - data_sample.to(self.device) for data_sample in batch_data_samples + data_sample.to(self._device) for data_sample in batch_data_samples ] if not batch_data_samples: @@ -93,6 +90,10 @@ class BaseDataPreprocessor(nn.Module): batch_inputs = torch.stack(inputs, dim=0) return batch_inputs, batch_data_samples + @property + def device(self): + return self._device + def to(self, device: Optional[Union[int, torch.device]], *args, **kwargs) -> nn.Module: """Overrides this method to set the :attr:`device` @@ -104,7 +105,7 @@ class BaseDataPreprocessor(nn.Module): Returns: nn.Module: The model itself. """ - self.device = torch.device(device) + self._device = torch.device(device) return super().to(device) def cuda(self, *args, **kwargs) -> nn.Module: @@ -113,9 +114,18 @@ class BaseDataPreprocessor(nn.Module): Returns: nn.Module: The model itself. """ - self.device = torch.cuda.current_device() + self._device = torch.device(torch.cuda.current_device()) return super().cuda() + def cpu(self, *args, **kwargs) -> nn.Module: + """Overrides this method to set the :attr:`device` + + Returns: + nn.Module: The model itself. + """ + self._device = torch.device('cpu') + return super().cpu() + @MODELS.register_module() class ImgDataPreprocessor(BaseDataPreprocessor): @@ -158,7 +168,6 @@ class ImgDataPreprocessor(BaseDataPreprocessor): Defaults to False. rgb_to_bgr (bool): whether to convert image from RGB to RGB. Defaults to False. - device (int or torch.device): Target device. """ def __init__(self, @@ -167,9 +176,8 @@ class ImgDataPreprocessor(BaseDataPreprocessor): pad_size_divisor: int = 1, pad_value: Union[float, int] = 0, bgr_to_rgb: bool = False, - rgb_to_bgr: bool = False, - device: Union[int, torch.device] = 'cpu'): - super().__init__(device) + rgb_to_bgr: bool = False): + super().__init__() assert len(mean) == 3 or len(mean) == 1, ( 'The length of mean should be 1 or 3 to be compatible with RGB ' f'or gray image, but got {len(mean)}') @@ -208,6 +216,6 @@ class ImgDataPreprocessor(BaseDataPreprocessor): # Normalization. inputs = [(_input - self.mean) / self.std for _input in inputs] # Pad and stack Tensor. - batch_inputs = stach_batch_imgs(inputs, self.pad_size_divisor, - self.pad_value) + batch_inputs = stack_batch(inputs, self.pad_size_divisor, + self.pad_value) return batch_inputs, batch_data_samples diff --git a/mmengine/model/utils.py b/mmengine/model/utils.py index 29cc779b..b3dd5e17 100644 --- a/mmengine/model/utils.py +++ b/mmengine/model/utils.py @@ -673,10 +673,10 @@ def trunc_normal_(tensor: Tensor, return _no_grad_trunc_normal_(tensor, mean, std, a, b) -def stach_batch_imgs(tensor_list: List[torch.Tensor], - pad_size_divisor: int = 1, - pad_value: Union[int, float] = 0) -> torch.Tensor: - """Stack multiple tensors to form a batch and pad the images to the max +def stack_batch(tensor_list: List[torch.Tensor], + pad_size_divisor: int = 1, + pad_value: Union[int, float] = 0) -> torch.Tensor: + """Stack multiple tensors to form a batch and pad the tensor to the max shape use the right bottom padding mode in these images. If ``pad_size_divisor > 0``, add padding to ensure the shape of each dim is divisible by ``pad_size_divisor``. @@ -690,7 +690,7 @@ def stach_batch_imgs(tensor_list: List[torch.Tensor], pad_value (int, float): The padding value. Defaults to 0. Returns: - Tensor: The 4D-tensor. + Tensor: The n dim tensor. """ assert isinstance( tensor_list, diff --git a/tests/test_model/test_base_model/test_base_model.py b/tests/test_model/test_base_model/test_base_model.py index 222e703d..0abac656 100644 --- a/tests/test_model/test_base_model/test_base_model.py +++ b/tests/test_model/test_base_model/test_base_model.py @@ -115,11 +115,13 @@ class TestBaseModel(TestCase): inputs = torch.randn(3, 1, 1).cuda() data = dict(inputs=inputs) model = ToyModel().cuda() - model.val_step([data]) + out = model.val_step([data]) + self.assertEqual(out.device.type, 'cuda') @unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available') def test_to(self): - inputs = torch.randn(3, 1, 1).cuda() + inputs = torch.randn(3, 1, 1).to('cuda:0') data = dict(inputs=inputs) model = ToyModel().to(torch.cuda.current_device()) - model.val_step([data]) + out = model.val_step([data]) + self.assertEqual(out.device.type, 'cuda') diff --git a/tests/test_model/test_base_model/test_data_preprocessor.py b/tests/test_model/test_base_model/test_data_preprocessor.py index 146ed35e..0639b59c 100644 --- a/tests/test_model/test_base_model/test_data_preprocessor.py +++ b/tests/test_model/test_base_model/test_data_preprocessor.py @@ -13,7 +13,7 @@ class TestBaseDataPreprocessor(TestCase): def test_init(self): base_data_preprocessor = BaseDataPreprocessor() - self.assertEqual(base_data_preprocessor.device, 'cpu') + self.assertEqual(base_data_preprocessor._device.type, 'cpu') def test_forward(self): base_data_preprocessor = BaseDataPreprocessor() @@ -35,6 +35,19 @@ class TestBaseDataPreprocessor(TestCase): assert_allclose(label1, batch_labels[0]) assert_allclose(label2, batch_labels[1]) + if torch.cuda.is_available(): + base_data_preprocessor = base_data_preprocessor.cuda() + batch_inputs, batch_labels = base_data_preprocessor(data) + self.assertEqual(batch_inputs.device.type, 'cuda') + + base_data_preprocessor = base_data_preprocessor.cpu() + batch_inputs, batch_labels = base_data_preprocessor(data) + self.assertEqual(batch_inputs.device.type, 'cpu') + + base_data_preprocessor = base_data_preprocessor.to('cuda:0') + batch_inputs, batch_labels = base_data_preprocessor(data) + self.assertEqual(batch_inputs.device.type, 'cuda') + class TestImageDataPreprocessor(TestBaseDataPreprocessor): -- GitLab