From c2982723c979c3af58690b8edc6dff2e0c5af12b Mon Sep 17 00:00:00 2001
From: shenmishajing <shenmishajing@Gmail.com>
Date: Tue, 18 Oct 2022 17:58:51 +0800
Subject: [PATCH] [Feats]: add non_blocking feature to BaseDataPreprocessor
 (#618)

* add non_blocking feature to BaseDataPreprocessor

* Update mmengine/model/base_model/data_preprocessor.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* move new parameters to the last to avoid bc issue

* Update mmengine/model/base_model/data_preprocessor.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/model/base_model/data_preprocessor.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
---
 .../model/base_model/data_preprocessor.py     | 20 ++++++++++++++-----
 .../test_base_model/test_data_preprocessor.py |  5 +++++
 2 files changed, 20 insertions(+), 5 deletions(-)

diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py
index 2b14d75e..2832f86e 100644
--- a/mmengine/model/base_model/data_preprocessor.py
+++ b/mmengine/model/base_model/data_preprocessor.py
@@ -22,13 +22,19 @@ class BaseDataPreprocessor(nn.Module):
     forward method to implement custom data pre-processing, such as
     batch-resize, MixUp, or CutMix.
 
+    Args:
+        non_blocking (bool): Whether block current process
+            when transferring data to device.
+            New in version 0.3.0.
+
     Note:
         Data dictionary returned by dataloader must be a dict and at least
         contain the ``inputs`` key.
     """
 
-    def __init__(self):
+    def __init__(self, non_blocking: Optional[bool] = False):
         super().__init__()
+        self._non_blocking = non_blocking
         self._device = torch.device('cpu')
 
     def cast_data(self, data: CastData) -> CastData:
@@ -48,9 +54,9 @@ class BaseDataPreprocessor(nn.Module):
         elif isinstance(data, Sequence):
             return [self.cast_data(sample) for sample in data]
         elif isinstance(data, torch.Tensor):
-            return data.to(self.device)
+            return data.to(self.device, non_blocking=self._non_blocking)
         elif isinstance(data, BaseDataElement):
-            return data.to(self.device)
+            return data.to(self.device, non_blocking=self._non_blocking)
         else:
             return data
 
@@ -150,6 +156,9 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
             Defaults to False.
         rgb_to_bgr (bool): whether to convert image from RGB to RGB.
             Defaults to False.
+        non_blocking (bool): Whether block current process
+            when transferring data to device.
+            New in version v0.3.0.
 
     Note:
         if images do not need to be normalized, `std` and `mean` should be
@@ -163,8 +172,9 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
                  pad_size_divisor: int = 1,
                  pad_value: Union[float, int] = 0,
                  bgr_to_rgb: bool = False,
-                 rgb_to_bgr: bool = False):
-        super().__init__()
+                 rgb_to_bgr: bool = False,
+                 non_blocking: Optional[bool] = False):
+        super().__init__(non_blocking)
         assert not (bgr_to_rgb and rgb_to_bgr), (
             '`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time')
         assert (mean is None) == (std is None), (
diff --git a/tests/test_model/test_base_model/test_data_preprocessor.py b/tests/test_model/test_base_model/test_data_preprocessor.py
index 15ba57d3..a1bc25d4 100644
--- a/tests/test_model/test_base_model/test_data_preprocessor.py
+++ b/tests/test_model/test_base_model/test_data_preprocessor.py
@@ -14,6 +14,11 @@ class TestBaseDataPreprocessor(TestCase):
     def test_init(self):
         base_data_preprocessor = BaseDataPreprocessor()
         self.assertEqual(base_data_preprocessor._device.type, 'cpu')
+        self.assertEqual(base_data_preprocessor._non_blocking, False)
+
+        base_data_preprocessor = BaseDataPreprocessor(True)
+        self.assertEqual(base_data_preprocessor._device.type, 'cpu')
+        self.assertEqual(base_data_preprocessor._non_blocking, True)
 
     def test_forward(self):
         # Test cpu forward with list of data samples.
-- 
GitLab