diff --git a/mmengine/model/__init__.py b/mmengine/model/__init__.py
index 0b7f08e7d75be808db9ec13137febd011e1a639a..47e2356c270d1b181bec01910107ddb79d65acdd 100644
--- a/mmengine/model/__init__.py
+++ b/mmengine/model/__init__.py
@@ -3,7 +3,7 @@ from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA,
                              StochasticWeightAverage)
 from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
 from .base_module import BaseModule
-from .utils import detect_anomalous_params, merge_dict, stach_batch_imgs
+from .utils import detect_anomalous_params, merge_dict, stack_batch
 from .wrappers import (MMDistributedDataParallel,
                        MMSeparateDistributedDataParallel, is_model_wrapper)
 
@@ -11,6 +11,6 @@ __all__ = [
     'MMDistributedDataParallel', 'is_model_wrapper', 'StochasticWeightAverage',
     'ExponentialMovingAverage', 'MomentumAnnealingEMA', 'BaseModel',
     'BaseDataPreprocessor', 'ImgDataPreprocessor',
-    'MMSeparateDistributedDataParallel', 'BaseModule', 'stach_batch_imgs',
+    'MMSeparateDistributedDataParallel', 'BaseModule', 'stack_batch',
     'merge_dict', 'detect_anomalous_params'
 ]
diff --git a/mmengine/model/base_model/base_model.py b/mmengine/model/base_model/base_model.py
index 1155d999b532fcecf35c3af99ebc1a375eadb439..fe970acbfd3921b4eeaba628ff0031351420a0d1 100644
--- a/mmengine/model/base_model/base_model.py
+++ b/mmengine/model/base_model/base_model.py
@@ -179,8 +179,8 @@ class BaseModel(BaseModule):
 
     def to(self, device: Optional[Union[int, torch.device]], *args,
            **kwargs) -> nn.Module:
-        """Overrides this method to set the ``device`` attribute of
-        :obj:`BaseDataPreprocessor` additionally
+        """Overrides this method to call :meth:`BaseDataPreprocessor.to`
+        additionally.
 
         Args:
             device (int or torch.device, optional): the desired device of the
@@ -189,19 +189,29 @@ class BaseModel(BaseModule):
         Returns:
             nn.Module: The model itself.
         """
-        self.data_preprocessor.device = torch.device(device)
+        self.data_preprocessor.to(device)
         return super().to(device)
 
     def cuda(self, *args, **kwargs) -> nn.Module:
-        """Overrides this method to set the ``device`` attribute of
-        :obj:`BaseDataPreprocessor` additionally
+        """Overrides this method to call :meth:`BaseDataPreprocessor.cuda`
+        additionally.
 
         Returns:
             nn.Module: The model itself.
         """
-        self.data_preprocessor.device = torch.cuda.current_device()
+        self.data_preprocessor.cuda()
         return super().cuda()
 
