Skip to content
Snippets Groups Projects
Unverified Commit 2fdca03f authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Enhancement] Support get sub dataset and rename method and variable. (#145)

* add get_subset method, add comment, rename variable

* add unit test

* Please mypy

* Fix as comment, support negative index, and fix index access error

* add and refine docstring, handle indices=0

* handle indices=0

* add empty list indices test

* rename in_meta in docstring and comments to in_metainfo

* clean meta nameing

* Fix negative indices error

* test empty list of get_subset

* fix comments and docstring

* add unit test

* Fix as comment

* Fix as comment

* add docstring to mention wrapped dataset should not inherit from Basedataset

* Fix wrapped dataset docstring

* Fix wrapped dataset docstring

* Fix method name, docstring, and comments

* Fix comments

* Fix comments

* Fix comments
parent 8b4d7dda
No related branches found
No related tags found
No related merge requests found
This diff is collapsed.
...@@ -4,7 +4,7 @@ import copy ...@@ -4,7 +4,7 @@ import copy
import math import math
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from typing import List, Sequence, Tuple from typing import List, Sequence, Tuple, Union
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
...@@ -16,6 +16,13 @@ class ConcatDataset(_ConcatDataset): ...@@ -16,6 +16,13 @@ class ConcatDataset(_ConcatDataset):
Same as ``torch.utils.data.dataset.ConcatDataset`` and support lazy_init. Same as ``torch.utils.data.dataset.ConcatDataset`` and support lazy_init.
Note:
``ConcatDataset`` should not inherit from ``BaseDataset`` since
``get_subset`` and ``get_subset_`` could produce ambiguous meaning
sub-dataset which conflicts with original dataset. If you want to use
a sub-dataset of ``ConcatDataset``, you should set ``indices``
arguments for wrapped dataset which inherit from ``BaseDataset``.
Args: Args:
datasets (Sequence[BaseDataset]): A list of datasets which will be datasets (Sequence[BaseDataset]): A list of datasets which will be
concatenated. concatenated.
...@@ -26,12 +33,12 @@ class ConcatDataset(_ConcatDataset): ...@@ -26,12 +33,12 @@ class ConcatDataset(_ConcatDataset):
def __init__(self, def __init__(self,
datasets: Sequence[BaseDataset], datasets: Sequence[BaseDataset],
lazy_init: bool = False): lazy_init: bool = False):
# Only use meta of first dataset. # Only use metainfo of first dataset.
self._meta = datasets[0].meta self._metainfo = datasets[0].metainfo
self.datasets = datasets # type: ignore self.datasets = datasets # type: ignore
for i, dataset in enumerate(datasets, 1): for i, dataset in enumerate(datasets, 1):
if self._meta != dataset.meta: if self._metainfo != dataset.metainfo:
warnings.warn( raise ValueError(
f'The meta information of the {i}-th dataset does not ' f'The meta information of the {i}-th dataset does not '
'match meta information of the first dataset') 'match meta information of the first dataset')
...@@ -40,14 +47,14 @@ class ConcatDataset(_ConcatDataset): ...@@ -40,14 +47,14 @@ class ConcatDataset(_ConcatDataset):
self.full_init() self.full_init()
@property @property
def meta(self) -> dict: def metainfo(self) -> dict:
"""Get the meta information of the first dataset in ``self.datasets``. """Get the meta information of the first dataset in ``self.datasets``.
Returns: Returns:
dict: Meta information of first dataset. dict: Meta information of first dataset.
""" """
# Prevent `self._meta` from being modified by outside. # Prevent `self._metainfo` from being modified by outside.
return copy.deepcopy(self._meta) return copy.deepcopy(self._metainfo)
def full_init(self): def full_init(self):
"""Loop to ``full_init`` each dataset.""" """Loop to ``full_init`` each dataset."""
...@@ -77,8 +84,9 @@ class ConcatDataset(_ConcatDataset): ...@@ -77,8 +84,9 @@ class ConcatDataset(_ConcatDataset):
f'absolute value of index({idx}) should not exceed dataset' f'absolute value of index({idx}) should not exceed dataset'
f'length({len(self)}).') f'length({len(self)}).')
idx = len(self) + idx idx = len(self) + idx
# Get the inner index of single dataset # Get `dataset_idx` to tell idx belongs to which dataset.
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
# Get the inner index of single dataset.
if dataset_idx == 0: if dataset_idx == 0:
sample_idx = idx sample_idx = idx
else: else:
...@@ -111,6 +119,26 @@ class ConcatDataset(_ConcatDataset): ...@@ -111,6 +119,26 @@ class ConcatDataset(_ConcatDataset):
dataset_idx, sample_idx = self._get_ori_dataset_idx(idx) dataset_idx, sample_idx = self._get_ori_dataset_idx(idx)
return self.datasets[dataset_idx][sample_idx] return self.datasets[dataset_idx][sample_idx]
def get_subset_(self, indices: Union[List[int], int]) -> None:
"""Not supported in ``ConcatDataset`` for the ambiguous meaning of sub-
dataset."""
raise NotImplementedError(
'`ConcatDataset` dose not support `get_subset` and '
'`get_subset_` interfaces because this will lead to ambiguous '
'implementation of some methods. If you want to use `get_subset` '
'or `get_subset_` interfaces, please use them in the wrapped '
'dataset first and then use `ConcatDataset`.')
def get_subset(self, indices: Union[List[int], int]) -> 'BaseDataset':
"""Not supported in ``ConcatDataset`` for the ambiguous meaning of sub-
dataset."""
raise NotImplementedError(
'`ConcatDataset` dose not support `get_subset` and '
'`get_subset_` interfaces because this will lead to ambiguous '
'implementation of some methods. If you want to use `get_subset` '
'or `get_subset_` interfaces, please use them in the wrapped '
'dataset first and then use `ConcatDataset`.')
class RepeatDataset: class RepeatDataset:
"""A wrapper of repeated dataset. """A wrapper of repeated dataset.
...@@ -120,10 +148,17 @@ class RepeatDataset: ...@@ -120,10 +148,17 @@ class RepeatDataset:
is small. Using RepeatDataset can reduce the data loading time between is small. Using RepeatDataset can reduce the data loading time between
epochs. epochs.
Note:
``RepeatDataset`` should not inherit from ``BaseDataset`` since
``get_subset`` and ``get_subset_`` could produce ambiguous meaning
sub-dataset which conflicts with original dataset. If you want to use
a sub-dataset of ``RepeatDataset``, you should set ``indices``
arguments for wrapped dataset which inherit from ``BaseDataset``.
Args: Args:
dataset (BaseDataset): The dataset to be repeated. dataset (BaseDataset): The dataset to be repeated.
times (int): Repeat times. times (int): Repeat times.
lazy_init (bool, optional): Whether to load annotation during lazy_init (bool): Whether to load annotation during
instantiation. Defaults to False. instantiation. Defaults to False.
""" """
...@@ -133,20 +168,20 @@ class RepeatDataset: ...@@ -133,20 +168,20 @@ class RepeatDataset:
lazy_init: bool = False): lazy_init: bool = False):
self.dataset = dataset self.dataset = dataset
self.times = times self.times = times
self._meta = dataset.meta self._metainfo = dataset.metainfo
self._fully_initialized = False self._fully_initialized = False
if not lazy_init: if not lazy_init:
self.full_init() self.full_init()
@property @property
def meta(self) -> dict: def metainfo(self) -> dict:
"""Get the meta information of the repeated dataset. """Get the meta information of the repeated dataset.
Returns: Returns:
dict: The meta information of repeated dataset. dict: The meta information of repeated dataset.
""" """
return copy.deepcopy(self._meta) return copy.deepcopy(self._metainfo)
def full_init(self): def full_init(self):
"""Loop to ``full_init`` each dataset.""" """Loop to ``full_init`` each dataset."""
...@@ -195,6 +230,26 @@ class RepeatDataset: ...@@ -195,6 +230,26 @@ class RepeatDataset:
def __len__(self): def __len__(self):
return self.times * self._ori_len return self.times * self._ori_len
def get_subset_(self, indices: Union[List[int], int]) -> None:
"""Not supported in ``RepeatDataset`` for the ambiguous meaning of sub-
dataset."""
raise NotImplementedError(
'`RepeatDataset` dose not support `get_subset` and '
'`get_subset_` interfaces because this will lead to ambiguous '
'implementation of some methods. If you want to use `get_subset` '
'or `get_subset_` interfaces, please use them in the wrapped '
'dataset first and then use `RepeatDataset`.')
def get_subset(self, indices: Union[List[int], int]) -> 'BaseDataset':
"""Not supported in ``RepeatDataset`` for the ambiguous meaning of sub-
dataset."""
raise NotImplementedError(
'`RepeatDataset` dose not support `get_subset` and '
'`get_subset_` interfaces because this will lead to ambiguous '
'implementation of some methods. If you want to use `get_subset` '
'or `get_subset_` interfaces, please use them in the wrapped '
'dataset first and then use `RepeatDataset`.')
class ClassBalancedDataset: class ClassBalancedDataset:
"""A wrapper of class balanced dataset. """A wrapper of class balanced dataset.
...@@ -219,6 +274,14 @@ class ClassBalancedDataset: ...@@ -219,6 +274,14 @@ class ClassBalancedDataset:
3. For each image I, compute the image-level repeat factor: 3. For each image I, compute the image-level repeat factor:
:math:`r(I) = max_{c in I} r(c)` :math:`r(I) = max_{c in I} r(c)`
Note:
``ClassBalancedDataset`` should not inherit from ``BaseDataset``
since ``get_subset`` and ``get_subset_`` could produce ambiguous
meaning sub-dataset which conflicts with original dataset. If you
want to use a sub-dataset of ``ClassBalancedDataset``, you should set
``indices`` arguments for wrapped dataset which inherit from
``BaseDataset``.
Args: Args:
dataset (BaseDataset): The dataset to be repeated. dataset (BaseDataset): The dataset to be repeated.
oversample_thr (float): frequency threshold below which data is oversample_thr (float): frequency threshold below which data is
...@@ -236,20 +299,20 @@ class ClassBalancedDataset: ...@@ -236,20 +299,20 @@ class ClassBalancedDataset:
lazy_init: bool = False): lazy_init: bool = False):
self.dataset = dataset self.dataset = dataset
self.oversample_thr = oversample_thr self.oversample_thr = oversample_thr
self._meta = dataset.meta self._metainfo = dataset.metainfo
self._fully_initialized = False self._fully_initialized = False
if not lazy_init: if not lazy_init:
self.full_init() self.full_init()
@property @property
def meta(self) -> dict: def metainfo(self) -> dict:
"""Get the meta information of the repeated dataset. """Get the meta information of the repeated dataset.
Returns: Returns:
dict: The meta information of repeated dataset. dict: The meta information of repeated dataset.
""" """
return copy.deepcopy(self._meta) return copy.deepcopy(self._metainfo)
def full_init(self): def full_init(self):
"""Loop to ``full_init`` each dataset.""" """Loop to ``full_init`` each dataset."""
...@@ -257,9 +320,12 @@ class ClassBalancedDataset: ...@@ -257,9 +320,12 @@ class ClassBalancedDataset:
return return
self.dataset.full_init() self.dataset.full_init()
# Get repeat factors for each image.
repeat_factors = self._get_repeat_factors(self.dataset, repeat_factors = self._get_repeat_factors(self.dataset,
self.oversample_thr) self.oversample_thr)
# Repeat dataset's indices according to repeat_factors. For example,
# if `repeat_factors = [1, 2, 3]`, and the `len(dataset) == 3`,
# the repeated indices will be [1, 2, 2, 3, 3, 3].
repeat_indices = [] repeat_indices = []
for dataset_index, repeat_factor in enumerate(repeat_factors): for dataset_index, repeat_factor in enumerate(repeat_factors):
repeat_indices.extend([dataset_index] * math.ceil(repeat_factor)) repeat_indices.extend([dataset_index] * math.ceil(repeat_factor))
...@@ -362,3 +428,23 @@ class ClassBalancedDataset: ...@@ -362,3 +428,23 @@ class ClassBalancedDataset:
@force_full_init @force_full_init
def __len__(self): def __len__(self):
return len(self.repeat_indices) return len(self.repeat_indices)
def get_subset_(self, indices: Union[List[int], int]) -> None:
"""Not supported in ``ClassBalancedDataset`` for the ambiguous meaning
of sub-dataset."""
raise NotImplementedError(
'`ClassBalancedDataset` dose not support `get_subset` and '
'`get_subset_` interfaces because this will lead to ambiguous '
'implementation of some methods. If you want to use `get_subset` '
'or `get_subset_` interfaces, please use them in the wrapped '
'dataset first and then use `ClassBalancedDataset`.')
def get_subset(self, indices: Union[List[int], int]) -> 'BaseDataset':
"""Not supported in ``ClassBalancedDataset`` for the ambiguous meaning
of sub-dataset."""
raise NotImplementedError(
'`ClassBalancedDataset` dose not support `get_subset` and '
'`get_subset_` interfaces because this will lead to ambiguous '
'implementation of some methods. If you want to use `get_subset` '
'or `get_subset_` interfaces, please use them in the wrapped '
'dataset first and then use `ClassBalancedDataset`.')
...@@ -46,6 +46,26 @@ ...@@ -46,6 +46,26 @@
"extra_anns": [4,5,6] "extra_anns": [4,5,6]
} }
] ]
},
{
"img_path": "gray.jpg",
"height": 512,
"width": 512,
"instances":
[
{
"bbox": [0, 0, 10, 20],
"bbox_label": 1,
"mask": [[0,0],[0,10],[10,20],[20,0]],
"extra_anns": [1,2,3]
},
{
"bbox": [10, 10, 110, 120],
"bbox_label": 2,
"mask": [[10,10],[10,110],[110,120],[120,10]],
"extra_anns": [4,5,6]
}
]
} }
] ]
} }
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment