From c2982723c979c3af58690b8edc6dff2e0c5af12b Mon Sep 17 00:00:00 2001 From: shenmishajing <shenmishajing@Gmail.com> Date: Tue, 18 Oct 2022 17:58:51 +0800 Subject: [PATCH] [Feats]: add non_blocking feature to BaseDataPreprocessor (#618) * add non_blocking feature to BaseDataPreprocessor * Update mmengine/model/base_model/data_preprocessor.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * move new parameters to the last to avoid bc issue * Update mmengine/model/base_model/data_preprocessor.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/model/base_model/data_preprocessor.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- .../model/base_model/data_preprocessor.py | 20 ++++++++++++++----- .../test_base_model/test_data_preprocessor.py | 5 +++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index 2b14d75e..2832f86e 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -22,13 +22,19 @@ class BaseDataPreprocessor(nn.Module): forward method to implement custom data pre-processing, such as batch-resize, MixUp, or CutMix. + Args: + non_blocking (bool): Whether block current process + when transferring data to device. + New in version 0.3.0. + Note: Data dictionary returned by dataloader must be a dict and at least contain the ``inputs`` key. """ - def __init__(self): + def __init__(self, non_blocking: Optional[bool] = False): super().__init__() + self._non_blocking = non_blocking self._device = torch.device('cpu') def cast_data(self, data: CastData) -> CastData: @@ -48,9 +54,9 @@ class BaseDataPreprocessor(nn.Module): elif isinstance(data, Sequence): return [self.cast_data(sample) for sample in data] elif isinstance(data, torch.Tensor): - return data.to(self.device) + return data.to(self.device, non_blocking=self._non_blocking) elif isinstance(data, BaseDataElement): - return data.to(self.device) + return data.to(self.device, non_blocking=self._non_blocking) else: return data @@ -150,6 +156,9 @@ class ImgDataPreprocessor(BaseDataPreprocessor): Defaults to False. rgb_to_bgr (bool): whether to convert image from RGB to RGB. Defaults to False. + non_blocking (bool): Whether block current process + when transferring data to device. + New in version v0.3.0. Note: if images do not need to be normalized, `std` and `mean` should be @@ -163,8 +172,9 @@ class ImgDataPreprocessor(BaseDataPreprocessor): pad_size_divisor: int = 1, pad_value: Union[float, int] = 0, bgr_to_rgb: bool = False, - rgb_to_bgr: bool = False): - super().__init__() + rgb_to_bgr: bool = False, + non_blocking: Optional[bool] = False): + super().__init__(non_blocking) assert not (bgr_to_rgb and rgb_to_bgr), ( '`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time') assert (mean is None) == (std is None), ( diff --git a/tests/test_model/test_base_model/test_data_preprocessor.py b/tests/test_model/test_base_model/test_data_preprocessor.py index 15ba57d3..a1bc25d4 100644 --- a/tests/test_model/test_base_model/test_data_preprocessor.py +++ b/tests/test_model/test_base_model/test_data_preprocessor.py @@ -14,6 +14,11 @@ class TestBaseDataPreprocessor(TestCase): def test_init(self): base_data_preprocessor = BaseDataPreprocessor() self.assertEqual(base_data_preprocessor._device.type, 'cpu') + self.assertEqual(base_data_preprocessor._non_blocking, False) + + base_data_preprocessor = BaseDataPreprocessor(True) + self.assertEqual(base_data_preprocessor._device.type, 'cpu') + self.assertEqual(base_data_preprocessor._non_blocking, True) def test_forward(self): # Test cpu forward with list of data samples. -- GitLab