+    def cpu(self, *args, **kwargs) -> nn.Module:
+        """Overrides this method to call :meth:`BaseDataPreprocessor.cpu`
+        additionally.
+
+        Returns:
+            nn.Module: The model itself.
+        """
+        self.data_preprocessor.cpu()
+        return super().cpu()
+
     @abstractmethod
     def forward(self,
                 batch_inputs: torch.Tensor,
diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py
index 2b9d2cb3eefa2f51a782bd8d7c243208fb26917d..cf94b54de29d24713f88deb8ff7ea4718aa80f12 100644
--- a/mmengine/model/base_model/data_preprocessor.py
+++ b/mmengine/model/base_model/data_preprocessor.py
@@ -6,7 +6,7 @@ import torch.nn as nn
 
 from mmengine.data import BaseDataElement
 from mmengine.registry import MODELS
-from ..utils import stach_batch_imgs
+from ..utils import stack_batch
 
 
 @MODELS.register_module()
@@ -25,18 +25,15 @@ class BaseDataPreprocessor(nn.Module):
     forward method to implement custom data pre-processing, such as
     batch-resize, MixUp, or CutMix.
 
-    Args:
-        device (int or torch.device): Target device.
-
     Warnings:
         Each item of data sampled from dataloader must be a dict and at least
         contain the ``inputs`` key. Furthermore, the value of ``inputs``
         must be a ``Tensor`` with the same shape.
     """
 
-    def __init__(self, device: Union[int, torch.device] = 'cpu'):
+    def __init__(self):
         super().__init__()
-        self.device = device
+        self._device = torch.device('cpu')
 
     def collate_data(
             self,
@@ -56,7 +53,7 @@ class BaseDataPreprocessor(nn.Module):
             Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input
             tensor and list of labels at target device.
         """
-        inputs = [_data['inputs'].to(self.device) for _data in data]
+        inputs = [_data['inputs'].to(self._device) for _data in data]
         batch_data_samples: List[BaseDataElement] = []
         # Model can get predictions without any data samples.
         for _data in data:
@@ -64,7 +61,7 @@ class BaseDataPreprocessor(nn.Module):
                 batch_data_samples.append(_data['data_sample'])
         # Move data from CPU to corresponding device.
         batch_data_samples = [
-            data_sample.to(self.device) for data_sample in batch_data_samples
+            data_sample.to(self._device) for data_sample in batch_data_samples
         ]
 
         if not batch_data_samples:
@@ -93,6 +90,10 @@ class BaseDataPreprocessor(nn.Module):
         batch_inputs = torch.stack(inputs, dim=0)
         return batch_inputs, batch_data_samples
 
+    @property
+    def device(self):
+        return self._device
+
     def to(self, device: Optional[Union[int, torch.device]], *args,
            **kwargs) -> nn.Module:
         """Overrides this method to set the :attr:`device`
@@ -104,7 +105,7 @@ class BaseDataPreprocessor(nn.Module):
         Returns:
             nn.Module: The model itself.
         """
-        self.device = torch.device(device)
+        self._device = torch.device(device)
         return super().to(device)
 
     def cuda(self, *args, **kwargs) -> nn.Module:
@@ -113,9 +114,18 @@ class BaseDataPreprocessor(nn.Module):
         Returns:
             nn.Module: The model itself.
         """
-        self.device = torch.cuda.current_device()
+        self._device = torch.device(torch.cuda.current_device())
         return super().cuda()
 
+    def cpu(self, *args, **kwargs) -> nn.Module:
+        """Overrides this method to set the :attr:`device`
+
+        Returns:
+            nn.Module: The model itself.
+        """
+        self._device = torch.device('cpu')
+        return super().cpu()
+
 
 @MODELS.register_module()
 class ImgDataPreprocessor(BaseDataPreprocessor):
@@ -158,7 +168,6 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
             Defaults to False.
         rgb_to_bgr (bool): whether to convert image from RGB to RGB.
             Defaults to False.
-        device (int or torch.device): Target device.
     """
 
     def __init__(self,
@@ -167,9 +176,8 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
                  pad_size_divisor: int = 1,
                  pad_value: Union[float, int] = 0,
                  bgr_to_rgb: bool = False,
-                 rgb_to_bgr: bool = False,
-                 device: Union[int, torch.device] = 'cpu'):
-        super().__init__(device)
+                 rgb_to_bgr: bool = False):
+        super().__init__()
         assert len(mean) == 3 or len(mean) == 1, (
             'The length of mean should be 1 or 3 to be compatible with RGB '
             f'or gray image, but got {len(mean)}')
@@ -208,6 +216,6 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
         # Normalization.
         inputs = [(_input - self.mean) / self.std for _input in inputs]
         # Pad and stack Tensor.
-        batch_inputs = stach_batch_imgs(inputs, self.pad_size_divisor,
-                                        self.pad_value)
+        batch_inputs = stack_batch(inputs, self.pad_size_divisor,
+                                   self.pad_value)
         return batch_inputs, batch_data_samples
