Skip to content
Snippets Groups Projects
Unverified Commit 92b94e8e authored by Tao Gong's avatar Tao Gong Committed by GitHub
Browse files

[Docs] Add docs for custom dataset wrapper (#222)

* add docs for custom dataset wrapper

* Update basedataset.md
parent 22d3b045
No related branches found
No related tags found
No related merge requests found
...@@ -219,7 +219,7 @@ class ToyVideoDataset(BaseDataset): ...@@ -219,7 +219,7 @@ class ToyVideoDataset(BaseDataset):
1. 将不满足规范的标注文件转换成满足规范的标注文件,再通过上述方式使用数据集基类。 1. 将不满足规范的标注文件转换成满足规范的标注文件,再通过上述方式使用数据集基类。
2. 实现一个新的数据集类,继承自数据集基类,并且重载数据集基类的 `load_data_list(self, ann_file):` 函数,处理不满足规范的标注文件,并保证返回值为 `list[dict]`,其中每个 `dict` 代表一个数据样本。 2. 实现一个新的数据集类,继承自数据集基类,并且重载数据集基类的 `load_data_list(self):` 函数,处理不满足规范的标注文件,并保证返回值为 `list[dict]`,其中每个 `dict` 代表一个数据样本。
## 数据集基类的其它特性 ## 数据集基类的其它特性
...@@ -391,3 +391,92 @@ toy_dataset_repeat = ClassBalancedDataset(dataset=toy_dataset, oversample_thr=1e ...@@ -391,3 +391,92 @@ toy_dataset_repeat = ClassBalancedDataset(dataset=toy_dataset, oversample_thr=1e
``` ```
上述例子将数据集的 `train` 部分以 `oversample_thr=1e-3` 重新采样,具体地,对于数据集中出现频率低于 `1e-3` 的类别,会重复采样该类别对应的样本,否则不重复采样,具体采样策略请参考 `ClassBalancedDataset` API 文档。 上述例子将数据集的 `train` 部分以 `oversample_thr=1e-3` 重新采样,具体地,对于数据集中出现频率低于 `1e-3` 的类别,会重复采样该类别对应的样本,否则不重复采样,具体采样策略请参考 `ClassBalancedDataset` API 文档。
### 自定义数据集类包装
由于数据集基类实现了懒加载的功能,因此在自定义数据集类包装时,需要遵循一些规则,下面以一个例子的方式来展示如何自定义数据集类包装:
```python
from mmengine.dataset import BaseDataset
from mmengine.registry import DATASETS
@DATASETS.register_module()
class ExampleDatasetWrapper:
def __init__(self, dataset, lazy_init = False, ...):
# 构建原数据集(self.dataset)
if isinstance(dataset, dict):
self.dataset = DATASETS.build(dataset)
elif isinstance(dataset, BaseDataset):
self.dataset = dataset
else:
raise TypeError(
'elements in datasets sequence should be config or '
f'`BaseDataset` instance, but got {type(dataset)}')
# 记录原数据集的元信息
self._metainfo = self.dataset.metainfo
'''
1. 在这里实现一些代码,来记录用于包装数据集的一些超参。
'''
self._fully_initialized = False
if not lazy_init:
self.full_init()
def full_init(self):
if self._fully_initialized:
return
# 将原数据集完全初始化
self.dataset.full_init()
'''
2. 在这里实现一些代码,来包装原数据集。
'''
self._fully_initialized = True
@force_full_init
def _get_ori_dataset_idx(self, idx: int):
'''
3. 在这里实现一些代码,来将包装的索引 `idx` 映射到原数据集的索引 `ori_idx`。
'''
ori_idx = ...
return ori_idx
# 提供与 `self.dataset` 一样的对外接口。
@force_full_init
def get_data_info(self, idx):
sample_idx = self._get_ori_dataset_idx(idx)
return self.dataset.get_data_info(sample_idx)
# 提供与 `self.dataset` 一样的对外接口。
def __getitem__(self, idx):
if not self._fully_initialized:
warnings.warn('Please call `full_init` method manually to '
'accelerate the speed.')
self.full_init()
sample_idx = self._get_ori_dataset_idx(idx)
return self.dataset[sample_idx]
# 提供与 `self.dataset` 一样的对外接口。
@force_full_init
def __len__(self):
'''
4. 在这里实现一些代码,来计算包装数据集之后的长度。
'''
len_wrapper = ...
return len_wrapper
# 提供与 `self.dataset` 一样的对外接口。
@property
def metainfo(self)
return copy.deepcopy(self._metainfo)
```
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment