Skip to content
Snippets Groups Projects
Unverified Commit a3b8d4ea authored by Tao Gong's avatar Tao Gong Committed by GitHub
Browse files

Refactor docs of basedataset (#175)

* refactor docs of basedataset

* fix ci

* fix comments

* fix comments

* fix comments

* fix comments

* fix comments

* set default value of ann_file to ''

* fix comments
parent 82a313d0
No related branches found
No related tags found
No related merge requests found
......@@ -6,25 +6,25 @@
因此 **MMEngine** 实现了一个数据集基类(BaseDataset)并定义了一些基本接口,且基于这套接口实现了一些数据集包装(DatasetWrapper)。OpenMMLab 算法库中的大部分数据集都会满足这套数据集基类定义的接口,并使用统一的数据集包装。
数据集基类的基本功能是加载数据集信息,这里我们将数据集信息分成两类,一种是元信息 (meta information),代表数据集自身相关的信息,有时需要被模型或其他外部组件获取,比如在图像分类任务中,数据集的元信息一般包含类别信息 `classes`,因为分类模型 `model` 一般需要记录数据集的类别信息;另一种为数据信息 (data information),在数据信息中,定义了具体样本的文件路径、对应标签等的信息。除此之外,数据集基类的另一个功能为将数据送入数据流水线(data pipeline)中,进行数据预处理。
数据集基类的基本功能是加载数据集信息,这里我们将数据集信息分成两类,一种是元信息 (meta information),代表数据集自身相关的信息,有时需要被模型或其他外部组件获取,比如在图像分类任务中,数据集的元信息一般包含类别信息 `classes`,因为分类模型 `model` 一般需要记录数据集的类别信息;另一种为数据信息 (data information),在数据信息中,定义了具体样本的文件路径、对应标签等的信息。除此之外,数据集基类的另一个功能为不断地将数据送入数据流水线(data pipeline)中,进行数据预处理。
### 数据标注文件规范
为了统一不同任务的数据集接口,便于多任务的算法模型训练,OpenMMLab 制定了 **OpenMMLab 2.0 数据集格式规范**, 数据集标注文件需符合该规范,数据集基类基于该规范去读取与解析数据标注文件。如果用户提供的数据标注文件不符合规定格式,用户应该将其转化为规定格式才能使用 OpenMMLab 的算法库基于该数据标注文件进行算法训练和测试。
为了统一不同任务的数据集接口,便于多任务的算法模型训练,OpenMMLab 制定了 **OpenMMLab 2.0 数据集格式规范**, 数据集标注文件需符合该规范,数据集基类基于该规范去读取与解析数据标注文件。如果用户提供的数据标注文件不符合规定格式,用户可以选择将其转化为规定格式,并使用 OpenMMLab 的算法库基于该数据标注文件进行算法训练和测试。
OpenMMLab 2.0 数据集格式规范规定,标注文件必须为 `json``yaml``yml``pickle``pkl` 格式;标注文件中存储的字典必须包含 `metadata``data_infos` 两个字段。其中 `metadata` 是一个字典,里面包含数据集的元信息;`data_infos` 是一个列表,列表中每个元素是一个字典,该字典定义了一个原始数据(raw data),每个原始数据包含一个或若干个训练/测试样本。
OpenMMLab 2.0 数据集格式规范规定,标注文件必须为 `json``yaml``yml``pickle``pkl` 格式;标注文件中存储的字典必须包含 `metainfo``data_list` 两个字段。其中 `metainfo` 是一个字典,里面包含数据集的元信息;`data_list` 是一个列表,列表中每个元素是一个字典,该字典定义了一个原始数据(raw data),每个原始数据包含一个或若干个训练/测试样本。
以下是一个 JSON 标注文件的例子(该例子中每个原始数据只包含一个训练/测试样本):
```json
{
'metadata':
'metainfo':
{
'classes': ('cat', 'dog'),
...
},
'data_infos':
'data_list':
[
{
'img_path': "xxx/xxx_0.jpg",
......@@ -55,45 +55,59 @@ data
### 数据集基类的初始化流程
数据集基类的初始化流程如下:
数据集基类的初始化流程如下图所示
1. 获取数据集的元信息,元信息有三种来源,优先级从高到低为:
![image](https://user-images.githubusercontent.com/26813582/164611564-af44e3f2-a50f-4ef1-a6db-eddd840e2f40.png)
- `__init__()` 方法中用户传入的 `meta` 字典;改动频率最高,因为用户可以在实例化数据集时,传入该参数;
1. `load metainfo`:获取数据集的元信息,元信息有三种来源,优先级从高到低为:
- 类属性 `BaseDataset.META` 字典;改动频率中等,因为用户可以改动自定义数据集类中的类属性 `BaseDataset.META`
- `__init__()` 方法中用户传入的 `metainfo` 字典;改动频率最高,因为用户可以在实例化数据集时,传入该参数
- 标注文件中包含的 `metadata` 字典;改动频率最低,因为标注文件一般不做改动。
- 类属性 `BaseDataset.METAINFO` 字典;改动频率中等,因为用户可以改动自定义数据集类中的类属性 `BaseDataset.METAINFO`
- 标注文件中包含的 `metainfo` 字典;改动频率最低,因为标注文件一般不做改动。
如果三种来源中有相同的字段,优先级最高的来源决定该字段的值;
2. 构建数据流水线(data pipeline),用于数据预处理与数据准备;
2. `join path`:处理数据与标注文件的路径;
3. `build pipeline`:构建数据流水线(data pipeline),用于数据预处理与数据准备;
3. 读取与解析满足 OpenMMLab 2.0 数据集格式规范的标注文件,该步骤中会有 `parse_annotations()` 方法,该方法负责解析标注文件里的每个原始数据;
4. `full init`:完全初始化数据集类,该步骤主要包含以下操作:
4. 过滤无用数据,比如不包含标注的样本等
- `load data list`:读取与解析满足 OpenMMLab 2.0 数据集格式规范的标注文件,该步骤中会调用 `parse_data_info()` 方法,该方法负责解析标注文件里的每个原始数据
5. 采样数据,比如只取前 10 个样本参与训练/测试
- `filter data` (可选):根据 `filter_cfg` 过滤无用数据,比如不包含标注的样本等;默认不做过滤操作,下游子类可以按自身所需对其进行重写
6. 序列化全部样本,以达到节省内存的效果,详情请参考[节省内存](#节省内存)
- `get subset` (可选):根据给定的索引或整数值采样数据,比如只取前 10 个样本参与训练/测试;默认不采样数据,即使用全部数据样本;
数据集基类中包含的 `parse_annotations()` 方法用于将标注文件里的一个原始数据处理成一个或若干个训练/测试样本的方法。因此对于自定义数据集类,用户需要实现 `parse_annotations()` 方法。
- `serialize data` (可选):序列化全部样本,以达到节省内存的效果,详情请参考[节省内存](#节省内存);默认操作为序列化全部样本。
数据集基类中包含的 `parse_data_info()` 方法用于将标注文件里的一个原始数据处理成一个或若干个训练/测试样本的方法。因此对于自定义数据集类,用户需要实现 `parse_data_info()` 方法。
### 数据集基类提供的接口
`torch.utils.data.Dataset` 类似,数据集初始化后,支持 `__getitem__` 方法,用来索引数据,以及 `__len__` 操作获取数据集大小,除此之外,OpenMMLab 的数据集基类主要提供了以下接口来访问具体信息:
- `meta` 返回元信息,返回值为字典
- `metainfo`:返回元信息,返回值为字典
- `get_data_info(idx)`:返回指定 `idx` 的样本全量信息,返回值为字典
- `get_data_info(idx)` 返回指定 `idx` 的样本全量信息,返回值为字典
- `__getitem__(idx)`返回指定 `idx` 的样本经过 pipeline 之后的结果(也就是送入模型的数据),返回值为字典
- `__getitem__(idx)` :返回指定 `idx` 的样本经过 pipeline 之后的结果(也就是送入模型的数据),返回值为字典
- `__len__()`:返回数据集长度,返回值为整数型
- `__len__()` 返回数据集长度,返回值为整数型
- `get_subset_(indices)`:根据 `indices` 以 inplace 的方式**修改原数据集类**。如果 `indices``int`,则原数据集类只包含前若干个数据样本;如果 `indices``Sequence[int]`,则原数据集类包含根据 `Sequence[int]` 指定的数据样本。
- `get_subset(indices)`:根据 `indices`**非** inplace 的方式**返回子数据集类**,即重新复制一份子数据集。如果 `indices``int`,则返回的子数据集类只包含前若干个数据样本;如果 `indices``Sequence[int]`,则返回的子数据集类包含根据 `Sequence[int]` 指定的数据样本。
## 使用数据集基类自定义数据集类
在了解了数据集基类的初始化流程与提供的接口之后,就可以基于数据集基类自定义数据集类,如上所述,对于满足 OpenMMLab 2.0 数据集格式规范的标注文件,用户可以重载 `parse_annotations()`来加载标签。以下是一个使用数据集基类来实现某一具体数据集的例子。
在了解了数据集基类的初始化流程与提供的接口之后,就可以基于数据集基类自定义数据集类。
### 对于满足 OpenMMLab 2.0 数据集格式规范的标注文件
如上所述,对于满足 OpenMMLab 2.0 数据集格式规范的标注文件,用户可以重载 `parse_data_info()` 来加载标签。以下是一个使用数据集基类来实现某一具体数据集的例子。
```python
import os.path as osp
......@@ -103,13 +117,13 @@ from mmengine.data import BaseDataset
class ToyDataset(BaseDataset):
# 以上面标注文件为例,在这里 raw_data_info 代表 `data_infos` 对应列表里的某个字典:
# 以上面标注文件为例,在这里 raw_data_info 代表 `data_list` 对应列表里的某个字典:
# {
# 'img_path': "xxx/xxx_0.jpg",
# 'img_label': 0,
# ...
# }
def parse_annotations(self, raw_data_info):
def parse_data_info(self, raw_data_info):
data_info = raw_data_info
img_prefix = self.data_prefix.get('img', None)
if img_prefix is not None:
......@@ -119,7 +133,7 @@ class ToyDataset(BaseDataset):
```
### 使用自定义数据集类
#### 使用自定义数据集类
在定义了数据集类后,就可以通过如下配置实例化 `ToyDataset`
......@@ -155,13 +169,23 @@ len(toy_dataset)
toy_dataset[0]
# dict(img=xxx, label=0)
# `get_subset` 接口不对原数据集类做修改,即完全复制一份新的
sub_toy_dataset = toy_dataset.get_subset(1)
len(toy_dataset), len(sub_toy_dataset)
# 2, 1
# `get_subset_` 接口会对原数据集类做修改,即 inplace 的方式
toy_dataset.get_subset_(1)
len(toy_dataset)
# 1
```
经过以上步骤,可以了解基于数据集基类如何自定义新的数据集类,以及如何使用自定义数据集类。
### 自定义视频的数据集类
#### 自定义视频的数据集类
在上面的例子中,标注文件的每个原始数据只包含一个训练/测试样本(通常是图像领域)。如果每个原始数据包含若干个训练/测试样本(通常是视频领域),则只需保证 `parse_annotations()` 的返回值为 `list[dict]` 即可:
在上面的例子中,标注文件的每个原始数据只包含一个训练/测试样本(通常是图像领域)。如果每个原始数据包含若干个训练/测试样本(通常是视频领域),则只需保证 `parse_data_info()` 的返回值为 `list[dict]` 即可:
```python
from mmengine.data import BaseDataset
......@@ -170,8 +194,8 @@ from mmengine.data import BaseDataset
class ToyVideoDataset(BaseDataset):
# raw_data_info 仍为一个字典,但它包含了多个样本
def parse_annotations(self, raw_data_info):
data_infos = []
def parse_data_info(self, raw_data_info):
data_list = []
...
......@@ -181,14 +205,22 @@ class ToyVideoDataset(BaseDataset):
...
data_infos.append(data_info)
data_list.append(data_info)
return data_infos
return data_list
```
`ToyVideoDataset` 使用方法与 `ToyDataset` 类似,在此不做赘述。
### 对于不满足 OpenMMLab 2.0 数据集格式规范的标注文件
对于不满足 OpenMMLab 2.0 数据集格式规范的标注文件,有两种方式来使用数据集基类:
1. 将不满足规范的标注文件转换成满足规范的标注文件,再通过上述方式使用数据集基类。
2. 实现一个新的数据集类,继承自数据集基类,并且重载数据集基类的 `load_data_list(self, ann_file):` 函数,处理不满足规范的标注文件,并保证返回值为 `list[dict]`,其中每个 `dict` 代表一个数据样本。
## 数据集基类的其它特性
数据集基类还包含以下特性:
......@@ -213,9 +245,9 @@ toy_dataset = ToyDataset(
lazy_init=True)
```
`lazy_init=True` 时,`ToyDataset` 的初始化方法只执行了[数据集基类的初始化流程](#数据集基类的初始化流程)中的 1、2 步骤,此时 `toy_dataset` 并未被完全初始化,因为 `toy_dataset` 并不会读取与解析标注文件,只会设置数据集类的元信息(`meta`)。
`lazy_init=True` 时,`ToyDataset` 的初始化方法只执行了[数据集基类的初始化流程](#数据集基类的初始化流程)中的 1、2、3 步骤,此时 `toy_dataset` 并未被完全初始化,因为 `toy_dataset` 并不会读取与解析标注文件,只会设置数据集类的元信息(`metainfo`)。
自然的,如果之后需要访问具体的数据信息,可以手动调用 `toy_dataset.full_init()` 接口来执行完整的初始化过程,在这个过程中数据标注文件将被读取与解析。调用 `get_data_info(idx)`, `__len__()`, `__getitem__()` 接口也会自动地调用 `full_init()` 接口来执行完整的初始化过程(仅在第一次调用时,之后调用不会重复地调用 `full_init()` 接口):
自然的,如果之后需要访问具体的数据信息,可以手动调用 `toy_dataset.full_init()` 接口来执行完整的初始化过程,在这个过程中数据标注文件将被读取与解析。调用 `get_data_info(idx)`, `__len__()`, `__getitem__(idx)``get_subset_(indices)``get_subset(indices)` 接口也会自动地调用 `full_init()` 接口来执行完整的初始化过程(仅在第一次调用时,之后调用不会重复地调用 `full_init()` 接口):
```python
# 完整初始化
......@@ -234,9 +266,9 @@ toy_dataset[0] # dict(img=xxx, label=0)
### 节省内存
在具体的读取数据过程中,数据加载器(dataloader)通常会起多个 worker 来预取数据,多个 worker 都拥有完整的数据集类备份,因此内存中会存在多份相同的 `data_infos`,为了节省这部分内存消耗,数据集基类可以提前将 `data_infos` 序列化存入内存中,使得多个 worker 可以共享同一份 `data_infos`,以达到节省内存的目的。
在具体的读取数据过程中,数据加载器(dataloader)通常会起多个 worker 来预取数据,多个 worker 都拥有完整的数据集类备份,因此内存中会存在多份相同的 `data_list`,为了节省这部分内存消耗,数据集基类可以提前将 `data_list` 序列化存入内存中,使得多个 worker 可以共享同一份 `data_list`,以达到节省内存的目的。
数据集基类默认是将 `data_infos` 序列化存入内存,也可以通过 `serialize_data` 变量(默认为 `True`)来控制是否提前将 `data_infos` 序列化存入内存中:
数据集基类默认是将 `data_list` 序列化存入内存,也可以通过 `serialize_data` 变量(默认为 `True`)来控制是否提前将 `data_list` 序列化存入内存中:
```python
pipeline = [
......@@ -254,7 +286,7 @@ toy_dataset = ToyDataset(
serialize_data=False)
```
上面例子不会提前将 `data_infos` 序列化存入内存中,因此不建议在使用数据加载器开多个 worker 加载数据的情况下,使用这种方式实例化数据集类。
上面例子不会提前将 `data_list` 序列化存入内存中,因此不建议在使用数据加载器开多个 worker 加载数据的情况下,使用这种方式实例化数据集类。
## 数据集基类包装
......@@ -329,7 +361,7 @@ from mmengine.data import BaseDataset, ClassBalancedDataset
class ToyDataset(BaseDataset):
def parse_annotations(self, raw_data_info):
def parse_data_info(self, raw_data_info):
data_info = raw_data_info
img_prefix = self.data_prefix.get('img', None)
if img_prefix is not None:
......
......@@ -118,12 +118,12 @@ class BaseDataset(Dataset):
.. code-block:: none
{
"metadata":
"metainfo":
{
"dataset_type": "test_dataset",
"task_name": "test_task"
},
"data_infos":
"data_list":
[
{
"img_path": "test_img.jpg",
......@@ -149,7 +149,7 @@ class BaseDataset(Dataset):
}
Args:
ann_file (str): Annotation file path.
ann_file (str): Annotation file path. Defaults to ''.
metainfo (dict, optional): Meta information for dataset, such as class
information. Defaults to None.
data_root (str, optional): The root directory for ``data_prefix`` and
......@@ -208,7 +208,7 @@ class BaseDataset(Dataset):
_fully_initialized: bool = False
def __init__(self,
ann_file: str,
ann_file: str = '',
metainfo: Optional[dict] = None,
data_root: Optional[str] = None,
data_prefix: dict = dict(img=None, ann=None),
......@@ -232,7 +232,7 @@ class BaseDataset(Dataset):
self.data_bytes: np.ndarray
# Set meta information.
self._metainfo = self._get_meta_info(copy.deepcopy(metainfo))
self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))
# Join paths.
if self.data_root is not None:
......@@ -429,21 +429,21 @@ class BaseDataset(Dataset):
if not isinstance(annotations, dict):
raise TypeError(f'The annotations loaded from annotation file '
f'should be a dict, but got {type(annotations)}!')
if 'data_infos' not in annotations or 'metadata' not in annotations:
raise ValueError('Annotation must have data_infos and metadata '
if 'data_list' not in annotations or 'metainfo' not in annotations:
raise ValueError('Annotation must have data_list and metainfo '
'keys')
meta_data = annotations['metadata']
raw_data_infos = annotations['data_infos']
metainfo = annotations['metainfo']
raw_data_list = annotations['data_list']
# Meta information load from annotation file will not influence the
# existed meta information load from `BaseDataset.METAINFO` and
# `metainfo` arguments defined in constructor.
for k, v in meta_data.items():
for k, v in metainfo.items():
self._metainfo.setdefault(k, v)
# load and parse data_infos.
data_list = []
for raw_data_info in raw_data_infos:
for raw_data_info in raw_data_list:
# parse raw data information to target format
data_info = self.parse_data_info(raw_data_info)
if isinstance(data_info, dict):
......@@ -467,11 +467,11 @@ class BaseDataset(Dataset):
return data_list
@classmethod
def _get_meta_info(cls, in_metainfo: dict = None) -> dict:
def _load_metainfo(cls, metainfo: dict = None) -> dict:
"""Collect meta information from the dictionary of meta.
Args:
in_metainfo (dict): Meta information dict. If ``in_metainfo``
metainfo (dict): Meta information dict. If ``metainfo``
contains existed filename, it will be parsed by
``list_from_file``.
......@@ -480,15 +480,15 @@ class BaseDataset(Dataset):
"""
# `cls.METAINFO` will be overwritten by in_meta
cls_metainfo = copy.deepcopy(cls.METAINFO)
if in_metainfo is None:
if metainfo is None:
return cls_metainfo
if not isinstance(in_metainfo, dict):
if not isinstance(metainfo, dict):
raise TypeError(
f'in_metainfo should be a dict, but got {type(in_metainfo)}')
f'metainfo should be a dict, but got {type(metainfo)}')
for k, v in in_metainfo.items():
for k, v in metainfo.items():
if isinstance(v, str) and osp.isfile(v):
# if filename in in_metainfo, this key will be further parsed.
# if filename in metainfo, this key will be further parsed.
# nested filename will be ignored.
cls_metainfo[k] = list_from_file(v)
else:
......
{
"metadata":
"metainfo":
{
"dataset_type": "test_dataset",
"task_name": "test_task",
"empty_list": []
},
"data_infos":
"data_list":
[
{
"img_path": "test_img.jpg",
......
......@@ -87,6 +87,11 @@ class TestBaseDataset:
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'),
ann_file='annotations/not_existed_annotation.json')
# Use the default value of ann_file, i.e., ''
with pytest.raises(FileNotFoundError):
self.dataset_type(
data_root=osp.join(osp.dirname(__file__), '../data/'),
data_prefix=dict(img='imgs'))
# test the instantiation of self.base_dataset when the ann_file is
# wrong
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment