diff --git a/mmengine/model/base_model/data_preprocessor.py b/mmengine/model/base_model/data_preprocessor.py index 2832f86e172f702c8ffb7ca0abaf4ac2de4d3e55..d7d27ec9078bfca0bedabe8e7ddca1c36ba7a829 100644 --- a/mmengine/model/base_model/data_preprocessor.py +++ b/mmengine/model/base_model/data_preprocessor.py @@ -11,7 +11,8 @@ from mmengine.structures import BaseDataElement from mmengine.utils import is_list_of from ..utils import stack_batch -CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list] +CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str, + None] @MODELS.register_module() @@ -48,17 +49,20 @@ class BaseDataPreprocessor(nn.Module): """ if isinstance(data, Mapping): return {key: self.cast_data(data[key]) for key in data} + elif isinstance(data, (str, bytes)) or data is None: + return data elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple - return type(data)(*(self.cast_data(sample)for sample in data)) # type: ignore # noqa: E501 # yapf:disable + return type(data)(*(self.cast_data(sample) for sample in data)) # type: ignore # noqa: E501 # yapf:disable elif isinstance(data, Sequence): - return [self.cast_data(sample) for sample in data] - elif isinstance(data, torch.Tensor): - return data.to(self.device, non_blocking=self._non_blocking) - elif isinstance(data, BaseDataElement): + return type(data)(self.cast_data(sample) for sample in data) # type: ignore # noqa: E501 # yapf:disable + elif isinstance(data, (torch.Tensor, BaseDataElement)): return data.to(self.device, non_blocking=self._non_blocking) else: - return data + raise TypeError( + '`BaseDataPreprocessor.cast_data`: batch data must contain ' + 'tensors, numpy arrays, numbers, dicts or lists, but ' + f'found {type(data)}') def forward(self, data: dict, training: bool = False) -> Union[dict, list]: """Preprocesses the data into the model input format. diff --git a/tests/test_model/test_base_model/test_data_preprocessor.py b/tests/test_model/test_base_model/test_data_preprocessor.py index a1bc25d41c37723013a76260a06c22598d7bab8a..377cb1753b7a5869d4636886203b513c56d9bc5c 100644 --- a/tests/test_model/test_base_model/test_data_preprocessor.py +++ b/tests/test_model/test_base_model/test_data_preprocessor.py @@ -28,8 +28,8 @@ class TestBaseDataPreprocessor(TestCase): label1 = torch.randn(1) label2 = torch.randn(1) + # Test with dict of batch inputs and batch data samples data = dict(inputs=[input1, input2], data_sample=[label1, label2]) - output = base_data_preprocessor(data) batch_inputs, batch_labels = output['inputs'], output['data_sample'] self.assertTrue(torch.is_floating_point(batch_inputs[0])) @@ -41,40 +41,54 @@ class TestBaseDataPreprocessor(TestCase): assert_allclose(label2, batch_labels[1]) # Test with tuple of batch inputs and batch data samples - data = dict( - inputs=torch.stack([input1, input2]), data_sample=[label1, label2]) - output = base_data_preprocessor(data)['inputs'] + data = (torch.stack([input1, input2]), (label1, label2)) + batch_inputs, batch_labels = base_data_preprocessor(data) + self.assertTrue(torch.is_floating_point(batch_inputs)) + self.assertEqual(batch_inputs[0].shape, (1, 3, 5)) + self.assertEqual(batch_inputs[1].shape, (1, 3, 5)) self.assertTrue(torch.is_floating_point(batch_inputs[0])) # Test cuda forward if torch.cuda.is_available(): # Test with list of data samples. + data = dict(inputs=[input1, input2], data_sample=[label1, label2]) base_data_preprocessor = base_data_preprocessor.cuda() output = base_data_preprocessor(data) batch_inputs, batch_labels = output['inputs'], output[ 'data_sample'] - self.assertTrue(torch.is_floating_point(batch_inputs)) - self.assertEqual(batch_inputs.device.type, 'cuda') + self.assertTrue(torch.is_floating_point(batch_inputs[0])) + self.assertEqual(batch_inputs[0].device.type, 'cuda') + # Fallback to test with cpu. base_data_preprocessor = base_data_preprocessor.cpu() output = base_data_preprocessor(data) batch_inputs, batch_labels = output['inputs'], output[ 'data_sample'] - self.assertTrue(torch.is_floating_point(batch_inputs)) - self.assertEqual(batch_inputs.device.type, 'cpu') + self.assertTrue(torch.is_floating_point(batch_inputs[0])) + self.assertEqual(batch_inputs[0].device.type, 'cpu') + # Test `base_data_preprocessor` can be moved to cuda again. base_data_preprocessor = base_data_preprocessor.to('cuda:0') output = base_data_preprocessor(data) batch_inputs, batch_labels = output['inputs'], output[ 'data_sample'] - self.assertTrue(torch.is_floating_point(batch_inputs)) - self.assertEqual(batch_inputs.device.type, 'cuda') + self.assertTrue(torch.is_floating_point(batch_inputs[0])) + self.assertEqual(batch_inputs[0].device.type, 'cuda') # device of `base_data_preprocessor` is cuda, output should be # cuda tensor. - self.assertEqual(batch_inputs.device.type, 'cuda') + self.assertEqual(batch_inputs[0].device.type, 'cuda') self.assertEqual(batch_labels[0].device.type, 'cuda') + # Test forward with string value + data = dict(string='abc') + base_data_preprocessor(data) + + with self.assertRaisesRegex(TypeError, + '`BaseDataPreprocessor.cast_data`:'): + data = dict(string=object()) + base_data_preprocessor(data) + class TestImgDataPreprocessor(TestBaseDataPreprocessor):