From 538ff48aec7394238244e437ee7a3adce906c9f8 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Tue, 7 Jun 2022 17:09:33 +0800 Subject: [PATCH] [Fix] Rename data_list and support loading from ceph in dataset (#240) * rename datalist and support load ceph * rename datalist and support load ceph * remove check disk file path in _load_metainfo * fix rename error * fix rename error * unit test error * fix rename error * remove unnecessary code * fix lint --- mmengine/dataset/base_dataset.py | 30 ++++++++++++++++------------ mmengine/utils/__init__.py | 6 +++--- mmengine/utils/path.py | 15 ++++++++++++++ tests/test_data/test_base_dataset.py | 3 ++- 4 files changed, 37 insertions(+), 17 deletions(-) diff --git a/mmengine/dataset/base_dataset.py b/mmengine/dataset/base_dataset.py index fbcfe82c..9e85f902 100644 --- a/mmengine/dataset/base_dataset.py +++ b/mmengine/dataset/base_dataset.py @@ -12,7 +12,7 @@ from torch.utils.data import Dataset from mmengine.fileio import list_from_file, load from mmengine.registry import TRANSFORMS -from mmengine.utils import check_file_exist +from mmengine.utils import is_abs class Compose: @@ -428,7 +428,6 @@ class BaseDataset(Dataset): """ # noqa: E501 # `self.ann_file` denotes the absolute annotation file path if # `self.root=None` or relative path if `self.root=/path/to/data/`. - check_file_exist(self.ann_file) annotations = load(self.ann_file) if not isinstance(annotations, dict): raise TypeError(f'The annotations loaded from annotation file ' @@ -482,7 +481,7 @@ class BaseDataset(Dataset): Returns: dict: Parsed meta information. """ - # `cls.METAINFO` will be overwritten by in_meta + # avoid `cls.METAINFO` being overwritten by `metainfo` cls_metainfo = copy.deepcopy(cls.METAINFO) if metainfo is None: return cls_metainfo @@ -491,13 +490,17 @@ class BaseDataset(Dataset): f'metainfo should be a dict, but got {type(metainfo)}') for k, v in metainfo.items(): - if isinstance(v, str) and osp.isfile(v): - # if filename in metainfo, this key will be further parsed. - # nested filename will be ignored. - cls_metainfo[k] = list_from_file(v) + if isinstance(v, str): + # If type of value is string, and can be loaded from + # corresponding backend. it means the file name of meta file. + try: + cls_metainfo[k] = list_from_file(v) + except (TypeError, FileNotFoundError): + warnings.warn(f'{v} is not a meta file, simply parsed as ' + 'meta information') + cls_metainfo[k] = v else: cls_metainfo[k] = v - return cls_metainfo def _join_prefix(self): @@ -526,7 +529,7 @@ class BaseDataset(Dataset): """ # Automatically join annotation file path with `self.root` if # `self.ann_file` is not an absolute path. - if not osp.isabs(self.ann_file) and self.ann_file: + if not is_abs(self.ann_file) and self.ann_file: self.ann_file = osp.join(self.data_root, self.ann_file) # Automatically join data directory with `self.root` if path value in # `self.data_prefix` is not an absolute path. @@ -534,9 +537,11 @@ class BaseDataset(Dataset): if prefix is None: self.data_prefix[data_key] = self.data_root elif isinstance(prefix, str): - if not osp.isabs(prefix): + if not is_abs(prefix): self.data_prefix[data_key] = osp.join( self.data_root, prefix) + else: + self.data_prefix[data_key] = prefix else: raise TypeError('prefix should be a string or None, but got ' f'{type(prefix)}') @@ -725,10 +730,9 @@ class BaseDataset(Dataset): sub_data_list = self.data_list[indices:] elif isinstance(indices, Sequence): # Return the data information according to given indices. - subdata_list = [] + sub_data_list = [] for idx in indices: - subdata_list.append(self.data_list[idx]) - sub_data_list = subdata_list + sub_data_list.append(self.data_list[idx]) else: raise TypeError('indices should be a int or sequence of int, ' f'but got {type(indices)}') diff --git a/mmengine/utils/__init__.py b/mmengine/utils/__init__.py index 690b3b39..56483df3 100644 --- a/mmengine/utils/__init__.py +++ b/mmengine/utils/__init__.py @@ -10,8 +10,8 @@ from .misc import (check_prerequisites, concat_list, deprecated_api_warning, to_1tuple, to_2tuple, to_3tuple, to_4tuple, to_ntuple, tuple_cast) from .parrots_wrapper import TORCH_VERSION -from .path import (check_file_exist, fopen, is_filepath, mkdir_or_exist, - scandir, symlink) +from .path import (check_file_exist, fopen, is_abs, is_filepath, + mkdir_or_exist, scandir, symlink) from .setup_env import set_multi_processing from .version_utils import digit_version, get_git_hash @@ -28,5 +28,5 @@ __all__ = [ 'is_method_overridden', 'has_method', 'mmcv_full_available', 'digit_version', 'get_git_hash', 'TORCH_VERSION', 'load_url', 'find_latest_checkpoint', 'ManagerMeta', 'ManagerMixin', - 'set_multi_processing', 'has_batch_norm' + 'set_multi_processing', 'has_batch_norm', 'is_abs' ] diff --git a/mmengine/utils/path.py b/mmengine/utils/path.py index 56808183..c1228f13 100644 --- a/mmengine/utils/path.py +++ b/mmengine/utils/path.py @@ -99,3 +99,18 @@ def find_vcs_root(path, markers=('.git', )): return cur prev, cur = cur, osp.split(cur)[0] return None + + +def is_abs(path: str) -> bool: + """Check if path is an absolute path in different backends. + + Args: + path (str): path of directory or file. + + Returns: + bool: whether path is an absolute path. + """ + if osp.isabs(path) or path.startswith(('http', 'https', 's3')): + return True + else: + return False diff --git a/tests/test_data/test_base_dataset.py b/tests/test_data/test_base_dataset.py index a3cf586c..7c575e2a 100644 --- a/tests/test_data/test_base_dataset.py +++ b/tests/test_data/test_base_dataset.py @@ -85,6 +85,7 @@ class TestBaseDataset: lazy_init=True) assert not dataset._fully_initialized assert not dataset.data_list + # test the instantiation of self.base_dataset if ann_file is not # existed. with pytest.raises(FileNotFoundError): @@ -93,7 +94,7 @@ class TestBaseDataset: data_prefix=dict(img='imgs'), ann_file='annotations/not_existed_annotation.json') # Use the default value of ann_file, i.e., '' - with pytest.raises(FileNotFoundError): + with pytest.raises(TypeError): BaseDataset( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs')) -- GitLab