diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py
index 9fa5b48e1b904a5c3f1bff0b91be8dd4e7e650ca..8b5f79824c9d178b043db7ed02c997c28969bcf4 100644
--- a/mmengine/model/base_model/data_preprocessor.py
+++ b/mmengine/model/base_model/data_preprocessor.py
@@ -153,11 +153,15 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
     Args:
         mean (Sequence[float or int], optional): The pixel mean of image
             channels. If ``bgr_to_rgb=True`` it means the mean value of R,
-            G, B channels. If it is not specified, images will not be
-            normalized. Defaults None.
+            G, B channels. If the length of `mean` is 1, it means all
+            channels have the same mean value, or the input is a gray image.
+            If it is not specified, images will not be normalized. Defaults
+            None.
         std (Sequence[float or int], optional): The pixel standard deviation of
             image channels. If ``bgr_to_rgb=True`` it means the standard
-            deviation of R, G, B channels. If it is not specified, images will
+            deviation of R, G, B channels. If the length of `std` is 1,
+            it means all channels have the same standard deviation, or the
+            input is a gray image.  If it is not specified, images will
             not be normalized. Defaults None.
         pad_size_divisor (int): The size of padded image should be
             divisible by ``pad_size_divisor``. Defaults to 1.
@@ -187,11 +191,11 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
             'mean and std should be both None or tuple')
         if mean is not None:
             assert len(mean) == 3 or len(mean) == 1, (
-                'The length of mean should be 1 or 3 to be compatible with '
-                f'RGB or gray image, but got {len(mean)}')
+                '`mean` should have 1 or 3 values, to be compatible with '
+                f'RGB or gray image, but got {len(mean)} values')
             assert len(std) == 3 or len(std) == 1, (  # type: ignore
-                'The length of std should be 1 or 3 to be compatible with RGB '  # type: ignore # noqa: E501
-                f'or gray image, but got {len(std)}')  # type: ignore
+                '`std` should have 1 or 3 values, to be compatible with RGB '  # type: ignore # noqa: E501
+                f'or gray image, but got {len(std)} values')
             self._enable_normalize = True
             self.register_buffer('mean',
                                  torch.tensor(mean).view(-1, 1, 1), False)
@@ -221,12 +225,19 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
             model input.
         """
         inputs, batch_data_samples = self.collate_data(data)
-        # channel transform
-        if self.channel_conversion:
-            inputs = [_input[[2, 1, 0], ...] for _input in inputs]
-        # Normalization.
-        if self._enable_normalize:
-            inputs = [(_input - self.mean) / self.std for _input in inputs]
+        for idx, _input in enumerate(inputs):
+            # channel transform
+            if self.channel_conversion:
+                _input = _input[[2, 1, 0], ...]
+            # Normalization.
+            if self._enable_normalize:
+                if self.mean.shape[0] == 3:
+                    assert _input.dim() == 3 and _input.shape[0] == 3, (
+                        'If the mean has 3 values, the input tensor should in '
+                        'shape of (3, H, W), but got the tensor with shape '
+                        f'{_input.shape}')
+                _input = (_input - self.mean) / self.std
+            inputs[idx] = _input
         # Pad and stack Tensor.
         batch_inputs = stack_batch(inputs, self.pad_size_divisor,
                                    self.pad_value)
diff --git a/tests/test_model/test_base_model/test_data_preprocessor.py b/tests/test_model/test_base_model/test_data_preprocessor.py
index 6af6075210054c3ee2850c6c2b564be43834b992..99cbf2895a7f3760705f36ad5ea59e8ab2bbcb17 100644
--- a/tests/test_model/test_base_model/test_data_preprocessor.py
+++ b/tests/test_model/test_base_model/test_data_preprocessor.py
@@ -53,7 +53,7 @@ class TestBaseDataPreprocessor(TestCase):
             self.assertEqual(batch_inputs.device.type, 'cuda')
 
 
-class TestImageDataPreprocessor(TestBaseDataPreprocessor):
+class TestImgataPreprocessor(TestBaseDataPreprocessor):
 
     def test_init(self):
         # initiate model without `preprocess_cfg`
@@ -79,10 +79,10 @@ class TestImageDataPreprocessor(TestBaseDataPreprocessor):
         assert_allclose(data_processor.pad_value, torch.tensor(10))
         self.assertEqual(data_processor.pad_size_divisor, 16)
 
-        with self.assertRaisesRegex(AssertionError, 'The length of mean'):
+        with self.assertRaisesRegex(AssertionError, '`mean` should have'):
             ImgDataPreprocessor(mean=(1, 2), std=(1, 2, 3))
 
-        with self.assertRaisesRegex(AssertionError, 'The length of std'):
+        with self.assertRaisesRegex(AssertionError, '`std` should have'):
             ImgDataPreprocessor(mean=(1, 2, 3), std=(1, 2))
 
         with self.assertRaisesRegex(AssertionError, '`bgr2rgb` and `rgb2bgr`'):
@@ -121,6 +121,7 @@ class TestImageDataPreprocessor(TestBaseDataPreprocessor):
         std = torch.tensor([1, 2, 3]).view(-1, 1, 1)
         target_inputs1 = (inputs1.clone()[[2, 1, 0], ...] - 127.5) / std
         target_inputs2 = (inputs2.clone()[[2, 1, 0], ...] - 127.5) / std
+
         target_inputs1 = F.pad(target_inputs1, (0, 6, 0, 6), value=10)
         target_inputs2 = F.pad(target_inputs2, (0, 1, 0, 1), value=10)
 
@@ -155,6 +156,25 @@ class TestImageDataPreprocessor(TestBaseDataPreprocessor):
             assert_allclose(input_, target_input)
             assert_allclose(data_sample.bboxes, target_data_sample.bboxes)
 
+        # Test gray image with 3 dim mean will raise error
+        data_preprocessor = ImgDataPreprocessor(
+            mean=(127.5, 127.5, 127.5), std=(127.5, 127.5, 127.5))
+        data = [
+            dict(inputs=torch.ones(10, 10)),
+            dict(inputs=torch.ones(10, 10))
+        ]
+        with self.assertRaisesRegex(AssertionError,
+                                    'If the mean has 3 values'):
+            data_preprocessor(data)
+
+        data = [
+            dict(inputs=torch.ones(1, 10, 10)),
+            dict(inputs=torch.ones(1, 10, 10))
+        ]
+        with self.assertRaisesRegex(AssertionError,
+                                    'If the mean has 3 values'):
+            data_preprocessor(data)
+
         # Test empty `data_sample`
         data = [dict(inputs=inputs1.clone()), dict(inputs=inputs2.clone())]
         data_preprocessor(data, True)