From d624fa9191e0f1a5c79f366a03d809a7a6be3859 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Tue, 28 Jun 2022 11:46:12 +0800 Subject: [PATCH] [Enhance] assert image shape before forward (#300) * assert image shape before forward * add unit test * enhance error message * allow gray image input * fix as comment * fix unit test * fix unit test --- .../model/base_model/data_preprocessor.py | 37 ++++++++++++------- .../test_base_model/test_data_preprocessor.py | 26 +++++++++++-- 2 files changed, 47 insertions(+), 16 deletions(-) diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index 9fa5b48e..8b5f7982 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -153,11 +153,15 @@ class ImgDataPreprocessor(BaseDataPreprocessor): Args: mean (Sequence[float or int], optional): The pixel mean of image channels. If ``bgr_to_rgb=True`` it means the mean value of R, - G, B channels. If it is not specified, images will not be - normalized. Defaults None. + G, B channels. If the length of `mean` is 1, it means all + channels have the same mean value, or the input is a gray image. + If it is not specified, images will not be normalized. Defaults + None. std (Sequence[float or int], optional): The pixel standard deviation of image channels. If ``bgr_to_rgb=True`` it means the standard - deviation of R, G, B channels. If it is not specified, images will + deviation of R, G, B channels. If the length of `std` is 1, + it means all channels have the same standard deviation, or the + input is a gray image. If it is not specified, images will not be normalized. Defaults None. pad_size_divisor (int): The size of padded image should be divisible by ``pad_size_divisor``. Defaults to 1. @@ -187,11 +191,11 @@ class ImgDataPreprocessor(BaseDataPreprocessor): 'mean and std should be both None or tuple') if mean is not None: assert len(mean) == 3 or len(mean) == 1, ( - 'The length of mean should be 1 or 3 to be compatible with ' - f'RGB or gray image, but got {len(mean)}') + '`mean` should have 1 or 3 values, to be compatible with ' + f'RGB or gray image, but got {len(mean)} values') assert len(std) == 3 or len(std) == 1, ( # type: ignore - 'The length of std should be 1 or 3 to be compatible with RGB ' # type: ignore # noqa: E501 - f'or gray image, but got {len(std)}') # type: ignore + '`std` should have 1 or 3 values, to be compatible with RGB ' # type: ignore # noqa: E501 + f'or gray image, but got {len(std)} values') self._enable_normalize = True self.register_buffer('mean', torch.tensor(mean).view(-1, 1, 1), False) @@ -221,12 +225,19 @@ class ImgDataPreprocessor(BaseDataPreprocessor): model input. """ inputs, batch_data_samples = self.collate_data(data) - # channel transform - if self.channel_conversion: - inputs = [_input[[2, 1, 0], ...] for _input in inputs] - # Normalization. - if self._enable_normalize: - inputs = [(_input - self.mean) / self.std for _input in inputs] + for idx, _input in enumerate(inputs): + # channel transform + if self.channel_conversion: + _input = _input[[2, 1, 0], ...] + # Normalization. + if self._enable_normalize: + if self.mean.shape[0] == 3: + assert _input.dim() == 3 and _input.shape[0] == 3, ( + 'If the mean has 3 values, the input tensor should in ' + 'shape of (3, H, W), but got the tensor with shape ' + f'{_input.shape}') + _input = (_input - self.mean) / self.std + inputs[idx] = _input # Pad and stack Tensor. batch_inputs = stack_batch(inputs, self.pad_size_divisor, self.pad_value) diff --git a/tests/test_model/test_base_model/test_data_preprocessor.py b/tests/test_model/test_base_model/test_data_preprocessor.py index 6af60752..99cbf289 100644 --- a/tests/test_model/test_base_model/test_data_preprocessor.py +++ b/tests/test_model/test_base_model/test_data_preprocessor.py @@ -53,7 +53,7 @@ class TestBaseDataPreprocessor(TestCase): self.assertEqual(batch_inputs.device.type, 'cuda') -class TestImageDataPreprocessor(TestBaseDataPreprocessor): +class TestImgataPreprocessor(TestBaseDataPreprocessor): def test_init(self): # initiate model without `preprocess_cfg` @@ -79,10 +79,10 @@ class TestImageDataPreprocessor(TestBaseDataPreprocessor): assert_allclose(data_processor.pad_value, torch.tensor(10)) self.assertEqual(data_processor.pad_size_divisor, 16) - with self.assertRaisesRegex(AssertionError, 'The length of mean'): + with self.assertRaisesRegex(AssertionError, '`mean` should have'): ImgDataPreprocessor(mean=(1, 2), std=(1, 2, 3)) - with self.assertRaisesRegex(AssertionError, 'The length of std'): + with self.assertRaisesRegex(AssertionError, '`std` should have'): ImgDataPreprocessor(mean=(1, 2, 3), std=(1, 2)) with self.assertRaisesRegex(AssertionError, '`bgr2rgb` and `rgb2bgr`'): @@ -121,6 +121,7 @@ class TestImageDataPreprocessor(TestBaseDataPreprocessor): std = torch.tensor([1, 2, 3]).view(-1, 1, 1) target_inputs1 = (inputs1.clone()[[2, 1, 0], ...] - 127.5) / std target_inputs2 = (inputs2.clone()[[2, 1, 0], ...] - 127.5) / std + target_inputs1 = F.pad(target_inputs1, (0, 6, 0, 6), value=10) target_inputs2 = F.pad(target_inputs2, (0, 1, 0, 1), value=10) @@ -155,6 +156,25 @@ class TestImageDataPreprocessor(TestBaseDataPreprocessor): assert_allclose(input_, target_input) assert_allclose(data_sample.bboxes, target_data_sample.bboxes) + # Test gray image with 3 dim mean will raise error + data_preprocessor = ImgDataPreprocessor( + mean=(127.5, 127.5, 127.5), std=(127.5, 127.5, 127.5)) + data = [ + dict(inputs=torch.ones(10, 10)), + dict(inputs=torch.ones(10, 10)) + ] + with self.assertRaisesRegex(AssertionError, + 'If the mean has 3 values'): + data_preprocessor(data) + + data = [ + dict(inputs=torch.ones(1, 10, 10)), + dict(inputs=torch.ones(1, 10, 10)) + ] + with self.assertRaisesRegex(AssertionError, + 'If the mean has 3 values'): + data_preprocessor(data) + # Test empty `data_sample` data = [dict(inputs=inputs1.clone()), dict(inputs=inputs2.clone())] data_preprocessor(data, True) -- GitLab