Skip to content
Snippets Groups Projects
Unverified Commit cc8a6b86 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Fix BaseDataset: join prefix in parse_data_info (#226)

* implement parse_data_info

* add unit test

* fix join prefix of ann_file

* fix docstring
parent f5867f84
No related branches found
No related tags found
No related merge requests found
...@@ -155,7 +155,7 @@ class BaseDataset(Dataset): ...@@ -155,7 +155,7 @@ class BaseDataset(Dataset):
data_root (str, optional): The root directory for ``data_prefix`` and data_root (str, optional): The root directory for ``data_prefix`` and
``ann_file``. Defaults to None. ``ann_file``. Defaults to None.
data_prefix (dict, optional): Prefix for training data. Defaults to data_prefix (dict, optional): Prefix for training data. Defaults to
dict(img=None, ann=None). dict(img_path=None, seg_path=None).
filter_cfg (dict, optional): Config for filter data. Defaults to None. filter_cfg (dict, optional): Config for filter data. Defaults to None.
indices (int or Sequence[int], optional): Support using first few indices (int or Sequence[int], optional): Support using first few
data in annotation file to facilitate training/testing on a smaller data in annotation file to facilitate training/testing on a smaller
...@@ -211,7 +211,7 @@ class BaseDataset(Dataset): ...@@ -211,7 +211,7 @@ class BaseDataset(Dataset):
ann_file: str = '', ann_file: str = '',
metainfo: Optional[dict] = None, metainfo: Optional[dict] = None,
data_root: Optional[str] = None, data_root: Optional[str] = None,
data_prefix: dict = dict(img=None, ann=None), data_prefix: dict = dict(img_path=None, seg_path=None),
filter_cfg: Optional[dict] = None, filter_cfg: Optional[dict] = None,
indices: Optional[Union[int, Sequence[int]]] = None, indices: Optional[Union[int, Sequence[int]]] = None,
serialize_data: bool = True, serialize_data: bool = True,
...@@ -330,6 +330,12 @@ class BaseDataset(Dataset): ...@@ -330,6 +330,12 @@ class BaseDataset(Dataset):
Returns: Returns:
list or list[dict]: Parsed annotation. list or list[dict]: Parsed annotation.
""" """
for prefix_key, prefix in self.data_prefix.items():
assert prefix_key in raw_data_info, (
f'raw_data_info: {raw_data_info} dose not contain prefix key'
f'{prefix_key}, please check your data_prefix.')
raw_data_info[prefix_key] = osp.join(prefix,
raw_data_info[prefix_key])
return raw_data_info return raw_data_info
def filter_data(self) -> List[dict]: def filter_data(self) -> List[dict]:
...@@ -520,7 +526,7 @@ class BaseDataset(Dataset): ...@@ -520,7 +526,7 @@ class BaseDataset(Dataset):
""" """
# Automatically join annotation file path with `self.root` if # Automatically join annotation file path with `self.root` if
# `self.ann_file` is not an absolute path. # `self.ann_file` is not an absolute path.
if not osp.isabs(self.ann_file): if not osp.isabs(self.ann_file) and self.ann_file:
self.ann_file = osp.join(self.data_root, self.ann_file) self.ann_file = osp.join(self.data_root, self.ann_file)
# Automatically join data directory with `self.root` if path value in # Automatically join data directory with `self.root` if path value in
# `self.data_prefix` is not an absolute path. # `self.data_prefix` is not an absolute path.
......
...@@ -39,11 +39,13 @@ class TestBaseDataset: ...@@ -39,11 +39,13 @@ class TestBaseDataset:
filename='test_img.jpg', height=604, width=640, sample_idx=0) filename='test_img.jpg', height=604, width=640, sample_idx=0)
self.imgs = torch.rand((2, 3, 32, 32)) self.imgs = torch.rand((2, 3, 32, 32))
self.ori_meta = BaseDataset.METAINFO self.ori_meta = BaseDataset.METAINFO
self.ori_parse_data_info = BaseDataset.parse_data_info
BaseDataset.parse_data_info = MagicMock(return_value=self.data_info) BaseDataset.parse_data_info = MagicMock(return_value=self.data_info)
self.pipeline = MagicMock(return_value=dict(imgs=self.imgs)) self.pipeline = MagicMock(return_value=dict(imgs=self.imgs))
def teardown(self): def teardown(self):
BaseDataset.METAINFO = self.ori_meta BaseDataset.METAINFO = self.ori_meta
BaseDataset.parse_data_info = self.ori_parse_data_info
def test_init(self): def test_init(self):
# test the instantiation of self.base_dataset # test the instantiation of self.base_dataset
...@@ -83,7 +85,6 @@ class TestBaseDataset: ...@@ -83,7 +85,6 @@ class TestBaseDataset:
lazy_init=True) lazy_init=True)
assert not dataset._fully_initialized assert not dataset._fully_initialized
assert not dataset.data_list assert not dataset.data_list
# test the instantiation of self.base_dataset if ann_file is not # test the instantiation of self.base_dataset if ann_file is not
# existed. # existed.
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
...@@ -147,6 +148,15 @@ class TestBaseDataset: ...@@ -147,6 +148,15 @@ class TestBaseDataset:
data_root=osp.join(osp.dirname(__file__), '../data/'), data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'), data_prefix=dict(img='imgs'),
ann_file='annotations/dummy_annotation.json') ann_file='annotations/dummy_annotation.json')
# test the instantiation of self.base_dataset without `ann_file`
BaseDataset.parse_data_info = self.ori_parse_data_info
dataset = BaseDataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='',
serialize_data=False,
lazy_init=True)
assert not dataset.ann_file
def test_meta(self): def test_meta(self):
# test dataset.metainfo with setting the metainfo from annotation file # test dataset.metainfo with setting the metainfo from annotation file
...@@ -369,6 +379,16 @@ class TestBaseDataset: ...@@ -369,6 +379,16 @@ class TestBaseDataset:
assert dataset.get_data_info(0) == self.data_info assert dataset.get_data_info(0) == self.data_info
assert dataset._fully_initialized assert dataset._fully_initialized
assert hasattr(dataset, 'data_list') assert hasattr(dataset, 'data_list')
# Test parse_data_info with `data_prefix`
BaseDataset.parse_data_info = self.ori_parse_data_info
data_root = osp.join(osp.dirname(__file__), '../data/')
dataset = BaseDataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img_path='imgs'),
ann_file='annotations/dummy_annotation.json')
data_info = dataset.get_data_info(0)
assert data_info['img_path'] == osp.join(data_root, 'imgs',
'test_img.jpg')
def test_force_full_init(self): def test_force_full_init(self):
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment