Skip to content
Snippets Groups Projects
Unverified Commit 819e10c2 authored by RangiLyu's avatar RangiLyu Committed by GitHub
Browse files

[Fix] Fix image dtype when enable_normalize=False. (#301)

* [Fix] Fix image dtype when enable_normalize=False.

* update ut

* move to collate

* update ut
parent bcab8132
No related branches found
No related tags found
No related merge requests found
......@@ -53,7 +53,7 @@ class BaseDataPreprocessor(nn.Module):
Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input
tensor and list of labels at target device.
"""
inputs = [_data['inputs'].to(self._device) for _data in data]
inputs = [_data['inputs'].to(self._device).float() for _data in data]
batch_data_samples: List[BaseDataElement] = []
# Model can get predictions without any data samples.
for _data in data:
......
......@@ -28,6 +28,7 @@ class TestBaseDataPreprocessor(TestCase):
]
batch_inputs, batch_labels = base_data_preprocessor(data)
self.assertTrue(torch.is_floating_point(batch_inputs))
self.assertEqual(batch_inputs.shape, (2, 1, 3, 5))
assert_allclose(input1, batch_inputs[0])
......@@ -38,14 +39,17 @@ class TestBaseDataPreprocessor(TestCase):
if torch.cuda.is_available():
base_data_preprocessor = base_data_preprocessor.cuda()
batch_inputs, batch_labels = base_data_preprocessor(data)
self.assertTrue(torch.is_floating_point(batch_inputs))
self.assertEqual(batch_inputs.device.type, 'cuda')
base_data_preprocessor = base_data_preprocessor.cpu()
batch_inputs, batch_labels = base_data_preprocessor(data)
self.assertTrue(torch.is_floating_point(batch_inputs))
self.assertEqual(batch_inputs.device.type, 'cpu')
base_data_preprocessor = base_data_preprocessor.to('cuda:0')
batch_inputs, batch_labels = base_data_preprocessor(data)
self.assertTrue(torch.is_floating_point(batch_inputs))
self.assertEqual(batch_inputs.device.type, 'cuda')
......@@ -122,6 +126,7 @@ class TestImageDataPreprocessor(TestBaseDataPreprocessor):
target_inputs = [target_inputs1, target_inputs2]
inputs, data_samples = data_preprocessor(data, True)
self.assertTrue(torch.is_floating_point(inputs))
target_data_samples = [data_sample1, data_sample2]
for input_, data_sample, target_input, target_data_sample in zip(
......@@ -142,6 +147,7 @@ class TestImageDataPreprocessor(TestBaseDataPreprocessor):
target_inputs = [target_inputs1, target_inputs2]
inputs, data_samples = data_preprocessor(data, True)
self.assertTrue(torch.is_floating_point(inputs))
target_data_samples = [data_sample1, data_sample2]
for input_, data_sample, target_input, target_data_sample in zip(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment