Skip to content
Snippets Groups Projects
Unverified Commit 538ff48a authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Fix] Rename data_list and support loading from ceph in dataset (#240)

* rename datalist and support load ceph

* rename datalist and support load ceph

* remove check disk file path in _load_metainfo

* fix rename error

* fix rename error

* unit test error

* fix rename error

* remove unnecessary code

* fix lint
parent bd3c53b3
No related branches found
No related tags found
No related merge requests found
......@@ -12,7 +12,7 @@ from torch.utils.data import Dataset
from mmengine.fileio import list_from_file, load
from mmengine.registry import TRANSFORMS
from mmengine.utils import check_file_exist
from mmengine.utils import is_abs
class Compose:
......@@ -428,7 +428,6 @@ class BaseDataset(Dataset):
""" # noqa: E501
# `self.ann_file` denotes the absolute annotation file path if
# `self.root=None` or relative path if `self.root=/path/to/data/`.
check_file_exist(self.ann_file)
annotations = load(self.ann_file)
if not isinstance(annotations, dict):
raise TypeError(f'The annotations loaded from annotation file '
......@@ -482,7 +481,7 @@ class BaseDataset(Dataset):
Returns:
dict: Parsed meta information.
"""
# `cls.METAINFO` will be overwritten by in_meta
# avoid `cls.METAINFO` being overwritten by `metainfo`
cls_metainfo = copy.deepcopy(cls.METAINFO)
if metainfo is None:
return cls_metainfo
......@@ -491,13 +490,17 @@ class BaseDataset(Dataset):
f'metainfo should be a dict, but got {type(metainfo)}')
for k, v in metainfo.items():
if isinstance(v, str) and osp.isfile(v):
# if filename in metainfo, this key will be further parsed.
# nested filename will be ignored.
cls_metainfo[k] = list_from_file(v)
if isinstance(v, str):
# If type of value is string, and can be loaded from
# corresponding backend. it means the file name of meta file.
try:
cls_metainfo[k] = list_from_file(v)
except (TypeError, FileNotFoundError):
warnings.warn(f'{v} is not a meta file, simply parsed as '
'meta information')
cls_metainfo[k] = v
else:
cls_metainfo[k] = v
return cls_metainfo
def _join_prefix(self):
......@@ -526,7 +529,7 @@ class BaseDataset(Dataset):
"""
# Automatically join annotation file path with `self.root` if
# `self.ann_file` is not an absolute path.
if not osp.isabs(self.ann_file) and self.ann_file:
if not is_abs(self.ann_file) and self.ann_file:
self.ann_file = osp.join(self.data_root, self.ann_file)
# Automatically join data directory with `self.root` if path value in
# `self.data_prefix` is not an absolute path.
......@@ -534,9 +537,11 @@ class BaseDataset(Dataset):
if prefix is None:
self.data_prefix[data_key] = self.data_root
elif isinstance(prefix, str):
if not osp.isabs(prefix):
if not is_abs(prefix):
self.data_prefix[data_key] = osp.join(
self.data_root, prefix)
else:
self.data_prefix[data_key] = prefix
else:
raise TypeError('prefix should be a string or None, but got '
f'{type(prefix)}')
......@@ -725,10 +730,9 @@ class BaseDataset(Dataset):
sub_data_list = self.data_list[indices:]
elif isinstance(indices, Sequence):
# Return the data information according to given indices.
subdata_list = []
sub_data_list = []
for idx in indices:
subdata_list.append(self.data_list[idx])
sub_data_list = subdata_list
sub_data_list.append(self.data_list[idx])
else:
raise TypeError('indices should be a int or sequence of int, '
f'but got {type(indices)}')
......
......@@ -10,8 +10,8 @@ from .misc import (check_prerequisites, concat_list, deprecated_api_warning,
to_1tuple, to_2tuple, to_3tuple, to_4tuple, to_ntuple,
tuple_cast)
from .parrots_wrapper import TORCH_VERSION
from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist,
scandir, symlink)
from .path import (check_file_exist, fopen, is_abs, is_filepath,
mkdir_or_exist, scandir, symlink)
from .setup_env import set_multi_processing
from .version_utils import digit_version, get_git_hash
......@@ -28,5 +28,5 @@ __all__ = [
'is_method_overridden', 'has_method', 'mmcv_full_available',
'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url',
'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin',
'set_multi_processing', 'has_batch_norm'
'set_multi_processing', 'has_batch_norm', 'is_abs'
]
......@@ -99,3 +99,18 @@ def find_vcs_root(path, markers=('.git', )):
return cur
prev, cur = cur, osp.split(cur)[0]
return None
def is_abs(path: str) -> bool:
"""Check if path is an absolute path in different backends.
Args:
path (str): path of directory or file.
Returns:
bool: whether path is an absolute path.
"""
if osp.isabs(path) or path.startswith(('http', 'https', 's3')):
return True
else:
return False
......@@ -85,6 +85,7 @@ class TestBaseDataset:
lazy_init=True)
assert not dataset._fully_initialized
assert not dataset.data_list
# test the instantiation of self.base_dataset if ann_file is not
# existed.
with pytest.raises(FileNotFoundError):
......@@ -93,7 +94,7 @@ class TestBaseDataset:
data_prefix=dict(img='imgs'),
ann_file='annotations/not_existed_annotation.json')
# Use the default value of ann_file, i.e., ''
with pytest.raises(FileNotFoundError):
with pytest.raises(TypeError):
BaseDataset(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment