Skip to content
Snippets Groups Projects
Unverified Commit c2982723 authored by shenmishajing's avatar shenmishajing Committed by GitHub
Browse files

[Feats]: add non_blocking feature to BaseDataPreprocessor (#618)


* add non_blocking feature to BaseDataPreprocessor

* Update mmengine/model/base_model/data_preprocessor.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* move new parameters to the last to avoid bc issue

* Update mmengine/model/base_model/data_preprocessor.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmengine/model/base_model/data_preprocessor.py

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 09a195b2
No related branches found
No related tags found
No related merge requests found
......@@ -22,13 +22,19 @@ class BaseDataPreprocessor(nn.Module):
forward method to implement custom data pre-processing, such as
batch-resize, MixUp, or CutMix.
Args:
non_blocking (bool): Whether block current process
when transferring data to device.
New in version 0.3.0.
Note:
Data dictionary returned by dataloader must be a dict and at least
contain the ``inputs`` key.
"""
def __init__(self):
def __init__(self, non_blocking: Optional[bool] = False):
super().__init__()
self._non_blocking = non_blocking
self._device = torch.device('cpu')
def cast_data(self, data: CastData) -> CastData:
......@@ -48,9 +54,9 @@ class BaseDataPreprocessor(nn.Module):
elif isinstance(data, Sequence):
return [self.cast_data(sample) for sample in data]
elif isinstance(data, torch.Tensor):
return data.to(self.device)
return data.to(self.device, non_blocking=self._non_blocking)
elif isinstance(data, BaseDataElement):
return data.to(self.device)
return data.to(self.device, non_blocking=self._non_blocking)
else:
return data
......@@ -150,6 +156,9 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
Defaults to False.
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
Defaults to False.
non_blocking (bool): Whether block current process
when transferring data to device.
New in version v0.3.0.
Note:
if images do not need to be normalized, `std` and `mean` should be
......@@ -163,8 +172,9 @@ class ImgDataPreprocessor(BaseDataPreprocessor):
pad_size_divisor: int = 1,
pad_value: Union[float, int] = 0,
bgr_to_rgb: bool = False,
rgb_to_bgr: bool = False):
super().__init__()
rgb_to_bgr: bool = False,
non_blocking: Optional[bool] = False):
super().__init__(non_blocking)
assert not (bgr_to_rgb and rgb_to_bgr), (
'`bgr2rgb` and `rgb2bgr` cannot be set to True at the same time')
assert (mean is None) == (std is None), (
......
......@@ -14,6 +14,11 @@ class TestBaseDataPreprocessor(TestCase):
def test_init(self):
base_data_preprocessor = BaseDataPreprocessor()
self.assertEqual(base_data_preprocessor._device.type, 'cpu')
self.assertEqual(base_data_preprocessor._non_blocking, False)
base_data_preprocessor = BaseDataPreprocessor(True)
self.assertEqual(base_data_preprocessor._device.type, 'cpu')
self.assertEqual(base_data_preprocessor._non_blocking, True)
def test_forward(self):
# Test cpu forward with list of data samples.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment