diff --git a/docs/zh_cn/tutorials/evaluator.md b/docs/zh_cn/tutorials/evaluator.md index 1a09f181dcb705643abdf8ce01b15c3238101b4b..1dfbb85af5a8e66723859e3941b2ef6f5dd3712c 100644 --- a/docs/zh_cn/tutorials/evaluator.md +++ b/docs/zh_cn/tutorials/evaluator.md @@ -40,7 +40,7 @@ validation_cfg=dict( dict(type='Accuracy', top_k=1), # 使用分类æ£ç¡®çŽ‡è¯„测器 dict(type='F1Score') # 使用 F1_score 评测器 ], - main_metric='accuracy' + main_metric='accuracy', interval=10, by_epoch=True, ) @@ -94,13 +94,14 @@ validation_cfg=dict( 具体的实现如下: ```python -from mmengine.evaluator import BaseEvaluator -from mmengine.registry import EVALUATORS +from mmengine.evaluator import BaseMetric +from mmengine.registry import METRICS import numpy as np -@EVALUATORS.register_module() -class Accuracy(BaseEvaluator): + +@METRICS.register_module() +class Accuracy(BaseMetric): """ Accuracy Evaluator Default prefix: ACC @@ -111,24 +112,24 @@ class Accuracy(BaseEvaluator): default_prefix = 'ACC' - def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]], - predictions: Sequence[BaseDataElement]): + def process(self, data_batch: Sequence[Tuple[Any, dict]], + predictions: Sequence[dict]): """Process one batch of data and predictions. The processed Results should be stored in `self.results`, which will be used to computed the metrics when all batches have been processed. Args: - data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data + data_batch (Sequence[Tuple[Any, dict]]): A batch of data from the dataloader. - predictions (Sequence[BaseDataElement]): A batch of outputs from + predictions (Sequence[dict]): A batch of outputs from the model. """ # å–å‡ºåˆ†ç±»é¢„æµ‹ç»“æžœå’Œç±»åˆ«æ ‡ç¾ - result = dict( - 'pred': predictions.pred_label, - 'gt': data_samples.gt_label - ) + result = { + 'pred': predictions['pred_label'], + 'gt': data_batch['gt_label'] + } # å°†å½“å‰ batch 的结果å˜è¿› self.results self.results.append(result) diff --git a/docs/zh_cn/tutorials/registry.md b/docs/zh_cn/tutorials/registry.md index 09a3805881d69653bb116ed388a0ddcd4f14d81d..ced73f25162fa6984e4dd14e39f6b29686c65511 100644 --- a/docs/zh_cn/tutorials/registry.md +++ b/docs/zh_cn/tutorials/registry.md @@ -225,7 +225,7 @@ MMEngine 的注册器支æŒè·¨é¡¹ç›®è°ƒç”¨ï¼Œå³å¯ä»¥åœ¨ä¸€ä¸ªé¡¹ç›®ä¸ä½¿ç”¨ - OPTIMIZERS: 注册了 PyTorch ä¸æ‰€æœ‰çš„ `optimizer` 以åŠè‡ªå®šä¹‰çš„ `optimizer` - OPTIMIZER_CONSTRUCTORS: optimizer çš„æž„é€ å™¨ - PARAM_SCHEDULERS: å„ç§å‚数调度器, 如 `MultiStepLR` -- EVALUATORS: 用于验è¯æ¨¡åž‹ç²¾åº¦çš„评估器 +- METRICS: 用于验è¯æ¨¡åž‹ç²¾åº¦çš„è¯„ä¼°æŒ‡æ ‡ - TASK_UTILS: 任务强相关的一些组件,如 `AnchorGenerator`, `BboxCoder` - VISUALIZERS: 管ç†ç»˜åˆ¶æ¨¡å—,如 `DetVisualizer` å¯åœ¨å›¾ç‰‡ä¸Šç»˜åˆ¶é¢„测框 - WRITERS: å˜å‚¨è®ç»ƒæ—¥å¿—çš„åŽç«¯ï¼Œå¦‚ `LocalWriter`, `TensorboardWriter` diff --git a/mmengine/data/base_data_element.py b/mmengine/data/base_data_element.py index 47e8d715794ccfdfed1d93f31874fa1beabf423e..3485c84a26f6ad39d8d124aac3ffffd907763902 100644 --- a/mmengine/data/base_data_element.py +++ b/mmengine/data/base_data_element.py @@ -497,6 +497,13 @@ class BaseDataElement: new_data.set_data(data) return new_data + def to_dict(self) -> dict: + """Convert BaseDataElement to dict.""" + return { + k: v.to_dict() if isinstance(v, BaseDataElement) else v + for k, v in self.items() + } + def __repr__(self) -> str: def _addindent(s_: str, num_spaces: int) -> str: diff --git a/mmengine/dataset/base_dataset.py b/mmengine/dataset/base_dataset.py index 0fdfbd0f95c4629dcb46a02889b95a482f979212..179d003836a7a9c7a03f653106a66ca5452cddf5 100644 --- a/mmengine/dataset/base_dataset.py +++ b/mmengine/dataset/base_dataset.py @@ -25,7 +25,10 @@ class Compose: def __init__(self, transforms: Sequence[Union[dict, Callable]]): self.transforms: List[Callable] = [] + for transform in transforms: + # `Compose` can be built with config dict with type and + # corresponding arguments. if isinstance(transform, dict): transform = TRANSFORMS.build(transform) if not callable(transform): @@ -50,6 +53,10 @@ class Compose: """ for t in self.transforms: data = t(data) + # The transform will return None when it failed to load images or + # cannot find suitable augmentation parameters to augment the data. + # Here we simply return None if the transform returns None and the + # dataset will handle it by randomly selecting another data sample. if data is None: return None return data @@ -79,12 +86,16 @@ def force_full_init(old_func: Callable) -> Any: Returns: Any: Depends on old_func. """ - # TODO This decorator can also be used in runner. + @functools.wraps(old_func) def wrapper(obj: object, *args, **kwargs): + # The instance must have `full_init` method. if not hasattr(obj, 'full_init'): raise AttributeError(f'{type(obj)} does not have full_init ' 'method.') + # If instance does not have `_fully_initialized` attribute or + # `_fully_initialized` is False, call `full_init` and set + # `_fully_initialized` to True if not getattr(obj, '_fully_initialized', False): warnings.warn('Attribute `_fully_initialized` is not defined in ' f'{type(obj)} or `type(obj)._fully_initialized is ' @@ -139,16 +150,16 @@ class BaseDataset(Dataset): Args: ann_file (str): Annotation file path. - meta (dict, optional): Meta information for dataset, such as class + metainfo (dict, optional): Meta information for dataset, such as class information. Defaults to None. data_root (str, optional): The root directory for ``data_prefix`` and ``ann_file``. Defaults to None. data_prefix (dict, optional): Prefix for training data. Defaults to dict(img=None, ann=None). filter_cfg (dict, optional): Config for filter data. Defaults to None. - num_samples (int, optional): Support using first few data in - annotation file to facilitate training/testing on a smaller - dataset. Defaults to -1 which means using all ``data_infos``. + indices (int or Sequence[int], optional): Support using first few + data in annotation file to facilitate training/testing on a smaller + dataset. Defaults to None which means using all ``data_infos``. serialize_data (bool, optional): Whether to hold memory using serialized objects, when enabled, data loader workers can use shared RAM from master process instead of making a copy. Defaults @@ -161,26 +172,48 @@ class BaseDataset(Dataset): information of the dataset is needed, which is not necessary to load annotation file. ``Basedataset`` can skip load annotations to save time by set ``lazy_init=False``. Defaults to False. - max_refetch (int, optional): The maximum number of cycles to get a - valid image. Defaults to 1000. + max_refetch (int, optional): If ``Basedataset.prepare_data`` get a + None img. The maximum extra number of cycles to get a valid + image. Defaults to 1000. Note: BaseDataset collects meta information from `annotation file` (the - lowest priority), ``BaseDataset.META``(medium) and `meta parameter` - (highest) passed to constructors. The lower priority meta information - will be overwritten by higher one. + lowest priority), ``BaseDataset.METAINFO``(medium) and `metainfo + parameter` (highest) passed to constructors. The lower priority meta + information will be overwritten by higher one. + + Note: + Dataset wrapper such as ``ConcatDataset``, ``RepeatDataset`` .etc. + should not inherit from ``BaseDataset`` since ``get_subset`` and + ``get_subset_`` could produce ambiguous meaning sub-dataset which + conflicts with original dataset. + + Examples: + Assume the annotation file is given above. + >>> class CustomDataset(BaseDataset): + >>> METAINFO: dict = dict(task_name='custom_task', + >>> dataset_type='custom_type') + >>> metainfo=dict(task_name='custom_task_name') + >>> custom_dataset = CustomDataset( + >>> 'path/to/ann_file', + >>> metainfo=metainfo) + >>> # meta information of annotation file will be overwritten by + >>> # `CustomDataset.METAINFO`. The merged meta information will + >>> # further be overwritten by argument `metainfo`. + >>> custom_dataset.metainfo + {'task_name': custom_task_name, dataset_type: custom_type} """ - META: dict = dict() + METAINFO: dict = dict() _fully_initialized: bool = False def __init__(self, ann_file: str, - meta: Optional[dict] = None, + metainfo: Optional[dict] = None, data_root: Optional[str] = None, data_prefix: dict = dict(img=None, ann=None), filter_cfg: Optional[dict] = None, - num_samples: int = -1, + indices: Optional[Union[int, Sequence[int]]] = None, serialize_data: bool = True, pipeline: List[Union[dict, Callable]] = [], test_mode: bool = False, @@ -191,23 +224,23 @@ class BaseDataset(Dataset): self.data_prefix = copy.copy(data_prefix) self.ann_file = ann_file self.filter_cfg = copy.deepcopy(filter_cfg) - self._num_samples = num_samples + self._indices = indices self.serialize_data = serialize_data self.test_mode = test_mode self.max_refetch = max_refetch - self.data_infos: List[dict] = [] - self.data_infos_bytes: bytearray = bytearray() + self.data_list: List[dict] = [] + self.date_bytes: np.ndarray - # set meta information - self._meta = self._get_meta_data(copy.deepcopy(meta)) + # Set meta information. + self._metainfo = self._get_meta_info(copy.deepcopy(metainfo)) - # join paths + # Join paths. if self.data_root is not None: self._join_prefix() - # build pipeline + # Build pipeline. self.pipeline = Compose(pipeline) - + # Full initialize the dataset. if not lazy_init: self.full_init() @@ -225,11 +258,13 @@ class BaseDataset(Dataset): if self.serialize_data: start_addr = 0 if idx == 0 else self.data_address[idx - 1].item() end_addr = self.data_address[idx].item() - bytes = memoryview(self.data_infos_bytes[start_addr:end_addr]) - data_info = pickle.loads(bytes) + bytes = memoryview( + self.date_bytes[start_addr:end_addr]) # type: ignore + data_info = pickle.loads(bytes) # type: ignore else: - data_info = self.data_infos[idx] - # To record the real positive index of data information. + data_info = self.data_list[idx] + # Some codebase needs `sample_idx` of data information. Here we convert + # the idx to a positive number and save it in data information. if idx >= 0: data_info['sample_idx'] = idx else: @@ -248,69 +283,66 @@ class BaseDataset(Dataset): Several steps to initialize annotation: - - load_annotations: Load annotations from annotation file. + - load_data_list: Load annotations from annotation file. - filter data information: Filter annotations according to - filter_cfg. - - slice_data: Slice dataset according to ``self._num_samples`` - - serialize_data: Serialize ``self.data_infos`` if + filter_cfg. + - slice_data: Slice dataset according to ``self._indices`` + - serialize_data: Serialize ``self.data_list`` if ``self.serialize_data`` is True. """ if self._fully_initialized: return - # load data information - self.data_infos = self.load_annotations(self.ann_file) + self.data_list = self.load_data_list(self.ann_file) # filter illegal data, such as data that has no annotations. - self.data_infos = self.filter_data() - # if `num_samples > 0`, return the first `num_samples` data information - self.data_infos = self._slice_data() - # serialize data_infos + self.data_list = self.filter_data() + # Get subset data according to indices. + if self._indices is not None: + self.data_list = self._get_unserialized_subset(self._indices) + + # serialize data_list if self.serialize_data: - self.data_infos_bytes, self.data_address = self._serialize_data() - # Empty cache for preventing making multiple copies of - # `self.data_info` when loading data multi-processes. - self.data_infos.clear() - gc.collect() + self.date_bytes, self.data_address = self._serialize_data() + self._fully_initialized = True @property - def meta(self) -> dict: + def metainfo(self) -> dict: """Get meta information of dataset. Returns: - dict: meta information collected from ``BaseDataset.META``, - annotation file and meta parameter during instantiation. + dict: meta information collected from ``BaseDataset.METAINFO``, + annotation file and metainfo argument during instantiation. """ - return copy.deepcopy(self._meta) + return copy.deepcopy(self._metainfo) - def parse_annotations(self, - raw_data_info: dict) -> Union[dict, List[dict]]: + def parse_data_info(self, raw_data_info: dict) -> Union[dict, List[dict]]: """Parse raw annotation to target format. - ``parse_annotations`` should return ``dict`` or ``List[dict]``. Each - dict contains the annotations of a training sample. If the protocol of + This method should return dict or list of dict. Each dict or list + contains the data information of a training sample. If the protocol of the sample annotations is changed, this function can be overridden to update the parsing logic while keeping compatibility. Args: - raw_data_info (dict): Raw annotation load from ``ann_file`` + raw_data_info (dict): Raw data information load from ``ann_file`` Returns: - Union[dict, List[dict]]: Parsed annotation. + list or list[dict]: Parsed annotation. """ return raw_data_info def filter_data(self) -> List[dict]: """Filter annotations according to filter_cfg. Defaults return all - ``data_infos``. + ``data_list``. - If some ``data_infos`` could be filtered according to specific logic, + If some ``data_list`` could be filtered according to specific logic, the subclass should override this method. Returns: - List[dict]: Filtered results. + list[int]: Filtered results. """ - return self.data_infos + return self.data_list def get_cat_ids(self, idx: int) -> List[int]: """Get category ids by index. Dataset wrapped by ClassBalancedDataset @@ -323,26 +355,34 @@ class BaseDataset(Dataset): idx (int): The index of data. Returns: - List[int]: All categories in the image of specified index. + list[int]: All categories in the image of specified index. """ raise NotImplementedError(f'{type(self)} must implement `get_cat_ids` ' 'method') def __getitem__(self, idx: int) -> dict: - """Get the idx-th image of dataset after ``self.pipelines`` and - ``full_init`` will be called if the dataset has not been fully - initialized. + """Get the idx-th image and data information of dataset after + ``self.pipeline``, and ``full_init`` will be called if the dataset has + not been fully initialized. - During training phase, if ``self.pipelines`` get ``None``, + During training phase, if ``self.pipeline`` get ``None``, ``self._rand_another`` will be called until a valid image is fetched or the maximum limit of refetech is reached. Args: - idx (int): The index of self.data_infos + idx (int): The index of self.data_list. Returns: - dict: The idx-th image of dataset after ``self.pipelines``. + dict: The idx-th image and data information of dataset after + ``self.pipeline``. """ + # Performing full initialization by calling `__getitem__` will consume + # extra memory. If a dataset is not fully initialized by setting + # `lazy_init=True` and then fed into the dataloader. Different workers + # will simultaneously read and parse the annotation. It will cost more + # time and memory, although this may work. Therefore, it is recommended + # to manually call `full_init` before dataset fed into dataloader to + # ensure all workers use shared RAM from master process. if not self._fully_initialized: warnings.warn( 'Please call `full_init()` method manually to accelerate ' @@ -350,31 +390,39 @@ class BaseDataset(Dataset): self.full_init() if self.test_mode: - return self.prepare_data(idx) - - for _ in range(self.max_refetch): - data_sample = self.prepare_data(idx) - if data_sample is None: + data = self.prepare_data(idx) + if data is None: + raise Exception('Test time pipline should not get `None` ' + 'data_sample') + return data + + for _ in range(self.max_refetch + 1): + data = self.prepare_data(idx) + # Broken images or random augmentations may cause the returned data + # to be None + if data is None: idx = self._rand_another() continue - return data_sample + return data raise Exception(f'Cannot find valid image after {self.max_refetch}! ' - 'Please check your image path and pipelines') + 'Please check your image path and pipeline') - def load_annotations(self, ann_file: str) -> List[dict]: + def load_data_list(self, ann_file: str) -> List[dict]: """Load annotations from an annotation file. If the annotation file does not follow `OpenMMLab 2.0 format dataset <https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md>`_ . - The subclass must override this method for load annotations. + The subclass must override this method for load annotations. The meta + information of annotation file will be overwritten :attr:`METAINFO` + and ``metainfo`` argument of constructor. Args: ann_file (str): Absolute annotation file path if ``self.root=None`` or relative path if ``self.root=/path/to/data/``. Returns: - List[dict]: A list of annotation. + list[dict]: A list of annotation. """ # noqa: E501 check_file_exist(ann_file) annotations = load(ann_file) @@ -387,59 +435,66 @@ class BaseDataset(Dataset): meta_data = annotations['metadata'] raw_data_infos = annotations['data_infos'] - # update self._meta + # Meta information load from annotation file will not influence the + # existed meta information load from `BaseDataset.METAINFO` and + # `metainfo` arguments defined in constructor. for k, v in meta_data.items(): - # We only merge keys that are not contained in self._meta. - self._meta.setdefault(k, v) + self._metainfo.setdefault(k, v) - # load and parse data_infos - data_infos = [] + # load and parse data_infos. + data_list = [] for raw_data_info in raw_data_infos: - data_info = self.parse_annotations(raw_data_info) + # parse raw data information to target format + data_info = self.parse_data_info(raw_data_info) if isinstance(data_info, dict): - data_infos.append(data_info) + # For image tasks, `data_info` should information if single + # image, such as dict(img_path='xxx', width=360, ...) + data_list.append(data_info) elif isinstance(data_info, list): - # data_info can also be a list of dict, which means one - # data_info contains multiple samples. + # For video tasks, `data_info` could contain image + # information of multiple frames, such as + # [dict(video_path='xxx', timestamps=...), + # dict(video_path='xxx', timestamps=...)] for item in data_info: if not isinstance(item, dict): - raise TypeError('data_info must be a list[dict], but ' + raise TypeError('data_info must be list of dict, but ' f'got {type(item)}') - data_infos.extend(data_info) + data_list.extend(data_info) else: - raise TypeError('data_info should be a dict or list[dict], ' + raise TypeError('data_info should be a dict or list of dict, ' f'but got {type(data_info)}') - return data_infos + return data_list @classmethod - def _get_meta_data(cls, in_meta: dict = None) -> dict: + def _get_meta_info(cls, in_metainfo: dict = None) -> dict: """Collect meta information from the dictionary of meta. Args: - in_meta (dict): Meta information dict. If ``in_meta`` contains - existed filename, it will be parsed by ``list_from_file``. + in_metainfo (dict): Meta information dict. If ``in_metainfo`` + contains existed filename, it will be parsed by + ``list_from_file``. Returns: dict: Parsed meta information. """ - # cls.META will be overwritten by in_meta - cls_meta = copy.deepcopy(cls.META) - if in_meta is None: - return cls_meta - if not isinstance(in_meta, dict): + # `cls.METAINFO` will be overwritten by in_meta + cls_metainfo = copy.deepcopy(cls.METAINFO) + if in_metainfo is None: + return cls_metainfo + if not isinstance(in_metainfo, dict): raise TypeError( - f'in_meta should be a dict, but got {type(in_meta)}') + f'in_metainfo should be a dict, but got {type(in_metainfo)}') - for k, v in in_meta.items(): + for k, v in in_metainfo.items(): if isinstance(v, str) and osp.isfile(v): - # if filename in in_meta, this key will be further parsed. - # nested filename will be ignored.: - cls_meta[k] = list_from_file(v) + # if filename in in_metainfo, this key will be further parsed. + # nested filename will be ignored. + cls_metainfo[k] = list_from_file(v) else: - cls_meta[k] = v + cls_metainfo[k] = v - return cls_meta + return cls_metainfo def _join_prefix(self): """Join ``self.data_root`` with ``self.data_prefix`` and @@ -465,9 +520,12 @@ class BaseDataset(Dataset): >>> self.ann_file 'a/b/c/f' """ + # Automatically join annotation file path with `self.root` if + # `self.ann_file` is not an absolute path. if not osp.isabs(self.ann_file): self.ann_file = osp.join(self.data_root, self.ann_file) - + # Automatically join data directory with `self.root` if path value in + # `self.data_prefix` is not an absolute path. for data_key, prefix in self.data_prefix.items(): if prefix is None: self.data_prefix[data_key] = self.data_root @@ -479,31 +537,208 @@ class BaseDataset(Dataset): raise TypeError('prefix should be a string or None, but got ' f'{type(prefix)}') - def _slice_data(self) -> List[dict]: - """Slice ``self.data_infos``. BaseDataset supports only using the first - few data. + @force_full_init + def get_subset_(self, indices: Union[Sequence[int], int]) -> None: + """The in-place version of ``get_subset `` to convert dataset to a + subset of original dataset. + + This method will convert the original dataset to a subset of dataset. + If type of indices is int, ``get_subset_`` will return a subdataset + which contains the first or last few data information according to + indices is positive or negative. If type of indices is a sequence of + int, the subdataset will extract the data information according to + the index given in indices. + + Examples: + >>> dataset = BaseDataset('path/to/ann_file') + >>> len(dataset) + 100 + >>> dataset.get_subset_(90) + >>> len(dataset) + 90 + >>> # if type of indices is sequence, extract the corresponding + >>> # index data information + >>> dataset.get_subset_([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + >>> len(dataset) + 10 + >>> dataset.get_subset_(-3) + >>> len(dataset) # Get the latest few data information. + 3 + + Args: + indices (int or Sequence[int]): If type of indices is int, indices + represents the first or last few data of dataset according to + indices is positive or negative. If type of indices is + Sequence, indices represents the target data information + index of dataset. + """ + # Get subset of data from serialized data or data information sequence + # according to `self.serialize_data`. + if self.serialize_data: + self.date_bytes, self.data_address = \ + self._get_serialized_subset(indices) + else: + self.data_list = self._get_unserialized_subset(indices) + + @force_full_init + def get_subset(self, indices: Union[Sequence[int], int]) -> 'BaseDataset': + """Return a subset of dataset. + + This method will return a subset of original dataset. If type of + indices is int, ``get_subset_`` will return a subdataset which + contains the first or last few data information according to + indices is positive or negative. If type of indices is a sequence of + int, the subdataset will extract the information according to the index + given in indices. + + Examples: + >>> dataset = BaseDataset('path/to/ann_file') + >>> len(dataset) + 100 + >>> subdataset = dataset.get_subset(90) + >>> len(sub_dataset) + 90 + >>> # if type of indices is list, extract the corresponding + >>> # index data information + >>> subdataset = dataset.get_subset([0, 1, 2, 3, 4, 5, 6, 7, + >>> 8, 9]) + >>> len(sub_dataset) + 10 + >>> subdataset = dataset.get_subset(-3) + >>> len(subdataset) # Get the latest few data information. + 3 + + Args: + indices (int or Sequence[int]): If type of indices is int, indices + represents the first or last few data of dataset according to + indices is positive or negative. If type of indices is + Sequence, indices represents the target data information + index of dataset. + + Returns: + BaseDataset: A subset of dataset. + """ + # Get subset of data from serialized data or data information list + # according to `self.serialize_data`. Since `_get_serialized_subset` + # will recalculate the subset data information, + # `_copy_without_annotation` will copy all attributes except data + # information. + sub_dataset = self._copy_without_annotation() + # Get subset of dataset with serialize and unserialized data. + if self.serialize_data: + date_bytes, data_address = \ + self._get_serialized_subset(indices) + sub_dataset.date_bytes = date_bytes.copy() + sub_dataset.data_address = data_address.copy() + else: + data_list = self._get_unserialized_subset(indices) + sub_dataset.data_list = copy.deepcopy(data_list) + return sub_dataset + + def _get_serialized_subset(self, indices: Union[Sequence[int], int]) \ + -> Tuple[np.ndarray, np.ndarray]: + """Get subset of serialized data information list. + + Args: + indices (int or Sequence[int]): If type of indices is int, + indices represents the first or last few data of serialized + data information list. If type of indices is Sequence, indices + represents the target data information index which consist of + subset data information. Returns: - List[dict]: A slice of ``self.data_infos`` + Tuple[np.ndarray, np.ndarray]: subset of serialized data + information. """ - assert self._num_samples < len(self.data_infos), \ - f'Slice size({self._num_samples}) is larger than dataset ' \ - f'size({self.data_infos}, please keep `num_sample` smaller than' \ - f'{self.data_infos})' - if self._num_samples > 0: - return self.data_infos[:self._num_samples] + sub_date_bytes: Union[List, np.ndarray] + sub_data_address: Union[List, np.ndarray] + if isinstance(indices, int): + if indices >= 0: + assert indices < len(self.data_address), \ + f'{indices} is out of dataset length({len(self)}' + # Return the first few data information. + end_addr = self.data_address[indices - 1].item() \ + if indices > 0 else 0 + # Slicing operation of `np.ndarray` does not trigger a memory + # copy. + sub_date_bytes = self.date_bytes[:end_addr] + # Since the buffer size of first few data information is not + # changed, + sub_data_address = self.data_address[:indices] + else: + assert -indices <= len(self.data_address), \ + f'{indices} is out of dataset length({len(self)}' + # Return the last few data information. + ignored_bytes_size = self.data_address[indices - 1] + start_addr = self.data_address[indices - 1].item() + sub_date_bytes = self.date_bytes[start_addr:] + sub_data_address = self.data_address[indices:] + sub_data_address = sub_data_address - ignored_bytes_size + elif isinstance(indices, Sequence): + sub_date_bytes = [] + sub_data_address = [] + for idx in indices: + assert len(self) > idx >= -len(self) + start_addr = 0 if idx == 0 else \ + self.data_address[idx - 1].item() + end_addr = self.data_address[idx].item() + # Get data information by address. + sub_date_bytes.append(self.date_bytes[start_addr:end_addr]) + # Get data information size. + sub_data_address.append(end_addr - start_addr) + # Handle indices is an empty list. + if sub_date_bytes: + sub_date_bytes = np.concatenate(sub_date_bytes) + sub_data_address = np.cumsum(sub_data_address) + else: + sub_date_bytes = np.array([]) + sub_data_address = np.array([]) else: - return self.data_infos + raise TypeError('indices should be a int or sequence of int, ' + f'but got {type(indices)}') + return sub_date_bytes, sub_data_address # type: ignore + + def _get_unserialized_subset(self, indices: Union[Sequence[int], + int]) -> list: + """Get subset of data information list. + + Args: + indices (int or Sequence[int]): If type of indices is int, + indices represents the first or last few data of data + information. If indices of indices is Sequence, indices + represents the target data information index which consist + of subset data information. + + Returns: + Tuple[np.ndarray, np.ndarray]: subset of data information. + """ + if isinstance(indices, int): + if indices >= 0: + # Return the first few data information. + sub_data_list = self.data_list[:indices] + else: + # Return the last few data information. + sub_data_list = self.data_list[indices:] + elif isinstance(indices, Sequence): + # Return the data information according to given indices. + subdata_list = [] + for idx in indices: + subdata_list.append(self.data_list[idx]) + sub_data_list = subdata_list + else: + raise TypeError('indices should be a int or sequence of int, ' + f'but got {type(indices)}') + return sub_data_list def _serialize_data(self) -> Tuple[np.ndarray, np.ndarray]: - """Serialize ``self.data_infos`` to save memory when launching multiple + """Serialize ``self.data_list`` to save memory when launching multiple workers in data loading. This function will be called in ``full_init``. Hold memory using serialized objects, and data loader workers can use shared RAM from master process instead of making a copy. Returns: - Tuple[np.ndarray, np.ndarray]: serialize result and corresponding + Tuple[np.ndarray, np.ndarray]: Serialized result and corresponding address. """ @@ -511,13 +746,19 @@ class BaseDataset(Dataset): buffer = pickle.dumps(data, protocol=4) return np.frombuffer(buffer, dtype=np.uint8) - serialized_data_infos_list = [_serialize(x) for x in self.data_infos] - address_list = np.asarray([len(x) for x in serialized_data_infos_list], - dtype=np.int64) + # Serialize data information list avoid making multiple copies of + # `self.data_list` when iterate `import torch.utils.data.dataloader` + # with multiple workers. + data_list = [_serialize(x) for x in self.data_list] + address_list = np.asarray([len(x) for x in data_list], dtype=np.int64) data_address: np.ndarray = np.cumsum(address_list) - serialized_data_infos = np.concatenate(serialized_data_infos_list) - - return serialized_data_infos, data_address + # TODO Check if np.concatenate is necessary + data_bytes = np.concatenate(data_list) + # Empty cache for preventing making multiple copies of + # `self.data_info` when loading data multi-processes. + self.data_list.clear() + gc.collect() + return data_bytes, data_address def _rand_another(self) -> int: """Get random index. @@ -550,4 +791,24 @@ class BaseDataset(Dataset): if self.serialize_data: return len(self.data_address) else: - return len(self.data_infos) + return len(self.data_list) + + def _copy_without_annotation(self, memo=dict()) -> 'BaseDataset': + """Deepcopy for all attributes other than ``data_list``, + ``data_address`` and ``date_bytes``. + + Args: + memo: Memory dict which used to reconstruct complex object + correctly. + """ + cls = self.__class__ + other = cls.__new__(cls) + memo[id(self)] = other + + for key, value in self.__dict__.items(): + if key in ['data_list', 'data_address', 'date_bytes']: + continue + super(BaseDataset, other).__setattr__(key, + copy.deepcopy(value, memo)) + + return other diff --git a/mmengine/dataset/dataset_wrapper.py b/mmengine/dataset/dataset_wrapper.py index ba8c83cfbb0f511295f022b10b2ad279fbcadc5b..454ec96d0df47902b88c1d1f6f2469e944075a70 100644 --- a/mmengine/dataset/dataset_wrapper.py +++ b/mmengine/dataset/dataset_wrapper.py @@ -4,7 +4,7 @@ import copy import math import warnings from collections import defaultdict -from typing import List, Sequence, Tuple +from typing import List, Sequence, Tuple, Union from torch.utils.data.dataset import ConcatDataset as _ConcatDataset @@ -16,6 +16,13 @@ class ConcatDataset(_ConcatDataset): Same as ``torch.utils.data.dataset.ConcatDataset`` and support lazy_init. + Note: + ``ConcatDataset`` should not inherit from ``BaseDataset`` since + ``get_subset`` and ``get_subset_`` could produce ambiguous meaning + sub-dataset which conflicts with original dataset. If you want to use + a sub-dataset of ``ConcatDataset``, you should set ``indices`` + arguments for wrapped dataset which inherit from ``BaseDataset``. + Args: datasets (Sequence[BaseDataset]): A list of datasets which will be concatenated. @@ -26,12 +33,12 @@ class ConcatDataset(_ConcatDataset): def __init__(self, datasets: Sequence[BaseDataset], lazy_init: bool = False): - # Only use meta of first dataset. - self._meta = datasets[0].meta + # Only use metainfo of first dataset. + self._metainfo = datasets[0].metainfo self.datasets = datasets # type: ignore for i, dataset in enumerate(datasets, 1): - if self._meta != dataset.meta: - warnings.warn( + if self._metainfo != dataset.metainfo: + raise ValueError( f'The meta information of the {i}-th dataset does not ' 'match meta information of the first dataset') @@ -40,14 +47,14 @@ class ConcatDataset(_ConcatDataset): self.full_init() @property - def meta(self) -> dict: + def metainfo(self) -> dict: """Get the meta information of the first dataset in ``self.datasets``. Returns: dict: Meta information of first dataset. """ - # Prevent `self._meta` from being modified by outside. - return copy.deepcopy(self._meta) + # Prevent `self._metainfo` from being modified by outside. + return copy.deepcopy(self._metainfo) def full_init(self): """Loop to ``full_init`` each dataset.""" @@ -77,8 +84,9 @@ class ConcatDataset(_ConcatDataset): f'absolute value of index({idx}) should not exceed dataset' f'length({len(self)}).') idx = len(self) + idx - # Get the inner index of single dataset + # Get `dataset_idx` to tell idx belongs to which dataset. dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + # Get the inner index of single dataset. if dataset_idx == 0: sample_idx = idx else: @@ -111,6 +119,26 @@ class ConcatDataset(_ConcatDataset): dataset_idx, sample_idx = self._get_ori_dataset_idx(idx) return self.datasets[dataset_idx][sample_idx] + def get_subset_(self, indices: Union[List[int], int]) -> None: + """Not supported in ``ConcatDataset`` for the ambiguous meaning of sub- + dataset.""" + raise NotImplementedError( + '`ConcatDataset` dose not support `get_subset` and ' + '`get_subset_` interfaces because this will lead to ambiguous ' + 'implementation of some methods. If you want to use `get_subset` ' + 'or `get_subset_` interfaces, please use them in the wrapped ' + 'dataset first and then use `ConcatDataset`.') + + def get_subset(self, indices: Union[List[int], int]) -> 'BaseDataset': + """Not supported in ``ConcatDataset`` for the ambiguous meaning of sub- + dataset.""" + raise NotImplementedError( + '`ConcatDataset` dose not support `get_subset` and ' + '`get_subset_` interfaces because this will lead to ambiguous ' + 'implementation of some methods. If you want to use `get_subset` ' + 'or `get_subset_` interfaces, please use them in the wrapped ' + 'dataset first and then use `ConcatDataset`.') + class RepeatDataset: """A wrapper of repeated dataset. @@ -120,10 +148,17 @@ class RepeatDataset: is small. Using RepeatDataset can reduce the data loading time between epochs. + Note: + ``RepeatDataset`` should not inherit from ``BaseDataset`` since + ``get_subset`` and ``get_subset_`` could produce ambiguous meaning + sub-dataset which conflicts with original dataset. If you want to use + a sub-dataset of ``RepeatDataset``, you should set ``indices`` + arguments for wrapped dataset which inherit from ``BaseDataset``. + Args: dataset (BaseDataset): The dataset to be repeated. times (int): Repeat times. - lazy_init (bool, optional): Whether to load annotation during + lazy_init (bool): Whether to load annotation during instantiation. Defaults to False. """ @@ -133,20 +168,20 @@ class RepeatDataset: lazy_init: bool = False): self.dataset = dataset self.times = times - self._meta = dataset.meta + self._metainfo = dataset.metainfo self._fully_initialized = False if not lazy_init: self.full_init() @property - def meta(self) -> dict: + def metainfo(self) -> dict: """Get the meta information of the repeated dataset. Returns: dict: The meta information of repeated dataset. """ - return copy.deepcopy(self._meta) + return copy.deepcopy(self._metainfo) def full_init(self): """Loop to ``full_init`` each dataset.""" @@ -195,6 +230,26 @@ class RepeatDataset: def __len__(self): return self.times * self._ori_len + def get_subset_(self, indices: Union[List[int], int]) -> None: + """Not supported in ``RepeatDataset`` for the ambiguous meaning of sub- + dataset.""" + raise NotImplementedError( + '`RepeatDataset` dose not support `get_subset` and ' + '`get_subset_` interfaces because this will lead to ambiguous ' + 'implementation of some methods. If you want to use `get_subset` ' + 'or `get_subset_` interfaces, please use them in the wrapped ' + 'dataset first and then use `RepeatDataset`.') + + def get_subset(self, indices: Union[List[int], int]) -> 'BaseDataset': + """Not supported in ``RepeatDataset`` for the ambiguous meaning of sub- + dataset.""" + raise NotImplementedError( + '`RepeatDataset` dose not support `get_subset` and ' + '`get_subset_` interfaces because this will lead to ambiguous ' + 'implementation of some methods. If you want to use `get_subset` ' + 'or `get_subset_` interfaces, please use them in the wrapped ' + 'dataset first and then use `RepeatDataset`.') + class ClassBalancedDataset: """A wrapper of class balanced dataset. @@ -219,6 +274,14 @@ class ClassBalancedDataset: 3. For each image I, compute the image-level repeat factor: :math:`r(I) = max_{c in I} r(c)` + Note: + ``ClassBalancedDataset`` should not inherit from ``BaseDataset`` + since ``get_subset`` and ``get_subset_`` could produce ambiguous + meaning sub-dataset which conflicts with original dataset. If you + want to use a sub-dataset of ``ClassBalancedDataset``, you should set + ``indices`` arguments for wrapped dataset which inherit from + ``BaseDataset``. + Args: dataset (BaseDataset): The dataset to be repeated. oversample_thr (float): frequency threshold below which data is @@ -236,20 +299,20 @@ class ClassBalancedDataset: lazy_init: bool = False): self.dataset = dataset self.oversample_thr = oversample_thr - self._meta = dataset.meta + self._metainfo = dataset.metainfo self._fully_initialized = False if not lazy_init: self.full_init() @property - def meta(self) -> dict: + def metainfo(self) -> dict: """Get the meta information of the repeated dataset. Returns: dict: The meta information of repeated dataset. """ - return copy.deepcopy(self._meta) + return copy.deepcopy(self._metainfo) def full_init(self): """Loop to ``full_init`` each dataset.""" @@ -257,9 +320,12 @@ class ClassBalancedDataset: return self.dataset.full_init() - + # Get repeat factors for each image. repeat_factors = self._get_repeat_factors(self.dataset, self.oversample_thr) + # Repeat dataset's indices according to repeat_factors. For example, + # if `repeat_factors = [1, 2, 3]`, and the `len(dataset) == 3`, + # the repeated indices will be [1, 2, 2, 3, 3, 3]. repeat_indices = [] for dataset_index, repeat_factor in enumerate(repeat_factors): repeat_indices.extend([dataset_index] * math.ceil(repeat_factor)) @@ -362,3 +428,23 @@ class ClassBalancedDataset: @force_full_init def __len__(self): return len(self.repeat_indices) + + def get_subset_(self, indices: Union[List[int], int]) -> None: + """Not supported in ``ClassBalancedDataset`` for the ambiguous meaning + of sub-dataset.""" + raise NotImplementedError( + '`ClassBalancedDataset` dose not support `get_subset` and ' + '`get_subset_` interfaces because this will lead to ambiguous ' + 'implementation of some methods. If you want to use `get_subset` ' + 'or `get_subset_` interfaces, please use them in the wrapped ' + 'dataset first and then use `ClassBalancedDataset`.') + + def get_subset(self, indices: Union[List[int], int]) -> 'BaseDataset': + """Not supported in ``ClassBalancedDataset`` for the ambiguous meaning + of sub-dataset.""" + raise NotImplementedError( + '`ClassBalancedDataset` dose not support `get_subset` and ' + '`get_subset_` interfaces because this will lead to ambiguous ' + 'implementation of some methods. If you want to use `get_subset` ' + 'or `get_subset_` interfaces, please use them in the wrapped ' + 'dataset first and then use `ClassBalancedDataset`.') diff --git a/mmengine/evaluator/__init__.py b/mmengine/evaluator/__init__.py index 3866851031203d6fc76debc0fd4f150a179447d4..ac91828d98ba56345aa58214f8fcdcabca861ecc 100644 --- a/mmengine/evaluator/__init__.py +++ b/mmengine/evaluator/__init__.py @@ -1,9 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .base import BaseEvaluator -from .builder import build_evaluator -from .composed_evaluator import ComposedEvaluator +from .evaluator import Evaluator +from .metric import BaseMetric from .utils import get_metric_value -__all__ = [ - 'BaseEvaluator', 'ComposedEvaluator', 'build_evaluator', 'get_metric_value' -] +__all__ = ['BaseMetric', 'Evaluator', 'get_metric_value'] diff --git a/mmengine/evaluator/builder.py b/mmengine/evaluator/builder.py deleted file mode 100644 index 40fa03a3f240f6aff95942982df985d7ef5fafea..0000000000000000000000000000000000000000 --- a/mmengine/evaluator/builder.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Union - -from ..registry import EVALUATORS -from .base import BaseEvaluator -from .composed_evaluator import ComposedEvaluator - - -def build_evaluator( - cfg: Union[dict, list]) -> Union[BaseEvaluator, ComposedEvaluator]: - """Build function of evaluator. - - When the evaluator config is a list, it will automatically build composed - evaluators. - - Args: - cfg (dict | list): Config of evaluator. When the config is a list, it - will automatically build composed evaluators. - - Returns: - BaseEvaluator or ComposedEvaluator: The built evaluator. - """ - if isinstance(cfg, list): - evaluators = [EVALUATORS.build(_cfg) for _cfg in cfg] - return ComposedEvaluator(evaluators=evaluators) - else: - return EVALUATORS.build(cfg) diff --git a/mmengine/evaluator/composed_evaluator.py b/mmengine/evaluator/composed_evaluator.py deleted file mode 100644 index b965b358eb731b1f8c21946e5ebf83f81a6ea764..0000000000000000000000000000000000000000 --- a/mmengine/evaluator/composed_evaluator.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Optional, Sequence, Tuple, Union - -from mmengine.data import BaseDataElement -from .base import BaseEvaluator - - -class ComposedEvaluator: - """Wrapper class to compose multiple :class:`BaseEvaluator` instances. - - Args: - evaluators (Sequence[BaseEvaluator]): The evaluators to compose. - collect_device (str): Device name used for collecting results from - different ranks during distributed training. Must be 'cpu' or - 'gpu'. Defaults to 'cpu'. - """ - - def __init__(self, - evaluators: Sequence[BaseEvaluator], - collect_device='cpu'): - self._dataset_meta: Union[None, dict] = None - self.collect_device = collect_device - self.evaluators = evaluators - - @property - def dataset_meta(self) -> Optional[dict]: - return self._dataset_meta - - @dataset_meta.setter - def dataset_meta(self, dataset_meta: dict) -> None: - self._dataset_meta = dataset_meta - for evaluator in self.evaluators: - evaluator.dataset_meta = dataset_meta - - def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]], - predictions: Sequence[BaseDataElement]): - """Invoke process method of each wrapped evaluator. - - Args: - data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data - from the dataloader. - predictions (Sequence[BaseDataElement]): A batch of outputs from - the model. - """ - - for evalutor in self.evaluators: - evalutor.process(data_batch, predictions) - - def evaluate(self, size: int) -> dict: - """Invoke evaluate method of each wrapped evaluator and collect the - metrics dict. - - Args: - size (int): Length of the entire validation dataset. When batch - size > 1, the dataloader may pad some data samples to make - sure all ranks have the same length of dataset slice. The - ``collect_results`` function will drop the padded data base on - this size. - - Returns: - dict: Evaluation metrics of all wrapped evaluators. The keys are - the names of the metrics, and the values are corresponding results. - """ - metrics = {} - for evaluator in self.evaluators: - _metrics = evaluator.evaluate(size) - - # Check metric name conflicts - for name in _metrics.keys(): - if name in metrics: - raise ValueError( - 'There are multiple evaluators with the same metric ' - f'name {name}') - - metrics.update(_metrics) - return metrics diff --git a/mmengine/evaluator/evaluator.py b/mmengine/evaluator/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..c653fb563aa90ce447ad0be964b14901b9aed657 --- /dev/null +++ b/mmengine/evaluator/evaluator.py @@ -0,0 +1,131 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Iterator, List, Optional, Sequence, Tuple, Union + +from mmengine.data import BaseDataElement +from ..registry.root import METRICS +from .metric import BaseMetric + + +class Evaluator: + """Wrapper class to compose multiple :class:`BaseMetric` instances. + + Args: + metrics (dict or BaseMetric or Sequence): The config of metrics. + """ + + def __init__(self, metrics: Union[dict, BaseMetric, Sequence]): + self._dataset_meta: Optional[dict] = None + if not isinstance(metrics, Sequence): + metrics = [metrics] + self.metrics: List[BaseMetric] = [] + for metric in metrics: + if isinstance(metric, BaseMetric): + self.metrics.append(metric) + elif isinstance(metric, dict): + self.metrics.append(METRICS.build(metric)) + else: + raise TypeError('metric should be a dict or a BaseMetric, ' + f'but got {metric}.') + + @property + def dataset_meta(self) -> Optional[dict]: + return self._dataset_meta + + @dataset_meta.setter + def dataset_meta(self, dataset_meta: dict) -> None: + self._dataset_meta = dataset_meta + for metric in self.metrics: + metric.dataset_meta = dataset_meta + + def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]], + predictions: Sequence[BaseDataElement]): + """Convert ``BaseDataSample`` to dict and invoke process method of each + metric. + + Args: + data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data + from the dataloader. + predictions (Sequence[BaseDataElement]): A batch of outputs from + the model. + """ + _data_batch = [] + for input, data in data_batch: + if isinstance(data, BaseDataElement): + _data_batch.append((input, data.to_dict())) + else: + _data_batch.append((input, data)) + _predictions = [] + for pred in predictions: + if isinstance(pred, BaseDataElement): + _predictions.append(pred.to_dict()) + else: + _predictions.append(pred) + + for metric in self.metrics: + metric.process(_data_batch, _predictions) + + def evaluate(self, size: int) -> dict: + """Invoke ``evaluate`` method of each metric and collect the metrics + dictionary. + + Args: + size (int): Length of the entire validation dataset. When batch + size > 1, the dataloader may pad some data samples to make + sure all ranks have the same length of dataset slice. The + ``collect_results`` function will drop the padded data based on + this size. + + Returns: + dict: Evaluation results of all metrics. The keys are the names + of the metrics, and the values are corresponding results. + """ + metrics = {} + for metric in self.metrics: + _results = metric.evaluate(size) + + # Check metric name conflicts + for name in _results.keys(): + if name in metrics: + raise ValueError( + 'There are multiple evaluation results with the same ' + f'metric name {name}. Please make sure all metrics ' + 'have different prefixes.') + + metrics.update(_results) + return metrics + + def offline_evaluate(self, + data: Sequence, + predictions: Sequence, + chunk_size: int = 1): + """Offline evaluate the dumped predictions on the given data . + + Args: + data (Sequence): All data of the validation set. + predictions (Sequence): All predictions of the model on the + validation set. + chunk_size (int): The number of data samples and predictions to be + processed in a batch. + """ + + # support chunking iterable objects + def get_chunks(seq: Iterator, chunk_size=1): + stop = False + while not stop: + chunk = [] + for _ in range(chunk_size): + try: + chunk.append(next(seq)) + except StopIteration: + stop = True + break + if chunk: + yield chunk + + size = 0 + for data_chunk, pred_chunk in zip( + get_chunks(iter(data), chunk_size), + get_chunks(iter(predictions), chunk_size)): + size += len(data_chunk) + self.process(data_chunk, pred_chunk) + return self.evaluate(size) diff --git a/mmengine/evaluator/base.py b/mmengine/evaluator/metric.py similarity index 78% rename from mmengine/evaluator/base.py rename to mmengine/evaluator/metric.py index 2f31444b76ddeafe26ff194fb9b0806309e34a0e..e8a71488023a344792acfc22726c3b9582e9e276 100644 --- a/mmengine/evaluator/base.py +++ b/mmengine/evaluator/metric.py @@ -3,20 +3,19 @@ import warnings from abc import ABCMeta, abstractmethod from typing import Any, List, Optional, Sequence, Tuple, Union -from mmengine.data import BaseDataElement from mmengine.dist import (broadcast_object_list, collect_results, is_main_process) -class BaseEvaluator(metaclass=ABCMeta): - """Base class for an evaluator. +class BaseMetric(metaclass=ABCMeta): + """Base class for a metric. - The evaluator first processes each batch of data_samples and - predictions, and appends the processed results in to the results list. - Then it collects all results together from all ranks if distributed - training is used. Finally, it computes the metrics of the entire dataset. + The metric first processes each batch of data_samples and predictions, + and appends the processed results to the results list. Then it + collects all results together from all ranks if distributed training + is used. Finally, it computes the metrics of the entire dataset. - A subclass of class:`BaseEvaluator` should assign a meaningful value to the + A subclass of class:`BaseMetric` should assign a meaningful value to the class attribute `default_prefix`. See the argument `prefix` for details. Args: @@ -39,7 +38,7 @@ class BaseEvaluator(metaclass=ABCMeta): self.results: List[Any] = [] self.prefix = prefix or self.default_prefix if self.prefix is None: - warnings.warn('The prefix is not set in evaluator class ' + warnings.warn('The prefix is not set in metric class ' f'{self.__class__.__name__}.') @property @@ -51,16 +50,16 @@ class BaseEvaluator(metaclass=ABCMeta): self._dataset_meta = dataset_meta @abstractmethod - def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]], - predictions: Sequence[BaseDataElement]) -> None: + def process(self, data_batch: Sequence[Tuple[Any, dict]], + predictions: Sequence[dict]) -> None: """Process one batch of data samples and predictions. The processed results should be stored in ``self.results``, which will be used to compute the metrics when all batches have been processed. Args: - data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data + data_batch (Sequence[Tuple[Any, dict]]): A batch of data from the dataloader. - predictions (Sequence[BaseDataElement]): A batch of outputs from + predictions (Sequence[dict]): A batch of outputs from the model. """ @@ -84,7 +83,7 @@ class BaseEvaluator(metaclass=ABCMeta): size (int): Length of the entire validation dataset. When batch size > 1, the dataloader may pad some data samples to make sure all ranks have the same length of dataset slice. The - ``collect_results`` function will drop the padded data base on + ``collect_results`` function will drop the padded data based on this size. Returns: @@ -93,9 +92,9 @@ class BaseEvaluator(metaclass=ABCMeta): """ if len(self.results) == 0: warnings.warn( - f'{self.__class__.__name__} got empty `self._results`. Please ' + f'{self.__class__.__name__} got empty `self.results`. Please ' 'ensure that the processed results are properly added into ' - '`self._results` in `process` method.') + '`self.results` in `process` method.') results = collect_results(self.results, size, self.collect_device) diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py index 2299c17e2b20582985f44edc3a3ae1c284ad0b8f..ead8cb0afd7ba9e4a4800006fa188b866c7ceb8f 100644 --- a/mmengine/registry/__init__.py +++ b/mmengine/registry/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .default_scope import DefaultScope from .registry import Registry, build_from_cfg -from .root import (DATA_SAMPLERS, DATASETS, EVALUATORS, HOOKS, LOOPS, +from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS, MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, TRANSFORMS, VISUALIZERS, WEIGHT_INITIALIZERS, WRITERS) @@ -10,6 +10,6 @@ __all__ = [ 'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', - 'EVALUATORS', 'MODEL_WRAPPERS', 'LOOPS', 'WRITERS', 'VISUALIZERS', + 'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'WRITERS', 'VISUALIZERS', 'DefaultScope' ] diff --git a/mmengine/registry/root.py b/mmengine/registry/root.py index d692acc18e231f4104f6f03f61244c8ed62f9f59..571d55cbbc9454319945d465e3799edab1921c95 100644 --- a/mmengine/registry/root.py +++ b/mmengine/registry/root.py @@ -35,8 +35,8 @@ OPTIMIZERS = Registry('optimizer') OPTIMIZER_CONSTRUCTORS = Registry('optimizer constructor') # mangage all kinds of parameter schedulers like `MultiStepLR` PARAM_SCHEDULERS = Registry('parameter scheduler') -# manage all kinds of evaluators for computing metrics -EVALUATORS = Registry('evaluator') +# manage all kinds of metrics +METRICS = Registry('metric') # manage task-specific modules like anchor generators and box coders TASK_UTILS = Registry('task util') diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 9e04d640e3715d84fb6a59240895143e4b2dadf3..eb5c3454bb38c1b686b57518f205715a55ce2001 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -5,7 +5,7 @@ import torch from torch.utils.data import DataLoader from mmengine.data import BaseDataElement -from mmengine.evaluator import BaseEvaluator, build_evaluator +from mmengine.evaluator import Evaluator from mmengine.registry import LOOPS from mmengine.utils import is_list_of from .base_loop import BaseLoop @@ -165,19 +165,19 @@ class ValLoop(BaseLoop): runner (Runner): A reference of runner. dataloader (Dataloader or dict): A dataloader object or a dict to build a dataloader. - evaluator (BaseEvaluator or dict or list): Used for computing metrics. + evaluator (Evaluator or dict or list): Used for computing metrics. interval (int): Validation interval. Defaults to 1. """ def __init__(self, runner, dataloader: Union[DataLoader, Dict], - evaluator: Union[BaseEvaluator, Dict, List], + evaluator: Union[Evaluator, Dict, List], interval: int = 1) -> None: super().__init__(runner, dataloader) if isinstance(evaluator, dict) or is_list_of(evaluator, dict): - self.evaluator = build_evaluator(evaluator) # type: ignore + self.evaluator = runner.build_evaluator(evaluator) # type: ignore else: self.evaluator = evaluator # type: ignore @@ -228,15 +228,15 @@ class TestLoop(BaseLoop): runner (Runner): A reference of runner. dataloader (Dataloader or dict): A dataloader object or a dict to build a dataloader. - evaluator (BaseEvaluator or dict or list): Used for computing metrics. + evaluator (Evaluator or dict or list): Used for computing metrics. """ def __init__(self, runner, dataloader: Union[DataLoader, Dict], - evaluator: Union[BaseEvaluator, Dict, List]): + evaluator: Union[Evaluator, Dict, List]): super().__init__(runner, dataloader) if isinstance(evaluator, dict) or is_list_of(evaluator, dict): - self.evaluator = build_evaluator(evaluator) # type: ignore + self.evaluator = runner.build_evaluator(evaluator) # type: ignore else: self.evaluator = evaluator # type: ignore diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index fea3dff2f3e34d0166626ad7fd5f3a679563ad42..a833570e235818ade3c7832aa78696b74bd655c4 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -23,8 +23,7 @@ from mmengine.config import Config, ConfigDict from mmengine.data import pseudo_collate, worker_init_fn from mmengine.dist import (broadcast, get_dist_info, init_dist, master_only, sync_random_seed) -from mmengine.evaluator import (BaseEvaluator, ComposedEvaluator, - build_evaluator) +from mmengine.evaluator import Evaluator from mmengine.hooks import Hook from mmengine.logging import MessageHub, MMLogger from mmengine.model import is_model_wrapper @@ -41,7 +40,6 @@ from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model, from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop from .priority import Priority, get_priority -EvaluatorType = Union[BaseEvaluator, ComposedEvaluator] ConfigType = Union[Dict, Config, ConfigDict] @@ -211,8 +209,8 @@ class Runner: test_cfg: Optional[Dict] = None, optimizer: Optional[Union[Optimizer, Dict]] = None, param_scheduler: Optional[Union[_ParamScheduler, Dict, List]] = None, - val_evaluator: Optional[Union[EvaluatorType, Dict, List]] = None, - test_evaluator: Optional[Union[EvaluatorType, Dict, List]] = None, + val_evaluator: Optional[Union[Evaluator, Dict, List]] = None, + test_evaluator: Optional[Union[Evaluator, Dict, List]] = None, default_hooks: Optional[Dict[str, Union[Hook, Dict]]] = None, custom_hooks: Optional[List[Union[Hook, Dict]]] = None, load_from: Optional[str] = None, @@ -804,37 +802,35 @@ class Runner: return param_schedulers def build_evaluator( - self, evaluator: Union[Dict, List[Dict], - EvaluatorType]) -> EvaluatorType: + self, evaluator: Union[Dict, List[Dict], Evaluator]) -> Evaluator: """Build evaluator. Examples of ``evaluator``:: - evaluator = dict(type='ToyEvaluator') + evaluator = dict(type='ToyMetric') # evaluator can also be a list of dict evaluator = [ - dict(type='ToyEvaluator1'), + dict(type='ToyMetric1'), dict(type='ToyEvaluator2') ] Args: - evaluator (BaseEvaluator or ComposedEvaluator or dict or list): + evaluator (Evaluator or dict or list): An Evaluator object or a config dict or list of config dict - used to build evaluators. + used to build an Evaluator. Returns: - BaseEvaluator or ComposedEvaluator: Evaluators build from - ``evaluator``. + Evaluator: Evaluator build from ``evaluator``. """ - if isinstance(evaluator, (BaseEvaluator, ComposedEvaluator)): + if isinstance(evaluator, Evaluator): return evaluator elif isinstance(evaluator, dict) or is_list_of(evaluator, dict): - return build_evaluator(evaluator) # type: ignore + return Evaluator(evaluator) # type: ignore else: raise TypeError( - 'evaluator should be one of dict, list of dict, BaseEvaluator ' - f'and ComposedEvaluator, but got {evaluator}') + 'evaluator should be one of dict, list of dict, and Evaluator' + f', but got {evaluator}') def build_dataloader(self, dataloader: Union[DataLoader, Dict]) -> DataLoader: diff --git a/tests/data/annotations/dummy_annotation.json b/tests/data/annotations/dummy_annotation.json index d36109ce731d86e607bcb73925778101b0ffa07d..5fac907e8079cdb305535733c36435460cd0e877 100644 --- a/tests/data/annotations/dummy_annotation.json +++ b/tests/data/annotations/dummy_annotation.json @@ -46,6 +46,26 @@ "extra_anns": [4,5,6] } ] + }, + { + "img_path": "gray.jpg", + "height": 512, + "width": 512, + "instances": + [ + { + "bbox": [0, 0, 10, 20], + "bbox_label": 1, + "mask": [[0,0],[0,10],[10,20],[20,0]], + "extra_anns": [1,2,3] + }, + { + "bbox": [10, 10, 110, 120], + "bbox_label": 2, + "mask": [[10,10],[10,110],[110,120],[120,10]], + "extra_anns": [4,5,6] + } + ] } ] } diff --git a/tests/test_data/test_base_dataset.py b/tests/test_data/test_base_dataset.py index e695488c53dd51d967031b14e8bfceac7c229fe0..3dbbd91917d14d38c8e61661e48f91b4a2f124b6 100644 --- a/tests/test_data/test_base_dataset.py +++ b/tests/test_data/test_base_dataset.py @@ -33,12 +33,12 @@ class TestBaseDataset: filename='test_img.jpg', height=604, width=640, sample_idx=0) imgs = torch.rand((2, 3, 32, 32)) pipeline = MagicMock(return_value=dict(imgs=imgs)) - META: dict = dict() - parse_annotations = MagicMock(return_value=data_info) + METAINFO: dict = dict() + parse_data_info = MagicMock(return_value=data_info) def _init_dataset(self): - self.dataset_type.META = self.META - self.dataset_type.parse_annotations = self.parse_annotations + self.dataset_type.METAINFO = self.METAINFO + self.dataset_type.parse_data_info = self.parse_data_info def test_init(self): self._init_dataset() @@ -48,14 +48,14 @@ class TestBaseDataset: data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json') assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') + assert hasattr(dataset, 'data_list') assert hasattr(dataset, 'data_address') dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img=None), ann_file='annotations/dummy_annotation.json') assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') + assert hasattr(dataset, 'data_list') assert hasattr(dataset, 'data_address') # test the instantiation of self.base_dataset with @@ -66,9 +66,9 @@ class TestBaseDataset: ann_file='annotations/dummy_annotation.json', serialize_data=False) assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') + assert hasattr(dataset, 'data_list') assert not hasattr(dataset, 'data_address') - assert len(dataset) == 2 + assert len(dataset) == 3 assert dataset.get_data_info(0) == self.data_info # test the instantiation of self.base_dataset with lazy init @@ -78,7 +78,7 @@ class TestBaseDataset: ann_file='annotations/dummy_annotation.json', lazy_init=True) assert not dataset._fully_initialized - assert not dataset.data_infos + assert not dataset.data_list # test the instantiation of self.base_dataset if ann_file is not # existed. @@ -106,9 +106,9 @@ class TestBaseDataset: data_prefix=dict(img=['img']), ann_file='annotations/annotation_wrong_format.json') - # test the instantiation of self.base_dataset when `parse_annotations` + # test the instantiation of self.base_dataset when `parse_data_info` # return `list[dict]` - self.dataset_type.parse_annotations = MagicMock( + self.dataset_type.parse_data_info = MagicMock( return_value=[self.data_info, self.data_info.copy()]) dataset = self.dataset_type( @@ -117,136 +117,137 @@ class TestBaseDataset: ann_file='annotations/dummy_annotation.json') dataset.pipeline = self.pipeline assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') + assert hasattr(dataset, 'data_list') assert hasattr(dataset, 'data_address') - assert len(dataset) == 4 + assert len(dataset) == 6 assert dataset[0] == dict(imgs=self.imgs) assert dataset.get_data_info(0) == self.data_info - # test the instantiation of self.base_dataset when `parse_annotations` + # test the instantiation of self.base_dataset when `parse_data_info` # return unsupported data. with pytest.raises(TypeError): - self.dataset_type.parse_annotations = MagicMock(return_value='xxx') + self.dataset_type.parse_data_info = MagicMock(return_value='xxx') dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json') with pytest.raises(TypeError): - self.dataset_type.parse_annotations = MagicMock( + self.dataset_type.parse_data_info = MagicMock( return_value=[self.data_info, 'xxx']) - dataset = self.dataset_type( + self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json') def test_meta(self): self._init_dataset() - # test dataset.meta with setting the meta from annotation file as the - # meta of self.base_dataset + # test dataset.metainfo with setting the metainfo from annotation file + # as the metainfo of self.base_dataset. dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json') - assert dataset.meta == dict( + assert dataset.metainfo == dict( dataset_type='test_dataset', task_name='test_task', empty_list=[]) - # test dataset.meta with setting META in self.base_dataset + # test dataset.metainfo with setting METAINFO in self.base_dataset dataset_type = 'new_dataset' - self.dataset_type.META = dict( + self.dataset_type.METAINFO = dict( dataset_type=dataset_type, classes=('dog', 'cat')) dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json') - assert dataset.meta == dict( + assert dataset.metainfo == dict( dataset_type=dataset_type, task_name='test_task', classes=('dog', 'cat'), empty_list=[]) - # test dataset.meta with passing meta into self.base_dataset - meta = dict(classes=('dog', ), task_name='new_task') + # test dataset.metainfo with passing metainfo into self.base_dataset + metainfo = dict(classes=('dog', ), task_name='new_task') dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json', - meta=meta) - assert self.dataset_type.META == dict( + metainfo=metainfo) + assert self.dataset_type.METAINFO == dict( dataset_type=dataset_type, classes=('dog', 'cat')) - assert dataset.meta == dict( + assert dataset.metainfo == dict( dataset_type=dataset_type, task_name='new_task', classes=('dog', ), empty_list=[]) - # reset `base_dataset.META`, the `dataset.meta` should not change - self.dataset_type.META['classes'] = ('dog', 'cat', 'fish') - assert self.dataset_type.META == dict( + # reset `base_dataset.METAINFO`, the `dataset.metainfo` should not + # change + self.dataset_type.METAINFO['classes'] = ('dog', 'cat', 'fish') + assert self.dataset_type.METAINFO == dict( dataset_type=dataset_type, classes=('dog', 'cat', 'fish')) - assert dataset.meta == dict( + assert dataset.metainfo == dict( dataset_type=dataset_type, task_name='new_task', classes=('dog', ), empty_list=[]) - # test dataset.meta with passing meta containing a file into + # test dataset.metainfo with passing metainfo containing a file into # self.base_dataset - meta = dict( + metainfo = dict( classes=osp.join( osp.dirname(__file__), '../data/meta/classes.txt')) dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json', - meta=meta) - assert dataset.meta == dict( + metainfo=metainfo) + assert dataset.metainfo == dict( dataset_type=dataset_type, task_name='test_task', classes=['dog'], empty_list=[]) - # test dataset.meta with passing unsupported meta into + # test dataset.metainfo with passing unsupported metainfo into # self.base_dataset with pytest.raises(TypeError): - meta = 'dog' + metainfo = 'dog' dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json', - meta=meta) + metainfo=metainfo) - # test dataset.meta with passing meta into self.base_dataset and - # lazy_init is True - meta = dict(classes=('dog', )) + # test dataset.metainfo with passing metainfo into self.base_dataset + # and lazy_init is True + metainfo = dict(classes=('dog', )) dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json', - meta=meta, + metainfo=metainfo, lazy_init=True) - # 'task_name' and 'empty_list' not in dataset.meta - assert dataset.meta == dict( + # 'task_name' and 'empty_list' not in dataset.metainfo + assert dataset.metainfo == dict( dataset_type=dataset_type, classes=('dog', )) - # test whether self.base_dataset.META is changed when a customize + # test whether self.base_dataset.METAINFO is changed when a customize # dataset inherit self.base_dataset - # test reset META in ToyDataset. + # test reset METAINFO in ToyDataset. class ToyDataset(self.dataset_type): - META = dict(xxx='xxx') + METAINFO = dict(xxx='xxx') - assert ToyDataset.META == dict(xxx='xxx') - assert self.dataset_type.META == dict( + assert ToyDataset.METAINFO == dict(xxx='xxx') + assert self.dataset_type.METAINFO == dict( dataset_type=dataset_type, classes=('dog', 'cat', 'fish')) - # test update META in ToyDataset. + # test update METAINFO in ToyDataset. class ToyDataset(self.dataset_type): - META = copy.deepcopy(self.dataset_type.META) - META['classes'] = ('bird', ) + METAINFO = copy.deepcopy(self.dataset_type.METAINFO) + METAINFO['classes'] = ('bird', ) - assert ToyDataset.META == dict( + assert ToyDataset.METAINFO == dict( dataset_type=dataset_type, classes=('bird', )) - assert self.dataset_type.META == dict( + assert self.dataset_type.METAINFO == dict( dataset_type=dataset_type, classes=('dog', 'cat', 'fish')) @pytest.mark.parametrize('lazy_init', [True, False]) @@ -258,16 +259,16 @@ class TestBaseDataset: lazy_init=lazy_init) if not lazy_init: assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') - assert len(dataset) == 2 + assert hasattr(dataset, 'data_list') + assert len(dataset) == 3 else: # test `__len__()` when lazy_init is True assert not dataset._fully_initialized - assert not dataset.data_infos + assert not dataset.data_list # call `full_init()` automatically - assert len(dataset) == 2 + assert len(dataset) == 3 assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') + assert hasattr(dataset, 'data_list') def test_compose(self): # test callable transform @@ -308,30 +309,41 @@ class TestBaseDataset: dataset.pipeline = self.pipeline if not lazy_init: assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') + assert hasattr(dataset, 'data_list') assert dataset[0] == dict(imgs=self.imgs) else: - # test `__getitem__()` when lazy_init is True + # Test `__getitem__()` when lazy_init is True assert not dataset._fully_initialized - assert not dataset.data_infos - # call `full_init()` automatically + assert not dataset.data_list + # Call `full_init()` automatically assert dataset[0] == dict(imgs=self.imgs) assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') + assert hasattr(dataset, 'data_list') - # test with test mode - dataset.test_mode = True + # Test with test mode + dataset.test_mode = False assert dataset[0] == dict(imgs=self.imgs) + # Test cannot get a valid image. + dataset.prepare_data = MagicMock(return_value=None) + with pytest.raises(Exception): + dataset[0] + # Test get valid image by `_rand_another` - pipeline = MagicMock(return_value=None) - dataset.pipeline = pipeline - # test cannot get a valid image. - dataset.test_mode = False + def fake_prepare_data(idx): + if idx == 0: + return None + else: + return 1 + + dataset.prepare_data = fake_prepare_data + dataset[0] + dataset.test_mode = True with pytest.raises(Exception): dataset[0] @pytest.mark.parametrize('lazy_init', [True, False]) def test_get_data_info(self, lazy_init): + self._init_dataset() dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), @@ -340,16 +352,16 @@ class TestBaseDataset: if not lazy_init: assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') + assert hasattr(dataset, 'data_list') assert dataset.get_data_info(0) == self.data_info else: # test `get_data_info()` when lazy_init is True assert not dataset._fully_initialized - assert not dataset.data_infos + assert not dataset.data_list # call `full_init()` automatically assert dataset.get_data_info(0) == self.data_info assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') + assert hasattr(dataset, 'data_list') def test_force_full_init(self): with pytest.raises(AttributeError): @@ -363,40 +375,178 @@ class TestBaseDataset: class_without_full_init = ClassWithoutFullInit() class_without_full_init.foo() - @pytest.mark.parametrize('lazy_init', [True, False]) - def test_full_init(self, lazy_init): + def test_full_init(self): + self._init_dataset() dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), ann_file='annotations/dummy_annotation.json', - lazy_init=lazy_init) + lazy_init=True) dataset.pipeline = self.pipeline - if not lazy_init: - assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') - assert len(dataset) == 2 - assert dataset[0] == dict(imgs=self.imgs) - assert dataset.get_data_info(0) == self.data_info - else: - # test `full_init()` when lazy_init is True - assert not dataset._fully_initialized - assert not dataset.data_infos - # call `full_init()` manually - dataset.full_init() - assert dataset._fully_initialized - assert hasattr(dataset, 'data_infos') - assert len(dataset) == 2 - assert dataset[0] == dict(imgs=self.imgs) - assert dataset.get_data_info(0) == self.data_info + # test `full_init()` when lazy_init is True + assert not dataset._fully_initialized + assert not dataset.data_list + # call `full_init()` manually + dataset.full_init() + assert dataset._fully_initialized + assert hasattr(dataset, 'data_list') + assert len(dataset) == 3 + assert dataset[0] == dict(imgs=self.imgs) + assert dataset.get_data_info(0) == self.data_info - def test_slice_data(self): - # test the instantiation of self.base_dataset when passing num_samples + dataset = self.dataset_type( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img='imgs'), + ann_file='annotations/dummy_annotation.json', + lazy_init=False) + + dataset.pipeline = self.pipeline + assert dataset._fully_initialized + assert hasattr(dataset, 'data_list') + assert len(dataset) == 3 + assert dataset[0] == dict(imgs=self.imgs) + assert dataset.get_data_info(0) == self.data_info + + # test the instantiation of self.base_dataset when passing indices + dataset = self.dataset_type( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img=None), + ann_file='annotations/dummy_annotation.json') + dataset_sliced = self.dataset_type( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img=None), + ann_file='annotations/dummy_annotation.json', + indices=1) + assert dataset_sliced[0] == dataset[0] + assert len(dataset_sliced) == 1 + + @pytest.mark.parametrize( + 'lazy_init, serialize_data', + ([True, False], [False, True], [True, True], [False, False])) + def test_get_subset_(self, lazy_init, serialize_data): + # Test positive int indices. + indices = 2 + dataset = self.dataset_type( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img=None), + ann_file='annotations/dummy_annotation.json', + lazy_init=lazy_init, + serialize_data=serialize_data) + + dataset_copy = copy.deepcopy(dataset) + dataset_copy.get_subset_(indices) + assert len(dataset_copy) == 2 + for i in range(len(dataset_copy)): + ori_data = dataset[i] + assert dataset_copy[i] == ori_data + + # Test negative int indices. + indices = -2 + dataset_copy = copy.deepcopy(dataset) + dataset_copy.get_subset_(indices) + assert len(dataset_copy) == 2 + for i in range(len(dataset_copy)): + ori_data = dataset[i + 1] + ori_data['sample_idx'] = i + assert dataset_copy[i] == ori_data + + # If indices is 0, return empty dataset. + dataset_copy = copy.deepcopy(dataset) + dataset_copy.get_subset_(0) + assert len(dataset_copy) == 0 + + # Test list indices with positive element. + indices = [1] + dataset_copy = copy.deepcopy(dataset) + ori_data = dataset[1] + ori_data['sample_idx'] = 0 + dataset_copy.get_subset_(indices) + assert len(dataset_copy) == 1 + assert dataset_copy[0] == ori_data + + # Test list indices with negative element. + indices = [-1] + dataset_copy = copy.deepcopy(dataset) + ori_data = dataset[2] + ori_data['sample_idx'] = 0 + dataset_copy.get_subset_(indices) + assert len(dataset_copy) == 1 + assert dataset_copy[0] == ori_data + + # Test empty list. + indices = [] + dataset_copy = copy.deepcopy(dataset) + dataset_copy.get_subset_(indices) + assert len(dataset_copy) == 0 + # Test list with multiple positive indices. + indices = [0, 1, 2] + dataset_copy = copy.deepcopy(dataset) + dataset_copy.get_subset_(indices) + for i in range(len(dataset_copy)): + ori_data = dataset[i] + ori_data['sample_idx'] = i + assert dataset_copy[i] == ori_data + # Test list with multiple negative indices. + indices = [-1, -2, 0] + dataset_copy = copy.deepcopy(dataset) + dataset_copy.get_subset_(indices) + for i in range(len(dataset_copy)): + ori_data = dataset[len(dataset) - i - 1] + ori_data['sample_idx'] = i + assert dataset_copy[i] == ori_data + + with pytest.raises(TypeError): + dataset.get_subset_(dict()) + + @pytest.mark.parametrize( + 'lazy_init, serialize_data', + ([True, False], [False, True], [True, True], [False, False])) + def test_get_subset(self, lazy_init, serialize_data): + # Test positive indices. + indices = 2 dataset = self.dataset_type( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img=None), ann_file='annotations/dummy_annotation.json', - num_samples=1) - assert len(dataset) == 1 + lazy_init=lazy_init, + serialize_data=serialize_data) + dataset_sliced = dataset.get_subset(indices) + assert len(dataset_sliced) == 2 + assert dataset_sliced[0] == dataset[0] + for i in range(len(dataset_sliced)): + assert dataset_sliced[i] == dataset[i] + # Test negative indices. + indices = -2 + dataset_sliced = dataset.get_subset(indices) + assert len(dataset_sliced) == 2 + for i in range(len(dataset_sliced)): + ori_data = dataset[i + 1] + ori_data['sample_idx'] = i + assert dataset_sliced[i] == ori_data + # If indices is 0 or empty list, return empty dataset. + assert len(dataset.get_subset(0)) == 0 + assert len(dataset.get_subset([])) == 0 + # test list indices. + indices = [1] + dataset_sliced = dataset.get_subset(indices) + ori_data = dataset[1] + ori_data['sample_idx'] = 0 + assert len(dataset_sliced) == 1 + assert dataset_sliced[0] == ori_data + # Test list with multiple positive index. + indices = [0, 1, 2] + dataset_sliced = dataset.get_subset(indices) + for i in range(len(dataset_sliced)): + ori_data = dataset[i] + ori_data['sample_idx'] = i + assert dataset_sliced[i] == ori_data + # Test list with multiple negative index. + indices = [-1, -2, 0] + dataset_sliced = dataset.get_subset(indices) + for i in range(len(dataset_sliced)): + ori_data = dataset[len(dataset) - i - 1] + ori_data['sample_idx'] = i + assert dataset_sliced[i] == ori_data def test_rand_another(self): # test the instantiation of self.base_dataset when passing num_samples @@ -404,7 +554,7 @@ class TestBaseDataset: data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img=None), ann_file='annotations/dummy_annotation.json', - num_samples=1) + indices=1) assert dataset._rand_another() >= 0 assert dataset._rand_another() < len(dataset) @@ -416,7 +566,7 @@ class TestConcatDataset: # create dataset_a data_info = dict(filename='test_img.jpg', height=604, width=640) - dataset.parse_annotations = MagicMock(return_value=data_info) + dataset.parse_data_info = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) self.dataset_a = dataset( @@ -427,58 +577,44 @@ class TestConcatDataset: # create dataset_b data_info = dict(filename='gray.jpg', height=288, width=512) - dataset.parse_annotations = MagicMock(return_value=data_info) + dataset.parse_data_info = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) self.dataset_b = dataset( data_root=osp.join(osp.dirname(__file__), '../data/'), data_prefix=dict(img='imgs'), - ann_file='annotations/dummy_annotation.json', - meta=dict(classes=('dog', 'cat'))) + ann_file='annotations/dummy_annotation.json') self.dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs)) # test init self.cat_datasets = ConcatDataset( datasets=[self.dataset_a, self.dataset_b]) def test_full_init(self): - dataset = BaseDataset - - # create dataset_a - data_info = dict(filename='test_img.jpg', height=604, width=640) - dataset.parse_annotations = MagicMock(return_value=data_info) - imgs = torch.rand((2, 3, 32, 32)) - - dataset_a = dataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), - ann_file='annotations/dummy_annotation.json') - dataset_a.pipeline = MagicMock(return_value=dict(imgs=imgs)) - - # create dataset_b - data_info = dict(filename='gray.jpg', height=288, width=512) - dataset.parse_annotations = MagicMock(return_value=data_info) - imgs = torch.rand((2, 3, 32, 32)) - dataset_b = dataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), - ann_file='annotations/dummy_annotation.json', - meta=dict(classes=('dog', 'cat'))) - dataset_b.pipeline = MagicMock(return_value=dict(imgs=imgs)) + self._init_dataset() # test init with lazy_init=True - cat_datasets = ConcatDataset( - datasets=[dataset_a, dataset_b], lazy_init=True) - cat_datasets.full_init() - assert len(cat_datasets) == 4 - cat_datasets.full_init() - cat_datasets._fully_initialized = False - cat_datasets[1] - assert len(cat_datasets) == 4 + self.cat_datasets.full_init() + assert len(self.cat_datasets) == 6 + self.cat_datasets.full_init() + self.cat_datasets._fully_initialized = False + self.cat_datasets[1] + assert len(self.cat_datasets) == 6 + + with pytest.raises(NotImplementedError): + self.cat_datasets.get_subset_(1) + + with pytest.raises(NotImplementedError): + self.cat_datasets.get_subset(1) + # Different meta information will raise error. + with pytest.raises(ValueError): + dataset_b = BaseDataset( + data_root=osp.join(osp.dirname(__file__), '../data/'), + data_prefix=dict(img='imgs'), + ann_file='annotations/dummy_annotation.json', + metainfo=dict(classes=('cat'))) + ConcatDataset(datasets=[self.dataset_a, dataset_b]) - def test_meta(self): + def test_metainfo(self): self._init_dataset() - assert self.cat_datasets.meta == self.dataset_a.meta - # meta of self.cat_datasets is from the first dataset when - # concatnating datasets with different metas. - assert self.cat_datasets.meta != self.dataset_b.meta + assert self.cat_datasets.metainfo == self.dataset_a.metainfo def test_length(self): self._init_dataset() @@ -524,7 +660,7 @@ class TestRepeatDataset: def _init_dataset(self): dataset = BaseDataset data_info = dict(filename='test_img.jpg', height=604, width=640) - dataset.parse_annotations = MagicMock(return_value=data_info) + dataset.parse_data_info = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) self.dataset = dataset( data_root=osp.join(osp.dirname(__file__), '../data/'), @@ -538,31 +674,26 @@ class TestRepeatDataset: dataset=self.dataset, times=self.repeat_times) def test_full_init(self): - dataset = BaseDataset - data_info = dict(filename='test_img.jpg', height=604, width=640) - dataset.parse_annotations = MagicMock(return_value=data_info) - imgs = torch.rand((2, 3, 32, 32)) - dataset = dataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), - ann_file='annotations/dummy_annotation.json') - dataset.pipeline = MagicMock(return_value=dict(imgs=imgs)) + self._init_dataset() - repeat_times = 5 - # test init - repeat_datasets = RepeatDataset( - dataset=dataset, times=repeat_times, lazy_init=True) + self.repeat_datasets.full_init() + assert len( + self.repeat_datasets) == self.repeat_times * len(self.dataset) + self.repeat_datasets.full_init() + self.repeat_datasets._fully_initialized = False + self.repeat_datasets[1] + assert len(self.repeat_datasets) == \ + self.repeat_times * len(self.dataset) - repeat_datasets.full_init() - assert len(repeat_datasets) == repeat_times * len(dataset) - repeat_datasets.full_init() - repeat_datasets._fully_initialized = False - repeat_datasets[1] - assert len(repeat_datasets) == repeat_times * len(dataset) + with pytest.raises(NotImplementedError): + self.repeat_datasets.get_subset_(1) - def test_meta(self): + with pytest.raises(NotImplementedError): + self.repeat_datasets.get_subset(1) + + def test_metainfo(self): self._init_dataset() - assert self.repeat_datasets.meta == self.dataset.meta + assert self.repeat_datasets.metainfo == self.dataset.metainfo def test_length(self): self._init_dataset() @@ -587,7 +718,7 @@ class TestClassBalancedDataset: def _init_dataset(self): dataset = BaseDataset data_info = dict(filename='test_img.jpg', height=604, width=640) - dataset.parse_annotations = MagicMock(return_value=data_info) + dataset.parse_data_info = MagicMock(return_value=data_info) imgs = torch.rand((2, 3, 32, 32)) dataset.get_cat_ids = MagicMock(return_value=[0]) self.dataset = dataset( @@ -603,34 +734,26 @@ class TestClassBalancedDataset: self.cls_banlanced_datasets.repeat_indices = self.repeat_indices def test_full_init(self): - dataset = BaseDataset - data_info = dict(filename='test_img.jpg', height=604, width=640) - dataset.parse_annotations = MagicMock(return_value=data_info) - imgs = torch.rand((2, 3, 32, 32)) - dataset.get_cat_ids = MagicMock(return_value=[0]) - dataset = dataset( - data_root=osp.join(osp.dirname(__file__), '../data/'), - data_prefix=dict(img='imgs'), - ann_file='annotations/dummy_annotation.json') - dataset.pipeline = MagicMock(return_value=dict(imgs=imgs)) + self._init_dataset() - repeat_indices = [0, 0, 1, 1, 1] - # test init - cls_banlanced_datasets = ClassBalancedDataset( - dataset=dataset, oversample_thr=1e-3, lazy_init=True) - - cls_banlanced_datasets.full_init() - cls_banlanced_datasets.repeat_indices = repeat_indices - assert len(cls_banlanced_datasets) == len(repeat_indices) - cls_banlanced_datasets.full_init() - cls_banlanced_datasets._fully_initialized = False - cls_banlanced_datasets[1] - cls_banlanced_datasets.repeat_indices = repeat_indices - assert len(cls_banlanced_datasets) == len(repeat_indices) + self.cls_banlanced_datasets.full_init() + self.cls_banlanced_datasets.repeat_indices = self.repeat_indices + assert len(self.cls_banlanced_datasets) == len(self.repeat_indices) + self.cls_banlanced_datasets.full_init() + self.cls_banlanced_datasets._fully_initialized = False + self.cls_banlanced_datasets[1] + self.cls_banlanced_datasets.repeat_indices = self.repeat_indices + assert len(self.cls_banlanced_datasets) == len(self.repeat_indices) - def test_meta(self): + with pytest.raises(NotImplementedError): + self.cls_banlanced_datasets.get_subset_(1) + + with pytest.raises(NotImplementedError): + self.cls_banlanced_datasets.get_subset(1) + + def test_metainfo(self): self._init_dataset() - assert self.cls_banlanced_datasets.meta == self.dataset.meta + assert self.cls_banlanced_datasets.metainfo == self.dataset.metainfo def test_length(self): self._init_dataset() diff --git a/tests/test_data/test_data_element.py b/tests/test_data/test_data_element.py index 275b8a3ccd3797d11056245b2e78b48ce4d5a7e9..7f93383f905cac22d0ddf26e5aba7648d9aa91e4 100644 --- a/tests/test_data/test_data_element.py +++ b/tests/test_data/test_data_element.py @@ -417,3 +417,12 @@ class TestBaseDataElement(TestCase): # test_items assert len(dict(instances.items())) == len(dict(data.items())) + + def test_to_dict(self): + metainfo, data = self.setup_data() + instances = BaseDataElement(metainfo=metainfo, **data) + dict_instances = instances.to_dict() + # test convert BaseDataElement to dict + assert isinstance(dict_instances, dict) + assert isinstance(dict_instances['gt_instances'], dict) + assert isinstance(dict_instances['pred_instances'], dict) diff --git a/tests/test_evaluator/test_base_evaluator.py b/tests/test_evaluator/test_evaluator.py similarity index 72% rename from tests/test_evaluator/test_base_evaluator.py rename to tests/test_evaluator/test_evaluator.py index 042d2fb8a0a1298f7dec93e59c276718349f7368..61be034be2d5d3934f3d23eb60fd7eaa0eec143b 100644 --- a/tests/test_evaluator/test_base_evaluator.py +++ b/tests/test_evaluator/test_evaluator.py @@ -6,12 +6,12 @@ from unittest import TestCase import numpy as np from mmengine.data import BaseDataElement -from mmengine.evaluator import BaseEvaluator, build_evaluator, get_metric_value -from mmengine.registry import EVALUATORS +from mmengine.evaluator import BaseMetric, Evaluator, get_metric_value +from mmengine.registry import METRICS -@EVALUATORS.register_module() -class ToyEvaluator(BaseEvaluator): +@METRICS.register_module() +class ToyMetric(BaseMetric): """Evaluaotr that calculates the metric `accuracy` from predictions and labels. Alternatively, this evaluator can return arbitrary dummy metrics set in the config. @@ -39,8 +39,8 @@ class ToyEvaluator(BaseEvaluator): def process(self, data_batch, predictions): results = [{ - 'pred': pred.pred, - 'label': data[1].label + 'pred': pred.get('pred'), + 'label': data[1].get('label') } for pred, data in zip(predictions, data_batch)] self.results.extend(results) @@ -61,13 +61,13 @@ class ToyEvaluator(BaseEvaluator): return metrics -@EVALUATORS.register_module() -class NonPrefixedEvaluator(BaseEvaluator): +@METRICS.register_module() +class NonPrefixedMetric(BaseMetric): """Evaluator with unassigned `default_prefix` to test the warning information.""" - def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]], - predictions: Sequence[BaseDataElement]) -> None: + def process(self, data_batch: Sequence[Tuple[Any, dict]], + predictions: Sequence[dict]) -> None: pass def compute_metrics(self, results: list) -> dict: @@ -85,11 +85,11 @@ def generate_test_results(size, batch_size, pred, label): yield (data_batch, predictions) -class TestBaseEvaluator(TestCase): +class TestEvaluator(TestCase): - def test_single_evaluator(self): - cfg = dict(type='ToyEvaluator') - evaluator = build_evaluator(cfg) + def test_single_metric(self): + cfg = dict(type='ToyMetric') + evaluator = Evaluator(cfg) size = 10 batch_size = 4 @@ -103,18 +103,18 @@ class TestBaseEvaluator(TestCase): self.assertEqual(metrics['Toy/size'], size) # Test empty results - cfg = dict(type='ToyEvaluator', dummy_metrics=dict(accuracy=1.0)) - evaluator = build_evaluator(cfg) - with self.assertWarnsRegex(UserWarning, 'got empty `self._results`.'): + cfg = dict(type='ToyMetric', dummy_metrics=dict(accuracy=1.0)) + evaluator = Evaluator(cfg) + with self.assertWarnsRegex(UserWarning, 'got empty `self.results`.'): evaluator.evaluate(0) - def test_composed_evaluator(self): + def test_composed_metrics(self): cfg = [ - dict(type='ToyEvaluator'), - dict(type='ToyEvaluator', dummy_metrics=dict(mAP=0.0)) + dict(type='ToyMetric'), + dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0)) ] - evaluator = build_evaluator(cfg) + evaluator = Evaluator(cfg) size = 10 batch_size = 4 @@ -129,14 +129,13 @@ class TestBaseEvaluator(TestCase): self.assertAlmostEqual(metrics['Toy/mAP'], 0.0) self.assertEqual(metrics['Toy/size'], size) - def test_ambiguate_metric(self): - + def test_ambiguous_metric(self): cfg = [ - dict(type='ToyEvaluator', dummy_metrics=dict(mAP=0.0)), - dict(type='ToyEvaluator', dummy_metrics=dict(mAP=0.0)) + dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0)), + dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0)) ] - evaluator = build_evaluator(cfg) + evaluator = Evaluator(cfg) size = 10 batch_size = 4 @@ -147,28 +146,42 @@ class TestBaseEvaluator(TestCase): with self.assertRaisesRegex( ValueError, - 'There are multiple evaluators with the same metric name'): + 'There are multiple evaluation results with the same metric ' + 'name'): _ = evaluator.evaluate(size=size) def test_dataset_meta(self): dataset_meta = dict(classes=('cat', 'dog')) cfg = [ - dict(type='ToyEvaluator'), - dict(type='ToyEvaluator', dummy_metrics=dict(mAP=0.0)) + dict(type='ToyMetric'), + dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0)) ] - evaluator = build_evaluator(cfg) + evaluator = Evaluator(cfg) evaluator.dataset_meta = dataset_meta self.assertDictEqual(evaluator.dataset_meta, dataset_meta) - for _evaluator in evaluator.evaluators: + for _evaluator in evaluator.metrics: self.assertDictEqual(_evaluator.dataset_meta, dataset_meta) + def test_collect_device(self): + cfg = [ + dict(type='ToyMetric', collect_device='cpu'), + dict( + type='ToyMetric', + collect_device='gpu', + dummy_metrics=dict(mAP=0.0)) + ] + + evaluator = Evaluator(cfg) + self.assertEqual(evaluator.metrics[0].collect_device, 'cpu') + self.assertEqual(evaluator.metrics[1].collect_device, 'gpu') + def test_prefix(self): - cfg = dict(type='NonPrefixedEvaluator') + cfg = dict(type='NonPrefixedMetric') with self.assertWarnsRegex(UserWarning, 'The prefix is not set'): - _ = build_evaluator(cfg) + _ = Evaluator(cfg) def test_get_metric_value(self): @@ -208,3 +221,14 @@ class TestBaseEvaluator(TestCase): indicator = 'metric_2' # unmatched indicator with self.assertRaisesRegex(ValueError, 'can not match any metric'): _ = get_metric_value(indicator, metrics) + + def test_offline_evaluate(self): + cfg = dict(type='ToyMetric') + evaluator = Evaluator(cfg) + + size = 10 + + all_data = [(np.zeros((3, 10, 10)), BaseDataElement(label=1)) + for _ in range(size)] + all_predictions = [BaseDataElement(pred=0) for _ in range(size)] + evaluator.offline_evaluate(all_data, all_predictions) diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index e6e0316e15e8ffbbf16b7314b937ee410b73f249..25fbdbcfa22c129247e3a7c15aa7926dd3bb5cae 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -13,16 +13,14 @@ from torch.utils.data import DataLoader, Dataset from mmengine.config import Config from mmengine.data import DefaultSampler -from mmengine.evaluator import (BaseEvaluator, ComposedEvaluator, - build_evaluator) +from mmengine.evaluator import BaseMetric, Evaluator from mmengine.hooks import (Hook, IterTimerHook, LoggerHook, OptimizerHook, ParamSchedulerHook) from mmengine.hooks.checkpoint_hook import CheckpointHook from mmengine.logging import MessageHub, MMLogger from mmengine.optim.scheduler import MultiStepLR, StepLR -from mmengine.registry import (DATASETS, EVALUATORS, HOOKS, LOOPS, - MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS, - Registry) +from mmengine.registry import (DATASETS, HOOKS, LOOPS, METRICS, MODEL_WRAPPERS, + MODELS, PARAM_SCHEDULERS, Registry) from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop, Runner, TestLoop, ValLoop) from mmengine.runner.priority import Priority, get_priority @@ -80,8 +78,8 @@ class ToyDataset(Dataset): return self.data[index], self.label[index] -@EVALUATORS.register_module() -class ToyEvaluator1(BaseEvaluator): +@METRICS.register_module() +class ToyMetric1(BaseMetric): def __init__(self, collect_device='cpu', dummy_metrics=None): super().__init__(collect_device=collect_device) @@ -95,8 +93,8 @@ class ToyEvaluator1(BaseEvaluator): return dict(acc=1) -@EVALUATORS.register_module() -class ToyEvaluator2(BaseEvaluator): +@METRICS.register_module() +class ToyMetric2(BaseMetric): def __init__(self, collect_device='cpu', dummy_metrics=None): super().__init__(collect_device=collect_device) @@ -145,7 +143,7 @@ class CustomValLoop(BaseLoop): self._runner = runner if isinstance(evaluator, dict) or is_list_of(evaluator, dict): - self.evaluator = build_evaluator(evaluator) # type: ignore + self.evaluator = runner.build_evaluator(evaluator) # type: ignore else: self.evaluator = evaluator @@ -161,7 +159,7 @@ class CustomTestLoop(BaseLoop): self._runner = runner if isinstance(evaluator, dict) or is_list_of(evaluator, dict): - self.evaluator = build_evaluator(evaluator) # type: ignore + self.evaluator = runner.build_evaluator(evaluator) # type: ignore else: self.evaluator = evaluator @@ -197,8 +195,8 @@ class TestRunner(TestCase): num_workers=0), optimizer=dict(type='SGD', lr=0.01), param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]), - val_evaluator=dict(type='ToyEvaluator1'), - test_evaluator=dict(type='ToyEvaluator1'), + val_evaluator=dict(type='ToyMetric1'), + test_evaluator=dict(type='ToyMetric1'), train_cfg=dict(by_epoch=True, max_epochs=3), val_cfg=dict(interval=1), test_cfg=dict(), @@ -355,14 +353,14 @@ class TestRunner(TestCase): self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) self.assertIsInstance(runner.val_loop, BaseLoop) self.assertIsInstance(runner.val_loop.dataloader, DataLoader) - self.assertIsInstance(runner.val_loop.evaluator, ToyEvaluator1) + self.assertIsInstance(runner.val_loop.evaluator, Evaluator) # After calling runner.test(), test_dataloader should be initialized self.assertIsInstance(runner.test_loop, dict) runner.test() self.assertIsInstance(runner.test_loop, BaseLoop) self.assertIsInstance(runner.test_loop.dataloader, DataLoader) - self.assertIsInstance(runner.test_loop.evaluator, ToyEvaluator1) + self.assertIsInstance(runner.test_loop.evaluator, Evaluator) # 4. initialize runner with objects rather than config model = ToyModel() @@ -385,10 +383,10 @@ class TestRunner(TestCase): param_scheduler=MultiStepLR(optimizer, milestones=[1, 2]), val_cfg=dict(interval=1), val_dataloader=val_dataloader, - val_evaluator=ToyEvaluator1(), + val_evaluator=ToyMetric1(), test_cfg=dict(), test_dataloader=test_dataloader, - test_evaluator=ToyEvaluator1(), + test_evaluator=ToyMetric1(), default_hooks=dict(param_scheduler=toy_hook), custom_hooks=[toy_hook2], experiment_name='test_init14') @@ -585,20 +583,28 @@ class TestRunner(TestCase): runner = Runner.from_cfg(cfg) # input is a BaseEvaluator or ComposedEvaluator object - evaluator = ToyEvaluator1() + evaluator = Evaluator(ToyMetric1()) self.assertEqual(id(runner.build_evaluator(evaluator)), id(evaluator)) - evaluator = ComposedEvaluator([ToyEvaluator1(), ToyEvaluator2()]) + evaluator = Evaluator([ToyMetric1(), ToyMetric2()]) self.assertEqual(id(runner.build_evaluator(evaluator)), id(evaluator)) - # input is a dict or list of dict - evaluator = dict(type='ToyEvaluator1') - self.assertIsInstance(runner.build_evaluator(evaluator), ToyEvaluator1) + # input is a dict + evaluator = dict(type='ToyMetric1') + self.assertIsInstance(runner.build_evaluator(evaluator), Evaluator) + + # input is a list of dict + evaluator = [dict(type='ToyMetric1'), dict(type='ToyMetric2')] + self.assertIsInstance(runner.build_evaluator(evaluator), Evaluator) - # input is a invalid type - evaluator = [dict(type='ToyEvaluator1'), dict(type='ToyEvaluator2')] - self.assertIsInstance( - runner.build_evaluator(evaluator), ComposedEvaluator) + # test collect device + evaluator = [ + dict(type='ToyMetric1', collect_device='cpu'), + dict(type='ToyMetric2', collect_device='gpu') + ] + _evaluator = runner.build_evaluator(evaluator) + self.assertEqual(_evaluator.metrics[0].collect_device, 'cpu') + self.assertEqual(_evaluator.metrics[1].collect_device, 'gpu') def test_build_dataloader(self): cfg = copy.deepcopy(self.epoch_based_cfg)