diff --git a/mmengine/model/utils.py b/mmengine/model/utils.py
index 29cc779b9f250e8d6293f2ff73492cddb17752d8..b3dd5e17fa7e2a9f4e2bc492147732b80ea724af 100644
--- a/mmengine/model/utils.py
+++ b/mmengine/model/utils.py
@@ -673,10 +673,10 @@ def trunc_normal_(tensor: Tensor,
     return _no_grad_trunc_normal_(tensor, mean, std, a, b)
 
 
-def stach_batch_imgs(tensor_list: List[torch.Tensor],
-                     pad_size_divisor: int = 1,
-                     pad_value: Union[int, float] = 0) -> torch.Tensor:
-    """Stack multiple tensors to form a batch and pad the images to the max
+def stack_batch(tensor_list: List[torch.Tensor],
+                pad_size_divisor: int = 1,
+                pad_value: Union[int, float] = 0) -> torch.Tensor:
+    """Stack multiple tensors to form a batch and pad the tensor to the max
     shape use the right bottom padding mode in these images. If
     ``pad_size_divisor > 0``, add padding to ensure the shape of each dim is
     divisible by ``pad_size_divisor``.
@@ -690,7 +690,7 @@ def stach_batch_imgs(tensor_list: List[torch.Tensor],
         pad_value (int, float): The padding value. Defaults to 0.
 
     Returns:
-       Tensor: The 4D-tensor.
+       Tensor: The n dim tensor.
     """
     assert isinstance(
         tensor_list,
diff --git a/tests/test_model/test_base_model/test_base_model.py b/tests/test_model/test_base_model/test_base_model.py
index 222e703d53128af4d7172a1cc82fc65868f6c050..0abac656a0b0d14be1208d4e6e936849bdb928a7 100644
--- a/tests/test_model/test_base_model/test_base_model.py
+++ b/tests/test_model/test_base_model/test_base_model.py
@@ -115,11 +115,13 @@ class TestBaseModel(TestCase):
         inputs = torch.randn(3, 1, 1).cuda()
         data = dict(inputs=inputs)
         model = ToyModel().cuda()
-        model.val_step([data])
+        out = model.val_step([data])
+        self.assertEqual(out.device.type, 'cuda')
 
     @unittest.skipIf(not torch.cuda.is_available(), 'cuda should be available')
     def test_to(self):
-        inputs = torch.randn(3, 1, 1).cuda()
+        inputs = torch.randn(3, 1, 1).to('cuda:0')
         data = dict(inputs=inputs)
         model = ToyModel().to(torch.cuda.current_device())
-        model.val_step([data])
+        out = model.val_step([data])
+        self.assertEqual(out.device.type, 'cuda')
diff --git a/tests/test_model/test_base_model/test_data_preprocessor.py b/tests/test_model/test_base_model/test_data_preprocessor.py
index 146ed35e209aeab31802c2b3d31e142086f2f0bc..0639b59c9997ccc7cb94b328ff9841f9d36ab9d1 100644
--- a/tests/test_model/test_base_model/test_data_preprocessor.py
+++ b/tests/test_model/test_base_model/test_data_preprocessor.py
@@ -13,7 +13,7 @@ class TestBaseDataPreprocessor(TestCase):
 
     def test_init(self):
         base_data_preprocessor = BaseDataPreprocessor()
-        self.assertEqual(base_data_preprocessor.device, 'cpu')
+        self.assertEqual(base_data_preprocessor._device.type, 'cpu')
 
     def test_forward(self):
         base_data_preprocessor = BaseDataPreprocessor()
@@ -35,6 +35,19 @@ class TestBaseDataPreprocessor(TestCase):
         assert_allclose(label1, batch_labels[0])
         assert_allclose(label2, batch_labels[1])
 
+        if torch.cuda.is_available():
+            base_data_preprocessor = base_data_preprocessor.cuda()
+            batch_inputs, batch_labels = base_data_preprocessor(data)
+            self.assertEqual(batch_inputs.device.type, 'cuda')
+
+            base_data_preprocessor = base_data_preprocessor.cpu()
+            batch_inputs, batch_labels = base_data_preprocessor(data)
+            self.assertEqual(batch_inputs.device.type, 'cpu')
+
+            base_data_preprocessor = base_data_preprocessor.to('cuda:0')
+            batch_inputs, batch_labels = base_data_preprocessor(data)
+            self.assertEqual(batch_inputs.device.type, 'cuda')
+
 
 class TestImageDataPreprocessor(TestBaseDataPreprocessor):